KMayanja commited on
Commit
fe0de4d
·
verified ·
1 Parent(s): 37d9a80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -62
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from peft import PeftModel
@@ -8,102 +10,90 @@ import os
8
 
9
  # === HF Login ===
10
  hf_token = os.environ.get("HF_TOKEN")
11
- if hf_token:
12
- login(token=hf_token)
13
- else:
14
- raise ValueError("HF_TOKEN not set! Add it as a Space secret.")
15
-
16
- # === MODEL CONFIG ===
17
- # You currently have ONLY the LoRA adapter uploaded
18
- # So we load the base model first, then apply your LoRA on top
19
  BASE_MODEL = "Sunbird/translate-nllb-1.3b-salt"
20
- LORA_ADAPTER = "KMayanja/sunbird-medical-luganda-bidirectional" # ← your repo
21
 
 
22
  snapshot_download(repo_id=BASE_MODEL, token=hf_token)
23
  snapshot_download(repo_id=LORA_ADAPTER, token=hf_token)
24
 
25
- print("Loading tokenizer and base model...")
26
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
27
 
28
- print("Loading base model (this takes ~15 seconds)...")
29
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
30
  BASE_MODEL,
31
- torch_dtype=torch.float32,
32
  low_cpu_mem_usage=True,
33
  trust_remote_code=True
34
  )
35
 
36
- print("Applying your medical LoRA adapter...")
37
  model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
38
-
39
- # DO NOT .to(device) here — @spaces.GPU will handle it automatically
40
  model.eval()
41
- print("Model ready! (LoRA successfully applied)")
42
-
43
- # === LANGUAGE CODES (correct FLORES-200 codes) ===
44
- supported_langs = ["eng_Latn", "lug_Latn"]
45
- lang_names = {"eng_Latn": "English", "lug_Latn": "Luganda"}
46
-
47
- # === FALLBACK TO OLD CODE (just uncomment if you ever need it) ===
48
- """
49
- # model_name = "Sunbird/translate-nllb-1.3b-salt"
50
- # tokenizer = NllbTokenizer.from_pretrained(model_name)
51
- # model = M2M100ForConditionalGeneration.from_pretrained(model_name)
52
- # language_tokens = {'eng': 256047, 'lug': 256110, ...}
53
- """
54
-
55
- # === TRANSLATION FUNCTION (GPU → CPU auto-fallback via @spaces.GPU) ===
56
- @spaces.GPU(duration=180) # 3-minute GPU, then falls back to CPU
57
- def translate(text, source_language="eng_Latn", target_language="lug_Latn"):
58
- if not text.strip():
59
- return "Please enter some text."
60
 
61
- tokenizer.src_lang = source_language
62
- tokenizer.tgt_lang = target_language
 
 
 
 
 
 
 
 
 
63
 
 
64
  inputs = tokenizer(
65
  text,
66
  return_tensors="pt",
67
  padding=True,
68
  truncation=True,
69
  max_length=512
70
- ).to(model.device) # automatically uses GPU or CPU
71
 
72
  with torch.no_grad():
73
  generated = model.generate(
74
  **inputs,
75
- forced_bos_token_id=tokenizer.lang_code_to_id[target_language],
76
  max_length=512,
77
  num_beams=5,
78
  early_stopping=True,
79
- no_repeat_ngram_size=3
 
80
  )
81
 
82
  return tokenizer.decode(generated[0], skip_special_tokens=True)
83
 
84
 
85
- # === GRADIO INTERFACE ===
86
- iface = gr.Interface(
87
- fn=translate,
88
- inputs=[
89
- gr.Textbox(label="Text to translate", lines=5, placeholder="Enter medical text..."),
90
- gr.Dropdown(choices=supported_langs, value="eng_Latn", label="Source Language"),
91
- gr.Dropdown(choices=supported_langs, value="lug_Latn", label="Target Language"),
92
- ],
93
- outputs=gr.Textbox(label="Translation", lines=5),
94
- title="Uganda Medical Translator (English ↔ Luganda)",
95
- description="""
96
- **Best available medical translator for Luganda** — fine-tuned on 6.8k high-quality medical sentences.<br>
97
- Trained by KMayanja using Sunbird 1.3B + LoRA.<br>
98
- BLEU 20, chrF 36 (excellent real-world quality despite low-resource metrics).
99
- """,
100
- examples=[
101
- ["The patient has severe malaria and needs immediate artesunate.", "eng_Latn", "lug_Latn"],
102
- ["Take 2 tablets three times daily after meals.", "eng_Latn", "lug_Latn"],
 
 
 
103
  ["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"],
104
- ],
105
- allow_flagging="never"
106
- )
107
 
