optimopium's picture
Update app.py
67ad485 verified
raw
history blame
4.83 kB
# app.py — Persian Zero-Shot NER (CPU) with few-shot prompting + beams
import re, json, gradio as gr
from typing import Dict, Any, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_ID = "google/mt5-small" # try "google/mt5-base" on CPU if still empty (slower, better)
ALLOWED_LABELS: List[str] = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."
# --- Few-shot examples (in Persian) to nudge the model ---
FEW_SHOT = """
نمونه ۱:
متن: من با علی در تهران در شرکت دیجی‌کالا جلسه داشتم.
خروجی:
{"entities":[
{"text":"علی","label":"PERSON","start":7,"end":10},
{"text":"تهران","label":"LOC","start":14,"end":19},
{"text":"دیجی‌کالا","label":"ORG","start":29,"end":37}
]}
نمونه ۲:
متن: سارا فردا ساعت ۱۰ در دانشگاه تهران سخنرانی دارد.
خروجی:
{"entities":[
{"text":"سارا","label":"PERSON","start":0,"end":4},
{"text":"فردا","label":"DATE","start":5,"end":9},
{"text":"۱۰","label":"TIME","start":15,"end":17},
{"text":"دانشگاه تهران","label":"ORG","start":21,"end":34}
]}
"""
def build_prompt(text: str, labels: List[str]) -> str:
return (
"متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
f"لیبل‌های مجاز: {', '.join(labels)}.\n"
"فقط JSON معتبر با اسکیمای زیر را برگردان:\n"
'{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
"هیچ متن دیگری ننویس؛ فقط JSON.\n"
+ FEW_SHOT +
"\nاکنون متن زیر را پردازش کن و فقط JSON بده:\n"
f"متن: {text}\n"
"خروجی:\n"
)
def extract_first_json(s: str) -> Dict[str, Any]:
m = re.search(r"\{[\s\S]*\}", s)
if not m:
return {"entities": []}
raw = m.group(0)
try:
return json.loads(raw)
except Exception:
raw = re.sub(r",\s*}", "}", raw)
raw = re.sub(r",\s*]", "]", raw)
try:
return json.loads(raw)
except Exception:
return {"entities": []}
def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
text_norm = text or ""
out = []
for e in data.get("entities", []):
try:
t = str(e.get("text","")).strip()
lab = str(e.get("label","")).strip().upper()
if not t or lab not in labels:
continue
st, en = e.get("start"), e.get("end")
if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
idx = text_norm.find(t)
st, en = (idx, idx+len(t)) if idx >= 0 else (0, 0)
out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
except Exception:
pass
return {"entities": out}
# lazy CPU load
_tokenizer = None
_model = None
def load_model():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
return _tokenizer, _model
def ner_infer(text: str, max_new_tokens: int = 256, num_beams: int = 4) -> Dict[str, Any]:
text = (text or "").strip()
if not text:
return {"entities": []}
tok, model = load_model()
prompt = build_prompt(text, ALLOWED_LABELS)
inputs = tok(prompt, return_tensors="pt") # CPU
gen_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False, # deterministic
num_beams=int(num_beams), # stronger decoding than greedy on CPU
length_penalty=1.05,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id,
)
out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
raw = extract_first_json(out_text)
return normalize_entities(raw, text, ALLOWED_LABELS)
with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
gr.Markdown("## Persian Zero-Shot NER — CPU (mT5) + Few-Shot Prompting")
inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
with gr.Row():
max_tok = gr.Slider(96, 512, value=256, step=16, label="حداکثر توکن خروجی")
beams = gr.Slider(1, 8, value=4, step=1, label="Beam size")
btn = gr.Button("استخراج موجودیت‌ها")
out = gr.JSON(label="خروجی JSON")
btn.click(fn=ner_infer, inputs=[inp, max_tok, beams], outputs=out)
if __name__ == "__main__":
demo.launch()