optimopium commited on
Commit
2b32561
·
verified ·
1 Parent(s): 4865162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -95
app.py CHANGED
@@ -11,7 +11,7 @@ import re
11
  from typing import List, Dict, Tuple
12
  import numpy as np
13
 
14
- # Set device and dtype for optimization
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
 
@@ -31,33 +31,41 @@ ner_pipeline = pipeline(
31
  aggregation_strategy="simple"
32
  )
33
 
34
- # Load Gemma model for stock symbol detection
35
- print("Loading Gemma-2-9b-it model for context understanding...")
36
- gemma_model_name = "google/gemma-2-9b-it"
 
37
 
38
- # Load with optimization settings for better performance
39
- gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_name)
40
- gemma_model = AutoModelForCausalLM.from_pretrained(
41
- gemma_model_name,
42
- torch_dtype=dtype,
43
- device_map="auto" if torch.cuda.is_available() else None,
44
- low_cpu_mem_usage=True
45
- )
46
-
47
- if device == "cpu":
48
- gemma_model = gemma_model.to(device)
49
-
50
- # Create text generation pipeline
51
- gemma_pipeline = pipeline(
52
- "text-generation",
53
- model=gemma_model,
54
- tokenizer=gemma_tokenizer,
55
- device=0 if device == "cuda" else -1,
56
- max_new_tokens=50,
57
- temperature=0.1, # Low temperature for consistent outputs
58
- do_sample=False, # Deterministic outputs
59
- pad_token_id=gemma_tokenizer.eos_token_id
60
- )
 
 
 
 
 
 
 
61
 
62
  # Load stock symbols from CSV
63
  def load_stock_symbols(csv_path="symbols.csv"):
@@ -99,65 +107,101 @@ MARKET_KEYWORDS = {
99
  'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت'
100
  }
101
 
102
- def use_gemma_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
103
  """
104
- Use Gemma-2-9b-it to determine if a word is used as a stock symbol
105
  Returns confidence score (0-1)
106
  """
 
 
 
 
107
  try:
108
- # Create a focused prompt for Gemma
109
- prompt = f"""<bos><start_of_turn>user
110
- You are a Persian financial text analyzer. Determine if the word "{potential_symbol}" in the following Persian text is used as a stock market symbol or as a regular word.
111
 
112
- Context information:
113
- - The word "{potential_symbol}" could be a stock symbol for "{symbol_info['company']}" (industry: {symbol_info['bazaar_group']})
114
- - Stock symbols usually appear with financial terms like: سهام، بورس، معامله، قیمت، خرید، فروش
115
 
116
- Text to analyze:
117
- "{text}"
118
 
119
- Answer ONLY with one of these:
120
- 1. "STOCK" if it's used as a stock market symbol
121
- 2. "WORD" if it's used as a regular word
122
-
123
- Reasoning: Consider the surrounding context. If the text discusses trading, prices, or stock market activities, it's likely a stock symbol. If it discusses the general meaning (like فولاد meaning steel in manufacturing context), it's a regular word.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- Answer:<end_of_turn>
126
- <start_of_turn>model
127
- """
128
-
129
- # Generate response
130
- response = gemma_pipeline(
131
- prompt,
132
- max_new_tokens=20,
133
- temperature=0.1,
134
- do_sample=False,
135
- return_full_text=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
 
138
- # Extract the answer
139
- answer = response[0]['generated_text'].strip().upper()
 
140
 
141
- # Determine confidence based on response
142
- if "STOCK" in answer:
143
- return 0.9 # High confidence it's a stock symbol
144
- elif "WORD" in answer:
145
- return 0.1 # Low confidence it's a stock symbol
146
  else:
147
- # If unclear, analyze the response for clues
148
- if any(keyword in answer.lower() for keyword in ['نماد', 'سهام', 'بورس']):
149
- return 0.7
150
- else:
151
- return 0.3
152
-
153
  except Exception as e:
154
- print(f"Gemma inference error: {e}")
155
- return 0.5 # Neutral confidence on error
156
 
157
  def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
158
  """
159
  Check if a potential symbol is actually used as a stock symbol in context
160
- Using both heuristics and Gemma model
161
  """
162
  # Get surrounding context
163
  symbol_pos = text.find(potential_symbol)
@@ -181,11 +225,14 @@ def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Di
181
  elif market_keyword_count == 0 and len(words_in_context) > 10:
182
  return False, 0.05
183
 
184
- # Use Gemma for disambiguation
185
- gemma_score = use_gemma_for_disambiguation(context_window, potential_symbol, symbol_info)
 
 
 
186
 
187
- # Combine scores (give more weight to Gemma as it understands context better)
188
- final_score = (heuristic_score * 0.2 + gemma_score * 0.8)
189
 
190
  # Decision threshold
191
  is_stock = final_score > 0.5
@@ -195,7 +242,7 @@ def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Di
195
  def find_stock_symbols_in_text(text: str) -> List[Dict]:
196
  """Find and validate stock symbols in text"""
197
  found_symbols = []
198
- processed_positions = set() # To avoid duplicate processing
199
 
200
  # Pattern to match Persian/Arabic words
201
  pattern = r'\b[\u0600-\u06FF]+\b'
@@ -206,7 +253,7 @@ def find_stock_symbols_in_text(text: str) -> List[Dict]:
206
  if word in SYMBOL_NAMES and match.start() not in processed_positions:
207
  symbol_info = STOCK_SYMBOLS[word]
208
 
209
- # Check context using Gemma
210
  is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
211
 
212
  if is_stock:
@@ -255,13 +302,10 @@ label_names = {
255
  }
256
 
257
  def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
258
- """Merge entities, removing overlaps (stock symbols take precedence)"""
259
  all_entities = []
260
-
261
- # Add stock entities first
262
  all_entities.extend(stock_entities)
263
 
264
- # Add NER entities that don't overlap
265
  for ner_ent in entities:
266
  overlap = False
267
  for stock_ent in stock_entities:
@@ -312,7 +356,7 @@ def perform_ner(text):
312
  # Perform standard NER
313
  entities = ner_pipeline(text)
314
 
315
- # Find stock symbols using Gemma
316
  stock_entities = find_stock_symbols_in_text(text)
317
 
318
  # Merge entities
@@ -379,13 +423,13 @@ with gr.Blocks(
379
  .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
380
  """
381
  ) as demo:
382
- gr.Markdown("""
383
  # 🏦 شناسایی هوشمند موجودیت‌ها و نمادهای بورس ایران
384
  ## Persian Named Entity Recognition with Stock Symbol Detection
385
- ### Powered by Google Gemma-2-9B-IT
386
 
387
  <div class="rtl-text">
388
- این برنامه با استفاده از مدل قدرتمند Gemma-2-9B، متن‌های فارسی را تحلیل کرده و موجودیت‌های مختلف را شناسایی می‌کند.
389
  </div>
390
  """)
391
 
@@ -447,12 +491,9 @@ with gr.Blocks(
447
  | 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ |
448
  | 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا |
449
 
450
- ## ویژگی خاص: تشخیص هوشمند با Gemma-2-9B
451
 
452
- این برنامه از **مدل Gemma-2-9B** گوگل استفاده می‌کند که:
453
- - درک عمیق از زبان فارسی دارد
454
- - متن را به صورت کامل تحلیل می‌کند
455
- - بین نماد بورسی و کلمه عادی تمایز قائل می‌شود
456
 
457
  **مثال:**
458
  - «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅
@@ -468,9 +509,7 @@ with gr.Blocks(
468
  ## مدل‌های استفاده شده:
469
 
470
  - **ParsBERT NER**: شناسایی موجودیت‌های عمومی
471
- - **Google Gemma-2-9B-IT**: تحلیل هوشمند متن و تشخیص نمادهای بورسی
472
-
473
- ⚠️ **توجه**: مدل Gemma به دلیل حجم بالا (9 میلیارد پارامتر) ممکن است کمی کندتر باشد
474
 
475
  </div>
476
  """)
@@ -498,11 +537,9 @@ if __name__ == "__main__":
498
  print("Starting Persian NER + Stock Symbol Detection System...")
499
  print(f"Using device: {device}")
500
  print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
501
- print("Models:")
502
  print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased")
503
- print(" - Context Understanding: Google Gemma-2-9B-IT")
504
- print("\nNote: Gemma-2-9B is a large model. First run may take time to download.")
505
- print("For better performance, consider using GPU if available.")
506
  demo.launch(
507
  share=False,
508
  debug=True
 
11
  from typing import List, Dict, Tuple
12
  import numpy as np
13
 
14
+ # Set device
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
17
 
 
31
  aggregation_strategy="simple"
32
  )
33
 
34
+ # Load a smaller, open model for context understanding
35
+ print("Loading context understanding model...")
36
+ # Using Microsoft Phi-2 (small and efficient) or Mistral-7B-Instruct (if you have more resources)
37
+ context_model_name = "microsoft/phi-2" # 2.7B parameters, works well on CPU
38
 
39
+ try:
40
+ context_tokenizer = AutoTokenizer.from_pretrained(context_model_name, trust_remote_code=True)
41
+ context_model = AutoModelForCausalLM.from_pretrained(
42
+ context_model_name,
43
+ torch_dtype=dtype,
44
+ trust_remote_code=True,
45
+ device_map="auto" if torch.cuda.is_available() else None
46
+ )
47
+ if device == "cpu":
48
+ context_model = context_model.to(device)
49
+
50
+ # Set pad token if not set
51
+ if context_tokenizer.pad_token is None:
52
+ context_tokenizer.pad_token = context_tokenizer.eos_token
53
+
54
+ use_llm_model = True
55
+ print(f"Successfully loaded {context_model_name}")
56
+
57
+ except Exception as e:
58
+ print(f"Could not load Phi-2 model: {e}")
59
+ print("Falling back to zero-shot classification model...")
60
+
61
+ # Fallback to mDeBERTa for zero-shot classification
62
+ context_model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
63
+ classifier = pipeline(
64
+ "zero-shot-classification",
65
+ model=context_model_name,
66
+ device=0 if device == "cuda" else -1
67
+ )
68
+ use_llm_model = False
69
 
70
  # Load stock symbols from CSV
71
  def load_stock_symbols(csv_path="symbols.csv"):
 
107
  'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت'
108
  }
109
 
110
+ def use_phi_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
111
  """
112
+ Use Phi-2 model to determine if a word is used as a stock symbol
113
  Returns confidence score (0-1)
114
  """
115
+ if not use_llm_model:
116
+ # Use zero-shot classification instead
117
+ return use_zero_shot_classification(text, potential_symbol, symbol_info)
118
+
119
  try:
120
+ # Create a simple prompt for Phi-2
121
+ prompt = f"""Analyze this Persian text and determine if "{potential_symbol}" is used as a stock market symbol.
 
122
 
123
+ Context: "{potential_symbol}" could be a stock symbol for {symbol_info['company']} company.
 
 
124
 
125
+ Text: {text}
 
126
 
127
+ Answer with only "STOCK" if it's a stock symbol, or "WORD" if it's a regular word:
128
+ Answer: """
129
+
130
+ # Tokenize and generate
131
+ inputs = context_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
132
+ inputs = {k: v.to(device) for k, v in inputs.items()}
133
+
134
+ with torch.no_grad():
135
+ outputs = context_model.generate(
136
+ **inputs,
137
+ max_new_tokens=10,
138
+ temperature=0.1,
139
+ do_sample=False,
140
+ pad_token_id=context_tokenizer.eos_token_id
141
+ )
142
+
143
+ # Decode the response
144
+ response = context_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
145
+ response = response.strip().upper()
146
+
147
+ # Parse response
148
+ if "STOCK" in response:
149
+ return 0.9
150
+ elif "WORD" in response:
151
+ return 0.1
152
+ else:
153
+ # Ambiguous response, use heuristics
154
+ return 0.5
155
+
156
+ except Exception as e:
157
+ print(f"Phi-2 inference error: {e}")
158
+ return 0.5
159
 
160
+ def use_zero_shot_classification(text: str, potential_symbol: str, symbol_info: Dict) -> float:
161
+ """
162
+ Fallback: Use zero-shot classification to determine if a word is a stock symbol
163
+ """
164
+ try:
165
+ # Get context around the symbol
166
+ symbol_pos = text.find(potential_symbol)
167
+ if symbol_pos == -1:
168
+ return 0.5
169
+
170
+ start = max(0, symbol_pos - 100)
171
+ end = min(len(text), symbol_pos + len(potential_symbol) + 100)
172
+ context_text = text[start:end]
173
+
174
+ # Define candidate labels
175
+ candidate_labels = [
176
+ f"نماد بورسی {symbol_info['company']}",
177
+ f"کلمه عادی {potential_symbol}",
178
+ "stock market symbol",
179
+ "regular word"
180
+ ]
181
+
182
+ # Perform classification
183
+ result = classifier(
184
+ context_text,
185
+ candidate_labels=candidate_labels,
186
+ multi_label=False
187
  )
188
 
189
+ # Check top label
190
+ top_label = result['labels'][0]
191
+ top_score = result['scores'][0]
192
 
193
+ if 'نماد بورسی' in top_label or 'stock' in top_label.lower():
194
+ return top_score
 
 
 
195
  else:
196
+ return 1 - top_score
197
+
 
 
 
 
198
  except Exception as e:
199
+ print(f"Classification error: {e}")
200
+ return 0.5
201
 
202
  def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
203
  """
204
  Check if a potential symbol is actually used as a stock symbol in context
 
205
  """
206
  # Get surrounding context
207
  symbol_pos = text.find(potential_symbol)
 
225
  elif market_keyword_count == 0 and len(words_in_context) > 10:
226
  return False, 0.05
227
 
228
+ # Use AI model for disambiguation
229
+ if use_llm_model:
230
+ ai_score = use_phi_for_disambiguation(context_window, potential_symbol, symbol_info)
231
+ else:
232
+ ai_score = use_zero_shot_classification(context_window, potential_symbol, symbol_info)
233
 
234
+ # Combine scores
235
+ final_score = (heuristic_score * 0.3 + ai_score * 0.7)
236
 
237
  # Decision threshold
238
  is_stock = final_score > 0.5
 
242
  def find_stock_symbols_in_text(text: str) -> List[Dict]:
243
  """Find and validate stock symbols in text"""
244
  found_symbols = []
245
+ processed_positions = set()
246
 
247
  # Pattern to match Persian/Arabic words
248
  pattern = r'\b[\u0600-\u06FF]+\b'
 
253
  if word in SYMBOL_NAMES and match.start() not in processed_positions:
254
  symbol_info = STOCK_SYMBOLS[word]
255
 
256
+ # Check context
257
  is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
258
 
259
  if is_stock:
 
302
  }
303
 
304
  def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
305
+ """Merge entities, removing overlaps"""
306
  all_entities = []
 
 
307
  all_entities.extend(stock_entities)
308
 
 
309
  for ner_ent in entities:
310
  overlap = False
311
  for stock_ent in stock_entities:
 
356
  # Perform standard NER
357
  entities = ner_pipeline(text)
358
 
359
+ # Find stock symbols
360
  stock_entities = find_stock_symbols_in_text(text)
361
 
362
  # Merge entities
 
423
  .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
424
  """
425
  ) as demo:
426
+ gr.Markdown(f"""
427
  # 🏦 شناسایی هوشمند موجودیت‌ها و نمادهای بورس ایران
428
  ## Persian Named Entity Recognition with Stock Symbol Detection
429
+ ### Using {context_model_name.split('/')[-1]} for Context Understanding
430
 
431
  <div class="rtl-text">
432
+ این برنامه متن‌های فارسی را تحلیل کرده و موجودیت‌های مختلف را شناسایی می‌کند.
433
  </div>
434
  """)
435
 
 
491
  | 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ |
492
  | 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا |
493
 
494
+ ## ویژگی خاص: تشخیص هوشمند نمادهای بورسی
495
 
496
+ برنامه با استفاده از **هوش مصنوعی** تشخیص می‌دهد که آیا یک کلمه نماد بورسی است یا خیر.
 
 
 
497
 
498
  **مثال:**
499
  - «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅
 
509
  ## مدل‌های استفاده شده:
510
 
511
  - **ParsBERT NER**: شناسایی موجودیت‌های عمومی
512
+ - **Microsoft Phi-2 / mDeBERTa**: تحلیل هوشمند متن برای تشخیص نمادهای بورسی
 
 
513
 
514
  </div>
515
  """)
 
537
  print("Starting Persian NER + Stock Symbol Detection System...")
538
  print(f"Using device: {device}")
539
  print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
540
+ print("Models loaded:")
541
  print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased")
542
+ print(f" - Context: {context_model_name}")
 
 
543
  demo.launch(
544
  share=False,
545
  debug=True