Spaces:
Sleeping
Sleeping
| # app.py — Persian Zero-Shot NER (CPU) with few-shot prompting + beams | |
| import re, json, gradio as gr | |
| from typing import Dict, Any, List | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| MODEL_ID = "google/mt5-small" # try "google/mt5-base" on CPU if still empty (slower, better) | |
| ALLOWED_LABELS: List[str] = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"] | |
| DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم." | |
| # --- Few-shot examples (in Persian) to nudge the model --- | |
| FEW_SHOT = """ | |
| نمونه ۱: | |
| متن: من با علی در تهران در شرکت دیجیکالا جلسه داشتم. | |
| خروجی: | |
| {"entities":[ | |
| {"text":"علی","label":"PERSON","start":7,"end":10}, | |
| {"text":"تهران","label":"LOC","start":14,"end":19}, | |
| {"text":"دیجیکالا","label":"ORG","start":29,"end":37} | |
| ]} | |
| نمونه ۲: | |
| متن: سارا فردا ساعت ۱۰ در دانشگاه تهران سخنرانی دارد. | |
| خروجی: | |
| {"entities":[ | |
| {"text":"سارا","label":"PERSON","start":0,"end":4}, | |
| {"text":"فردا","label":"DATE","start":5,"end":9}, | |
| {"text":"۱۰","label":"TIME","start":15,"end":17}, | |
| {"text":"دانشگاه تهران","label":"ORG","start":21,"end":34} | |
| ]} | |
| """ | |
| def build_prompt(text: str, labels: List[str]) -> str: | |
| return ( | |
| "متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n" | |
| f"لیبلهای مجاز: {', '.join(labels)}.\n" | |
| "فقط JSON معتبر با اسکیمای زیر را برگردان:\n" | |
| '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n' | |
| "هیچ متن دیگری ننویس؛ فقط JSON.\n" | |
| + FEW_SHOT + | |
| "\nاکنون متن زیر را پردازش کن و فقط JSON بده:\n" | |
| f"متن: {text}\n" | |
| "خروجی:\n" | |
| ) | |
| def extract_first_json(s: str) -> Dict[str, Any]: | |
| m = re.search(r"\{[\s\S]*\}", s) | |
| if not m: | |
| return {"entities": []} | |
| raw = m.group(0) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| raw = re.sub(r",\s*}", "}", raw) | |
| raw = re.sub(r",\s*]", "]", raw) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| return {"entities": []} | |
| def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]: | |
| text_norm = text or "" | |
| out = [] | |
| for e in data.get("entities", []): | |
| try: | |
| t = str(e.get("text","")).strip() | |
| lab = str(e.get("label","")).strip().upper() | |
| if not t or lab not in labels: | |
| continue | |
| st, en = e.get("start"), e.get("end") | |
| if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0: | |
| idx = text_norm.find(t) | |
| st, en = (idx, idx+len(t)) if idx >= 0 else (0, 0) | |
| out.append({"text": t, "label": lab, "start": int(st), "end": int(en)}) | |
| except Exception: | |
| pass | |
| return {"entities": out} | |
| # lazy CPU load | |
| _tokenizer = None | |
| _model = None | |
| def load_model(): | |
| global _tokenizer, _model | |
| if _tokenizer is None or _model is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) | |
| return _tokenizer, _model | |
| def ner_infer(text: str, max_new_tokens: int = 256, num_beams: int = 4) -> Dict[str, Any]: | |
| text = (text or "").strip() | |
| if not text: | |
| return {"entities": []} | |
| tok, model = load_model() | |
| prompt = build_prompt(text, ALLOWED_LABELS) | |
| inputs = tok(prompt, return_tensors="pt") # CPU | |
| gen_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=False, # deterministic | |
| num_beams=int(num_beams), # stronger decoding than greedy on CPU | |
| length_penalty=1.05, | |
| pad_token_id=tok.pad_token_id, | |
| eos_token_id=tok.eos_token_id, | |
| ) | |
| out_text = tok.decode(gen_ids[0], skip_special_tokens=True) | |
| raw = extract_first_json(out_text) | |
| return normalize_entities(raw, text, ALLOWED_LABELS) | |
| with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo: | |
| gr.Markdown("## Persian Zero-Shot NER — CPU (mT5) + Few-Shot Prompting") | |
| inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE) | |
| with gr.Row(): | |
| max_tok = gr.Slider(96, 512, value=256, step=16, label="حداکثر توکن خروجی") | |
| beams = gr.Slider(1, 8, value=4, step=1, label="Beam size") | |
| btn = gr.Button("استخراج موجودیتها") | |
| out = gr.JSON(label="خروجی JSON") | |
| btn.click(fn=ner_infer, inputs=[inp, max_tok, beams], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch() | |