Spaces:
Sleeping
Sleeping
| import os, re, gc, torch, platform | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM | |
| import gradio as gr | |
| # ---- env (same intent as your notebook) ---- | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| os.environ["WANDB_DISABLED"] = "true" | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| torch.set_num_threads(2) # or 1 on the smallest CPU tier | |
| os.environ["OMP_NUM_THREADS"] = "2" | |
| os.environ["MKL_NUM_THREADS"] = "2" | |
| print("CUDA available:", torch.cuda.is_available()) | |
| if torch.cuda.is_available(): | |
| print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0)) | |
| # On Colab T4 we can try FP16 on GPU; fallback to CPU if it fails. | |
| COMPUTE_DTYPE = torch.float16 | |
| # ---- AgriParam (exact logic) ---- | |
| AGRI_ID = "bharatgenai/AgriParam" | |
| ag_tok = AutoTokenizer.from_pretrained(AGRI_ID, use_fast=True, trust_remote_code=True) | |
| if ag_tok.pad_token is None: | |
| ag_tok.pad_token = ag_tok.eos_token or ag_tok.sep_token | |
| ag_mdl = None | |
| try: | |
| # Try GPU (faster). If OOM or no GPU, we’ll drop to CPU automatically. | |
| ag_mdl = AutoModelForCausalLM.from_pretrained( | |
| AGRI_ID, | |
| torch_dtype=COMPUTE_DTYPE, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ).eval() | |
| where = "GPU" | |
| except Exception as e: | |
| print("GPU load failed ⇒ using CPU:", e) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| ag_mdl = AutoModelForCausalLM.from_pretrained( | |
| AGRI_ID, | |
| torch_dtype=torch.float32, # safe on CPU | |
| device_map={"": "cpu"}, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ).eval() | |
| where = "CPU" | |
| cfg = ag_mdl.generation_config | |
| cfg.do_sample = False | |
| cfg.temperature = None | |
| if cfg.pad_token_id is None: cfg.pad_token_id = ag_tok.eos_token_id | |
| if cfg.eos_token_id is None: cfg.eos_token_id = ag_tok.eos_token_id | |
| print(f"AgriParam loaded on {where}") | |
| # ---- NLLB translator (exact logic) ---- | |
| NLLB_ID = "facebook/nllb-200-distilled-600M" | |
| mt_tok = AutoTokenizer.from_pretrained(NLLB_ID) # fast is fine here | |
| mt_mdl = AutoModelForSeq2SeqLM.from_pretrained( | |
| NLLB_ID, device_map={"": "cpu"}, torch_dtype=torch.float32 | |
| ).eval() | |
| # Helper to get language token id across transformers versions | |
| def _get_lang_id(tok, code: str) -> int: | |
| if hasattr(tok, "lang_code_to_id") and code in tok.lang_code_to_id: | |
| return tok.lang_code_to_id[code] | |
| if hasattr(tok, "lang_token_to_id") and code in tok.lang_token_to_id: | |
| return tok.lang_token_to_id[code] | |
| try: | |
| tid = tok.convert_tokens_to_ids(code) | |
| if tid is not None and tid != tok.unk_token_id: | |
| return tid | |
| except Exception: | |
| pass | |
| if hasattr(tok, "get_added_vocab"): | |
| added = tok.get_added_vocab() | |
| if code in added: | |
| return added[code] | |
| raise KeyError(f"Language code not found: {code}") | |
| def _nllb_translate(text: str, src_code: str, tgt_code: str, max_new: int = 256) -> str: | |
| mt_tok.src_lang = src_code | |
| enc = mt_tok(text, return_tensors="pt", truncation=True, max_length=768) | |
| tgt_id = _get_lang_id(mt_tok, tgt_code) | |
| with torch.inference_mode(): | |
| out = mt_mdl.generate( | |
| **enc, | |
| forced_bos_token_id=tgt_id, | |
| max_new_tokens=max_new, | |
| do_sample=False, | |
| # num_beams=4, | |
| num_beams=1, | |
| repetition_penalty=1.02, | |
| ) | |
| return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip() | |
| def so_to_en(t: str) -> str: | |
| return _nllb_translate(t, "som_Latn", "eng_Latn", max_new=220) | |
| def en_to_so(t: str) -> str: | |
| txt = _nllb_translate(t, "eng_Latn", "som_Latn", max_new=320) | |
| # light cleanup of repeats | |
| txt = re.sub(r'(?:\b[a-z]\)\s*){2,}', '', txt, flags=re.IGNORECASE) # drop "a) a) a)" | |
| txt = re.sub(r'(\b\w+\b)(?:\s+\1){2,}', r'\1', txt) # word x3+ -> single | |
| return txt.strip() | |
| print("MT smoke:", so_to_en("Biyaha roobka badan; sidee u yareeyaa qudhunka basasha?")) | |
| # ---- Your router + helpers (exact logic) ---- | |
| AGRI_HINT = ( | |
| "Respond with ONLY bullet points. Each line MUST start with '- '. " | |
| "No headings, no questions, no meta text. " | |
| "Include concrete rates (g/L, ml/L, kg/ha), intervals (days), thresholds, " | |
| "timings, and safety. For bacterial diseases: NO curative chemicals; " | |
| "emphasize sanitation, water management, curing/storage; copper products " | |
| "only as protectants per label." | |
| ) | |
| # crop guess (very light) | |
| CROP_HINTS = { | |
| "basal":"onions","basasha":"onions", | |
| "yaanyo":"tomatoes","yaanyada":"tomatoes", | |
| "digir":"beans","digirta":"beans", | |
| "galley":"maize","masaggo":"sorghum","qamad":"wheat", | |
| } | |
| def guess_crop(so_q: str) -> str: | |
| q = so_q.lower() | |
| for k,v in CROP_HINTS.items(): | |
| if re.search(rf"\b{k}\b", q): return v | |
| return "" | |
| def polish_en_question(en_q_raw: str, crop: str) -> str: | |
| # Keep the user intent; just add formatting guidance. | |
| base = en_q_raw.strip() | |
| if len(base) < 8: | |
| base = "Give practical, field-ready advice for this farm question." | |
| guide = (" Answer ONLY in short bullet points that each start with '- '. " | |
| "Include specific rates (g/L, ml/L, kg/ha), days/intervals, timings and safety.") | |
| return base + " " + guide | |
| # Preserve chemical and unit tokens | |
| CHEM_KEEP = [ | |
| r"\bNPK\b", r"\bDAP\b", r"\bMOP\b", r"\bK2O\b", | |
| r"\bmancozeb\b", r"\bcopper oxychloride\b", r"\bstreptomycin\b", | |
| r"\bETc\b", r"\bEC\b" | |
| ] | |
| UNIT_PAT = [ | |
| r"\b\d+(\.\d+)?\s*(kg|g|ml|L)\s*/\s*(ha|L)\b", | |
| r"\b\d+(\.\d+)?\s*ppm\b", | |
| r"\b\d+(\.\d+)?\s*%\b", | |
| r"\b\d+\s*-\s*\d+\s*(days?|maalmo?)\b", | |
| r"\b\d+\s*(days?|maalmo?)\b", | |
| ] | |
| def _protect_terms(s: str): | |
| placeholders = {} | |
| idx = 0 | |
| for pat in CHEM_KEEP + UNIT_PAT: | |
| for m in re.finditer(pat, s, flags=re.IGNORECASE): | |
| span = m.group(0) | |
| key = f"__P{idx}__" | |
| s = s.replace(span, key, 1) | |
| placeholders[key] = span | |
| idx += 1 | |
| return s, placeholders | |
| def _restore_terms(s: str, placeholders: dict): | |
| for k, v in placeholders.items(): | |
| s = s.replace(k, v) | |
| return s | |
| def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=80) -> list[str]: | |
| crop = f" Crop: {crop_hint}." if crop_hint else "" | |
| prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>" | |
| enc = ag_tok(prompt, return_tensors="pt") | |
| # Model may be on CPU or GPU already; just move inputs to same device | |
| dev = {k: v.to(ag_mdl.device) for k, v in enc.items()} | |
| out = ag_mdl.generate( | |
| **dev, | |
| max_new_tokens=max_new, | |
| do_sample=False, | |
| repetition_penalty=1.05, | |
| eos_token_id=ag_tok.eos_token_id, | |
| pad_token_id=ag_tok.eos_token_id, | |
| ) | |
| cont = out[0][enc["input_ids"].shape[1]:] | |
| txt = ag_tok.decode(cont, skip_special_tokens=True).strip() | |
| # 1) Preferred: lines that start with "- " | |
| lines = [ln.strip() for ln in txt.splitlines()] | |
| bullets = [ln[2:].strip() for ln in lines if ln.startswith("- ") and len(ln[2:].strip()) > 2] | |
| # 2) Fallback: split paragraph into short actionable pieces | |
| if not bullets: | |
| parts = [p.strip() for p in re.split(r"[•\-\u2013\u2014]|[\.\n;]", txt) if p.strip()] | |
| bullets = [p for p in parts if len(p) > 3][:10] | |
| # Deduplicate (case-insensitive) while preserving order | |
| seen, out_items = set(), [] | |
| for b in bullets: | |
| k = b.lower() | |
| if k not in seen: | |
| seen.add(k) | |
| out_items.append(b) | |
| return out_items[:10] | |
| def bullets_en_to_so(bullets_en: list[str]) -> list[str]: | |
| out = [] | |
| for b in bullets_en: | |
| prot, ph = _protect_terms(b) | |
| so = en_to_so(prot) | |
| so = _restore_terms(so, ph) | |
| so = so.strip(" .;:•-—") | |
| if len(so) >= 3: | |
| out.append(so) | |
| # final dedupe | |
| seen, ded = set(), [] | |
| for s in out: | |
| k = s.lower() | |
| if k not in seen: | |
| seen.add(k) | |
| ded.append(s) | |
| return ded[:10] | |
| def answer_router(question_so: str) -> str: | |
| crop = guess_crop(question_so) | |
| # Somali → English | |
| en_q_raw = so_to_en(question_so) | |
| en_q = polish_en_question(en_q_raw, crop) | |
| # Ask the agri model | |
| # bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200) | |
| bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=80) | |
| if not bullets_en: | |
| return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)." | |
| # Back to Somali | |
| bullets_so = bullets_en_to_so(bullets_en) | |
| if not bullets_so: | |
| return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)." | |
| return "\n".join(f"- {b}" for b in bullets_so) | |
| # ---- Minimal Gradio wrapper (to run this code in a Space) ---- | |
| with gr.Blocks(title="Somali Agri (AgriParam + NLLB, CPU)") as demo: | |
| gr.Markdown("## Somali Agri Assistant\nKu qor su'aashaada Af-Soomaali.") | |
| q = gr.Textbox(label="Su’aal (Af-Soomaali)", placeholder="Tusaale: Sidee yaanyada loo bacrimiyaa?") | |
| btn = gr.Button("Soo saar talooyin") | |
| a = gr.Textbox(label="Jawaab") | |
| btn.click(answer_router, q, a) | |
| q.submit(answer_router, q, a) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |