File size: 3,189 Bytes
ee8ab7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import re, json, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]

def build_prompt(text, labels=LABELS):
    return (
        "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
        f"لیبل‌های مجاز: {', '.join(labels)}.\n"
        "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
        '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
        "هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
        f"متن: {text}\n"
    )

def extract_json(s: str):
    m = re.search(r"\{[\s\S]*\}", s)
    if not m: return {"entities": []}
    raw = m.group(0)
    try:
        return json.loads(raw)
    except Exception:
        raw = re.sub(r",\s*}", "}", raw)
        raw = re.sub(r",\s*]", "]", raw)
        try:
            return json.loads(raw)
        except Exception:
            return {"entities": []}

# lazy globals
_tokenizer = None
_model = None

def load_model():
    global _tokenizer, _model
    if _tokenizer is None or _model is None:
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
        _model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16 if torch.cuda.is_available() else None,
            device_map="auto"
        )
    return _tokenizer, _model

def ner_infer(text, temperature=0.0, max_new_tokens=256):
    if not text.strip():
        return {"entities": []}
    tok, model = load_model()
    prompt = build_prompt(text)
    inputs = tok(prompt, return_tensors="pt").to(model.device)

    gen_ids = model.generate(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        do_sample=(float(temperature) > 0),
        temperature=float(temperature),
        pad_token_id=tok.eos_token_id or tok.pad_token_id,
    )
    out = tok.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
    data = extract_json(out)

    # normalize schema
    ents = []
    for e in data.get("entities", []):
        try:
            t = e["text"]; lab = e["label"]
            st = int(e.get("start", 0)); en = int(e.get("end", st + len(t)))
            ents.append({"text": t, "label": lab, "start": st, "end": en})
        except Exception:
            pass
    return {"entities": ents}

with gr.Blocks(title="Persian Zero-Shot NER (LLM)") as demo:
    gr.Markdown("## Persian Zero-Shot NER (LLM) — JSON output")
    inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.")
    with gr.Row():
        temp = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
        max_tok = gr.Slider(64, 512, value=256, step=16, label="Max new tokens")
    btn = gr.Button("Extract Entities")
    out = gr.JSON(label="خروجی JSON")

    btn.click(fn=ner_infer, inputs=[inp, temp, max_tok], outputs=out)

if __name__ == "__main__":
    demo.launch()