#!/usr/bin/env python3 """ Production Benchmark Suite for FinEE ===================================== Comprehensive evaluation with: - Precision/Recall/F1 per field - Bank-specific performance - Cross-validation - Failure case analysis - Comparison with baselines Author: Ranjit Behera """ import json import random from pathlib import Path from typing import List, Dict, Tuple, Optional from dataclasses import dataclass, field from collections import defaultdict import time @dataclass class FieldMetrics: """Metrics for a single field.""" tp: int = 0 # True positives fp: int = 0 # False positives fn: int = 0 # False negatives @property def precision(self) -> float: if self.tp + self.fp == 0: return 0.0 return self.tp / (self.tp + self.fp) @property def recall(self) -> float: if self.tp + self.fn == 0: return 0.0 return self.tp / (self.tp + self.fn) @property def f1(self) -> float: if self.precision + self.recall == 0: return 0.0 return 2 * (self.precision * self.recall) / (self.precision + self.recall) @dataclass class BenchmarkResult: """Complete benchmark results.""" field_metrics: Dict[str, FieldMetrics] = field(default_factory=dict) bank_metrics: Dict[str, Dict[str, FieldMetrics]] = field(default_factory=dict) failures: List[Dict] = field(default_factory=list) latency_ms: List[float] = field(default_factory=list) total_samples: int = 0 @property def overall_f1(self) -> float: if not self.field_metrics: return 0.0 return sum(m.f1 for m in self.field_metrics.values()) / len(self.field_metrics) @property def avg_latency_ms(self) -> float: if not self.latency_ms: return 0.0 return sum(self.latency_ms) / len(self.latency_ms) class ProductionBenchmark: """ Production-grade benchmark for financial entity extraction. """ FIELDS = ["amount", "type", "bank", "merchant", "category", "reference", "vpa"] def __init__(self, test_data_path: Optional[Path] = None): self.test_data_path = test_data_path self.extractor = None self.results = BenchmarkResult() def load_extractor(self, use_llm: bool = False): """Load the extractor.""" try: from finee import FinancialExtractor self.extractor = FinancialExtractor(use_llm=use_llm) except ImportError: from finee import extract self.extractor = type('Extractor', (), {'extract': lambda self, x: extract(x)})() def load_test_data(self, path: Optional[Path] = None) -> List[Dict]: """Load test dataset.""" path = path or self.test_data_path if path and path.exists(): records = [] with open(path) as f: for line in f: try: records.append(json.loads(line)) except: continue return records return [] def _normalize_value(self, value, field: str): """Normalize values for comparison.""" if value is None: return None if field == "amount": if isinstance(value, (int, float)): return round(float(value), 2) if isinstance(value, str): try: return round(float(value.replace(',', '')), 2) except: return None if field == "type": v = str(value).lower().strip() if v in ["debit", "dr", "debited"]: return "debit" if v in ["credit", "cr", "credited"]: return "credit" return v if isinstance(value, str): return value.lower().strip() return value def _compare_values(self, predicted, expected, field: str) -> Tuple[bool, str]: """Compare predicted vs expected values.""" pred_norm = self._normalize_value(predicted, field) exp_norm = self._normalize_value(expected, field) if pred_norm is None and exp_norm is None: return True, "both_none" if pred_norm is None and exp_norm is not None: return False, "false_negative" if pred_norm is not None and exp_norm is None: return False, "false_positive" if pred_norm == exp_norm: return True, "true_positive" # Partial match for strings if field in ["merchant", "bank"]: if str(pred_norm) in str(exp_norm) or str(exp_norm) in str(pred_norm): return True, "partial_match" return False, "mismatch" def evaluate_single(self, text: str, expected: Dict) -> Tuple[Dict, Dict, List[str]]: """ Evaluate a single example. Returns: (predicted, expected, error_fields) """ start = time.perf_counter() # Extract if hasattr(self.extractor, 'extract'): predicted = self.extractor.extract(text) else: predicted = self.extractor(text) # Convert to dict if needed if hasattr(predicted, 'to_dict'): predicted = predicted.to_dict() elif hasattr(predicted, '__dict__'): predicted = {k: v for k, v in predicted.__dict__.items() if not k.startswith('_')} latency = (time.perf_counter() - start) * 1000 self.results.latency_ms.append(latency) # Compare each field errors = [] for field in self.FIELDS: pred_val = predicted.get(field) exp_val = expected.get(field) match, reason = self._compare_values(pred_val, exp_val, field) if field not in self.results.field_metrics: self.results.field_metrics[field] = FieldMetrics() metrics = self.results.field_metrics[field] if reason == "true_positive" or reason == "partial_match": metrics.tp += 1 elif reason == "false_negative": metrics.fn += 1 errors.append(f"{field}: expected '{exp_val}', got None") elif reason == "false_positive": metrics.fp += 1 errors.append(f"{field}: expected None, got '{pred_val}'") elif reason == "mismatch": metrics.fn += 1 metrics.fp += 1 errors.append(f"{field}: expected '{exp_val}', got '{pred_val}'") return predicted, expected, errors def run( self, test_data: Optional[List[Dict]] = None, max_samples: int = 1000, include_failures: bool = True ) -> BenchmarkResult: """ Run the full benchmark. Args: test_data: List of test samples max_samples: Maximum samples to evaluate include_failures: Whether to collect failure cases Returns: BenchmarkResult with all metrics """ if self.extractor is None: self.load_extractor() if test_data is None: test_data = self.load_test_data() if not test_data: print("āš ļø No test data provided") return self.results # Sample if too many if len(test_data) > max_samples: test_data = random.sample(test_data, max_samples) self.results = BenchmarkResult() self.results.total_samples = len(test_data) print(f"Running benchmark on {len(test_data)} samples...") for i, record in enumerate(test_data): text = record.get("input", record.get("text", "")) expected = record.get("output", record.get("ground_truth", {})) if isinstance(expected, str): try: expected = json.loads(expected) except: continue predicted, expected, errors = self.evaluate_single(text, expected) # Track failures if include_failures and errors: self.results.failures.append({ "text": text[:100], "expected": expected, "predicted": predicted, "errors": errors, }) # Progress if (i + 1) % 100 == 0: print(f" Processed {i + 1}/{len(test_data)}...") return self.results def print_report(self): """Print a detailed report.""" print("\n" + "=" * 70) print("PRODUCTION BENCHMARK REPORT") print("=" * 70) print(f"\nšŸ“Š Overall Statistics:") print(f" Total Samples: {self.results.total_samples:,}") print(f" Overall F1: {self.results.overall_f1:.1%}") print(f" Avg Latency: {self.results.avg_latency_ms:.2f}ms") print(f"\nšŸ“ˆ Per-Field Metrics:") print(f" {'Field':<12} {'Precision':>10} {'Recall':>10} {'F1':>10}") print(" " + "-" * 42) for field in self.FIELDS: if field in self.results.field_metrics: m = self.results.field_metrics[field] status = "āœ…" if m.f1 >= 0.90 else "āš ļø" if m.f1 >= 0.70 else "āŒ" print(f" {field:<12} {m.precision:>9.1%} {m.recall:>9.1%} {m.f1:>9.1%} {status}") print(f"\nāŒ Failure Cases: {len(self.results.failures)}") if self.results.failures: print("\n Sample Failures:") for failure in self.results.failures[:5]: print(f"\n Text: {failure['text'][:60]}...") for err in failure['errors'][:3]: print(f" • {err}") # Grade f1 = self.results.overall_f1 if f1 >= 0.95: grade = "A+ (Production Ready)" elif f1 >= 0.90: grade = "A (Near Production)" elif f1 >= 0.80: grade = "B (Good)" elif f1 >= 0.70: grade = "C (Needs Improvement)" else: grade = "D (Significant Work Needed)" print(f"\nšŸ† Grade: {grade}") print("=" * 70) def export_results(self, path: Path): """Export results to JSON.""" data = { "overall_f1": self.results.overall_f1, "avg_latency_ms": self.results.avg_latency_ms, "total_samples": self.results.total_samples, "field_metrics": { field: { "precision": m.precision, "recall": m.recall, "f1": m.f1, } for field, m in self.results.field_metrics.items() }, "failure_count": len(self.results.failures), "failures": self.results.failures[:20], } with open(path, 'w') as f: json.dump(data, f, indent=2) print(f"Results exported to {path}") def create_held_out_test_set( data_path: Path, output_path: Path, held_out_banks: List[str] = ["PNB", "BOB", "CANARA"], num_samples: int = 1000 ): """ Create a held-out test set with banks NOT in training. This is critical for proper evaluation. """ print(f"Creating held-out test set with banks: {held_out_banks}") held_out = [] with open(data_path) as f: for line in f: try: record = json.loads(line) text = record.get("input", record.get("text", "")).upper() # Check if contains held-out bank for bank in held_out_banks: if bank in text: held_out.append(record) break if len(held_out) >= num_samples: break except: continue # Save output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: for record in held_out: f.write(json.dumps(record) + '\n') print(f"Created held-out test set with {len(held_out)} samples at {output_path}") return held_out # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run production benchmark") parser.add_argument("--test-file", help="Path to test JSONL file") parser.add_argument("--max-samples", type=int, default=1000) parser.add_argument("--export", help="Export results to JSON") parser.add_argument("--create-held-out", action="store_true", help="Create held-out test set") args = parser.parse_args() if args.create_held_out: create_held_out_test_set( Path("data/instruction/test.jsonl"), Path("data/benchmark/held_out_test.jsonl"), ) else: benchmark = ProductionBenchmark() if args.test_file: test_data = benchmark.load_test_data(Path(args.test_file)) else: # Use default test set test_data = benchmark.load_test_data(Path("data/instruction/test.jsonl")) benchmark.run(test_data, max_samples=args.max_samples) benchmark.print_report() if args.export: benchmark.export_results(Path(args.export))