Spaces:
Sleeping
Sleeping
| # app.py | |
| # Persian Zero-Shot NER (CPU) — Hugging Face Spaces (Gradio) | |
| # Uses a lightweight Seq2Seq model (mT5-small) and slow tokenizer (no GPU deps). | |
| import re | |
| import json | |
| import gradio as gr | |
| from typing import Dict, Any, List | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ---- Config (CPU-friendly) ---- | |
| MODEL_ID = "google/mt5-small" | |
| ALLOWED_LABELS: List[str] = [ | |
| "PERSON", "ORG", "LOC", "GPE", "DATE", "TIME", "PRODUCT", "EVENT" | |
| ] | |
| DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم." | |
| # ---- Prompt & Parsing ---- | |
| def build_prompt(text: str, labels: List[str]) -> str: | |
| return ( | |
| "متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n" | |
| f"لیبلهای مجاز: {', '.join(labels)}.\n" | |
| "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n" | |
| '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n' | |
| "هیچ متن دیگری ننویس؛ فقط JSON.\n\n" | |
| f"متن: {text}\n" | |
| ) | |
| def extract_first_json(s: str) -> Dict[str, Any]: | |
| m = re.search(r"\{[\s\S]*\}", s) | |
| if not m: | |
| return {"entities": []} | |
| raw = m.group(0) | |
| # try direct parse | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| # quick repairs for trailing commas | |
| raw = re.sub(r",\s*}", "}", raw) | |
| raw = re.sub(r",\s*]", "]", raw) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| return {"entities": []} | |
| def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]: | |
| out = [] | |
| for e in data.get("entities", []): | |
| try: | |
| t = str(e.get("text", "")).strip() | |
| lab = str(e.get("label", "")).strip().upper() | |
| if not t or not lab: | |
| continue | |
| # keep only allowed labels | |
| if lab not in labels: | |
| continue | |
| st = e.get("start"); en = e.get("end") | |
| if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0: | |
| # fallback: first occurrence | |
| idx = text.find(t) | |
| if idx >= 0: | |
| st, en = idx, idx + len(t) | |
| else: | |
| st, en = 0, 0 | |
| out.append({"text": t, "label": lab, "start": int(st), "end": int(en)}) | |
| except Exception: | |
| # ignore malformed entries | |
| pass | |
| return {"entities": out} | |
| # ---- Lazy model load (CPU) ---- | |
| _tokenizer = None | |
| _model = None | |
| def load_model(): | |
| global _tokenizer, _model | |
| if _tokenizer is None or _model is None: | |
| # IMPORTANT: use_fast=False to avoid SentencePiece fast-conversion issues on CPU | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) | |
| return _tokenizer, _model | |
| # ---- Inference ---- | |
| def ner_infer(text: str, max_new_tokens: int = 192) -> Dict[str, Any]: | |
| text = (text or "").strip() | |
| if not text: | |
| return {"entities": []} | |
| tok, model = load_model() | |
| prompt = build_prompt(text, ALLOWED_LABELS) | |
| inputs = tok(prompt, return_tensors="pt") # stays on CPU | |
| gen_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=False, # deterministic for stable outputs on CPU | |
| temperature=0.0, | |
| pad_token_id=tok.pad_token_id, | |
| eos_token_id=tok.eos_token_id | |
| ) | |
| out_text = tok.decode(gen_ids[0], skip_special_tokens=True) | |
| raw = extract_first_json(out_text) | |
| return normalize_entities(raw, text, ALLOWED_LABELS) | |
| # ---- UI ---- | |
| with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo: | |
| gr.Markdown("## Persian Zero-Shot NER (LLM) — **CPU version (mT5-small)**") | |
| with gr.Row(): | |
| inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE) | |
| with gr.Row(): | |
| max_tok = gr.Slider(64, 512, value=192, step=16, label="حداکثر توکن خروجی (CPU)") | |
| btn = gr.Button("استخراج موجودیتها") | |
| out = gr.JSON(label="خروجی JSON (entities)") | |
| btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch() | |