optimopium commited on
Commit
c4885a5
·
verified ·
1 Parent(s): 3c57d2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +357 -251
app.py CHANGED
@@ -1,83 +1,225 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
 
 
 
 
 
 
3
  import torch
 
4
  import re
5
- import csv
6
- import os
7
-
8
- # Set device to CPU explicitly
9
- device = "cpu"
10
 
11
- # Load the model and tokenizer
12
- model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
13
 
14
- print("Loading model and tokenizer...")
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForTokenClassification.from_pretrained(model_name)
17
- model.to(device)
 
 
18
 
19
  # Create NER pipeline
20
  ner_pipeline = pipeline(
21
  "ner",
22
- model=model,
23
- tokenizer=tokenizer,
24
- device=-1, # -1 means CPU
25
- aggregation_strategy="simple" # Groups entities together
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
 
28
- # Load stock symbols from CSV file
29
- def load_stock_symbols_from_csv(csv_path='symbols.csv'):
30
- """Load stock symbols from CSV file"""
31
- stock_symbols = {}
32
  try:
33
- with open(csv_path, 'r', encoding='utf-8') as f:
34
- reader = csv.DictReader(f)
35
- for row in reader:
36
- symbol = row['symbol']
37
- company_name = row['company_name']
38
- stock_symbols[symbol] = company_name
39
- print(f"Loaded {len(stock_symbols)} stock symbols from CSV")
40
- except FileNotFoundError:
41
- print(f"Warning: {csv_path} not found. Using default symbols.")
42
- return stock_symbols
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Load stock symbols
45
- STOCK_SYMBOLS = load_stock_symbols_from_csv()
 
46
 
47
- # Hypernym patterns (generic terms that can be made more specific)
48
- HYPERNYM_PATTERNS = {
49
- "شرکت": "ORG",
50
- "سازمان": "ORG",
51
- "موسسه": "ORG",
52
- "بانک": "ORG",
53
- "دانشگاه": "ORG",
54
- "شهر": "LOC",
55
- "کشور": "LOC",
56
- "استان": "LOC",
57
- "آقای": "PER",
58
- "خانم": "PER",
59
- "دکتر": "PER",
60
- "مهندس": "PER",
61
  }
62
 
63
- # Label mapping for better readability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  label_colors = {
65
- "B-PER": "#FF6B6B",
66
  "I-PER": "#FFB3B3",
67
- "B-ORG": "#4ECDC4",
68
  "I-ORG": "#A7E9E4",
69
- "B-LOC": "#95E1D3",
70
  "I-LOC": "#C7F0E8",
71
- "B-DAT": "#FFA07A",
72
  "I-DAT": "#FFDAB9",
73
- "B-TIM": "#DDA0DD",
74
  "I-TIM": "#E6D0E6",
75
- "B-MON": "#FFD700",
76
  "I-MON": "#FFEB99",
77
- "B-PCT": "#87CEEB",
78
  "I-PCT": "#B3DFEF",
79
- "STK": "#FF1493", # Stock symbol - Deep Pink
80
- "HYP": "#A9A9A9", # Hypernym - Dark Gray
81
  }
82
 
83
  label_names = {
@@ -88,98 +230,36 @@ label_names = {
88
  "TIM": "زمان (Time)",
89
  "MON": "پول (Money)",
90
  "PCT": "درصد (Percent)",
91
- "STK": "نماد بورس (Stock Symbol)",
92
- "HYP": "واژه عمومی (Hypernym)",
93
  }
94
 
95
- def detect_stock_symbols(text):
96
- """Detect Persian stock market symbols in text"""
97
- stock_entities = []
98
-
99
- # Split text into words
100
- words = re.findall(r'[\u0600-\u06FF]+', text)
101
-
102
- for word in words:
103
- if word in STOCK_SYMBOLS:
104
- # Find all occurrences of this symbol in the text
105
- for match in re.finditer(re.escape(word), text):
106
- stock_entities.append({
107
- 'entity_group': 'STK',
108
- 'word': word,
109
- 'start': match.start(),
110
- 'end': match.end(),
111
- 'score': 0.99, # High confidence for dictionary match
112
- 'full_name': STOCK_SYMBOLS[word]
113
- })
114
-
115
- return stock_entities
116
-
117
- def detect_hypernyms(text, entities):
118
- """Detect hypernyms (general terms) in text and classify them"""
119
- hypernym_entities = []
120
 
121
- for hypernym, entity_type in HYPERNYM_PATTERNS.items():
122
- for match in re.finditer(re.escape(hypernym), text):
123
- start, end = match.start(), match.end()
124
-
125
- # Check if this position already has a specific entity
126
- is_covered = False
127
- for ent in entities:
128
- if start >= ent['start'] and end <= ent['end']:
129
- is_covered = True
130
- break
131
-
132
- if not is_covered:
133
- hypernym_entities.append({
134
- 'entity_group': 'HYP',
135
- 'word': hypernym,
136
- 'start': start,
137
- 'end': end,
138
- 'score': 0.95,
139
- 'base_type': entity_type,
140
- 'is_hypernym': True
141
- })
142
 
143
- return hypernym_entities
144
-
145
- def merge_entities(entities, stock_entities, hypernym_entities):
146
- """Merge all entity types and remove overlaps, prioritizing specific entities"""
147
- all_entities = entities + stock_entities + hypernym_entities
148
-
149
- # Sort by start position
150
- all_entities.sort(key=lambda x: x['start'])
151
-
152
- # Remove overlapping entities (keep higher priority)
153
- # Priority: STK > specific entities > HYP
154
- filtered_entities = []
155
- for entity in all_entities:
156
- overlaps = False
157
- for existing in filtered_entities:
158
  # Check for overlap
159
- if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']):
160
- overlaps = True
161
- # If new entity is stock symbol, replace existing
162
- if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK':
163
- filtered_entities.remove(existing)
164
- overlaps = False
165
- # If existing is hypernym and new is specific, replace
166
- elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP':
167
- filtered_entities.remove(existing)
168
- overlaps = False
169
  break
170
-
171
- if not overlaps:
172
- filtered_entities.append(entity)
173
 
174
- return sorted(filtered_entities, key=lambda x: x['start'])
175
 
176
- def highlight_entities(text, entities):
177
  """Create HTML with highlighted entities"""
178
- if not entities:
179
  return text
180
 
181
- # Sort entities by start position (reverse order to replace from end to start)
182
- entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True)
183
 
184
  result = text
185
  for entity in entities_sorted:
@@ -190,177 +270,197 @@ def highlight_entities(text, entities):
190
  score = entity['score']
191
 
192
  # Get color for this label
193
- if label == 'STK':
194
- color = label_colors.get('STK')
195
- extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else ""
196
- title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})"
197
- elif label == 'HYP':
198
- color = label_colors.get('HYP')
199
- base_type = entity.get('base_type', '')
200
- title_text = f"Hypernym (general term for {base_type})"
201
- else:
202
- color = label_colors.get(f"B-{label}", "#CCCCCC")
203
- title_text = f"{label} (confidence: {score:.2f})"
204
 
205
  # Create highlighted span
206
- highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{title_text}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
207
 
208
  result = result[:start] + highlighted + result[end:]
209
 
210
  return result
211
 
212
  def perform_ner(text):
213
- """Perform NER on input text"""
214
  if not text.strip():
215
  return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
216
 
217
  try:
218
- # Perform base NER
219
  entities = ner_pipeline(text)
220
 
221
- # Detect stock symbols
222
- stock_entities = detect_stock_symbols(text)
223
 
224
- # Detect hypernyms
225
- hypernym_entities = detect_hypernyms(text, entities)
226
-
227
- # Merge all entities
228
- all_entities = merge_entities(entities, stock_entities, hypernym_entities)
229
 
230
  # Create highlighted version
231
- highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, all_entities)}</div>"
 
 
 
 
 
 
232
 
233
  # Create entities table
234
  if all_entities:
235
  entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
236
- entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) | اطلاعات اضافی (Additional Info) |\n"
237
- entity_info += "|------------|-----------|---------------------|----------------------------------|\n"
 
 
 
 
238
  for ent in all_entities:
239
  label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
240
- extra_info = ""
241
-
242
- if ent['entity_group'] == 'STK' and 'full_name' in ent:
243
- extra_info = ent['full_name']
244
- elif ent['entity_group'] == 'HYP':
245
- extra_info = f"Hypernym of {ent.get('base_type', '')}"
246
-
247
- entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n"
248
  else:
249
  entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
250
 
251
- return highlighted_html, entity_info
 
 
 
 
 
 
 
 
252
 
253
  except Exception as e:
254
- return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""
255
-
256
- # Save stock symbols to CSV function
257
- def save_symbols_to_csv(output_path='symbols.csv'):
258
- """Save current stock symbols to CSV file"""
259
- with open(output_path, 'w', encoding='utf-8', newline='') as f:
260
- writer = csv.writer(f)
261
- writer.writerow(['symbol', 'company_name'])
262
- for symbol, name in STOCK_SYMBOLS.items():
263
- writer.writerow([symbol, name])
264
- print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}")
265
 
266
- # Example texts
267
  examples = [
268
- ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی می‌کرد."],
269
- ["شرکت گوگل در کالیفرنیا واقع شده است."],
270
- ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
271
- ["دانشگاه تهران یکی از قدیمی‌ترین دانشگاه‌های ایران است."],
272
- ["علی و حسین به همراه مریم به مشهد سفر کردند."],
273
- ["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."],
274
- ["بانک ملت و وتجارت در بازار بورس فعال هستند."],
275
- ["آقای احمدی مدیرعامل شرکت پتروشیمی است."],
276
- ["وبملت و فملی امروز در صف خرید قرار گرفتند."],
277
  ]
278
 
279
  # Create Gradio interface
280
- with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نامدار فارسی", theme=gr.themes.Soft()) as demo:
281
- gr.Markdown(f"""
282
- # 🇮🇷 Persian Named Entity Recognition + Stock Symbols
283
- # شناسایی موجودیت‌های نامدار فارسی + نمادهای بورس
 
 
 
 
 
 
284
 
285
- این سیستم موجودیت‌های نامدار مانند اسامی اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها، **نمادهای بورس** و **واژه‌های عمومی (Hypernyms)** را در متن فارسی شناسایی می‌کند.
 
286
 
287
- This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text.
 
288
 
289
- **Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection
290
- **Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE)
291
- **Running on:** CPU (may be slow for long texts)
292
 
293
  ---
294
-
295
- ### 📊 APIs for Updating Stock Symbols:
296
-
297
- **Recommended Python Libraries:**
298
- 1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data
299
- 2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM
300
- 3. **tse-dataloader** - Data extraction from Tehran Stock Exchange
301
-
302
- **Example Usage:**
303
- ```python
304
- # Using tsetmc-api
305
- from tsetmc_api import market_watch
306
- stocks = market_watch.get_market_watch()
307
-
308
- # Using tehran-stocks
309
- from tehran_stocks import Stocks
310
- all_stocks = Stocks.query.all()
311
- ```
312
-
313
- **Official TSE Website:** https://tse.ir
314
- **TSETMC Data Portal:** http://www.tsetmc.com
315
  """)
316
 
317
  with gr.Row():
318
- with gr.Column():
319
  input_text = gr.Textbox(
320
- label="متن فارسی خود را وارد کنید (Enter Persian Text)",
321
- placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...",
322
- lines=5,
323
- rtl=True
 
324
  )
325
- submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary")
 
 
326
 
327
- with gr.Column():
328
- output_html = gr.HTML(label="متن با موجودیت‌های برجسته (Text with Highlighted Entities)")
329
- output_entities = gr.Markdown(label="لیست موجودیت‌ها (Entity List)")
 
 
 
 
 
 
 
 
330
 
331
  gr.Examples(
332
  examples=examples,
333
  inputs=input_text,
334
- label="مثال‌ها (Examples)"
 
335
  )
336
 
337
- # Legend
338
- gr.Markdown("""
339
- ### راهنمای رنگ‌ها (Color Guide):
340
- - 🔴 **PER (شخص)**: اسامی اشخاص / Person names
341
- - 🔵 **ORG (سازمان)**: نام سازمان‌ها / Organizations
342
- - 🟢 **LOC (مکان)**: نام مکان‌ها / Locations
343
- - 🟠 **DAT (تاریخ)**: تاریخ‌ها / Dates
344
- - 🟣 **TIM (زمان)**: زمان‌ها / Times
345
- - 🟡 **MON (پول)**: مقادیر پولی / Money
346
- - 🔷 **PCT (درصد)**: درصدها / Percentages
347
- - 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols
348
- - **HYP (واژه عمومی)**: واژه‌های عمومی / Hypernyms (general terms)
349
-
350
- ---
351
-
352
- ### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد
 
 
353
 
354
- *برای به‌روزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.*
355
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- # Event handler
358
  submit_btn.click(
359
  fn=perform_ner,
360
  inputs=input_text,
361
  outputs=[output_html, output_entities]
362
  )
363
 
 
 
 
 
 
364
  input_text.submit(
365
  fn=perform_ner,
366
  inputs=input_text,
@@ -369,4 +469,10 @@ with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نا
369
 
370
  # Launch the app
371
  if __name__ == "__main__":
372
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForTokenClassification,
5
+ AutoModelForSequenceClassification,
6
+ AutoModelForQuestionAnswering,
7
+ pipeline
8
+ )
9
  import torch
10
+ import pandas as pd
11
  import re
12
+ from typing import List, Dict, Tuple
13
+ import numpy as np
 
 
 
14
 
15
+ # Set device
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # Load the NER model and tokenizer
19
+ print("Loading NER model...")
20
+ ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
21
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
22
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
23
+ ner_model.to(device)
24
 
25
  # Create NER pipeline
26
  ner_pipeline = pipeline(
27
  "ner",
28
+ model=ner_model,
29
+ tokenizer=ner_tokenizer,
30
+ device=0 if device == "cuda" else -1,
31
+ aggregation_strategy="simple"
32
+ )
33
+
34
+ # Load Persian LLM for context understanding
35
+ print("Loading Persian context model...")
36
+ # Using a Persian BERT model fine-tuned for question answering/text classification
37
+ context_model_name = "persiannlp/mt5-small-parsinlu-multiple-choice" # Alternative model
38
+ # You can also try: "HooshvareLab/bert-fa-base-uncased" with custom classification head
39
+
40
+ # For stock symbol disambiguation, we'll use a QA approach
41
+ qa_model_name = "mohammadehab/persian-qa-bert"
42
+ qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
43
+ qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
44
+ qa_model.to(device)
45
+
46
+ qa_pipeline = pipeline(
47
+ "question-answering",
48
+ model=qa_model,
49
+ tokenizer=qa_tokenizer,
50
+ device=0 if device == "cuda" else -1
51
  )
52
 
53
+ # Load stock symbols from CSV
54
+ def load_stock_symbols(csv_path="symbols.csv"):
55
+ """Load Iranian stock market symbols from CSV file"""
 
56
  try:
57
+ df = pd.read_csv(csv_path, encoding='utf-8')
58
+ symbols_dict = {}
59
+ for _, row in df.iterrows():
60
+ symbol = row['symbol']
61
+ symbols_dict[symbol] = {
62
+ 'company': row['company_name'],
63
+ 'bazaar': row['bazaar'],
64
+ 'bazaar_group': row['bazaar_group']
65
+ }
66
+ return symbols_dict
67
+ except Exception as e:
68
+ print(f"Error loading symbols CSV: {e}")
69
+ # Provide some default symbols for demo
70
+ return {
71
+ 'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'},
72
+ 'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'},
73
+ 'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'},
74
+ 'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'},
75
+ 'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'},
76
+ 'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'},
77
+ }
78
 
79
+ # Load symbols
80
+ STOCK_SYMBOLS = load_stock_symbols()
81
+ SYMBOL_NAMES = set(STOCK_SYMBOLS.keys())
82
 
83
+ # Market context keywords for Persian
84
+ MARKET_KEYWORDS = {
85
+ 'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش',
86
+ 'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد',
87
+ 'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود',
88
+ 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت',
89
+ 'کندل', 'نمودار', 'تحلیل', 'بنیادی', 'تکنیکال', 'حمایت', 'مقاومت'
 
 
 
 
 
 
 
90
  }
91
 
92
+ def use_llm_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
93
+ """
94
+ Use Persian QA model to determine if a word is used as a stock symbol
95
+ Returns confidence score (0-1)
96
+ """
97
+ try:
98
+ # Create context-aware questions in Persian
99
+ context = text
100
+
101
+ # Ask multiple questions to get better understanding
102
+ questions = [
103
+ f"آیا {potential_symbol} در این متن نماد بورسی است؟",
104
+ f"آیا منظور از {potential_symbol} سهام شرکت {symbol_info['company']} است؟",
105
+ f"آیا {potential_symbol} در این متن به معاملات بورس اشاره دارد؟"
106
+ ]
107
+
108
+ scores = []
109
+ for question in questions:
110
+ try:
111
+ result = qa_pipeline(question=question, context=context)
112
+ # Check if the answer contains positive indicators
113
+ answer = result['answer'].lower()
114
+ confidence = result['score']
115
+
116
+ # Adjust confidence based on answer content
117
+ if any(word in answer for word in ['بله', 'است', 'می‌باشد', 'درست']):
118
+ scores.append(confidence)
119
+ elif any(word in answer for word in ['خیر', 'نیست', 'نمی‌باشد']):
120
+ scores.append(1 - confidence)
121
+ else:
122
+ scores.append(0.5) # Uncertain
123
+
124
+ except Exception:
125
+ scores.append(0.5) # Default to uncertain
126
+
127
+ # Return average confidence
128
+ return np.mean(scores) if scores else 0.5
129
+
130
+ except Exception as e:
131
+ print(f"LLM disambiguation error: {e}")
132
+ return 0.5 # Return neutral confidence on error
133
+
134
+ def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
135
+ """
136
+ Check if a potential symbol is actually used as a stock symbol in context
137
+ Returns (is_stock_symbol, confidence_score)
138
+ """
139
+ # First, do a heuristic check
140
+ text_lower = text.lower()
141
+
142
+ # Get surrounding context (window of 30 characters before and after)
143
+ symbol_pos = text.find(potential_symbol)
144
+ if symbol_pos == -1:
145
+ return False, 0.0
146
+
147
+ start_context = max(0, symbol_pos - 50)
148
+ end_context = min(len(text), symbol_pos + len(potential_symbol) + 50)
149
+ context_window = text[start_context:end_context]
150
+
151
+ # Count market keywords in context
152
+ words_in_context = context_window.split()
153
+ market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS)
154
+
155
+ # Calculate heuristic score
156
+ heuristic_score = min(market_keyword_count * 0.3, 1.0)
157
+
158
+ # If very strong or very weak signal from heuristics, use that
159
+ if heuristic_score >= 0.9:
160
+ return True, heuristic_score
161
+ elif heuristic_score == 0 and len(words_in_context) > 5:
162
+ return False, 0.1
163
+
164
+ # For ambiguous cases, use LLM
165
+ llm_score = use_llm_for_disambiguation(text, potential_symbol, symbol_info)
166
+
167
+ # Combine heuristic and LLM scores
168
+ final_score = (heuristic_score * 0.4 + llm_score * 0.6)
169
+
170
+ # Decision threshold
171
+ is_stock = final_score > 0.5
172
+
173
+ return is_stock, final_score
174
+
175
+ def find_stock_symbols_in_text(text: str) -> List[Dict]:
176
+ """Find and validate stock symbols in text using context analysis"""
177
+ found_symbols = []
178
+
179
+ # Use regex to find all potential symbols (Persian words)
180
+ # This pattern matches Persian words that might be symbols
181
+ pattern = r'\b[\u0600-\u06FF]+\b'
182
+
183
+ for match in re.finditer(pattern, text):
184
+ word = match.group()
185
+
186
+ if word in SYMBOL_NAMES:
187
+ symbol_info = STOCK_SYMBOLS[word]
188
+
189
+ # Check context to determine if it's actually used as a stock symbol
190
+ is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
191
+
192
+ if is_stock:
193
+ found_symbols.append({
194
+ 'word': word,
195
+ 'start': match.start(),
196
+ 'end': match.end(),
197
+ 'entity_group': 'STOCK',
198
+ 'score': confidence,
199
+ 'company': symbol_info['company'],
200
+ 'bazaar': symbol_info['bazaar'],
201
+ 'bazaar_group': symbol_info['bazaar_group']
202
+ })
203
+
204
+ return found_symbols
205
+
206
+ # Label colors and names
207
  label_colors = {
208
+ "B-PER": "#FF6B6B", # Person - Red
209
  "I-PER": "#FFB3B3",
210
+ "B-ORG": "#4ECDC4", # Organization - Teal
211
  "I-ORG": "#A7E9E4",
212
+ "B-LOC": "#95E1D3", # Location - Green
213
  "I-LOC": "#C7F0E8",
214
+ "B-DAT": "#FFA07A", # Date - Orange
215
  "I-DAT": "#FFDAB9",
216
+ "B-TIM": "#DDA0DD", # Time - Purple
217
  "I-TIM": "#E6D0E6",
218
+ "B-MON": "#FFD700", # Money - Gold
219
  "I-MON": "#FFEB99",
220
+ "B-PCT": "#87CEEB", # Percent - Sky Blue
221
  "I-PCT": "#B3DFEF",
222
+ "STOCK": "#00FA9A", # Stock Symbol - Medium Spring Green
 
223
  }
224
 
225
  label_names = {
 
230
  "TIM": "زمان (Time)",
231
  "MON": "پول (Money)",
232
  "PCT": "درصد (Percent)",
233
+ "STOCK": "نماد بورسی (Stock Symbol)",
 
234
  }
235
 
236
+ def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
237
+ """Merge entities, removing overlaps (stock symbols take precedence)"""
238
+ all_entities = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # Add stock entities first (they have priority)
241
+ all_entities.extend(stock_entities)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # Add NER entities that don't overlap with stock entities
244
+ for ner_ent in entities:
245
+ overlap = False
246
+ for stock_ent in stock_entities:
 
 
 
 
 
 
 
 
 
 
 
247
  # Check for overlap
248
+ if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']):
249
+ overlap = True
 
 
 
 
 
 
 
 
250
  break
251
+ if not overlap:
252
+ all_entities.append(ner_ent)
 
253
 
254
+ return all_entities
255
 
256
+ def highlight_entities(text, all_entities):
257
  """Create HTML with highlighted entities"""
258
+ if not all_entities:
259
  return text
260
 
261
+ # Sort entities by start position (reverse order)
262
+ entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True)
263
 
264
  result = text
265
  for entity in entities_sorted:
 
270
  score = entity['score']
271
 
272
  # Get color for this label
273
+ color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC")
274
+
275
+ # Add extra info for stock symbols
276
+ tooltip_info = f"{label} (confidence: {score:.2f})"
277
+ if label == 'STOCK':
278
+ company = entity.get('company', '')
279
+ bazaar = entity.get('bazaar', '')
280
+ if company:
281
+ tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})"
 
 
282
 
283
  # Create highlighted span
284
+ highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
285
 
286
  result = result[:start] + highlighted + result[end:]
287
 
288
  return result
289
 
290
  def perform_ner(text):
291
+ """Perform integrated NER and stock symbol detection"""
292
  if not text.strip():
293
  return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
294
 
295
  try:
296
+ # Perform standard NER
297
  entities = ner_pipeline(text)
298
 
299
+ # Find stock symbols using Persian LLM
300
+ stock_entities = find_stock_symbols_in_text(text)
301
 
302
+ # Merge entities (remove overlaps)
303
+ all_entities = merge_overlapping_entities(entities, stock_entities)
 
 
 
304
 
305
  # Create highlighted version
306
+ highlighted_html = f"""
307
+ <div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5;
308
+ padding: 20px; border: 1px solid #ddd; border-radius: 5px;
309
+ background-color: #f9f9f9; font-family: Tahoma, Arial;'>
310
+ {highlight_entities(text, all_entities)}
311
+ </div>
312
+ """
313
 
314
  # Create entities table
315
  if all_entities:
316
  entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
317
+ entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n"
318
+ entity_info += "|:------------|:-----------|:------------------|:---------------------|\n"
319
+
320
+ # Sort by position in text
321
+ all_entities.sort(key=lambda x: x['start'])
322
+
323
  for ent in all_entities:
324
  label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
325
+ details = ""
326
+ if ent['entity_group'] == 'STOCK':
327
+ company = ent.get('company', '')
328
+ bazaar = ent.get('bazaar', '')
329
+ group = ent.get('bazaar_group', '')
330
+ details = f"{company}<br>{bazaar}<br>{group}"
331
+ entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n"
 
332
  else:
333
  entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
334
 
335
+ # Add statistics
336
+ stats = f"\n\n### آمار (Statistics):\n"
337
+ stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n"
338
+ stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n"
339
+ stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n"
340
+ stats += f"- سازمان‌ها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n"
341
+ stats += f"- مکان‌ها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n"
342
+
343
+ return highlighted_html, entity_info + stats
344
 
345
  except Exception as e:
346
+ return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e)
 
 
 
 
 
 
 
 
 
 
347
 
348
+ # Enhanced examples
349
  examples = [
350
+ ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."],
351
+ ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."],
352
+ ["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."],
353
+ ["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."],
354
+ ["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."],
355
+ ["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید می‌کند."],
356
+ ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."],
 
 
357
  ]
358
 
359
  # Create Gradio interface
360
+ with gr.Blocks(
361
+ title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی",
362
+ theme=gr.themes.Soft(),
363
+ css="""
364
+ .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
365
+ """
366
+ ) as demo:
367
+ gr.Markdown("""
368
+ # 🏦 Persian NER with Intelligent Stock Symbol Detection
369
+ # شناسایی هوشمند موجودیت‌های نامدار و نمادهای بورس ایران
370
 
371
+ <div class="rtl-text">
372
+ این سیستم با است��اده از دو مدل هوش مصنوعی:
373
 
374
+ 1. **ParsBERT-NER**: برای شناسایی موجودیت‌های نامدار (اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها)
375
+ 2. **Persian BERT QA**: برای تشخیص هوشمند نمادهای بورسی با درک متن
376
 
377
+ ویژگی خاص: تشخیص هوشمند کلماتی مثل «فولاد» که می‌تواند نماد بورسی یا کلمه عادی باشد
378
+ </div>
 
379
 
380
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  """)
382
 
383
  with gr.Row():
384
+ with gr.Column(scale=6):
385
  input_text = gr.Textbox(
386
+ label="متن فارسی (Persian Text)",
387
+ placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...",
388
+ lines=6,
389
+ rtl=True,
390
+ elem_classes=["rtl-text"]
391
  )
392
+ with gr.Row():
393
+ submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2)
394
+ clear_btn = gr.Button("🗑️ پاک کردن", scale=1)
395
 
396
+ with gr.Column(scale=6):
397
+ output_html = gr.HTML(
398
+ label="نتیجه تحلیل (Analysis Result)",
399
+ elem_classes=["rtl-text"]
400
+ )
401
+
402
+ with gr.Row():
403
+ output_entities = gr.Markdown(
404
+ label="جدول موجودیت‌ها (Entity Table)",
405
+ elem_classes=["rtl-text"]
406
+ )
407
 
408
  gr.Examples(
409
  examples=examples,
410
  inputs=input_text,
411
+ label="نمونه‌های آماده (Ready Examples)",
412
+ examples_per_page=4
413
  )
414
 
415
+ # Color guide
416
+ with gr.Accordion("📖 راهنمای رنگ‌ها (Color Guide)", open=False):
417
+ gr.Markdown("""
418
+ <div class="rtl-text">
419
+
420
+ | رنگ | نوع موجودیت | توضیحات |
421
+ |:---:|:------------|:--------|
422
+ | 🔴 | **PER** | اشخاص و نام‌های افراد |
423
+ | 🔵 | **ORG** | سازمان‌ها و شرکت‌ها |
424
+ | 🟢 | **LOC** | مکان‌ها و نام‌های جغرافیایی |
425
+ | 🟠 | **DAT** | تاریخ‌ها |
426
+ | 🟣 | **TIM** | زمان‌ها |
427
+ | 🟡 | **MON** | مقادیر پولی |
428
+ | 🔷 | **PCT** | درصدها |
429
+ | 💚 | **STOCK** | نمادهای بورسی (با تحلیل هوشمند متن) |
430
+
431
+ </div>
432
+ """)
433
 
434
+ # Info section
435
+ with gr.Accordion("ℹ️ درباره سیستم (About)", open=False):
436
+ gr.Markdown("""
437
+ <div class="rtl-text">
438
+
439
+ ### قابلیت‌های کلیدی:
440
+
441
+ - **تشخیص هوشمند نمادهای بورسی**: با استفاده از مدل زبانی فارسی، سیستم تشخیص می‌دهد که آیا کلمات مشابه نمادها (مثل فولاد، فارس) در متن به عنوان نماد بورسی استفاده شده‌اند یا معنای عادی دارند
442
+
443
+ - **ترکیب دو مدل**: استفاده همزمان از ParsBERT-NER برای NER کلاسیک و Persian BERT برای درک متن
444
+
445
+ - **اطلاعات کامل نمادها**: نمایش نام شرکت، بازار، و گروه صنعت برای هر نماد شناسایی شده
446
+
447
+ - **دقت بالا**: با ترکیب تحلیل‌های مبتنی بر قواعد و مدل زبانی
448
+
449
+ </div>
450
+ """)
451
 
452
+ # Event handlers
453
  submit_btn.click(
454
  fn=perform_ner,
455
  inputs=input_text,
456
  outputs=[output_html, output_entities]
457
  )
458
 
459
+ clear_btn.click(
460
+ lambda: ("", "", ""),
461
+ outputs=[input_text, output_html, output_entities]
462
+ )
463
+
464
  input_text.submit(
465
  fn=perform_ner,
466
  inputs=input_text,
 
469
 
470
  # Launch the app
471
  if __name__ == "__main__":
472
+ print("Starting Persian NER + Stock Symbol Detection System...")
473
+ print(f"Using device: {device}")
474
+ print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
475
+ demo.launch(
476
+ share=False,
477
+ debug=True
478
+ )