Spaces:
Sleeping
Sleeping
| import re, json, gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # LIGHTWEIGHT, CPU-FRIENDLY MODEL | |
| MODEL_ID = "google/mt5-small" | |
| LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"] | |
| def build_prompt(text, labels=LABELS): | |
| return ( | |
| "متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n" | |
| f"لیبلهای مجاز: {', '.join(labels)}.\n" | |
| "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n" | |
| '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n' | |
| "هیچ متن دیگری ننویس؛ فقط JSON.\n\n" | |
| f"متن: {text}\n" | |
| ) | |
| def extract_json(s: str): | |
| m = re.search(r"\{[\s\S]*\}", s) | |
| if not m: return {"entities": []} | |
| raw = m.group(0) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| raw = re.sub(r",\s*}", "}", raw) | |
| raw = re.sub(r",\s*]", "]", raw) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| return {"entities": []} | |
| # Lazy load on CPU | |
| _tokenizer = None | |
| _model = None | |
| def load_model(): | |
| global _tokenizer, _model | |
| if _tokenizer is None or _model is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) # CPU by default on Spaces | |
| return _tokenizer, _model | |
| def ner_infer(text, max_new_tokens=192): | |
| text = (text or "").strip() | |
| if not text: | |
| return {"entities": []} | |
| tok, model = load_model() | |
| prompt = build_prompt(text) | |
| inputs = tok(prompt, return_tensors="pt") # stays on CPU | |
| gen_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=False, # deterministic on CPU | |
| temperature=0.0, | |
| # pad_token_id must be set for some T5/mT5 variants: | |
| pad_token_id=tok.pad_token_id, | |
| eos_token_id=tok.eos_token_id | |
| ) | |
| out = tok.decode(gen_ids[0], skip_special_tokens=True) | |
| data = extract_json(out) | |
| # normalize; if model omits start/end, compute first occurrence | |
| ents = [] | |
| for e in data.get("entities", []): | |
| try: | |
| t = str(e.get("text","")).strip() | |
| lab = str(e.get("label","")).strip() | |
| if not t or not lab: | |
| continue | |
| st = e.get("start"); en = e.get("end") | |
| if not isinstance(st, int) or not isinstance(en, int): | |
| idx = text.find(t) | |
| if idx >= 0: | |
| st, en = idx, idx + len(t) | |
| else: | |
| st, en = 0, 0 | |
| ents.append({"text": t, "label": lab, "start": int(st), "end": int(en)}) | |
| except Exception: | |
| pass | |
| return {"entities": ents} | |
| with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo: | |
| gr.Markdown("## Persian Zero-Shot NER (LLM) — CPU version (mT5-small)") | |
| inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم.") | |
| max_tok = gr.Slider(64, 512, value=192, step=16, label="Max new tokens (CPU)") | |
| btn = gr.Button("Extract Entities") | |
| out = gr.JSON(label="خروجی JSON") | |
| btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch() | |