optimopium commited on
Commit
4865162
·
verified ·
1 Parent(s): c4885a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -137
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
  from transformers import (
3
  AutoTokenizer,
4
- AutoModelForTokenClassification,
5
- AutoModelForSequenceClassification,
6
- AutoModelForQuestionAnswering,
7
  pipeline
8
  )
9
  import torch
@@ -12,8 +11,9 @@ import re
12
  from typing import List, Dict, Tuple
13
  import numpy as np
14
 
15
- # Set device
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
17
 
18
  # Load the NER model and tokenizer
19
  print("Loading NER model...")
@@ -31,23 +31,32 @@ ner_pipeline = pipeline(
31
  aggregation_strategy="simple"
32
  )
33
 
34
- # Load Persian LLM for context understanding
35
- print("Loading Persian context model...")
36
- # Using a Persian BERT model fine-tuned for question answering/text classification
37
- context_model_name = "persiannlp/mt5-small-parsinlu-multiple-choice" # Alternative model
38
- # You can also try: "HooshvareLab/bert-fa-base-uncased" with custom classification head
39
 
40
- # For stock symbol disambiguation, we'll use a QA approach
41
- qa_model_name = "mohammadehab/persian-qa-bert"
42
- qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
43
- qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
44
- qa_model.to(device)
 
 
 
 
 
 
45
 
46
- qa_pipeline = pipeline(
47
- "question-answering",
48
- model=qa_model,
49
- tokenizer=qa_tokenizer,
50
- device=0 if device == "cuda" else -1
 
 
 
 
 
51
  )
52
 
53
  # Load stock symbols from CSV
@@ -66,7 +75,7 @@ def load_stock_symbols(csv_path="symbols.csv"):
66
  return symbols_dict
67
  except Exception as e:
68
  print(f"Error loading symbols CSV: {e}")
69
- # Provide some default symbols for demo
70
  return {
71
  'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'},
72
  'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'},
@@ -74,98 +83,109 @@ def load_stock_symbols(csv_path="symbols.csv"):
74
  'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'},
75
  'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'},
76
  'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'},
 
 
77
  }
78
 
79
  # Load symbols
80
  STOCK_SYMBOLS = load_stock_symbols()
81
  SYMBOL_NAMES = set(STOCK_SYMBOLS.keys())
82
 
83
- # Market context keywords for Persian
84
  MARKET_KEYWORDS = {
85
  'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش',
86
  'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد',
87
  'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود',
88
- 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت',
89
- 'کندل', 'نمودار', 'تحلیل', 'بنیادی', 'تکنیکال', 'حمایت', 'مقاومت'
90
  }
91
 
92
- def use_llm_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
93
  """
94
- Use Persian QA model to determine if a word is used as a stock symbol
95
  Returns confidence score (0-1)
96
  """
97
  try:
98
- # Create context-aware questions in Persian
99
- context = text
100
-
101
- # Ask multiple questions to get better understanding
102
- questions = [
103
- f"آیا {potential_symbol} در این متن نماد بورسی است؟",
104
- f"آیا منظور از {potential_symbol} سهام شرکت {symbol_info['company']} است؟",
105
- f"آیا {potential_symbol} در این متن به معاملات بورس اشاره دارد؟"
106
- ]
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- scores = []
109
- for question in questions:
110
- try:
111
- result = qa_pipeline(question=question, context=context)
112
- # Check if the answer contains positive indicators
113
- answer = result['answer'].lower()
114
- confidence = result['score']
115
-
116
- # Adjust confidence based on answer content
117
- if any(word in answer for word in ['بله', 'است', 'می‌باشد', 'درست']):
118
- scores.append(confidence)
119
- elif any(word in answer for word in ['خیر', 'نیست', 'نمی‌باشد']):
120
- scores.append(1 - confidence)
121
- else:
122
- scores.append(0.5) # Uncertain
123
-
124
- except Exception:
125
- scores.append(0.5) # Default to uncertain
126
 
127
- # Return average confidence
128
- return np.mean(scores) if scores else 0.5
129
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  except Exception as e:
131
- print(f"LLM disambiguation error: {e}")
132
- return 0.5 # Return neutral confidence on error
133
 
134
  def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
135
  """
136
  Check if a potential symbol is actually used as a stock symbol in context
137
- Returns (is_stock_symbol, confidence_score)
138
  """
139
- # First, do a heuristic check
140
- text_lower = text.lower()
141
-
142
- # Get surrounding context (window of 30 characters before and after)
143
  symbol_pos = text.find(potential_symbol)
144
  if symbol_pos == -1:
145
  return False, 0.0
146
 
147
- start_context = max(0, symbol_pos - 50)
148
- end_context = min(len(text), symbol_pos + len(potential_symbol) + 50)
149
  context_window = text[start_context:end_context]
150
 
151
- # Count market keywords in context
152
  words_in_context = context_window.split()
153
  market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS)
154
 
155
  # Calculate heuristic score
156
- heuristic_score = min(market_keyword_count * 0.3, 1.0)
157
 
158
- # If very strong or very weak signal from heuristics, use that
159
- if heuristic_score >= 0.9:
160
- return True, heuristic_score
161
- elif heuristic_score == 0 and len(words_in_context) > 5:
162
- return False, 0.1
163
 
164
- # For ambiguous cases, use LLM
165
- llm_score = use_llm_for_disambiguation(text, potential_symbol, symbol_info)
166
 
167
- # Combine heuristic and LLM scores
168
- final_score = (heuristic_score * 0.4 + llm_score * 0.6)
169
 
170
  # Decision threshold
171
  is_stock = final_score > 0.5
@@ -173,20 +193,20 @@ def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Di
173
  return is_stock, final_score
174
 
175
  def find_stock_symbols_in_text(text: str) -> List[Dict]:
176
- """Find and validate stock symbols in text using context analysis"""
177
  found_symbols = []
 
178
 
179
- # Use regex to find all potential symbols (Persian words)
180
- # This pattern matches Persian words that might be symbols
181
  pattern = r'\b[\u0600-\u06FF]+\b'
182
 
183
  for match in re.finditer(pattern, text):
184
  word = match.group()
185
 
186
- if word in SYMBOL_NAMES:
187
  symbol_info = STOCK_SYMBOLS[word]
188
 
189
- # Check context to determine if it's actually used as a stock symbol
190
  is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
191
 
192
  if is_stock:
@@ -200,26 +220,27 @@ def find_stock_symbols_in_text(text: str) -> List[Dict]:
200
  'bazaar': symbol_info['bazaar'],
201
  'bazaar_group': symbol_info['bazaar_group']
202
  })
 
203
 
204
  return found_symbols
205
 
206
  # Label colors and names
207
  label_colors = {
208
- "B-PER": "#FF6B6B", # Person - Red
209
  "I-PER": "#FFB3B3",
210
- "B-ORG": "#4ECDC4", # Organization - Teal
211
  "I-ORG": "#A7E9E4",
212
- "B-LOC": "#95E1D3", # Location - Green
213
  "I-LOC": "#C7F0E8",
214
- "B-DAT": "#FFA07A", # Date - Orange
215
  "I-DAT": "#FFDAB9",
216
- "B-TIM": "#DDA0DD", # Time - Purple
217
  "I-TIM": "#E6D0E6",
218
- "B-MON": "#FFD700", # Money - Gold
219
  "I-MON": "#FFEB99",
220
- "B-PCT": "#87CEEB", # Percent - Sky Blue
221
  "I-PCT": "#B3DFEF",
222
- "STOCK": "#00FA9A", # Stock Symbol - Medium Spring Green
223
  }
224
 
225
  label_names = {
@@ -237,14 +258,13 @@ def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict])
237
  """Merge entities, removing overlaps (stock symbols take precedence)"""
238
  all_entities = []
239
 
240
- # Add stock entities first (they have priority)
241
  all_entities.extend(stock_entities)
242
 
243
- # Add NER entities that don't overlap with stock entities
244
  for ner_ent in entities:
245
  overlap = False
246
  for stock_ent in stock_entities:
247
- # Check for overlap
248
  if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']):
249
  overlap = True
250
  break
@@ -258,7 +278,6 @@ def highlight_entities(text, all_entities):
258
  if not all_entities:
259
  return text
260
 
261
- # Sort entities by start position (reverse order)
262
  entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True)
263
 
264
  result = text
@@ -269,10 +288,8 @@ def highlight_entities(text, all_entities):
269
  word = text[start:end]
270
  score = entity['score']
271
 
272
- # Get color for this label
273
  color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC")
274
 
275
- # Add extra info for stock symbols
276
  tooltip_info = f"{label} (confidence: {score:.2f})"
277
  if label == 'STOCK':
278
  company = entity.get('company', '')
@@ -280,7 +297,6 @@ def highlight_entities(text, all_entities):
280
  if company:
281
  tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})"
282
 
283
- # Create highlighted span
284
  highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
285
 
286
  result = result[:start] + highlighted + result[end:]
@@ -296,13 +312,13 @@ def perform_ner(text):
296
  # Perform standard NER
297
  entities = ner_pipeline(text)
298
 
299
- # Find stock symbols using Persian LLM
300
  stock_entities = find_stock_symbols_in_text(text)
301
 
302
- # Merge entities (remove overlaps)
303
  all_entities = merge_overlapping_entities(entities, stock_entities)
304
 
305
- # Create highlighted version
306
  highlighted_html = f"""
307
  <div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5;
308
  padding: 20px; border: 1px solid #ddd; border-radius: 5px;
@@ -317,7 +333,6 @@ def perform_ner(text):
317
  entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n"
318
  entity_info += "|:------------|:-----------|:------------------|:---------------------|\n"
319
 
320
- # Sort by position in text
321
  all_entities.sort(key=lambda x: x['start'])
322
 
323
  for ent in all_entities:
@@ -332,7 +347,7 @@ def perform_ner(text):
332
  else:
333
  entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
334
 
335
- # Add statistics
336
  stats = f"\n\n### آمار (Statistics):\n"
337
  stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n"
338
  stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n"
@@ -345,7 +360,7 @@ def perform_ner(text):
345
  except Exception as e:
346
  return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e)
347
 
348
- # Enhanced examples
349
  examples = [
350
  ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."],
351
  ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."],
@@ -356,7 +371,7 @@ examples = [
356
  ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."],
357
  ]
358
 
359
- # Create Gradio interface
360
  with gr.Blocks(
361
  title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی",
362
  theme=gr.themes.Soft(),
@@ -365,25 +380,19 @@ with gr.Blocks(
365
  """
366
  ) as demo:
367
  gr.Markdown("""
368
- # 🏦 Persian NER with Intelligent Stock Symbol Detection
369
- # شناسایی هوشمند موجودیت‌های نامدار و نمادهای بورس ایران
 
370
 
371
  <div class="rtl-text">
372
- این سیستم با استفاده از دو مدل هوش مصنوعی:
373
-
374
- 1. **ParsBERT-NER**: برای شناسایی موجودیت‌های نامدار (اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها)
375
- 2. **Persian BERT QA**: برای تشخیص هوشمند نمادهای بورسی با درک متن
376
-
377
- ویژگی خاص: تشخیص هوشمند کلماتی مثل «فولاد» که می‌تواند نماد بورسی یا کلمه عادی باشد
378
  </div>
379
-
380
- ---
381
  """)
382
 
383
  with gr.Row():
384
  with gr.Column(scale=6):
