DocRex
A small on-device document classifier that sorts a single image into one of:
bank_statementinvoiceother
Designed as a first-stage triage step before any heavyweight OCR or extraction β small enough to ship inside a mobile app and run fully offline.
Recommended artifact
| File | Format | Size | Top-1 acc | Use |
|---|---|---|---|---|
invoice_classifier_fp32.onnx |
ONNX, fp32 | ~5.8 MB | 98.35% | Ship this. |
invoice_classifier_int8_qdq.onnx |
ONNX, QDQ static int8 | ~1.7 MB | 58.85% | β οΈ Experimental β see Quantization notes. |
TL;DR β use the fp32 model. It's only ~6 MB, runs in well under 100 ms per image on modern phone CPUs, and has no accuracy drop. The int8 build is included for reference but is not recommended for deployment (details below).
Model details
Architecture: MobileNetV3-Small (torchvision
mobilenet_v3_small), ImageNet-1k pretrained backbone, final 1000-class head replaced with a 3-class linear layer.Parameters: ~1.5 M.
Input: 1 Γ 3 Γ 224 Γ 224, float32, NCHW, ImageNet-normalized.
Output:
logitsβ 1 Γ 3 unnormalized scores. Apply softmax to get per-class confidence.Class index order (alphabetical, must match
labels.json):0 bank_statement 1 invoice 2 otherOpset: 18.
Intended use
- Triage classifier deciding whether a page is worth running invoice / statement extraction on.
- Lightweight client-side filtering before backend OCR.
Out of scope
- Not an OCR model β does not extract text, totals, dates, or account numbers.
- Not a fraud / authenticity detector.
- Not a layout analyzer β looks at the page as a whole.
- Anything outside
{bank_statement, invoice}collapses intoother. The model does not distinguish sub-types ofother(receipts vs IDs vs photos).
How to use
import json
import numpy as np
import onnxruntime as ort
from PIL import Image
session = ort.InferenceSession("invoice_classifier_fp32.onnx")
labels = json.load(open("labels.json"))
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
img = Image.open("page.jpg").convert("RGB").resize((256, 256))
left = (256 - 224) // 2
img = img.crop((left, left, left + 224, left + 224))
x = np.asarray(img, dtype=np.float32).transpose(2, 0, 1) / 255.0
x = ((x - mean) / std)[None].astype(np.float32)
logits = session.run(["logits"], {"input": x})[0][0]
probs = np.exp(logits - logits.max())
probs /= probs.sum()
print(labels[int(probs.argmax())], float(probs.max()))
Preprocessing
| Step | Value |
|---|---|
| Resize | shorter edge β 256 |
| Crop | center crop to 224 Γ 224 |
| Color | RGB |
| Scale | divide by 255 |
| Normalize mean | [0.485, 0.456, 0.406] |
| Normalize std | [0.229, 0.224, 0.225] |
| Layout | NCHW |
| Dtype | float32 |
Standard ImageNet stats β also captured in preprocess.json for
programmatic loading.
Training
- Backbone weights: torchvision
MobileNet_V3_Small_Weights.IMAGENET1K_V1. - Head: replaced with
nn.Linear(in, 3). - Optimizer: AdamW, weight decay 1 Γ 10β»β΄.
- Schedule: cosine annealing across all epochs.
- Stage 1: backbone frozen for 2 epochs, only the new head trains (lr = 3 Γ 10β»β΄).
- Stage 2: backbone unfrozen at lr / 10, head stays at base lr (discriminative learning rates).
- Loss:
CrossEntropyLosswith inverse-frequency class weights and label smoothing 0.05. - Augmentation: Resize(256) β RandomResizedCrop(224, scale 0.7β1.0) β ColorJitter (brightness/contrast/saturation/hue) β small RandomRotation β occasional grayscale β ImageNet normalize.
- Best checkpoint: selected by validation accuracy.
Evaluation
Held-out test set: 243 images across the three classes.
fp32
| Metric | Value |
|---|---|
| Top-1 accuracy | 98.35% |
| Macro F1 | 0.9801 |
| Class | F1 |
|---|---|
bank_statement |
0.9783 |
invoice |
0.9697 |
other |
0.9924 |
Confusion matrix (rows = true, cols = predicted):
| bank_statement | invoice | other | |
|---|---|---|---|
| bank_statement | 45 | 2 | 0 |
| invoice | 0 | 64 | 0 |
| other | 0 | 2 | 130 |
int8 (QDQ) β not recommended
| Metric | Value |
|---|---|
| Top-1 accuracy | 58.85% |
| Macro F1 | 0.4517 |
| Top-1 disagreement vs fp32 | 40.74% (99/243) |
Best result observed across MinMax / Entropy / Percentile calibration
Γ per-channel / per-tensor weights. All configurations produce a similar
collapse (45β58% accuracy).
Quantization notes
Post-training static quantization of MobileNetV3-Small is a known-difficult problem. The architecture's Hardswish activations and Squeeze-and-Excitation blocks produce activation distributions with extreme outliers that don't fit cleanly into INT8 scales. PTQ β regardless of QDQ vs QOperator format, calibration method, or per-channel vs per-tensor β accumulates enough error across ~140 tensors to collapse one or more classes.
If you need a smaller model, in increasing order of effort:
- FP16 β usually within rounding error of fp32. Simplest path to ~3 MB.
- Quantization-aware training (QAT) β torchvision provides
models.quantization.mobilenet_v3_small. Requires a retraining run but typically lands within 1β2 points of fp32. - Switch architectures β MobileNetV2, EfficientNet-Lite0, or a small ConvNeXt variant all post-train-quantize more reliably than MNV3.
The shipped int8 file is left in this repo only as evidence of the failure mode, not as a deployable artifact.
Why QDQ format anyway? ONNX Runtime Mobile does not include
ConvInteger/MatMulIntegeroperators. A model quantized withQuantFormat.QOperatororquantize_dynamicwill load on desktop ORT and then fail at runtime on mobile withcode=9 (NOT_IMPLEMENTED). QDQ keeps standardConv/MatMulnodes surrounded byQuantizeLinear/DequantizeLinear, which is the path ORT Mobile executes. So if you do produce a working int8 build (e.g. via QAT), export it as QDQ.
Limitations and bias
- Domain bias toward English-language, Western-format documents. Performance on non-Latin scripts, right-to-left layouts, and regional statement / invoice formats has not been systematically measured.
- Photo conditions matter. Heavy glare, motion blur, extreme skew
(>~15Β°), or occlusion shifts predictions toward
other. otheris an open set. Its decision boundary is determined entirely by the contents of the training data'sother/folder. Receipts, IDs, screenshots, and shipping labels were included; any class not seen in training may be classified inconsistently.- No PII handling. Documents are processed as opaque pixels; the model does not redact or filter sensitive fields.
Files
| File | Purpose |
|---|---|
invoice_classifier_fp32.onnx |
Recommended β fp32 ONNX model. |
invoice_classifier_int8_qdq.onnx |
Experimental int8 build (not recommended). |
labels.json |
Class names in model index order. |
preprocess.json |
Input shape + ImageNet mean/std. |
sha256.txt |
SHA-256 hashes + file sizes for pinned downloads. |
Pinning hashes
bbe9997671953145206939291c592dc937c6c523202234bba66e8a589cc643db invoice_classifier_fp32.onnx 6084524
83d4ac8aafa1c8fa36cfdab50217769f060b6f9ba0a364e8771dbfdda791d7c3 invoice_classifier_int8_qdq.onnx 1664530
License
Apache-2.0. The pretrained ImageNet backbone is also Apache-2.0 (torchvision MobileNetV3 weights).
Citation
@software{DocRex,
title = {DocRex (MobileNetV3-Small)},
author = {Vivek Kaushal},
year = {2026},
url = {https://huggingface.co/vivekkaushal/DocRex}
}
Model tree for vivekkaushal/DocRex
Base model
timm/mobilenetv3_small_100.lamb_in1k