|
|
|
|
|
""" |
|
|
Production Benchmark Suite for FinEE |
|
|
===================================== |
|
|
|
|
|
Comprehensive evaluation with: |
|
|
- Precision/Recall/F1 per field |
|
|
- Bank-specific performance |
|
|
- Cross-validation |
|
|
- Failure case analysis |
|
|
- Comparison with baselines |
|
|
|
|
|
Author: Ranjit Behera |
|
|
""" |
|
|
|
|
|
import json |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
from dataclasses import dataclass, field |
|
|
from collections import defaultdict |
|
|
import time |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FieldMetrics: |
|
|
"""Metrics for a single field.""" |
|
|
tp: int = 0 |
|
|
fp: int = 0 |
|
|
fn: int = 0 |
|
|
|
|
|
@property |
|
|
def precision(self) -> float: |
|
|
if self.tp + self.fp == 0: |
|
|
return 0.0 |
|
|
return self.tp / (self.tp + self.fp) |
|
|
|
|
|
@property |
|
|
def recall(self) -> float: |
|
|
if self.tp + self.fn == 0: |
|
|
return 0.0 |
|
|
return self.tp / (self.tp + self.fn) |
|
|
|
|
|
@property |
|
|
def f1(self) -> float: |
|
|
if self.precision + self.recall == 0: |
|
|
return 0.0 |
|
|
return 2 * (self.precision * self.recall) / (self.precision + self.recall) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BenchmarkResult: |
|
|
"""Complete benchmark results.""" |
|
|
field_metrics: Dict[str, FieldMetrics] = field(default_factory=dict) |
|
|
bank_metrics: Dict[str, Dict[str, FieldMetrics]] = field(default_factory=dict) |
|
|
failures: List[Dict] = field(default_factory=list) |
|
|
latency_ms: List[float] = field(default_factory=list) |
|
|
total_samples: int = 0 |
|
|
|
|
|
@property |
|
|
def overall_f1(self) -> float: |
|
|
if not self.field_metrics: |
|
|
return 0.0 |
|
|
return sum(m.f1 for m in self.field_metrics.values()) / len(self.field_metrics) |
|
|
|
|
|
@property |
|
|
def avg_latency_ms(self) -> float: |
|
|
if not self.latency_ms: |
|
|
return 0.0 |
|
|
return sum(self.latency_ms) / len(self.latency_ms) |
|
|
|
|
|
|
|
|
class ProductionBenchmark: |
|
|
""" |
|
|
Production-grade benchmark for financial entity extraction. |
|
|
""" |
|
|
|
|
|
FIELDS = ["amount", "type", "bank", "merchant", "category", "reference", "vpa"] |
|
|
|
|
|
def __init__(self, test_data_path: Optional[Path] = None): |
|
|
self.test_data_path = test_data_path |
|
|
self.extractor = None |
|
|
self.results = BenchmarkResult() |
|
|
|
|
|
def load_extractor(self, use_llm: bool = False): |
|
|
"""Load the extractor.""" |
|
|
try: |
|
|
from finee import FinancialExtractor |
|
|
self.extractor = FinancialExtractor(use_llm=use_llm) |
|
|
except ImportError: |
|
|
from finee import extract |
|
|
self.extractor = type('Extractor', (), {'extract': lambda self, x: extract(x)})() |
|
|
|
|
|
def load_test_data(self, path: Optional[Path] = None) -> List[Dict]: |
|
|
"""Load test dataset.""" |
|
|
path = path or self.test_data_path |
|
|
|
|
|
if path and path.exists(): |
|
|
records = [] |
|
|
with open(path) as f: |
|
|
for line in f: |
|
|
try: |
|
|
records.append(json.loads(line)) |
|
|
except: |
|
|
continue |
|
|
return records |
|
|
|
|
|
return [] |
|
|
|
|
|
def _normalize_value(self, value, field: str): |
|
|
"""Normalize values for comparison.""" |
|
|
if value is None: |
|
|
return None |
|
|
|
|
|
if field == "amount": |
|
|
if isinstance(value, (int, float)): |
|
|
return round(float(value), 2) |
|
|
if isinstance(value, str): |
|
|
try: |
|
|
return round(float(value.replace(',', '')), 2) |
|
|
except: |
|
|
return None |
|
|
|
|
|
if field == "type": |
|
|
v = str(value).lower().strip() |
|
|
if v in ["debit", "dr", "debited"]: |
|
|
return "debit" |
|
|
if v in ["credit", "cr", "credited"]: |
|
|
return "credit" |
|
|
return v |
|
|
|
|
|
if isinstance(value, str): |
|
|
return value.lower().strip() |
|
|
|
|
|
return value |
|
|
|
|
|
def _compare_values(self, predicted, expected, field: str) -> Tuple[bool, str]: |
|
|
"""Compare predicted vs expected values.""" |
|
|
pred_norm = self._normalize_value(predicted, field) |
|
|
exp_norm = self._normalize_value(expected, field) |
|
|
|
|
|
if pred_norm is None and exp_norm is None: |
|
|
return True, "both_none" |
|
|
|
|
|
if pred_norm is None and exp_norm is not None: |
|
|
return False, "false_negative" |
|
|
|
|
|
if pred_norm is not None and exp_norm is None: |
|
|
return False, "false_positive" |
|
|
|
|
|
if pred_norm == exp_norm: |
|
|
return True, "true_positive" |
|
|
|
|
|
|
|
|
if field in ["merchant", "bank"]: |
|
|
if str(pred_norm) in str(exp_norm) or str(exp_norm) in str(pred_norm): |
|
|
return True, "partial_match" |
|
|
|
|
|
return False, "mismatch" |
|
|
|
|
|
def evaluate_single(self, text: str, expected: Dict) -> Tuple[Dict, Dict, List[str]]: |
|
|
""" |
|
|
Evaluate a single example. |
|
|
|
|
|
Returns: |
|
|
(predicted, expected, error_fields) |
|
|
""" |
|
|
start = time.perf_counter() |
|
|
|
|
|
|
|
|
if hasattr(self.extractor, 'extract'): |
|
|
predicted = self.extractor.extract(text) |
|
|
else: |
|
|
predicted = self.extractor(text) |
|
|
|
|
|
|
|
|
if hasattr(predicted, 'to_dict'): |
|
|
predicted = predicted.to_dict() |
|
|
elif hasattr(predicted, '__dict__'): |
|
|
predicted = {k: v for k, v in predicted.__dict__.items() if not k.startswith('_')} |
|
|
|
|
|
latency = (time.perf_counter() - start) * 1000 |
|
|
self.results.latency_ms.append(latency) |
|
|
|
|
|
|
|
|
errors = [] |
|
|
for field in self.FIELDS: |
|
|
pred_val = predicted.get(field) |
|
|
exp_val = expected.get(field) |
|
|
|
|
|
match, reason = self._compare_values(pred_val, exp_val, field) |
|
|
|
|
|
if field not in self.results.field_metrics: |
|
|
self.results.field_metrics[field] = FieldMetrics() |
|
|
|
|
|
metrics = self.results.field_metrics[field] |
|
|
|
|
|
if reason == "true_positive" or reason == "partial_match": |
|
|
metrics.tp += 1 |
|
|
elif reason == "false_negative": |
|
|
metrics.fn += 1 |
|
|
errors.append(f"{field}: expected '{exp_val}', got None") |
|
|
elif reason == "false_positive": |
|
|
metrics.fp += 1 |
|
|
errors.append(f"{field}: expected None, got '{pred_val}'") |
|
|
elif reason == "mismatch": |
|
|
metrics.fn += 1 |
|
|
metrics.fp += 1 |
|
|
errors.append(f"{field}: expected '{exp_val}', got '{pred_val}'") |
|
|
|
|
|
return predicted, expected, errors |
|
|
|
|
|
def run( |
|
|
self, |
|
|
test_data: Optional[List[Dict]] = None, |
|
|
max_samples: int = 1000, |
|
|
include_failures: bool = True |
|
|
) -> BenchmarkResult: |
|
|
""" |
|
|
Run the full benchmark. |
|
|
|
|
|
Args: |
|
|
test_data: List of test samples |
|
|
max_samples: Maximum samples to evaluate |
|
|
include_failures: Whether to collect failure cases |
|
|
|
|
|
Returns: |
|
|
BenchmarkResult with all metrics |
|
|
""" |
|
|
if self.extractor is None: |
|
|
self.load_extractor() |
|
|
|
|
|
if test_data is None: |
|
|
test_data = self.load_test_data() |
|
|
|
|
|
if not test_data: |
|
|
print("⚠️ No test data provided") |
|
|
return self.results |
|
|
|
|
|
|
|
|
if len(test_data) > max_samples: |
|
|
test_data = random.sample(test_data, max_samples) |
|
|
|
|
|
self.results = BenchmarkResult() |
|
|
self.results.total_samples = len(test_data) |
|
|
|
|
|
print(f"Running benchmark on {len(test_data)} samples...") |
|
|
|
|
|
for i, record in enumerate(test_data): |
|
|
text = record.get("input", record.get("text", "")) |
|
|
expected = record.get("output", record.get("ground_truth", {})) |
|
|
|
|
|
if isinstance(expected, str): |
|
|
try: |
|
|
expected = json.loads(expected) |
|
|
except: |
|
|
continue |
|
|
|
|
|
predicted, expected, errors = self.evaluate_single(text, expected) |
|
|
|
|
|
|
|
|
if include_failures and errors: |
|
|
self.results.failures.append({ |
|
|
"text": text[:100], |
|
|
"expected": expected, |
|
|
"predicted": predicted, |
|
|
"errors": errors, |
|
|
}) |
|
|
|
|
|
|
|
|
if (i + 1) % 100 == 0: |
|
|
print(f" Processed {i + 1}/{len(test_data)}...") |
|
|
|
|
|
return self.results |
|
|
|
|
|
def print_report(self): |
|
|
"""Print a detailed report.""" |
|
|
print("\n" + "=" * 70) |
|
|
print("PRODUCTION BENCHMARK REPORT") |
|
|
print("=" * 70) |
|
|
|
|
|
print(f"\n📊 Overall Statistics:") |
|
|
print(f" Total Samples: {self.results.total_samples:,}") |
|
|
print(f" Overall F1: {self.results.overall_f1:.1%}") |
|
|
print(f" Avg Latency: {self.results.avg_latency_ms:.2f}ms") |
|
|
|
|
|
print(f"\n📈 Per-Field Metrics:") |
|
|
print(f" {'Field':<12} {'Precision':>10} {'Recall':>10} {'F1':>10}") |
|
|
print(" " + "-" * 42) |
|
|
|
|
|
for field in self.FIELDS: |
|
|
if field in self.results.field_metrics: |
|
|
m = self.results.field_metrics[field] |
|
|
status = "✅" if m.f1 >= 0.90 else "⚠️" if m.f1 >= 0.70 else "❌" |
|
|
print(f" {field:<12} {m.precision:>9.1%} {m.recall:>9.1%} {m.f1:>9.1%} {status}") |
|
|
|
|
|
print(f"\n❌ Failure Cases: {len(self.results.failures)}") |
|
|
|
|
|
if self.results.failures: |
|
|
print("\n Sample Failures:") |
|
|
for failure in self.results.failures[:5]: |
|
|
print(f"\n Text: {failure['text'][:60]}...") |
|
|
for err in failure['errors'][:3]: |
|
|
print(f" • {err}") |
|
|
|
|
|
|
|
|
f1 = self.results.overall_f1 |
|
|
if f1 >= 0.95: |
|
|
grade = "A+ (Production Ready)" |
|
|
elif f1 >= 0.90: |
|
|
grade = "A (Near Production)" |
|
|
elif f1 >= 0.80: |
|
|
grade = "B (Good)" |
|
|
elif f1 >= 0.70: |
|
|
grade = "C (Needs Improvement)" |
|
|
else: |
|
|
grade = "D (Significant Work Needed)" |
|
|
|
|
|
print(f"\n🏆 Grade: {grade}") |
|
|
print("=" * 70) |
|
|
|
|
|
def export_results(self, path: Path): |
|
|
"""Export results to JSON.""" |
|
|
data = { |
|
|
"overall_f1": self.results.overall_f1, |
|
|
"avg_latency_ms": self.results.avg_latency_ms, |
|
|
"total_samples": self.results.total_samples, |
|
|
"field_metrics": { |
|
|
field: { |
|
|
"precision": m.precision, |
|
|
"recall": m.recall, |
|
|
"f1": m.f1, |
|
|
} |
|
|
for field, m in self.results.field_metrics.items() |
|
|
}, |
|
|
"failure_count": len(self.results.failures), |
|
|
"failures": self.results.failures[:20], |
|
|
} |
|
|
|
|
|
with open(path, 'w') as f: |
|
|
json.dump(data, f, indent=2) |
|
|
|
|
|
print(f"Results exported to {path}") |
|
|
|
|
|
|
|
|
def create_held_out_test_set( |
|
|
data_path: Path, |
|
|
output_path: Path, |
|
|
held_out_banks: List[str] = ["PNB", "BOB", "CANARA"], |
|
|
num_samples: int = 1000 |
|
|
): |
|
|
""" |
|
|
Create a held-out test set with banks NOT in training. |
|
|
|
|
|
This is critical for proper evaluation. |
|
|
""" |
|
|
print(f"Creating held-out test set with banks: {held_out_banks}") |
|
|
|
|
|
held_out = [] |
|
|
with open(data_path) as f: |
|
|
for line in f: |
|
|
try: |
|
|
record = json.loads(line) |
|
|
text = record.get("input", record.get("text", "")).upper() |
|
|
|
|
|
|
|
|
for bank in held_out_banks: |
|
|
if bank in text: |
|
|
held_out.append(record) |
|
|
break |
|
|
|
|
|
if len(held_out) >= num_samples: |
|
|
break |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(output_path, 'w') as f: |
|
|
for record in held_out: |
|
|
f.write(json.dumps(record) + '\n') |
|
|
|
|
|
print(f"Created held-out test set with {len(held_out)} samples at {output_path}") |
|
|
return held_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run production benchmark") |
|
|
parser.add_argument("--test-file", help="Path to test JSONL file") |
|
|
parser.add_argument("--max-samples", type=int, default=1000) |
|
|
parser.add_argument("--export", help="Export results to JSON") |
|
|
parser.add_argument("--create-held-out", action="store_true", |
|
|
help="Create held-out test set") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.create_held_out: |
|
|
create_held_out_test_set( |
|
|
Path("data/instruction/test.jsonl"), |
|
|
Path("data/benchmark/held_out_test.jsonl"), |
|
|
) |
|
|
else: |
|
|
benchmark = ProductionBenchmark() |
|
|
|
|
|
if args.test_file: |
|
|
test_data = benchmark.load_test_data(Path(args.test_file)) |
|
|
else: |
|
|
|
|
|
test_data = benchmark.load_test_data(Path("data/instruction/test.jsonl")) |
|
|
|
|
|
benchmark.run(test_data, max_samples=args.max_samples) |
|
|
benchmark.print_report() |
|
|
|
|
|
if args.export: |
|
|
benchmark.export_results(Path(args.export)) |
|
|
|