Spaces:
Sleeping
Sleeping
File size: 9,438 Bytes
7993f15 270089c 7993f15 270089c 8584a8b 7993f15 270089c 7993f15 270089c 7993f15 8584a8b 7993f15 270089c 7993f15 270089c 7993f15 8584a8b 7993f15 270089c 7993f15 270089c 7993f15 270089c 7993f15 270089c 7993f15 270089c 7993f15 270089c 7993f15 8584a8b 7993f15 270089c 7993f15 270089c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
import os, re, gc, torch, platform
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
# ---- env (same intent as your notebook) ----
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_num_threads(2) # or 1 on the smallest CPU tier
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0))
# On Colab T4 we can try FP16 on GPU; fallback to CPU if it fails.
COMPUTE_DTYPE = torch.float16
# ---- AgriParam (exact logic) ----
AGRI_ID = "bharatgenai/AgriParam"
ag_tok = AutoTokenizer.from_pretrained(AGRI_ID, use_fast=True, trust_remote_code=True)
if ag_tok.pad_token is None:
ag_tok.pad_token = ag_tok.eos_token or ag_tok.sep_token
ag_mdl = None
try:
# Try GPU (faster). If OOM or no GPU, we’ll drop to CPU automatically.
ag_mdl = AutoModelForCausalLM.from_pretrained(
AGRI_ID,
torch_dtype=COMPUTE_DTYPE,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
where = "GPU"
except Exception as e:
print("GPU load failed ⇒ using CPU:", e)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
ag_mdl = AutoModelForCausalLM.from_pretrained(
AGRI_ID,
torch_dtype=torch.float32, # safe on CPU
device_map={"": "cpu"},
low_cpu_mem_usage=True,
trust_remote_code=True,
).eval()
where = "CPU"
cfg = ag_mdl.generation_config
cfg.do_sample = False
cfg.temperature = None
if cfg.pad_token_id is None: cfg.pad_token_id = ag_tok.eos_token_id
if cfg.eos_token_id is None: cfg.eos_token_id = ag_tok.eos_token_id
print(f"AgriParam loaded on {where}")
# ---- NLLB translator (exact logic) ----
NLLB_ID = "facebook/nllb-200-distilled-600M"
mt_tok = AutoTokenizer.from_pretrained(NLLB_ID) # fast is fine here
mt_mdl = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_ID, device_map={"": "cpu"}, torch_dtype=torch.float32
).eval()
# Helper to get language token id across transformers versions
def _get_lang_id(tok, code: str) -> int:
if hasattr(tok, "lang_code_to_id") and code in tok.lang_code_to_id:
return tok.lang_code_to_id[code]
if hasattr(tok, "lang_token_to_id") and code in tok.lang_token_to_id:
return tok.lang_token_to_id[code]
try:
tid = tok.convert_tokens_to_ids(code)
if tid is not None and tid != tok.unk_token_id:
return tid
except Exception:
pass
if hasattr(tok, "get_added_vocab"):
added = tok.get_added_vocab()
if code in added:
return added[code]
raise KeyError(f"Language code not found: {code}")
def _nllb_translate(text: str, src_code: str, tgt_code: str, max_new: int = 256) -> str:
mt_tok.src_lang = src_code
enc = mt_tok(text, return_tensors="pt", truncation=True, max_length=768)
tgt_id = _get_lang_id(mt_tok, tgt_code)
with torch.inference_mode():
out = mt_mdl.generate(
**enc,
forced_bos_token_id=tgt_id,
max_new_tokens=max_new,
do_sample=False,
# num_beams=4,
num_beams=1,
repetition_penalty=1.02,
)
return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip()
def so_to_en(t: str) -> str:
return _nllb_translate(t, "som_Latn", "eng_Latn", max_new=220)
def en_to_so(t: str) -> str:
txt = _nllb_translate(t, "eng_Latn", "som_Latn", max_new=320)
# light cleanup of repeats
txt = re.sub(r'(?:\b[a-z]\)\s*){2,}', '', txt, flags=re.IGNORECASE) # drop "a) a) a)"
txt = re.sub(r'(\b\w+\b)(?:\s+\1){2,}', r'\1', txt) # word x3+ -> single
return txt.strip()
print("MT smoke:", so_to_en("Biyaha roobka badan; sidee u yareeyaa qudhunka basasha?"))
# ---- Your router + helpers (exact logic) ----
AGRI_HINT = (
"Respond with ONLY bullet points. Each line MUST start with '- '. "
"No headings, no questions, no meta text. "
"Include concrete rates (g/L, ml/L, kg/ha), intervals (days), thresholds, "
"timings, and safety. For bacterial diseases: NO curative chemicals; "
"emphasize sanitation, water management, curing/storage; copper products "
"only as protectants per label."
)
# crop guess (very light)
CROP_HINTS = {
"basal":"onions","basasha":"onions",
"yaanyo":"tomatoes","yaanyada":"tomatoes",
"digir":"beans","digirta":"beans",
"galley":"maize","masaggo":"sorghum","qamad":"wheat",
}
def guess_crop(so_q: str) -> str:
q = so_q.lower()
for k,v in CROP_HINTS.items():
if re.search(rf"\b{k}\b", q): return v
return ""
def polish_en_question(en_q_raw: str, crop: str) -> str:
# Keep the user intent; just add formatting guidance.
base = en_q_raw.strip()
if len(base) < 8:
base = "Give practical, field-ready advice for this farm question."
guide = (" Answer ONLY in short bullet points that each start with '- '. "
"Include specific rates (g/L, ml/L, kg/ha), days/intervals, timings and safety.")
return base + " " + guide
# Preserve chemical and unit tokens
CHEM_KEEP = [
r"\bNPK\b", r"\bDAP\b", r"\bMOP\b", r"\bK2O\b",
r"\bmancozeb\b", r"\bcopper oxychloride\b", r"\bstreptomycin\b",
r"\bETc\b", r"\bEC\b"
]
UNIT_PAT = [
r"\b\d+(\.\d+)?\s*(kg|g|ml|L)\s*/\s*(ha|L)\b",
r"\b\d+(\.\d+)?\s*ppm\b",
r"\b\d+(\.\d+)?\s*%\b",
r"\b\d+\s*-\s*\d+\s*(days?|maalmo?)\b",
r"\b\d+\s*(days?|maalmo?)\b",
]
def _protect_terms(s: str):
placeholders = {}
idx = 0
for pat in CHEM_KEEP + UNIT_PAT:
for m in re.finditer(pat, s, flags=re.IGNORECASE):
span = m.group(0)
key = f"__P{idx}__"
s = s.replace(span, key, 1)
placeholders[key] = span
idx += 1
return s, placeholders
def _restore_terms(s: str, placeholders: dict):
for k, v in placeholders.items():
s = s.replace(k, v)
return s
@torch.inference_mode()
def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=80) -> list[str]:
crop = f" Crop: {crop_hint}." if crop_hint else ""
prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>"
enc = ag_tok(prompt, return_tensors="pt")
# Model may be on CPU or GPU already; just move inputs to same device
dev = {k: v.to(ag_mdl.device) for k, v in enc.items()}
out = ag_mdl.generate(
**dev,
max_new_tokens=max_new,
do_sample=False,
repetition_penalty=1.05,
eos_token_id=ag_tok.eos_token_id,
pad_token_id=ag_tok.eos_token_id,
)
cont = out[0][enc["input_ids"].shape[1]:]
txt = ag_tok.decode(cont, skip_special_tokens=True).strip()
# 1) Preferred: lines that start with "- "
lines = [ln.strip() for ln in txt.splitlines()]
bullets = [ln[2:].strip() for ln in lines if ln.startswith("- ") and len(ln[2:].strip()) > 2]
# 2) Fallback: split paragraph into short actionable pieces
if not bullets:
parts = [p.strip() for p in re.split(r"[•\-\u2013\u2014]|[\.\n;]", txt) if p.strip()]
bullets = [p for p in parts if len(p) > 3][:10]
# Deduplicate (case-insensitive) while preserving order
seen, out_items = set(), []
for b in bullets:
k = b.lower()
if k not in seen:
seen.add(k)
out_items.append(b)
return out_items[:10]
def bullets_en_to_so(bullets_en: list[str]) -> list[str]:
out = []
for b in bullets_en:
prot, ph = _protect_terms(b)
so = en_to_so(prot)
so = _restore_terms(so, ph)
so = so.strip(" .;:•-—")
if len(so) >= 3:
out.append(so)
# final dedupe
seen, ded = set(), []
for s in out:
k = s.lower()
if k not in seen:
seen.add(k)
ded.append(s)
return ded[:10]
def answer_router(question_so: str) -> str:
crop = guess_crop(question_so)
# Somali → English
en_q_raw = so_to_en(question_so)
en_q = polish_en_question(en_q_raw, crop)
# Ask the agri model
# bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200)
bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=80)
if not bullets_en:
return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
# Back to Somali
bullets_so = bullets_en_to_so(bullets_en)
if not bullets_so:
return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
return "\n".join(f"- {b}" for b in bullets_so)
# ---- Minimal Gradio wrapper (to run this code in a Space) ----
with gr.Blocks(title="Somali Agri (AgriParam + NLLB, CPU)") as demo:
gr.Markdown("## Somali Agri Assistant\nKu qor su'aashaada Af-Soomaali.")
q = gr.Textbox(label="Su’aal (Af-Soomaali)", placeholder="Tusaale: Sidee yaanyada loo bacrimiyaa?")
btn = gr.Button("Soo saar talooyin")
a = gr.Textbox(label="Jawaab")
btn.click(answer_router, q, a)
q.submit(answer_router, q, a)
if __name__ == "__main__":
demo.queue().launch()
|