Spaces:
Sleeping
Sleeping
| # app.py — FINAL WORKING VERSION (deploy this now) | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from peft import PeftModel | |
| import torch | |
| import spaces | |
| from huggingface_hub import login, snapshot_download | |
| import os | |
| # === HF Login === | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| raise ValueError("Add HF_TOKEN as a secret in your Space!") | |
| login(token=hf_token) | |
| # === MODEL === | |
| BASE_MODEL = "Sunbird/translate-nllb-1.3b-salt" | |
| LORA_ADAPTER = "KMayanja/sunbird-medical-luganda-bidirectional-v3-kaggle" | |
| print("Downloading models...") | |
| snapshot_download(repo_id=BASE_MODEL, token=hf_token) | |
| snapshot_download(repo_id=LORA_ADAPTER, token=hf_token) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) | |
| print("Loading base model...") | |
| base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| print("Applying your LoRA adapter...") | |
| model = PeftModel.from_pretrained(base_model, LORA_ADAPTER) | |
| model.eval() | |
| # === FIXED: Correct way to get language token IDs (works with fast tokenizer) === | |
| def get_lang_id(lang_code: str) -> int: | |
| return tokenizer.convert_tokens_to_ids(lang_code) | |
| print("Model ready on:", "GPU" if torch.cuda.is_available() else "CPU") | |
| # === Translation function === | |
| def predict(text, source_language="eng_Latn", target_language="lug_Latn"): | |
| if not text.strip(): | |
| return "Please enter text to translate." | |
| # ✅ Directly use NLLB language codes from the UI | |
| src = source_language | |
| tgt = target_language | |
| tokenizer.src_lang = src # only needed for some NLLB versions | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| generated = model.generate( | |
| **inputs, | |
| forced_bos_token_id=get_lang_id(tgt), # ← FIXED LINE | |
| max_length=512, | |
| num_beams=5, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3, | |
| repetition_penalty=1.1 | |
| ) | |
| return tokenizer.decode(generated[0], skip_special_tokens=True) | |
| # === Gradio UI === | |
| with gr.Blocks(title="Medical Translator") as iface: | |
| gr.Markdown("# Uganda Medical Translator (English ↔ Luganda)") | |
| gr.Markdown("**Luganda medical model** — fine-tuned on 6.8k sentences by KMayanja") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| textbox = gr.Textbox(lines=6, label="Input Text", placeholder="Enter medical text here...") | |
| with gr.Column(scale=2): | |
| output = gr.Textbox(lines=6, label="Translation") | |
| with gr.Row(): | |
| src_lang = gr.Dropdown(["eng_Latn", "lug_Latn"], value="eng_Latn", label="Source Language") | |
| tgt_lang = gr.Dropdown(["lug_Latn", "eng_Latn"], value="lug_Latn", label="Target Language") | |
| btn = gr.Button("Translate", variant="primary") | |
| btn.click(predict, inputs=[textbox, src_lang, tgt_lang], outputs=output) | |
| gr.Examples([ | |
| ["The patient has severe malaria and needs immediate artesunate injection.", "eng_Latn", "lug_Latn"], | |
| ["Take two tablets three times daily after meals.", "eng_Latn", "lug_Latn"], | |
| ["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"], | |
| ], inputs=[textbox, src_lang, tgt_lang]) | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |