File size: 4,360 Bytes
eb83f82
 
 
 
 
 
 
 
53299a5
ee8ab7e
eb83f82
53299a5
eb83f82
 
 
 
ee8ab7e
eb83f82
 
ee8ab7e
 
 
 
 
 
 
 
 
eb83f82
ee8ab7e
eb83f82
 
ee8ab7e
eb83f82
ee8ab7e
 
 
eb83f82
ee8ab7e
 
 
 
 
 
 
eb83f82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee8ab7e
 
 
 
 
 
eb83f82
 
 
ee8ab7e
 
eb83f82
 
53299a5
 
ee8ab7e
 
eb83f82
53299a5
ee8ab7e
 
 
eb83f82
53299a5
 
 
ee8ab7e
eb83f82
 
 
ee8ab7e
eb83f82
53299a5
eb83f82
 
 
 
 
 
 
53299a5
ee8ab7e
 
53299a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# app.py
# Persian Zero-Shot NER (CPU) — Hugging Face Spaces (Gradio)
# Uses a lightweight Seq2Seq model (mT5-small) and slow tokenizer (no GPU deps).

import re
import json
import gradio as gr
from typing import Dict, Any, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# ---- Config (CPU-friendly) ----
MODEL_ID = "google/mt5-small"
ALLOWED_LABELS: List[str] = [
    "PERSON", "ORG", "LOC", "GPE", "DATE", "TIME", "PRODUCT", "EVENT"
]
DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."

# ---- Prompt & Parsing ----
def build_prompt(text: str, labels: List[str]) -> str:
    return (
        "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
        f"لیبل‌های مجاز: {', '.join(labels)}.\n"
        "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
        '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
        "هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
        f"متن: {text}\n"
    )

def extract_first_json(s: str) -> Dict[str, Any]:
    m = re.search(r"\{[\s\S]*\}", s)
    if not m:
        return {"entities": []}
    raw = m.group(0)
    # try direct parse
    try:
        return json.loads(raw)
    except Exception:
        # quick repairs for trailing commas
        raw = re.sub(r",\s*}", "}", raw)
        raw = re.sub(r",\s*]", "]", raw)
        try:
            return json.loads(raw)
        except Exception:
            return {"entities": []}

def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
    out = []
    for e in data.get("entities", []):
        try:
            t = str(e.get("text", "")).strip()
            lab = str(e.get("label", "")).strip().upper()
            if not t or not lab:
                continue
            # keep only allowed labels
            if lab not in labels:
                continue
            st = e.get("start"); en = e.get("end")
            if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
                # fallback: first occurrence
                idx = text.find(t)
                if idx >= 0:
                    st, en = idx, idx + len(t)
                else:
                    st, en = 0, 0
            out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
        except Exception:
            # ignore malformed entries
            pass
    return {"entities": out}

# ---- Lazy model load (CPU) ----
_tokenizer = None
_model = None

def load_model():
    global _tokenizer, _model
    if _tokenizer is None or _model is None:
        # IMPORTANT: use_fast=False to avoid SentencePiece fast-conversion issues on CPU
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
        _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
    return _tokenizer, _model

# ---- Inference ----
def ner_infer(text: str, max_new_tokens: int = 192) -> Dict[str, Any]:
    text = (text or "").strip()
    if not text:
        return {"entities": []}
    tok, model = load_model()
    prompt = build_prompt(text, ALLOWED_LABELS)
    inputs = tok(prompt, return_tensors="pt")  # stays on CPU
    gen_ids = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        do_sample=False,    # deterministic for stable outputs on CPU
        temperature=0.0,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id
    )
    out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
    raw = extract_first_json(out_text)
    return normalize_entities(raw, text, ALLOWED_LABELS)

# ---- UI ----
with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
    gr.Markdown("## Persian Zero-Shot NER (LLM) — **CPU version (mT5-small)**")
    with gr.Row():
        inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
    with gr.Row():
        max_tok = gr.Slider(64, 512, value=192, step=16, label="حداکثر توکن خروجی (CPU)")
    btn = gr.Button("استخراج موجودیت‌ها")
    out = gr.JSON(label="خروجی JSON (entities)")
    btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)

if __name__ == "__main__":
    demo.launch()