Spaces:
Sleeping
Sleeping
File size: 14,561 Bytes
a39d4c2 3c57d2d ee8ab7e a39d4c2 ee8ab7e a39d4c2 67ad485 a39d4c2 67ad485 a39d4c2 ee8ab7e 3c57d2d a39d4c2 3c57d2d a39d4c2 ee8ab7e a39d4c2 3c57d2d a39d4c2 eb83f82 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 ee8ab7e a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 ee8ab7e 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 eb83f82 a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 3c57d2d a39d4c2 ee8ab7e a39d4c2 ee8ab7e a39d4c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import torch
import re
import csv
import os
# Set device to CPU explicitly
device = "cpu"
# Load the model and tokenizer
model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to(device)
# Create NER pipeline
ner_pipeline = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
device=-1, # -1 means CPU
aggregation_strategy="simple" # Groups entities together
)
# Load stock symbols from CSV file
def load_stock_symbols_from_csv(csv_path='symbols.csv'):
"""Load stock symbols from CSV file"""
stock_symbols = {}
try:
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
symbol = row['symbol']
company_name = row['company_name']
stock_symbols[symbol] = company_name
print(f"Loaded {len(stock_symbols)} stock symbols from CSV")
except FileNotFoundError:
print(f"Warning: {csv_path} not found. Using default symbols.")
return stock_symbols
# Load stock symbols
STOCK_SYMBOLS = load_stock_symbols_from_csv()
# Hypernym patterns (generic terms that can be made more specific)
HYPERNYM_PATTERNS = {
"شرکت": "ORG",
"سازمان": "ORG",
"موسسه": "ORG",
"بانک": "ORG",
"دانشگاه": "ORG",
"شهر": "LOC",
"کشور": "LOC",
"استان": "LOC",
"آقای": "PER",
"خانم": "PER",
"دکتر": "PER",
"مهندس": "PER",
}
# Label mapping for better readability
label_colors = {
"B-PER": "#FF6B6B",
"I-PER": "#FFB3B3",
"B-ORG": "#4ECDC4",
"I-ORG": "#A7E9E4",
"B-LOC": "#95E1D3",
"I-LOC": "#C7F0E8",
"B-DAT": "#FFA07A",
"I-DAT": "#FFDAB9",
"B-TIM": "#DDA0DD",
"I-TIM": "#E6D0E6",
"B-MON": "#FFD700",
"I-MON": "#FFEB99",
"B-PCT": "#87CEEB",
"I-PCT": "#B3DFEF",
"STK": "#FF1493", # Stock symbol - Deep Pink
"HYP": "#A9A9A9", # Hypernym - Dark Gray
}
label_names = {
"PER": "شخص (Person)",
"ORG": "سازمان (Organization)",
"LOC": "مکان (Location)",
"DAT": "تاریخ (Date)",
"TIM": "زمان (Time)",
"MON": "پول (Money)",
"PCT": "درصد (Percent)",
"STK": "نماد بورس (Stock Symbol)",
"HYP": "واژه عمومی (Hypernym)",
}
def detect_stock_symbols(text):
"""Detect Persian stock market symbols in text"""
stock_entities = []
# Split text into words
words = re.findall(r'[\u0600-\u06FF]+', text)
for word in words:
if word in STOCK_SYMBOLS:
# Find all occurrences of this symbol in the text
for match in re.finditer(re.escape(word), text):
stock_entities.append({
'entity_group': 'STK',
'word': word,
'start': match.start(),
'end': match.end(),
'score': 0.99, # High confidence for dictionary match
'full_name': STOCK_SYMBOLS[word]
})
return stock_entities
def detect_hypernyms(text, entities):
"""Detect hypernyms (general terms) in text and classify them"""
hypernym_entities = []
for hypernym, entity_type in HYPERNYM_PATTERNS.items():
for match in re.finditer(re.escape(hypernym), text):
start, end = match.start(), match.end()
# Check if this position already has a specific entity
is_covered = False
for ent in entities:
if start >= ent['start'] and end <= ent['end']:
is_covered = True
break
if not is_covered:
hypernym_entities.append({
'entity_group': 'HYP',
'word': hypernym,
'start': start,
'end': end,
'score': 0.95,
'base_type': entity_type,
'is_hypernym': True
})
return hypernym_entities
def merge_entities(entities, stock_entities, hypernym_entities):
"""Merge all entity types and remove overlaps, prioritizing specific entities"""
all_entities = entities + stock_entities + hypernym_entities
# Sort by start position
all_entities.sort(key=lambda x: x['start'])
# Remove overlapping entities (keep higher priority)
# Priority: STK > specific entities > HYP
filtered_entities = []
for entity in all_entities:
overlaps = False
for existing in filtered_entities:
# Check for overlap
if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']):
overlaps = True
# If new entity is stock symbol, replace existing
if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK':
filtered_entities.remove(existing)
overlaps = False
# If existing is hypernym and new is specific, replace
elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP':
filtered_entities.remove(existing)
overlaps = False
break
if not overlaps:
filtered_entities.append(entity)
return sorted(filtered_entities, key=lambda x: x['start'])
def highlight_entities(text, entities):
"""Create HTML with highlighted entities"""
if not entities:
return text
# Sort entities by start position (reverse order to replace from end to start)
entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in entities_sorted:
start = entity['start']
end = entity['end']
label = entity['entity_group']
word = text[start:end]
score = entity['score']
# Get color for this label
if label == 'STK':
color = label_colors.get('STK')
extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else ""
title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})"
elif label == 'HYP':
color = label_colors.get('HYP')
base_type = entity.get('base_type', '')
title_text = f"Hypernym (general term for {base_type})"
else:
color = label_colors.get(f"B-{label}", "#CCCCCC")
title_text = f"{label} (confidence: {score:.2f})"
# Create highlighted span
highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{title_text}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
result = result[:start] + highlighted + result[end:]
return result
def perform_ner(text):
"""Perform NER on input text"""
if not text.strip():
return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
try:
# Perform base NER
entities = ner_pipeline(text)
# Detect stock symbols
stock_entities = detect_stock_symbols(text)
# Detect hypernyms
hypernym_entities = detect_hypernyms(text, entities)
# Merge all entities
all_entities = merge_entities(entities, stock_entities, hypernym_entities)
# Create highlighted version
highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, all_entities)}</div>"
# Create entities table
if all_entities:
entity_info = "### موجودیتهای شناسایی شده (Detected Entities):\n\n"
entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) | اطلاعات اضافی (Additional Info) |\n"
entity_info += "|------------|-----------|---------------------|----------------------------------|\n"
for ent in all_entities:
label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
extra_info = ""
if ent['entity_group'] == 'STK' and 'full_name' in ent:
extra_info = ent['full_name']
elif ent['entity_group'] == 'HYP':
extra_info = f"Hypernym of {ent.get('base_type', '')}"
entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n"
else:
entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
return highlighted_html, entity_info
except Exception as e:
return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""
# Save stock symbols to CSV function
def save_symbols_to_csv(output_path='symbols.csv'):
"""Save current stock symbols to CSV file"""
with open(output_path, 'w', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(['symbol', 'company_name'])
for symbol, name in STOCK_SYMBOLS.items():
writer.writerow([symbol, name])
print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}")
# Example texts
examples = [
["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی میکرد."],
["شرکت گوگل در کالیفرنیا واقع شده است."],
["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
["دانشگاه تهران یکی از قدیمیترین دانشگاههای ایران است."],
["علی و حسین به همراه مریم به مشهد سفر کردند."],
["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."],
["بانک ملت و وتجارت در بازار بورس فعال هستند."],
["آقای احمدی مدیرعامل شرکت پتروشیمی است."],
["وبملت و فملی امروز در صف خرید قرار گرفتند."],
]
# Create Gradio interface
with gr.Blocks(title="Persian NER - شناسایی موجودیتهای نامدار فارسی", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# 🇮🇷 Persian Named Entity Recognition + Stock Symbols
# شناسایی موجودیتهای نامدار فارسی + نمادهای بورس
این سیستم موجودیتهای نامدار مانند اسامی اشخاص، سازمانها، مکانها، تاریخها، **نمادهای بورس** و **واژههای عمومی (Hypernyms)** را در متن فارسی شناسایی میکند.
This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text.
**Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection
**Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE)
**Running on:** CPU (may be slow for long texts)
---
### 📊 APIs for Updating Stock Symbols:
**Recommended Python Libraries:**
1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data
2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM
3. **tse-dataloader** - Data extraction from Tehran Stock Exchange
**Example Usage:**
```python
# Using tsetmc-api
from tsetmc_api import market_watch
stocks = market_watch.get_market_watch()
# Using tehran-stocks
from tehran_stocks import Stocks
all_stocks = Stocks.query.all()
```
**Official TSE Website:** https://tse.ir
**TSETMC Data Portal:** http://www.tsetmc.com
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="متن فارسی خود را وارد کنید (Enter Persian Text)",
placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...",
lines=5,
rtl=True
)
submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary")
with gr.Column():
output_html = gr.HTML(label="متن با موجودیتهای برجسته (Text with Highlighted Entities)")
output_entities = gr.Markdown(label="لیست موجودیتها (Entity List)")
gr.Examples(
examples=examples,
inputs=input_text,
label="مثالها (Examples)"
)
# Legend
gr.Markdown("""
### راهنمای رنگها (Color Guide):
- 🔴 **PER (شخص)**: اسامی اشخاص / Person names
- 🔵 **ORG (سازمان)**: نام سازمانها / Organizations
- 🟢 **LOC (مکان)**: نام مکانها / Locations
- 🟠 **DAT (تاریخ)**: تاریخها / Dates
- 🟣 **TIM (زمان)**: زمانها / Times
- 🟡 **MON (پول)**: مقادیر پولی / Money
- 🔷 **PCT (درصد)**: درصدها / Percentages
- 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols
- ⚫ **HYP (واژه عمومی)**: واژههای عمومی / Hypernyms (general terms)
---
### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد
*برای بهروزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.*
""")
# Event handler
submit_btn.click(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
input_text.submit(
fn=perform_ner,
inputs=input_text,
outputs=[output_html, output_entities]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |