# app.py — FINAL WORKING VERSION (deploy this now) import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel import torch import spaces from huggingface_hub import login, snapshot_download import os # === HF Login === hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError("Add HF_TOKEN as a secret in your Space!") login(token=hf_token) # === MODEL === BASE_MODEL = "Sunbird/translate-nllb-1.3b-salt" LORA_ADAPTER = "KMayanja/sunbird-medical-luganda-bidirectional-v3-kaggle" print("Downloading models...") snapshot_download(repo_id=BASE_MODEL, token=hf_token) snapshot_download(repo_id=LORA_ADAPTER, token=hf_token) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) print("Loading base model...") base_model = AutoModelForSeq2SeqLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True ) print("Applying your LoRA adapter...") model = PeftModel.from_pretrained(base_model, LORA_ADAPTER) model.eval() # === FIXED: Correct way to get language token IDs (works with fast tokenizer) === def get_lang_id(lang_code: str) -> int: return tokenizer.convert_tokens_to_ids(lang_code) print("Model ready on:", "GPU" if torch.cuda.is_available() else "CPU") # === Translation function === @spaces.GPU(duration=180) def predict(text, source_language="eng_Latn", target_language="lug_Latn"): if not text.strip(): return "Please enter text to translate." # ✅ Directly use NLLB language codes from the UI src = source_language tgt = target_language tokenizer.src_lang = src # only needed for some NLLB versions inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(model.device) with torch.no_grad(): generated = model.generate( **inputs, forced_bos_token_id=get_lang_id(tgt), # ← FIXED LINE max_length=512, num_beams=5, early_stopping=True, no_repeat_ngram_size=3, repetition_penalty=1.1 ) return tokenizer.decode(generated[0], skip_special_tokens=True) # === Gradio UI === with gr.Blocks(title="Medical Translator") as iface: gr.Markdown("# Uganda Medical Translator (English ↔ Luganda)") gr.Markdown("**Luganda medical model** — fine-tuned on 6.8k sentences by KMayanja") with gr.Row(): with gr.Column(scale=2): textbox = gr.Textbox(lines=6, label="Input Text", placeholder="Enter medical text here...") with gr.Column(scale=2): output = gr.Textbox(lines=6, label="Translation") with gr.Row(): src_lang = gr.Dropdown(["eng_Latn", "lug_Latn"], value="eng_Latn", label="Source Language") tgt_lang = gr.Dropdown(["lug_Latn", "eng_Latn"], value="lug_Latn", label="Target Language") btn = gr.Button("Translate", variant="primary") btn.click(predict, inputs=[textbox, src_lang, tgt_lang], outputs=output) gr.Examples([ ["The patient has severe malaria and needs immediate artesunate injection.", "eng_Latn", "lug_Latn"], ["Take two tablets three times daily after meals.", "eng_Latn", "lug_Latn"], ["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"], ], inputs=[textbox, src_lang, tgt_lang]) iface.launch(server_name="0.0.0.0", server_port=7860)