optimopium commited on
Commit
53299a5
·
verified ·
1 Parent(s): e1b936d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -28
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import re, json, gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
 
5
- MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
 
6
  LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
7
 
8
  def build_prompt(text, labels=LABELS):
@@ -29,7 +29,7 @@ def extract_json(s: str):
29
  except Exception:
30
  return {"entities": []}
31
 
32
- # lazy globals
33
  _tokenizer = None
34
  _model = None
35
 
@@ -37,51 +37,57 @@ def load_model():
37
  global _tokenizer, _model
38
  if _tokenizer is None or _model is None:
39
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
40
- _model = AutoModelForCausalLM.from_pretrained(
41
- MODEL_ID,
42
- torch_dtype=torch.float16 if torch.cuda.is_available() else None,
43
- device_map="auto"
44
- )
45
  return _tokenizer, _model
46
 
47
- def ner_infer(text, temperature=0.0, max_new_tokens=256):
48
- if not text.strip():
 
49
  return {"entities": []}
 
50
  tok, model = load_model()
51
  prompt = build_prompt(text)
52
- inputs = tok(prompt, return_tensors="pt").to(model.device)
53
 
54
  gen_ids = model.generate(
55
  **inputs,
56
  max_new_tokens=int(max_new_tokens),
57
- do_sample=(float(temperature) > 0),
58
- temperature=float(temperature),
59
- pad_token_id=tok.eos_token_id or tok.pad_token_id,
 
 
60
  )
61
- out = tok.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
62
  data = extract_json(out)
63
 
64
- # normalize schema
65
  ents = []
66
  for e in data.get("entities", []):
67
  try:
68
- t = e["text"]; lab = e["label"]
69
- st = int(e.get("start", 0)); en = int(e.get("end", st + len(t)))
70
- ents.append({"text": t, "label": lab, "start": st, "end": en})
 
 
 
 
 
 
 
 
 
71
  except Exception:
72
  pass
73
  return {"entities": ents}
74
 
75
- with gr.Blocks(title="Persian Zero-Shot NER (LLM)") as demo:
76
- gr.Markdown("## Persian Zero-Shot NER (LLM) — JSON output")
77
  inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.")
78
- with gr.Row():
79
- temp = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
80
- max_tok = gr.Slider(64, 512, value=256, step=16, label="Max new tokens")
81
  btn = gr.Button("Extract Entities")
82
  out = gr.JSON(label="خروجی JSON")
83
-
84
- btn.click(fn=ner_infer, inputs=[inp, temp, max_tok], outputs=out)
85
 
86
  if __name__ == "__main__":
87
- demo.launch()
 
1
  import re, json, gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
+ # LIGHTWEIGHT, CPU-FRIENDLY MODEL
5
+ MODEL_ID = "google/mt5-small"
6
  LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
7
 
8
  def build_prompt(text, labels=LABELS):
 
29
  except Exception:
30
  return {"entities": []}
31
 
32
+ # Lazy load on CPU
33
  _tokenizer = None
34
  _model = None
35
 
 
37
  global _tokenizer, _model
38
  if _tokenizer is None or _model is None:
39
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
40
+ _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) # CPU by default on Spaces
 
 
 
 
41
  return _tokenizer, _model
42
 
43
+ def ner_infer(text, max_new_tokens=192):
44
+ text = (text or "").strip()
45
+ if not text:
46
  return {"entities": []}
47
+
48
  tok, model = load_model()
49
  prompt = build_prompt(text)
50
+ inputs = tok(prompt, return_tensors="pt") # stays on CPU
51
 
52
  gen_ids = model.generate(
53
  **inputs,
54
  max_new_tokens=int(max_new_tokens),
55
+ do_sample=False, # deterministic on CPU
56
+ temperature=0.0,
57
+ # pad_token_id must be set for some T5/mT5 variants:
58
+ pad_token_id=tok.pad_token_id,
59
+ eos_token_id=tok.eos_token_id
60
  )
61
+ out = tok.decode(gen_ids[0], skip_special_tokens=True)
62
  data = extract_json(out)
63
 
64
+ # normalize; if model omits start/end, compute first occurrence
65
  ents = []
66
  for e in data.get("entities", []):
67
  try:
68
+ t = str(e.get("text","")).strip()
69
+ lab = str(e.get("label","")).strip()
70
+ if not t or not lab:
71
+ continue
72
+ st = e.get("start"); en = e.get("end")
73
+ if not isinstance(st, int) or not isinstance(en, int):
74
+ idx = text.find(t)
75
+ if idx >= 0:
76
+ st, en = idx, idx + len(t)
77
+ else:
78
+ st, en = 0, 0
79
+ ents.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
80
  except Exception:
81
  pass
82
  return {"entities": ents}
83
 
84
+ with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
85
+ gr.Markdown("## Persian Zero-Shot NER (LLM) — CPU version (mT5-small)")
86
  inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم.")
87
+ max_tok = gr.Slider(64, 512, value=192, step=16, label="Max new tokens (CPU)")
 
 
88
  btn = gr.Button("Extract Entities")
89
  out = gr.JSON(label="خروجی JSON")
90
+ btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)
 
91
 
92
  if __name__ == "__main__":
93
+ demo.launch()