Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
| import torch | |
| import re | |
| import csv | |
| import os | |
| # Set device to CPU explicitly | |
| device = "cpu" | |
| # Load the model and tokenizer | |
| model_name = "HooshvareLab/bert-base-parsbert-ner-uncased" | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForTokenClassification.from_pretrained(model_name) | |
| model.to(device) | |
| # Create NER pipeline | |
| ner_pipeline = pipeline( | |
| "ner", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=-1, # -1 means CPU | |
| aggregation_strategy="simple" # Groups entities together | |
| ) | |
| # Load stock symbols from CSV file | |
| def load_stock_symbols_from_csv(csv_path='symbols.csv'): | |
| """Load stock symbols from CSV file""" | |
| stock_symbols = {} | |
| try: | |
| with open(csv_path, 'r', encoding='utf-8') as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| symbol = row['symbol'] | |
| company_name = row['company_name'] | |
| stock_symbols[symbol] = company_name | |
| print(f"Loaded {len(stock_symbols)} stock symbols from CSV") | |
| except FileNotFoundError: | |
| print(f"Warning: {csv_path} not found. Using default symbols.") | |
| return stock_symbols | |
| # Load stock symbols | |
| STOCK_SYMBOLS = load_stock_symbols_from_csv() | |
| # Hypernym patterns (generic terms that can be made more specific) | |
| HYPERNYM_PATTERNS = { | |
| "شرکت": "ORG", | |
| "سازمان": "ORG", | |
| "موسسه": "ORG", | |
| "بانک": "ORG", | |
| "دانشگاه": "ORG", | |
| "شهر": "LOC", | |
| "کشور": "LOC", | |
| "استان": "LOC", | |
| "آقای": "PER", | |
| "خانم": "PER", | |
| "دکتر": "PER", | |
| "مهندس": "PER", | |
| } | |
| # Label mapping for better readability | |
| label_colors = { | |
| "B-PER": "#FF6B6B", | |
| "I-PER": "#FFB3B3", | |
| "B-ORG": "#4ECDC4", | |
| "I-ORG": "#A7E9E4", | |
| "B-LOC": "#95E1D3", | |
| "I-LOC": "#C7F0E8", | |
| "B-DAT": "#FFA07A", | |
| "I-DAT": "#FFDAB9", | |
| "B-TIM": "#DDA0DD", | |
| "I-TIM": "#E6D0E6", | |
| "B-MON": "#FFD700", | |
| "I-MON": "#FFEB99", | |
| "B-PCT": "#87CEEB", | |
| "I-PCT": "#B3DFEF", | |
| "STK": "#FF1493", # Stock symbol - Deep Pink | |
| "HYP": "#A9A9A9", # Hypernym - Dark Gray | |
| } | |
| label_names = { | |
| "PER": "شخص (Person)", | |
| "ORG": "سازمان (Organization)", | |
| "LOC": "مکان (Location)", | |
| "DAT": "تاریخ (Date)", | |
| "TIM": "زمان (Time)", | |
| "MON": "پول (Money)", | |
| "PCT": "درصد (Percent)", | |
| "STK": "نماد بورس (Stock Symbol)", | |
| "HYP": "واژه عمومی (Hypernym)", | |
| } | |
| def detect_stock_symbols(text): | |
| """Detect Persian stock market symbols in text""" | |
| stock_entities = [] | |
| # Split text into words | |
| words = re.findall(r'[\u0600-\u06FF]+', text) | |
| for word in words: | |
| if word in STOCK_SYMBOLS: | |
| # Find all occurrences of this symbol in the text | |
| for match in re.finditer(re.escape(word), text): | |
| stock_entities.append({ | |
| 'entity_group': 'STK', | |
| 'word': word, | |
| 'start': match.start(), | |
| 'end': match.end(), | |
| 'score': 0.99, # High confidence for dictionary match | |
| 'full_name': STOCK_SYMBOLS[word] | |
| }) | |
| return stock_entities | |
| def detect_hypernyms(text, entities): | |
| """Detect hypernyms (general terms) in text and classify them""" | |
| hypernym_entities = [] | |
| for hypernym, entity_type in HYPERNYM_PATTERNS.items(): | |
| for match in re.finditer(re.escape(hypernym), text): | |
| start, end = match.start(), match.end() | |
| # Check if this position already has a specific entity | |
| is_covered = False | |
| for ent in entities: | |
| if start >= ent['start'] and end <= ent['end']: | |
| is_covered = True | |
| break | |
| if not is_covered: | |
| hypernym_entities.append({ | |
| 'entity_group': 'HYP', | |
| 'word': hypernym, | |
| 'start': start, | |
| 'end': end, | |
| 'score': 0.95, | |
| 'base_type': entity_type, | |
| 'is_hypernym': True | |
| }) | |
| return hypernym_entities | |
| def merge_entities(entities, stock_entities, hypernym_entities): | |
| """Merge all entity types and remove overlaps, prioritizing specific entities""" | |
| all_entities = entities + stock_entities + hypernym_entities | |
| # Sort by start position | |
| all_entities.sort(key=lambda x: x['start']) | |
| # Remove overlapping entities (keep higher priority) | |
| # Priority: STK > specific entities > HYP | |
| filtered_entities = [] | |
| for entity in all_entities: | |
| overlaps = False | |
| for existing in filtered_entities: | |
| # Check for overlap | |
| if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']): | |
| overlaps = True | |
| # If new entity is stock symbol, replace existing | |
| if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK': | |
| filtered_entities.remove(existing) | |
| overlaps = False | |
| # If existing is hypernym and new is specific, replace | |
| elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP': | |
| filtered_entities.remove(existing) | |
| overlaps = False | |
| break | |
| if not overlaps: | |
| filtered_entities.append(entity) | |
| return sorted(filtered_entities, key=lambda x: x['start']) | |
| def highlight_entities(text, entities): | |
| """Create HTML with highlighted entities""" | |
| if not entities: | |
| return text | |
| # Sort entities by start position (reverse order to replace from end to start) | |
| entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True) | |
| result = text | |
| for entity in entities_sorted: | |
| start = entity['start'] | |
| end = entity['end'] | |
| label = entity['entity_group'] | |
| word = text[start:end] | |
| score = entity['score'] | |
| # Get color for this label | |
| if label == 'STK': | |
| color = label_colors.get('STK') | |
| extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else "" | |
| title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})" | |
| elif label == 'HYP': | |
| color = label_colors.get('HYP') | |
| base_type = entity.get('base_type', '') | |
| title_text = f"Hypernym (general term for {base_type})" | |
| else: | |
| color = label_colors.get(f"B-{label}", "#CCCCCC") | |
| title_text = f"{label} (confidence: {score:.2f})" | |
| # Create highlighted span | |
| highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{title_text}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>' | |
| result = result[:start] + highlighted + result[end:] | |
| return result | |
| def perform_ner(text): | |
| """Perform NER on input text""" | |
| if not text.strip(): | |
| return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", "" | |
| try: | |
| # Perform base NER | |
| entities = ner_pipeline(text) | |
| # Detect stock symbols | |
| stock_entities = detect_stock_symbols(text) | |
| # Detect hypernyms | |
| hypernym_entities = detect_hypernyms(text, entities) | |
| # Merge all entities | |
| all_entities = merge_entities(entities, stock_entities, hypernym_entities) | |
| # Create highlighted version | |
| highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, all_entities)}</div>" | |
| # Create entities table | |
| if all_entities: | |
| entity_info = "### موجودیتهای شناسایی شده (Detected Entities):\n\n" | |
| entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) | اطلاعات اضافی (Additional Info) |\n" | |
| entity_info += "|------------|-----------|---------------------|----------------------------------|\n" | |
| for ent in all_entities: | |
| label_fa = label_names.get(ent['entity_group'], ent['entity_group']) | |
| extra_info = "" | |
| if ent['entity_group'] == 'STK' and 'full_name' in ent: | |
| extra_info = ent['full_name'] | |
| elif ent['entity_group'] == 'HYP': | |
| extra_info = f"Hypernym of {ent.get('base_type', '')}" | |
| entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n" | |
| else: | |
| entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)" | |
| return highlighted_html, entity_info | |
| except Exception as e: | |
| return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", "" | |
| # Save stock symbols to CSV function | |
| def save_symbols_to_csv(output_path='symbols.csv'): | |
| """Save current stock symbols to CSV file""" | |
| with open(output_path, 'w', encoding='utf-8', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow(['symbol', 'company_name']) | |
| for symbol, name in STOCK_SYMBOLS.items(): | |
| writer.writerow([symbol, name]) | |
| print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}") | |
| # Example texts | |
| examples = [ | |
| ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی میکرد."], | |
| ["شرکت گوگل در کالیفرنیا واقع شده است."], | |
| ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."], | |
| ["دانشگاه تهران یکی از قدیمیترین دانشگاههای ایران است."], | |
| ["علی و حسین به همراه مریم به مشهد سفر کردند."], | |
| ["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."], | |
| ["بانک ملت و وتجارت در بازار بورس فعال هستند."], | |
| ["آقای احمدی مدیرعامل شرکت پتروشیمی است."], | |
| ["وبملت و فملی امروز در صف خرید قرار گرفتند."], | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Persian NER - شناسایی موجودیتهای نامدار فارسی", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f""" | |
| # 🇮🇷 Persian Named Entity Recognition + Stock Symbols | |
| # شناسایی موجودیتهای نامدار فارسی + نمادهای بورس | |
| این سیستم موجودیتهای نامدار مانند اسامی اشخاص، سازمانها، مکانها، تاریخها، **نمادهای بورس** و **واژههای عمومی (Hypernyms)** را در متن فارسی شناسایی میکند. | |
| This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text. | |
| **Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection | |
| **Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE) | |
| **Running on:** CPU (may be slow for long texts) | |
| --- | |
| ### 📊 APIs for Updating Stock Symbols: | |
| **Recommended Python Libraries:** | |
| 1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data | |
| 2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM | |
| 3. **tse-dataloader** - Data extraction from Tehran Stock Exchange | |
| **Example Usage:** | |
| ```python | |
| # Using tsetmc-api | |
| from tsetmc_api import market_watch | |
| stocks = market_watch.get_market_watch() | |
| # Using tehran-stocks | |
| from tehran_stocks import Stocks | |
| all_stocks = Stocks.query.all() | |
| ``` | |
| **Official TSE Website:** https://tse.ir | |
| **TSETMC Data Portal:** http://www.tsetmc.com | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="متن فارسی خود را وارد کنید (Enter Persian Text)", | |
| placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...", | |
| lines=5, | |
| rtl=True | |
| ) | |
| submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary") | |
| with gr.Column(): | |
| output_html = gr.HTML(label="متن با موجودیتهای برجسته (Text with Highlighted Entities)") | |
| output_entities = gr.Markdown(label="لیست موجودیتها (Entity List)") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_text, | |
| label="مثالها (Examples)" | |
| ) | |
| # Legend | |
| gr.Markdown(""" | |
| ### راهنمای رنگها (Color Guide): | |
| - 🔴 **PER (شخص)**: اسامی اشخاص / Person names | |
| - 🔵 **ORG (سازمان)**: نام سازمانها / Organizations | |
| - 🟢 **LOC (مکان)**: نام مکانها / Locations | |
| - 🟠 **DAT (تاریخ)**: تاریخها / Dates | |
| - 🟣 **TIM (زمان)**: زمانها / Times | |
| - 🟡 **MON (پول)**: مقادیر پولی / Money | |
| - 🔷 **PCT (درصد)**: درصدها / Percentages | |
| - 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols | |
| - ⚫ **HYP (واژه عمومی)**: واژههای عمومی / Hypernyms (general terms) | |
| --- | |
| ### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد | |
| *برای بهروزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.* | |
| """) | |
| # Event handler | |
| submit_btn.click( | |
| fn=perform_ner, | |
| inputs=input_text, | |
| outputs=[output_html, output_entities] | |
| ) | |
| input_text.submit( | |
| fn=perform_ner, | |
| inputs=input_text, | |
| outputs=[output_html, output_entities] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |