Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForTokenClassification, | |
| AutoModelForSequenceClassification, | |
| AutoModelForQuestionAnswering, | |
| pipeline | |
| ) | |
| import torch | |
| import pandas as pd | |
| import re | |
| from typing import List, Dict, Tuple | |
| import numpy as np | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the NER model and tokenizer | |
| print("Loading NER model...") | |
| ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased" | |
| ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) | |
| ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name) | |
| ner_model.to(device) | |
| # Create NER pipeline | |
| ner_pipeline = pipeline( | |
| "ner", | |
| model=ner_model, | |
| tokenizer=ner_tokenizer, | |
| device=0 if device == "cuda" else -1, | |
| aggregation_strategy="simple" | |
| ) | |
| # Load Persian LLM for context understanding | |
| print("Loading Persian context model...") | |
| # Using a Persian BERT model fine-tuned for question answering/text classification | |
| context_model_name = "persiannlp/mt5-small-parsinlu-multiple-choice" # Alternative model | |
| # You can also try: "HooshvareLab/bert-fa-base-uncased" with custom classification head | |
| # For stock symbol disambiguation, we'll use a QA approach | |
| qa_model_name = "mohammadehab/persian-qa-bert" | |
| qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
| qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) | |
| qa_model.to(device) | |
| qa_pipeline = pipeline( | |
| "question-answering", | |
| model=qa_model, | |
| tokenizer=qa_tokenizer, | |
| device=0 if device == "cuda" else -1 | |
| ) | |
| # Load stock symbols from CSV | |
| def load_stock_symbols(csv_path="symbols.csv"): | |
| """Load Iranian stock market symbols from CSV file""" | |
| try: | |
| df = pd.read_csv(csv_path, encoding='utf-8') | |
| symbols_dict = {} | |
| for _, row in df.iterrows(): | |
| symbol = row['symbol'] | |
| symbols_dict[symbol] = { | |
| 'company': row['company_name'], | |
| 'bazaar': row['bazaar'], | |
| 'bazaar_group': row['bazaar_group'] | |
| } | |
| return symbols_dict | |
| except Exception as e: | |
| print(f"Error loading symbols CSV: {e}") | |
| # Provide some default symbols for demo | |
| return { | |
| 'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'}, | |
| 'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'}, | |
| 'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'}, | |
| 'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآوردههای نفتی'}, | |
| 'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'}, | |
| 'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'}, | |
| } | |
| # Load symbols | |
| STOCK_SYMBOLS = load_stock_symbols() | |
| SYMBOL_NAMES = set(STOCK_SYMBOLS.keys()) | |
| # Market context keywords for Persian | |
| MARKET_KEYWORDS = { | |
| 'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش', | |
| 'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایهگذاری', 'پرتفوی', 'نماد', | |
| 'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود', | |
| 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت', | |
| 'کندل', 'نمودار', 'تحلیل', 'بنیادی', 'تکنیکال', 'حمایت', 'مقاومت' | |
| } | |
| def use_llm_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float: | |
| """ | |
| Use Persian QA model to determine if a word is used as a stock symbol | |
| Returns confidence score (0-1) | |
| """ | |
| try: | |
| # Create context-aware questions in Persian | |
| context = text | |
| # Ask multiple questions to get better understanding | |
| questions = [ | |
| f"آیا {potential_symbol} در این متن نماد بورسی است؟", | |
| f"آیا منظور از {potential_symbol} سهام شرکت {symbol_info['company']} است؟", | |
| f"آیا {potential_symbol} در این متن به معاملات بورس اشاره دارد؟" | |
| ] | |
| scores = [] | |
| for question in questions: | |
| try: | |
| result = qa_pipeline(question=question, context=context) | |
| # Check if the answer contains positive indicators | |
| answer = result['answer'].lower() | |
| confidence = result['score'] | |
| # Adjust confidence based on answer content | |
| if any(word in answer for word in ['بله', 'است', 'میباشد', 'درست']): | |
| scores.append(confidence) | |
| elif any(word in answer for word in ['خیر', 'نیست', 'نمیباشد']): | |
| scores.append(1 - confidence) | |
| else: | |
| scores.append(0.5) # Uncertain | |
| except Exception: | |
| scores.append(0.5) # Default to uncertain | |
| # Return average confidence | |
| return np.mean(scores) if scores else 0.5 | |
| except Exception as e: | |
| print(f"LLM disambiguation error: {e}") | |
| return 0.5 # Return neutral confidence on error | |
| def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]: | |
| """ | |
| Check if a potential symbol is actually used as a stock symbol in context | |
| Returns (is_stock_symbol, confidence_score) | |
| """ | |
| # First, do a heuristic check | |
| text_lower = text.lower() | |
| # Get surrounding context (window of 30 characters before and after) | |
| symbol_pos = text.find(potential_symbol) | |
| if symbol_pos == -1: | |
| return False, 0.0 | |
| start_context = max(0, symbol_pos - 50) | |
| end_context = min(len(text), symbol_pos + len(potential_symbol) + 50) | |
| context_window = text[start_context:end_context] | |
| # Count market keywords in context | |
| words_in_context = context_window.split() | |
| market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS) | |
| # Calculate heuristic score | |
| heuristic_score = min(market_keyword_count * 0.3, 1.0) | |
| # If very strong or very weak signal from heuristics, use that | |
| if heuristic_score >= 0.9: | |
| return True, heuristic_score | |
| elif heuristic_score == 0 and len(words_in_context) > 5: | |
| return False, 0.1 | |
| # For ambiguous cases, use LLM | |
| llm_score = use_llm_for_disambiguation(text, potential_symbol, symbol_info) | |
| # Combine heuristic and LLM scores | |
| final_score = (heuristic_score * 0.4 + llm_score * 0.6) | |
| # Decision threshold | |
| is_stock = final_score > 0.5 | |
| return is_stock, final_score | |
| def find_stock_symbols_in_text(text: str) -> List[Dict]: | |
| """Find and validate stock symbols in text using context analysis""" | |
| found_symbols = [] | |
| # Use regex to find all potential symbols (Persian words) | |
| # This pattern matches Persian words that might be symbols | |
| pattern = r'\b[\u0600-\u06FF]+\b' | |
| for match in re.finditer(pattern, text): | |
| word = match.group() | |
| if word in SYMBOL_NAMES: | |
| symbol_info = STOCK_SYMBOLS[word] | |
| # Check context to determine if it's actually used as a stock symbol | |
| is_stock, confidence = check_stock_symbol_context(text, word, symbol_info) | |
| if is_stock: | |
| found_symbols.append({ | |
| 'word': word, | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'entity_group': 'STOCK', | |
| 'score': confidence, | |
| 'company': symbol_info['company'], | |
| 'bazaar': symbol_info['bazaar'], | |
| 'bazaar_group': symbol_info['bazaar_group'] | |
| }) | |
| return found_symbols | |
| # Label colors and names | |
| label_colors = { | |
| "B-PER": "#FF6B6B", # Person - Red | |
| "I-PER": "#FFB3B3", | |
| "B-ORG": "#4ECDC4", # Organization - Teal | |
| "I-ORG": "#A7E9E4", | |
| "B-LOC": "#95E1D3", # Location - Green | |
| "I-LOC": "#C7F0E8", | |
| "B-DAT": "#FFA07A", # Date - Orange | |
| "I-DAT": "#FFDAB9", | |
| "B-TIM": "#DDA0DD", # Time - Purple | |
| "I-TIM": "#E6D0E6", | |
| "B-MON": "#FFD700", # Money - Gold | |
| "I-MON": "#FFEB99", | |
| "B-PCT": "#87CEEB", # Percent - Sky Blue | |
| "I-PCT": "#B3DFEF", | |
| "STOCK": "#00FA9A", # Stock Symbol - Medium Spring Green | |
| } | |
| label_names = { | |
| "PER": "شخص (Person)", | |
| "ORG": "سازمان (Organization)", | |
| "LOC": "مکان (Location)", | |
| "DAT": "تاریخ (Date)", | |
| "TIM": "زمان (Time)", | |
| "MON": "پول (Money)", | |
| "PCT": "درصد (Percent)", | |
| "STOCK": "نماد بورسی (Stock Symbol)", | |
| } | |
| def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]: | |
| """Merge entities, removing overlaps (stock symbols take precedence)""" | |
| all_entities = [] | |
| # Add stock entities first (they have priority) | |
| all_entities.extend(stock_entities) | |
| # Add NER entities that don't overlap with stock entities | |
| for ner_ent in entities: | |
| overlap = False | |
| for stock_ent in stock_entities: | |
| # Check for overlap | |
| if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']): | |
| overlap = True | |
| break | |
| if not overlap: | |
| all_entities.append(ner_ent) | |
| return all_entities | |
| def highlight_entities(text, all_entities): | |
| """Create HTML with highlighted entities""" | |
| if not all_entities: | |
| return text | |
| # Sort entities by start position (reverse order) | |
| entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True) | |
| result = text | |
| for entity in entities_sorted: | |
| start = entity['start'] | |
| end = entity['end'] | |
| label = entity['entity_group'] | |
| word = text[start:end] | |
| score = entity['score'] | |
| # Get color for this label | |
| color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC") | |
| # Add extra info for stock symbols | |
| tooltip_info = f"{label} (confidence: {score:.2f})" | |
| if label == 'STOCK': | |
| company = entity.get('company', '') | |
| bazaar = entity.get('bazaar', '') | |
| if company: | |
| tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})" | |
| # Create highlighted span | |
| highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{tooltip_info}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>' | |
| result = result[:start] + highlighted + result[end:] | |
| return result | |
| def perform_ner(text): | |
| """Perform integrated NER and stock symbol detection""" | |
| if not text.strip(): | |
| return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", "" | |
| try: | |
| # Perform standard NER | |
| entities = ner_pipeline(text) | |
| # Find stock symbols using Persian LLM | |
| stock_entities = find_stock_symbols_in_text(text) | |
| # Merge entities (remove overlaps) | |
| all_entities = merge_overlapping_entities(entities, stock_entities) | |
| # Create highlighted version | |
| highlighted_html = f""" | |
| <div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2.5; | |
| padding: 20px; border: 1px solid #ddd; border-radius: 5px; | |
| background-color: #f9f9f9; font-family: Tahoma, Arial;'> | |
| {highlight_entities(text, all_entities)} | |
| </div> | |
| """ | |
| # Create entities table | |
| if all_entities: | |
| entity_info = "### موجودیتهای شناسایی شده (Detected Entities):\n\n" | |
| entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n" | |
| entity_info += "|:------------|:-----------|:------------------|:---------------------|\n" | |
| # Sort by position in text | |
| all_entities.sort(key=lambda x: x['start']) | |
| for ent in all_entities: | |
| label_fa = label_names.get(ent['entity_group'], ent['entity_group']) | |
| details = "" | |
| if ent['entity_group'] == 'STOCK': | |
| company = ent.get('company', '') | |
| bazaar = ent.get('bazaar', '') | |
| group = ent.get('bazaar_group', '') | |
| details = f"{company}<br>{bazaar}<br>{group}" | |
| entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n" | |
| else: | |
| entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)" | |
| # Add statistics | |
| stats = f"\n\n### آمار (Statistics):\n" | |
| stats += f"- تعداد کل موجودیتها: {len(all_entities)}\n" | |
| stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n" | |
| stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n" | |
| stats += f"- سازمانها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n" | |
| stats += f"- مکانها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n" | |
| return highlighted_html, entity_info + stats | |
| except Exception as e: | |
| return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", str(e) | |
| # Enhanced examples | |
| examples = [ | |
| ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."], | |
| ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."], | |
| ["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."], | |
| ["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."], | |
| ["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."], | |
| ["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید میکند."], | |
| ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."], | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="Persian NER + Stock Symbols | شناسایی موجودیتها و نمادهای بورسی", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🏦 Persian NER with Intelligent Stock Symbol Detection | |
| # شناسایی هوشمند موجودیتهای نامدار و نمادهای بورس ایران | |
| <div class="rtl-text"> | |
| این سیستم با استفاده از دو مدل هوش مصنوعی: | |
| 1. **ParsBERT-NER**: برای شناسایی موجودیتهای نامدار (اشخاص، سازمانها، مکانها، تاریخها) | |
| 2. **Persian BERT QA**: برای تشخیص هوشمند نمادهای بورسی با درک متن | |
| ویژگی خاص: تشخیص هوشمند کلماتی مثل «فولاد» که میتواند نماد بورسی یا کلمه عادی باشد | |
| </div> | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| input_text = gr.Textbox( | |
| label="متن فارسی (Persian Text)", | |
| placeholder="مثال: سهام فولاد در بورس تهران معامله میشود...", | |
| lines=6, | |
| rtl=True, | |
| elem_classes=["rtl-text"] | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2) | |
| clear_btn = gr.Button("🗑️ پاک کردن", scale=1) | |
| with gr.Column(scale=6): | |
| output_html = gr.HTML( | |
| label="نتیجه تحلیل (Analysis Result)", | |
| elem_classes=["rtl-text"] | |
| ) | |
| with gr.Row(): | |
| output_entities = gr.Markdown( | |
| label="جدول موجودیتها (Entity Table)", | |
| elem_classes=["rtl-text"] | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_text, | |
| label="نمونههای آماده (Ready Examples)", | |
| examples_per_page=4 | |
| ) | |
| # Color guide | |
| with gr.Accordion("📖 راهنمای رنگها (Color Guide)", open=False): | |
| gr.Markdown(""" | |
| <div class="rtl-text"> | |
| | رنگ | نوع موجودیت | توضیحات | | |
| |:---:|:------------|:--------| | |
| | 🔴 | **PER** | اشخاص و نامهای افراد | | |
| | 🔵 | **ORG** | سازمانها و شرکتها | | |
| | 🟢 | **LOC** | مکانها و نامهای جغرافیایی | | |
| | 🟠 | **DAT** | تاریخها | | |
| | 🟣 | **TIM** | زمانها | | |
| | 🟡 | **MON** | مقادیر پولی | | |
| | 🔷 | **PCT** | درصدها | | |
| | 💚 | **STOCK** | نمادهای بورسی (با تحلیل هوشمند متن) | | |
| </div> | |
| """) | |
| # Info section | |
| with gr.Accordion("ℹ️ درباره سیستم (About)", open=False): | |
| gr.Markdown(""" | |
| <div class="rtl-text"> | |
| ### قابلیتهای کلیدی: | |
| - **تشخیص هوشمند نمادهای بورسی**: با استفاده از مدل زبانی فارسی، سیستم تشخیص میدهد که آیا کلمات مشابه نمادها (مثل فولاد، فارس) در متن به عنوان نماد بورسی استفاده شدهاند یا معنای عادی دارند | |
| - **ترکیب دو مدل**: استفاده همزمان از ParsBERT-NER برای NER کلاسیک و Persian BERT برای درک متن | |
| - **اطلاعات کامل نمادها**: نمایش نام شرکت، بازار، و گروه صنعت برای هر نماد شناسایی شده | |
| - **دقت بالا**: با ترکیب تحلیلهای مبتنی بر قواعد و مدل زبانی | |
| </div> | |
| """) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=perform_ner, | |
| inputs=input_text, | |
| outputs=[output_html, output_entities] | |
| ) | |
| clear_btn.click( | |
| lambda: ("", "", ""), | |
| outputs=[input_text, output_html, output_entities] | |
| ) | |
| input_text.submit( | |
| fn=perform_ner, | |
| inputs=input_text, | |
| outputs=[output_html, output_entities] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| print("Starting Persian NER + Stock Symbol Detection System...") | |
| print(f"Using device: {device}") | |
| print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols") | |
| demo.launch( | |
| share=False, | |
| debug=True | |
| ) |