optimopium commited on
Commit
3c57d2d
·
verified ·
1 Parent(s): 6e7759a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -31
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  import torch
 
 
 
4
 
5
  # Set device to CPU explicitly
6
  device = "cpu"
@@ -22,22 +25,59 @@ ner_pipeline = pipeline(
22
  aggregation_strategy="simple" # Groups entities together
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Label mapping for better readability
26
  label_colors = {
27
- "B-PER": "#FF6B6B", # Person - Red
28
- "I-PER": "#FFB3B3", # Person continuation - Light Red
29
- "B-ORG": "#4ECDC4", # Organization - Teal
30
- "I-ORG": "#A7E9E4", # Organization continuation - Light Teal
31
- "B-LOC": "#95E1D3", # Location - Green
32
- "I-LOC": "#C7F0E8", # Location continuation - Light Green
33
- "B-DAT": "#FFA07A", # Date - Orange
34
- "I-DAT": "#FFDAB9", # Date continuation - Light Orange
35
- "B-TIM": "#DDA0DD", # Time - Purple
36
- "I-TIM": "#E6D0E6", # Time continuation - Light Purple
37
- "B-MON": "#FFD700", # Money - Gold
38
- "I-MON": "#FFEB99", # Money continuation - Light Gold
39
- "B-PCT": "#87CEEB", # Percent - Sky Blue
40
- "I-PCT": "#B3DFEF", # Percent continuation - Light Sky Blue
 
 
41
  }
42
 
43
  label_names = {
@@ -48,8 +88,91 @@ label_names = {
48
  "TIM": "زمان (Time)",
49
  "MON": "پول (Money)",
50
  "PCT": "درصد (Percent)",
 
 
51
  }
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def highlight_entities(text, entities):
54
  """Create HTML with highlighted entities"""
55
  if not entities:
@@ -67,10 +190,20 @@ def highlight_entities(text, entities):
67
  score = entity['score']
68
 
69
  # Get color for this label
70
- color = label_colors.get(f"B-{label}", "#CCCCCC")
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Create highlighted span
73
- highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{label} (confidence: {score:.2f})">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
74
 
75
  result = result[:start] + highlighted + result[end:]
76
 
@@ -82,20 +215,36 @@ def perform_ner(text):
82
  return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
83
 
84
  try:
85
- # Perform NER
86
  entities = ner_pipeline(text)
87
 
 
 
 
 
 
 
 
 
 
88
  # Create highlighted version
89
- highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, entities)}</div>"
90
 
91
  # Create entities table
92
- if entities:
93
  entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
94
- entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) |\n"
95
- entity_info += "|------------|-----------|---------------------|\n"
96
- for ent in entities:
97
  label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
98
- entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} |\n"
 
 
 
 
 
 
 
99
  else:
100
  entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
101
 
@@ -104,6 +253,16 @@ def perform_ner(text):
104
  except Exception as e:
105
  return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""
106
 
 
 
 
 
 
 
 
 
 
 
107
  # Example texts
108
  examples = [
109
  ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی می‌کرد."],
@@ -111,27 +270,55 @@ examples = [
111
  ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
112
  ["دانشگاه تهران یکی از قدیمی‌ترین دانشگاه‌های ایران است."],
113
  ["علی و حسین به همراه مریم به مشهد سفر کردند."],
 
 
 
 
114
  ]
115
 
116
  # Create Gradio interface
117
  with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نامدار فارسی", theme=gr.themes.Soft()) as demo:
118
- gr.Markdown("""
119
- # 🇮🇷 Persian Named Entity Recognition
120
- # شناسایی موجودیت‌های نامدار فارسی
121
 
122
- این سیستم موجودیت‌های نامدار مانند اسامی اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها و ... را در متن فارسی شناسایی می‌کند.
123
 
124
- This system identifies named entities such as person names, organizations, locations, dates, etc. in Persian text.
125
 
126
- **Model:** ParsBERT-NER (HooshvareLab)
 
127
  **Running on:** CPU (may be slow for long texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  """)
129
 
130
  with gr.Row():
131
  with gr.Column():
132
  input_text = gr.Textbox(
133
  label="متن فارسی خود را وارد کنید (Enter Persian Text)",
134
- placeholder="مثال: رضا در تهران زندگی می‌کند...",
135
  lines=5,
136
  rtl=True
137
  )
@@ -151,12 +338,20 @@ with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نا
151
  gr.Markdown("""
152
  ### راهنمای رنگ‌ها (Color Guide):
153
  - 🔴 **PER (شخص)**: اسامی اشخاص / Person names
154
- - 🔵 **ORG (سازمان)**: نام سازمان‌ها / Organizations
155
  - 🟢 **LOC (مکان)**: نام مکان‌ها / Locations
156
  - 🟠 **DAT (تاریخ)**: تاریخ‌ها / Dates
157
  - 🟣 **TIM (زمان)**: زمان‌ها / Times
158
  - 🟡 **MON (پول)**: مقادیر پولی / Money
159
  - 🔷 **PCT (درصد)**: درصدها / Percentages
 
 
 
 
 
 
 
 
160
  """)
161
 
162
  # Event handler
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  import torch
4
+ import re
5
+ import csv
6
+ import os
7
 
8
  # Set device to CPU explicitly
9
  device = "cpu"
 
25
  aggregation_strategy="simple" # Groups entities together
26
  )
27
 
28
+ # Load stock symbols from CSV file
29
+ def load_stock_symbols_from_csv(csv_path='symbols.csv'):
30
+ """Load stock symbols from CSV file"""
31
+ stock_symbols = {}
32
+ try:
33
+ with open(csv_path, 'r', encoding='utf-8') as f:
34
+ reader = csv.DictReader(f)
35
+ for row in reader:
36
+ symbol = row['symbol']
37
+ company_name = row['company_name']
38
+ stock_symbols[symbol] = company_name
39
+ print(f"Loaded {len(stock_symbols)} stock symbols from CSV")
40
+ except FileNotFoundError:
41
+ print(f"Warning: {csv_path} not found. Using default symbols.")
42
+ return stock_symbols
43
+
44
+ # Load stock symbols
45
+ STOCK_SYMBOLS = load_stock_symbols_from_csv()
46
+
47
+ # Hypernym patterns (generic terms that can be made more specific)
48
+ HYPERNYM_PATTERNS = {
49
+ "شرکت": "ORG",
50
+ "سازمان": "ORG",
51
+ "موسسه": "ORG",
52
+ "بانک": "ORG",
53
+ "دانشگاه": "ORG",
54
+ "شهر": "LOC",
55
+ "کشور": "LOC",
56
+ "استان": "LOC",
57
+ "آقای": "PER",
58
+ "خانم": "PER",
59
+ "دکتر": "PER",
60
+ "مهندس": "PER",
61
+ }
62
+
63
  # Label mapping for better readability
64
  label_colors = {
65
+ "B-PER": "#FF6B6B",
66
+ "I-PER": "#FFB3B3",
67
+ "B-ORG": "#4ECDC4",
68
+ "I-ORG": "#A7E9E4",
69
+ "B-LOC": "#95E1D3",
70
+ "I-LOC": "#C7F0E8",
71
+ "B-DAT": "#FFA07A",
72
+ "I-DAT": "#FFDAB9",
73
+ "B-TIM": "#DDA0DD",
74
+ "I-TIM": "#E6D0E6",
75
+ "B-MON": "#FFD700",
76
+ "I-MON": "#FFEB99",
77
+ "B-PCT": "#87CEEB",
78
+ "I-PCT": "#B3DFEF",
79
+ "STK": "#FF1493", # Stock symbol - Deep Pink
80
+ "HYP": "#A9A9A9", # Hypernym - Dark Gray
81
  }
82
 
83
  label_names = {
 
88
  "TIM": "زمان (Time)",
89
  "MON": "پول (Money)",
90
  "PCT": "درصد (Percent)",
91
+ "STK": "نماد بورس (Stock Symbol)",
92
+ "HYP": "واژه عمومی (Hypernym)",
93
  }
94
 
95
+ def detect_stock_symbols(text):
96
+ """Detect Persian stock market symbols in text"""
97
+ stock_entities = []
98
+
99
+ # Split text into words
100
+ words = re.findall(r'[\u0600-\u06FF]+', text)
101
+
102
+ for word in words:
103
+ if word in STOCK_SYMBOLS:
104
+ # Find all occurrences of this symbol in the text
105
+ for match in re.finditer(re.escape(word), text):
106
+ stock_entities.append({
107
+ 'entity_group': 'STK',
108
+ 'word': word,
109
+ 'start': match.start(),
110
+ 'end': match.end(),
111
+ 'score': 0.99, # High confidence for dictionary match
112
+ 'full_name': STOCK_SYMBOLS[word]
113
+ })
114
+
115
+ return stock_entities
116
+
117
+ def detect_hypernyms(text, entities):
118
+ """Detect hypernyms (general terms) in text and classify them"""
119
+ hypernym_entities = []
120
+
121
+ for hypernym, entity_type in HYPERNYM_PATTERNS.items():
122
+ for match in re.finditer(re.escape(hypernym), text):
123
+ start, end = match.start(), match.end()
124
+
125
+ # Check if this position already has a specific entity
126
+ is_covered = False
127
+ for ent in entities:
128
+ if start >= ent['start'] and end <= ent['end']:
129
+ is_covered = True
130
+ break
131
+
132
+ if not is_covered:
133
+ hypernym_entities.append({
134
+ 'entity_group': 'HYP',
135
+ 'word': hypernym,
136
+ 'start': start,
137
+ 'end': end,
138
+ 'score': 0.95,
139
+ 'base_type': entity_type,
140
+ 'is_hypernym': True
141
+ })
142
+
143
+ return hypernym_entities
144
+
145
+ def merge_entities(entities, stock_entities, hypernym_entities):
146
+ """Merge all entity types and remove overlaps, prioritizing specific entities"""
147
+ all_entities = entities + stock_entities + hypernym_entities
148
+
149
+ # Sort by start position
150
+ all_entities.sort(key=lambda x: x['start'])
151
+
152
+ # Remove overlapping entities (keep higher priority)
153
+ # Priority: STK > specific entities > HYP
154
+ filtered_entities = []
155
+ for entity in all_entities:
156
+ overlaps = False
157
+ for existing in filtered_entities:
158
+ # Check for overlap
159
+ if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']):
160
+ overlaps = True
161
+ # If new entity is stock symbol, replace existing
162
+ if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK':
163
+ filtered_entities.remove(existing)
164
+ overlaps = False
165
+ # If existing is hypernym and new is specific, replace
166
+ elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP':
167
+ filtered_entities.remove(existing)
168
+ overlaps = False
169
+ break
170
+
171
+ if not overlaps:
172
+ filtered_entities.append(entity)
173
+
174
+ return sorted(filtered_entities, key=lambda x: x['start'])
175
+
176
  def highlight_entities(text, entities):
177
  """Create HTML with highlighted entities"""
178
  if not entities:
 
190
  score = entity['score']
191
 
192
  # Get color for this label
193
+ if label == 'STK':
194
+ color = label_colors.get('STK')
195
+ extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else ""
196
+ title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})"
197
+ elif label == 'HYP':
198
+ color = label_colors.get('HYP')
199
+ base_type = entity.get('base_type', '')
200
+ title_text = f"Hypernym (general term for {base_type})"
201
+ else:
202
+ color = label_colors.get(f"B-{label}", "#CCCCCC")
203
+ title_text = f"{label} (confidence: {score:.2f})"
204
 
205
  # Create highlighted span
206
+ highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{title_text}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
207
 
208
  result = result[:start] + highlighted + result[end:]
209
 
 
215
  return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
216
 
217
  try:
218
+ # Perform base NER
219
  entities = ner_pipeline(text)
220
 
221
+ # Detect stock symbols
222
+ stock_entities = detect_stock_symbols(text)
223
+
224
+ # Detect hypernyms
225
+ hypernym_entities = detect_hypernyms(text, entities)
226
+
227
+ # Merge all entities
228
+ all_entities = merge_entities(entities, stock_entities, hypernym_entities)
229
+
230
  # Create highlighted version
231
+ highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, all_entities)}</div>"
232
 
233
  # Create entities table
234
+ if all_entities:
235
  entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
236
+ entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) | اطلاعات اضافی (Additional Info) |\n"
237
+ entity_info += "|------------|-----------|---------------------|----------------------------------|\n"
238
+ for ent in all_entities:
239
  label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
240
+ extra_info = ""
241
+
242
+ if ent['entity_group'] == 'STK' and 'full_name' in ent:
243
+ extra_info = ent['full_name']
244
+ elif ent['entity_group'] == 'HYP':
245
+ extra_info = f"Hypernym of {ent.get('base_type', '')}"
246
+
247
+ entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n"
248
  else:
249
  entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
250
 
 
253
  except Exception as e:
254
  return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""
255
 
256
+ # Save stock symbols to CSV function
257
+ def save_symbols_to_csv(output_path='symbols.csv'):
258
+ """Save current stock symbols to CSV file"""
259
+ with open(output_path, 'w', encoding='utf-8', newline='') as f:
260
+ writer = csv.writer(f)
261
+ writer.writerow(['symbol', 'company_name'])
262
+ for symbol, name in STOCK_SYMBOLS.items():
263
+ writer.writerow([symbol, name])
264
+ print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}")
265
+
266
  # Example texts
267
  examples = [
268
  ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی می‌کرد."],
 
270
  ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
271
  ["دانشگاه تهران یکی از قدیمی‌ترین دانشگاه‌های ایران است."],
272
  ["علی و حسین به همراه مریم به مشهد سفر کردند."],
273
+ ["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."],
274
+ ["بانک ملت و وتجارت در بازار بورس فعال هستند."],
275
+ ["آقای احمدی مدیرعامل شرکت پتروشیمی است."],
276
+ ["وبملت و فملی امروز در صف خرید قرار گرفتند."],
277
  ]
278
 
279
  # Create Gradio interface
280
  with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نامدار فارسی", theme=gr.themes.Soft()) as demo:
281
+ gr.Markdown(f"""
282
+ # 🇮🇷 Persian Named Entity Recognition + Stock Symbols
283
+ # شناسایی موجودیت‌های نامدار فارسی + نمادهای بورس
284
 
285
+ این سیستم موجودیت‌های نامدار مانند اسامی اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها، **نمادهای بورس** و **واژه‌های عمومی (Hypernyms)** را در متن فارسی شناسایی می‌کند.
286
 
287
+ This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text.
288
 
289
+ **Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection
290
+ **Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE)
291
  **Running on:** CPU (may be slow for long texts)
292
+
293
+ ---
294
+
295
+ ### 📊 APIs for Updating Stock Symbols:
296
+
297
+ **Recommended Python Libraries:**
298
+ 1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data
299
+ 2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM
300
+ 3. **tse-dataloader** - Data extraction from Tehran Stock Exchange
301
+
302
+ **Example Usage:**
303
+ ```python
304
+ # Using tsetmc-api
305
+ from tsetmc_api import market_watch
306
+ stocks = market_watch.get_market_watch()
307
+
308
+ # Using tehran-stocks
309
+ from tehran_stocks import Stocks
310
+ all_stocks = Stocks.query.all()
311
+ ```
312
+
313
+ **Official TSE Website:** https://tse.ir
314
+ **TSETMC Data Portal:** http://www.tsetmc.com
315
  """)
316
 
317
  with gr.Row():
318
  with gr.Column():
319
  input_text = gr.Textbox(
320
  label="متن فارسی خود را وارد کنید (Enter Persian Text)",
321
+ placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...",
322
  lines=5,
323
  rtl=True
324
  )
 
338
  gr.Markdown("""
339
  ### راهنمای رنگ‌ها (Color Guide):
340
  - 🔴 **PER (شخص)**: اسامی اشخاص / Person names
341
+ - 🔵 **ORG (سازمان)**: نام سازمان‌ها / Organizations
342
  - 🟢 **LOC (مکان)**: نام مکان‌ها / Locations
343
  - 🟠 **DAT (تاریخ)**: تاریخ‌ها / Dates
344
  - 🟣 **TIM (زمان)**: زمان‌ها / Times
345
  - 🟡 **MON (پول)**: مقادیر پولی / Money
346
  - 🔷 **PCT (درصد)**: درصدها / Percentages
347
+ - 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols
348
+ - ⚫ **HYP (واژه عمومی)**: واژه‌های عمومی / Hypernyms (general terms)
349
+
350
+ ---
351
+
352
+ ### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد
353
+
354
+ *برای به‌روزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.*
355
  """)
356
 
357
  # Event handler