import os, re, gc, torch, platform from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr # ---- env (same intent as your notebook) ---- os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" os.environ["WANDB_DISABLED"] = "true" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" torch.set_num_threads(2) # or 1 on the smallest CPU tier os.environ["OMP_NUM_THREADS"] = "2" os.environ["MKL_NUM_THREADS"] = "2" print("CUDA available:", torch.cuda.is_available()) if torch.cuda.is_available(): print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0)) # On Colab T4 we can try FP16 on GPU; fallback to CPU if it fails. COMPUTE_DTYPE = torch.float16 # ---- AgriParam (exact logic) ---- AGRI_ID = "bharatgenai/AgriParam" ag_tok = AutoTokenizer.from_pretrained(AGRI_ID, use_fast=True, trust_remote_code=True) if ag_tok.pad_token is None: ag_tok.pad_token = ag_tok.eos_token or ag_tok.sep_token ag_mdl = None try: # Try GPU (faster). If OOM or no GPU, we’ll drop to CPU automatically. ag_mdl = AutoModelForCausalLM.from_pretrained( AGRI_ID, torch_dtype=COMPUTE_DTYPE, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True, ).eval() where = "GPU" except Exception as e: print("GPU load failed ⇒ using CPU:", e) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() ag_mdl = AutoModelForCausalLM.from_pretrained( AGRI_ID, torch_dtype=torch.float32, # safe on CPU device_map={"": "cpu"}, low_cpu_mem_usage=True, trust_remote_code=True, ).eval() where = "CPU" cfg = ag_mdl.generation_config cfg.do_sample = False cfg.temperature = None if cfg.pad_token_id is None: cfg.pad_token_id = ag_tok.eos_token_id if cfg.eos_token_id is None: cfg.eos_token_id = ag_tok.eos_token_id print(f"AgriParam loaded on {where}") # ---- NLLB translator (exact logic) ---- NLLB_ID = "facebook/nllb-200-distilled-600M" mt_tok = AutoTokenizer.from_pretrained(NLLB_ID) # fast is fine here mt_mdl = AutoModelForSeq2SeqLM.from_pretrained( NLLB_ID, device_map={"": "cpu"}, torch_dtype=torch.float32 ).eval() # Helper to get language token id across transformers versions def _get_lang_id(tok, code: str) -> int: if hasattr(tok, "lang_code_to_id") and code in tok.lang_code_to_id: return tok.lang_code_to_id[code] if hasattr(tok, "lang_token_to_id") and code in tok.lang_token_to_id: return tok.lang_token_to_id[code] try: tid = tok.convert_tokens_to_ids(code) if tid is not None and tid != tok.unk_token_id: return tid except Exception: pass if hasattr(tok, "get_added_vocab"): added = tok.get_added_vocab() if code in added: return added[code] raise KeyError(f"Language code not found: {code}") def _nllb_translate(text: str, src_code: str, tgt_code: str, max_new: int = 256) -> str: mt_tok.src_lang = src_code enc = mt_tok(text, return_tensors="pt", truncation=True, max_length=768) tgt_id = _get_lang_id(mt_tok, tgt_code) with torch.inference_mode(): out = mt_mdl.generate( **enc, forced_bos_token_id=tgt_id, max_new_tokens=max_new, do_sample=False, # num_beams=4, num_beams=1, repetition_penalty=1.02, ) return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip() def so_to_en(t: str) -> str: return _nllb_translate(t, "som_Latn", "eng_Latn", max_new=220) def en_to_so(t: str) -> str: txt = _nllb_translate(t, "eng_Latn", "som_Latn", max_new=320) # light cleanup of repeats txt = re.sub(r'(?:\b[a-z]\)\s*){2,}', '', txt, flags=re.IGNORECASE) # drop "a) a) a)" txt = re.sub(r'(\b\w+\b)(?:\s+\1){2,}', r'\1', txt) # word x3+ -> single return txt.strip() print("MT smoke:", so_to_en("Biyaha roobka badan; sidee u yareeyaa qudhunka basasha?")) # ---- Your router + helpers (exact logic) ---- AGRI_HINT = ( "Respond with ONLY bullet points. Each line MUST start with '- '. " "No headings, no questions, no meta text. " "Include concrete rates (g/L, ml/L, kg/ha), intervals (days), thresholds, " "timings, and safety. For bacterial diseases: NO curative chemicals; " "emphasize sanitation, water management, curing/storage; copper products " "only as protectants per label." ) # crop guess (very light) CROP_HINTS = { "basal":"onions","basasha":"onions", "yaanyo":"tomatoes","yaanyada":"tomatoes", "digir":"beans","digirta":"beans", "galley":"maize","masaggo":"sorghum","qamad":"wheat", } def guess_crop(so_q: str) -> str: q = so_q.lower() for k,v in CROP_HINTS.items(): if re.search(rf"\b{k}\b", q): return v return "" def polish_en_question(en_q_raw: str, crop: str) -> str: # Keep the user intent; just add formatting guidance. base = en_q_raw.strip() if len(base) < 8: base = "Give practical, field-ready advice for this farm question." guide = (" Answer ONLY in short bullet points that each start with '- '. " "Include specific rates (g/L, ml/L, kg/ha), days/intervals, timings and safety.") return base + " " + guide # Preserve chemical and unit tokens CHEM_KEEP = [ r"\bNPK\b", r"\bDAP\b", r"\bMOP\b", r"\bK2O\b", r"\bmancozeb\b", r"\bcopper oxychloride\b", r"\bstreptomycin\b", r"\bETc\b", r"\bEC\b" ] UNIT_PAT = [ r"\b\d+(\.\d+)?\s*(kg|g|ml|L)\s*/\s*(ha|L)\b", r"\b\d+(\.\d+)?\s*ppm\b", r"\b\d+(\.\d+)?\s*%\b", r"\b\d+\s*-\s*\d+\s*(days?|maalmo?)\b", r"\b\d+\s*(days?|maalmo?)\b", ] def _protect_terms(s: str): placeholders = {} idx = 0 for pat in CHEM_KEEP + UNIT_PAT: for m in re.finditer(pat, s, flags=re.IGNORECASE): span = m.group(0) key = f"__P{idx}__" s = s.replace(span, key, 1) placeholders[key] = span idx += 1 return s, placeholders def _restore_terms(s: str, placeholders: dict): for k, v in placeholders.items(): s = s.replace(k, v) return s @torch.inference_mode() def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=80) -> list[str]: crop = f" Crop: {crop_hint}." if crop_hint else "" prompt = f" Somali original: {so_original}.{crop} {AGRI_HINT} {q_en} " enc = ag_tok(prompt, return_tensors="pt") # Model may be on CPU or GPU already; just move inputs to same device dev = {k: v.to(ag_mdl.device) for k, v in enc.items()} out = ag_mdl.generate( **dev, max_new_tokens=max_new, do_sample=False, repetition_penalty=1.05, eos_token_id=ag_tok.eos_token_id, pad_token_id=ag_tok.eos_token_id, ) cont = out[0][enc["input_ids"].shape[1]:] txt = ag_tok.decode(cont, skip_special_tokens=True).strip() # 1) Preferred: lines that start with "- " lines = [ln.strip() for ln in txt.splitlines()] bullets = [ln[2:].strip() for ln in lines if ln.startswith("- ") and len(ln[2:].strip()) > 2] # 2) Fallback: split paragraph into short actionable pieces if not bullets: parts = [p.strip() for p in re.split(r"[•\-\u2013\u2014]|[\.\n;]", txt) if p.strip()] bullets = [p for p in parts if len(p) > 3][:10] # Deduplicate (case-insensitive) while preserving order seen, out_items = set(), [] for b in bullets: k = b.lower() if k not in seen: seen.add(k) out_items.append(b) return out_items[:10] def bullets_en_to_so(bullets_en: list[str]) -> list[str]: out = [] for b in bullets_en: prot, ph = _protect_terms(b) so = en_to_so(prot) so = _restore_terms(so, ph) so = so.strip(" .;:•-—") if len(so) >= 3: out.append(so) # final dedupe seen, ded = set(), [] for s in out: k = s.lower() if k not in seen: seen.add(k) ded.append(s) return ded[:10] def answer_router(question_so: str) -> str: crop = guess_crop(question_so) # Somali → English en_q_raw = so_to_en(question_so) en_q = polish_en_question(en_q_raw, crop) # Ask the agri model # bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200) bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=80) if not bullets_en: return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)." # Back to Somali bullets_so = bullets_en_to_so(bullets_en) if not bullets_so: return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)." return "\n".join(f"- {b}" for b in bullets_so) # ---- Minimal Gradio wrapper (to run this code in a Space) ---- with gr.Blocks(title="Somali Agri (AgriParam + NLLB, CPU)") as demo: gr.Markdown("## Somali Agri Assistant\nKu qor su'aashaada Af-Soomaali.") q = gr.Textbox(label="Su’aal (Af-Soomaali)", placeholder="Tusaale: Sidee yaanyada loo bacrimiyaa?") btn = gr.Button("Soo saar talooyin") a = gr.Textbox(label="Jawaab") btn.click(answer_router, q, a) q.submit(answer_router, q, a) if __name__ == "__main__": demo.queue().launch()