import re, json, gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # LIGHTWEIGHT, CPU-FRIENDLY MODEL MODEL_ID = "google/mt5-small" LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"] def build_prompt(text, labels=LABELS): return ( "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n" f"لیبل‌های مجاز: {', '.join(labels)}.\n" "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n" '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n' "هیچ متن دیگری ننویس؛ فقط JSON.\n\n" f"متن: {text}\n" ) def extract_json(s: str): m = re.search(r"\{[\s\S]*\}", s) if not m: return {"entities": []} raw = m.group(0) try: return json.loads(raw) except Exception: raw = re.sub(r",\s*}", "}", raw) raw = re.sub(r",\s*]", "]", raw) try: return json.loads(raw) except Exception: return {"entities": []} # Lazy load on CPU _tokenizer = None _model = None def load_model(): global _tokenizer, _model if _tokenizer is None or _model is None: _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) # CPU by default on Spaces return _tokenizer, _model def ner_infer(text, max_new_tokens=192): text = (text or "").strip() if not text: return {"entities": []} tok, model = load_model() prompt = build_prompt(text) inputs = tok(prompt, return_tensors="pt") # stays on CPU gen_ids = model.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=False, # deterministic on CPU temperature=0.0, # pad_token_id must be set for some T5/mT5 variants: pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id ) out = tok.decode(gen_ids[0], skip_special_tokens=True) data = extract_json(out) # normalize; if model omits start/end, compute first occurrence ents = [] for e in data.get("entities", []): try: t = str(e.get("text","")).strip() lab = str(e.get("label","")).strip() if not t or not lab: continue st = e.get("start"); en = e.get("end") if not isinstance(st, int) or not isinstance(en, int): idx = text.find(t) if idx >= 0: st, en = idx, idx + len(t) else: st, en = 0, 0 ents.append({"text": t, "label": lab, "start": int(st), "end": int(en)}) except Exception: pass return {"entities": ents} with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo: gr.Markdown("## Persian Zero-Shot NER (LLM) — CPU version (mT5-small)") inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.") max_tok = gr.Slider(64, 512, value=192, step=16, label="Max new tokens (CPU)") btn = gr.Button("Extract Entities") out = gr.JSON(label="خروجی JSON") btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out) if __name__ == "__main__": demo.launch()