ErikDaska commited on
Commit
55ad4cc
·
verified ·
1 Parent(s): 216cd80

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -10
src/streamlit_app.py CHANGED
@@ -43,25 +43,41 @@ def instantiate_translation_model(model_name: str, text: str, src_lg: str, tgt_l
43
  tgt = "por_Latn" if "pt" in tgt_lg else "kea_Latn"
44
  pipe = pipeline("translation", model=model_path, token=token, src_lang=src, tgt_lang=tgt)
45
  elif "m2m100" in model_name:
46
- src = "kea_Latn" if "en" in src_lg else "pt"
47
- tgt = "por_Latn" if "pt" in tgt_lg else "en"
48
- pipe = pipeline("translation", model=model_path, token=token, src_lang=src_lg, tgt_lang=tgt_lg)
 
 
 
 
 
49
  else:
50
  # Standard logic for MBart
51
  pipe = pipeline("translation", model=model_path, token=token, src_lang=src_lg, tgt_lang=tgt_lg)
52
 
53
- result = pipe(text)
54
  return result[0]["translation_text"]
55
 
56
- # --- UI Build Functions ---
57
  def build_translation_page(model_name):
58
  st.title(f"🌍 {model_name}: Tradução")
59
-
60
- # Dynamic language mapping based on model
61
  if "nllb" in model_name:
62
- lang_map = {"Português": "por_Latn", "Kabuverdianu": "kea_Latn"}
63
- else:
64
- lang_map = {"Português": "pt_XX", "Kabuverdianu": "en_XX"} # MBart style
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  col1, col2 = st.columns(2)
67
  with col1:
 
43
  tgt = "por_Latn" if "pt" in tgt_lg else "kea_Latn"
44
  pipe = pipeline("translation", model=model_path, token=token, src_lang=src, tgt_lang=tgt)
45
  elif "m2m100" in model_name:
46
+ pipe = pipeline(
47
+ "translation",
48
+ model=model_path,
49
+ tokenizer=model_path,
50
+ token=token
51
+ )
52
+
53
+ pipe.tokenizer.src_lang = src_lg
54
  else:
55
  # Standard logic for MBart
56
  pipe = pipeline("translation", model=model_path, token=token, src_lang=src_lg, tgt_lang=tgt_lg)
57
 
58
+ result = pipe(text, forced_bos_token_id=pipe.tokenizer.get_lang_id(tgt_lg))
59
  return result[0]["translation_text"]
60
 
 
61
  def build_translation_page(model_name):
62
  st.title(f"🌍 {model_name}: Tradução")
63
+
 
64
  if "nllb" in model_name:
65
+ lang_map = {
66
+ "Português": "por_Latn",
67
+ "Kabuverdianu": "kea_Latn"
68
+ }
69
+
70
+ elif "m2m100" in model_name:
71
+ lang_map = {
72
+ "Português": "pt",
73
+ "Kabuverdianu": "en" # m2m100 does not support kea
74
+ }
75
+
76
+ else: # mBART
77
+ lang_map = {
78
+ "Português": "pt_XX",
79
+ "Kabuverdianu": "en_XX"
80
+ }
81
 
82
  col1, col2 = st.columns(2)
83
  with col1: