LouisMonawe commited on
Commit
ebbe17f
Β·
1 Parent(s): ce97752

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -14
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import requests
2
  import gradio as gr
3
  from dotenv import load_dotenv
 
4
  import os
5
 
6
  # Load environment variables
@@ -8,24 +9,35 @@ load_dotenv()
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
10
 
11
- # Language to ISO 639-3 codes (used for NLLB-200)
12
  LANGUAGES = {
13
- "English β†’ Afrikaans": "afr",
14
- "English β†’ Xhosa": "xho",
15
- "English β†’ Zulu": "zul",
16
- "English β†’ Sesotho": "sot",
17
- "English β†’ Tswana": "tsn",
18
- "English β†’ Northern Sotho": "nso",
19
- "English β†’ Swati": "ssw",
20
- "English β†’ Tsonga": "tso",
21
- "English β†’ Venda": "ven",
 
 
22
  }
23
 
24
  MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"
25
  API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
26
 
 
 
 
 
 
 
 
 
27
 
28
  def query(payload):
 
29
  response = requests.post(API_URL, headers=headers, json=payload)
30
 
31
  if response.status_code != 200:
@@ -40,10 +52,17 @@ def query(payload):
40
 
41
 
42
  def translate(input_text, language_label):
43
- language_code = LANGUAGES[language_label]
44
- formatted_input = f">>{language_code}<< {input_text}"
 
 
 
 
 
 
 
45
 
46
- response = query({"inputs": formatted_input, "options": {"wait_for_model": True}})
47
 
48
  if "error" in response:
49
  return f"Error: {response['error']}"
@@ -51,6 +70,7 @@ def translate(input_text, language_label):
51
  return response[0]["translation_text"]
52
 
53
 
 
54
  translator = gr.Interface(
55
  fn=translate,
56
  inputs=[
@@ -59,7 +79,7 @@ translator = gr.Interface(
59
  ],
60
  outputs=gr.Textbox(label="Translation"),
61
  title="Translademia",
62
- description="Translate English text to South African languages using Meta's NLLB-200 model.",
63
  )
64
 
65
  translator.launch()
 
1
  import requests
2
  import gradio as gr
3
  from dotenv import load_dotenv
4
+ from transformers import MBart50TokenizerFast
5
  import os
6
 
7
  # Load environment variables
 
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
11
 
12
+ # Correct mBART-50 language codes
13
  LANGUAGES = {
14
+ "English β†’ Afrikaans": "af_ZA",
15
+ "English β†’ Xhosa": "xh_ZA",
16
+ "English β†’ Zulu": "zu_ZA",
17
+ "English β†’ Sesotho": "st_ZA", # Southern Sotho
18
+ "English β†’ Tswana": "tn_ZA",
19
+ # The following are *not officially* supported by mBART-50 and may raise errors
20
+ # You can remove them if not working
21
+ # "English β†’ Northern Sotho": "nso_ZA",
22
+ # "English β†’ Swati": "ss_ZA",
23
+ # "English β†’ Tsonga": "ts_ZA",
24
+ # "English β†’ Venda": "ve_ZA",
25
  }
26
 
27
  MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"
28
  API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
29
 
30
+ # Load tokenizer to get language token IDs
31
+ tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME)
32
+
33
+
34
+ def get_token_id(lang_code):
35
+ """Return the forced_bos_token_id for the target language."""
36
+ return tokenizer.lang_code_to_id[lang_code]
37
+
38
 
39
  def query(payload):
40
+ """Send the translation request to the Hugging Face API."""
41
  response = requests.post(API_URL, headers=headers, json=payload)
42
 
43
  if response.status_code != 200:
 
52
 
53
 
54
  def translate(input_text, language_label):
55
+ """Main translation function."""
56
+ target_lang_code = LANGUAGES[language_label]
57
+ token_id = get_token_id(target_lang_code)
58
+
59
+ payload = {
60
+ "inputs": input_text,
61
+ "parameters": {"forced_bos_token_id": token_id},
62
+ "options": {"wait_for_model": True},
63
+ }
64
 
65
+ response = query(payload)
66
 
67
  if "error" in response:
68
  return f"Error: {response['error']}"
 
70
  return response[0]["translation_text"]
71
 
72
 
73
+ # Gradio UI
74
  translator = gr.Interface(
75
  fn=translate,
76
  inputs=[
 
79
  ],
80
  outputs=gr.Textbox(label="Translation"),
81
  title="Translademia",
82
+ description="Translate English text to South African languages using Meta's mBART-50 model.",
83
  )
84
 
85
  translator.launch()