Spaces:
Sleeping
Sleeping
| import re, json, gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" | |
| LABELS = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"] | |
| def build_prompt(text, labels=LABELS): | |
| return ( | |
| "متن زیر را برای شناسایی موجودیتهای نامدار (NER) تحلیل کن.\n" | |
| f"لیبلهای مجاز: {', '.join(labels)}.\n" | |
| "خروجی را فقط به صورت JSON معتبر با اسکیمای زیر بده:\n" | |
| '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n' | |
| "هیچ متن دیگری ننویس؛ فقط JSON.\n\n" | |
| f"متن: {text}\n" | |
| ) | |
| def extract_json(s: str): | |
| m = re.search(r"\{[\s\S]*\}", s) | |
| if not m: return {"entities": []} | |
| raw = m.group(0) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| raw = re.sub(r",\s*}", "}", raw) | |
| raw = re.sub(r",\s*]", "]", raw) | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| return {"entities": []} | |
| # lazy globals | |
| _tokenizer = None | |
| _model = None | |
| def load_model(): | |
| global _tokenizer, _model | |
| if _tokenizer is None or _model is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else None, | |
| device_map="auto" | |
| ) | |
| return _tokenizer, _model | |
| def ner_infer(text, temperature=0.0, max_new_tokens=256): | |
| if not text.strip(): | |
| return {"entities": []} | |
| tok, model = load_model() | |
| prompt = build_prompt(text) | |
| inputs = tok(prompt, return_tensors="pt").to(model.device) | |
| gen_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=(float(temperature) > 0), | |
| temperature=float(temperature), | |
| pad_token_id=tok.eos_token_id or tok.pad_token_id, | |
| ) | |
| out = tok.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| data = extract_json(out) | |
| # normalize schema | |
| ents = [] | |
| for e in data.get("entities", []): | |
| try: | |
| t = e["text"]; lab = e["label"] | |
| st = int(e.get("start", 0)); en = int(e.get("end", st + len(t))) | |
| ents.append({"text": t, "label": lab, "start": st, "end": en}) | |
| except Exception: | |
| pass | |
| return {"entities": ents} | |
| with gr.Blocks(title="Persian Zero-Shot NER (LLM)") as demo: | |
| gr.Markdown("## Persian Zero-Shot NER (LLM) — JSON output") | |
| inp = gr.Textbox(label="متن فارسی", lines=4, value="من دیروز با علی در تهران در دفتر دیجیکالا جلسه داشتم.") | |
| with gr.Row(): | |
| temp = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature") | |
| max_tok = gr.Slider(64, 512, value=256, step=16, label="Max new tokens") | |
| btn = gr.Button("Extract Entities") | |
| out = gr.JSON(label="خروجی JSON") | |
| btn.click(fn=ner_infer, inputs=[inp, temp, max_tok], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch() |