translator / app.py
nurfarah57's picture
Update app.py
8584a8b verified
import os, re, gc, torch, platform
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
# ---- env (same intent as your notebook) ----
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_num_threads(2) # or 1 on the smallest CPU tier
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0))
# On Colab T4 we can try FP16 on GPU; fallback to CPU if it fails.
COMPUTE_DTYPE = torch.float16
# ---- AgriParam (exact logic) ----
AGRI_ID = "bharatgenai/AgriParam"
ag_tok = AutoTokenizer.from_pretrained(AGRI_ID, use_fast=True, trust_remote_code=True)
if ag_tok.pad_token is None:
ag_tok.pad_token = ag_tok.eos_token or ag_tok.sep_token
ag_mdl = None
try:
# Try GPU (faster). If OOM or no GPU, we’ll drop to CPU automatically.
ag_mdl = AutoModelForCausalLM.from_pretrained(
AGRI_ID,
torch_dtype=COMPUTE_DTYPE,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
where = "GPU"
except Exception as e:
print("GPU load failed ⇒ using CPU:", e)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
ag_mdl = AutoModelForCausalLM.from_pretrained(
AGRI_ID,
torch_dtype=torch.float32, # safe on CPU
device_map={"": "cpu"},
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
where = "CPU"
cfg = ag_mdl.generation_config
cfg.do_sample = False
cfg.temperature = None
if cfg.pad_token_id is None: cfg.pad_token_id = ag_tok.eos_token_id
if cfg.eos_token_id is None: cfg.eos_token_id = ag_tok.eos_token_id
print(f"AgriParam loaded on {where}")
# ---- NLLB translator (exact logic) ----
NLLB_ID = "facebook/nllb-200-distilled-600M"
mt_tok = AutoTokenizer.from_pretrained(NLLB_ID) # fast is fine here
mt_mdl = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_ID, device_map={"": "cpu"}, torch_dtype=torch.float32
).eval()
# Helper to get language token id across transformers versions
def _get_lang_id(tok, code: str) -> int:
if hasattr(tok, "lang_code_to_id") and code in tok.lang_code_to_id:
return tok.lang_code_to_id[code]
if hasattr(tok, "lang_token_to_id") and code in tok.lang_token_to_id:
return tok.lang_token_to_id[code]
try:
tid = tok.convert_tokens_to_ids(code)
if tid is not None and tid != tok.unk_token_id:
return tid
except Exception:
pass
if hasattr(tok, "get_added_vocab"):
added = tok.get_added_vocab()
if code in added:
return added[code]
raise KeyError(f"Language code not found: {code}")
def _nllb_translate(text: str, src_code: str, tgt_code: str, max_new: int = 256) -> str:
mt_tok.src_lang = src_code
enc = mt_tok(text, return_tensors="pt", truncation=True, max_length=768)
tgt_id = _get_lang_id(mt_tok, tgt_code)
with torch.inference_mode():
out = mt_mdl.generate(
**enc,
forced_bos_token_id=tgt_id,
max_new_tokens=max_new,
do_sample=False,
# num_beams=4,
num_beams=1,
repetition_penalty=1.02,
)
return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip()
def so_to_en(t: str) -> str:
return _nllb_translate(t, "som_Latn", "eng_Latn", max_new=220)
def en_to_so(t: str) -> str:
txt = _nllb_translate(t, "eng_Latn", "som_Latn", max_new=320)
# light cleanup of repeats
txt = re.sub(r'(?:\b[a-z]\)\s*){2,}', '', txt, flags=re.IGNORECASE) # drop "a) a) a)"
txt = re.sub(r'(\b\w+\b)(?:\s+\1){2,}', r'\1', txt) # word x3+ -> single
return txt.strip()
print("MT smoke:", so_to_en("Biyaha roobka badan; sidee u yareeyaa qudhunka basasha?"))
# ---- Your router + helpers (exact logic) ----
AGRI_HINT = (
"Respond with ONLY bullet points. Each line MUST start with '- '. "
"No headings, no questions, no meta text. "
"Include concrete rates (g/L, ml/L, kg/ha), intervals (days), thresholds, "
"timings, and safety. For bacterial diseases: NO curative chemicals; "
"emphasize sanitation, water management, curing/storage; copper products "
"only as protectants per label."
)
# crop guess (very light)
CROP_HINTS = {
"basal":"onions","basasha":"onions",
"yaanyo":"tomatoes","yaanyada":"tomatoes",
"digir":"beans","digirta":"beans",
"galley":"maize","masaggo":"sorghum","qamad":"wheat",
}
def guess_crop(so_q: str) -> str:
q = so_q.lower()
for k,v in CROP_HINTS.items():
if re.search(rf"\b{k}\b", q): return v
return ""
def polish_en_question(en_q_raw: str, crop: str) -> str:
# Keep the user intent; just add formatting guidance.
base = en_q_raw.strip()
if len(base) < 8:
base = "Give practical, field-ready advice for this farm question."
guide = (" Answer ONLY in short bullet points that each start with '- '. "
"Include specific rates (g/L, ml/L, kg/ha), days/intervals, timings and safety.")
return base + " " + guide
# Preserve chemical and unit tokens
CHEM_KEEP = [
r"\bNPK\b", r"\bDAP\b", r"\bMOP\b", r"\bK2O\b",
r"\bmancozeb\b", r"\bcopper oxychloride\b", r"\bstreptomycin\b",
r"\bETc\b", r"\bEC\b"
]
UNIT_PAT = [
r"\b\d+(\.\d+)?\s*(kg|g|ml|L)\s*/\s*(ha|L)\b",
r"\b\d+(\.\d+)?\s*ppm\b",
r"\b\d+(\.\d+)?\s*%\b",
r"\b\d+\s*-\s*\d+\s*(days?|maalmo?)\b",
r"\b\d+\s*(days?|maalmo?)\b",
]
def _protect_terms(s: str):
placeholders = {}
idx = 0
for pat in CHEM_KEEP + UNIT_PAT:
for m in re.finditer(pat, s, flags=re.IGNORECASE):
span = m.group(0)
key = f"__P{idx}__"
s = s.replace(span, key, 1)
placeholders[key] = span
idx += 1
return s, placeholders
def _restore_terms(s: str, placeholders: dict):
for k, v in placeholders.items():
s = s.replace(k, v)
return s
@torch.inference_mode()
def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=80) -> list[str]:
crop = f" Crop: {crop_hint}." if crop_hint else ""
prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>"
enc = ag_tok(prompt, return_tensors="pt")
# Model may be on CPU or GPU already; just move inputs to same device
dev = {k: v.to(ag_mdl.device) for k, v in enc.items()}
out = ag_mdl.generate(
**dev,
max_new_tokens=max_new,
do_sample=False,
repetition_penalty=1.05,
eos_token_id=ag_tok.eos_token_id,
pad_token_id=ag_tok.eos_token_id,
)
cont = out[0][enc["input_ids"].shape[1]:]
txt = ag_tok.decode(cont, skip_special_tokens=True).strip()
# 1) Preferred: lines that start with "- "
lines = [ln.strip() for ln in txt.splitlines()]
bullets = [ln[2:].strip() for ln in lines if ln.startswith("- ") and len(ln[2:].strip()) > 2]
# 2) Fallback: split paragraph into short actionable pieces
if not bullets:
parts = [p.strip() for p in re.split(r"[•\-\u2013\u2014]|[\.\n;]", txt) if p.strip()]
bullets = [p for p in parts if len(p) > 3][:10]
# Deduplicate (case-insensitive) while preserving order
seen, out_items = set(), []
for b in bullets:
k = b.lower()
if k not in seen:
seen.add(k)
out_items.append(b)
return out_items[:10]
def bullets_en_to_so(bullets_en: list[str]) -> list[str]:
out = []
for b in bullets_en:
prot, ph = _protect_terms(b)
so = en_to_so(prot)
so = _restore_terms(so, ph)
so = so.strip(" .;:•-—")
if len(so) >= 3:
out.append(so)
# final dedupe
seen, ded = set(), []
for s in out:
k = s.lower()
if k not in seen:
seen.add(k)
ded.append(s)
return ded[:10]
def answer_router(question_so: str) -> str:
crop = guess_crop(question_so)
# Somali → English
en_q_raw = so_to_en(question_so)
en_q = polish_en_question(en_q_raw, crop)
# Ask the agri model
# bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200)
bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=80)
if not bullets_en:
return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
# Back to Somali
bullets_so = bullets_en_to_so(bullets_en)
if not bullets_so:
return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
return "\n".join(f"- {b}" for b in bullets_so)
# ---- Minimal Gradio wrapper (to run this code in a Space) ----
with gr.Blocks(title="Somali Agri (AgriParam + NLLB, CPU)") as demo:
gr.Markdown("## Somali Agri Assistant\nKu qor su'aashaada Af-Soomaali.")
q = gr.Textbox(label="Su’aal (Af-Soomaali)", placeholder="Tusaale: Sidee yaanyada loo bacrimiyaa?")
btn = gr.Button("Soo saar talooyin")
a = gr.Textbox(label="Jawaab")
btn.click(answer_router, q, a)
q.submit(answer_router, q, a)
if __name__ == "__main__":
demo.queue().launch()