Spaces:
Sleeping
Sleeping
File size: 3,434 Bytes
ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 ee8ab7e 53299a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import re, json, gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# LIGHTWEIGHT, CPU-FRIENDLY MODEL
MODEL_ID = "google/mt5-small"
LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
def build_prompt(text, labels=LABELS):
return (
"متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n"
f"لیبلهای مجاز: {', '.join(labels)}.\n"
"خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
'{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
"هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
f"متن: {text}\n"
)
def extract_json(s: str):
m = re.search(r"\{[\s\S]*\}", s)
if not m: return {"entities": []}
raw = m.group(0)
try:
return json.loads(raw)
except Exception:
raw = re.sub(r",\s*}", "}", raw)
raw = re.sub(r",\s*]", "]", raw)
try:
return json.loads(raw)
except Exception:
return {"entities": []}
# Lazy load on CPU
_tokenizer = None
_model = None
def load_model():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) # CPU by default on Spaces
return _tokenizer, _model
def ner_infer(text, max_new_tokens=192):
text = (text or "").strip()
if not text:
return {"entities": []}
tok, model = load_model()
prompt = build_prompt(text)
inputs = tok(prompt, return_tensors="pt") # stays on CPU
gen_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False, # deterministic on CPU
temperature=0.0,
# pad_token_id must be set for some T5/mT5 variants:
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id
)
out = tok.decode(gen_ids[0], skip_special_tokens=True)
data = extract_json(out)
# normalize; if model omits start/end, compute first occurrence
ents = []
for e in data.get("entities", []):
try:
t = str(e.get("text","")).strip()
lab = str(e.get("label","")).strip()
if not t or not lab:
continue
st = e.get("start"); en = e.get("end")
if not isinstance(st, int) or not isinstance(en, int):
idx = text.find(t)
if idx >= 0:
st, en = idx, idx + len(t)
else:
st, en = 0, 0
ents.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
except Exception:
pass
return {"entities": ents}
with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
gr.Markdown("## Persian Zero-Shot NER (LLM) — CPU version (mT5-small)")
inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم.")
max_tok = gr.Slider(64, 512, value=192, step=16, label="Max new tokens (CPU)")
btn = gr.Button("Extract Entities")
out = gr.JSON(label="خروجی JSON")
btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)
if __name__ == "__main__":
demo.launch()
|