optimopium commited on
Commit
67ad485
·
verified ·
1 Parent(s): eb83f82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -47
app.py CHANGED
@@ -1,29 +1,45 @@
1
- # app.py
2
- # Persian Zero-Shot NER (CPU) — Hugging Face Spaces (Gradio)
3
- # Uses a lightweight Seq2Seq model (mT5-small) and slow tokenizer (no GPU deps).
4
-
5
- import re
6
- import json
7
- import gradio as gr
8
  from typing import Dict, Any, List
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
 
11
- # ---- Config (CPU-friendly) ----
12
- MODEL_ID = "google/mt5-small"
13
- ALLOWED_LABELS: List[str] = [
14
- "PERSON", "ORG", "LOC", "GPE", "DATE", "TIME", "PRODUCT", "EVENT"
15
- ]
16
  DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."
17
 
18
- # ---- Prompt & Parsing ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def build_prompt(text: str, labels: List[str]) -> str:
20
  return (
21
  "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
22
  f"لیبل‌های مجاز: {', '.join(labels)}.\n"
23
- "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
24
  '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
25
- "هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
 
 
26
  f"متن: {text}\n"
 
27
  )
28
 
29
  def extract_first_json(s: str) -> Dict[str, Any]:
@@ -31,11 +47,9 @@ def extract_first_json(s: str) -> Dict[str, Any]:
31
  if not m:
32
  return {"entities": []}
33
  raw = m.group(0)
34
- # try direct parse
35
  try:
36
  return json.loads(raw)
37
  except Exception:
38
- # quick repairs for trailing commas
39
  raw = re.sub(r",\s*}", "}", raw)
40
  raw = re.sub(r",\s*]", "]", raw)
41
  try:
@@ -44,72 +58,62 @@ def extract_first_json(s: str) -> Dict[str, Any]:
44
  return {"entities": []}
45
 
46
  def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
 
47
  out = []
48
  for e in data.get("entities", []):
49
  try:
50
- t = str(e.get("text", "")).strip()
51
- lab = str(e.get("label", "")).strip().upper()
52
- if not t or not lab:
53
  continue
54
- # keep only allowed labels
55
- if lab not in labels:
56
- continue
57
- st = e.get("start"); en = e.get("end")
58
  if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
59
- # fallback: first occurrence
60
- idx = text.find(t)
61
- if idx >= 0:
62
- st, en = idx, idx + len(t)
63
- else:
64
- st, en = 0, 0
65
  out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
66
  except Exception:
67
- # ignore malformed entries
68
  pass
69
  return {"entities": out}
70
 
71
- # ---- Lazy model load (CPU) ----
72
  _tokenizer = None
73
  _model = None
74
-
75
  def load_model():
76
  global _tokenizer, _model
77
  if _tokenizer is None or _model is None:
78
- # IMPORTANT: use_fast=False to avoid SentencePiece fast-conversion issues on CPU
79
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
80
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
81
  return _tokenizer, _model
82
 
83
- # ---- Inference ----
84
- def ner_infer(text: str, max_new_tokens: int = 192) -> Dict[str, Any]:
85
  text = (text or "").strip()
86
  if not text:
87
  return {"entities": []}
88
  tok, model = load_model()
89
  prompt = build_prompt(text, ALLOWED_LABELS)
90
- inputs = tok(prompt, return_tensors="pt") # stays on CPU
91
  gen_ids = model.generate(
92
  **inputs,
93
  max_new_tokens=int(max_new_tokens),
94
- do_sample=False, # deterministic for stable outputs on CPU
95
- temperature=0.0,
 
96
  pad_token_id=tok.pad_token_id,
97
- eos_token_id=tok.eos_token_id
98
  )
99
  out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
100
  raw = extract_first_json(out_text)
101
  return normalize_entities(raw, text, ALLOWED_LABELS)
102
 
103
- # ---- UI ----
104
  with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
105
- gr.Markdown("## Persian Zero-Shot NER (LLM) **CPU version (mT5-small)**")
106
- with gr.Row():
107
- inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
108
  with gr.Row():
109
- max_tok = gr.Slider(64, 512, value=192, step=16, label="حداکثر توکن خروجی (CPU)")
 
110
  btn = gr.Button("استخراج موجودیت‌ها")
111
- out = gr.JSON(label="خروجی JSON (entities)")
112
- btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)
113
 
114
  if __name__ == "__main__":
115
  demo.launch()
 
1
+ # app.py — Persian Zero-Shot NER (CPU) with few-shot prompting + beams
2
+ import re, json, gradio as gr
 
 
 
 
 
3
  from typing import Dict, Any, List
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ MODEL_ID = "google/mt5-small" # try "google/mt5-base" on CPU if still empty (slower, better)
7
+ ALLOWED_LABELS: List[str] = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
 
 
 
