gcaillaut commited on
Commit
aed1d52
·
1 Parent(s): fb8fc36

add more models

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -6,7 +6,9 @@ import itertools
6
  DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
7
  MODEL_IDS = [
8
  "70M",
9
- # "160M",
 
 
10
  ]
11
  MODEL_MAPPING = {
12
  model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
@@ -157,16 +159,13 @@ def translate_with_all_models(text, tgt_lang, src_lang, domain):
157
  src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
158
  domain = DOMAIN_MAPPING[domain]
159
 
160
- res = {
161
- model_name: translate_with_model(model_name, text, tgt_lang, src_lang, domain)
162
- for model_name in MODEL_IDS
163
- }
164
- return list(
165
- itertools.chain.from_iterable(
166
- [res[model_id][k] for k in ("translation", "source_lang", "domain")]
167
- for model_id in MODEL_IDS
168
- )
169
- )
170
 
171
 
172
  with gr.Blocks() as demo:
 
6
  DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
7
  MODEL_IDS = [
8
  "70M",
9
+ "160M",
10
+ "410M",
11
+ "610M",
12
  ]
13
  MODEL_MAPPING = {
14
  model_id: f"LinguaCustodia/multilingual-multidomain-fin-mt-{model_id}"
 
159
  src_lang = None if src_lang == "Auto" else LANG2CODE.get(src_lang)
160
  domain = DOMAIN_MAPPING[domain]
161
 
162
+ outputs = [None] * (3 * len(MODEL_IDS))
163
+ for i, model_id in enumerate(MODEL_IDS):
164
+ model_output = translate_with_model(model_id, text, tgt_lang, src_lang, domain)
165
+ outputs[i * 3] = model_output["translation"]
166
+ outputs[i * 3 + 1] = model_output["source_lang"]
167
+ outputs[i * 3 + 2] = model_output["domain"]
168
+ yield outputs
 
 
 
169
 
170
 
171
  with gr.Blocks() as demo: