optimopium commited on
Commit
ee8ab7e
·
verified ·
1 Parent(s): a1a4c8c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, json, gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
6
+ LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
7
+
8
+ def build_prompt(text, labels=LABELS):
9
+ return (
10
+ "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
11
+ f"لیبل‌های مجاز: {', '.join(labels)}.\n"
12
+ "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
13
+ '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
14
+ "هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
15
+ f"متن: {text}\n"
16
+ )
17
+
18
+ def extract_json(s: str):
19
+ m = re.search(r"\{[\s\S]*\}", s)
20
+ if not m: return {"entities": []}
21
+ raw = m.group(0)
22
+ try:
23
+ return json.loads(raw)
24
+ except Exception:
25
+ raw = re.sub(r",\s*}", "}", raw)
26
+ raw = re.sub(r",\s*]", "]", raw)
27
+ try:
28
+ return json.loads(raw)
29
+ except Exception:
30
+ return {"entities": []}
31
+
32
+ # lazy globals
33
+ _tokenizer = None
34
+ _model = None
35
+
36
+ def load_model():
37
+ global _tokenizer, _model
38
+ if _tokenizer is None or _model is None:
39
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
40
+ _model = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_ID,
42
+ torch_dtype=torch.float16 if torch.cuda.is_available() else None,
43
+ device_map="auto"
44
+ )
45
+ return _tokenizer, _model
46
+
47
+ def ner_infer(text, temperature=0.0, max_new_tokens=256):
48
+ if not text.strip():
49
+ return {"entities": []}
50
+ tok, model = load_model()
51
+ prompt = build_prompt(text)
52
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
53
+
54
+ gen_ids = model.generate(
55
+ **inputs,
56
+ max_new_tokens=int(max_new_tokens),
57
+ do_sample=(float(temperature) > 0),
58
+ temperature=float(temperature),
59
+ pad_token_id=tok.eos_token_id or tok.pad_token_id,
60
+ )
61
+ out = tok.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
62
+ data = extract_json(out)
63
+
64
+ # normalize schema
65
+ ents = []
66
+ for e in data.get("entities", []):
67
+ try:
68
+ t = e["text"]; lab = e["label"]
69
+ st = int(e.get("start", 0)); en = int(e.get("end", st + len(t)))
70
+ ents.append({"text": t, "label": lab, "start": st, "end": en})
71
+ except Exception:
72
+ pass
73
+ return {"entities": ents}
74
+
75
+ with gr.Blocks(title="Persian Zero-Shot NER (LLM)") as demo:
76
+ gr.Markdown("## Persian Zero-Shot NER (LLM) — JSON output")
77
+ inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.")
78
+ with gr.Row():
79
+ temp = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
80
+ max_tok = gr.Slider(64, 512, value=256, step=16, label="Max new tokens")
81
+ btn = gr.Button("Extract Entities")
82
+ out = gr.JSON(label="خروجی JSON")
83
+
84
+ btn.click(fn=ner_infer, inputs=[inp, temp, max_tok], outputs=out)
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch()