Update app.py
Browse files
app.py
CHANGED
|
@@ -1,83 +1,225 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
|
|
|
| 4 |
import re
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
# Set device to CPU explicitly
|
| 9 |
-
device = "cpu"
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# Create NER pipeline
|
| 20 |
ner_pipeline = pipeline(
|
| 21 |
"ner",
|
| 22 |
-
model=
|
| 23 |
-
tokenizer=
|
| 24 |
-
device
|
| 25 |
-
aggregation_strategy="simple"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
-
# Load stock symbols from CSV
|
| 29 |
-
def
|
| 30 |
-
"""Load stock symbols from CSV file"""
|
| 31 |
-
stock_symbols = {}
|
| 32 |
try:
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
# Load
|
| 45 |
-
STOCK_SYMBOLS =
|
|
|
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
"شهر": "LOC",
|
| 55 |
-
"کشور": "LOC",
|
| 56 |
-
"استان": "LOC",
|
| 57 |
-
"آقای": "PER",
|
| 58 |
-
"خانم": "PER",
|
| 59 |
-
"دکتر": "PER",
|
| 60 |
-
"مهندس": "PER",
|
| 61 |
}
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
label_colors = {
|
| 65 |
-
"B-PER": "#FF6B6B",
|
| 66 |
"I-PER": "#FFB3B3",
|
| 67 |
-
"B-ORG": "#4ECDC4",
|
| 68 |
"I-ORG": "#A7E9E4",
|
| 69 |
-
"B-LOC": "#95E1D3",
|
| 70 |
"I-LOC": "#C7F0E8",
|
| 71 |
-
"B-DAT": "#FFA07A",
|
| 72 |
"I-DAT": "#FFDAB9",
|
| 73 |
-
"B-TIM": "#DDA0DD",
|
| 74 |
"I-TIM": "#E6D0E6",
|
| 75 |
-
"B-MON": "#FFD700",
|
| 76 |
"I-MON": "#FFEB99",
|
| 77 |
-
"B-PCT": "#87CEEB",
|
| 78 |
"I-PCT": "#B3DFEF",
|
| 79 |
-
"
|
| 80 |
-
"HYP": "#A9A9A9", # Hypernym - Dark Gray
|
| 81 |
}
|
| 82 |
|
| 83 |
label_names = {
|
|
@@ -88,98 +230,36 @@ label_names = {
|
|
| 88 |
"TIM": "زمان (Time)",
|
| 89 |
"MON": "پول (Money)",
|
| 90 |
"PCT": "درصد (Percent)",
|
| 91 |
-
"
|
| 92 |
-
"HYP": "واژه عمومی (Hypernym)",
|
| 93 |
}
|
| 94 |
|
| 95 |
-
def
|
| 96 |
-
"""
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# Split text into words
|
| 100 |
-
words = re.findall(r'[\u0600-\u06FF]+', text)
|
| 101 |
-
|
| 102 |
-
for word in words:
|
| 103 |
-
if word in STOCK_SYMBOLS:
|
| 104 |
-
# Find all occurrences of this symbol in the text
|
| 105 |
-
for match in re.finditer(re.escape(word), text):
|
| 106 |
-
stock_entities.append({
|
| 107 |
-
'entity_group': 'STK',
|
| 108 |
-
'word': word,
|
| 109 |
-
'start': match.start(),
|
| 110 |
-
'end': match.end(),
|
| 111 |
-
'score': 0.99, # High confidence for dictionary match
|
| 112 |
-
'full_name': STOCK_SYMBOLS[word]
|
| 113 |
-
})
|
| 114 |
-
|
| 115 |
-
return stock_entities
|
| 116 |
-
|
| 117 |
-
def detect_hypernyms(text, entities):
|
| 118 |
-
"""Detect hypernyms (general terms) in text and classify them"""
|
| 119 |
-
hypernym_entities = []
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
start, end = match.start(), match.end()
|
| 124 |
-
|
| 125 |
-
# Check if this position already has a specific entity
|
| 126 |
-
is_covered = False
|
| 127 |
-
for ent in entities:
|
| 128 |
-
if start >= ent['start'] and end <= ent['end']:
|
| 129 |
-
is_covered = True
|
| 130 |
-
break
|
| 131 |
-
|
| 132 |
-
if not is_covered:
|
| 133 |
-
hypernym_entities.append({
|
| 134 |
-
'entity_group': 'HYP',
|
| 135 |
-
'word': hypernym,
|
| 136 |
-
'start': start,
|
| 137 |
-
'end': end,
|
| 138 |
-
'score': 0.95,
|
| 139 |
-
'base_type': entity_type,
|
| 140 |
-
'is_hypernym': True
|
| 141 |
-
})
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
all_entities = entities + stock_entities + hypernym_entities
|
| 148 |
-
|
| 149 |
-
# Sort by start position
|
| 150 |
-
all_entities.sort(key=lambda x: x['start'])
|
| 151 |
-
|
| 152 |
-
# Remove overlapping entities (keep higher priority)
|
| 153 |
-
# Priority: STK > specific entities > HYP
|
| 154 |
-
filtered_entities = []
|
| 155 |
-
for entity in all_entities:
|
| 156 |
-
overlaps = False
|
| 157 |
-
for existing in filtered_entities:
|
| 158 |
# Check for overlap
|
| 159 |
-
if not (
|
| 160 |
-
|
| 161 |
-
# If new entity is stock symbol, replace existing
|
| 162 |
-
if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK':
|
| 163 |
-
filtered_entities.remove(existing)
|
| 164 |
-
overlaps = False
|
| 165 |
-
# If existing is hypernym and new is specific, replace
|
| 166 |
-
elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP':
|
| 167 |
-
filtered_entities.remove(existing)
|
| 168 |
-
overlaps = False
|
| 169 |
break
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
filtered_entities.append(entity)
|
| 173 |
|
| 174 |
-
return
|
| 175 |
|
| 176 |
-
def highlight_entities(text,
|
| 177 |
"""Create HTML with highlighted entities"""
|
| 178 |
-
if not
|
| 179 |
return text
|
| 180 |
|
| 181 |
-
# Sort entities by start position (reverse order
|
| 182 |
-
entities_sorted = sorted(
|
| 183 |
|
| 184 |
result = text
|
| 185 |
for entity in entities_sorted:
|
|
@@ -190,177 +270,197 @@ def highlight_entities(text, entities):
|
|
| 190 |
score = entity['score']
|
| 191 |
|
| 192 |
# Get color for this label
|
| 193 |
-
if label == '
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
color = label_colors.get(f"B-{label}", "#CCCCCC")
|
| 203 |
-
title_text = f"{label} (confidence: {score:.2f})"
|
| 204 |
|
| 205 |
# Create highlighted span
|
| 206 |
-
highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{
|
| 207 |
|
| 208 |
result = result[:start] + highlighted + result[end:]
|
| 209 |
|
| 210 |
return result
|
| 211 |
|
| 212 |
def perform_ner(text):
|
| 213 |
-
"""Perform NER
|
| 214 |
if not text.strip():
|
| 215 |
return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
|
| 216 |
|
| 217 |
try:
|
| 218 |
-
# Perform
|
| 219 |
entities = ner_pipeline(text)
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
stock_entities =
|
| 223 |
|
| 224 |
-
#
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
# Merge all entities
|
| 228 |
-
all_entities = merge_entities(entities, stock_entities, hypernym_entities)
|
| 229 |
|
| 230 |
# Create highlighted version
|
| 231 |
-
highlighted_html = f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
# Create entities table
|
| 234 |
if all_entities:
|
| 235 |
entity_info = "### موجودیتهای شناسایی شده (Detected Entities):\n\n"
|
| 236 |
-
entity_info += "| کلمه (Word) | نوع (Type) |
|
| 237 |
-
entity_info += "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
for ent in all_entities:
|
| 239 |
label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n"
|
| 248 |
else:
|
| 249 |
entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
|
| 250 |
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
except Exception as e:
|
| 254 |
-
return f"<p style='color: red;'>خطا (Error): {str(e)}</p>",
|
| 255 |
-
|
| 256 |
-
# Save stock symbols to CSV function
|
| 257 |
-
def save_symbols_to_csv(output_path='symbols.csv'):
|
| 258 |
-
"""Save current stock symbols to CSV file"""
|
| 259 |
-
with open(output_path, 'w', encoding='utf-8', newline='') as f:
|
| 260 |
-
writer = csv.writer(f)
|
| 261 |
-
writer.writerow(['symbol', 'company_name'])
|
| 262 |
-
for symbol, name in STOCK_SYMBOLS.items():
|
| 263 |
-
writer.writerow([symbol, name])
|
| 264 |
-
print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}")
|
| 265 |
|
| 266 |
-
#
|
| 267 |
examples = [
|
| 268 |
-
["
|
| 269 |
-
["
|
| 270 |
-
["
|
| 271 |
-
["
|
| 272 |
-
["
|
| 273 |
-
["
|
| 274 |
-
["
|
| 275 |
-
["آقای احمدی مدیرعامل شرکت پتروشیمی است."],
|
| 276 |
-
["وبملت و فملی امروز در صف خرید قرار گرفتند."],
|
| 277 |
]
|
| 278 |
|
| 279 |
# Create Gradio interface
|
| 280 |
-
with gr.Blocks(
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
-
|
|
|
|
| 286 |
|
| 287 |
-
|
|
|
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
**Running on:** CPU (may be slow for long texts)
|
| 292 |
|
| 293 |
---
|
| 294 |
-
|
| 295 |
-
### 📊 APIs for Updating Stock Symbols:
|
| 296 |
-
|
| 297 |
-
**Recommended Python Libraries:**
|
| 298 |
-
1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data
|
| 299 |
-
2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM
|
| 300 |
-
3. **tse-dataloader** - Data extraction from Tehran Stock Exchange
|
| 301 |
-
|
| 302 |
-
**Example Usage:**
|
| 303 |
-
```python
|
| 304 |
-
# Using tsetmc-api
|
| 305 |
-
from tsetmc_api import market_watch
|
| 306 |
-
stocks = market_watch.get_market_watch()
|
| 307 |
-
|
| 308 |
-
# Using tehran-stocks
|
| 309 |
-
from tehran_stocks import Stocks
|
| 310 |
-
all_stocks = Stocks.query.all()
|
| 311 |
-
```
|
| 312 |
-
|
| 313 |
-
**Official TSE Website:** https://tse.ir
|
| 314 |
-
**TSETMC Data Portal:** http://www.tsetmc.com
|
| 315 |
""")
|
| 316 |
|
| 317 |
with gr.Row():
|
| 318 |
-
with gr.Column():
|
| 319 |
input_text = gr.Textbox(
|
| 320 |
-
label="متن فارسی
|
| 321 |
-
placeholder="مثال: سهام فولاد
|
| 322 |
-
lines=
|
| 323 |
-
rtl=True
|
|
|
|
| 324 |
)
|
| 325 |
-
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
with gr.Column():
|
| 328 |
-
output_html = gr.HTML(
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
gr.Examples(
|
| 332 |
examples=examples,
|
| 333 |
inputs=input_text,
|
| 334 |
-
label="
|
|
|
|
| 335 |
)
|
| 336 |
|
| 337 |
-
#
|
| 338 |
-
gr.
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
# Event
|
| 358 |
submit_btn.click(
|
| 359 |
fn=perform_ner,
|
| 360 |
inputs=input_text,
|
| 361 |
outputs=[output_html, output_entities]
|
| 362 |
)
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
input_text.submit(
|
| 365 |
fn=perform_ner,
|
| 366 |
inputs=input_text,
|
|
@@ -369,4 +469,10 @@ with gr.Blocks(title="Persian NER - شناسایی موجودیتهای نا
|
|
| 369 |
|
| 370 |
# Launch the app
|
| 371 |
if __name__ == "__main__":
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForTokenClassification,
|
| 5 |
+
AutoModelForSequenceClassification,
|
| 6 |
+
AutoModelForQuestionAnswering,
|
| 7 |
+
pipeline
|
| 8 |
+
)
|
| 9 |
import torch
|
| 10 |
+
import pandas as pd
|
| 11 |
import re
|
| 12 |
+
from typing import List, Dict, Tuple
|
| 13 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
# Set device
|
| 16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
|
| 18 |
+
# Load the NER model and tokenizer
|
| 19 |
+
print("Loading NER model...")
|
| 20 |
+
ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
|
| 21 |
+
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
|
| 22 |
+
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
|
| 23 |
+
ner_model.to(device)
|
| 24 |
|
| 25 |
# Create NER pipeline
|
| 26 |
ner_pipeline = pipeline(
|
| 27 |
"ner",
|
| 28 |
+
model=ner_model,
|
| 29 |
+
tokenizer=ner_tokenizer,
|
| 30 |
+
device=0 if device == "cuda" else -1,
|
| 31 |
+
aggregation_strategy="simple"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Load Persian LLM for context understanding
|
| 35 |
+
print("Loading Persian context model...")
|
| 36 |
+
# Using a Persian BERT model fine-tuned for question answering/text classification
|
| 37 |
+
context_model_name = "persiannlp/mt5-small-parsinlu-multiple-choice" # Alternative model
|
| 38 |
+
# You can also try: "HooshvareLab/bert-fa-base-uncased" with custom classification head
|
| 39 |
+
|
| 40 |
+
# For stock symbol disambiguation, we'll use a QA approach
|
| 41 |
+
qa_model_name = "mohammadehab/persian-qa-bert"
|
| 42 |
+
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
|
| 43 |
+
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
|
| 44 |
+
qa_model.to(device)
|
| 45 |
+
|
| 46 |
+
qa_pipeline = pipeline(
|
| 47 |
+
"question-answering",
|
| 48 |
+
model=qa_model,
|
| 49 |
+
tokenizer=qa_tokenizer,
|
| 50 |
+
device=0 if device == "cuda" else -1
|
| 51 |
)
|
| 52 |
|
| 53 |
+
# Load stock symbols from CSV
|
| 54 |
+
def load_stock_symbols(csv_path="symbols.csv"):
|
| 55 |
+
"""Load Iranian stock market symbols from CSV file"""
|
|
|
|
| 56 |
try:
|
| 57 |
+
df = pd.read_csv(csv_path, encoding='utf-8')
|
| 58 |
+
symbols_dict = {}
|
| 59 |
+
for _, row in df.iterrows():
|
| 60 |
+
symbol = row['symbol']
|
| 61 |
+
symbols_dict[symbol] = {
|
| 62 |
+
'company': row['company_name'],
|
| 63 |
+
'bazaar': row['bazaar'],
|
| 64 |
+
'bazaar_group': row['bazaar_group']
|
| 65 |
+
}
|
| 66 |
+
return symbols_dict
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Error loading symbols CSV: {e}")
|
| 69 |
+
# Provide some default symbols for demo
|
| 70 |
+
return {
|
| 71 |
+
'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'},
|
| 72 |
+
'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'},
|
| 73 |
+
'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'},
|
| 74 |
+
'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآوردههای نفتی'},
|
| 75 |
+
'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'},
|
| 76 |
+
'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'},
|
| 77 |
+
}
|
| 78 |
|
| 79 |
+
# Load symbols
|
| 80 |
+
STOCK_SYMBOLS = load_stock_symbols()
|
| 81 |
+
SYMBOL_NAMES = set(STOCK_SYMBOLS.keys())
|
| 82 |
|
| 83 |
+
# Market context keywords for Persian
|
| 84 |
+
MARKET_KEYWORDS = {
|
| 85 |
+
'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش',
|
| 86 |
+
'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایهگذاری', 'پرتفوی', 'نماد',
|
| 87 |
+
'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود',
|
| 88 |
+
'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت',
|
| 89 |
+
'کندل', 'نمودار', 'تحلیل', 'بنیادی', 'تکنیکال', 'حمایت', 'مقاومت'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
}
|
| 91 |
|
| 92 |
+
def use_llm_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
|
| 93 |
+
"""
|
| 94 |
+
Use Persian QA model to determine if a word is used as a stock symbol
|
| 95 |
+
Returns confidence score (0-1)
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
# Create context-aware questions in Persian
|
| 99 |
+
context = text
|
| 100 |
+
|
| 101 |
+
# Ask multiple questions to get better understanding
|
| 102 |
+
questions = [
|
| 103 |
+
f"آیا {potential_symbol} در این متن نماد بورسی است؟",
|
| 104 |
+
f"آیا منظور از {potential_symbol} سهام شرکت {symbol_info['company']} است؟",
|
| 105 |
+
f"آیا {potential_symbol} در این متن به معاملات بورس اشاره دارد؟"
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
scores = []
|
| 109 |
+
for question in questions:
|
| 110 |
+
try:
|
| 111 |
+
result = qa_pipeline(question=question, context=context)
|
| 112 |
+
# Check if the answer contains positive indicators
|
| 113 |
+
answer = result['answer'].lower()
|
| 114 |
+
confidence = result['score']
|
| 115 |
+
|
| 116 |
+
# Adjust confidence based on answer content
|
| 117 |
+
if any(word in answer for word in ['بله', 'است', 'میباشد', 'درست']):
|
| 118 |
+
scores.append(confidence)
|
| 119 |
+
elif any(word in answer for word in ['خیر', 'نیست', 'نمیباشد']):
|
| 120 |
+
scores.append(1 - confidence)
|
| 121 |
+
else:
|
| 122 |
+
scores.append(0.5) # Uncertain
|
| 123 |
+
|
| 124 |
+
except Exception:
|
| 125 |
+
scores.append(0.5) # Default to uncertain
|
| 126 |
+
|
| 127 |
+
# Return average confidence
|
| 128 |
+
return np.mean(scores) if scores else 0.5
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"LLM disambiguation error: {e}")
|
| 132 |
+
return 0.5 # Return neutral confidence on error
|
| 133 |
+
|
| 134 |
+
def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
|
| 135 |
+
"""
|
| 136 |
+
Check if a potential symbol is actually used as a stock symbol in context
|
| 137 |
+
Returns (is_stock_symbol, confidence_score)
|
| 138 |
+
"""
|
| 139 |
+
# First, do a heuristic check
|
| 140 |
+
text_lower = text.lower()
|
| 141 |
+
|
| 142 |
+
# Get surrounding context (window of 30 characters before and after)
|
| 143 |
+
symbol_pos = text.find(potential_symbol)
|
| 144 |
+
if symbol_pos == -1:
|
| 145 |
+
return False, 0.0
|
| 146 |
+
|
| 147 |
+
start_context = max(0, symbol_pos - 50)
|
| 148 |
+
end_context = min(len(text), symbol_pos + len(potential_symbol) + 50)
|
| 149 |
+
context_window = text[start_context:end_context]
|
| 150 |
+
|
| 151 |
+
# Count market keywords in context
|
| 152 |
+
words_in_context = context_window.split()
|
| 153 |
+
market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS)
|
| 154 |
+
|
| 155 |
+
# Calculate heuristic score
|
| 156 |
+
heuristic_score = min(market_keyword_count * 0.3, 1.0)
|
| 157 |
+
|
| 158 |
+
# If very strong or very weak signal from heuristics, use that
|
| 159 |
+
if heuristic_score >= 0.9:
|
| 160 |
+
return True, heuristic_score
|
| 161 |
+
elif heuristic_score == 0 and len(words_in_context) > 5:
|
| 162 |
+
return False, 0.1
|
| 163 |
+
|
| 164 |
+
# For ambiguous cases, use LLM
|
| 165 |
+
llm_score = use_llm_for_disambiguation(text, potential_symbol, symbol_info)
|
| 166 |
+
|
| 167 |
+
# Combine heuristic and LLM scores
|
| 168 |
+
final_score = (heuristic_score * 0.4 + llm_score * 0.6)
|
| 169 |
+
|
| 170 |
+
# Decision threshold
|
| 171 |
+
is_stock = final_score > 0.5
|
| 172 |
+
|
| 173 |
+
return is_stock, final_score
|
| 174 |
+
|
| 175 |
+
def find_stock_symbols_in_text(text: str) -> List[Dict]:
|
| 176 |
+
"""Find and validate stock symbols in text using context analysis"""
|
| 177 |
+
found_symbols = []
|
| 178 |
+
|
| 179 |
+
# Use regex to find all potential symbols (Persian words)
|
| 180 |
+
# This pattern matches Persian words that might be symbols
|
| 181 |
+
pattern = r'\b[\u0600-\u06FF]+\b'
|
| 182 |
+
|
| 183 |
+
for match in re.finditer(pattern, text):
|
| 184 |
+
word = match.group()
|
| 185 |
+
|
| 186 |
+
if word in SYMBOL_NAMES:
|
| 187 |
+
symbol_info = STOCK_SYMBOLS[word]
|
| 188 |
+
|
| 189 |
+
# Check context to determine if it's actually used as a stock symbol
|
| 190 |
+
is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
|
| 191 |
+
|
| 192 |
+
if is_stock:
|
| 193 |
+
found_symbols.append({
|
| 194 |
+
'word': word,
|
| 195 |
+
'start': match.start(),
|
| 196 |
+
'end': match.end(),
|
| 197 |
+
'entity_group': 'STOCK',
|
| 198 |
+
'score': confidence,
|
| 199 |
+
'company': symbol_info['company'],
|
| 200 |
+
'bazaar': symbol_info['bazaar'],
|
| 201 |
+
'bazaar_group': symbol_info['bazaar_group']
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
return found_symbols
|
| 205 |
+
|
| 206 |
+
# Label colors and names
|
| 207 |
label_colors = {
|
| 208 |
+
"B-PER": "#FF6B6B", # Person - Red
|
| 209 |
"I-PER": "#FFB3B3",
|
| 210 |
+
"B-ORG": "#4ECDC4", # Organization - Teal
|
| 211 |
"I-ORG": "#A7E9E4",
|
| 212 |
+
"B-LOC": "#95E1D3", # Location - Green
|
| 213 |
"I-LOC": "#C7F0E8",
|
| 214 |
+
"B-DAT": "#FFA07A", # Date - Orange
|
| 215 |
"I-DAT": "#FFDAB9",
|
| 216 |
+
"B-TIM": "#DDA0DD", # Time - Purple
|
| 217 |
"I-TIM": "#E6D0E6",
|
| 218 |
+
"B-MON": "#FFD700", # Money - Gold
|
| 219 |
"I-MON": "#FFEB99",
|
| 220 |
+
"B-PCT": "#87CEEB", # Percent - Sky Blue
|
| 221 |
"I-PCT": "#B3DFEF",
|
| 222 |
+
"STOCK": "#00FA9A", # Stock Symbol - Medium Spring Green
|
|
|
|
| 223 |
}
|
| 224 |
|
| 225 |
label_names = {
|
|
|
|
| 230 |
"TIM": "زمان (Time)",
|
| 231 |
"MON": "پول (Money)",
|
| 232 |
"PCT": "درصد (Percent)",
|
| 233 |
+
"STOCK": "نماد بورسی (Stock Symbol)",
|
|
|
|
| 234 |
}
|
| 235 |
|
| 236 |
+
def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
|
| 237 |
+
"""Merge entities, removing overlaps (stock symbols take precedence)"""
|
| 238 |
+
all_entities = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
+
# Add stock entities first (they have priority)
|
| 241 |
+
all_entities.extend(stock_entities)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
# Add NER entities that don't overlap with stock entities
|
| 244 |
+
for ner_ent in entities:
|
| 245 |
+
overlap = False
|
| 246 |
+
for stock_ent in stock_entities:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
# Check for overlap
|
| 248 |
+
if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']):
|
| 249 |
+
overlap = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
break
|
| 251 |
+
if not overlap:
|
| 252 |
+
all_entities.append(ner_ent)
|
|
|
|
| 253 |
|
| 254 |
+
return all_entities
|
| 255 |
|
| 256 |
+
def highlight_entities(text, all_entities):
|
| 257 |
"""Create HTML with highlighted entities"""
|
| 258 |
+
if not all_entities:
|
| 259 |
return text
|
| 260 |
|
| 261 |
+
# Sort entities by start position (reverse order)
|
| 262 |
+
entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True)
|
| 263 |
|
| 264 |
result = text
|
| 265 |
for entity in entities_sorted:
|
|
|
|
| 270 |
score = entity['score']
|
| 271 |
|
| 272 |
# Get color for this label
|
| 273 |
+
color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC")
|
| 274 |
+
|
| 275 |
+
# Add extra info for stock symbols
|
| 276 |
+
tooltip_info = f"{label} (confidence: {score:.2f})"
|
| 277 |
+
if label == 'STOCK':
|
| 278 |
+
company = entity.get('company', '')
|
| 279 |
+
bazaar = entity.get('bazaar', '')
|
| 280 |
+
if company:
|
| 281 |
+
tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})"
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# Create highlighted span
|
| 284 |
+
highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
|
| 285 |
|
| 286 |
result = result[:start] + highlighted + result[end:]
|
| 287 |
|
| 288 |
return result
|
| 289 |
|
| 290 |
def perform_ner(text):
|
| 291 |
+
"""Perform integrated NER and stock symbol detection"""
|
| 292 |
if not text.strip():
|
| 293 |
return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
|
| 294 |
|
| 295 |
try:
|
| 296 |
+
# Perform standard NER
|
| 297 |
entities = ner_pipeline(text)
|
| 298 |
|
| 299 |
+
# Find stock symbols using Persian LLM
|
| 300 |
+
stock_entities = find_stock_symbols_in_text(text)
|
| 301 |
|
| 302 |
+
# Merge entities (remove overlaps)
|
| 303 |
+
all_entities = merge_overlapping_entities(entities, stock_entities)
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
# Create highlighted version
|
| 306 |
+
highlighted_html = f"""
|
| 307 |
+
<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5;
|
| 308 |
+
padding: 20px; border: 1px solid #ddd; border-radius: 5px;
|
| 309 |
+
background-color: #f9f9f9; font-family: Tahoma, Arial;'>
|
| 310 |
+
{highlight_entities(text, all_entities)}
|
| 311 |
+
</div>
|
| 312 |
+
"""
|
| 313 |
|
| 314 |
# Create entities table
|
| 315 |
if all_entities:
|
| 316 |
entity_info = "### موجودیتهای شناسایی شده (Detected Entities):\n\n"
|
| 317 |
+
entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n"
|
| 318 |
+
entity_info += "|:------------|:-----------|:------------------|:---------------------|\n"
|
| 319 |
+
|
| 320 |
+
# Sort by position in text
|
| 321 |
+
all_entities.sort(key=lambda x: x['start'])
|
| 322 |
+
|
| 323 |
for ent in all_entities:
|
| 324 |
label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
|
| 325 |
+
details = ""
|
| 326 |
+
if ent['entity_group'] == 'STOCK':
|
| 327 |
+
company = ent.get('company', '')
|
| 328 |
+
bazaar = ent.get('bazaar', '')
|
| 329 |
+
group = ent.get('bazaar_group', '')
|
| 330 |
+
details = f"{company}<br>{bazaar}<br>{group}"
|
| 331 |
+
entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n"
|
|
|
|
| 332 |
else:
|
| 333 |
entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
|
| 334 |
|
| 335 |
+
# Add statistics
|
| 336 |
+
stats = f"\n\n### آمار (Statistics):\n"
|
| 337 |
+
stats += f"- تعداد کل موجودیتها: {len(all_entities)}\n"
|
| 338 |
+
stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n"
|
| 339 |
+
stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n"
|
| 340 |
+
stats += f"- سازمانها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n"
|
| 341 |
+
stats += f"- مکانها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n"
|
| 342 |
+
|
| 343 |
+
return highlighted_html, entity_info + stats
|
| 344 |
|
| 345 |
except Exception as e:
|
| 346 |
+
return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
+
# Enhanced examples
|
| 349 |
examples = [
|
| 350 |
+
["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."],
|
| 351 |
+
["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."],
|
| 352 |
+
["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."],
|
| 353 |
+
["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."],
|
| 354 |
+
["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."],
|
| 355 |
+
["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید میکند."],
|
| 356 |
+
["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."],
|
|
|
|
|
|
|
| 357 |
]
|
| 358 |
|
| 359 |
# Create Gradio interface
|
| 360 |
+
with gr.Blocks(
|
| 361 |
+
title="Persian NER + Stock Symbols | شناسایی موجودیتها و نمادهای بورسی",
|
| 362 |
+
theme=gr.themes.Soft(),
|
| 363 |
+
css="""
|
| 364 |
+
.rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
|
| 365 |
+
"""
|
| 366 |
+
) as demo:
|
| 367 |
+
gr.Markdown("""
|
| 368 |
+
# 🏦 Persian NER with Intelligent Stock Symbol Detection
|
| 369 |
+
# شناسایی هوشمند موجودیتهای نامدار و نمادهای بورس ایران
|
| 370 |
|
| 371 |
+
<div class="rtl-text">
|
| 372 |
+
این سیستم با است��اده از دو مدل هوش مصنوعی:
|
| 373 |
|
| 374 |
+
1. **ParsBERT-NER**: برای شناسایی موجودیتهای نامدار (اشخاص، سازمانها، مکانها، تاریخها)
|
| 375 |
+
2. **Persian BERT QA**: برای تشخیص هوشمند نمادهای بورسی با درک متن
|
| 376 |
|
| 377 |
+
ویژگی خاص: تشخیص هوشمند کلماتی مثل «فولاد» که میتواند نماد بورسی یا کلمه عادی باشد
|
| 378 |
+
</div>
|
|
|
|
| 379 |
|
| 380 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
""")
|
| 382 |
|
| 383 |
with gr.Row():
|
| 384 |
+
with gr.Column(scale=6):
|
| 385 |
input_text = gr.Textbox(
|
| 386 |
+
label="متن فارسی (Persian Text)",
|
| 387 |
+
placeholder="مثال: سهام فولاد در بورس تهران معامله میشود...",
|
| 388 |
+
lines=6,
|
| 389 |
+
rtl=True,
|
| 390 |
+
elem_classes=["rtl-text"]
|
| 391 |
)
|
| 392 |
+
with gr.Row():
|
| 393 |
+
submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2)
|
| 394 |
+
clear_btn = gr.Button("🗑️ پاک کردن", scale=1)
|
| 395 |
|
| 396 |
+
with gr.Column(scale=6):
|
| 397 |
+
output_html = gr.HTML(
|
| 398 |
+
label="نتیجه تحلیل (Analysis Result)",
|
| 399 |
+
elem_classes=["rtl-text"]
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
with gr.Row():
|
| 403 |
+
output_entities = gr.Markdown(
|
| 404 |
+
label="جدول موجودیتها (Entity Table)",
|
| 405 |
+
elem_classes=["rtl-text"]
|
| 406 |
+
)
|
| 407 |
|
| 408 |
gr.Examples(
|
| 409 |
examples=examples,
|
| 410 |
inputs=input_text,
|
| 411 |
+
label="نمونههای آماده (Ready Examples)",
|
| 412 |
+
examples_per_page=4
|
| 413 |
)
|
| 414 |
|
| 415 |
+
# Color guide
|
| 416 |
+
with gr.Accordion("📖 راهنمای رنگها (Color Guide)", open=False):
|
| 417 |
+
gr.Markdown("""
|
| 418 |
+
<div class="rtl-text">
|
| 419 |
+
|
| 420 |
+
| رنگ | نوع موجودیت | توضیحات |
|
| 421 |
+
|:---:|:------------|:--------|
|
| 422 |
+
| 🔴 | **PER** | اشخاص و نامهای افراد |
|
| 423 |
+
| 🔵 | **ORG** | سازمانها و شرکتها |
|
| 424 |
+
| 🟢 | **LOC** | مکانها و نامهای جغرافیایی |
|
| 425 |
+
| 🟠 | **DAT** | تاریخها |
|
| 426 |
+
| 🟣 | **TIM** | زمانها |
|
| 427 |
+
| 🟡 | **MON** | مقادیر پولی |
|
| 428 |
+
| 🔷 | **PCT** | درصدها |
|
| 429 |
+
| 💚 | **STOCK** | نمادهای بورسی (با تحلیل هوشمند متن) |
|
| 430 |
+
|
| 431 |
+
</div>
|
| 432 |
+
""")
|
| 433 |
|
| 434 |
+
# Info section
|
| 435 |
+
with gr.Accordion("ℹ️ درباره سیستم (About)", open=False):
|
| 436 |
+
gr.Markdown("""
|
| 437 |
+
<div class="rtl-text">
|
| 438 |
+
|
| 439 |
+
### قابلیتهای کلیدی:
|
| 440 |
+
|
| 441 |
+
- **تشخیص هوشمند نمادهای بورسی**: با استفاده از مدل زبانی فارسی، سیستم تشخیص میدهد که آیا کلمات مشابه نمادها (مثل فولاد، فارس) در متن به عنوان نماد بورسی استفاده شدهاند یا معنای عادی دارند
|
| 442 |
+
|
| 443 |
+
- **ترکیب دو مدل**: استفاده همزمان از ParsBERT-NER برای NER کلاسیک و Persian BERT برای درک متن
|
| 444 |
+
|
| 445 |
+
- **اطلاعات کامل نمادها**: نمایش نام شرکت، بازار، و گروه صنعت برای هر نماد شناسایی شده
|
| 446 |
+
|
| 447 |
+
- **دقت بالا**: با ترکیب تحلیلهای مبتنی بر قواعد و مدل زبانی
|
| 448 |
+
|
| 449 |
+
</div>
|
| 450 |
+
""")
|
| 451 |
|
| 452 |
+
# Event handlers
|
| 453 |
submit_btn.click(
|
| 454 |
fn=perform_ner,
|
| 455 |
inputs=input_text,
|
| 456 |
outputs=[output_html, output_entities]
|
| 457 |
)
|
| 458 |
|
| 459 |
+
clear_btn.click(
|
| 460 |
+
lambda: ("", "", ""),
|
| 461 |
+
outputs=[input_text, output_html, output_entities]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
input_text.submit(
|
| 465 |
fn=perform_ner,
|
| 466 |
inputs=input_text,
|
|
|
|
| 469 |
|
| 470 |
# Launch the app
|
| 471 |
if __name__ == "__main__":
|
| 472 |
+
print("Starting Persian NER + Stock Symbol Detection System...")
|
| 473 |
+
print(f"Using device: {device}")
|
| 474 |
+
print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
|
| 475 |
+
demo.launch(
|
| 476 |
+
share=False,
|
| 477 |
+
debug=True
|
| 478 |
+
)
|