File size: 14,561 Bytes
a39d4c2
 
 
3c57d2d
 
 
ee8ab7e
a39d4c2
 
ee8ab7e
a39d4c2
 
67ad485
a39d4c2
 
 
 
67ad485
a39d4c2
 
 
 
 
 
 
 
ee8ab7e
3c57d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a39d4c2
 
3c57d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a39d4c2
ee8ab7e
a39d4c2
 
 
 
 
 
 
 
3c57d2d
 
a39d4c2
eb83f82
3c57d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a39d4c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c57d2d
 
 
 
 
 
 
 
 
 
 
a39d4c2
 
3c57d2d
a39d4c2
 
 
 
ee8ab7e
a39d4c2
 
 
 
 
 
3c57d2d
a39d4c2
 
3c57d2d
 
 
 
 
 
 
 
 
a39d4c2
3c57d2d
a39d4c2
 
3c57d2d
a39d4c2
3c57d2d
 
 
a39d4c2
3c57d2d
 
 
 
 
 
 
 
a39d4c2
 
 
 
 
 
 
ee8ab7e
3c57d2d
 
 
 
 
 
 
 
 
 
a39d4c2
 
 
 
 
 
 
3c57d2d
 
 
 
a39d4c2
 
 
 
3c57d2d
 
 
a39d4c2
3c57d2d
a39d4c2
3c57d2d
a39d4c2
3c57d2d
 
a39d4c2
3c57d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a39d4c2
 
eb83f82
a39d4c2
 
 
3c57d2d
a39d4c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c57d2d
a39d4c2
 
 
 
 
3c57d2d
 
 
 
 
 
 
 
a39d4c2
 
 
 
 
 
 
 
 
 
 
 
 
 
ee8ab7e
a39d4c2
ee8ab7e
a39d4c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import gradio as gr
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import torch
import re
import csv
import os

# Set device to CPU explicitly
device = "cpu"

# Load the model and tokenizer
model_name = "HooshvareLab/bert-base-parsbert-ner-uncased"

print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.to(device)

# Create NER pipeline
ner_pipeline = pipeline(
    "ner",
    model=model,
    tokenizer=tokenizer,
    device=-1,  # -1 means CPU
    aggregation_strategy="simple"  # Groups entities together
)

# Load stock symbols from CSV file
def load_stock_symbols_from_csv(csv_path='symbols.csv'):
    """Load stock symbols from CSV file"""
    stock_symbols = {}
    try:
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                symbol = row['symbol']
                company_name = row['company_name']
                stock_symbols[symbol] = company_name
        print(f"Loaded {len(stock_symbols)} stock symbols from CSV")
    except FileNotFoundError:
        print(f"Warning: {csv_path} not found. Using default symbols.")
    return stock_symbols

# Load stock symbols
STOCK_SYMBOLS = load_stock_symbols_from_csv()

# Hypernym patterns (generic terms that can be made more specific)
HYPERNYM_PATTERNS = {
    "شرکت": "ORG",
    "سازمان": "ORG",
    "موسسه": "ORG",
    "بانک": "ORG",
    "دانشگاه": "ORG",
    "شهر": "LOC",
    "کشور": "LOC",
    "استان": "LOC",
    "آقای": "PER",
    "خانم": "PER",
    "دکتر": "PER",
    "مهندس": "PER",
}

# Label mapping for better readability
label_colors = {
    "B-PER": "#FF6B6B",
    "I-PER": "#FFB3B3",
    "B-ORG": "#4ECDC4",
    "I-ORG": "#A7E9E4",
    "B-LOC": "#95E1D3",
    "I-LOC": "#C7F0E8",
    "B-DAT": "#FFA07A",
    "I-DAT": "#FFDAB9",
    "B-TIM": "#DDA0DD",
    "I-TIM": "#E6D0E6",
    "B-MON": "#FFD700",
    "I-MON": "#FFEB99",
    "B-PCT": "#87CEEB",
    "I-PCT": "#B3DFEF",
    "STK": "#FF1493",  # Stock symbol - Deep Pink
    "HYP": "#A9A9A9",  # Hypernym - Dark Gray
}

label_names = {
    "PER": "شخص (Person)",
    "ORG": "سازمان (Organization)",
    "LOC": "مکان (Location)",
    "DAT": "تاریخ (Date)",
    "TIM": "زمان (Time)",
    "MON": "پول (Money)",
    "PCT": "درصد (Percent)",
    "STK": "نماد بورس (Stock Symbol)",
    "HYP": "واژه عمومی (Hypernym)",
}

def detect_stock_symbols(text):
    """Detect Persian stock market symbols in text"""
    stock_entities = []
    
    # Split text into words
    words = re.findall(r'[\u0600-\u06FF]+', text)
    
    for word in words:
        if word in STOCK_SYMBOLS:
            # Find all occurrences of this symbol in the text
            for match in re.finditer(re.escape(word), text):
                stock_entities.append({
                    'entity_group': 'STK',
                    'word': word,
                    'start': match.start(),
                    'end': match.end(),
                    'score': 0.99,  # High confidence for dictionary match
                    'full_name': STOCK_SYMBOLS[word]
                })
    
    return stock_entities

def detect_hypernyms(text, entities):
    """Detect hypernyms (general terms) in text and classify them"""
    hypernym_entities = []
    
    for hypernym, entity_type in HYPERNYM_PATTERNS.items():
        for match in re.finditer(re.escape(hypernym), text):
            start, end = match.start(), match.end()
            
            # Check if this position already has a specific entity
            is_covered = False
            for ent in entities:
                if start >= ent['start'] and end <= ent['end']:
                    is_covered = True
                    break
            
            if not is_covered:
                hypernym_entities.append({
                    'entity_group': 'HYP',
                    'word': hypernym,
                    'start': start,
                    'end': end,
                    'score': 0.95,
                    'base_type': entity_type,
                    'is_hypernym': True
                })
    
    return hypernym_entities

def merge_entities(entities, stock_entities, hypernym_entities):
    """Merge all entity types and remove overlaps, prioritizing specific entities"""
    all_entities = entities + stock_entities + hypernym_entities
    
    # Sort by start position
    all_entities.sort(key=lambda x: x['start'])
    
    # Remove overlapping entities (keep higher priority)
    # Priority: STK > specific entities > HYP
    filtered_entities = []
    for entity in all_entities:
        overlaps = False
        for existing in filtered_entities:
            # Check for overlap
            if not (entity['end'] <= existing['start'] or entity['start'] >= existing['end']):
                overlaps = True
                # If new entity is stock symbol, replace existing
                if entity['entity_group'] == 'STK' and existing['entity_group'] != 'STK':
                    filtered_entities.remove(existing)
                    overlaps = False
                # If existing is hypernym and new is specific, replace
                elif existing['entity_group'] == 'HYP' and entity['entity_group'] != 'HYP':
                    filtered_entities.remove(existing)
                    overlaps = False
                break
        
        if not overlaps:
            filtered_entities.append(entity)
    
    return sorted(filtered_entities, key=lambda x: x['start'])

def highlight_entities(text, entities):
    """Create HTML with highlighted entities"""
    if not entities:
        return text
    
    # Sort entities by start position (reverse order to replace from end to start)
    entities_sorted = sorted(entities, key=lambda x: x['start'], reverse=True)
    
    result = text
    for entity in entities_sorted:
        start = entity['start']
        end = entity['end']
        label = entity['entity_group']
        word = text[start:end]
        score = entity['score']
        
        # Get color for this label
        if label == 'STK':
            color = label_colors.get('STK')
            extra_info = f" - {entity.get('full_name', '')}" if 'full_name' in entity else ""
            title_text = f"Stock Symbol{extra_info} (confidence: {score:.2f})"
        elif label == 'HYP':
            color = label_colors.get('HYP')
            base_type = entity.get('base_type', '')
            title_text = f"Hypernym (general term for {base_type})"
        else:
            color = label_colors.get(f"B-{label}", "#CCCCCC")
            title_text = f"{label} (confidence: {score:.2f})"
        
        # Create highlighted span
        highlighted = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; margin: 0 2px; display: inline-block;" title="{title_text}">{word} <sup style="font-size: 0.7em; font-weight: bold;">[{label}]</sup></span>'
        
        result = result[:start] + highlighted + result[end:]
    
    return result

def perform_ner(text):
    """Perform NER on input text"""
    if not text.strip():
        return "<p style='color: red;'>لطفا متن فارسی وارد کنید (Please enter Persian text)</p>", ""
    
    try:
        # Perform base NER
        entities = ner_pipeline(text)
        
        # Detect stock symbols
        stock_entities = detect_stock_symbols(text)
        
        # Detect hypernyms
        hypernym_entities = detect_hypernyms(text, entities)
        
        # Merge all entities
        all_entities = merge_entities(entities, stock_entities, hypernym_entities)
        
        # Create highlighted version
        highlighted_html = f"<div style='direction: rtl; text-align: right; font-size: 18px; line-height: 2; padding: 20px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;'>{highlight_entities(text, all_entities)}</div>"
        
        # Create entities table
        if all_entities:
            entity_info = "### موجودیت‌های شناسایی شده (Detected Entities):\n\n"
            entity_info += "| کلمه (Word) | نوع (Type) | اطمینان (Confidence) | اطلاعات اضافی (Additional Info) |\n"
            entity_info += "|------------|-----------|---------------------|----------------------------------|\n"
            for ent in all_entities:
                label_fa = label_names.get(ent['entity_group'], ent['entity_group'])
                extra_info = ""
                
                if ent['entity_group'] == 'STK' and 'full_name' in ent:
                    extra_info = ent['full_name']
                elif ent['entity_group'] == 'HYP':
                    extra_info = f"Hypernym of {ent.get('base_type', '')}"
                
                entity_info += f"| {ent['word']} | {label_fa} | {ent['score']:.2%} | {extra_info} |\n"
        else:
            entity_info = "هیچ موجودیتی شناسایی نشد (No entities detected)"
        
        return highlighted_html, entity_info
        
    except Exception as e:
        return f"<p style='color: red;'>خطا (Error): {str(e)}</p>", ""

# Save stock symbols to CSV function
def save_symbols_to_csv(output_path='symbols.csv'):
    """Save current stock symbols to CSV file"""
    with open(output_path, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['symbol', 'company_name'])
        for symbol, name in STOCK_SYMBOLS.items():
            writer.writerow([symbol, name])
    print(f"Saved {len(STOCK_SYMBOLS)} symbols to {output_path}")

# Example texts
examples = [
    ["باراک اوباما در هاوایی متولد شد و در شیکاگو زندگی می‌کرد."],
    ["شرکت گوگل در کالیفرنیا واقع شده است."],
    ["رضا در تهران در تاریخ ۱۵ خرداد ۱۳۸۰ متولد شد."],
    ["دانشگاه تهران یکی از قدیمی‌ترین دانشگاه‌های ایران است."],
    ["علی و حسین به همراه مریم به مشهد سفر کردند."],
    ["سهام فولاد و خودرو امروز رشد خوبی داشتند و شپنا هم صعودی بود."],
    ["بانک ملت و وتجارت در بازار بورس فعال هستند."],
    ["آقای احمدی مدیرعامل شرکت پتروشیمی است."],
    ["وبملت و فملی امروز در صف خرید قرار گرفتند."],
]

# Create Gradio interface
with gr.Blocks(title="Persian NER - شناسایی موجودیت‌های نامدار فارسی", theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"""
    # 🇮🇷 Persian Named Entity Recognition + Stock Symbols
    # شناسایی موجودیت‌های نامدار فارسی + نمادهای بورس
    
    این سیستم موجودیت‌های نامدار مانند اسامی اشخاص، سازمان‌ها، مکان‌ها، تاریخ‌ها، **نمادهای بورس** و **واژه‌های عمومی (Hypernyms)** را در متن فارسی شناسایی می‌کند.
    
    This system identifies named entities including person names, organizations, locations, dates, **stock symbols**, and **hypernyms** in Persian text.
    
    **Model:** ParsBERT-NER (HooshvareLab) + Custom Stock Symbol Detection  
    **Stock Symbols Loaded:** {len(STOCK_SYMBOLS)} symbols from Tehran Stock Exchange (TSE)  
    **Running on:** CPU (may be slow for long texts)
    
    ---
    
    ### 📊 APIs for Updating Stock Symbols:
    
    **Recommended Python Libraries:**
    1. **tsetmc-api** - `pip install tsetmc-api` - Direct access to TSETMC data
    2. **tehran-stocks** - `pip install tehran-stocks` - Full stock price history with ORM
    3. **tse-dataloader** - Data extraction from Tehran Stock Exchange
    
    **Example Usage:**
```python
    # Using tsetmc-api
    from tsetmc_api import market_watch
    stocks = market_watch.get_market_watch()
    
    # Using tehran-stocks
    from tehran_stocks import Stocks
    all_stocks = Stocks.query.all()
```
    
    **Official TSE Website:** https://tse.ir  
    **TSETMC Data Portal:** http://www.tsetmc.com
    """)
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                label="متن فارسی خود را وارد کنید (Enter Persian Text)",
                placeholder="مثال: سهام فولاد و خودرو امروز رشد کردند...",
                lines=5,
                rtl=True
            )
            submit_btn = gr.Button("🔍 تحلیل متن (Analyze Text)", variant="primary")
            
        with gr.Column():
            output_html = gr.HTML(label="متن با موجودیت‌های برجسته (Text with Highlighted Entities)")
            output_entities = gr.Markdown(label="لیست موجودیت‌ها (Entity List)")
    
    gr.Examples(
        examples=examples,
        inputs=input_text,
        label="مثال‌ها (Examples)"
    )
    
    # Legend
    gr.Markdown("""
    ### راهنمای رنگ‌ها (Color Guide):
    - 🔴 **PER (شخص)**: اسامی اشخاص / Person names
    - 🔵 **ORG (سازمان)**: نام سازمان‌ها / Organizations  
    - 🟢 **LOC (مکان)**: نام مکان‌ها / Locations
    - 🟠 **DAT (تاریخ)**: تاریخ‌ها / Dates
    - 🟣 **TIM (زمان)**: زمان‌ها / Times
    - 🟡 **MON (پول)**: مقادیر پولی / Money
    - 🔷 **PCT (درصد)**: درصدها / Percentages
    - 💗 **STK (نماد بورس)**: نمادهای بورس تهران / Tehran Stock Exchange symbols
    - ⚫ **HYP (واژه عمومی)**: واژه‌های عمومی / Hypernyms (general terms)
    
    ---
    
    ### 📝 تعداد نمادهای بورس: {len(STOCK_SYMBOLS)} نماد
    
    *برای به‌روزرسانی نمادها، فایل CSV را جایگزین کنید یا از API استفاده کنید.*
    """)
    
    # Event handler
    submit_btn.click(
        fn=perform_ner,
        inputs=input_text,
        outputs=[output_html, output_entities]
    )
    
    input_text.submit(
        fn=perform_ner,
        inputs=input_text,
        outputs=[output_html, output_entities]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()