testTranslate / app.py
KMayanja's picture
Update app.py
8b7405e verified
# app.py — FINAL WORKING VERSION (deploy this now)
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import torch
import spaces
from huggingface_hub import login, snapshot_download
import os
# === HF Login ===
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError("Add HF_TOKEN as a secret in your Space!")
login(token=hf_token)
# === MODEL ===
BASE_MODEL = "Sunbird/translate-nllb-1.3b-salt"
LORA_ADAPTER = "KMayanja/sunbird-medical-luganda-bidirectional-v3-kaggle"
print("Downloading models...")
snapshot_download(repo_id=BASE_MODEL, token=hf_token)
snapshot_download(repo_id=LORA_ADAPTER, token=hf_token)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
print("Loading base model...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True
)
print("Applying your LoRA adapter...")
model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
model.eval()
# === FIXED: Correct way to get language token IDs (works with fast tokenizer) ===
def get_lang_id(lang_code: str) -> int:
return tokenizer.convert_tokens_to_ids(lang_code)
print("Model ready on:", "GPU" if torch.cuda.is_available() else "CPU")
# === Translation function ===
@spaces.GPU(duration=180)
def predict(text, source_language="eng_Latn", target_language="lug_Latn"):
if not text.strip():
return "Please enter text to translate."
# ✅ Directly use NLLB language codes from the UI
src = source_language
tgt = target_language
tokenizer.src_lang = src # only needed for some NLLB versions
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(model.device)
with torch.no_grad():
generated = model.generate(
**inputs,
forced_bos_token_id=get_lang_id(tgt), # ← FIXED LINE
max_length=512,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=1.1
)
return tokenizer.decode(generated[0], skip_special_tokens=True)
# === Gradio UI ===
with gr.Blocks(title="Medical Translator") as iface:
gr.Markdown("# Uganda Medical Translator (English ↔ Luganda)")
gr.Markdown("**Luganda medical model** — fine-tuned on 6.8k sentences by KMayanja")
with gr.Row():
with gr.Column(scale=2):
textbox = gr.Textbox(lines=6, label="Input Text", placeholder="Enter medical text here...")
with gr.Column(scale=2):
output = gr.Textbox(lines=6, label="Translation")
with gr.Row():
src_lang = gr.Dropdown(["eng_Latn", "lug_Latn"], value="eng_Latn", label="Source Language")
tgt_lang = gr.Dropdown(["lug_Latn", "eng_Latn"], value="lug_Latn", label="Target Language")
btn = gr.Button("Translate", variant="primary")
btn.click(predict, inputs=[textbox, src_lang, tgt_lang], outputs=output)
gr.Examples([
["The patient has severe malaria and needs immediate artesunate injection.", "eng_Latn", "lug_Latn"],
["Take two tablets three times daily after meals.", "eng_Latn", "lug_Latn"],
["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"],
], inputs=[textbox, src_lang, tgt_lang])
iface.launch(server_name="0.0.0.0", server_port=7860)