File size: 8,348 Bytes
186406f
 
 
 
 
 
 
 
 
 
5d0f073
186406f
 
 
 
5d0f073
186406f
 
 
 
5d0f073
186406f
 
 
 
 
 
 
5d0f073
 
186406f
5d0f073
186406f
5d0f073
 
 
 
 
 
 
 
 
186406f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0f073
 
 
186406f
 
 
5d0f073
 
186406f
5d0f073
 
 
 
186406f
 
 
 
 
 
 
 
 
5d0f073
186406f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0f073
 
186406f
 
 
 
 
 
 
 
 
 
 
 
 
5d0f073
 
186406f
 
 
 
 
5d0f073
186406f
5d0f073
186406f
5d0f073
 
 
 
186406f
5d0f073
 
 
 
 
186406f
5d0f073
186406f
5d0f073
 
 
 
 
 
 
 
 
 
 
 
 
186406f
5d0f073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186406f
 
 
 
 
 
 
 
 
5d0f073
 
 
186406f
5d0f073
186406f
 
 
 
 
5d0f073
 
 
186406f
 
 
 
 
 
929ae58
 
186406f
 
 
 
 
 
 
 
 
 
5d0f073
 
 
186406f
5d0f073
186406f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
---
license: apache-2.0
tags:
  - image-classification
  - onnx
  - onnxruntime
  - mobilenet
  - mobile
  - on-device
  - document-classification
  - tally
library_name: onnx
pipeline_tag: image-classification
metrics:
  - accuracy
  - f1
base_model: timm/mobilenetv3_small_100.lamb_in1k
datasets: []
---

# DocRex

A small on-device document classifier that sorts a single image into one of:

- `bank_statement`
- `invoice`
- `other`

Designed as a first-stage triage step before any heavyweight OCR or
extraction β€” small enough to ship inside a mobile app and run fully offline.

## Recommended artifact

| File | Format | Size | Top-1 acc | Use |
|------|--------|-----:|----------:|------|
| **`invoice_classifier_fp32.onnx`** | ONNX, fp32 | ~5.8 MB | **98.35%** | **Ship this.** |
| `invoice_classifier_int8_qdq.onnx` | ONNX, QDQ static int8 | ~1.7 MB | 58.85% | ⚠️ Experimental β€” see _Quantization notes_. |

**TL;DR β€” use the fp32 model.** It's only ~6 MB, runs in well under
100 ms per image on modern phone CPUs, and has no accuracy drop. The int8
build is included for reference but is **not recommended for deployment**
(details below).

## Model details

- **Architecture**: MobileNetV3-Small (torchvision `mobilenet_v3_small`),
  ImageNet-1k pretrained backbone, final 1000-class head replaced with a
  3-class linear layer.
- **Parameters**: ~1.5 M.
- **Input**: 1 Γ— 3 Γ— 224 Γ— 224, float32, NCHW, ImageNet-normalized.
- **Output**: `logits` β€” 1 Γ— 3 unnormalized scores. Apply softmax to get
  per-class confidence.
- **Class index order** (alphabetical, must match `labels.json`):

  ```
  0  bank_statement
  1  invoice
  2  other
  ```

- **Opset**: 18.

## Intended use

- Triage classifier deciding whether a page is worth running invoice /
  statement extraction on.
- Lightweight client-side filtering before backend OCR.

### Out of scope

- **Not an OCR model** β€” does not extract text, totals, dates, or account
  numbers.
- **Not a fraud / authenticity detector.**
- **Not a layout analyzer** β€” looks at the page as a whole.
- Anything outside `{bank_statement, invoice}` collapses into `other`. The
  model does not distinguish sub-types of `other` (receipts vs IDs vs
  photos).

## How to use

```python
import json
import numpy as np
import onnxruntime as ort
from PIL import Image

session = ort.InferenceSession("invoice_classifier_fp32.onnx")
labels = json.load(open("labels.json"))
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)

img = Image.open("page.jpg").convert("RGB").resize((256, 256))
left = (256 - 224) // 2
img = img.crop((left, left, left + 224, left + 224))
x = np.asarray(img, dtype=np.float32).transpose(2, 0, 1) / 255.0
x = ((x - mean) / std)[None].astype(np.float32)

logits = session.run(["logits"], {"input": x})[0][0]
probs = np.exp(logits - logits.max())
probs /= probs.sum()
print(labels[int(probs.argmax())], float(probs.max()))
```

## Preprocessing

| Step | Value |
|------|-------|
| Resize | shorter edge β†’ 256 |
| Crop | center crop to 224 Γ— 224 |
| Color | RGB |
| Scale | divide by 255 |
| Normalize mean | `[0.485, 0.456, 0.406]` |
| Normalize std | `[0.229, 0.224, 0.225]` |
| Layout | NCHW |
| Dtype | float32 |

Standard ImageNet stats β€” also captured in `preprocess.json` for
programmatic loading.

## Training

- **Backbone weights**: torchvision `MobileNet_V3_Small_Weights.IMAGENET1K_V1`.
- **Head**: replaced with `nn.Linear(in, 3)`.
- **Optimizer**: AdamW, weight decay 1 Γ— 10⁻⁴.
- **Schedule**: cosine annealing across all epochs.
- **Stage 1**: backbone frozen for 2 epochs, only the new head trains
  (lr = 3 Γ— 10⁻⁴).
- **Stage 2**: backbone unfrozen at lr / 10, head stays at base lr
  (discriminative learning rates).
- **Loss**: `CrossEntropyLoss` with inverse-frequency class weights and
  label smoothing 0.05.
- **Augmentation**: Resize(256) β†’ RandomResizedCrop(224, scale 0.7–1.0) β†’
  ColorJitter (brightness/contrast/saturation/hue) β†’ small RandomRotation
  β†’ occasional grayscale β†’ ImageNet normalize.
- **Best checkpoint**: selected by validation accuracy.

## Evaluation

Held-out test set: **243 images** across the three classes.

### fp32

| Metric | Value |
|--------|------:|
| Top-1 accuracy | **98.35%** |
| Macro F1 | 0.9801 |

| Class | F1 |
|-------|---:|
| `bank_statement` | 0.9783 |
| `invoice` | 0.9697 |
| `other` | 0.9924 |

Confusion matrix (rows = true, cols = predicted):

|                    | bank_statement | invoice | other |
|--------------------|---------------:|--------:|------:|
| **bank_statement** | 45 | 2 | 0 |
| **invoice**        | 0  | 64 | 0 |
| **other**          | 0  | 2 | 130 |

### int8 (QDQ) β€” not recommended

| Metric | Value |
|--------|------:|
| Top-1 accuracy | 58.85% |
| Macro F1 | 0.4517 |
| Top-1 disagreement vs fp32 | **40.74% (99/243)** |

Best result observed across `MinMax` / `Entropy` / `Percentile` calibration
Γ— per-channel / per-tensor weights. All configurations produce a similar
collapse (45–58% accuracy).

## Quantization notes

Post-training static quantization of MobileNetV3-Small is a known-difficult
problem. The architecture's **Hardswish** activations and
**Squeeze-and-Excitation** blocks produce activation distributions with
extreme outliers that don't fit cleanly into INT8 scales. PTQ β€” regardless
of QDQ vs QOperator format, calibration method, or per-channel vs
per-tensor β€” accumulates enough error across ~140 tensors to collapse one
or more classes.

If you need a smaller model, in increasing order of effort:

1. **FP16** β€” usually within rounding error of fp32. Simplest path to ~3 MB.
2. **Quantization-aware training (QAT)** β€” torchvision provides
   `models.quantization.mobilenet_v3_small`. Requires a retraining run but
   typically lands within 1–2 points of fp32.
3. **Switch architectures** β€” MobileNetV2, EfficientNet-Lite0, or a small
   ConvNeXt variant all post-train-quantize more reliably than MNV3.

The shipped int8 file is left in this repo only as evidence of the failure
mode, not as a deployable artifact.

> **Why QDQ format anyway?** ONNX Runtime Mobile does not include
> `ConvInteger` / `MatMulInteger` operators. A model quantized with
> `QuantFormat.QOperator` or `quantize_dynamic` will load on desktop ORT
> and then fail at runtime on mobile with `code=9 (NOT_IMPLEMENTED)`. QDQ
> keeps standard `Conv` / `MatMul` nodes surrounded by
> `QuantizeLinear` / `DequantizeLinear`, which is the path ORT Mobile
> executes. So if you do produce a working int8 build (e.g. via QAT),
> export it as QDQ.

## Limitations and bias

- **Domain bias toward English-language, Western-format documents.**
  Performance on non-Latin scripts, right-to-left layouts, and regional
  statement / invoice formats has not been systematically measured.
- **Photo conditions matter.** Heavy glare, motion blur, extreme skew
  (>~15Β°), or occlusion shifts predictions toward `other`.
- **`other` is an open set.** Its decision boundary is determined entirely
  by the contents of the training data's `other/` folder. Receipts, IDs,
  screenshots, and shipping labels were included; any class not seen in
  training may be classified inconsistently.
- **No PII handling.** Documents are processed as opaque pixels; the model
  does not redact or filter sensitive fields.

## Files

| File | Purpose |
|------|---------|
| `invoice_classifier_fp32.onnx` | **Recommended** β€” fp32 ONNX model. |
| `invoice_classifier_int8_qdq.onnx` | Experimental int8 build (not recommended). |
| `labels.json` | Class names in model index order. |
| `preprocess.json` | Input shape + ImageNet mean/std. |
| `sha256.txt` | SHA-256 hashes + file sizes for pinned downloads. |

### Pinning hashes

```
bbe9997671953145206939291c592dc937c6c523202234bba66e8a589cc643db  invoice_classifier_fp32.onnx     6084524
83d4ac8aafa1c8fa36cfdab50217769f060b6f9ba0a364e8771dbfdda791d7c3  invoice_classifier_int8_qdq.onnx 1664530
```

## License

Apache-2.0. The pretrained ImageNet backbone is also Apache-2.0
(torchvision MobileNetV3 weights).

## Citation

```bibtex
@software{DocRex,
  title  = {DocRex (MobileNetV3-Small)},
  author = {Vivek Kaushal},
  year   = {2026},
  url    = {https://huggingface.co/vivekkaushal/DocRex}
}
```