leilaghomashchi commited on
Commit
646303e
·
verified ·
1 Parent(s): de2a64d

Upload fixed_ner_evaluator.py

Browse files
Files changed (1) hide show
  1. fixed_ner_evaluator.py +353 -0
fixed_ner_evaluator.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed NER Anonymization Evaluator
3
+ ارزیاب درست و دقیق - بدون مشکلات tokenization
4
+
5
+ این نسخه مستقیماً entities را مقایسه می‌کند بدون IOB2
6
+ """
7
+
8
+ import pandas as pd
9
+ import re
10
+ from typing import Dict, List, Set, Tuple
11
+ import gradio as gr
12
+ from datetime import datetime
13
+ import tempfile
14
+ import os
15
+
16
+
17
+ class FixedNEREvaluator:
18
+ """ارزیاب درست - مقایسه مستقیم entities"""
19
+
20
+ def __init__(self):
21
+ self.results_df = None
22
+
23
+ # الگوهای regex برای تشخیص entities
24
+ # توجه: این الگوها باید با فرمت واقعی شما match کنند
25
+ self.entity_patterns = [
26
+ # فرمت استاندارد: type-number
27
+ r'\b(COMPANY|company|PERSON|person|AMOUNT|amount|PERCENT|percent|GROUP|group|STOCK|stock)-(\d+)\b',
28
+ # فرمت با underscore: TYPE_NUMBER
29
+ r'\b(COMPANY|PERSON|AMOUNT|PERCENT|GROUP|STOCK)_(\d+)\b',
30
+ # فرمت کامل: TYPE_NUMBER_SUFFIX
31
+ r'\b(COMPANY|PERSON|AMOUNT|PERCENT|GROUP|STOCK)_(\d+)_[A-Z]+\b',
32
+ # فرمت STOCK خاص
33
+ r'\bSTOCK_SYMBOL_(\d+)(?:_[A-Z]+)?\b',
34
+ ]
35
+
36
+ def extract_entities(self, text: str) -> Set[Tuple[str, str]]:
37
+ """
38
+ استخراج entities از متن
39
+
40
+ Returns:
41
+ Set of (entity_type, entity_id) tuples
42
+ مثال: {('COMPANY', '01'), ('PERSON', '02')}
43
+ """
44
+ if pd.isna(text) or not isinstance(text, str):
45
+ return set()
46
+
47
+ entities = set()
48
+
49
+ for pattern in self.entity_patterns:
50
+ matches = re.finditer(pattern, text, re.IGNORECASE)
51
+ for match in matches:
52
+ groups = match.groups()
53
+ if len(groups) >= 2:
54
+ entity_type = groups[0].upper()
55
+ entity_id = groups[1]
56
+ # نرمال‌سازی: همه به فرمت TYPE-ID
57
+ entities.add((entity_type, entity_id))
58
+
59
+ return entities
60
+
61
+ def calculate_metrics(self, reference_entities: Set, predicted_entities: Set) -> Dict:
62
+ """
63
+ محاسبه metrics بر اساس مجموعه entities
64
+
65
+ Args:
66
+ reference_entities: مجموعه entities مرجع
67
+ predicted_entities: مجموعه entities پیش‌بینی شده
68
+
69
+ Returns:
70
+ دیکشنری شامل TP, FP, FN, Precision, Recall, F1
71
+ """
72
+ # محاسبه TP, FP, FN
73
+ tp = len(reference_entities & predicted_entities) # اشتراک
74
+ fp = len(predicted_entities - reference_entities) # پیش‌بینی اضافی
75
+ fn = len(reference_entities - predicted_entities) # فراموش شده
76
+
77
+ # محاسبه Precision, Recall, F1
78
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
79
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
80
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
81
+
82
+ # اگر هر دو خالی باشند = تطابق کامل
83
+ if len(reference_entities) == 0 and len(predicted_entities) == 0:
84
+ precision = recall = f1 = 1.0
85
+
86
+ return {
87
+ 'tp': tp,
88
+ 'fp': fp,
89
+ 'fn': fn,
90
+ 'precision': round(precision, 4),
91
+ 'recall': round(recall, 4),
92
+ 'f1': round(f1, 4)
93
+ }
94
+
95
+ def evaluate_single_row(self, reference_text: str, predicted_text: str) -> Dict:
96
+ """
97
+ ارزیابی یک سطر
98
+
99
+ Returns:
100
+ دیکشنری شامل metrics + entities برای debug
101
+ """
102
+ ref_entities = self.extract_entities(reference_text)
103
+ pred_entities = self.extract_entities(predicted_text)
104
+
105
+ metrics = self.calculate_metrics(ref_entities, pred_entities)
106
+
107
+ # اضافه کردن entities برای debug
108
+ metrics['ref_entities'] = sorted(list(ref_entities))
109
+ metrics['pred_entities'] = sorted(list(pred_entities))
110
+ metrics['matched'] = sorted(list(ref_entities & pred_entities))
111
+ metrics['missed'] = sorted(list(ref_entities - pred_entities))
112
+ metrics['extra'] = sorted(list(pred_entities - ref_entities))
113
+
114
+ return metrics
115
+
116
+ def evaluate_dataset(self, file_path: str) -> Tuple[bool, str, pd.DataFrame]:
117
+ """ارزیابی کل دیتاست"""
118
+ try:
119
+ print(f"📂 در حال خواندن فایل: {file_path}")
120
+ df = pd.read_csv(file_path, encoding='utf-8-sig')
121
+ print(f"✅ فایل خوانده شد: {len(df)} سطر")
122
+ print(f"📋 ستون‌ها: {list(df.columns)}")
123
+
124
+ # تشخیص ستون‌ها
125
+ if 'Reference_text' in df.columns and 'anonymized_text' in df.columns:
126
+ reference_col = 'Reference_text'
127
+ predicted_col = 'anonymized_text'
128
+ elif 'original_text' in df.columns and 'anonymized_text' in df.columns:
129
+ reference_col = 'original_text'
130
+ predicted_col = 'anonymized_text'
131
+ else:
132
+ return (
133
+ False,
134
+ f"❌ ستون‌های مورد نیاز یافت نشد!\n\nستون‌های موجود: {list(df.columns)}",
135
+ pd.DataFrame()
136
+ )
137
+
138
+ print(f"🔍 شروع ارزیابی...")
139
+
140
+ # ارزیابی هر سطر
141
+ results = []
142
+ for index, row in df.iterrows():
143
+ if (index + 1) % 10 == 0:
144
+ print(f" پردازش سطر {index + 1}/{len(df)}...")
145
+
146
+ metrics = self.evaluate_single_row(
147
+ str(row[reference_col]),
148
+ str(row[predicted_col])
149
+ )
150
+ results.append(metrics)
151
+
152
+ print(f"✅ ارزیابی کامل شد!")
153
+
154
+ # ایجاد DataFrame
155
+ results_df = pd.DataFrame(results)
156
+
157
+ # اضافه کردن ستون‌های اصلی
158
+ for col in df.columns:
159
+ results_df[col] = df[col].values
160
+
161
+ # ترتیب ستون‌ها
162
+ metric_cols = ['precision', 'recall', 'f1', 'tp', 'fp', 'fn']
163
+ debug_cols = ['ref_entities', 'pred_entities', 'matched', 'missed', 'extra']
164
+ main_cols = [col for col in df.columns if col in results_df.columns]
165
+
166
+ results_df = results_df[metric_cols + debug_cols + main_cols]
167
+
168
+ self.results_df = results_df
169
+
170
+ # محاسبه آمار کلی
171
+ avg_precision = results_df['precision'].mean()
172
+ avg_recall = results_df['recall'].mean()
173
+ avg_f1 = results_df['f1'].mean()
174
+
175
+ total_tp = results_df['tp'].sum()
176
+ total_fp = results_df['fp'].sum()
177
+ total_fn = results_df['fn'].sum()
178
+
179
+ # F1 کلی (macro-average)
180
+ macro_f1 = avg_f1
181
+
182
+ # F1 کلی (micro-average) - بر اساس مجموع TP/FP/FN
183
+ micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
184
+ micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
185
+ micro_f1 = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0
186
+
187
+ high_f1 = len(results_df[results_df['f1'] >= 0.9])
188
+ mid_f1 = len(results_df[results_df['f1'] >= 0.7])
189
+ low_f1 = len(results_df[results_df['f1'] < 0.5])
190
+
191
+ status = f"""✅ ارزیابی با موفقیت انجام شد!
192
+
193
+ 📊 **نتایج کلی (Direct Entity Matching):**
194
+ • Macro-Average F1: {macro_f1:.4f}
195
+ • Micro-Average F1: {micro_f1:.4f}
196
+ • میانگین Precision: {avg_precision:.4f}
197
+ • میانگین Recall: {avg_recall:.4f}
198
+
199
+ 📈 **آمار کلی:**
200
+ • کل True Positives: {total_tp}
201
+ • کل False Positives: {total_fp}
202
+ • کل False Negatives: {total_fn}
203
+ • تعداد سطرها: {len(df)}
204
+
205
+ 📊 **توزیع عملکرد:**
206
+ • F1 ≥ 0.9 (عالی): {high_f1} سطر ({high_f1/len(df)*100:.1f}%)
207
+ • F1 ≥ 0.7 (خوب): {mid_f1} سطر ({mid_f1/len(df)*100:.1f}%)
208
+ • F1 < 0.5 (ضعیف): {low_f1} سطر ({low_f1/len(df)*100:.1f}%)
209
+
210
+ 🔬 **مقایسه:**
211
+ • مرجع (انسانی): {reference_col}
212
+ • پیش‌بینی (LLM): {predicted_col}
213
+
214
+ 💡 **تفاوت با seqeval:**
215
+ این نسخه مستقیماً entities را مقایسه می‌کند بدون مشکلات tokenization
216
+ """
217
+
218
+ return True, status, results_df
219
+
220
+ except Exception as e:
221
+ import traceback
222
+ error_details = traceback.format_exc()
223
+ return False, f"❌ خطا در پردازش:\n\n{str(e)}\n\n{error_details[:500]}", pd.DataFrame()
224
+
225
+ def create_downloadable_csv(self) -> str:
226
+ """ایجاد فایل CSV برای دانلود"""
227
+ if self.results_df is None or self.results_df.empty:
228
+ return None
229
+
230
+ try:
231
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
232
+ temp_filename = f"fixed_evaluation_results_{timestamp}.csv"
233
+ temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
234
+
235
+ # تبدیل لیست‌ها به string برای CSV
236
+ df_to_save = self.results_df.copy()
237
+ for col in ['ref_entities', 'pred_entities', 'matched', 'missed', 'extra']:
238
+ if col in df_to_save.columns:
239
+ df_to_save[col] = df_to_save[col].apply(str)
240
+
241
+ df_to_save.to_csv(temp_path, index=False, encoding='utf-8-sig')
242
+
243
+ return temp_path
244
+ except Exception as e:
245
+ print(f"❌ خطا در ایجاد CSV: {str(e)}")
246
+ return None
247
+
248
+
249
+ def create_interface():
250
+ """ایجاد رابط کاربری Gradio"""
251
+
252
+ evaluator = FixedNEREvaluator()
253
+
254
+ with gr.Blocks(title="Fixed NER Evaluator", theme=gr.themes.Soft()) as demo:
255
+
256
+ gr.Markdown("""
257
+ # 🎯 ارزیاب درست و دقیق NER
258
+ ## Fixed NER Anonymization Evaluator
259
+
260
+ ### ✅ این نسخه بدون مشکلات tokenization کار می‌کند
261
+ """)
262
+
263
+ with gr.Row():
264
+ with gr.Column(scale=1):
265
+ gr.Markdown("### 📂 بارگذاری فایل")
266
+
267
+ file_input = gr.File(
268
+ label="فایل CSV (با ستون‌های Reference_text و anonymized_text)",
269
+ file_types=[".csv"]
270
+ )
271
+
272
+ evaluate_btn = gr.Button("🚀 شروع ارزیابی", variant="primary", size="lg")
273
+ download_btn = gr.Button("💾 دانلود نتایج CSV", visible=False, variant="secondary")
274
+
275
+ with gr.Column(scale=2):
276
+ status_output = gr.Markdown("آماده دریافت فایل...")
277
+
278
+ results_table = gr.Dataframe(
279
+ label="نتایج تفصیلی (10 سطر اول)",
280
+ visible=False,
281
+ wrap=True
282
+ )
283
+
284
+ download_file = gr.File(visible=False)
285
+
286
+ with gr.Accordion("📖 راهنمای استفاده", open=False):
287
+ gr.Markdown("""
288
+ ## نحوه استفاده:
289
+
290
+ 1. فایل CSV خود را آپلود کنید
291
+ 2. فایل باید شامل این ستون‌ها باشد:
292
+ - `Reference_text` (مرجع انسانی)
293
+ - `anonymized_text` (پیش‌بینی LLM)
294
+ 3. روی دکمه "شروع ارزیابی" کلیک کنید
295
+ 4. نتایج را مشاهده و دانلود کنید
296
+
297
+ ## تفاوت با نسخه قبلی:
298
+
299
+ - ✅ مستقیماً entities را مقایسه می‌کند
300
+ - ✅ بدون مشکلات tokenization
301
+ - ✅ برای فارسی کاملاً دقیق
302
+ - ✅ شامل اطلاعات debug (matched, missed, extra entities)
303
+ """)
304
+
305
+ def evaluate_file(file):
306
+ if file is None:
307
+ return (
308
+ "❌ لطفاً فایل CSV را بارگذاری کنید",
309
+ gr.Dataframe(visible=False),
310
+ gr.Button(visible=False),
311
+ gr.File(visible=False)
312
+ )
313
+
314
+ success, message, df = evaluator.evaluate_dataset(file)
315
+
316
+ if not success:
317
+ return (
318
+ f"❌ {message}",
319
+ gr.Dataframe(visible=False),
320
+ gr.Button(visible=False),
321
+ gr.File(visible=False)
322
+ )
323
+
324
+ return (
325
+ message,
326
+ gr.Dataframe(value=df.head(10), visible=True),
327
+ gr.Button(visible=True),
328
+ gr.File(visible=False)
329
+ )
330
+
331
+ def download_results():
332
+ csv_path = evaluator.create_downloadable_csv()
333
+ if csv_path and os.path.exists(csv_path):
334
+ return "✅ فایل نتایج آماده دانلود است", gr.File(value=csv_path, visible=True)
335
+ return "❌ خطا در ایجاد فایل", gr.File(visible=False)
336
+
337
+ evaluate_btn.click(
338
+ fn=evaluate_file,
339
+ inputs=[file_input],
340
+ outputs=[status_output, results_table, download_btn, download_file]
341
+ )
342
+
343
+ download_btn.click(
344
+ fn=download_results,
345
+ outputs=[status_output, download_file]
346
+ )
347
+
348
+ return demo
349
+
350
+
351
+ if __name__ == "__main__":
352
+ demo = create_interface()
353
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)