| """ |
| Evaluate the fine-tuned Donut model and generate a Field-Level Confusion Matrix. |
| Run this on the Workbench where the model and datasets are located. |
| |
| Usage: |
| python scripts/evaluate_model.py \ |
| --model_path outputs/receipt_donut_gcp_enterprise/best_model \ |
| --config configs/gcp_l4_enterprise.yaml \ |
| --output_dir evaluation_results |
| |
| Outputs: |
| - evaluation_results/field_confusion_matrix.png |
| - evaluation_results/field_accuracy.json |
| - evaluation_results/error_analysis.html |
| """ |
|
|
| import os |
| import sys |
| import json |
| import argparse |
| import Levenshtein |
| from pathlib import Path |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| from transformers import DonutProcessor, VisionEncoderDecoderModel |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from core.unified_dataset import UnifiedReceiptDataset |
|
|
|
|
| FIELDS = ["merchant", "date", "subtotal", "tax", "total", "address"] |
|
|
|
|
| def normalize_text(text): |
| """Lowercase and strip whitespace for fair comparison.""" |
| if text is None: |
| return "" |
| return str(text).lower().strip().replace("$", "").replace(",", "") |
|
|
|
|
| def categorize_match(gt, pred): |
| """ |
| Categorize a single field prediction into: |
| - correct: exact match after normalization |
| - minor_typo: < 20% Levenshtein distance |
| - incorrect: everything else |
| """ |
| gt_norm = normalize_text(gt) |
| pred_norm = normalize_text(pred) |
|
|
| if not gt_norm and not pred_norm: |
| return "correct" |
| if not gt_norm or not pred_norm: |
| return "incorrect" |
|
|
| if gt_norm == pred_norm: |
| return "correct" |
|
|
| dist = Levenshtein.distance(gt_norm, pred_norm) |
| max_len = max(len(gt_norm), len(pred_norm)) |
| ratio = dist / max_len if max_len > 0 else 0 |
|
|
| if ratio < 0.20: |
| return "minor_typo" |
| return "incorrect" |
|
|
|
|
| def run_inference(model, processor, image_path, device): |
| """Run model inference on a single image and return parsed JSON dict.""" |
| img = Image.open(image_path).convert("RGB") |
| pixel_values = processor(img, return_tensors="pt").pixel_values.to(device) |
| decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| pixel_values, |
| decoder_input_ids=decoder_input_ids, |
| max_length=512, |
| pad_token_id=processor.tokenizer.pad_token_id, |
| eos_token_id=processor.tokenizer.eos_token_id, |
| use_cache=True, |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], |
| ) |
|
|
| seq = processor.tokenizer.batch_decode(outputs.sequences)[0] |
| seq = seq.replace(processor.tokenizer.eos_token, "").replace( |
| processor.tokenizer.pad_token, "" |
| ) |
| seq = seq.replace( |
| processor.tokenizer.decode([model.config.decoder_start_token_id]), "" |
| ).strip() |
|
|
| try: |
| return json.loads(seq) |
| except json.JSONDecodeError: |
| return {} |
|
|
|
|
| def evaluate(model, processor, dataset, device, max_samples=None): |
| """ |
| Evaluate the model on a dataset and return per-field statistics. |
| """ |
| counts = {field: {"correct": 0, "minor_typo": 0, "incorrect": 0} for field in FIELDS} |
| errors = [] |
|
|
| n = min(len(dataset), max_samples) if max_samples else len(dataset) |
| print(f"Evaluating on {n} samples...") |
|
|
| for i in range(n): |
| sample = dataset[i] |
| image_path = sample["image_path"] |
| gt = sample["ground_truth"] |
|
|
| pred = run_inference(model, processor, image_path, device) |
|
|
| sample_error = {"image": image_path, "gt": gt, "pred": pred, "fields": {}} |
| all_correct = True |
|
|
| for field in FIELDS: |
| gt_val = gt.get(field, "") |
| pred_val = pred.get(field, "") |
| cat = categorize_match(gt_val, pred_val) |
| counts[field][cat] += 1 |
| sample_error["fields"][field] = cat |
| if cat != "correct": |
| all_correct = False |
|
|
| if not all_correct: |
| errors.append(sample_error) |
|
|
| if (i + 1) % 50 == 0: |
| print(f" Processed {i + 1}/{n}") |
|
|
| return counts, errors |
|
|
|
|
| def plot_confusion_matrix(counts, output_dir): |
| """Generate a stacked bar chart confusion matrix per field.""" |
| categories = ["correct", "minor_typo", "incorrect"] |
| colors = ["#4CAF50", "#FFC107", "#F44336"] |
|
|
| fig, ax = plt.subplots(figsize=(10, 6)) |
| x = np.arange(len(FIELDS)) |
| width = 0.25 |
|
|
| for i, cat in enumerate(categories): |
| values = [counts[f][cat] for f in FIELDS] |
| ax.bar(x + i * width, values, width, label=cat.replace("_", " ").title(), color=colors[i]) |
|
|
| ax.set_xlabel("Field") |
| ax.set_ylabel("Count") |
| ax.set_title("Field-Level Confusion Matrix (Validation/Test Set)") |
| ax.set_xticks(x + width) |
| ax.set_xticklabels(FIELDS, rotation=15, ha="right") |
| ax.legend() |
| ax.grid(axis="y", linestyle="--", alpha=0.5) |
| plt.tight_layout() |
|
|
| save_path = os.path.join(output_dir, "field_confusion_matrix.png") |
| plt.savefig(save_path, dpi=150) |
| print(f"Saved confusion matrix to {save_path}") |
| plt.close() |
|
|
|
|
| def save_accuracy_json(counts, output_dir): |
| """Save numerical accuracy breakdown per field.""" |
| results = {} |
| for field in FIELDS: |
| total = sum(counts[field].values()) |
| results[field] = { |
| "correct_pct": round(counts[field]["correct"] / total * 100, 1), |
| "minor_typo_pct": round(counts[field]["minor_typo"] / total * 100, 1), |
| "incorrect_pct": round(counts[field]["incorrect"] / total * 100, 1), |
| "counts": counts[field], |
| } |
|
|
| save_path = os.path.join(output_dir, "field_accuracy.json") |
| with open(save_path, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"Saved accuracy JSON to {save_path}") |
|
|
|
|
| def save_error_html(errors, output_dir, max_display=50): |
| """Generate an HTML file showing side-by-side GT vs Pred errors.""" |
| html = ["<html><head><style>", |
| "body{font-family:sans-serif;margin:20px}", |
| "table{border-collapse:collapse;width:100%}", |
| "th,td{border:1px solid #ccc;padding:8px;text-align:left}", |
| "th{background:#f0f0f0}", |
| ".correct{color:green}.minor{color:orange}.incorrect{color:red}", |
| "</style></head><body>", |
| f"<h1>Error Analysis ({min(len(errors), max_display)} of {len(errors)} failures)</h1>", |
| "<table><tr><th>Image</th><th>Field</th><th>Ground Truth</th><th>Predicted</th><th>Status</th></tr>"] |
|
|
| for err in errors[:max_display]: |
| img_name = os.path.basename(err["image"]) |
| for field in FIELDS: |
| status = err["fields"][field] |
| if status == "correct": |
| continue |
| css_class = "correct" if status == "correct" else ("minor" if status == "minor_typo" else "incorrect") |
| html.append(f"<tr><td>{img_name}</td><td>{field}</td>" |
| f"<td>{err['gt'].get(field, 'N/A')}</td>" |
| f"<td>{err['pred'].get(field, 'N/A')}</td>" |
| f"<td class='{css_class}'>{status}</td></tr>") |
|
|
| html.append("</table></body></html>") |
|
|
| save_path = os.path.join(output_dir, "error_analysis.html") |
| with open(save_path, "w") as f: |
| f.write("\n".join(html)) |
| print(f"Saved error analysis HTML to {save_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate Donut receipt model") |
| parser.add_argument("--model_path", required=True, help="Path to fine-tuned model") |
| parser.add_argument("--config", default="configs/gcp_l4_enterprise.yaml", help="Training config YAML") |
| parser.add_argument("--output_dir", default="evaluation_results", help="Where to save results") |
| parser.add_argument("--max_samples", type=int, default=None, help="Limit evaluation samples") |
| parser.add_argument("--split", default="test", choices=["train", "val", "test"], help="Which split to evaluate") |
| args = parser.parse_args() |
|
|
| import yaml |
| with open(args.config, "r") as f: |
| config = yaml.safe_load(f) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| print(f"Loading model from {args.model_path}...") |
| processor = DonutProcessor.from_pretrained(args.model_path) |
| model = VisionEncoderDecoderModel.from_pretrained(args.model_path) |
| model.to(device).eval() |
|
|
| print(f"Loading dataset split: {args.split}") |
| dataset = UnifiedReceiptDataset( |
| root=config["data"]["dataset_root"], |
| split=args.split, |
| processor=None, |
| include_datasets=config["data"].get("include_datasets"), |
| ) |
|
|
| counts, errors = evaluate(model, processor, dataset, device, args.max_samples) |
| plot_confusion_matrix(counts, args.output_dir) |
| save_accuracy_json(counts, args.output_dir) |
| save_error_html(errors, args.output_dir) |
|
|
| print("\n=== Evaluation Complete ===") |
| for field in FIELDS: |
| total = sum(counts[field].values()) |
| c = counts[field]["correct"] |
| m = counts[field]["minor_typo"] |
| i = counts[field]["incorrect"] |
| print(f" {field:12s}: Correct={c}/{total} ({c/total*100:.1f}%) | " |
| f"Minor={m} | Incorrect={i}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|