8
  DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."
9
 
10
+ # --- Few-shot examples (in Persian) to nudge the model ---
11
+ FEW_SHOT = """
12
+ نمونه ۱:
13
+ متن: من با علی در تهران در شرکت دیجی‌کالا جلسه داشتم.
14
+ خروجی:
15
+ {"entities":[
16
+ {"text":"علی","label":"PERSON","start":7,"end":10},
17
+ {"text":"تهران","label":"LOC","start":14,"end":19},
18
+ {"text":"دیجی‌کالا","label":"ORG","start":29,"end":37}
19
+ ]}
20
+
21
+ نمونه ۲:
22
+ متن: سارا فردا ساعت ۱۰ در دانشگاه تهران سخنرانی دارد.
23
+ خروجی:
24
+ {"entities":[
25
+ {"text":"سارا","label":"PERSON","start":0,"end":4},
26
+ {"text":"فردا","label":"DATE","start":5,"end":9},
27
+ {"text":"۱۰","label":"TIME","start":15,"end":17},
28
+ {"text":"دانشگاه تهران","label":"ORG","start":21,"end":34}
29
+ ]}
30
+ """
31
+
32
  def build_prompt(text: str, labels: List[str]) -> str:
33
  return (
34
  "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
35
  f"لیبل‌های مجاز: {', '.join(labels)}.\n"
36
+ "فقط JSON معتبر با اسکیمای زیر را برگردان:\n"
37
  '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
38
+ "هیچ متن دیگری ننویس؛ فقط JSON.\n"
39
+ + FEW_SHOT +
40
+ "\nاکنون متن زیر را پردازش کن و فقط JSON بده:\n"
41
  f"متن: {text}\n"
42
+ "خروجی:\n"
43
  )
44
 
45
  def extract_first_json(s: str) -> Dict[str, Any]:
 
47
  if not m:
48
  return {"entities": []}
49
  raw = m.group(0)
 
50
  try:
51
  return json.loads(raw)
52
  except Exception:
 
53
  raw = re.sub(r",\s*}", "}", raw)
54
  raw = re.sub(r",\s*]", "]", raw)
55
  try:
 
58
  return {"entities": []}
59
 
60
  def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
61
+ text_norm = text or ""
62
  out = []
63
  for e in data.get("entities", []):
64
  try:
65
+ t = str(e.get("text","")).strip()
66
+ lab = str(e.get("label","")).strip().upper()
67
+ if not t or lab not in labels:
68
  continue
69
+ st, en = e.get("start"), e.get("end")
 
 
 
70
  if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
71
+ idx = text_norm.find(t)
72
+ st, en = (idx, idx+len(t)) if idx >= 0 else (0, 0)
 
 
 
 
73
  out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
74
  except Exception:
 
75
  pass
76
  return {"entities": out}
77
 
78
+ # lazy CPU load
79
  _tokenizer = None
80
  _model = None
 
81
  def load_model():
82
  global _tokenizer, _model
83
  if _tokenizer is None or _model is None:
 
84
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
85
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
86
  return _tokenizer, _model
87
 
88
+ def ner_infer(text: str, max_new_tokens: int = 256, num_beams: int = 4) -> Dict[str, Any]:
 
89
  text = (text or "").strip()
90
  if not text:
91
  return {"entities": []}
92
  tok, model = load_model()
93
  prompt = build_prompt(text, ALLOWED_LABELS)
94
+ inputs = tok(prompt, return_tensors="pt") # CPU
95
  gen_ids = model.generate(
96
  **inputs,
97
  max_new_tokens=int(max_new_tokens),
98
+ do_sample=False, # deterministic
99
+ num_beams=int(num_beams), # stronger decoding than greedy on CPU
100
+ length_penalty=1.05,
101
  pad_token_id=tok.pad_token_id,
102
+ eos_token_id=tok.eos_token_id,
103
  )
104
  out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
105
  raw = extract_first_json(out_text)
106
  return normalize_entities(raw, text, ALLOWED_LABELS)
107
 
 
108
  with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
109
+ gr.Markdown("## Persian Zero-Shot NER — CPU (mT5) + Few-Shot Prompting")
110
+ inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
 
111
  with gr.Row():
112
+ max_tok = gr.Slider(96, 512, value=256, step=16, label="حداکثر توکن خروجی")
113
+ beams = gr.Slider(1, 8, value=4, step=1, label="Beam size")
114
  btn = gr.Button("استخراج موجودیت‌ها")
115
+ out = gr.JSON(label="خروجی JSON")
116
+ btn.click(fn=ner_infer, inputs=[inp, max_tok, beams], outputs=out)
117
 
118
  if __name__ == "__main__":
119
  demo.launch()