Spaces:
Sleeping
Sleeping
File size: 4,834 Bytes
67ad485 eb83f82 53299a5 ee8ab7e 67ad485 eb83f82 ee8ab7e 67ad485 eb83f82 ee8ab7e 67ad485 ee8ab7e 67ad485 ee8ab7e 67ad485 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 67ad485 eb83f82 67ad485 eb83f82 67ad485 eb83f82 67ad485 eb83f82 67ad485 ee8ab7e eb83f82 ee8ab7e 67ad485 53299a5 ee8ab7e eb83f82 67ad485 ee8ab7e 67ad485 53299a5 67ad485 ee8ab7e eb83f82 ee8ab7e 53299a5 67ad485 eb83f82 67ad485 eb83f82 67ad485 ee8ab7e 53299a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# app.py — Persian Zero-Shot NER (CPU) with few-shot prompting + beams
import re, json, gradio as gr
from typing import Dict, Any, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_ID = "google/mt5-small" # try "google/mt5-base" on CPU if still empty (slower, better)
ALLOWED_LABELS: List[str] = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم."
# --- Few-shot examples (in Persian) to nudge the model ---
FEW_SHOT = """
نمونه ۱:
متن: من با علی در تهران در شرکت دیجیکالا جلسه داشتم.
خروجی:
{"entities":[
{"text":"علی","label":"PERSON","start":7,"end":10},
{"text":"تهران","label":"LOC","start":14,"end":19},
{"text":"دیجیکالا","label":"ORG","start":29,"end":37}
]}
نمونه ۲:
متن: سارا فردا ساعت ۱۰ در دانشگاه تهران سخنرانی دارد.
خروجی:
{"entities":[
{"text":"سارا","label":"PERSON","start":0,"end":4},
{"text":"فردا","label":"DATE","start":5,"end":9},
{"text":"۱۰","label":"TIME","start":15,"end":17},
{"text":"دانشگاه تهران","label":"ORG","start":21,"end":34}
]}
"""
def build_prompt(text: str, labels: List[str]) -> str:
return (
"متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n"
f"لیبلهای مجاز: {', '.join(labels)}.\n"
"فقط JSON معتبر با اسکیمای زیر را برگردان:\n"
'{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
"هیچ متن دیگری ننویس؛ فقط JSON.\n"
+ FEW_SHOT +
"\nاکنون متن زیر را پردازش کن و فقط JSON بده:\n"
f"متن: {text}\n"
"خروجی:\n"
)
def extract_first_json(s: str) -> Dict[str, Any]:
m = re.search(r"\{[\s\S]*\}", s)
if not m:
return {"entities": []}
raw = m.group(0)
try:
return json.loads(raw)
except Exception:
raw = re.sub(r",\s*}", "}", raw)
raw = re.sub(r",\s*]", "]", raw)
try:
return json.loads(raw)
except Exception:
return {"entities": []}
def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
text_norm = text or ""
out = []
for e in data.get("entities", []):
try:
t = str(e.get("text","")).strip()
lab = str(e.get("label","")).strip().upper()
if not t or lab not in labels:
continue
st, en = e.get("start"), e.get("end")
if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
idx = text_norm.find(t)
st, en = (idx, idx+len(t)) if idx >= 0 else (0, 0)
out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
except Exception:
pass
return {"entities": out}
# lazy CPU load
_tokenizer = None
_model = None
def load_model():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
return _tokenizer, _model
def ner_infer(text: str, max_new_tokens: int = 256, num_beams: int = 4) -> Dict[str, Any]:
text = (text or "").strip()
if not text:
return {"entities": []}
tok, model = load_model()
prompt = build_prompt(text, ALLOWED_LABELS)
inputs = tok(prompt, return_tensors="pt") # CPU
gen_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False, # deterministic
num_beams=int(num_beams), # stronger decoding than greedy on CPU
length_penalty=1.05,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id,
)
out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
raw = extract_first_json(out_text)
return normalize_entities(raw, text, ALLOWED_LABELS)
with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
gr.Markdown("## Persian Zero-Shot NER — CPU (mT5) + Few-Shot Prompting")
inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
with gr.Row():
max_tok = gr.Slider(96, 512, value=256, step=16, label="حداکثر توکن خروجی")
beams = gr.Slider(1, 8, value=4, step=1, label="Beam size")
btn = gr.Button("استخراج موجودیتها")
out = gr.JSON(label="خروجی JSON")
btn.click(fn=ner_infer, inputs=[inp, max_tok, beams], outputs=out)
if __name__ == "__main__":
demo.launch()
|