File size: 3,434 Bytes
ee8ab7e
53299a5
ee8ab7e
53299a5
 
ee8ab7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53299a5
ee8ab7e
 
 
 
 
 
 
53299a5
ee8ab7e
 
53299a5
 
 
ee8ab7e
53299a5
ee8ab7e
 
53299a5
ee8ab7e
 
 
 
53299a5
 
 
 
 
ee8ab7e
53299a5
ee8ab7e
 
53299a5
ee8ab7e
 
 
53299a5
 
 
 
 
 
 
 
 
 
 
 
ee8ab7e
 
 
 
53299a5
 
ee8ab7e
53299a5
ee8ab7e
 
53299a5
ee8ab7e
 
53299a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import re, json, gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# LIGHTWEIGHT, CPU-FRIENDLY MODEL
MODEL_ID = "google/mt5-small"
LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]

def build_prompt(text, labels=LABELS):
    return (
        "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
        f"لیبل‌های مجاز: {', '.join(labels)}.\n"
        "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
        '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
        "هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
        f"متن: {text}\n"
    )

def extract_json(s: str):
    m = re.search(r"\{[\s\S]*\}", s)
    if not m: return {"entities": []}
    raw = m.group(0)
    try:
        return json.loads(raw)
    except Exception:
        raw = re.sub(r",\s*}", "}", raw)
        raw = re.sub(r",\s*]", "]", raw)
        try:
            return json.loads(raw)
        except Exception:
            return {"entities": []}

# Lazy load on CPU
_tokenizer = None
_model = None

def load_model():
    global _tokenizer, _model
    if _tokenizer is None or _model is None:
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
        _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)  # CPU by default on Spaces
    return _tokenizer, _model

def ner_infer(text, max_new_tokens=192):
    text = (text or "").strip()
    if not text:
        return {"entities": []}

    tok, model = load_model()
    prompt = build_prompt(text)
    inputs = tok(prompt, return_tensors="pt")  # stays on CPU

    gen_ids = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        do_sample=False,          # deterministic on CPU
        temperature=0.0,
        # pad_token_id must be set for some T5/mT5 variants:
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id
    )
    out = tok.decode(gen_ids[0], skip_special_tokens=True)
    data = extract_json(out)

    # normalize; if model omits start/end, compute first occurrence
    ents = []
    for e in data.get("entities", []):
        try:
            t = str(e.get("text","")).strip()
            lab = str(e.get("label","")).strip()
            if not t or not lab: 
                continue
            st = e.get("start"); en = e.get("end")
            if not isinstance(st, int) or not isinstance(en, int):
                idx = text.find(t)
                if idx >= 0:
                    st, en = idx, idx + len(t)
                else:
                    st, en = 0, 0
            ents.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
        except Exception:
            pass
    return {"entities": ents}

with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
    gr.Markdown("## Persian Zero-Shot NER (LLM) — CPU version (mT5-small)")
    inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.")
    max_tok = gr.Slider(64, 512, value=192, step=16, label="Max new tokens (CPU)")
    btn = gr.Button("Extract Entities")
    out = gr.JSON(label="خروجی JSON")
    btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)

if __name__ == "__main__":
    demo.launch()