KMayanja commited on
Commit
ba88112
·
verified ·
1 Parent(s): c680888

Update app.py

Browse files

Updated app.py to use medical model finetuned from the sunbird/salt-nllb-200-1.3B

Files changed (1) hide show
  1. app.py +83 -47
app.py CHANGED
@@ -1,84 +1,120 @@
1
  import gradio as gr
2
- from transformers import NllbTokenizer, M2M100ForConditionalGeneration
3
  import torch
4
  import spaces
5
  from huggingface_hub import login, snapshot_download
6
  import os
7
 
8
- # Fix: Retrieve HF token from environment (set as a Space secret)
9
  hf_token = os.environ.get("HF_TOKEN")
10
  if hf_token:
11
  login(token=hf_token)
12
- os.environ["HF_TOKEN"] = hf_token
13
  else:
14
- raise ValueError("HF_TOKEN environment variable not set. Add it as a secret in your Space settings.")
15
 
16
- # Model name
17
- model_name = "Sunbird/translate-nllb-1.3b-salt"
18
 
19
- # Download the model files first to avoid issues during loading
20
  snapshot_download(repo_id=model_name, token=hf_token)
21
 
22
- # Load the tokenizer and model once at startup
23
  try:
24
- tokenizer = NllbTokenizer.from_pretrained(model_name, token=hf_token)
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- model = M2M100ForConditionalGeneration.from_pretrained(model_name, token=hf_token)
27
- model.to(device) # Fix: Move to device once here, not per request
28
- model.eval() # Set to evaluation mode for inference
 
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
  print(f"Error loading model: {e}")
31
  raise
32
 
33
- # Supported languages and their tokens
34
- language_tokens = {
35
- 'eng': 256047,
36
- 'ach': 256111,
37
- 'lgg': 256008,
38
- 'lug': 256110,
39
- 'nyn': 256002,
40
- 'teo': 256006,
41
  }
42
 
43
- supported_languages = list(language_tokens.keys())
 
44
 
45
- @spaces.GPU
46
- def translate(text, source_language, target_language):
47
- if source_language not in supported_languages:
48
- raise ValueError(f"Source language '{source_language}' not supported. Supported: {supported_languages}")
49
- if target_language not in supported_languages:
50
- raise ValueError(f"Target language '{target_language}' not supported. Supported: {supported_languages}")
51
 
52
- # Fix: No need to move model hereit's already on device
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- inputs = tokenizer(text, return_tensors="pt").to(device)
55
- inputs['input_ids'][0][0] = language_tokens[source_language]
 
56
 
57
- translated_tokens = model.generate(
58
- **inputs,
59
- forced_bos_token_id=language_tokens[target_language],
60
- max_length=100,
61
- num_beams=5,
 
62
  )
63
 
64
- result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
65
- return result
66
 
67
- # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  iface = gr.Interface(
69
  fn=translate,
70
  inputs=[
71
- gr.Textbox(label="Text to translate"),
72
- gr.Dropdown(choices=supported_languages, label="Source language (e.g., 'eng')", value='eng'), # Optional: Dropdown for easier UX
73
- gr.Dropdown(choices=supported_languages, label="Target language (e.g., 'lug')", value='lug'), # Optional: Dropdown for easier UX
 
 
 
 
 
 
 
 
 
 
 
74
  ],
75
- outputs=gr.Textbox(label="Translated text"),
76
- title="Test Translation API",
77
- description="Translate text using Sunbird/translate-nllb-1.3b-salt model(To be replaced later). Supported languages: eng (English), lug (Luganda).",
78
  )
79
 
80
- # Fix: Remove share=True—HF Spaces handles this
81
- # Launch the application
82
  if __name__ == "__main__":
83
  iface.launch(
84
  server_name="0.0.0.0",
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  import spaces
5
  from huggingface_hub import login, snapshot_download
6
  import os
7
 
8
+ # === HF Login ===
9
  hf_token = os.environ.get("HF_TOKEN")
10
  if hf_token:
11
  login(token=hf_token)
 
12
  else:
13
+ raise ValueError("HF_TOKEN not set! Add it as a Space secret.")
14
 
15
+ # === MODEL CONFIG ===
16
+ model_name = "KMayanja/sunbird-medical-luganda-bidirectional"
17
 
18
+ # Optional: cache model locally on first load
19
  snapshot_download(repo_id=model_name, token=hf_token)
20
 
21
+ # === LOAD TOKENIZER & MODEL ONCE AT STARTUP ===
22
  try:
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
24
+
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float32, # Safe default (GPU will auto-upgrade to bfloat16 if possible)
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ # Let @spaces.GPU handle device placement — do NOT move model here
32
+ # model.to(device) ← removed on purpose
33
+
34
+ model.eval()
35
+ print("Model loaded successfully.")
36
+
37
  except Exception as e:
38
  print(f"Error loading model: {e}")
39
  raise
40
 
41
+ # === LANGUAGE CODES (correct ones for your fine-tuned model) ===
42
+ # These are the official FLORES-200 codes used by Sunbird & NLLB
43
+ lang_code_to_id = {
44
+ "eng_Latn": tokenizer.lang_code_to_id["eng_Latn"],
45
+ "lug_Latn": tokenizer.lang_code_to_id["lug_Latn"],
 
 
 
46
  }
47
 
48
+ supported_langs = ["eng_Latn", "lug_Latn"]
49
+ lang_names = {"eng_Latn": "English", "lug_Latn": "Luganda"}
50
 
 
 
 
 
 
 
51
 
52
+ # === FALLBACK: Old working code (commented outjust uncomment to revert) ===
53
+ """
54
+ # model_name = "Sunbird/translate-nllb-1.3b-salt"
55
+ # tokenizer = NllbTokenizer.from_pretrained(model_name, token=hf_token)
56
+ # model = M2M100ForConditionalGeneration.from_pretrained(model_name, token=hf_token)
57
+ # language_tokens = {'eng': 256047, 'lug': 256110, ...}
58
+ """
59
+
60
+ # === MAIN TRANSLATION FUNCTION WITH GPU AUTO-FALLBACK ===
61
+ @spaces.GPU(duration=120) # 2 minutes GPU, then auto-fallback to CPU
62
+ def translate(text, source_language, target_language):
63
+ if text.strip() == "":
64
+ return "Please enter text to translate."
65
 
66
+ # Set source & target language
67
+ tokenizer.src_lang = source_language
68
+ tokenizer.tgt_lang = target_language
69
 
70
+ inputs = tokenizer(
71
+ text,
72
+ return_tensors="pt",
73
+ padding=True,
74
+ truncation=True,
75
+ max_length=512
76
  )
77
 
78
+ # Move inputs to correct device (GPU if available, else CPU)
79
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
80
 
81
+ with torch.no_grad():
82
+ generated_ids = model.generate(
83
+ **inputs,
84
+ forced_bos_token_id=tokenizer.lang_code_to_id[target_language],
85
+ max_length=512,
86
+ num_beams=5,
87
+ early_stopping=True,
88
+ no_repeat_ngram_size=3
89
+ )
90
+
91
+ translation = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
92
+ return translation
93
+
94
+
95
+ # === GRADIO INTERFACE ===
96
  iface = gr.Interface(
97
  fn=translate,
98
  inputs=[
99
+ gr.Textbox(label="Text to translate", lines=4, placeholder="Enter medical text here..."),
100
+ gr.Dropdown(choices=supported_langs, value="eng_Latn", label="Source Language"),
101
+ gr.Dropdown(choices=supported_langs, value="lug_Latn", label="Target Language"),
102
+ ],
103
+ outputs=gr.Textbox(label="Translation", lines=4),
104
+ title="Luganda Medical Translator (Sunbird 1.3B Fine-tuned)",
105
+ description="""
106
+ State-of-the-art bidirectional English ↔ Luganda medical translator.<br>
107
+ Trained on 6.8k high-quality medical sentences. Best available model for healthcare in Uganda.
108
+ """,
109
+ examples=[
110
+ ["The patient has severe malaria and needs immediate treatment.", "eng_Latn", "lug_Latn"],
111
+ ["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"],
112
+ ["Take two tablets three times daily after meals.", "eng_Latn", "lug_Latn"],
113
  ],
114
+ allow_flagging="never"
 
 
115
  )
116
 
117
+ # === LAUNCH ===
 
118
  if __name__ == "__main__":
119
  iface.launch(
120
  server_name="0.0.0.0",