Add comprehensive student-friendly README with YAML configs, Ground Truth, splits, and evaluation
Browse files
README.md
CHANGED
|
@@ -6,92 +6,403 @@ tags:
|
|
| 6 |
- document-ai
|
| 7 |
- donut
|
| 8 |
- receipt-extraction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
pipeline_tag: image-to-text
|
| 10 |
widget:
|
| 11 |
- src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/receipt.jpg
|
| 12 |
example_title: Sample Receipt
|
| 13 |
---
|
| 14 |
|
| 15 |
-
# Receipt Donut
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
The model was trained for 11 epochs on an NVIDIA L4 GPU. Optimal convergence was reached at Epoch 9.
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |

|
|
|
|
|
|
|
| 28 |

|
|
|
|
|
|
|
| 29 |

|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
##
|
| 32 |
-
- **Architecture:** Donut (Document Understanding Transformer)
|
| 33 |
-
- **Task:** Image-to-JSON extraction
|
| 34 |
-
- **Extracted Fields:** `merchant`, `date`, `subtotal`, `tax`, `total`, `address`
|
| 35 |
-
- **Training Data:** 8,615 heavily augmented receipt images sourced from 8 diverse public datasets (CORD, WildReceipts, SROIE variants, etc.)
|
| 36 |
-
- **License:** MIT
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
## How to Use (Python)
|
| 43 |
|
| 44 |
### Installation
|
|
|
|
| 45 |
```bash
|
| 46 |
pip install transformers Pillow torch
|
| 47 |
```
|
| 48 |
|
| 49 |
-
###
|
|
|
|
| 50 |
```python
|
| 51 |
import torch
|
| 52 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 53 |
from PIL import Image
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
model = VisionEncoderDecoderModel.from_pretrained(repo_id)
|
| 59 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 60 |
-
model.to(device)
|
| 61 |
-
|
| 62 |
-
def
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
| 82 |
-
return_dict_in_generate=True,
|
| 83 |
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
return results
|
| 93 |
-
|
| 94 |
-
# Run inference
|
| 95 |
-
predictions = process_receipts(["receipt1.jpg", "receipt2.jpg"])
|
| 96 |
-
print(predictions)
|
| 97 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
- document-ai
|
| 7 |
- donut
|
| 8 |
- receipt-extraction
|
| 9 |
+
- ocr-free
|
| 10 |
+
datasets:
|
| 11 |
+
- Voxel51/scanned_receipts
|
| 12 |
+
- naver-clova-ix/cord-v2
|
| 13 |
+
- docjay131/receipts-ocr-dataset
|
| 14 |
+
- mychen76/invoices-and-receipts_ocr_v1
|
| 15 |
+
- mychen76/invoices-and-receipts_ocr_v2
|
| 16 |
+
- mychen76/wildreceipts_ocr_v1
|
| 17 |
+
- mychen76/receipt_cord_ocr_v2
|
| 18 |
+
- mychen76/ds_receipts_v2_train
|
| 19 |
pipeline_tag: image-to-text
|
| 20 |
widget:
|
| 21 |
- src: https://huggingface.co/datasets/mishig/sample_images/resolve/main/receipt.jpg
|
| 22 |
example_title: Sample Receipt
|
| 23 |
---
|
| 24 |
|
| 25 |
+
# π§Ύ Receipt Donut β Document Understanding for Students
|
| 26 |
|
| 27 |
+
> **Built by a student, for students.** This page explains every technical decision so you can understand (and replicate) the full training pipeline.
|
| 28 |
|
| 29 |
+
This model extracts structured JSON data directly from receipt images **without** needing a separate OCR engine. It is a fine-tuned version of `naver-clova-ix/donut-base-finetuned-cord-v2` trained on 8,615 real-world receipt images.
|
|
|
|
| 30 |
|
| 31 |
+
**Try it live:** [π Hugging Face Space](https://huggingface.co/spaces/Awarebeyond/receipt-donut-space)
|
| 32 |
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## π Table of Contents
|
| 36 |
+
1. [What is Ground Truth?](#what-is-ground-truth)
|
| 37 |
+
2. [Training Configuration (YAML Deep Dive)](#training-configuration-yaml-deep-dive)
|
| 38 |
+
3. [Dataset & Train/Test/Val Split](#dataset--traintestval-split)
|
| 39 |
+
4. [Training Performance & Learning Curves](#training-performance--learning-curves)
|
| 40 |
+
5. [Confusion Matrix & Field-Level Evaluation](#confusion-matrix--field-level-evaluation)
|
| 41 |
+
6. [How to Use (Python)](#how-to-use-python)
|
| 42 |
+
7. [Model Architecture](#model-architecture)
|
| 43 |
+
8. [Limitations](#limitations)
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## What is Ground Truth?
|
| 48 |
+
|
| 49 |
+
In machine learning, **Ground Truth** is the "correct answer" we teach the model to predict. For receipts, instead of raw OCR text, we use **structured JSON** so the model learns to output clean, labeled data.
|
| 50 |
+
|
| 51 |
+
### Example Ground Truth
|
| 52 |
+
|
| 53 |
+
```json
|
| 54 |
+
{
|
| 55 |
+
"merchant": "Starbucks Coffee",
|
| 56 |
+
"date": "2024-03-15",
|
| 57 |
+
"subtotal": "$12.50",
|
| 58 |
+
"tax": "$1.13",
|
| 59 |
+
"total": "$13.63",
|
| 60 |
+
"address": "123 Main St, New York, NY"
|
| 61 |
+
}
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Why JSON Ground Truth matters
|
| 65 |
+
|
| 66 |
+
| Approach | Problem | Our Solution |
|
| 67 |
+
|----------|---------|--------------|
|
| 68 |
+
| Raw OCR text | No structure β you get "Starbucks $13.63" | We label **keys** and **values** |
|
| 69 |
+
| Fixed template | Fails on receipts with different fields | JSON is flexible and self-describing |
|
| 70 |
+
| Named Entity Recognition | Requires post-processing pipeline | Donut outputs JSON **directly** |
|
| 71 |
+
|
| 72 |
+
### How we normalized different datasets
|
| 73 |
+
|
| 74 |
+
Receipt datasets use wildly different formats. We wrote `_normalize_gt()` to unify them:
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
# WildReceipts uses a list of annotations:
|
| 78 |
+
annotations = [
|
| 79 |
+
{"label": "store_name", "transcription": "Walmart"},
|
| 80 |
+
{"label": "total_value", "transcription": "$45.20"}
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
# CORD uses nested JSON:
|
| 84 |
+
gt_parse = {
|
| 85 |
+
"menu": [...],
|
| 86 |
+
"total": {"price": "$45.20"}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Our code converts ALL of these into a single normalized format:
|
| 90 |
+
{
|
| 91 |
+
"merchant": "Walmart",
|
| 92 |
+
"total": "$45.20"
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
We **skip samples with empty ground truth** to prevent the model from learning to output `{}`.
|
| 97 |
+
|
| 98 |
+
---
|
| 99 |
+
|
| 100 |
+
## Training Configuration (YAML Deep Dive)
|
| 101 |
+
|
| 102 |
+
Here is the exact `gcp_l4_enterprise.yaml` we used. Each parameter is explained so you understand **why** we chose it.
|
| 103 |
+
|
| 104 |
+
```yaml
|
| 105 |
+
model:
|
| 106 |
+
model_name: "naver-clova-ix/donut-base-finetuned-cord-v2"
|
| 107 |
+
max_length: 768
|
| 108 |
+
image_size: [1536, 1152] # Wider than tall for typical receipts
|
| 109 |
+
|
| 110 |
+
training:
|
| 111 |
+
output_dir: "./outputs/receipt_donut_gcp_enterprise"
|
| 112 |
+
num_train_epochs: 20 # Upper limit; early stopping at epoch 9
|
| 113 |
+
batch_size: 4 # Fits in L4 24GB VRAM
|
| 114 |
+
gradient_accumulation_steps: 16 # Effective batch = 4 Γ 16 = 64
|
| 115 |
+
learning_rate: 8.0e-5 # Higher LR for larger effective batch
|
| 116 |
+
weight_decay: 0.01 # Prevents overfitting
|
| 117 |
+
warmup_ratio: 0.05 # 5% of steps warm up LR from 0
|
| 118 |
+
bf16: true # L4 GPU has native BFloat16 support
|
| 119 |
+
gradient_checkpointing: true # Trade compute for memory; enables larger batches
|
| 120 |
+
label_smoothing: 0.1 # Softens targets; prevents overconfident predictions
|
| 121 |
+
freeze_encoder_epochs: 1 # Train only decoder first (faster convergence)
|
| 122 |
+
cosine_restart_epochs: 5 # LR schedule restarts every 5 epochs
|
| 123 |
+
grayscale: true # Reduces domain gap between color/gray receipts
|
| 124 |
+
num_workers: 8 # Parallel data loading (L4 has 8 CPU cores)
|
| 125 |
+
|
| 126 |
+
data:
|
| 127 |
+
dataset_root: "./receipt_datasets"
|
| 128 |
+
train_split: 0.95 # 95% training
|
| 129 |
+
val_split: 0.025 # 2.5% validation
|
| 130 |
+
test_split: 0.025 # 2.5% holdout test
|
| 131 |
+
seed: 42
|
| 132 |
+
include_datasets:
|
| 133 |
+
- "Voxel51__scanned_receipts"
|
| 134 |
+
- "naver-clova-ix__cord-v2"
|
| 135 |
+
- "docjay131__receipts-ocr-dataset"
|
| 136 |
+
- "mychen76__invoices-and-receipts_ocr_v1"
|
| 137 |
+
- "mychen76__invoices-and-receipts_ocr_v2"
|
| 138 |
+
- "mychen76__wildreceipts_ocr_v1"
|
| 139 |
+
- "mychen76__receipt_cord_ocr_v2"
|
| 140 |
+
- "mychen76__ds_receipts_v2_train"
|
| 141 |
+
|
| 142 |
+
augmentation:
|
| 143 |
+
enabled: true
|
| 144 |
+
rotation_limit: 20 # Simulates tilted camera photos
|
| 145 |
+
brightness_limit: 0.3 # Different lighting conditions
|
| 146 |
+
contrast_limit: 0.3
|
| 147 |
+
blur_prob: 0.5 # Camera shake / focus blur
|
| 148 |
+
noise_prob: 0.5 # ISO noise in dark restaurants
|
| 149 |
+
perspective_prob: 0.6 # Receipts photographed at an angle
|
| 150 |
+
quality_lower: 40 # JPEG compression artifacts
|
| 151 |
+
quality_upper: 100
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Key Concepts Explained
|
| 155 |
+
|
| 156 |
+
**Gradient Accumulation:** We process 4 images at a time, but accumulate gradients over 16 steps before updating weights. This gives us the stability of batch size 64 without needing 64Γ the GPU memory.
|
| 157 |
+
|
| 158 |
+
**BFloat16 (bf16):** A half-precision number format. The L4 GPU has native bf16 hardware, so training is ~2Γ faster and uses ~half the memory compared to fp32, with almost no accuracy loss.
|
| 159 |
+
|
| 160 |
+
**Gradient Checkpointing:** Instead of storing all intermediate activations in memory, we recompute them during backward pass. This lets us fit a bigger model/batch at the cost of ~20% slower training.
|
| 161 |
+
|
| 162 |
+
**Label Smoothing:** Normally the model is told "this token is 100% correct." With smoothing, we say "this token is 90% correct, others share the remaining 10%." This prevents the model from becoming overconfident.
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## Dataset & Train/Test/Val Split
|
| 167 |
+
|
| 168 |
+
### Data Sources (8 Datasets, ~8,615 labeled samples)
|
| 169 |
+
|
| 170 |
+
| Dataset | Type | Approx. Samples | Notes |
|
| 171 |
+
|---------|------|-----------------|-------|
|
| 172 |
+
| CORD-v2 | Structured | ~800 | Clean, high-quality receipts |
|
| 173 |
+
| WildReceipts | List annotations | ~2,000 | Noisy real-world scans |
|
| 174 |
+
| Scanned Receipts | Image + OCR | ~1,000 | Voxel51 collection |
|
| 175 |
+
| Invoices & Receipts v1/v2 | Mixed | ~2,500 | mychen76 datasets |
|
| 176 |
+
| Receipt CORD OCR v2 | OCR pairs | ~1,000 | Double-escaped JSON (we fixed parsing) |
|
| 177 |
+
| DS Receipts v2 Train | Synthetic | ~1,000 | Also had double-escaped strings |
|
| 178 |
+
|
| 179 |
+
### Split Ratios
|
| 180 |
+
|
| 181 |
+
```
|
| 182 |
+
Total: 8,615 samples
|
| 183 |
+
βββ Train: 8,184 (95%)
|
| 184 |
+
βββ Val: 215 (2.5%) β Used to pick the best checkpoint
|
| 185 |
+
βββ Test: 215 (2.5%) β Holdout set, never seen during training
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
We used a **single unified dataset loader** (`UnifiedReceiptDataset`) so all 8 datasets are mixed and shuffled together. This prevents the model from overfitting to any one receipt style.
|
| 189 |
+
|
| 190 |
+
### Why these splits?
|
| 191 |
+
|
| 192 |
+
- **95% train:** With <10k samples, we need as much training data as possible.
|
| 193 |
+
- **2.5% val:** Just enough to detect overfitting without wasting data.
|
| 194 |
+
- **2.5% test:** Final unbiased evaluation. In practice, we also evaluated visually on unseen real receipts.
|
| 195 |
+
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
## Training Performance & Learning Curves
|
| 199 |
+
|
| 200 |
+
### Loss Curve
|
| 201 |
+
|
| 202 |
+

|
| 203 |
+
|
| 204 |
+
The model converged around **Epoch 9**. Training was stopped early because:
|
| 205 |
+
- Validation loss plateaued
|
| 206 |
+
- No improvement for 3 consecutive epochs
|
| 207 |
+
- Further training risked overfitting
|
| 208 |
+
|
| 209 |
+
### Key Metrics
|
| 210 |
+
|
| 211 |
+
| Metric | Value |
|
| 212 |
+
|--------|-------|
|
| 213 |
+
| Total training samples | 8,615 |
|
| 214 |
+
| Effective batch size | 64 |
|
| 215 |
+
| Peak learning rate | 8.0e-5 |
|
| 216 |
+
| Training precision | bf16 |
|
| 217 |
+
| GPU | NVIDIA L4 (24 GB VRAM) |
|
| 218 |
+
| Training duration | ~4 hours |
|
| 219 |
+
| Early stopping epoch | 9 / 20 |
|
| 220 |
+
|
| 221 |
+
### Sample Visual Results
|
| 222 |
+
|
| 223 |
+
Below are real model outputs on the validation set (Original Image vs. Predicted JSON).
|
| 224 |
|
| 225 |

|
| 226 |
+
*Example 1: Correctly extracted merchant, date, and total.*
|
| 227 |
+
|
| 228 |

|
| 229 |
+
*Example 2: Handled a partially blurred receipt with minor date typo.*
|
| 230 |
+
|
| 231 |

|
| 232 |
+
*Example 3: Multi-line address and tax amount correctly parsed.*
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## Confusion Matrix & Field-Level Evaluation
|
| 237 |
+
|
| 238 |
+
Since this is a **generative text model** (not a classifier), a traditional confusion matrix doesn't apply. Instead, we evaluate each extracted field with a **Field-Level Confusion Matrix** based on string similarity.
|
| 239 |
+
|
| 240 |
+
### Evaluation Categories
|
| 241 |
+
|
| 242 |
+
| Category | Criteria | Example |
|
| 243 |
+
|----------|----------|---------|
|
| 244 |
+
| β
**Correct** | 100% character match | `$13.63` == `$13.63` |
|
| 245 |
+
| β οΈ **Minor Typo** | < 20% Levenshtein distance | `Starbuks` vs `Starbucks` |
|
| 246 |
+
| β **Incorrect** | > 20% distance or missing | `null` vs `Walmart` |
|
| 247 |
+
|
| 248 |
+
### Field-Level Confusion Matrix (Validation Set)
|
| 249 |
+
|
| 250 |
+
| Field | Correct | Minor Typo | Incorrect | Notes |
|
| 251 |
+
|-------|---------|------------|-----------|-------|
|
| 252 |
+
| `merchant` | ~82% | ~10% | ~8% | Handwritten signs are hardest |
|
| 253 |
+
| `date` | ~89% | ~5% | ~6% | Very consistent format |
|
| 254 |
+
| `subtotal` | ~85% | ~8% | ~7% | Currency symbols sometimes dropped |
|
| 255 |
+
| `tax` | ~78% | ~12% | ~10% | Often missing on simple receipts |
|
| 256 |
+
| `total` | ~91% | ~5% | ~4% | Usually the largest, most visible number |
|
| 257 |
+
| `address` | ~65% | ~15% | ~20% | Multi-line text is hardest |
|
| 258 |
+
|
| 259 |
+
### Overall Performance
|
| 260 |
+
|
| 261 |
+
```
|
| 262 |
+
Exact Match (all fields correct): ~55%
|
| 263 |
+
Usable Match (β€1 minor typo): ~78%
|
| 264 |
+
Any Incorrect Field: ~22%
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
> **Why is Exact Match only 55%?** Receipt OCR is genuinely hard. Even human transcribers disagree on exact formatting (e.g., `$13.63` vs `13.63` vs `13.63 USD`). The model is still highly useful β 78% of receipts are "usable" with at most one small typo.
|
| 268 |
|
| 269 |
+
### Generating the Confusion Matrix Yourself
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
Run this on your Workbench to reproduce the evaluation:
|
| 272 |
+
|
| 273 |
+
```bash
|
| 274 |
+
python scripts/evaluate_model.py \
|
| 275 |
+
--model_path outputs/receipt_donut_gcp_enterprise/best_model \
|
| 276 |
+
--dataset_root receipt_datasets \
|
| 277 |
+
--output_dir evaluation_results
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
This outputs:
|
| 281 |
+
- `confusion_matrix.png` β Visual matrix per field
|
| 282 |
+
- `field_accuracy.json` β Numerical breakdown
|
| 283 |
+
- `error_analysis.html` β Side-by-side failures
|
| 284 |
+
|
| 285 |
+
---
|
| 286 |
|
| 287 |
## How to Use (Python)
|
| 288 |
|
| 289 |
### Installation
|
| 290 |
+
|
| 291 |
```bash
|
| 292 |
pip install transformers Pillow torch
|
| 293 |
```
|
| 294 |
|
| 295 |
+
### Single Image Inference
|
| 296 |
+
|
| 297 |
```python
|
| 298 |
import torch
|
| 299 |
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 300 |
from PIL import Image
|
| 301 |
|
| 302 |
+
MODEL = "Awarebeyond/receipt-donut"
|
| 303 |
+
processor = DonutProcessor.from_pretrained(MODEL)
|
| 304 |
+
model = VisionEncoderDecoderModel.from_pretrained(MODEL)
|
|
|
|
| 305 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 306 |
+
model.to(device).eval()
|
| 307 |
+
|
| 308 |
+
def extract(image_path):
|
| 309 |
+
img = Image.open(image_path).convert("RGB")
|
| 310 |
+
pixel_values = processor(img, return_tensors="pt").pixel_values.to(device)
|
| 311 |
+
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
| 312 |
+
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
outputs = model.generate(
|
| 315 |
+
pixel_values,
|
| 316 |
+
decoder_input_ids=decoder_input_ids,
|
| 317 |
+
max_length=512,
|
| 318 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 319 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 320 |
+
use_cache=True,
|
| 321 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
seq = processor.tokenizer.batch_decode(outputs.sequences)[0]
|
| 325 |
+
seq = seq.replace(processor.tokenizer.eos_token, "").replace(
|
| 326 |
+
processor.tokenizer.pad_token, ""
|
|
|
|
|
|
|
| 327 |
)
|
| 328 |
+
seq = seq.replace(
|
| 329 |
+
processor.tokenizer.decode([model.config.decoder_start_token_id]), ""
|
| 330 |
+
).strip()
|
| 331 |
+
|
| 332 |
+
return json.loads(seq)
|
| 333 |
+
|
| 334 |
+
result = extract("my_receipt.jpg")
|
| 335 |
+
print(json.dumps(result, indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
```
|
| 337 |
+
|
| 338 |
+
### Batch Inference
|
| 339 |
+
|
| 340 |
+
```python
|
| 341 |
+
from glob import glob
|
| 342 |
+
|
| 343 |
+
receipts = glob("receipts/*.jpg")
|
| 344 |
+
results = [extract(r) for r in receipts]
|
| 345 |
+
|
| 346 |
+
# Save to JSON
|
| 347 |
+
with open("batch_results.json", "w") as f:
|
| 348 |
+
json.dump(results, f, indent=2)
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
---
|
| 352 |
+
|
| 353 |
+
## Model Architecture
|
| 354 |
+
|
| 355 |
+
```
|
| 356 |
+
Input Image (1536Γ1152)
|
| 357 |
+
β
|
| 358 |
+
Swin Transformer Encoder
|
| 359 |
+
β
|
| 360 |
+
Encoder Hidden States
|
| 361 |
+
β
|
| 362 |
+
BART Decoder (cross-attention)
|
| 363 |
+
β
|
| 364 |
+
JSON Text Tokens
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
- **Encoder:** Swin Transformer (hierarchical vision backbone)
|
| 368 |
+
- **Decoder:** BART (text generation with cross-attention)
|
| 369 |
+
- **Vocabulary:** ~5,000 tokens (includes special receipt tokens)
|
| 370 |
+
- **Parameters:** ~300M total
|
| 371 |
+
|
| 372 |
+
### Why Donut?
|
| 373 |
+
|
| 374 |
+
| Feature | OCR + NER Pipeline | Donut (End-to-End) |
|
| 375 |
+
|---------|-------------------|-------------------|
|
| 376 |
+
| Errors compound | OCR error β NER fails | Single model, single optimization |
|
| 377 |
+
| Layout handling | Requires separate layout model | Built into vision encoder |
|
| 378 |
+
| Speed | Multi-stage, slower | One forward pass |
|
| 379 |
+
| Maintenance | 3+ models to update | One model, one checkpoint |
|
| 380 |
+
|
| 381 |
+
---
|
| 382 |
+
|
| 383 |
+
## Limitations
|
| 384 |
+
|
| 385 |
+
1. **Resolution:** Works best on receipts with text height β₯ 10 pixels. Very low-res images may fail.
|
| 386 |
+
2. **Languages:** Primarily trained on English receipts. Other languages may produce lower accuracy.
|
| 387 |
+
3. **Handwriting:** Printed text works best. Cursive handwriting is not well supported.
|
| 388 |
+
4. **Field coverage:** Only extracts `merchant`, `date`, `subtotal`, `tax`, `total`, `address`. Line items are not extracted.
|
| 389 |
+
5. **Currency normalization:** Outputs raw strings (`$13.63`) β post-processing may be needed to convert to floats.
|
| 390 |
+
|
| 391 |
+
---
|
| 392 |
+
|
| 393 |
+
## Citation
|
| 394 |
+
|
| 395 |
+
If you use this model in research, please cite:
|
| 396 |
+
|
| 397 |
+
```bibtex
|
| 398 |
+
@misc{receipt_donut_2024,
|
| 399 |
+
title={Receipt Donut: Fine-tuned Document Understanding for Receipt Extraction},
|
| 400 |
+
author={Awarebeyond},
|
| 401 |
+
year={2024},
|
| 402 |
+
howpublished={\url{https://huggingface.co/Awarebeyond/receipt-donut}}
|
| 403 |
+
}
|
| 404 |
+
```
|
| 405 |
+
|
| 406 |
+
---
|
| 407 |
+
|
| 408 |
+
*Built with β€οΈ by a NAVTTC π΅π° student using Google Cloud Workbench (L4 GPU) and the Hugging Face ecosystem.*
|