385
  input_text = gr.Textbox(
386
- label="متن فارسی (Persian Text)",
387
  placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...",
388
  lines=6,
389
  rtl=True,
@@ -412,39 +421,56 @@ with gr.Blocks(
412
  examples_per_page=4
413
  )
414
 
415
- # Color guide
416
- with gr.Accordion("📖 راهنمای رنگ‌ها (Color Guide)", open=False):
417
  gr.Markdown("""
418
  <div class="rtl-text">
419
 
420
- | رنگ | نوع موجودیت | توضیحات |
421
- |:---:|:------------|:--------|
422
- | 🔴 | **PER** | اشخاص و نام‌های افراد |
423
- | 🔵 | **ORG** | سازمان‌ها و شرکت‌ها |
424
- | 🟢 | **LOC** | مکان‌ها و نام‌های جغرافیایی |
425
- | 🟠 | **DAT** | تاریخ‌ها |
426
- | 🟣 | **TIM** | زمان‌ها |
427
- | 🟡 | **MON** | مقادیر پولی |
428
- | 🔷 | **PCT** | درصدها |
429
- | 💚 | **STOCK** | نمادهای بورسی (با تحلیل هوشمند متن) |
430
 
431
- </div>
432
- """)
433
-
434
- # Info section
435
- with gr.Accordion("ℹ️ درباره سیستم (About)", open=False):
436
- gr.Markdown("""
437
- <div class="rtl-text">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
- ### قابلیت‌های کلیدی:
440
 
441
- - **تشخیص هوشمند نمادهای بورسی**: با استفاده از مدل زبانی فارسی، سیستم تشخیص می‌دهد که آیا کلمات مشابه نمادها (مثل فولاد، فارس) در متن به عنوان نماد بورسی استفاده شده‌اند یا معنای عادی دارند
 
 
 
442
 
443
- - **ترکیب دو مدل**: استفاده همزمان از ParsBERT-NER برای NER کلاسیک و Persian BERT برای درک متن
444
 
445
- - **اطلاعات کامل نمادها**: نمایش نام شرکت، بازار، و گروه صنعت برای هر نماد شناسایی شده
 
446
 
447
- - **دقت بالا**: با ترکیب تحلیل‌های مبتنی بر قواعد و مدل زبانی
448
 
449
  </div>
450
  """)
@@ -467,11 +493,16 @@ with gr.Blocks(
467
  outputs=[output_html, output_entities]
468
  )
469
 
470
- # Launch the app
471
  if __name__ == "__main__":
472
  print("Starting Persian NER + Stock Symbol Detection System...")
473
  print(f"Using device: {device}")
474
  print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
 
 
 
 
 
475
  demo.launch(
476
  share=False,
477
  debug=True
 
1
  import gradio as gr
2
  from transformers import (
3
  AutoTokenizer,
4
+ AutoModelForTokenClassification,
5
+ AutoModelForCausalLM,
 
6
  pipeline
7
  )
8
  import torch
 
11
  from typing import List, Dict, Tuple
12
  import numpy as np
13
 
14
+ # Set device and dtype for optimization
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
 
18
  # Load the NER model and tokenizer
19
  print("Loading NER model...")
 
31
  aggregation_strategy="simple"
32
  )
33
 
34
+ # Load Gemma model for stock symbol detection
35
+ print("Loading Gemma-2-9b-it model for context understanding...")
36
+ gemma_model_name = "google/gemma-2-9b-it"
 
 
37
 
38
+ # Load with optimization settings for better performance
39
+ gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_name)
40
+ gemma_model = AutoModelForCausalLM.from_pretrained(
41
+ gemma_model_name,
42
+ torch_dtype=dtype,
43
+ device_map="auto" if torch.cuda.is_available() else None,
44
+ low_cpu_mem_usage=True
45
+ )
46
+
47
+ if device == "cpu":
48
+ gemma_model = gemma_model.to(device)
49
 
50
+ # Create text generation pipeline
51
+ gemma_pipeline = pipeline(
52
+ "text-generation",
53
+ model=gemma_model,
54
+ tokenizer=gemma_tokenizer,
55
+ device=0 if device == "cuda" else -1,
56
+ max_new_tokens=50,
57
+ temperature=0.1, # Low temperature for consistent outputs
58
+ do_sample=False, # Deterministic outputs
59
+ pad_token_id=gemma_tokenizer.eos_token_id
60
  )
61
 
62
  # Load stock symbols from CSV
 
75
  return symbols_dict
76
  except Exception as e:
77
  print(f"Error loading symbols CSV: {e}")
78
+ # Provide default symbols for demo
79
  return {
80
  'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'},
81
  'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'},
 
83
  'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'},
84
  'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'},
85
  'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'},
86
+ 'وتوسکا': {'company': 'سرمایه گذاری توسعه توکا', 'bazaar': 'بورس', 'bazaar_group': 'سرمایه گذاریها'},
87
+ 'پی پاد': {'company': 'پرداخت الکترونیک پاسارگاد', 'bazaar': 'بورس', 'bazaar_group': 'رایانه و فعالیت های وابسته'},
88
  }
89
 
90
  # Load symbols
91
  STOCK_SYMBOLS = load_stock_symbols()
92
  SYMBOL_NAMES = set(STOCK_SYMBOLS.keys())
93
 
94
+ # Market context keywords
95
  MARKET_KEYWORDS = {
96
  'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش',
97
  'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد',
98
  'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود',
99
+ 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت'
 
100
  }
101
 
102
+ def use_gemma_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
103
  """
104
+ Use Gemma-2-9b-it to determine if a word is used as a stock symbol
105
  Returns confidence score (0-1)
106
  """
107
  try:
108
+ # Create a focused prompt for Gemma
109
+ prompt = f"""<bos><start_of_turn>user
110
+ You are a Persian financial text analyzer. Determine if the word "{potential_symbol}" in the following Persian text is used as a stock market symbol or as a regular word.
111
+
112
+ Context information:
113
+ - The word "{potential_symbol}" could be a stock symbol for "{symbol_info['company']}" (industry: {symbol_info['bazaar_group']})
114
+ - Stock symbols usually appear with financial terms like: سهام، بورس، معامله، قیمت، خرید، فروش
115
+
116
+ Text to analyze:
117
+ "{text}"
118
+
119
+ Answer ONLY with one of these:
120
+ 1. "STOCK" if it's used as a stock market symbol
121
+ 2. "WORD" if it's used as a regular word
122
+
123
+ Reasoning: Consider the surrounding context. If the text discusses trading, prices, or stock market activities, it's likely a stock symbol. If it discusses the general meaning (like فولاد meaning steel in manufacturing context), it's a regular word.
124
+
125
+ Answer:<end_of_turn>
126
+ <start_of_turn>model
127
+ """
128
 
129
+ # Generate response
130
+ response = gemma_pipeline(
131
+ prompt,
132
+ max_new_tokens=20,
133
+ temperature=0.1,
134
+ do_sample=False,
135
+ return_full_text=False
136
+ )
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Extract the answer
139
+ answer = response[0]['generated_text'].strip().upper()
140
 
141
+ # Determine confidence based on response
142
+ if "STOCK" in answer:
143
+ return 0.9 # High confidence it's a stock symbol
144
+ elif "WORD" in answer:
145
+ return 0.1 # Low confidence it's a stock symbol
146
+ else:
147
+ # If unclear, analyze the response for clues
148
+ if any(keyword in answer.lower() for keyword in ['نماد', 'سهام', 'بورس']):
149
+ return 0.7
150
+ else:
151
+ return 0.3
152
+
153
  except Exception as e:
154
+ print(f"Gemma inference error: {e}")
155
+ return 0.5 # Neutral confidence on error
156
 
157
  def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
158
  """
159
  Check if a potential symbol is actually used as a stock symbol in context
160
+ Using both heuristics and Gemma model
161
  """
162
+ # Get surrounding context
 
 
 
163
  symbol_pos = text.find(potential_symbol)
164
  if symbol_pos == -1:
165
  return False, 0.0
166
 
167
+ start_context = max(0, symbol_pos - 100)
168
+ end_context = min(len(text), symbol_pos + len(potential_symbol) + 100)
169
  context_window = text[start_context:end_context]
170
 
171
+ # Count market keywords
172
  words_in_context = context_window.split()
173
  market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS)
174
 
175
  # Calculate heuristic score
176
+ heuristic_score = min(market_keyword_count * 0.2, 1.0)
177
 
178
+ # Strong heuristic signals
179
+ if market_keyword_count >= 5:
180
+ return True, 0.95
181
+ elif market_keyword_count == 0 and len(words_in_context) > 10:
182
+ return False, 0.05
183
 
184
+ # Use Gemma for disambiguation
185
+ gemma_score = use_gemma_for_disambiguation(context_window, potential_symbol, symbol_info)
186
 
187
+ # Combine scores (give more weight to Gemma as it understands context better)
188
+ final_score = (heuristic_score * 0.2 + gemma_score * 0.8)
189
 
190
  # Decision threshold
191
  is_stock = final_score > 0.5
 
193
  return is_stock, final_score
194
 
195
  def find_stock_symbols_in_text(text: str) -> List[Dict]:
196
+ """Find and validate stock symbols in text"""
197
  found_symbols = []
198
+ processed_positions = set() # To avoid duplicate processing
199
 
200
+ # Pattern to match Persian/Arabic words
 
201
  pattern = r'\b[\u0600-\u06FF]+\b'
202
 
203
  for match in re.finditer(pattern, text):
204
  word = match.group()
205
 
206
+ if word in SYMBOL_NAMES and match.start() not in processed_positions:
207
  symbol_info = STOCK_SYMBOLS[word]
208
 
209
+ # Check context using Gemma
210
  is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
211
 
212
  if is_stock:
 
220
  'bazaar': symbol_info['bazaar'],
221
  'bazaar_group': symbol_info['bazaar_group']
222
  })
223
+ processed_positions.add(match.start())
224
 
225
  return found_symbols
226
 
227
  # Label colors and names
228
  label_colors = {
229
+ "B-PER": "#FF6B6B",
230
  "I-PER": "#FFB3B3",
231
+ "B-ORG": "#4ECDC4",
232
  "I-ORG": "#A7E9E4",
233
+ "B-LOC": "#95E1D3",
234
  "I-LOC": "#C7F0E8",
235
+ "B-DAT": "#FFA07A",
236
  "I-DAT": "#FFDAB9",
237
+ "B-TIM": "#DDA0DD",
238
  "I-TIM": "#E6D0E6",
239
+ "B-MON": "#FFD700",
240
  "I-MON": "#FFEB99",
241
+ "B-PCT": "#87CEEB",
242
  "I-PCT": "#B3DFEF",
243
+ "STOCK": "#00FA9A",
244
  }
245
 
246
  label_names = {
 
258
  """Merge entities, removing overlaps (stock symbols take precedence)"""
259
  all_entities = []
260
 
261
+ # Add stock entities first
262
  all_entities.extend(stock_entities)
263
 
264
+ # Add NER entities that don't overlap
265
  for ner_ent in entities:
266
  overlap = False
267
  for stock_ent in stock_entities:
 
268
  if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']):
269
  overlap = True
270
  break
 
278
  if not all_entities:
279
  return text
280
 
 
281
  entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True)
282
 
283
  result = text
 
288
  word = text[start:end]
289
  score = entity['score']
290
 
 
291
  color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC")
292
 
 
293
  tooltip_info = f"{label} (confidence: {score:.2f})"
294
  if label == 'STOCK':
295
  company = entity.get('company', '')
 
297
  if company:
298
  tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})"
299
 
 
300
  highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
301
 
302
  result = result[:start] + highlighted + result[end:]
 
312
  # Perform standard NER
313
  entities = ner_pipeline(text)
314
 
315
+ # Find stock symbols using Gemma
316
  stock_entities = find_stock_symbols_in_text(text)
317
 
318
+ # Merge entities
319
  all_entities = merge_overlapping_entities(entities, stock_entities)
320
 
321
+ # Create highlighted HTML
322
  highlighted_html = f"""
323
  <div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5;
324
  padding: 20px; border: 1px solid #ddd; border-radius: 5px;
 
333
  entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n"
334
  entity_info += "|:------------|:-----------|:------------------|:---------------------|\n"
335
 
 
336
  all_entities.sort(key=lambda x: x['start'])
337
 
338
  for ent in all_entities:
 
347
  else:
348
  entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
349
 
350
+ # Statistics
351
  stats = f"\n\n### آمار (Statistics):\n"
352
  stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n"
353
  stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n"
 
360
  except Exception as e:
361
  return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e)
362
 
363
+ # Examples
364
  examples = [
365
  ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."],
366
  ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."],
 
371
  ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."],
372
  ]
373
 
374
+ # Gradio interface
375
  with gr.Blocks(
376
  title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی",
377
  theme=gr.themes.Soft(),
 
380
  """
381
  ) as demo:
382
  gr.Markdown("""
383
+ # 🏦 شناسایی هوشمند موجودیت‌ها و نمادهای بورس ایران
384
+ ## Persian Named Entity Recognition with Stock Symbol Detection
385
+ ### Powered by Google Gemma-2-9B-IT
386
 
387
  <div class="rtl-text">
388
+ این برنامه با استفاده از مدل قدرتمند Gemma-2-9B، متن‌های فارسی را تحلیل کرده و موجودیت‌های مختلف را شناسایی می‌کند.
 
 
 
 
 
389
  </div>
 
 
390
  """)
391
 
392
  with gr.Row():
393
  with gr.Column(scale=6):
394
  input_text = gr.Textbox(
395
+ label="متن فارسی را وارد کنید (Enter Persian Text)",
396
  placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...",
397
  lines=6,
398
  rtl=True,
 
421
  examples_per_page=4
422
  )
423
 
424
+ # User guide
425
+ with gr.Accordion("📖 راهنمای استفاده (User Guide)", open=True):
426
  gr.Markdown("""
427
  <div class="rtl-text">
428
 
429
+ ## چگونه از این برنامه استفاده کنیم؟
 
 
 
 
 
 
 
 
 
430
 
431
+ 1. **متن فارسی خود را در کادر بالا وارد کنید**
432
+ 2. **دکمه «تحلیل متن» را بزنید**
433
+ 3. **نتایج را در دو بخش مشاهده کنید:**
434
+ - متن با موجودیت‌های رنگی شده
435
+ - جدول کامل موجودیت‌ها با جزئیات
436
+
437
+ ## انواع موجودیت‌هایی که شناسایی می‌شوند:
438
+
439
+ | رنگ | نوع | مثال |
440
+ |:---:|:----|:-----|
441
+ | 🔴 قرمز | **اشخاص** | علی احمدی، مریم رضایی |
442
+ | 🔵 آبی | **سازمان‌ها** | شرکت ملی نفت، بانک ملت |
443
+ | 🟢 سبز | **مکان‌ها** | تهران، اصفهان، ایران |
444
+ | 🟠 نارنجی | **تاریخ‌ها** | ۱۵ خرداد ۱۴۰۳ |
445
+ | 🟣 بنفش | **زمان‌ها** | ساعت ۱۰ صبح |
446
+ | 🟡 زرد | **مبالغ پولی** | ۱۰۰۰ ریال، ۵ میلیارد تومان |
447
+ | 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ |
448
+ | 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا |
449
+
450
+ ## ویژگی خاص: تشخیص هوشمند با Gemma-2-9B
451
+
452
+ این برنامه از **مدل Gemma-2-9B** گوگل استفاده می‌کند که:
453
+ - درک عمیق از زبان فارسی دارد
454
+ - متن را به صورت کامل تحلیل می‌کند
455
+ - بین نماد بورسی و کلمه عادی تمایز قائل می‌شود
456
+
457
+ **مثال:**
458
+ - «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅
459
+ - «صنعت **فولاد** در کشور مهم است» ← فولاد = کلمه عادی ❌
460
 
461
+ ## نحوه تفسیر نتایج:
462
 
463
+ - **رنگ‌ها**: نوع موجودیت را نشان می‌دهند
464
+ - **برچسب‌ها**: نوع موجودیت به صورت مختصر
465
+ - **درصد اطمینان**: میزان اطمینان سیستم (۰-۱۰۰٪)
466
+ - **جزئیات نمادها**: نام شرکت، بازار و گروه صنعت
467
 
468
+ ## مدل‌های استفاده شده:
469
 
470
+ - **ParsBERT NER**: شناسایی موجودیت‌های عمومی
471
+ - **Google Gemma-2-9B-IT**: تحلیل هوشمند متن و تشخیص نمادهای بورسی
472
 
473
+ ⚠️ **توجه**: مدل Gemma به دلیل حجم بالا (9 میلیارد پارامتر) ممکن است کمی کندتر باشد
474
 
475
  </div>
476
  """)
 
493
  outputs=[output_html, output_entities]
494
  )
495
 
496
+ # Launch
497
  if __name__ == "__main__":
498
  print("Starting Persian NER + Stock Symbol Detection System...")
499
  print(f"Using device: {device}")
500
  print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
501
+ print("Models:")
502
+ print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased")
503
+ print(" - Context Understanding: Google Gemma-2-9B-IT")
504
+ print("\nNote: Gemma-2-9B is a large model. First run may take time to download.")
505
+ print("For better performance, consider using GPU if available.")
506
  demo.launch(
507
  share=False,
508
  debug=True