TiberiuCristianLeon commited on
Commit
a50fb3c
·
verified ·
1 Parent(s): c6ce97e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -14,19 +14,11 @@ options.extend(list(all_langs.keys()))
14
  models = ["Helsinki-NLP",
15
  "t5-small", "t5-base", "t5-large",
16
  "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl",
17
- "facebook/nllb-200-distilled-600M",
18
- "facebook/nllb-200-distilled-1.3B",
19
- "facebook/mbart-large-50-many-to-many-mmt",
20
- "bigscience/mt0-small",
21
- "bigscience/mt0-base",
22
- "bigscience/mt0-large",
23
- "bigscience/mt0-xl",
24
- "bigscience/bloomz-560m",
25
- "bigscience/bloomz-1b1",
26
- "bigscience/bloomz-1b7",
27
- "bigscience/bloomz-3b",
28
- "utter-project/EuroLLM-1.7B",
29
- "utter-project/EuroLLM-1.7B-Instruct",
30
  "Unbabel/Tower-Plus-2B",
31
  "Unbabel/TowerInstruct-7B-v0.2",
32
  "Unbabel/TowerInstruct-Mistral-7B-v0.2",
@@ -183,7 +175,7 @@ def unbabel(model_name, sl, tl, input_text):
183
  translated_text = translated_text.replace('Answer:', '', 1).strip() if translated_text.startswith('Answer:') else translated_text
184
  return translated_text
185
 
186
- def mbart(model_name, sl, tl, input_text):
187
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
188
  model = MBartForConditionalGeneration.from_pretrained(model_name)
189
  tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
@@ -196,6 +188,20 @@ def mbart(model_name, sl, tl, input_text):
196
  )
197
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  @spaces.GPU
200
  def translate_text(input_text: str, sselected_language: str, tselected_language: str, model_name: str) -> tuple[str, str]:
201
  """
@@ -255,8 +261,12 @@ def translate_text(input_text: str, sselected_language: str, tselected_language:
255
  translated_text = nllb(model_name, nnlbsl, nnlbtl, input_text)
256
  return translated_text, message_text
257
 
258
- elif model_name.startswith('facebook/mbart-large'):
259
- translated_text = mbart(model_name, sselected_language, tselected_language, input_text)
 
 
 
 
260
  return translated_text, message_text
261
 
262
  elif 'Unbabel' in model_name:
 
14
  models = ["Helsinki-NLP",
15
  "t5-small", "t5-base", "t5-large",
16
  "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl",
17
+ "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/nllb-200-1.3B",
18
+ "facebook/mbart-large-50-many-to-many-mmt", "facebook/mbart-large-50-one-to-many-mmt",
19
+ "bigscience/mt0-small", "bigscience/mt0-base", "bigscience/mt0-large", "bigscience/mt0-xl",
20
+ "bigscience/bloomz-560m", "bigscience/bloomz-1b1", "bigscience/bloomz-1b7", "bigscience/bloomz-3b",
21
+ "utter-project/EuroLLM-1.7B", "utter-project/EuroLLM-1.7B-Instruct",
 
 
 
 
 
 
 
 
22
  "Unbabel/Tower-Plus-2B",
23
  "Unbabel/TowerInstruct-7B-v0.2",
24
  "Unbabel/TowerInstruct-Mistral-7B-v0.2",
 
175
  translated_text = translated_text.replace('Answer:', '', 1).strip() if translated_text.startswith('Answer:') else translated_text
176
  return translated_text
177
 
178
+ def mbart_many_to_many(model_name, sl, tl, input_text):
179
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
180
  model = MBartForConditionalGeneration.from_pretrained(model_name)
181
  tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
 
188
  )
189
  return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
190
 
191
+ def mbart_one_to_many(model_name, sl, tl, input_text):
192
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
193
+ article_en = input_text
194
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
195
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX")
196
+ model_inputs = tokenizer(article_en, return_tensors="pt")
197
+ # translate from English to Romanian
198
+ langid = languagecodes.mbart_large_languages[tl]
199
+ generated_tokens = model.generate(
200
+ **model_inputs,
201
+ forced_bos_token_id=tokenizer.lang_code_to_id[langid]
202
+ )
203
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
204
+
205
  @spaces.GPU
206
  def translate_text(input_text: str, sselected_language: str, tselected_language: str, model_name: str) -> tuple[str, str]:
207
  """
 
261
  translated_text = nllb(model_name, nnlbsl, nnlbtl, input_text)
262
  return translated_text, message_text
263
 
264
+ elif model_name == "facebook/mbart-large-50-many-to-many-mmt":
265
+ translated_text = mbart_many_to_many(model_name, sselected_language, tselected_language, input_text)
266
+ return translated_text, message_text
267
+
268
+ elif model_name == "facebook/mbart-large-50-one-to-many-mmt":
269
+ translated_text = mbart_one_to_many(model_name, sselected_language, tselected_language, input_text)
270
  return translated_text, message_text
271
 
272
  elif 'Unbabel' in model_name: