optimopium's picture
Update app.py
c4885a5 verified
raw
history blame
19.9 kB
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
pipeline
)
import torch
import pandas as pd
import re
from typing import List, Dict, Tuple
import numpy as np
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the NER model and tokenizer
print("Loading NER model...")
ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
ner_model.to(device)
# Create NER pipeline
ner_pipeline = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
device=0 if device == "cuda" else -1,
aggregation_strategy="simple"
)
# Load Persian LLM for context understanding
print("Loading Persian context model...")
# Using a Persian BERT model fine-tuned for question answering/text classification
context_model_name = "persiannlp/mt5-small-parsinlu-multiple-choice" # Alternative model
# You can also try: "HooshvareLab/bert-fa-base-uncased" with custom classification head
# For stock symbol disambiguation, we'll use a QA approach
qa_model_name = "mohammadehab/persian-qa-bert"
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
qa_model.to(device)
qa_pipeline = pipeline(
"question-answering",
model=qa_model,
tokenizer=qa_tokenizer,
device=0 if device == "cuda" else -1
)
# Load stock symbols from CSV
def load_stock_symbols(csv_path="symbols.csv"):
"""Load Iranian stock market symbols from CSV file"""
try:
df = pd.read_csv(csv_path, encoding='utf-8')
symbols_dict = {}
for _, row in df.iterrows():
symbol = row['symbol']
symbols_dict[symbol] = {
'company': row['company_name'],
'bazaar': row['bazaar'],
'bazaar_group': row['bazaar_group']
}
return symbols_dict
except Exception as e:
print(f"Error loading symbols CSV: {e}")
# Provide some default symbols for demo
return {
'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'},
'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'},
'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'},
'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'},
'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'},
'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'},
}
# Load symbols
STOCK_SYMBOLS = load_stock_symbols()
SYMBOL_NAMES = set(STOCK_SYMBOLS.keys())
# Market context keywords for Persian
MARKET_KEYWORDS = {
'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش',
'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد',
'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود',
'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت',
'کندل', 'نمودار', 'تحلیل', 'بنیادی', 'تکنیکال', 'حمایت', 'مقاومت'
}
def use_llm_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
"""
Use Persian QA model to determine if a word is used as a stock symbol
Returns confidence score (0-1)
"""
try:
# Create context-aware questions in Persian
context = text
# Ask multiple questions to get better understanding
questions = [
f"آیا {potential_symbol} در این متن نماد بورسی است؟",
f"آیا منظور از {potential_symbol} سهام شرکت {symbol_info['company']} است؟",
f"آیا {potential_symbol} در این متن به معاملات بورس اشاره دارد؟"
]
scores = []
for question in questions:
try:
result = qa_pipeline(question=question, context=context)
# Check if the answer contains positive indicators
answer = result['answer'].lower()
confidence = result['score']
# Adjust confidence based on answer content
if any(word in answer for word in ['بله', 'است', 'می‌باشد', 'درست']):
scores.append(confidence)
elif any(word in answer for word in ['خیر', 'نیست', 'نمی‌باشد']):
scores.append(1 - confidence)
else:
scores.append(0.5) # Uncertain
except Exception:
scores.append(0.5) # Default to uncertain
# Return average confidence
return np.mean(scores) if scores else 0.5
except Exception as e:
print(f"LLM disambiguation error: {e}")
return 0.5 # Return neutral confidence on error
def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
"""
Check if a potential symbol is actually used as a stock symbol in context
Returns (is_stock_symbol, confidence_score)
"""
# First, do a heuristic check
text_lower = text.lower()
# Get surrounding context (window of 30 characters before and after)
symbol_pos = text.find(potential_symbol)
if symbol_pos == -1:
return False, 0.0
start_context = max(0, symbol_pos - 50)
end_context = min(len(text), symbol_pos + len(potential_symbol) + 50)
context_window = text[start_context:end_context]
# Count market keywords in context
words_in_context = context_window.split()
market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS)
# Calculate heuristic score
heuristic_score = min(market_keyword_count * 0.3, 1.0)
# If very strong or very weak signal from heuristics, use that
if heuristic_score >= 0.9:
return True, heuristic_score
elif heuristic_score == 0 and len(words_in_context) > 5:
return False, 0.1
# For ambiguous cases, use LLM
llm_score = use_llm_for_disambiguation(text, potential_symbol, symbol_info)
# Combine heuristic and LLM scores
final_score = (heuristic_score * 0.4 + llm_score * 0.6)
# Decision threshold
is_stock = final_score > 0.5
return is_stock, final_score
def find_stock_symbols_in_text(text: str) -> List[Dict]:
"""Find and validate stock symbols in text using context analysis"""
found_symbols = []
# Use regex to find all potential symbols (Persian words)
# This pattern matches Persian words that might be symbols
pattern = r'\b[\u0600-\u06FF]+\b'
for match in re.finditer(pattern, text):
word = match.group()
if word in SYMBOL_NAMES:
symbol_info = STOCK_SYMBOLS[word]
# Check context to determine if it's actually used as a stock symbol
is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
if is_stock:
found_symbols.append({
'word': word,
'start': match.start(),
'end': match.end(),
'entity_group': 'STOCK',
'score': confidence,
'company': symbol_info['company'],
'bazaar': symbol_info['bazaar'],
'bazaar_group': symbol_info['bazaar_group']
})
return found_symbols
# Label colors and names
label_colors = {
"B-PER": "#FF6B6B", # Person - Red
"I-PER": "#FFB3B3",
"B-ORG": "#4ECDC4", # Organization - Teal
"I-ORG": "#A7E9E4",
"B-LOC": "#95E1D3", # Location - Green
"I-LOC": "#C7F0E8",
"B-DAT": "#FFA07A", # Date - Orange
"I-DAT": "#FFDAB9",
"B-TIM": "#DDA0DD", # Time - Purple
"I-TIM": "#E6D0E6",
"B-MON": "#FFD700", # Money - Gold
"I-MON": "#FFEB99",
"B-PCT": "#87CEEB", # Percent - Sky Blue
"I-PCT": "#B3DFEF",
"STOCK": "#00FA9A", # Stock Symbol - Medium Spring Green
}
label_names = {
"PER": "شخص (Person)",
"ORG": "سازمان (Organization)",
"LOC": "مکان (Location)",
"DAT": "تاریخ (Date)",
"TIM": "زمان (Time)",
"MON": "پول (Money)",
"PCT": "درصد (Percent)",
"STOCK": "نماد بورسی (Stock Symbol)",
}
def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
"""Merge entities, removing overlaps (stock symbols take precedence)"""
all_entities = []
# Add stock entities first (they have priority)
all_entities.extend(stock_entities)
# Add NER entities that don't overlap with stock entities
for ner_ent in entities:
overlap = False
for stock_ent in stock_entities:
# Check for overlap
if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']):
overlap = True
break
if not overlap:
all_entities.append(ner_ent)
return all_entities
def highlight_entities(text, all_entities):
"""Create HTML with highlighted entities"""
if not all_entities:
return text
# Sort entities by start position (reverse order)
entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in entities_sorted:
start = entity['start']
end = entity['end']
label = entity['entity_group']
word = text[start:end]
score = entity['score']
# Get color for this label
color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC")
# Add extra info for stock symbols
tooltip_info = f"{label} (confidence: {score:.2f})"
if label == 'STOCK':
company = entity.get('company', '')
bazaar = entity.get('bazaar', '')
if company:
tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})"
# Create highlighted span
highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
result = result[:start] + highlighted + result[end:]
return result
def perform_ner(text):
"""Perform integrated NER and stock symbol detection"""
if not text.strip():
return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
try:
# Perform standard NER
entities = ner_pipeline(text)
# Find stock symbols using Persian LLM
stock_entities = find_stock_symbols_in_text(text)
# Merge entities (remove overlaps)
all_entities = merge_overlapping_entities(entities, stock_entities)
# Create highlighted version
highlighted_html = f"""
<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5;
padding: 20px; border: 1px solid #ddd; border-radius: 5px;
background-color: #f9f9f9; font-family: Tahoma, Arial;'>
{highlight_entities(text, all_entities)}
</div>
"""
# Create entities table
if all_entities:
entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n"
entity_info += "|:------------|:-----------|:------------------|:---------------------|\n"
# Sort by position in text
all_entities.sort(key=lambda x: x['start'])
for ent in all_entities:
label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
details = ""
if ent['entity_group'] == 'STOCK':
company = ent.get('company', '')
bazaar = ent.get('bazaar', '')
group = ent.get('bazaar_group', '')
details = f"{company}<br>{bazaar}<br>{group}"
entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n"
else:
entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
# Add statistics
stats = f"\n\n### آمار (Statistics):\n"
stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n"
stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n"
stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n"
stats += f"- سازمان‌ها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n"
stats += f"- مکان‌ها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n"
return highlighted_html, entity_info + stats
except Exception as e:
return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e)
# Enhanced examples
examples = [
["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."],
["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."],
["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."],
["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."],
["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."],
["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید می‌کند."],
["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."],
]
# Create Gradio interface
with gr.Blocks(
title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی",
theme=gr.themes.Soft(),
css="""
.rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
"""
) as demo:
gr.Markdown("""
# 🏦 Persian NER with Intelligent Stock Symbol Detection
# شناسایی هوشمند موجودیت‌های نامدار و نمادهای بورس ایران
<div class="rtl-text">
این سیستم با استفاده از دو مدل هوش مصنوعی:
1. **ParsBERT-NER**: برای شناسایی موجودیت‌های نامدار (اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها)
2. **Persian BERT QA**: برای تشخیص هوشمند نمادهای بورسی با درک متن
ویژگی خاص: تشخیص هوشمند کلماتی مثل «فولاد» که می‌تواند نماد بورسی یا کلمه عادی باشد
</div>
---
""")
with gr.Row():
with gr.Column(scale=6):
input_text = gr.Textbox(
label="متن فارسی (Persian Text)",
placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...",
lines=6,
rtl=True,
elem_classes=["rtl-text"]
)
with gr.Row():
submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2)
clear_btn = gr.Button("🗑️ پاک کردن", scale=1)
with gr.Column(scale=6):
output_html = gr.HTML(
label="نتیجه تحلیل (Analysis Result)",
elem_classes=["rtl-text"]
)
with gr.Row():
output_entities = gr.Markdown(
label="جدول موجودیت‌ها (Entity Table)",
elem_classes=["rtl-text"]
)
gr.Examples(
examples=examples,
inputs=input_text,
label="نمونه‌های آماده (Ready Examples)",
examples_per_page=4
)
# Color guide
with gr.Accordion("📖 راهنمای رنگ‌ها (Color Guide)", open=False):
gr.Markdown("""
<div class="rtl-text">
| رنگ | نوع موجودیت | توضیحات |
|:---:|:------------|:--------|
| 🔴 | **PER** | اشخاص و نام‌های افراد |
| 🔵 | **ORG** | سازمان‌ها و شرکت‌ها |
| 🟢 | **LOC** | مکان‌ها و نام‌های جغرافیایی |
| 🟠 | **DAT** | تاریخ‌ها |
| 🟣 | **TIM** | زمان‌ها |
| 🟡 | **MON** | مقادیر پولی |
| 🔷 | **PCT** | درصدها |
| 💚 | **STOCK** | نمادهای بورسی (با تحلیل هوشمند متن) |
</div>
""")
# Info section
with gr.Accordion("ℹ️ درباره سیستم (About)", open=False):
gr.Markdown("""
<div class="rtl-text">
### قابلیت‌های کلیدی:
- **تشخیص هوشمند نمادهای بورسی**: با استفاده از مدل زبانی فارسی، سیستم تشخیص می‌دهد که آیا کلمات مشابه نمادها (مثل فولاد، فارس) در متن به عنوان نماد بورسی استفاده شده‌اند یا معنای عادی دارند
- **ترکیب دو مدل**: استفاده همزمان از ParsBERT-NER برای NER کلاسیک و Persian BERT برای درک متن
- **اطلاعات کامل نمادها**: نمایش نام شرکت، بازار، و گروه صنعت برای هر نماد شناسایی شده
- **دقت بالا**: با ترکیب تحلیل‌های مبتنی بر قواعد و مدل زبانی
</div>
""")
# Event handlers
submit_btn.click(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
clear_btn.click(
lambda: ("", "", ""),
outputs=[input_text, output_html, output_entities]
)
input_text.submit(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
# Launch the app
if __name__ == "__main__":
print("Starting Persian NER + Stock Symbol Detection System...")
print(f"Using device: {device}")
print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
demo.launch(
share=False,
debug=True
)