File size: 4,834 Bytes
67ad485
 
eb83f82
53299a5
ee8ab7e
67ad485
 
eb83f82
ee8ab7e
67ad485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb83f82
ee8ab7e
 
 
67ad485
ee8ab7e
67ad485
 
 
ee8ab7e
67ad485
ee8ab7e
 
eb83f82
ee8ab7e
eb83f82
 
ee8ab7e
 
 
 
 
 
 
 
 
 
 
eb83f82
67ad485
eb83f82
 
 
67ad485
 
 
eb83f82
67ad485
eb83f82
67ad485
 
eb83f82
 
 
 
 
67ad485
ee8ab7e
 
 
 
 
eb83f82
 
ee8ab7e
 
67ad485
53299a5
 
ee8ab7e
 
eb83f82
67ad485
ee8ab7e
 
 
67ad485
 
 
53299a5
67ad485
ee8ab7e
eb83f82
 
 
ee8ab7e
53299a5
67ad485
 
eb83f82
67ad485
 
eb83f82
67ad485
 
ee8ab7e
 
53299a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# app.py — Persian Zero-Shot NER (CPU) with few-shot prompting + beams
import re, json, gradio as gr
from typing import Dict, Any, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

MODEL_ID = "google/mt5-small"   # try "google/mt5-base" on CPU if still empty (slower, better)
ALLOWED_LABELS: List[str] = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."

# --- Few-shot examples (in Persian) to nudge the model ---
FEW_SHOT = """
نمونه ۱:
متن: من با علی در تهران در شرکت دیجی‌کالا جلسه داشتم.
خروجی:
{"entities":[
  {"text":"علی","label":"PERSON","start":7,"end":10},
  {"text":"تهران","label":"LOC","start":14,"end":19},
  {"text":"دیجی‌کالا","label":"ORG","start":29,"end":37}
]}

نمونه ۲:
متن: سارا فردا ساعت ۱۰ در دانشگاه تهران سخنرانی دارد.
خروجی:
{"entities":[
  {"text":"سارا","label":"PERSON","start":0,"end":4},
  {"text":"فردا","label":"DATE","start":5,"end":9},
  {"text":"۱۰","label":"TIME","start":15,"end":17},
  {"text":"دانشگاه تهران","label":"ORG","start":21,"end":34}
]}
"""

def build_prompt(text: str, labels: List[str]) -> str:
    return (
        "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
        f"لیبل‌های مجاز: {', '.join(labels)}.\n"
        "فقط JSON معتبر با اسکیمای زیر را برگردان:\n"
        '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
        "هیچ متن دیگری ننویس؛ فقط JSON.\n"
        + FEW_SHOT +
        "\nاکنون متن زیر را پردازش کن و فقط JSON بده:\n"
        f"متن: {text}\n"
        "خروجی:\n"
    )

def extract_first_json(s: str) -> Dict[str, Any]:
    m = re.search(r"\{[\s\S]*\}", s)
    if not m:
        return {"entities": []}
    raw = m.group(0)
    try:
        return json.loads(raw)
    except Exception:
        raw = re.sub(r",\s*}", "}", raw)
        raw = re.sub(r",\s*]", "]", raw)
        try:
            return json.loads(raw)
        except Exception:
            return {"entities": []}

def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
    text_norm = text or ""
    out = []
    for e in data.get("entities", []):
        try:
            t = str(e.get("text","")).strip()
            lab = str(e.get("label","")).strip().upper()
            if not t or lab not in labels:
                continue
            st, en = e.get("start"), e.get("end")
            if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
                idx = text_norm.find(t)
                st, en = (idx, idx+len(t)) if idx >= 0 else (0, 0)
            out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
        except Exception:
            pass
    return {"entities": out}

# lazy CPU load
_tokenizer = None
_model = None
def load_model():
    global _tokenizer, _model
    if _tokenizer is None or _model is None:
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
        _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
    return _tokenizer, _model

def ner_infer(text: str, max_new_tokens: int = 256, num_beams: int = 4) -> Dict[str, Any]:
    text = (text or "").strip()
    if not text:
        return {"entities": []}
    tok, model = load_model()
    prompt = build_prompt(text, ALLOWED_LABELS)
    inputs = tok(prompt, return_tensors="pt")   # CPU
    gen_ids = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        do_sample=False,          # deterministic
        num_beams=int(num_beams), # stronger decoding than greedy on CPU
        length_penalty=1.05,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )
    out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
    raw = extract_first_json(out_text)
    return normalize_entities(raw, text, ALLOWED_LABELS)

with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
    gr.Markdown("## Persian Zero-Shot NER — CPU (mT5) + Few-Shot Prompting")
    inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
    with gr.Row():
        max_tok = gr.Slider(96, 512, value=256, step=16, label="حداکثر توکن خروجی")
        beams = gr.Slider(1, 8, value=4, step=1, label="Beam size")
    btn = gr.Button("استخراج موجودیت‌ها")
    out = gr.JSON(label="خروجی JSON")
    btn.click(fn=ner_infer, inputs=[inp, max_tok, beams], outputs=out)

if __name__ == "__main__":
    demo.launch()