Ranjit0034's picture
Upload scripts/benchmark.py with huggingface_hub
1cba4da verified
#!/usr/bin/env python3
"""
Production Benchmark Suite for FinEE
=====================================
Comprehensive evaluation with:
- Precision/Recall/F1 per field
- Bank-specific performance
- Cross-validation
- Failure case analysis
- Comparison with baselines
Author: Ranjit Behera
"""
import json
import random
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, field
from collections import defaultdict
import time
@dataclass
class FieldMetrics:
"""Metrics for a single field."""
tp: int = 0 # True positives
fp: int = 0 # False positives
fn: int = 0 # False negatives
@property
def precision(self) -> float:
if self.tp + self.fp == 0:
return 0.0
return self.tp / (self.tp + self.fp)
@property
def recall(self) -> float:
if self.tp + self.fn == 0:
return 0.0
return self.tp / (self.tp + self.fn)
@property
def f1(self) -> float:
if self.precision + self.recall == 0:
return 0.0
return 2 * (self.precision * self.recall) / (self.precision + self.recall)
@dataclass
class BenchmarkResult:
"""Complete benchmark results."""
field_metrics: Dict[str, FieldMetrics] = field(default_factory=dict)
bank_metrics: Dict[str, Dict[str, FieldMetrics]] = field(default_factory=dict)
failures: List[Dict] = field(default_factory=list)
latency_ms: List[float] = field(default_factory=list)
total_samples: int = 0
@property
def overall_f1(self) -> float:
if not self.field_metrics:
return 0.0
return sum(m.f1 for m in self.field_metrics.values()) / len(self.field_metrics)
@property
def avg_latency_ms(self) -> float:
if not self.latency_ms:
return 0.0
return sum(self.latency_ms) / len(self.latency_ms)
class ProductionBenchmark:
"""
Production-grade benchmark for financial entity extraction.
"""
FIELDS = ["amount", "type", "bank", "merchant", "category", "reference", "vpa"]
def __init__(self, test_data_path: Optional[Path] = None):
self.test_data_path = test_data_path
self.extractor = None
self.results = BenchmarkResult()
def load_extractor(self, use_llm: bool = False):
"""Load the extractor."""
try:
from finee import FinancialExtractor
self.extractor = FinancialExtractor(use_llm=use_llm)
except ImportError:
from finee import extract
self.extractor = type('Extractor', (), {'extract': lambda self, x: extract(x)})()
def load_test_data(self, path: Optional[Path] = None) -> List[Dict]:
"""Load test dataset."""
path = path or self.test_data_path
if path and path.exists():
records = []
with open(path) as f:
for line in f:
try:
records.append(json.loads(line))
except:
continue
return records
return []
def _normalize_value(self, value, field: str):
"""Normalize values for comparison."""
if value is None:
return None
if field == "amount":
if isinstance(value, (int, float)):
return round(float(value), 2)
if isinstance(value, str):
try:
return round(float(value.replace(',', '')), 2)
except:
return None
if field == "type":
v = str(value).lower().strip()
if v in ["debit", "dr", "debited"]:
return "debit"
if v in ["credit", "cr", "credited"]:
return "credit"
return v
if isinstance(value, str):
return value.lower().strip()
return value
def _compare_values(self, predicted, expected, field: str) -> Tuple[bool, str]:
"""Compare predicted vs expected values."""
pred_norm = self._normalize_value(predicted, field)
exp_norm = self._normalize_value(expected, field)
if pred_norm is None and exp_norm is None:
return True, "both_none"
if pred_norm is None and exp_norm is not None:
return False, "false_negative"
if pred_norm is not None and exp_norm is None:
return False, "false_positive"
if pred_norm == exp_norm:
return True, "true_positive"
# Partial match for strings
if field in ["merchant", "bank"]:
if str(pred_norm) in str(exp_norm) or str(exp_norm) in str(pred_norm):
return True, "partial_match"
return False, "mismatch"
def evaluate_single(self, text: str, expected: Dict) -> Tuple[Dict, Dict, List[str]]:
"""
Evaluate a single example.
Returns:
(predicted, expected, error_fields)
"""
start = time.perf_counter()
# Extract
if hasattr(self.extractor, 'extract'):
predicted = self.extractor.extract(text)
else:
predicted = self.extractor(text)
# Convert to dict if needed
if hasattr(predicted, 'to_dict'):
predicted = predicted.to_dict()
elif hasattr(predicted, '__dict__'):
predicted = {k: v for k, v in predicted.__dict__.items() if not k.startswith('_')}
latency = (time.perf_counter() - start) * 1000
self.results.latency_ms.append(latency)
# Compare each field
errors = []
for field in self.FIELDS:
pred_val = predicted.get(field)
exp_val = expected.get(field)
match, reason = self._compare_values(pred_val, exp_val, field)
if field not in self.results.field_metrics:
self.results.field_metrics[field] = FieldMetrics()
metrics = self.results.field_metrics[field]
if reason == "true_positive" or reason == "partial_match":
metrics.tp += 1
elif reason == "false_negative":
metrics.fn += 1
errors.append(f"{field}: expected '{exp_val}', got None")
elif reason == "false_positive":
metrics.fp += 1
errors.append(f"{field}: expected None, got '{pred_val}'")
elif reason == "mismatch":
metrics.fn += 1
metrics.fp += 1
errors.append(f"{field}: expected '{exp_val}', got '{pred_val}'")
return predicted, expected, errors
def run(
self,
test_data: Optional[List[Dict]] = None,
max_samples: int = 1000,
include_failures: bool = True
) -> BenchmarkResult:
"""
Run the full benchmark.
Args:
test_data: List of test samples
max_samples: Maximum samples to evaluate
include_failures: Whether to collect failure cases
Returns:
BenchmarkResult with all metrics
"""
if self.extractor is None:
self.load_extractor()
if test_data is None:
test_data = self.load_test_data()
if not test_data:
print("⚠️ No test data provided")
return self.results
# Sample if too many
if len(test_data) > max_samples:
test_data = random.sample(test_data, max_samples)
self.results = BenchmarkResult()
self.results.total_samples = len(test_data)
print(f"Running benchmark on {len(test_data)} samples...")
for i, record in enumerate(test_data):
text = record.get("input", record.get("text", ""))
expected = record.get("output", record.get("ground_truth", {}))
if isinstance(expected, str):
try:
expected = json.loads(expected)
except:
continue
predicted, expected, errors = self.evaluate_single(text, expected)
# Track failures
if include_failures and errors:
self.results.failures.append({
"text": text[:100],
"expected": expected,
"predicted": predicted,
"errors": errors,
})
# Progress
if (i + 1) % 100 == 0:
print(f" Processed {i + 1}/{len(test_data)}...")
return self.results
def print_report(self):
"""Print a detailed report."""
print("\n" + "=" * 70)
print("PRODUCTION BENCHMARK REPORT")
print("=" * 70)
print(f"\n📊 Overall Statistics:")
print(f" Total Samples: {self.results.total_samples:,}")
print(f" Overall F1: {self.results.overall_f1:.1%}")
print(f" Avg Latency: {self.results.avg_latency_ms:.2f}ms")
print(f"\n📈 Per-Field Metrics:")
print(f" {'Field':<12} {'Precision':>10} {'Recall':>10} {'F1':>10}")
print(" " + "-" * 42)
for field in self.FIELDS:
if field in self.results.field_metrics:
m = self.results.field_metrics[field]
status = "✅" if m.f1 >= 0.90 else "⚠️" if m.f1 >= 0.70 else "❌"
print(f" {field:<12} {m.precision:>9.1%} {m.recall:>9.1%} {m.f1:>9.1%} {status}")
print(f"\n❌ Failure Cases: {len(self.results.failures)}")
if self.results.failures:
print("\n Sample Failures:")
for failure in self.results.failures[:5]:
print(f"\n Text: {failure['text'][:60]}...")
for err in failure['errors'][:3]:
print(f" • {err}")
# Grade
f1 = self.results.overall_f1
if f1 >= 0.95:
grade = "A+ (Production Ready)"
elif f1 >= 0.90:
grade = "A (Near Production)"
elif f1 >= 0.80:
grade = "B (Good)"
elif f1 >= 0.70:
grade = "C (Needs Improvement)"
else:
grade = "D (Significant Work Needed)"
print(f"\n🏆 Grade: {grade}")
print("=" * 70)
def export_results(self, path: Path):
"""Export results to JSON."""
data = {
"overall_f1": self.results.overall_f1,
"avg_latency_ms": self.results.avg_latency_ms,
"total_samples": self.results.total_samples,
"field_metrics": {
field: {
"precision": m.precision,
"recall": m.recall,
"f1": m.f1,
}
for field, m in self.results.field_metrics.items()
},
"failure_count": len(self.results.failures),
"failures": self.results.failures[:20],
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
print(f"Results exported to {path}")
def create_held_out_test_set(
data_path: Path,
output_path: Path,
held_out_banks: List[str] = ["PNB", "BOB", "CANARA"],
num_samples: int = 1000
):
"""
Create a held-out test set with banks NOT in training.
This is critical for proper evaluation.
"""
print(f"Creating held-out test set with banks: {held_out_banks}")
held_out = []
with open(data_path) as f:
for line in f:
try:
record = json.loads(line)
text = record.get("input", record.get("text", "")).upper()
# Check if contains held-out bank
for bank in held_out_banks:
if bank in text:
held_out.append(record)
break
if len(held_out) >= num_samples:
break
except:
continue
# Save
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for record in held_out:
f.write(json.dumps(record) + '\n')
print(f"Created held-out test set with {len(held_out)} samples at {output_path}")
return held_out
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run production benchmark")
parser.add_argument("--test-file", help="Path to test JSONL file")
parser.add_argument("--max-samples", type=int, default=1000)
parser.add_argument("--export", help="Export results to JSON")
parser.add_argument("--create-held-out", action="store_true",
help="Create held-out test set")
args = parser.parse_args()
if args.create_held_out:
create_held_out_test_set(
Path("data/instruction/test.jsonl"),
Path("data/benchmark/held_out_test.jsonl"),
)
else:
benchmark = ProductionBenchmark()
if args.test_file:
test_data = benchmark.load_test_data(Path(args.test_file))
else:
# Use default test set
test_data = benchmark.load_test_data(Path("data/instruction/test.jsonl"))
benchmark.run(test_data, max_samples=args.max_samples)
benchmark.print_report()
if args.export:
benchmark.export_results(Path(args.export))