Spaces:
Sleeping
Sleeping
File size: 4,360 Bytes
eb83f82 53299a5 ee8ab7e eb83f82 53299a5 eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 ee8ab7e eb83f82 53299a5 ee8ab7e eb83f82 53299a5 ee8ab7e eb83f82 53299a5 ee8ab7e eb83f82 ee8ab7e eb83f82 53299a5 eb83f82 53299a5 ee8ab7e 53299a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# app.py
# Persian Zero-Shot NER (CPU) — Hugging Face Spaces (Gradio)
# Uses a lightweight Seq2Seq model (mT5-small) and slow tokenizer (no GPU deps).
import re
import json
import gradio as gr
from typing import Dict, Any, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ---- Config (CPU-friendly) ----
MODEL_ID = "google/mt5-small"
ALLOWED_LABELS: List[str] = [
"PERSON", "ORG", "LOC", "GPE", "DATE", "TIME", "PRODUCT", "EVENT"
]
DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم."
# ---- Prompt & Parsing ----
def build_prompt(text: str, labels: List[str]) -> str:
return (
"متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n"
f"لیبلهای مجاز: {', '.join(labels)}.\n"
"خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n"
'{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
"هیچ متن دیگری ننویس؛ فقط JSON.\n\n"
f"متن: {text}\n"
)
def extract_first_json(s: str) -> Dict[str, Any]:
m = re.search(r"\{[\s\S]*\}", s)
if not m:
return {"entities": []}
raw = m.group(0)
# try direct parse
try:
return json.loads(raw)
except Exception:
# quick repairs for trailing commas
raw = re.sub(r",\s*}", "}", raw)
raw = re.sub(r",\s*]", "]", raw)
try:
return json.loads(raw)
except Exception:
return {"entities": []}
def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
out = []
for e in data.get("entities", []):
try:
t = str(e.get("text", "")).strip()
lab = str(e.get("label", "")).strip().upper()
if not t or not lab:
continue
# keep only allowed labels
if lab not in labels:
continue
st = e.get("start"); en = e.get("end")
if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
# fallback: first occurrence
idx = text.find(t)
if idx >= 0:
st, en = idx, idx + len(t)
else:
st, en = 0, 0
out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
except Exception:
# ignore malformed entries
pass
return {"entities": out}
# ---- Lazy model load (CPU) ----
_tokenizer = None
_model = None
def load_model():
global _tokenizer, _model
if _tokenizer is None or _model is None:
# IMPORTANT: use_fast=False to avoid SentencePiece fast-conversion issues on CPU
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
return _tokenizer, _model
# ---- Inference ----
def ner_infer(text: str, max_new_tokens: int = 192) -> Dict[str, Any]:
text = (text or "").strip()
if not text:
return {"entities": []}
tok, model = load_model()
prompt = build_prompt(text, ALLOWED_LABELS)
inputs = tok(prompt, return_tensors="pt") # stays on CPU
gen_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False, # deterministic for stable outputs on CPU
temperature=0.0,
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id
)
out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
raw = extract_first_json(out_text)
return normalize_entities(raw, text, ALLOWED_LABELS)
# ---- UI ----
with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
gr.Markdown("## Persian Zero-Shot NER (LLM) — **CPU version (mT5-small)**")
with gr.Row():
inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
with gr.Row():
max_tok = gr.Slider(64, 512, value=192, step=16, label="حداکثر توکن خروجی (CPU)")
btn = gr.Button("استخراج موجودیتها")
out = gr.JSON(label="خروجی JSON (entities)")
btn.click(fn=ner_infer, inputs=[inp, max_tok], outputs=out)
if __name__ == "__main__":
demo.launch()
|