import gradio as gr from transformers import ( AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, pipeline ) import torch import pandas as pd import re from typing import List, Dict, Tuple import numpy as np # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load the NER model and tokenizer print("Loading NER model...") ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased" ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name) ner_model.to(device) # Create NER pipeline ner_pipeline = pipeline( "ner", model=ner_model, tokenizer=ner_tokenizer, device=0 if device == "cuda" else -1, aggregation_strategy="simple" ) # Load Persian LLM for context understanding print("Loading Persian context model...") # Using a Persian BERT model fine-tuned for question answering/text classification context_model_name = "persiannlp/mt5-small-parsinlu-multiple-choice" # Alternative model # You can also try: "HooshvareLab/bert-fa-base-uncased" with custom classification head # For stock symbol disambiguation, we'll use a QA approach qa_model_name = "mohammadehab/persian-qa-bert" qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) qa_model.to(device) qa_pipeline = pipeline( "question-answering", model=qa_model, tokenizer=qa_tokenizer, device=0 if device == "cuda" else -1 ) # Load stock symbols from CSV def load_stock_symbols(csv_path="symbols.csv"): """Load Iranian stock market symbols from CSV file""" try: df = pd.read_csv(csv_path, encoding='utf-8') symbols_dict = {} for _, row in df.iterrows(): symbol = row['symbol'] symbols_dict[symbol] = { 'company': row['company_name'], 'bazaar': row['bazaar'], 'bazaar_group': row['bazaar_group'] } return symbols_dict except Exception as e: print(f"Error loading symbols CSV: {e}") # Provide some default symbols for demo return { 'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'}, 'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'}, 'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'}, 'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'}, 'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'}, 'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'}, } # Load symbols STOCK_SYMBOLS = load_stock_symbols() SYMBOL_NAMES = set(STOCK_SYMBOLS.keys()) # Market context keywords for Persian MARKET_KEYWORDS = { 'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش', 'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد', 'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود', 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت', 'کندل', 'نمودار', 'تحلیل', 'بنیادی', 'تکنیکال', 'حمایت', 'مقاومت' } def use_llm_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float: """ Use Persian QA model to determine if a word is used as a stock symbol Returns confidence score (0-1) """ try: # Create context-aware questions in Persian context = text # Ask multiple questions to get better understanding questions = [ f"آیا {potential_symbol} در این متن نماد بورسی است؟", f"آیا منظور از {potential_symbol} سهام شرکت {symbol_info['company']} است؟", f"آیا {potential_symbol} در این متن به معاملات بورس اشاره دارد؟" ] scores = [] for question in questions: try: result = qa_pipeline(question=question, context=context) # Check if the answer contains positive indicators answer = result['answer'].lower() confidence = result['score'] # Adjust confidence based on answer content if any(word in answer for word in ['بله', 'است', 'می‌باشد', 'درست']): scores.append(confidence) elif any(word in answer for word in ['خیر', 'نیست', 'نمی‌باشد']): scores.append(1 - confidence) else: scores.append(0.5) # Uncertain except Exception: scores.append(0.5) # Default to uncertain # Return average confidence return np.mean(scores) if scores else 0.5 except Exception as e: print(f"LLM disambiguation error: {e}") return 0.5 # Return neutral confidence on error def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]: """ Check if a potential symbol is actually used as a stock symbol in context Returns (is_stock_symbol, confidence_score) """ # First, do a heuristic check text_lower = text.lower() # Get surrounding context (window of 30 characters before and after) symbol_pos = text.find(potential_symbol) if symbol_pos == -1: return False, 0.0 start_context = max(0, symbol_pos - 50) end_context = min(len(text), symbol_pos + len(potential_symbol) + 50) context_window = text[start_context:end_context] # Count market keywords in context words_in_context = context_window.split() market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS) # Calculate heuristic score heuristic_score = min(market_keyword_count * 0.3, 1.0) # If very strong or very weak signal from heuristics, use that if heuristic_score >= 0.9: return True, heuristic_score elif heuristic_score == 0 and len(words_in_context) > 5: return False, 0.1 # For ambiguous cases, use LLM llm_score = use_llm_for_disambiguation(text, potential_symbol, symbol_info) # Combine heuristic and LLM scores final_score = (heuristic_score * 0.4 + llm_score * 0.6) # Decision threshold is_stock = final_score > 0.5 return is_stock, final_score def find_stock_symbols_in_text(text: str) -> List[Dict]: """Find and validate stock symbols in text using context analysis""" found_symbols = [] # Use regex to find all potential symbols (Persian words) # This pattern matches Persian words that might be symbols pattern = r'\b[\u0600-\u06FF]+\b' for match in re.finditer(pattern, text): word = match.group() if word in SYMBOL_NAMES: symbol_info = STOCK_SYMBOLS[word] # Check context to determine if it's actually used as a stock symbol is_stock, confidence = check_stock_symbol_context(text, word, symbol_info) if is_stock: found_symbols.append({ 'word': word, 'start': match.start(), 'end': match.end(), 'entity_group': 'STOCK', 'score': confidence, 'company': symbol_info['company'], 'bazaar': symbol_info['bazaar'], 'bazaar_group': symbol_info['bazaar_group'] }) return found_symbols # Label colors and names label_colors = { "B-PER": "#FF6B6B", # Person - Red "I-PER": "#FFB3B3", "B-ORG": "#4ECDC4", # Organization - Teal "I-ORG": "#A7E9E4", "B-LOC": "#95E1D3", # Location - Green "I-LOC": "#C7F0E8", "B-DAT": "#FFA07A", # Date - Orange "I-DAT": "#FFDAB9", "B-TIM": "#DDA0DD", # Time - Purple "I-TIM": "#E6D0E6", "B-MON": "#FFD700", # Money - Gold "I-MON": "#FFEB99", "B-PCT": "#87CEEB", # Percent - Sky Blue "I-PCT": "#B3DFEF", "STOCK": "#00FA9A", # Stock Symbol - Medium Spring Green } label_names = { "PER": "شخص (Person)", "ORG": "سازمان (Organization)", "LOC": "مکان (Location)", "DAT": "تاریخ (Date)", "TIM": "زمان (Time)", "MON": "پول (Money)", "PCT": "درصد (Percent)", "STOCK": "نماد بورسی (Stock Symbol)", } def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]: """Merge entities, removing overlaps (stock symbols take precedence)""" all_entities = [] # Add stock entities first (they have priority) all_entities.extend(stock_entities) # Add NER entities that don't overlap with stock entities for ner_ent in entities: overlap = False for stock_ent in stock_entities: # Check for overlap if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']): overlap = True break if not overlap: all_entities.append(ner_ent) return all_entities def highlight_entities(text, all_entities): """Create HTML with highlighted entities""" if not all_entities: return text # Sort entities by start position (reverse order) entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True) result = text for entity in entities_sorted: start = entity['start'] end = entity['end'] label = entity['entity_group'] word = text[start:end] score = entity['score'] # Get color for this label color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC") # Add extra info for stock symbols tooltip_info = f"{label} (confidence: {score:.2f})" if label == 'STOCK': company = entity.get('company', '') bazaar = entity.get('bazaar', '') if company: tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})" # Create highlighted span highlighted = f'{word} [{label}]' result = result[:start] + highlighted + result[end:] return result def perform_ner(text): """Perform integrated NER and stock symbol detection""" if not text.strip(): return "

لطفا متن فارسی وارد کنید (Please enter Persian text)

", "" try: # Perform standard NER entities = ner_pipeline(text) # Find stock symbols using Persian LLM stock_entities = find_stock_symbols_in_text(text) # Merge entities (remove overlaps) all_entities = merge_overlapping_entities(entities, stock_entities) # Create highlighted version highlighted_html = f"""
{highlight_entities(text, all_entities)}
""" # Create entities table if all_entities: entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n" entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n" entity_info += "|:------------|:-----------|:------------------|:---------------------|\n" # Sort by position in text all_entities.sort(key=lambda x: x['start']) for ent in all_entities: label_fa = label_names.get(ent['entity_group'], ent['entity_group']) details = "" if ent['entity_group'] == 'STOCK': company = ent.get('company', '') bazaar = ent.get('bazaar', '') group = ent.get('bazaar_group', '') details = f"{company}
{bazaar}
{group}" entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n" else: entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)" # Add statistics stats = f"\n\n### آمار (Statistics):\n" stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n" stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n" stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n" stats += f"- سازمان‌ها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n" stats += f"- مکان‌ها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n" return highlighted_html, entity_info + stats except Exception as e: return f"

خطا (Error): {str(e)}

", str(e) # Enhanced examples examples = [ ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."], ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."], ["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."], ["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."], ["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."], ["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید می‌کند."], ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."], ] # Create Gradio interface with gr.Blocks( title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی", theme=gr.themes.Soft(), css=""" .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; } """ ) as demo: gr.Markdown(""" # 🏦 Persian NER with Intelligent Stock Symbol Detection # شناسایی هوشمند موجودیت‌های نامدار و نمادهای بورس ایران
این سیستم با استفاده از دو مدل هوش مصنوعی: 1. **ParsBERT-NER**: برای شناسایی موجودیت‌های نامدار (اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها) 2. **Persian BERT QA**: برای تشخیص هوشمند نمادهای بورسی با درک متن ویژگی خاص: تشخیص هوشمند کلماتی مثل «فولاد» که می‌تواند نماد بورسی یا کلمه عادی باشد
--- """) with gr.Row(): with gr.Column(scale=6): input_text = gr.Textbox( label="متن فارسی (Persian Text)", placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...", lines=6, rtl=True, elem_classes=["rtl-text"] ) with gr.Row(): submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2) clear_btn = gr.Button("🗑️ پاک کردن", scale=1) with gr.Column(scale=6): output_html = gr.HTML( label="نتیجه تحلیل (Analysis Result)", elem_classes=["rtl-text"] ) with gr.Row(): output_entities = gr.Markdown( label="جدول موجودیت‌ها (Entity Table)", elem_classes=["rtl-text"] ) gr.Examples( examples=examples, inputs=input_text, label="نمونه‌های آماده (Ready Examples)", examples_per_page=4 ) # Color guide with gr.Accordion("📖 راهنمای رنگ‌ها (Color Guide)", open=False): gr.Markdown("""
| رنگ | نوع موجودیت | توضیحات | |:---:|:------------|:--------| | 🔴 | **PER** | اشخاص و نام‌های افراد | | 🔵 | **ORG** | سازمان‌ها و شرکت‌ها | | 🟢 | **LOC** | مکان‌ها و نام‌های جغرافیایی | | 🟠 | **DAT** | تاریخ‌ها | | 🟣 | **TIM** | زمان‌ها | | 🟡 | **MON** | مقادیر پولی | | 🔷 | **PCT** | درصدها | | 💚 | **STOCK** | نمادهای بورسی (با تحلیل هوشمند متن) |
""") # Info section with gr.Accordion("ℹ️ درباره سیستم (About)", open=False): gr.Markdown("""
### قابلیت‌های کلیدی: - **تشخیص هوشمند نمادهای بورسی**: با استفاده از مدل زبانی فارسی، سیستم تشخیص می‌دهد که آیا کلمات مشابه نمادها (مثل فولاد، فارس) در متن به عنوان نماد بورسی استفاده شده‌اند یا معنای عادی دارند - **ترکیب دو مدل**: استفاده همزمان از ParsBERT-NER برای NER کلاسیک و Persian BERT برای درک متن - **اطلاعات کامل نمادها**: نمایش نام شرکت، بازار، و گروه صنعت برای هر نماد شناسایی شده - **دقت بالا**: با ترکیب تحلیل‌های مبتنی بر قواعد و مدل زبانی
""") # Event handlers submit_btn.click( fn=perform_ner, inputs=input_text, outputs=[output_html, output_entities] ) clear_btn.click( lambda: ("", "", ""), outputs=[input_text, output_html, output_entities] ) input_text.submit( fn=perform_ner, inputs=input_text, outputs=[output_html, output_entities] ) # Launch the app if __name__ == "__main__": print("Starting Persian NER + Stock Symbol Detection System...") print(f"Using device: {device}") print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols") demo.launch( share=False, debug=True )