Spaces:
Sleeping
Sleeping
| """ | |
| benchmark.py — Accuracy evaluation script for the Bill/Invoice Scanner. | |
| This script processes the 1,000-receipt SROIE dataset and compares | |
| extracted fields (Vendor, Date, Total) against the ground-truth JSON files. | |
| Usage: | |
| conda run -n dl_projects python benchmark.py | |
| Metrics: | |
| - Vendor Accuracy: Case-normalized partial match. | |
| - Date Accuracy: String equality after normalization. | |
| - Total Accuracy: Fuzzy float equality (within 0.01). | |
| """ | |
| import os | |
| import json | |
| import pandas as pd | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import torch | |
| # Project modules | |
| import utils | |
| import ocr | |
| import extractor | |
| # Dataset paths | |
| DATA_DIR = Path("SROIE_Dataset/data") | |
| IMG_DIR = DATA_DIR / "img" | |
| KEY_DIR = DATA_DIR / "key" | |
| def normalize_text(text: str | None) -> str: | |
| """Normalize text for comparison (lower case, stripped, no extra whitespace).""" | |
| if text is None: | |
| return "" | |
| return " ".join(text.lower().strip().split()) | |
| def compare_totals(val1: float | None, val2: str | None) -> bool: | |
| """Compare a float (extracted) with a string (ground truth) fuzzy-style.""" | |
| if val1 is None or val2 is None: | |
| return False | |
| try: | |
| # Convert val2 to float | |
| gt_val = float(val2.replace(",", "")) | |
| return abs(val1 - gt_val) < 0.01 | |
| except ValueError: | |
| return False | |
| def run_benchmark(limit: int = 1000): | |
| """ | |
| Run benchmarking on the SROIE dataset images. | |
| Args: | |
| limit (int): Max number of images to process. | |
| """ | |
| if not IMG_DIR.exists(): | |
| print(f"ERROR: Image directory not found at {IMG_DIR}") | |
| return | |
| # Get list of images | |
| image_files = sorted(list(IMG_DIR.glob("*.jpg")))[:limit] | |
| total_images = len(image_files) | |
| results = [] | |
| print(f"🚀 Starting benchmark on {total_images} images...") | |
| print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}") | |
| for img_path in tqdm(image_files, desc="Benchmarking"): | |
| # 1. Load Ground Truth | |
| base_name = img_path.stem | |
| key_path = KEY_DIR / f"{base_name}.json" | |
| if not key_path.exists(): | |
| continue | |
| with open(key_path, "r") as f: | |
| gt = json.load(f) | |
| # 2. Run Pipeline | |
| try: | |
| # Preprocess | |
| bgr_img = utils.preprocess_image(img_path) | |
| # OCR | |
| full_text = ocr.extract_text(bgr_img) | |
| # Extract fields | |
| extracted = extractor.parse_invoice(full_text) | |
| # 3. Compare Fields | |
| v_match = normalize_text(gt.get("company")) in normalize_text(extracted.get("vendor")) or \ | |
| normalize_text(extracted.get("vendor")) in normalize_text(gt.get("company")) | |
| d_match = normalize_text(gt.get("date")) == normalize_text(extracted.get("date")) | |
| t_match = compare_totals(extracted.get("total"), gt.get("total")) | |
| results.append({ | |
| "file": base_name, | |
| "vendor_ok": v_match, | |
| "date_ok": d_match, | |
| "total_ok": t_match, | |
| "extracted_vendor": extracted.get("vendor"), | |
| "gt_vendor": gt.get("company"), | |
| "extracted_date": extracted.get("date"), | |
| "gt_date": gt.get("date"), | |
| "extracted_total": extracted.get("total"), | |
| "gt_total": gt.get("total"), | |
| }) | |
| except Exception as e: | |
| print(f"ERR processing {base_name}: {e}") | |
| continue | |
| # Generate Report | |
| if not results: | |
| print("No results to report.") | |
| return | |
| df = pd.DataFrame(results) | |
| vendor_acc = df["vendor_ok"].mean() * 100 | |
| date_acc = df["date_ok"].mean() * 100 | |
| total_acc = df["total_ok"].mean() * 100 | |
| print("\n" + "="*40) | |
| print(" SROIE BENCHMARK REPORT ") | |
| print("="*40) | |
| print(f"Total Processed: {len(df)}") | |
| print(f"Vendor Accuracy: {vendor_acc:5.1f}%") | |
| print(f"Date Accuracy: {date_acc:5.1f}%") | |
| print(f"Total Accuracy: {total_acc:5.1f}%") | |
| print("="*40) | |
| # Save mismatches for analysis | |
| mismatches = df[(~df["vendor_ok"]) | (~df["date_ok"]) | (~df["total_ok"])] | |
| mismatches.to_csv("benchmark_mismatches.csv", index=False) | |
| print(f"Mismatches saved to 'benchmark_mismatches.csv' ({len(mismatches)} rows)") | |
| if __name__ == "__main__": | |
| run_benchmark() | |