optimopium commited on
Commit
eb83f82
·
verified ·
1 Parent(s): 30352ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -42
app.py CHANGED
@@ -1,11 +1,22 @@
1
- import re, json, gradio as gr
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- # LIGHTWEIGHT, CPU-FRIENDLY MODEL
5
  MODEL_ID = "google/mt5-small"
6
- LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
 
 
 
7
 
8
- def build_prompt(text, labels=LABELS):
 
9
  return (
10
  "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
11
  f"لیبل‌های مجاز: {', '.join(labels)}.\n"
@@ -15,13 +26,16 @@ def build_prompt(text, labels=LABELS):
15
  f"متن: {text}\n"
16
  )
17
 
18
- def extract_json(s: str):
19
  m = re.search(r"\{[\s\S]*\}", s)
20
- if not m: return {"entities": []}
 
21
  raw = m.group(0)
 
22
  try:
23
  return json.loads(raw)
24
  except Exception:
 
25
  raw = re.sub(r",\s*}", "}", raw)
26
  raw = re.sub(r",\s*]", "]", raw)
27
  try:
@@ -29,64 +43,72 @@ def extract_json(s: str):
29
  except Exception:
30
  return {"entities": []}
31
 
32
- # Lazy load on CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  _tokenizer = None
34
  _model = None
35
 
36
  def load_model():
37
  global _tokenizer, _model
38
  if _tokenizer is None or _model is None:
39
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
40
- _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) # CPU by default on Spaces
 
41
  return _tokenizer, _model
42
 
43
- def ner_infer(text, max_new_tokens=192):
 
44
  text = (text or "").strip()
45
  if not text:
46
  return {"entities": []}
47
-
48
  tok, model = load_model()
49
- prompt = build_prompt(text)
50
  inputs = tok(prompt, return_tensors="pt") # stays on CPU
51
-
52
  gen_ids = model.generate(
53
  **inputs,
54
  max_new_tokens=int(max_new_tokens),
55
- do_sample=False, # deterministic on CPU
56
  temperature=0.0,
57
- # pad_token_id must be set for some T5/mT5 variants:
58
  pad_token_id=tok.pad_token_id,
59
  eos_token_id=tok.eos_token_id
60
  )
61
- out = tok.decode(gen_ids[0], skip_special_tokens=True)
62
- data = extract_json(out)
63
-
64
- # normalize; if model omits start/end, compute first occurrence
65
- ents = []
66
- for e in data.get("entities", []):
67
- try:
68
- t = str(e.get("text","")).strip()
69
- lab = str(e.get("label","")).strip()
70
- if not t or not lab:
71
- continue
72
- st = e.get("start"); en = e.get("end")
73
- if not isinstance(st, int) or not isinstance(en, int):
74
- idx = text.find(t)
75
- if idx >= 0:
76
- st, en = idx, idx + len(t)
77
- else:
78
- st, en = 0, 0
79
- ents.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
80
- except Exception:
81
- pass
82
- return {"entities": ents}
83
 
 
84
  with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
85
- gr.Markdown("## Persian Zero-Shot NER (LLM) — CPU version (mT5-small)")
86
- inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.")
87
- max_tok = gr.Slider(64, 512, value=192, step=16, label="Max new tokens (CPU)")
88
- btn = gr.Button("Extract Entities")
89
- out = gr.JSON(label="خروجی JSON")
 
 
90
  btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)
91
 
92
  if __name__ == "__main__":
 
1
+ # app.py
2
+ # Persian Zero-Shot NER (CPU) — Hugging Face Spaces (Gradio)
3
+ # Uses a lightweight Seq2Seq model (mT5-small) and slow tokenizer (no GPU deps).
4
+
5
+ import re
6
+ import json
7
+ import gradio as gr
8
+ from typing import Dict, Any, List
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
 
11
+ # ---- Config (CPU-friendly) ----
12
  MODEL_ID = "google/mt5-small"
13
+ ALLOWED_LABELS: List[str] = [
14
+ "PERSON", "ORG", "LOC", "GPE", "DATE", "TIME", "PRODUCT", "EVENT"
15
+ ]
16
+ DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."
17
 
18
+ # ---- Prompt & Parsing ----
19
+ def build_prompt(text: str, labels: List[str]) -> str:
20
  return (
21
  "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
22
  f"لیبل‌های مجاز: {', '.join(labels)}.\n"
 
26
  f"متن: {text}\n"
27
  )
28
 
29
+ def extract_first_json(s: str) -> Dict[str, Any]:
30
  m = re.search(r"\{[\s\S]*\}", s)
31
+ if not m:
32
+ return {"entities": []}
33
  raw = m.group(0)
34
+ # try direct parse
35
  try:
36
  return json.loads(raw)
37
  except Exception:
38
+ # quick repairs for trailing commas
39
  raw = re.sub(r",\s*}", "}", raw)
40
  raw = re.sub(r",\s*]", "]", raw)
41
  try:
 
43
  except Exception:
44
  return {"entities": []}
45
 
46
+ def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
47
+ out = []
48
+ for e in data.get("entities", []):
49
+ try:
50
+ t = str(e.get("text", "")).strip()
51
+ lab = str(e.get("label", "")).strip().upper()
52
+ if not t or not lab:
53
+ continue
54
+ # keep only allowed labels
55
+ if lab not in labels:
56
+ continue
57
+ st = e.get("start"); en = e.get("end")
58
+ if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
59
+ # fallback: first occurrence
60
+ idx = text.find(t)
61
+ if idx >= 0:
62
+ st, en = idx, idx + len(t)
63
+ else:
64
+ st, en = 0, 0
65
+ out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
66
+ except Exception:
67
+ # ignore malformed entries
68
+ pass
69
+ return {"entities": out}
70
+
71
+ # ---- Lazy model load (CPU) ----
72
  _tokenizer = None
73
  _model = None
74
 
75
  def load_model():
76
  global _tokenizer, _model
77
  if _tokenizer is None or _model is None:
78
+ # IMPORTANT: use_fast=False to avoid SentencePiece fast-conversion issues on CPU
79
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
80
+ _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
81
  return _tokenizer, _model
82
 
83
+ # ---- Inference ----
84
+ def ner_infer(text: str, max_new_tokens: int = 192) -> Dict[str, Any]:
85
  text = (text or "").strip()
86
  if not text:
87
  return {"entities": []}
 
88
  tok, model = load_model()
89
+ prompt = build_prompt(text, ALLOWED_LABELS)
90
  inputs = tok(prompt, return_tensors="pt") # stays on CPU
 
91
  gen_ids = model.generate(
92
  **inputs,
93
  max_new_tokens=int(max_new_tokens),
94
+ do_sample=False, # deterministic for stable outputs on CPU
95
  temperature=0.0,
 
96
  pad_token_id=tok.pad_token_id,
97
  eos_token_id=tok.eos_token_id
98
  )
99
+ out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
100
+ raw = extract_first_json(out_text)
101
+ return normalize_entities(raw, text, ALLOWED_LABELS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # ---- UI ----
104
  with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
105
+ gr.Markdown("## Persian Zero-Shot NER (LLM) — **CPU version (mT5-small)**")
106
+ with gr.Row():
107
+ inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
108
+ with gr.Row():
109
+ max_tok = gr.Slider(64, 512, value=192, step=16, label="حداکثر توکن خروجی (CPU)")
110
+ btn = gr.Button("استخراج موجودیت‌ها")
111
+ out = gr.JSON(label="خروجی JSON (entities)")
112
  btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)
113
 
114
  if __name__ == "__main__":