optimopium's picture
Update app.py
2b32561 verified
raw
history blame
22 kB
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
AutoModelForCausalLM,
pipeline
)
import torch
import pandas as pd
import re
from typing import List, Dict, Tuple
import numpy as np
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the NER model and tokenizer
print("Loading NER model...")
ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
ner_model.to(device)
# Create NER pipeline
ner_pipeline = pipeline(
"ner",
model=ner_model,
tokenizer=ner_tokenizer,
device=0 if device == "cuda" else -1,
aggregation_strategy="simple"
)
# Load a smaller, open model for context understanding
print("Loading context understanding model...")
# Using Microsoft Phi-2 (small and efficient) or Mistral-7B-Instruct (if you have more resources)
context_model_name = "microsoft/phi-2" # 2.7B parameters, works well on CPU
try:
context_tokenizer = AutoTokenizer.from_pretrained(context_model_name, trust_remote_code=True)
context_model = AutoModelForCausalLM.from_pretrained(
context_model_name,
torch_dtype=dtype,
trust_remote_code=True,
device_map="auto" if torch.cuda.is_available() else None
)
if device == "cpu":
context_model = context_model.to(device)
# Set pad token if not set
if context_tokenizer.pad_token is None:
context_tokenizer.pad_token = context_tokenizer.eos_token
use_llm_model = True
print(f"Successfully loaded {context_model_name}")
except Exception as e:
print(f"Could not load Phi-2 model: {e}")
print("Falling back to zero-shot classification model...")
# Fallback to mDeBERTa for zero-shot classification
context_model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
classifier = pipeline(
"zero-shot-classification",
model=context_model_name,
device=0 if device == "cuda" else -1
)
use_llm_model = False
# Load stock symbols from CSV
def load_stock_symbols(csv_path="symbols.csv"):
"""Load Iranian stock market symbols from CSV file"""
try:
df = pd.read_csv(csv_path, encoding='utf-8')
symbols_dict = {}
for _, row in df.iterrows():
symbol = row['symbol']
symbols_dict[symbol] = {
'company': row['company_name'],
'bazaar': row['bazaar'],
'bazaar_group': row['bazaar_group']
}
return symbols_dict
except Exception as e:
print(f"Error loading symbols CSV: {e}")
# Provide default symbols for demo
return {
'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'},
'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'},
'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'},
'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'},
'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'},
'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'},
'وتوسکا': {'company': 'سرمایه گذاری توسعه توکا', 'bazaar': 'بورس', 'bazaar_group': 'سرمایه گذاریها'},
'پی پاد': {'company': 'پرداخت الکترونیک پاسارگاد', 'bazaar': 'بورس', 'bazaar_group': 'رایانه و فعالیت های وابسته'},
}
# Load symbols
STOCK_SYMBOLS = load_stock_symbols()
SYMBOL_NAMES = set(STOCK_SYMBOLS.keys())
# Market context keywords
MARKET_KEYWORDS = {
'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش',
'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد',
'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود',
'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت'
}
def use_phi_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float:
"""
Use Phi-2 model to determine if a word is used as a stock symbol
Returns confidence score (0-1)
"""
if not use_llm_model:
# Use zero-shot classification instead
return use_zero_shot_classification(text, potential_symbol, symbol_info)
try:
# Create a simple prompt for Phi-2
prompt = f"""Analyze this Persian text and determine if "{potential_symbol}" is used as a stock market symbol.
Context: "{potential_symbol}" could be a stock symbol for {symbol_info['company']} company.
Text: {text}
Answer with only "STOCK" if it's a stock symbol, or "WORD" if it's a regular word:
Answer: """
# Tokenize and generate
inputs = context_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = context_model.generate(
**inputs,
max_new_tokens=10,
temperature=0.1,
do_sample=False,
pad_token_id=context_tokenizer.eos_token_id
)
# Decode the response
response = context_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
response = response.strip().upper()
# Parse response
if "STOCK" in response:
return 0.9
elif "WORD" in response:
return 0.1
else:
# Ambiguous response, use heuristics
return 0.5
except Exception as e:
print(f"Phi-2 inference error: {e}")
return 0.5
def use_zero_shot_classification(text: str, potential_symbol: str, symbol_info: Dict) -> float:
"""
Fallback: Use zero-shot classification to determine if a word is a stock symbol
"""
try:
# Get context around the symbol
symbol_pos = text.find(potential_symbol)
if symbol_pos == -1:
return 0.5
start = max(0, symbol_pos - 100)
end = min(len(text), symbol_pos + len(potential_symbol) + 100)
context_text = text[start:end]
# Define candidate labels
candidate_labels = [
f"نماد بورسی {symbol_info['company']}",
f"کلمه عادی {potential_symbol}",
"stock market symbol",
"regular word"
]
# Perform classification
result = classifier(
context_text,
candidate_labels=candidate_labels,
multi_label=False
)
# Check top label
top_label = result['labels'][0]
top_score = result['scores'][0]
if 'نماد بورسی' in top_label or 'stock' in top_label.lower():
return top_score
else:
return 1 - top_score
except Exception as e:
print(f"Classification error: {e}")
return 0.5
def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]:
"""
Check if a potential symbol is actually used as a stock symbol in context
"""
# Get surrounding context
symbol_pos = text.find(potential_symbol)
if symbol_pos == -1:
return False, 0.0
start_context = max(0, symbol_pos - 100)
end_context = min(len(text), symbol_pos + len(potential_symbol) + 100)
context_window = text[start_context:end_context]
# Count market keywords
words_in_context = context_window.split()
market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS)
# Calculate heuristic score
heuristic_score = min(market_keyword_count * 0.2, 1.0)
# Strong heuristic signals
if market_keyword_count >= 5:
return True, 0.95
elif market_keyword_count == 0 and len(words_in_context) > 10:
return False, 0.05
# Use AI model for disambiguation
if use_llm_model:
ai_score = use_phi_for_disambiguation(context_window, potential_symbol, symbol_info)
else:
ai_score = use_zero_shot_classification(context_window, potential_symbol, symbol_info)
# Combine scores
final_score = (heuristic_score * 0.3 + ai_score * 0.7)
# Decision threshold
is_stock = final_score > 0.5
return is_stock, final_score
def find_stock_symbols_in_text(text: str) -> List[Dict]:
"""Find and validate stock symbols in text"""
found_symbols = []
processed_positions = set()
# Pattern to match Persian/Arabic words
pattern = r'\b[\u0600-\u06FF]+\b'
for match in re.finditer(pattern, text):
word = match.group()
if word in SYMBOL_NAMES and match.start() not in processed_positions:
symbol_info = STOCK_SYMBOLS[word]
# Check context
is_stock, confidence = check_stock_symbol_context(text, word, symbol_info)
if is_stock:
found_symbols.append({
'word': word,
'start': match.start(),
'end': match.end(),
'entity_group': 'STOCK',
'score': confidence,
'company': symbol_info['company'],
'bazaar': symbol_info['bazaar'],
'bazaar_group': symbol_info['bazaar_group']
})
processed_positions.add(match.start())
return found_symbols
# Label colors and names
label_colors = {
"B-PER": "#FF6B6B",
"I-PER": "#FFB3B3",
"B-ORG": "#4ECDC4",
"I-ORG": "#A7E9E4",
"B-LOC": "#95E1D3",
"I-LOC": "#C7F0E8",
"B-DAT": "#FFA07A",
"I-DAT": "#FFDAB9",
"B-TIM": "#DDA0DD",
"I-TIM": "#E6D0E6",
"B-MON": "#FFD700",
"I-MON": "#FFEB99",
"B-PCT": "#87CEEB",
"I-PCT": "#B3DFEF",
"STOCK": "#00FA9A",
}
label_names = {
"PER": "شخص (Person)",
"ORG": "سازمان (Organization)",
"LOC": "مکان (Location)",
"DAT": "تاریخ (Date)",
"TIM": "زمان (Time)",
"MON": "پول (Money)",
"PCT": "درصد (Percent)",
"STOCK": "نماد بورسی (Stock Symbol)",
}
def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]:
"""Merge entities, removing overlaps"""
all_entities = []
all_entities.extend(stock_entities)
for ner_ent in entities:
overlap = False
for stock_ent in stock_entities:
if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']):
overlap = True
break
if not overlap:
all_entities.append(ner_ent)
return all_entities
def highlight_entities(text, all_entities):
"""Create HTML with highlighted entities"""
if not all_entities:
return text
entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in entities_sorted:
start = entity['start']
end = entity['end']
label = entity['entity_group']
word = text[start:end]
score = entity['score']
color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC")
tooltip_info = f"{label} (confidence: {score:.2f})"
if label == 'STOCK':
company = entity.get('company', '')
bazaar = entity.get('bazaar', '')
if company:
tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})"
highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
result = result[:start] + highlighted + result[end:]
return result
def perform_ner(text):
"""Perform integrated NER and stock symbol detection"""
if not text.strip():
return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
try:
# Perform standard NER
entities = ner_pipeline(text)
# Find stock symbols
stock_entities = find_stock_symbols_in_text(text)
# Merge entities
all_entities = merge_overlapping_entities(entities, stock_entities)
# Create highlighted HTML
highlighted_html = f"""
<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5;
padding: 20px; border: 1px solid #ddd; border-radius: 5px;
background-color: #f9f9f9; font-family: Tahoma, Arial;'>
{highlight_entities(text, all_entities)}
</div>
"""
# Create entities table
if all_entities:
entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n"
entity_info += "|:------------|:-----------|:------------------|:---------------------|\n"
all_entities.sort(key=lambda x: x['start'])
for ent in all_entities:
label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
details = ""
if ent['entity_group'] == 'STOCK':
company = ent.get('company', '')
bazaar = ent.get('bazaar', '')
group = ent.get('bazaar_group', '')
details = f"{company}<br>{bazaar}<br>{group}"
entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n"
else:
entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
# Statistics
stats = f"\n\n### آمار (Statistics):\n"
stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n"
stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n"
stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n"
stats += f"- سازمان‌ها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n"
stats += f"- مکان‌ها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n"
return highlighted_html, entity_info + stats
except Exception as e:
return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e)
# Examples
examples = [
["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."],
["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."],
["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."],
["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."],
["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."],
["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید می‌کند."],
["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."],
]
# Gradio interface
with gr.Blocks(
title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی",
theme=gr.themes.Soft(),
css="""
.rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; }
"""
) as demo:
gr.Markdown(f"""
# 🏦 شناسایی هوشمند موجودیت‌ها و نمادهای بورس ایران
## Persian Named Entity Recognition with Stock Symbol Detection
### Using {context_model_name.split('/')[-1]} for Context Understanding
<div class="rtl-text">
این برنامه متن‌های فارسی را تحلیل کرده و موجودیت‌های مختلف را شناسایی می‌کند.
</div>
""")
with gr.Row():
with gr.Column(scale=6):
input_text = gr.Textbox(
label="متن فارسی را وارد کنید (Enter Persian Text)",
placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...",
lines=6,
rtl=True,
elem_classes=["rtl-text"]
)
with gr.Row():
submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2)
clear_btn = gr.Button("🗑️ پاک کردن", scale=1)
with gr.Column(scale=6):
output_html = gr.HTML(
label="نتیجه تحلیل (Analysis Result)",
elem_classes=["rtl-text"]
)
with gr.Row():
output_entities = gr.Markdown(
label="جدول موجودیت‌ها (Entity Table)",
elem_classes=["rtl-text"]
)
gr.Examples(
examples=examples,
inputs=input_text,
label="نمونه‌های آماده (Ready Examples)",
examples_per_page=4
)
# User guide
with gr.Accordion("📖 راهنمای استفاده (User Guide)", open=True):
gr.Markdown("""
<div class="rtl-text">
## چگونه از این برنامه استفاده کنیم؟
1. **متن فارسی خود را در کادر بالا وارد کنید**
2. **دکمه «تحلیل متن» را بزنید**
3. **نتایج را در دو بخش مشاهده کنید:**
- متن با موجودیت‌های رنگی شده
- جدول کامل موجودیت‌ها با جزئیات
## انواع موجودیت‌هایی که شناسایی می‌شوند:
| رنگ | نوع | مثال |
|:---:|:----|:-----|
| 🔴 قرمز | **اشخاص** | علی احمدی، مریم رضایی |
| 🔵 آبی | **سازمان‌ها** | شرکت ملی نفت، بانک ملت |
| 🟢 سبز | **مکان‌ها** | تهران، اصفهان، ایران |
| 🟠 نارنجی | **تاریخ‌ها** | ۱۵ خرداد ۱۴۰۳ |
| 🟣 بنفش | **زمان‌ها** | ساعت ۱۰ صبح |
| 🟡 زرد | **مبالغ پولی** | ۱۰۰۰ ریال، ۵ میلیارد تومان |
| 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ |
| 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا |
## ویژگی خاص: تشخیص هوشمند نمادهای بورسی
برنامه با استفاده از **هوش مصنوعی** تشخیص می‌دهد که آیا یک کلمه نماد بورسی است یا خیر.
**مثال:**
- «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅
- «صنعت **فولاد** در کشور مهم است» ← فولاد = کلمه عادی ❌
## نحوه تفسیر نتایج:
- **رنگ‌ها**: نوع موجودیت را نشان می‌دهند
- **برچسب‌ها**: نوع موجودیت به صورت مختصر
- **درصد اطمینان**: میزان اطمینان سیستم (۰-۱۰۰٪)
- **جزئیات نمادها**: نام شرکت، بازار و گروه صنعت
## مدل‌های استفاده شده:
- **ParsBERT NER**: شناسایی موجودیت‌های عمومی
- **Microsoft Phi-2 / mDeBERTa**: تحلیل هوشمند متن برای تشخیص نمادهای بورسی
</div>
""")
# Event handlers
submit_btn.click(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
clear_btn.click(
lambda: ("", "", ""),
outputs=[input_text, output_html, output_entities]
)
input_text.submit(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
# Launch
if __name__ == "__main__":
print("Starting Persian NER + Stock Symbol Detection System...")
print(f"Using device: {device}")
print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols")
print("Models loaded:")
print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased")
print(f" - Context: {context_model_name}")
demo.launch(
share=False,
debug=True
)