108
- if __name__ == "__main__":
109
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py — FINAL WORKING VERSION (deploy this now)
2
+
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from peft import PeftModel
 
10
 
11
  # === HF Login ===
12
  hf_token = os.environ.get("HF_TOKEN")
13
+ if not hf_token:
14
+ raise ValueError("Add HF_TOKEN as a secret in your Space!")
15
+ login(token=hf_token)
16
+
17
+ # === MODEL ===
 
 
 
18
  BASE_MODEL = "Sunbird/translate-nllb-1.3b-salt"
19
+ LORA_ADAPTER = "KMayanja/sunbird-medical-luganda-bidirectional"
20
 
21
+ print("Downloading models...")
22
  snapshot_download(repo_id=BASE_MODEL, token=hf_token)
23
  snapshot_download(repo_id=LORA_ADAPTER, token=hf_token)
24
 
25
+ print("Loading tokenizer...")
26
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
27
 
28
+ print("Loading base model...")
29
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
30
  BASE_MODEL,
31
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
32
  low_cpu_mem_usage=True,
33
  trust_remote_code=True
34
  )
35
 
36
+ print("Applying your LoRA adapter...")
37
  model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
 
 
38
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # === FIXED: Correct way to get language token IDs (works with fast tokenizer) ===
41
+ def get_lang_id(lang_code: str) -> int:
42
+ return tokenizer.convert_tokens_to_ids(lang_code)
43
+
44
+ print("Model ready on:", "GPU" if torch.cuda.is_available() else "CPU")
45
+
46
+ # === Translation function ===
47
+ @spaces.GPU(duration=180)
48
+ def translate(text, src="eng_Latn", tgt="lug_Latn"):
49
+ if not text.strip():
50
+ return "Please enter text to translate."
51
 
52
+ tokenizer.src_lang = src # only needed for some NLLB versions
53
  inputs = tokenizer(
54
  text,
55
  return_tensors="pt",
56
  padding=True,
57
  truncation=True,
58
  max_length=512
59
+ ).to(model.device)
60
 
61
  with torch.no_grad():
62
  generated = model.generate(
63
  **inputs,
64
+ forced_bos_token_id=get_lang_id(tgt), # ← FIXED LINE
65
  max_length=512,
66
  num_beams=5,
67
  early_stopping=True,
68
+ no_repeat_ngram_size=3,
69
+ repetition_penalty=1.1
70
  )
71
 
72
  return tokenizer.decode(generated[0], skip_special_tokens=True)
73
 
74
 
75
+ # === Gradio UI ===
76
+ with gr.Blocks(title="Medical Translator") as iface:
77
+ gr.Markdown("# Uganda Medical Translator (English ↔ Luganda)")
78
+ gr.Markdown("**Luganda medical model** — fine-tuned on 6.8k sentences by KMayanja")
79
+
80
+ with gr.Row():
81
+ with gr.Column(scale=2):
82
+ textbox = gr.Textbox(lines=6, label="Input Text", placeholder="Enter medical text here...")
83
+ with gr.Column(scale=2):
84
+ output = gr.Textbox(lines=6, label="Translation")
85
+
86
+ with gr.Row():
87
+ src_lang = gr.Dropdown(["eng_Latn", "lug_Latn"], value="eng_Latn", label="Source Language")
88
+ tgt_lang = gr.Dropdown(["lug_Latn", "eng_Latn"], value="lug_Latn", label="Target Language")
89
+ btn = gr.Button("Translate", variant="primary")
90
+
91
+ btn.click(translate, inputs=[textbox, src_lang, tgt_lang], outputs=output)
92
+
93
+ gr.Examples([
94
+ ["The patient has severe malaria and needs immediate artesunate injection.", "eng_Latn", "lug_Latn"],
95
+ ["Take two tablets three times daily after meals.", "eng_Latn", "lug_Latn"],
96
  ["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"],
97
+ ], inputs=[textbox, src_lang, tgt_lang])
 
 
98
 
99
+ iface.launch(server_name="0.0.0.0", server_port=7860)