ErikDaska commited on
Commit
f2eb0c4
·
verified ·
1 Parent(s): 45fd469

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +12 -11
src/streamlit_app.py CHANGED
@@ -47,18 +47,19 @@ def instantiate_translation_model(model_name, text, src_lg, tgt_lg):
47
  return pipe(text)[0]["translation_text"]
48
 
49
  # ---- M2M100 ----
50
- elif "m2m100" in model_name:
51
- pipe = pipeline(
52
- "translation",
53
- model=model_path,
54
- tokenizer=model_path,
55
- token=token
56
- )
57
 
58
- pipe.tokenizer.src_lang = src_lg
 
 
 
 
 
 
 
 
59
  result = pipe(
60
  text,
61
- forced_bos_token_id=pipe.tokenizer.get_lang_id(tgt_lg)
62
  )
63
  return result[0]["translation_text"]
64
 
@@ -86,8 +87,8 @@ def build_translation_page(model_name):
86
 
87
  elif "m2m100" in model_name:
88
  lang_map = {
89
- "Português": "pt",
90
- "Kabuverdianu": "en" # m2m100 does NOT support kea
91
  }
92
 
93
  else: # mBART
 
47
  return pipe(text)[0]["translation_text"]
48
 
49
  # ---- M2M100 ----
 
 
 
 
 
 
 
50
 
51
+ elif "m2m100" in model_name:
52
+ pipe = load_pipeline("translation", model_path)
53
+
54
+ # Set the source language
55
+ pipe.tokenizer.src_lang = src_lg
56
+
57
+ # M2M100 requires the forced_bos_token_id to be the target lang token
58
+ tgt_lang_id = pipe.tokenizer.convert_tokens_to_ids(tgt_lg)
59
+
60
  result = pipe(
61
  text,
62
+ forced_bos_token_id=tgt_lang_id
63
  )
64
  return result[0]["translation_text"]
65
 
 
87
 
88
  elif "m2m100" in model_name:
89
  lang_map = {
90
+ "Português": "__pt__",
91
+ "Kabuverdianu": "__en__" # Proxying kea as __en__
92
  }
93
 
94
  else: # mBART