TiberiuCristianLeon commited on
Commit
c6ff6a5
·
verified ·
1 Parent(s): 7e51f36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -1
app.py CHANGED
@@ -42,6 +42,7 @@ models = ["Helsinki-NLP", "QUICKMT", "Argos", "Lego-MT/Lego-MT", "HPLT", "HPLT-O
42
  "t5-small", "t5-base", "t5-large",
43
  "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl",
44
  "google/madlad400-3b-mt", "Heng666/madlad400-3b-mt-ct2", "Heng666/madlad400-3b-mt-ct2-int8", "Heng666/madlad400-7b-mt-ct2-int8",
 
45
  "BSC-LT/salamandraTA-2b-instruct", "BSC-LT/salamandraTA-7b-instruct",
46
  "utter-project/EuroLLM-1.7B", "utter-project/EuroLLM-1.7B-Instruct",
47
  "Unbabel/Tower-Plus-2B", "Unbabel/TowerInstruct-7B-v0.2", "Unbabel/TowerInstruct-Mistral-7B-v0.2",
@@ -61,6 +62,25 @@ class Translators:
61
  response = httpx.get(url)
62
  return response.json()[0][0][0]
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def hplt(self, opus = False):
65
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']
66
  hplt_models = ['ar-en', 'bs-en', 'ca-en', 'en-ar', 'en-bs', 'en-ca', 'en-et', 'en-eu', 'en-fi',
@@ -583,7 +603,10 @@ def translate_text(model_name: str, s_language: str, t_language: str, input_text
583
  translated_text, message_text = Translators(model_name, sl, tl, input_text).hplt(opus = True)
584
  else:
585
  translated_text, message_text = Translators(model_name, sl, tl, input_text).hplt()
586
-
 
 
 
587
  elif model_name == 'Argos':
588
  translated_text = Translators(model_name, sl, tl, input_text).argos()
589
 
 
42
  "t5-small", "t5-base", "t5-large",
43
  "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl",
44
  "google/madlad400-3b-mt", "Heng666/madlad400-3b-mt-ct2", "Heng666/madlad400-3b-mt-ct2-int8", "Heng666/madlad400-7b-mt-ct2-int8",
45
+ "naist-nlp/mitre_466m", "naist-nlp/mitre_913m",
46
  "BSC-LT/salamandraTA-2b-instruct", "BSC-LT/salamandraTA-7b-instruct",
47
  "utter-project/EuroLLM-1.7B", "utter-project/EuroLLM-1.7B-Instruct",
48
  "Unbabel/Tower-Plus-2B", "Unbabel/TowerInstruct-7B-v0.2", "Unbabel/TowerInstruct-Mistral-7B-v0.2",
 
62
  response = httpx.get(url)
63
  return response.json()[0][0][0]
64
 
65
+ def mitre(self):
66
+ from transformers import AutoModel, AutoTokenizer
67
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, use_fast=False)
68
+ model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
69
+ # model.half() # recommended
70
+ model.eval()
71
+
72
+ # Translating from one or several sentences to a sole language
73
+ src_tokens = tokenizer.encode_source_tokens_to_input_ids([self.input_text, ], target_language=self.tl)
74
+ src_tokens = src_tokens.to(self.device)
75
+ # Translating from one or several sentences to corresponding languages
76
+ # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
77
+ # generated_tokens = model.generate(src_tokensto(self.device))
78
+ # results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
79
+
80
+ with torch.no_grad():
81
+ generated_tokens = model.generate(src_tokens)
82
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
83
+
84
  def hplt(self, opus = False):
85
  # langs = ['ar', 'bs', 'ca', 'en', 'et', 'eu', 'fi', 'ga', 'gl', 'hi', 'hr', 'is', 'mt', 'nn', 'sq', 'sw', 'zh_hant']
86
  hplt_models = ['ar-en', 'bs-en', 'ca-en', 'en-ar', 'en-bs', 'en-ca', 'en-et', 'en-eu', 'en-fi',
 
603
  translated_text, message_text = Translators(model_name, sl, tl, input_text).hplt(opus = True)
604
  else:
605
  translated_text, message_text = Translators(model_name, sl, tl, input_text).hplt()
606
+
607
+ elif 'mitre' in model_name.lower():
608
+ translated_text = Translators(model_name, sl, tl, input_text).mitre()
609
+
610
  elif model_name == 'Argos':
611
  translated_text = Translators(model_name, sl, tl, input_text).argos()
612