Ranjit0034 commited on
Commit
1cba4da
·
verified ·
1 Parent(s): f28b6fd

Upload scripts/benchmark.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/benchmark.py +421 -0
scripts/benchmark.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Production Benchmark Suite for FinEE
4
+ =====================================
5
+
6
+ Comprehensive evaluation with:
7
+ - Precision/Recall/F1 per field
8
+ - Bank-specific performance
9
+ - Cross-validation
10
+ - Failure case analysis
11
+ - Comparison with baselines
12
+
13
+ Author: Ranjit Behera
14
+ """
15
+
16
+ import json
17
+ import random
18
+ from pathlib import Path
19
+ from typing import List, Dict, Tuple, Optional
20
+ from dataclasses import dataclass, field
21
+ from collections import defaultdict
22
+ import time
23
+
24
+
25
+ @dataclass
26
+ class FieldMetrics:
27
+ """Metrics for a single field."""
28
+ tp: int = 0 # True positives
29
+ fp: int = 0 # False positives
30
+ fn: int = 0 # False negatives
31
+
32
+ @property
33
+ def precision(self) -> float:
34
+ if self.tp + self.fp == 0:
35
+ return 0.0
36
+ return self.tp / (self.tp + self.fp)
37
+
38
+ @property
39
+ def recall(self) -> float:
40
+ if self.tp + self.fn == 0:
41
+ return 0.0
42
+ return self.tp / (self.tp + self.fn)
43
+
44
+ @property
45
+ def f1(self) -> float:
46
+ if self.precision + self.recall == 0:
47
+ return 0.0
48
+ return 2 * (self.precision * self.recall) / (self.precision + self.recall)
49
+
50
+
51
+ @dataclass
52
+ class BenchmarkResult:
53
+ """Complete benchmark results."""
54
+ field_metrics: Dict[str, FieldMetrics] = field(default_factory=dict)
55
+ bank_metrics: Dict[str, Dict[str, FieldMetrics]] = field(default_factory=dict)
56
+ failures: List[Dict] = field(default_factory=list)
57
+ latency_ms: List[float] = field(default_factory=list)
58
+ total_samples: int = 0
59
+
60
+ @property
61
+ def overall_f1(self) -> float:
62
+ if not self.field_metrics:
63
+ return 0.0
64
+ return sum(m.f1 for m in self.field_metrics.values()) / len(self.field_metrics)
65
+
66
+ @property
67
+ def avg_latency_ms(self) -> float:
68
+ if not self.latency_ms:
69
+ return 0.0
70
+ return sum(self.latency_ms) / len(self.latency_ms)
71
+
72
+
73
+ class ProductionBenchmark:
74
+ """
75
+ Production-grade benchmark for financial entity extraction.
76
+ """
77
+
78
+ FIELDS = ["amount", "type", "bank", "merchant", "category", "reference", "vpa"]
79
+
80
+ def __init__(self, test_data_path: Optional[Path] = None):
81
+ self.test_data_path = test_data_path
82
+ self.extractor = None
83
+ self.results = BenchmarkResult()
84
+
85
+ def load_extractor(self, use_llm: bool = False):
86
+ """Load the extractor."""
87
+ try:
88
+ from finee import FinancialExtractor
89
+ self.extractor = FinancialExtractor(use_llm=use_llm)
90
+ except ImportError:
91
+ from finee import extract
92
+ self.extractor = type('Extractor', (), {'extract': lambda self, x: extract(x)})()
93
+
94
+ def load_test_data(self, path: Optional[Path] = None) -> List[Dict]:
95
+ """Load test dataset."""
96
+ path = path or self.test_data_path
97
+
98
+ if path and path.exists():
99
+ records = []
100
+ with open(path) as f:
101
+ for line in f:
102
+ try:
103
+ records.append(json.loads(line))
104
+ except:
105
+ continue
106
+ return records
107
+
108
+ return []
109
+
110
+ def _normalize_value(self, value, field: str):
111
+ """Normalize values for comparison."""
112
+ if value is None:
113
+ return None
114
+
115
+ if field == "amount":
116
+ if isinstance(value, (int, float)):
117
+ return round(float(value), 2)
118
+ if isinstance(value, str):
119
+ try:
120
+ return round(float(value.replace(',', '')), 2)
121
+ except:
122
+ return None
123
+
124
+ if field == "type":
125
+ v = str(value).lower().strip()
126
+ if v in ["debit", "dr", "debited"]:
127
+ return "debit"
128
+ if v in ["credit", "cr", "credited"]:
129
+ return "credit"
130
+ return v
131
+
132
+ if isinstance(value, str):
133
+ return value.lower().strip()
134
+
135
+ return value
136
+
137
+ def _compare_values(self, predicted, expected, field: str) -> Tuple[bool, str]:
138
+ """Compare predicted vs expected values."""
139
+ pred_norm = self._normalize_value(predicted, field)
140
+ exp_norm = self._normalize_value(expected, field)
141
+
142
+ if pred_norm is None and exp_norm is None:
143
+ return True, "both_none"
144
+
145
+ if pred_norm is None and exp_norm is not None:
146
+ return False, "false_negative"
147
+
148
+ if pred_norm is not None and exp_norm is None:
149
+ return False, "false_positive"
150
+
151
+ if pred_norm == exp_norm:
152
+ return True, "true_positive"
153
+
154
+ # Partial match for strings
155
+ if field in ["merchant", "bank"]:
156
+ if str(pred_norm) in str(exp_norm) or str(exp_norm) in str(pred_norm):
157
+ return True, "partial_match"
158
+
159
+ return False, "mismatch"
160
+
161
+ def evaluate_single(self, text: str, expected: Dict) -> Tuple[Dict, Dict, List[str]]:
162
+ """
163
+ Evaluate a single example.
164
+
165
+ Returns:
166
+ (predicted, expected, error_fields)
167
+ """
168
+ start = time.perf_counter()
169
+
170
+ # Extract
171
+ if hasattr(self.extractor, 'extract'):
172
+ predicted = self.extractor.extract(text)
173
+ else:
174
+ predicted = self.extractor(text)
175
+
176
+ # Convert to dict if needed
177
+ if hasattr(predicted, 'to_dict'):
178
+ predicted = predicted.to_dict()
179
+ elif hasattr(predicted, '__dict__'):
180
+ predicted = {k: v for k, v in predicted.__dict__.items() if not k.startswith('_')}
181
+
182
+ latency = (time.perf_counter() - start) * 1000
183
+ self.results.latency_ms.append(latency)
184
+
185
+ # Compare each field
186
+ errors = []
187
+ for field in self.FIELDS:
188
+ pred_val = predicted.get(field)
189
+ exp_val = expected.get(field)
190
+
191
+ match, reason = self._compare_values(pred_val, exp_val, field)
192
+
193
+ if field not in self.results.field_metrics:
194
+ self.results.field_metrics[field] = FieldMetrics()
195
+
196
+ metrics = self.results.field_metrics[field]
197
+
198
+ if reason == "true_positive" or reason == "partial_match":
199
+ metrics.tp += 1
200
+ elif reason == "false_negative":
201
+ metrics.fn += 1
202
+ errors.append(f"{field}: expected '{exp_val}', got None")
203
+ elif reason == "false_positive":
204
+ metrics.fp += 1
205
+ errors.append(f"{field}: expected None, got '{pred_val}'")
206
+ elif reason == "mismatch":
207
+ metrics.fn += 1
208
+ metrics.fp += 1
209
+ errors.append(f"{field}: expected '{exp_val}', got '{pred_val}'")
210
+
211
+ return predicted, expected, errors
212
+
213
+ def run(
214
+ self,
215
+ test_data: Optional[List[Dict]] = None,
216
+ max_samples: int = 1000,
217
+ include_failures: bool = True
218
+ ) -> BenchmarkResult:
219
+ """
220
+ Run the full benchmark.
221
+
222
+ Args:
223
+ test_data: List of test samples
224
+ max_samples: Maximum samples to evaluate
225
+ include_failures: Whether to collect failure cases
226
+
227
+ Returns:
228
+ BenchmarkResult with all metrics
229
+ """
230
+ if self.extractor is None:
231
+ self.load_extractor()
232
+
233
+ if test_data is None:
234
+ test_data = self.load_test_data()
235
+
236
+ if not test_data:
237
+ print("⚠️ No test data provided")
238
+ return self.results
239
+
240
+ # Sample if too many
241
+ if len(test_data) > max_samples:
242
+ test_data = random.sample(test_data, max_samples)
243
+
244
+ self.results = BenchmarkResult()
245
+ self.results.total_samples = len(test_data)
246
+
247
+ print(f"Running benchmark on {len(test_data)} samples...")
248
+
249
+ for i, record in enumerate(test_data):
250
+ text = record.get("input", record.get("text", ""))
251
+ expected = record.get("output", record.get("ground_truth", {}))
252
+
253
+ if isinstance(expected, str):
254
+ try:
255
+ expected = json.loads(expected)
256
+ except:
257
+ continue
258
+
259
+ predicted, expected, errors = self.evaluate_single(text, expected)
260
+
261
+ # Track failures
262
+ if include_failures and errors:
263
+ self.results.failures.append({
264
+ "text": text[:100],
265
+ "expected": expected,
266
+ "predicted": predicted,
267
+ "errors": errors,
268
+ })
269
+
270
+ # Progress
271
+ if (i + 1) % 100 == 0:
272
+ print(f" Processed {i + 1}/{len(test_data)}...")
273
+
274
+ return self.results
275
+
276
+ def print_report(self):
277
+ """Print a detailed report."""
278
+ print("\n" + "=" * 70)
279
+ print("PRODUCTION BENCHMARK REPORT")
280
+ print("=" * 70)
281
+
282
+ print(f"\n📊 Overall Statistics:")
283
+ print(f" Total Samples: {self.results.total_samples:,}")
284
+ print(f" Overall F1: {self.results.overall_f1:.1%}")
285
+ print(f" Avg Latency: {self.results.avg_latency_ms:.2f}ms")
286
+
287
+ print(f"\n📈 Per-Field Metrics:")
288
+ print(f" {'Field':<12} {'Precision':>10} {'Recall':>10} {'F1':>10}")
289
+ print(" " + "-" * 42)
290
+
291
+ for field in self.FIELDS:
292
+ if field in self.results.field_metrics:
293
+ m = self.results.field_metrics[field]
294
+ status = "✅" if m.f1 >= 0.90 else "⚠️" if m.f1 >= 0.70 else "❌"
295
+ print(f" {field:<12} {m.precision:>9.1%} {m.recall:>9.1%} {m.f1:>9.1%} {status}")
296
+
297
+ print(f"\n❌ Failure Cases: {len(self.results.failures)}")
298
+
299
+ if self.results.failures:
300
+ print("\n Sample Failures:")
301
+ for failure in self.results.failures[:5]:
302
+ print(f"\n Text: {failure['text'][:60]}...")
303
+ for err in failure['errors'][:3]:
304
+ print(f" • {err}")
305
+
306
+ # Grade
307
+ f1 = self.results.overall_f1
308
+ if f1 >= 0.95:
309
+ grade = "A+ (Production Ready)"
310
+ elif f1 >= 0.90:
311
+ grade = "A (Near Production)"
312
+ elif f1 >= 0.80:
313
+ grade = "B (Good)"
314
+ elif f1 >= 0.70:
315
+ grade = "C (Needs Improvement)"
316
+ else:
317
+ grade = "D (Significant Work Needed)"
318
+
319
+ print(f"\n🏆 Grade: {grade}")
320
+ print("=" * 70)
321
+
322
+ def export_results(self, path: Path):
323
+ """Export results to JSON."""
324
+ data = {
325
+ "overall_f1": self.results.overall_f1,
326
+ "avg_latency_ms": self.results.avg_latency_ms,
327
+ "total_samples": self.results.total_samples,
328
+ "field_metrics": {
329
+ field: {
330
+ "precision": m.precision,
331
+ "recall": m.recall,
332
+ "f1": m.f1,
333
+ }
334
+ for field, m in self.results.field_metrics.items()
335
+ },
336
+ "failure_count": len(self.results.failures),
337
+ "failures": self.results.failures[:20],
338
+ }
339
+
340
+ with open(path, 'w') as f:
341
+ json.dump(data, f, indent=2)
342
+
343
+ print(f"Results exported to {path}")
344
+
345
+
346
+ def create_held_out_test_set(
347
+ data_path: Path,
348
+ output_path: Path,
349
+ held_out_banks: List[str] = ["PNB", "BOB", "CANARA"],
350
+ num_samples: int = 1000
351
+ ):
352
+ """
353
+ Create a held-out test set with banks NOT in training.
354
+
355
+ This is critical for proper evaluation.
356
+ """
357
+ print(f"Creating held-out test set with banks: {held_out_banks}")
358
+
359
+ held_out = []
360
+ with open(data_path) as f:
361
+ for line in f:
362
+ try:
363
+ record = json.loads(line)
364
+ text = record.get("input", record.get("text", "")).upper()
365
+
366
+ # Check if contains held-out bank
367
+ for bank in held_out_banks:
368
+ if bank in text:
369
+ held_out.append(record)
370
+ break
371
+
372
+ if len(held_out) >= num_samples:
373
+ break
374
+ except:
375
+ continue
376
+
377
+ # Save
378
+ output_path.parent.mkdir(parents=True, exist_ok=True)
379
+ with open(output_path, 'w') as f:
380
+ for record in held_out:
381
+ f.write(json.dumps(record) + '\n')
382
+
383
+ print(f"Created held-out test set with {len(held_out)} samples at {output_path}")
384
+ return held_out
385
+
386
+
387
+ # ============================================================================
388
+ # MAIN
389
+ # ============================================================================
390
+
391
+ if __name__ == "__main__":
392
+ import argparse
393
+
394
+ parser = argparse.ArgumentParser(description="Run production benchmark")
395
+ parser.add_argument("--test-file", help="Path to test JSONL file")
396
+ parser.add_argument("--max-samples", type=int, default=1000)
397
+ parser.add_argument("--export", help="Export results to JSON")
398
+ parser.add_argument("--create-held-out", action="store_true",
399
+ help="Create held-out test set")
400
+
401
+ args = parser.parse_args()
402
+
403
+ if args.create_held_out:
404
+ create_held_out_test_set(
405
+ Path("data/instruction/test.jsonl"),
406
+ Path("data/benchmark/held_out_test.jsonl"),
407
+ )
408
+ else:
409
+ benchmark = ProductionBenchmark()
410
+
411
+ if args.test_file:
412
+ test_data = benchmark.load_test_data(Path(args.test_file))
413
+ else:
414
+ # Use default test set
415
+ test_data = benchmark.load_test_data(Path("data/instruction/test.jsonl"))
416
+
417
+ benchmark.run(test_data, max_samples=args.max_samples)
418
+ benchmark.print_report()
419
+
420
+ if args.export:
421
+ benchmark.export_results(Path(args.export))