File size: 3,602 Bytes
fe0de4d
 
31a1007
ba88112
b742d70
31a1007
301a316
31a1007
 
 
ba88112
31a1007
fe0de4d
 
 
 
 
b742d70
d91b8a4
b742d70
fe0de4d
b742d70
 
 
fe0de4d
b742d70
 
fe0de4d
b742d70
 
fe0de4d
b742d70
 
 
 
fe0de4d
b742d70
 
31a1007
fe0de4d
 
 
 
 
 
 
 
8b7405e
fe0de4d
 
31a1007
8b7405e
 
 
c0758fe
fe0de4d
8b7405e
ba88112
 
 
 
 
 
fe0de4d
31a1007
ba88112
b742d70
ba88112
fe0de4d
ba88112
 
 
fe0de4d
 
ba88112
 
b742d70
ba88112
 
fe0de4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21d7116
fe0de4d
 
 
 
ba88112
fe0de4d
31a1007
fe0de4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# app.py — FINAL WORKING VERSION (deploy this now)

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import torch
import spaces
from huggingface_hub import login, snapshot_download
import os

# === HF Login ===
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
    raise ValueError("Add HF_TOKEN as a secret in your Space!")
login(token=hf_token)

# === MODEL ===
BASE_MODEL = "Sunbird/translate-nllb-1.3b-salt"
LORA_ADAPTER = "KMayanja/sunbird-medical-luganda-bidirectional-v3-kaggle"

print("Downloading models...")
snapshot_download(repo_id=BASE_MODEL, token=hf_token)
snapshot_download(repo_id=LORA_ADAPTER, token=hf_token)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)

print("Loading base model...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

print("Applying your LoRA adapter...")
model = PeftModel.from_pretrained(base_model, LORA_ADAPTER)
model.eval()

# === FIXED: Correct way to get language token IDs (works with fast tokenizer) ===
def get_lang_id(lang_code: str) -> int:
    return tokenizer.convert_tokens_to_ids(lang_code)

print("Model ready on:", "GPU" if torch.cuda.is_available() else "CPU")

# === Translation function ===
@spaces.GPU(duration=180)
def predict(text, source_language="eng_Latn", target_language="lug_Latn"):
    if not text.strip():
        return "Please enter text to translate."

    # ✅ Directly use NLLB language codes from the UI
    src = source_language
    tgt = target_language

    tokenizer.src_lang = src          # only needed for some NLLB versions
    
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(model.device)

    with torch.no_grad():
        generated = model.generate(
            **inputs,
            forced_bos_token_id=get_lang_id(tgt),   # ← FIXED LINE
            max_length=512,
            num_beams=5,
            early_stopping=True,
            no_repeat_ngram_size=3,
            repetition_penalty=1.1
        )

    return tokenizer.decode(generated[0], skip_special_tokens=True)


# === Gradio UI ===
with gr.Blocks(title="Medical Translator") as iface:
    gr.Markdown("# Uganda Medical Translator (English ↔ Luganda)")
    gr.Markdown("**Luganda medical model** — fine-tuned on 6.8k sentences by KMayanja")

    with gr.Row():
        with gr.Column(scale=2):
            textbox = gr.Textbox(lines=6, label="Input Text", placeholder="Enter medical text here...")
        with gr.Column(scale=2):
            output = gr.Textbox(lines=6, label="Translation")

    with gr.Row():
        src_lang = gr.Dropdown(["eng_Latn", "lug_Latn"], value="eng_Latn", label="Source Language")
        tgt_lang = gr.Dropdown(["lug_Latn", "eng_Latn"], value="lug_Latn", label="Target Language")
        btn = gr.Button("Translate", variant="primary")

    btn.click(predict, inputs=[textbox, src_lang, tgt_lang], outputs=output)

    gr.Examples([
        ["The patient has severe malaria and needs immediate artesunate injection.", "eng_Latn", "lug_Latn"],
        ["Take two tablets three times daily after meals.", "eng_Latn", "lug_Latn"],
        ["Omulwadde alina omusujja ogw’ekizungu era akennyamba okunywa amazzi.", "lug_Latn", "eng_Latn"],
    ], inputs=[textbox, src_lang, tgt_lang])

iface.launch(server_name="0.0.0.0", server_port=7860)