DIVYANSHI SINGH
Root project layout configured for deployment
b0bec61
"""
benchmark.py — Accuracy evaluation script for the Bill/Invoice Scanner.
This script processes the 1,000-receipt SROIE dataset and compares
extracted fields (Vendor, Date, Total) against the ground-truth JSON files.
Usage:
conda run -n dl_projects python benchmark.py
Metrics:
- Vendor Accuracy: Case-normalized partial match.
- Date Accuracy: String equality after normalization.
- Total Accuracy: Fuzzy float equality (within 0.01).
"""
import os
import json
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import torch
# Project modules
import utils
import ocr
import extractor
# Dataset paths
DATA_DIR = Path("SROIE_Dataset/data")
IMG_DIR = DATA_DIR / "img"
KEY_DIR = DATA_DIR / "key"
def normalize_text(text: str | None) -> str:
"""Normalize text for comparison (lower case, stripped, no extra whitespace)."""
if text is None:
return ""
return " ".join(text.lower().strip().split())
def compare_totals(val1: float | None, val2: str | None) -> bool:
"""Compare a float (extracted) with a string (ground truth) fuzzy-style."""
if val1 is None or val2 is None:
return False
try:
# Convert val2 to float
gt_val = float(val2.replace(",", ""))
return abs(val1 - gt_val) < 0.01
except ValueError:
return False
def run_benchmark(limit: int = 1000):
"""
Run benchmarking on the SROIE dataset images.
Args:
limit (int): Max number of images to process.
"""
if not IMG_DIR.exists():
print(f"ERROR: Image directory not found at {IMG_DIR}")
return
# Get list of images
image_files = sorted(list(IMG_DIR.glob("*.jpg")))[:limit]
total_images = len(image_files)
results = []
print(f"🚀 Starting benchmark on {total_images} images...")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
for img_path in tqdm(image_files, desc="Benchmarking"):
# 1. Load Ground Truth
base_name = img_path.stem
key_path = KEY_DIR / f"{base_name}.json"
if not key_path.exists():
continue
with open(key_path, "r") as f:
gt = json.load(f)
# 2. Run Pipeline
try:
# Preprocess
bgr_img = utils.preprocess_image(img_path)
# OCR
full_text = ocr.extract_text(bgr_img)
# Extract fields
extracted = extractor.parse_invoice(full_text)
# 3. Compare Fields
v_match = normalize_text(gt.get("company")) in normalize_text(extracted.get("vendor")) or \
normalize_text(extracted.get("vendor")) in normalize_text(gt.get("company"))
d_match = normalize_text(gt.get("date")) == normalize_text(extracted.get("date"))
t_match = compare_totals(extracted.get("total"), gt.get("total"))
results.append({
"file": base_name,
"vendor_ok": v_match,
"date_ok": d_match,
"total_ok": t_match,
"extracted_vendor": extracted.get("vendor"),
"gt_vendor": gt.get("company"),
"extracted_date": extracted.get("date"),
"gt_date": gt.get("date"),
"extracted_total": extracted.get("total"),
"gt_total": gt.get("total"),
})
except Exception as e:
print(f"ERR processing {base_name}: {e}")
continue
# Generate Report
if not results:
print("No results to report.")
return
df = pd.DataFrame(results)
vendor_acc = df["vendor_ok"].mean() * 100
date_acc = df["date_ok"].mean() * 100
total_acc = df["total_ok"].mean() * 100
print("\n" + "="*40)
print(" SROIE BENCHMARK REPORT ")
print("="*40)
print(f"Total Processed: {len(df)}")
print(f"Vendor Accuracy: {vendor_acc:5.1f}%")
print(f"Date Accuracy: {date_acc:5.1f}%")
print(f"Total Accuracy: {total_acc:5.1f}%")
print("="*40)
# Save mismatches for analysis
mismatches = df[(~df["vendor_ok"]) | (~df["date_ok"]) | (~df["total_ok"])]
mismatches.to_csv("benchmark_mismatches.csv", index=False)
print(f"Mismatches saved to 'benchmark_mismatches.csv' ({len(mismatches)} rows)")
if __name__ == "__main__":
run_benchmark()