--- license: apache-2.0 tags: - image-classification - onnx - onnxruntime - mobilenet - mobile - on-device - document-classification - tally library_name: onnx pipeline_tag: image-classification metrics: - accuracy - f1 base_model: timm/mobilenetv3_small_100.lamb_in1k datasets: [] --- # DocRex A small on-device document classifier that sorts a single image into one of: - `bank_statement` - `invoice` - `other` Designed as a first-stage triage step before any heavyweight OCR or extraction — small enough to ship inside a mobile app and run fully offline. ## Recommended artifact | File | Format | Size | Top-1 acc | Use | |------|--------|-----:|----------:|------| | **`invoice_classifier_fp32.onnx`** | ONNX, fp32 | ~5.8 MB | **98.35%** | **Ship this.** | | `invoice_classifier_int8_qdq.onnx` | ONNX, QDQ static int8 | ~1.7 MB | 58.85% | ⚠️ Experimental — see _Quantization notes_. | **TL;DR — use the fp32 model.** It's only ~6 MB, runs in well under 100 ms per image on modern phone CPUs, and has no accuracy drop. The int8 build is included for reference but is **not recommended for deployment** (details below). ## Model details - **Architecture**: MobileNetV3-Small (torchvision `mobilenet_v3_small`), ImageNet-1k pretrained backbone, final 1000-class head replaced with a 3-class linear layer. - **Parameters**: ~1.5 M. - **Input**: 1 × 3 × 224 × 224, float32, NCHW, ImageNet-normalized. - **Output**: `logits` — 1 × 3 unnormalized scores. Apply softmax to get per-class confidence. - **Class index order** (alphabetical, must match `labels.json`): ``` 0 bank_statement 1 invoice 2 other ``` - **Opset**: 18. ## Intended use - Triage classifier deciding whether a page is worth running invoice / statement extraction on. - Lightweight client-side filtering before backend OCR. ### Out of scope - **Not an OCR model** — does not extract text, totals, dates, or account numbers. - **Not a fraud / authenticity detector.** - **Not a layout analyzer** — looks at the page as a whole. - Anything outside `{bank_statement, invoice}` collapses into `other`. The model does not distinguish sub-types of `other` (receipts vs IDs vs photos). ## How to use ```python import json import numpy as np import onnxruntime as ort from PIL import Image session = ort.InferenceSession("invoice_classifier_fp32.onnx") labels = json.load(open("labels.json")) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1) std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1) img = Image.open("page.jpg").convert("RGB").resize((256, 256)) left = (256 - 224) // 2 img = img.crop((left, left, left + 224, left + 224)) x = np.asarray(img, dtype=np.float32).transpose(2, 0, 1) / 255.0 x = ((x - mean) / std)[None].astype(np.float32) logits = session.run(["logits"], {"input": x})[0][0] probs = np.exp(logits - logits.max()) probs /= probs.sum() print(labels[int(probs.argmax())], float(probs.max())) ``` ## Preprocessing | Step | Value | |------|-------| | Resize | shorter edge → 256 | | Crop | center crop to 224 × 224 | | Color | RGB | | Scale | divide by 255 | | Normalize mean | `[0.485, 0.456, 0.406]` | | Normalize std | `[0.229, 0.224, 0.225]` | | Layout | NCHW | | Dtype | float32 | Standard ImageNet stats — also captured in `preprocess.json` for programmatic loading. ## Training - **Backbone weights**: torchvision `MobileNet_V3_Small_Weights.IMAGENET1K_V1`. - **Head**: replaced with `nn.Linear(in, 3)`. - **Optimizer**: AdamW, weight decay 1 × 10⁻⁴. - **Schedule**: cosine annealing across all epochs. - **Stage 1**: backbone frozen for 2 epochs, only the new head trains (lr = 3 × 10⁻⁴). - **Stage 2**: backbone unfrozen at lr / 10, head stays at base lr (discriminative learning rates). - **Loss**: `CrossEntropyLoss` with inverse-frequency class weights and label smoothing 0.05. - **Augmentation**: Resize(256) → RandomResizedCrop(224, scale 0.7–1.0) → ColorJitter (brightness/contrast/saturation/hue) → small RandomRotation → occasional grayscale → ImageNet normalize. - **Best checkpoint**: selected by validation accuracy. ## Evaluation Held-out test set: **243 images** across the three classes. ### fp32 | Metric | Value | |--------|------:| | Top-1 accuracy | **98.35%** | | Macro F1 | 0.9801 | | Class | F1 | |-------|---:| | `bank_statement` | 0.9783 | | `invoice` | 0.9697 | | `other` | 0.9924 | Confusion matrix (rows = true, cols = predicted): | | bank_statement | invoice | other | |--------------------|---------------:|--------:|------:| | **bank_statement** | 45 | 2 | 0 | | **invoice** | 0 | 64 | 0 | | **other** | 0 | 2 | 130 | ### int8 (QDQ) — not recommended | Metric | Value | |--------|------:| | Top-1 accuracy | 58.85% | | Macro F1 | 0.4517 | | Top-1 disagreement vs fp32 | **40.74% (99/243)** | Best result observed across `MinMax` / `Entropy` / `Percentile` calibration × per-channel / per-tensor weights. All configurations produce a similar collapse (45–58% accuracy). ## Quantization notes Post-training static quantization of MobileNetV3-Small is a known-difficult problem. The architecture's **Hardswish** activations and **Squeeze-and-Excitation** blocks produce activation distributions with extreme outliers that don't fit cleanly into INT8 scales. PTQ — regardless of QDQ vs QOperator format, calibration method, or per-channel vs per-tensor — accumulates enough error across ~140 tensors to collapse one or more classes. If you need a smaller model, in increasing order of effort: 1. **FP16** — usually within rounding error of fp32. Simplest path to ~3 MB. 2. **Quantization-aware training (QAT)** — torchvision provides `models.quantization.mobilenet_v3_small`. Requires a retraining run but typically lands within 1–2 points of fp32. 3. **Switch architectures** — MobileNetV2, EfficientNet-Lite0, or a small ConvNeXt variant all post-train-quantize more reliably than MNV3. The shipped int8 file is left in this repo only as evidence of the failure mode, not as a deployable artifact. > **Why QDQ format anyway?** ONNX Runtime Mobile does not include > `ConvInteger` / `MatMulInteger` operators. A model quantized with > `QuantFormat.QOperator` or `quantize_dynamic` will load on desktop ORT > and then fail at runtime on mobile with `code=9 (NOT_IMPLEMENTED)`. QDQ > keeps standard `Conv` / `MatMul` nodes surrounded by > `QuantizeLinear` / `DequantizeLinear`, which is the path ORT Mobile > executes. So if you do produce a working int8 build (e.g. via QAT), > export it as QDQ. ## Limitations and bias - **Domain bias toward English-language, Western-format documents.** Performance on non-Latin scripts, right-to-left layouts, and regional statement / invoice formats has not been systematically measured. - **Photo conditions matter.** Heavy glare, motion blur, extreme skew (>~15°), or occlusion shifts predictions toward `other`. - **`other` is an open set.** Its decision boundary is determined entirely by the contents of the training data's `other/` folder. Receipts, IDs, screenshots, and shipping labels were included; any class not seen in training may be classified inconsistently. - **No PII handling.** Documents are processed as opaque pixels; the model does not redact or filter sensitive fields. ## Files | File | Purpose | |------|---------| | `invoice_classifier_fp32.onnx` | **Recommended** — fp32 ONNX model. | | `invoice_classifier_int8_qdq.onnx` | Experimental int8 build (not recommended). | | `labels.json` | Class names in model index order. | | `preprocess.json` | Input shape + ImageNet mean/std. | | `sha256.txt` | SHA-256 hashes + file sizes for pinned downloads. | ### Pinning hashes ``` bbe9997671953145206939291c592dc937c6c523202234bba66e8a589cc643db invoice_classifier_fp32.onnx 6084524 83d4ac8aafa1c8fa36cfdab50217769f060b6f9ba0a364e8771dbfdda791d7c3 invoice_classifier_int8_qdq.onnx 1664530 ``` ## License Apache-2.0. The pretrained ImageNet backbone is also Apache-2.0 (torchvision MobileNetV3 weights). ## Citation ```bibtex @software{DocRex, title = {DocRex (MobileNetV3-Small)}, author = {Vivek Kaushal}, year = {2026}, url = {https://huggingface.co/vivekkaushal/DocRex} } ```