import gradio as gr from transformers import ( AutoTokenizer, AutoModelForTokenClassification, AutoModelForCausalLM, pipeline ) import torch import pandas as pd import re from typing import List, Dict, Tuple import numpy as np # Set device device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load the NER model and tokenizer print("Loading NER model...") ner_model_name = "HooshvareLab/bert-base-parsbert-ner-uncased" ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name) ner_model.to(device) # Create NER pipeline ner_pipeline = pipeline( "ner", model=ner_model, tokenizer=ner_tokenizer, device=0 if device == "cuda" else -1, aggregation_strategy="simple" ) # Load a smaller, open model for context understanding print("Loading context understanding model...") # Using Microsoft Phi-2 (small and efficient) or Mistral-7B-Instruct (if you have more resources) context_model_name = "microsoft/phi-2" # 2.7B parameters, works well on CPU try: context_tokenizer = AutoTokenizer.from_pretrained(context_model_name, trust_remote_code=True) context_model = AutoModelForCausalLM.from_pretrained( context_model_name, torch_dtype=dtype, trust_remote_code=True, device_map="auto" if torch.cuda.is_available() else None ) if device == "cpu": context_model = context_model.to(device) # Set pad token if not set if context_tokenizer.pad_token is None: context_tokenizer.pad_token = context_tokenizer.eos_token use_llm_model = True print(f"Successfully loaded {context_model_name}") except Exception as e: print(f"Could not load Phi-2 model: {e}") print("Falling back to zero-shot classification model...") # Fallback to mDeBERTa for zero-shot classification context_model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" classifier = pipeline( "zero-shot-classification", model=context_model_name, device=0 if device == "cuda" else -1 ) use_llm_model = False # Load stock symbols from CSV def load_stock_symbols(csv_path="symbols.csv"): """Load Iranian stock market symbols from CSV file""" try: df = pd.read_csv(csv_path, encoding='utf-8') symbols_dict = {} for _, row in df.iterrows(): symbol = row['symbol'] symbols_dict[symbol] = { 'company': row['company_name'], 'bazaar': row['bazaar'], 'bazaar_group': row['bazaar_group'] } return symbols_dict except Exception as e: print(f"Error loading symbols CSV: {e}") # Provide default symbols for demo return { 'وبصادر': {'company': 'بانک صادرات ایران', 'bazaar': 'بورس - بازار دوم', 'bazaar_group': 'بانکها و موسسات اعتباری'}, 'فولاد': {'company': 'فولاد مبارکه اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فلزات اساسی'}, 'فارس': {'company': 'پتروشیمی فارس', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'محصولات شیمیایی'}, 'شپنا': {'company': 'پالایش نفت اصفهان', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'فرآورده‌های نفتی'}, 'خودرو': {'company': 'ایران خودرو', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'خودرو'}, 'وبملت': {'company': 'بانک ملت', 'bazaar': 'بورس - بازار اول', 'bazaar_group': 'بانکها'}, 'وتوسکا': {'company': 'سرمایه گذاری توسعه توکا', 'bazaar': 'بورس', 'bazaar_group': 'سرمایه گذاریها'}, 'پی پاد': {'company': 'پرداخت الکترونیک پاسارگاد', 'bazaar': 'بورس', 'bazaar_group': 'رایانه و فعالیت های وابسته'}, } # Load symbols STOCK_SYMBOLS = load_stock_symbols() SYMBOL_NAMES = set(STOCK_SYMBOLS.keys()) # Market context keywords MARKET_KEYWORDS = { 'سهام', 'سهم', 'بورس', 'فرابورس', 'معامله', 'معاملات', 'خرید', 'فروش', 'قیمت', 'ارزش', 'بازار', 'سرمایه', 'سرمایه‌گذاری', 'پرتفوی', 'نماد', 'شاخص', 'حجم', 'عرضه', 'تقاضا', 'صف', 'نوسان', 'بازدهی', 'سود', 'زیان', 'ریال', 'تومان', 'میلیارد', 'میلیون', 'درصد', 'رشد', 'افت' } def use_phi_for_disambiguation(text: str, potential_symbol: str, symbol_info: Dict) -> float: """ Use Phi-2 model to determine if a word is used as a stock symbol Returns confidence score (0-1) """ if not use_llm_model: # Use zero-shot classification instead return use_zero_shot_classification(text, potential_symbol, symbol_info) try: # Create a simple prompt for Phi-2 prompt = f"""Analyze this Persian text and determine if "{potential_symbol}" is used as a stock market symbol. Context: "{potential_symbol}" could be a stock symbol for {symbol_info['company']} company. Text: {text} Answer with only "STOCK" if it's a stock symbol, or "WORD" if it's a regular word: Answer: """ # Tokenize and generate inputs = context_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = context_model.generate( **inputs, max_new_tokens=10, temperature=0.1, do_sample=False, pad_token_id=context_tokenizer.eos_token_id ) # Decode the response response = context_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) response = response.strip().upper() # Parse response if "STOCK" in response: return 0.9 elif "WORD" in response: return 0.1 else: # Ambiguous response, use heuristics return 0.5 except Exception as e: print(f"Phi-2 inference error: {e}") return 0.5 def use_zero_shot_classification(text: str, potential_symbol: str, symbol_info: Dict) -> float: """ Fallback: Use zero-shot classification to determine if a word is a stock symbol """ try: # Get context around the symbol symbol_pos = text.find(potential_symbol) if symbol_pos == -1: return 0.5 start = max(0, symbol_pos - 100) end = min(len(text), symbol_pos + len(potential_symbol) + 100) context_text = text[start:end] # Define candidate labels candidate_labels = [ f"نماد بورسی {symbol_info['company']}", f"کلمه عادی {potential_symbol}", "stock market symbol", "regular word" ] # Perform classification result = classifier( context_text, candidate_labels=candidate_labels, multi_label=False ) # Check top label top_label = result['labels'][0] top_score = result['scores'][0] if 'نماد بورسی' in top_label or 'stock' in top_label.lower(): return top_score else: return 1 - top_score except Exception as e: print(f"Classification error: {e}") return 0.5 def check_stock_symbol_context(text: str, potential_symbol: str, symbol_info: Dict) -> Tuple[bool, float]: """ Check if a potential symbol is actually used as a stock symbol in context """ # Get surrounding context symbol_pos = text.find(potential_symbol) if symbol_pos == -1: return False, 0.0 start_context = max(0, symbol_pos - 100) end_context = min(len(text), symbol_pos + len(potential_symbol) + 100) context_window = text[start_context:end_context] # Count market keywords words_in_context = context_window.split() market_keyword_count = sum(1 for word in words_in_context if word in MARKET_KEYWORDS) # Calculate heuristic score heuristic_score = min(market_keyword_count * 0.2, 1.0) # Strong heuristic signals if market_keyword_count >= 5: return True, 0.95 elif market_keyword_count == 0 and len(words_in_context) > 10: return False, 0.05 # Use AI model for disambiguation if use_llm_model: ai_score = use_phi_for_disambiguation(context_window, potential_symbol, symbol_info) else: ai_score = use_zero_shot_classification(context_window, potential_symbol, symbol_info) # Combine scores final_score = (heuristic_score * 0.3 + ai_score * 0.7) # Decision threshold is_stock = final_score > 0.5 return is_stock, final_score def find_stock_symbols_in_text(text: str) -> List[Dict]: """Find and validate stock symbols in text""" found_symbols = [] processed_positions = set() # Pattern to match Persian/Arabic words pattern = r'\b[\u0600-\u06FF]+\b' for match in re.finditer(pattern, text): word = match.group() if word in SYMBOL_NAMES and match.start() not in processed_positions: symbol_info = STOCK_SYMBOLS[word] # Check context is_stock, confidence = check_stock_symbol_context(text, word, symbol_info) if is_stock: found_symbols.append({ 'word': word, 'start': match.start(), 'end': match.end(), 'entity_group': 'STOCK', 'score': confidence, 'company': symbol_info['company'], 'bazaar': symbol_info['bazaar'], 'bazaar_group': symbol_info['bazaar_group'] }) processed_positions.add(match.start()) return found_symbols # Label colors and names label_colors = { "B-PER": "#FF6B6B", "I-PER": "#FFB3B3", "B-ORG": "#4ECDC4", "I-ORG": "#A7E9E4", "B-LOC": "#95E1D3", "I-LOC": "#C7F0E8", "B-DAT": "#FFA07A", "I-DAT": "#FFDAB9", "B-TIM": "#DDA0DD", "I-TIM": "#E6D0E6", "B-MON": "#FFD700", "I-MON": "#FFEB99", "B-PCT": "#87CEEB", "I-PCT": "#B3DFEF", "STOCK": "#00FA9A", } label_names = { "PER": "شخص (Person)", "ORG": "سازمان (Organization)", "LOC": "مکان (Location)", "DAT": "تاریخ (Date)", "TIM": "زمان (Time)", "MON": "پول (Money)", "PCT": "درصد (Percent)", "STOCK": "نماد بورسی (Stock Symbol)", } def merge_overlapping_entities(entities: List[Dict], stock_entities: List[Dict]) -> List[Dict]: """Merge entities, removing overlaps""" all_entities = [] all_entities.extend(stock_entities) for ner_ent in entities: overlap = False for stock_ent in stock_entities: if not (ner_ent['end'] <= stock_ent['start'] or ner_ent['start'] >= stock_ent['end']): overlap = True break if not overlap: all_entities.append(ner_ent) return all_entities def highlight_entities(text, all_entities): """Create HTML with highlighted entities""" if not all_entities: return text entities_sorted = sorted(all_entities, key=lambda x: x['start'], reverse=True) result = text for entity in entities_sorted: start = entity['start'] end = entity['end'] label = entity['entity_group'] word = text[start:end] score = entity['score'] color = label_colors.get(label if label == 'STOCK' else f"B-{label}", "#CCCCCC") tooltip_info = f"{label} (confidence: {score:.2f})" if label == 'STOCK': company = entity.get('company', '') bazaar = entity.get('bazaar', '') if company: tooltip_info = f"{company} - {bazaar} (confidence: {score:.2f})" highlighted = f'{word} [{label}]' result = result[:start] + highlighted + result[end:] return result def perform_ner(text): """Perform integrated NER and stock symbol detection""" if not text.strip(): return "

لطفا متن فارسی وارد کنید (Please enter Persian text)

", "" try: # Perform standard NER entities = ner_pipeline(text) # Find stock symbols stock_entities = find_stock_symbols_in_text(text) # Merge entities all_entities = merge_overlapping_entities(entities, stock_entities) # Create highlighted HTML highlighted_html = f"""
{highlight_entities(text, all_entities)}
""" # Create entities table if all_entities: entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n" entity_info += "| کلمه (Word) | نوع (Type) | جزئیات (Details) | اطمینان (Confidence) |\n" entity_info += "|:------------|:-----------|:------------------|:---------------------|\n" all_entities.sort(key=lambda x: x['start']) for ent in all_entities: label_fa = label_names.get(ent['entity_group'], ent['entity_group']) details = "" if ent['entity_group'] == 'STOCK': company = ent.get('company', '') bazaar = ent.get('bazaar', '') group = ent.get('bazaar_group', '') details = f"{company}
{bazaar}
{group}" entity_info += f"| **{ent['word']}** | {label_fa} | {details} | {ent['score']:.2%} |\n" else: entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)" # Statistics stats = f"\n\n### آمار (Statistics):\n" stats += f"- تعداد کل موجودیت‌ها: {len(all_entities)}\n" stats += f"- نمادهای بورسی: {len([e for e in all_entities if e['entity_group'] == 'STOCK'])}\n" stats += f"- اشخاص: {len([e for e in all_entities if e['entity_group'] == 'PER'])}\n" stats += f"- سازمان‌ها: {len([e for e in all_entities if e['entity_group'] == 'ORG'])}\n" stats += f"- مکان‌ها: {len([e for e in all_entities if e['entity_group'] == 'LOC'])}\n" return highlighted_html, entity_info + stats except Exception as e: return f"

خطا (Error): {str(e)}

", str(e) # Examples examples = [ ["علی احمدی دیروز در تهران با مدیر شرکت ملی نفت ایران دیدار کرد."], ["سهام وبصادر و فولاد در بورس امروز با افزایش قیمت مواجه شدند."], ["صنعت فولاد در اصفهان یکی از مهمترین صنایع کشور است."], ["قیمت سهام شپنا در معاملات امروز ۵ درصد رشد داشت و به ۱۲۰۰۰ ریال رسید."], ["بانک ملت اعلام کرد که سود سهام وبملت را در تاریخ ۱۵ خرداد ۱۴۰۳ پرداخت خواهد کرد."], ["شرکت فولاد مبارکه با نماد فولاد در بورس تهران فعال است و محصولات فولادی تولید می‌کند."], ["من دیروز ۱۰۰۰ سهم از وتوسکا خریدم و امیدوارم تا پایان هفته ۲۰ درصد سود کنم."], ] # Gradio interface with gr.Blocks( title="Persian NER + Stock Symbols | شناسایی موجودیت‌ها و نمادهای بورسی", theme=gr.themes.Soft(), css=""" .rtl-text { direction: rtl; text-align: right; font-family: 'B Nazanin', Tahoma, Arial; } """ ) as demo: gr.Markdown(f""" # 🏦 شناسایی هوشمند موجودیت‌ها و نمادهای بورس ایران ## Persian Named Entity Recognition with Stock Symbol Detection ### Using {context_model_name.split('/')[-1]} for Context Understanding
این برنامه متن‌های فارسی را تحلیل کرده و موجودیت‌های مختلف را شناسایی می‌کند.
""") with gr.Row(): with gr.Column(scale=6): input_text = gr.Textbox( label="متن فارسی را وارد کنید (Enter Persian Text)", placeholder="مثال: سهام فولاد در بورس تهران معامله می‌شود...", lines=6, rtl=True, elem_classes=["rtl-text"] ) with gr.Row(): submit_btn = gr.Button("🔍 تحلیل متن", variant="primary", scale=2) clear_btn = gr.Button("🗑️ پاک کردن", scale=1) with gr.Column(scale=6): output_html = gr.HTML( label="نتیجه تحلیل (Analysis Result)", elem_classes=["rtl-text"] ) with gr.Row(): output_entities = gr.Markdown( label="جدول موجودیت‌ها (Entity Table)", elem_classes=["rtl-text"] ) gr.Examples( examples=examples, inputs=input_text, label="نمونه‌های آماده (Ready Examples)", examples_per_page=4 ) # User guide with gr.Accordion("📖 راهنمای استفاده (User Guide)", open=True): gr.Markdown("""
## چگونه از این برنامه استفاده کنیم؟ 1. **متن فارسی خود را در کادر بالا وارد کنید** 2. **دکمه «تحلیل متن» را بزنید** 3. **نتایج را در دو بخش مشاهده کنید:** - متن با موجودیت‌های رنگی شده - جدول کامل موجودیت‌ها با جزئیات ## انواع موجودیت‌هایی که شناسایی می‌شوند: | رنگ | نوع | مثال | |:---:|:----|:-----| | 🔴 قرمز | **اشخاص** | علی احمدی، مریم رضایی | | 🔵 آبی | **سازمان‌ها** | شرکت ملی نفت، بانک ملت | | 🟢 سبز | **مکان‌ها** | تهران، اصفهان، ایران | | 🟠 نارنجی | **تاریخ‌ها** | ۱۵ خرداد ۱۴۰۳ | | 🟣 بنفش | **زمان‌ها** | ساعت ۱۰ صبح | | 🟡 زرد | **مبالغ پولی** | ۱۰۰۰ ریال، ۵ میلیارد تومان | | 🔷 آبی آسمانی | **درصدها** | ۲۰ درصد، ۵٪ | | 💚 سبز روشن | **نمادهای بورسی** | فولاد، وبملت، شپنا | ## ویژگی خاص: تشخیص هوشمند نمادهای بورسی برنامه با استفاده از **هوش مصنوعی** تشخیص می‌دهد که آیا یک کلمه نماد بورسی است یا خیر. **مثال:** - «سهام **فولاد** در بورس معامله شد» ← فولاد = نماد بورسی ✅ - «صنعت **فولاد** در کشور مهم است» ← فولاد = کلمه عادی ❌ ## نحوه تفسیر نتایج: - **رنگ‌ها**: نوع موجودیت را نشان می‌دهند - **برچسب‌ها**: نوع موجودیت به صورت مختصر - **درصد اطمینان**: میزان اطمینان سیستم (۰-۱۰۰٪) - **جزئیات نمادها**: نام شرکت، بازار و گروه صنعت ## مدل‌های استفاده شده: - **ParsBERT NER**: شناسایی موجودیت‌های عمومی - **Microsoft Phi-2 / mDeBERTa**: تحلیل هوشمند متن برای تشخیص نمادهای بورسی
""") # Event handlers submit_btn.click( fn=perform_ner, inputs=input_text, outputs=[output_html, output_entities] ) clear_btn.click( lambda: ("", "", ""), outputs=[input_text, output_html, output_entities] ) input_text.submit( fn=perform_ner, inputs=input_text, outputs=[output_html, output_entities] ) # Launch if __name__ == "__main__": print("Starting Persian NER + Stock Symbol Detection System...") print(f"Using device: {device}") print(f"Loaded {len(STOCK_SYMBOLS)} stock symbols") print("Models loaded:") print(" - NER: HooshvareLab/bert-base-parsbert-ner-uncased") print(f" - Context: {context_model_name}") demo.launch( share=False, debug=True )