optimopium commited on
Commit
a39d4c2
·
verified ·
1 Parent(s): 67ad485

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -107
app.py CHANGED
@@ -1,119 +1,177 @@
1
- # app.py Persian Zero-Shot NER (CPU) with few-shot prompting + beams
2
- import re, json, gradio as gr
3
- from typing import Dict, Any, List
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
- MODEL_ID = "google/mt5-small" # try "google/mt5-base" on CPU if still empty (slower, better)
7
- ALLOWED_LABELS: List[str] = ["PERSON","ORG","LOC","GPE","DATE","TIME","PRODUCT","EVENT"]
8
- DEFAULT_EXAMPLE = "من دیروز با علی در تهران در دفتر دیجی‌کالا جلسه داشتم."
9
 
10
- # --- Few-shot examples (in Persian) to nudge the model ---
11
- FEW_SHOT = """
12
- نمونه ۱:
13
- متن: من با علی در تهران در شرکت دیجی‌کالا جلسه داشتم.
14
- خروجی:
15
- {"entities":[
16
- {"text":"علی","label":"PERSON","start":7,"end":10},
17
- {"text":"تهران","label":"LOC","start":14,"end":19},
18
- {"text":"دیجی‌کالا","label":"ORG","start":29,"end":37}
19
- ]}
20
 
21
- نمونه ۲:
22
- متن: سارا فردا ساعت ۱۰ در دانشگاه تهران سخنرانی دارد.
23
- خروجی:
24
- {"entities":[
25
- {"text":"سارا","label":"PERSON","start":0,"end":4},
26
- {"text":"فردا","label":"DATE","start":5,"end":9},
27
- {"text":"۱۰","label":"TIME","start":15,"end":17},
28
- {"text":"دانشگاه تهران","label":"ORG","start":21,"end":34}
29
- ]}
30
- """
31
 
32
- def build_prompt(text: str, labels: List[str]) -> str:
33
- return (
34
- "متن زیر را برای شناسایی موجودیت‌های نامدار (NER) تحلیل کن.\n"
35
- f"لیبل‌های مجاز: {', '.join(labels)}.\n"
36
- "فقط JSON معتبر با اسکیمای زیر را برگردان:\n"
37
- '{"entities":[{"text":"...", "label":"ORG|PERSON|...", "start":0, "end":0}]}\n'
38
- "هیچ متن دیگری ننویس؛ فقط JSON.\n"
39
- + FEW_SHOT +
40
- "\nاکنون متن زیر را پردازش کن و فقط JSON بده:\n"
41
- f"متن: {text}\n"
42
- "خروجی:\n"
43
- )
44
 
45
- def extract_first_json(s: str) -> Dict[str, Any]:
46
- m = re.search(r"\{[\s\S]*\}", s)
47
- if not m:
48
- return {"entities": []}
49
- raw = m.group(0)
50
- try:
51
- return json.loads(raw)
52
- except Exception:
53
- raw = re.sub(r",\s*}", "}", raw)
54
- raw = re.sub(r",\s*]", "]", raw)
55
- try:
56
- return json.loads(raw)
57
- except Exception:
58
- return {"entities": []}
 
 
 
59
 
60
- def normalize_entities(data: Dict[str, Any], text: str, labels: List[str]) -> Dict[str, Any]:
61
- text_norm = text or ""
62
- out = []
63
- for e in data.get("entities", []):
64
- try:
65
- t = str(e.get("text","")).strip()
66
- lab = str(e.get("label","")).strip().upper()
67
- if not t or lab not in labels:
68
- continue
69
- st, en = e.get("start"), e.get("end")
70
- if not isinstance(st, int) or not isinstance(en, int) or st < 0 or en < 0:
71
- idx = text_norm.find(t)
72
- st, en = (idx, idx+len(t)) if idx >= 0 else (0, 0)
73
- out.append({"text": t, "label": lab, "start": int(st), "end": int(en)})
74
- except Exception:
75
- pass
76
- return {"entities": out}
77
 
78
- # lazy CPU load
79
- _tokenizer = None
80
- _model = None
81
- def load_model():
82
- global _tokenizer, _model
83
- if _tokenizer is None or _model is None:
84
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
85
- _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
86
- return _tokenizer, _model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- def ner_infer(text: str, max_new_tokens: int = 256, num_beams: int = 4) -> Dict[str, Any]:
89
- text = (text or "").strip()
90
- if not text:
91
- return {"entities": []}
92
- tok, model = load_model()
93
- prompt = build_prompt(text, ALLOWED_LABELS)
94
- inputs = tok(prompt, return_tensors="pt") # CPU
95
- gen_ids = model.generate(
96
- **inputs,
97
- max_new_tokens=int(max_new_tokens),
98
- do_sample=False, # deterministic
99
- num_beams=int(num_beams), # stronger decoding than greedy on CPU
100
- length_penalty=1.05,
101
- pad_token_id=tok.pad_token_id,
102
- eos_token_id=tok.eos_token_id,
103
- )
104
- out_text = tok.decode(gen_ids[0], skip_special_tokens=True)
105
- raw = extract_first_json(out_text)
106
- return normalize_entities(raw, text, ALLOWED_LABELS)
 
 
 
 
 
 
 
 
107
 
108
- with gr.Blocks(title="Persian Zero-Shot NER (CPU)") as demo:
109
- gr.Markdown("## Persian Zero-Shot NER — CPU (mT5) + Few-Shot Prompting")
110
- inp = gr.Textbox(label="متن فارسی", lines=4, value=DEFAULT_EXAMPLE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  with gr.Row():
112
- max_tok = gr.Slider(96, 512, value=256, step=16, label="حداکثر توکن خروجی")
113
- beams = gr.Slider(1, 8, value=4, step=1, label="Beam size")
114
- btn = gr.Button("استخراج موجودیت‌ها")
115
- out = gr.JSON(label="خروجی JSON")
116
- btn.click(fn=ner_infer, inputs=[inp, max_tok, beams], outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
 
118
  if __name__ == "__main__":
119
- demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
+ import torch
 
4
 
5
+ # Set device to CPU explicitly
6
+ device = "cpu"
 
7
 
8
+ # Load the model and tokenizer
9
+ model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
 
 
 
 
 
 
 
 
10
 
11
+ print("Loading model and tokenizer...")
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
14
+ model.to(device)
 
 
 
 
 
 
15
 
16
+ # Create NER pipeline
17
+ ner_pipeline = pipeline(
18
+ "ner",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
+ device=-1, # -1 means CPU
22
+ aggregation_strategy="simple" # Groups entities together
23
+ )
 
 
 
 
24
 
25
+ # Label mapping for better readability
26
+ label_colors = {
27
+ "B-PER": "#FF6B6B", # Person - Red
28
+ "I-PER": "#FFB3B3", # Person continuation - Light Red
29
+ "B-ORG": "#4ECDC4", # Organization - Teal
30
+ "I-ORG": "#A7E9E4", # Organization continuation - Light Teal
31
+ "B-LOC": "#95E1D3", # Location - Green
32
+ "I-LOC": "#C7F0E8", # Location continuation - Light Green
33
+ "B-DAT": "#FFA07A", # Date - Orange
34
+ "I-DAT": "#FFDAB9", # Date continuation - Light Orange
35
+ "B-TIM": "#DDA0DD", # Time - Purple
36
+ "I-TIM": "#E6D0E6", # Time continuation - Light Purple
37
+ "B-MON": "#FFD700", # Money - Gold
38
+ "I-MON": "#FFEB99", # Money continuation - Light Gold
39
+ "B-PCT": "#87CEEB", # Percent - Sky Blue
40
+ "I-PCT": "#B3DFEF", # Percent continuation - Light Sky Blue
41
+ }
42
 
43
+ label_names = {
44
+ "PER": "شخص (Person)",
45
+ "ORG": "سازمان (Organization)",
46
+ "LOC": "مکان (Location)",
47
+ "DAT": "تاریخ (Date)",
48
+ "TIM": "زمان (Time)",
49
+ "MON": "پول (Money)",
50
+ "PCT": "درصد (Percent)",
51
+ }
 
 
 
 
 
 
 
 
52
 
53
+ def highlight_entities(text, entities):
54
+ """Create HTML with highlighted entities"""
55
+ if not entities:
56
+ return text
57
+
58
+ # Sort entities by start position (reverse order to replace from end to start)
59
+ entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True)
60
+
61
+ result = text
62
+ for entity in entities_sorted:
63
+ start = entity['start']
64
+ end = entity['end']
65
+ label = entity['entity_group']
66
+ word = text[start:end]
67
+ score = entity['score']
68
+
69
+ # Get color for this label
70
+ color = label_colors.get(f"B-{label}", "#CCCCCC")
71
+
72
+ # Create highlighted span
73
+ highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{label} (confidence: {score:.2f})">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
74
+
75
+ result = result[:start] + highlighted + result[end:]
76
+
77
+ return result
78
 
79
+ def perform_ner(text):
80
+ """Perform NER on input text"""
81
+ if not text.strip():
82
+ return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
83
+
84
+ try:
85
+ # Perform NER
86
+ entities = ner_pipeline(text)
87
+
88
+ # Create highlighted version
89
+ highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, entities)}</div>"
90
+
91
+ # Create entities table
92
+ if entities:
93
+ entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
94
+ entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) |\n"
95
+ entity_info += "|------------|-----------|---------------------|\n"
96
+ for ent in entities:
97
+ label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
98
+ entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} |\n"
99
+ else:
100
+ entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
101
+
102
+ return highlighted_html, entity_info
103
+
104
+ except Exception as e:
105
+ return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""
106
 
107
+ # Example texts
108
+ examples = [
109
+ ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی می‌کرد."],
110
+ ["شرکت گوگل در کالیفرنیا واقع شده است."],
111
+ ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
112
+ ["دانشگاه تهران یکی از قدیمی‌ترین دانشگاه‌های ایران است."],
113
+ ["علی و حسین به همراه مریم به مشهد سفر کردند."],
114
+ ]
115
+
116
+ # Create Gradio interface
117
+ with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نامدار فارسی", theme=gr.themes.Soft()) as demo:
118
+ gr.Markdown("""
119
+ # 🇮🇷 Persian Named Entity Recognition
120
+ # شناسایی موجودیت‌های نامدار فارسی
121
+
122
+ این سیستم موجودیت‌های نامدار مانند اسامی اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها و ... را در متن فارسی شناسایی می‌کند.
123
+
124
+ This system identifies named entities such as person names, organizations, locations, dates, etc. in Persian text.
125
+
126
+ **Model:** ParsBERT-NER (HooshvareLab)
127
+ **Running on:** CPU (may be slow for long texts)
128
+ """)
129
+
130
  with gr.Row():
131
+ with gr.Column():
132
+ input_text = gr.Textbox(
133
+ label="متن فارسی خود را وارد کنید (Enter Persian Text)",
134
+ placeholder="مثال: رضا در تهران زندگی می‌کند...",
135
+ lines=5,
136
+ rtl=True
137
+ )
138
+ submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary")
139
+
140
+ with gr.Column():
141
+ output_html = gr.HTML(label="متن با موجودیت‌های برجسته (Text with Highlighted Entities)")
142
+ output_entities = gr.Markdown(label="لیست موجودیت‌ها (Entity List)")
143
+
144
+ gr.Examples(
145
+ examples=examples,
146
+ inputs=input_text,
147
+ label="مثال‌ها (Examples)"
148
+ )
149
+
150
+ # Legend
151
+ gr.Markdown("""
152
+ ### راهنمای رنگ‌ها (Color Guide):
153
+ - 🔴 **PER (شخص)**: اسامی اشخاص / Person names
154
+ - 🔵 **ORG (سازمان)**: نام سازمان‌ها / Organizations
155
+ - 🟢 **LOC (مکان)**: نام مکان‌ها / Locations
156
+ - 🟠 **DAT (تاریخ)**: تاریخ‌ها / Dates
157
+ - 🟣 **TIM (زمان)**: زمان‌ها / Times
158
+ - 🟡 **MON (پول)**: مقادیر پولی / Money
159
+ - 🔷 **PCT (درصد)**: درصدها / Percentages
160
+ """)
161
+
162
+ # Event handler
163
+ submit_btn.click(
164
+ fn=perform_ner,
165
+ inputs=input_text,
166
+ outputs=[output_html, output_entities]
167
+ )
168
+
169
+ input_text.submit(
170
+ fn=perform_ner,
171
+ inputs=input_text,
172
+ outputs=[output_html, output_entities]
173
+ )
174
 
175
+ # Launch the app
176
  if __name__ == "__main__":
177
+ demo.launch()