optimopium's picture
Update app.py
3c57d2d verified
raw
history blame
14.6 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import torch
import re
import csv
import os
# Set device to CPU explicitly
device = "cpu"
# Load the model and tokenizer
model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to(device)
# Create NER pipeline
ner_pipeline = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
device=-1, # -1 means CPU
aggregation_strategy="simple" # Groups entities together
)
# Load stock symbols from CSV file
def load_stock_symbols_from_csv(csv_path='symbols.csv'):
"""Load stock symbols from CSV file"""
stock_symbols = {}
try:
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
symbol = row['symbol']
company_name = row['company_name']
stock_symbols[symbol] = company_name
print(f"Loaded {len(stock_symbols)} stock symbols from CSV")
except FileNotFoundError:
print(f"Warning: {csv_path} not found. Using default symbols.")
return stock_symbols
# Load stock symbols
STOCK_SYMBOLS = load_stock_symbols_from_csv()
# Hypernym patterns (generic terms that can be made more specific)
HYPERNYM_PATTERNS = {
"شرکت": "ORG",
"سازمان": "ORG",
"موسسه": "ORG",
"بانک": "ORG",
"دانشگاه": "ORG",
"شهر": "LOC",
"کشور": "LOC",
"استان": "LOC",
"آقای": "PER",
"خانم": "PER",
"دکتر": "PER",
"مهندس": "PER",
}
# Label mapping for better readability
label_colors = {
"B-PER": "#FF6B6B",
"I-PER": "#FFB3B3",
"B-ORG": "#4ECDC4",
"I-ORG": "#A7E9E4",
"B-LOC": "#95E1D3",
"I-LOC": "#C7F0E8",
"B-DAT": "#FFA07A",
"I-DAT": "#FFDAB9",
"B-TIM": "#DDA0DD",
"I-TIM": "#E6D0E6",
"B-MON": "#FFD700",
"I-MON": "#FFEB99",
"B-PCT": "#87CEEB",
"I-PCT": "#B3DFEF",
"STK": "#FF1493", # Stock symbol - Deep Pink
"HYP": "#A9A9A9", # Hypernym - Dark Gray
}
label_names = {
"PER": "شخص (Person)",
"ORG": "سازمان (Organization)",
"LOC": "مکان (Location)",
"DAT": "تاریخ (Date)",
"TIM": "زمان (Time)",
"MON": "پول (Money)",
"PCT": "درصد (Percent)",
"STK": "نماد بورس (Stock Symbol)",
"HYP": "واژه عمومی (Hypernym)",
}
def detect_stock_symbols(text):
"""Detect Persian stock market symbols in text"""
stock_entities = []
# Split text into words
words = re.findall(r'[\u0600-\u06FF]+', text)
for word in words:
if word in STOCK_SYMBOLS:
# Find all occurrences of this symbol in the text
for match in re.finditer(re.escape(word), text):
stock_entities.append({
'entity_group': 'STK',
'word': word,
'start': match.start(),
'end': match.end(),
'score': 0.99, # High confidence for dictionary match
'full_name': STOCK_SYMBOLS[word]
})
return stock_entities
def detect_hypernyms(text, entities):
"""Detect hypernyms (general terms) in text and classify them"""
hypernym_entities = []
for hypernym, entity_type in HYPERNYM_PATTERNS.items():
for match in re.finditer(re.escape(hypernym), text):
start, end = match.start(), match.end()
# Check if this position already has a specific entity
is_covered = False
for ent in entities:
if start >= ent['start'] and end <= ent['end']:
is_covered = True
break
if not is_covered:
hypernym_entities.append({
'entity_group': 'HYP',
'word': hypernym,
'start': start,
'end': end,
'score': 0.95,
'base_type': entity_type,
'is_hypernym': True
})
return hypernym_entities
def merge_entities(entities, stock_entities, hypernym_entities):
"""Merge all entity types and remove overlaps, prioritizing specific entities"""
all_entities = entities + stock_entities + hypernym_entities
# Sort by start position
all_entities.sort(key=lambda x: x['start'])
# Remove overlapping entities (keep higher priority)
# Priority: STK > specific entities > HYP
filtered_entities = []
for entity in all_entities:
overlaps = False
for existing in filtered_entities:
# Check for overlap
if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']):
overlaps = True
# If new entity is stock symbol, replace existing
if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK':
filtered_entities.remove(existing)
overlaps = False
# If existing is hypernym and new is specific, replace
elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP':
filtered_entities.remove(existing)
overlaps = False
break
if not overlaps:
filtered_entities.append(entity)
return sorted(filtered_entities, key=lambda x: x['start'])
def highlight_entities(text, entities):
"""Create HTML with highlighted entities"""
if not entities:
return text
# Sort entities by start position (reverse order to replace from end to start)
entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in entities_sorted:
start = entity['start']
end = entity['end']
label = entity['entity_group']
word = text[start:end]
score = entity['score']
# Get color for this label
if label == 'STK':
color = label_colors.get('STK')
extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else ""
title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})"
elif label == 'HYP':
color = label_colors.get('HYP')
base_type = entity.get('base_type', '')
title_text = f"Hypernym (general term for {base_type})"
else:
color = label_colors.get(f"B-{label}", "#CCCCCC")
title_text = f"{label} (confidence: {score:.2f})"
# Create highlighted span
highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{title_text}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
result = result[:start] + highlighted + result[end:]
return result
def perform_ner(text):
"""Perform NER on input text"""
if not text.strip():
return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
try:
# Perform base NER
entities = ner_pipeline(text)
# Detect stock symbols
stock_entities = detect_stock_symbols(text)
# Detect hypernyms
hypernym_entities = detect_hypernyms(text, entities)
# Merge all entities
all_entities = merge_entities(entities, stock_entities, hypernym_entities)
# Create highlighted version
highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, all_entities)}</div>"
# Create entities table
if all_entities:
entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) | اطلاعات اضافی (Additional Info) |\n"
entity_info += "|------------|-----------|---------------------|----------------------------------|\n"
for ent in all_entities:
label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
extra_info = ""
if ent['entity_group'] == 'STK' and 'full_name' in ent:
extra_info = ent['full_name']
elif ent['entity_group'] == 'HYP':
extra_info = f"Hypernym of {ent.get('base_type', '')}"
entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n"
else:
entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
return highlighted_html, entity_info
except Exception as e:
return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""
# Save stock symbols to CSV function
def save_symbols_to_csv(output_path='symbols.csv'):
"""Save current stock symbols to CSV file"""
with open(output_path, 'w', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(['symbol', 'company_name'])
for symbol, name in STOCK_SYMBOLS.items():
writer.writerow([symbol, name])
print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}")
# Example texts
examples = [
["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی می‌کرد."],
["شرکت گوگل در کالیفرنیا واقع شده است."],
["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
["دانشگاه تهران یکی از قدیمی‌ترین دانشگاه‌های ایران است."],
["علی و حسین به همراه مریم به مشهد سفر کردند."],
["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."],
["بانک ملت و وتجارت در بازار بورس فعال هستند."],
["آقای احمدی مدیرعامل شرکت پتروشیمی است."],
["وبملت و فملی امروز در صف خرید قرار گرفتند."],
]
# Create Gradio interface
with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نامدار فارسی", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# 🇮🇷 Persian Named Entity Recognition + Stock Symbols
# شناسایی موجودیت‌های نامدار فارسی + نمادهای بورس
این سیستم موجودیت‌های نامدار مانند اسامی اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها، **نمادهای بورس** و **واژه‌های عمومی (Hypernyms)** را در متن فارسی شناسایی می‌کند.
This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text.
**Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection
**Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE)
**Running on:** CPU (may be slow for long texts)
---
### 📊 APIs for Updating Stock Symbols:
**Recommended Python Libraries:**
1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data
2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM
3. **tse-dataloader** - Data extraction from Tehran Stock Exchange
**Example Usage:**
```python
# Using tsetmc-api
from tsetmc_api import market_watch
stocks = market_watch.get_market_watch()
# Using tehran-stocks
from tehran_stocks import Stocks
all_stocks = Stocks.query.all()
```
**Official TSE Website:** https://tse.ir
**TSETMC Data Portal:** http://www.tsetmc.com
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="متن فارسی خود را وارد کنید (Enter Persian Text)",
placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...",
lines=5,
rtl=True
)
submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary")
with gr.Column():
output_html = gr.HTML(label="متن با موجودیت‌های برجسته (Text with Highlighted Entities)")
output_entities = gr.Markdown(label="لیست موجودیت‌ها (Entity List)")
gr.Examples(
examples=examples,
inputs=input_text,
label="مثال‌ها (Examples)"
)
# Legend
gr.Markdown("""
### راهنمای رنگ‌ها (Color Guide):
- 🔴 **PER (شخص)**: اسامی اشخاص / Person names
- 🔵 **ORG (سازمان)**: نام سازمان‌ها / Organizations
- 🟢 **LOC (مکان)**: نام مکان‌ها / Locations
- 🟠 **DAT (تاریخ)**: تاریخ‌ها / Dates
- 🟣 **TIM (زمان)**: زمان‌ها / Times
- 🟡 **MON (پول)**: مقادیر پولی / Money
- 🔷 **PCT (درصد)**: درصدها / Percentages
- 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols
- ⚫ **HYP (واژه عمومی)**: واژه‌های عمومی / Hypernyms (general terms)
---
### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد
*برای به‌روزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.*
""")
# Event handler
submit_btn.click(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
input_text.submit(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
# Launch the app
if __name__ == "__main__":
demo.launch()