DocRex

A small on-device document classifier that sorts a single image into one of:

  • bank_statement
  • invoice
  • other

Designed as a first-stage triage step before any heavyweight OCR or extraction β€” small enough to ship inside a mobile app and run fully offline.

Recommended artifact

File Format Size Top-1 acc Use
invoice_classifier_fp32.onnx ONNX, fp32 ~5.8 MB 98.35% Ship this.
invoice_classifier_int8_qdq.onnx ONNX, QDQ static int8 ~1.7 MB 58.85% ⚠️ Experimental β€” see Quantization notes.

TL;DR β€” use the fp32 model. It's only ~6 MB, runs in well under 100 ms per image on modern phone CPUs, and has no accuracy drop. The int8 build is included for reference but is not recommended for deployment (details below).

Model details

  • Architecture: MobileNetV3-Small (torchvision mobilenet_v3_small), ImageNet-1k pretrained backbone, final 1000-class head replaced with a 3-class linear layer.

  • Parameters: ~1.5 M.

  • Input: 1 Γ— 3 Γ— 224 Γ— 224, float32, NCHW, ImageNet-normalized.

  • Output: logits β€” 1 Γ— 3 unnormalized scores. Apply softmax to get per-class confidence.

  • Class index order (alphabetical, must match labels.json):

    0  bank_statement
    1  invoice
    2  other
    
  • Opset: 18.

Intended use

  • Triage classifier deciding whether a page is worth running invoice / statement extraction on.
  • Lightweight client-side filtering before backend OCR.

Out of scope

  • Not an OCR model β€” does not extract text, totals, dates, or account numbers.
  • Not a fraud / authenticity detector.
  • Not a layout analyzer β€” looks at the page as a whole.
  • Anything outside {bank_statement, invoice} collapses into other. The model does not distinguish sub-types of other (receipts vs IDs vs photos).

How to use

import json
import numpy as np
import onnxruntime as ort
from PIL import Image

session = ort.InferenceSession("invoice_classifier_fp32.onnx")
labels = json.load(open("labels.json"))
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
std  = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)

img = Image.open("page.jpg").convert("RGB").resize((256, 256))
left = (256 - 224) // 2
img = img.crop((left, left, left + 224, left + 224))
x = np.asarray(img, dtype=np.float32).transpose(2, 0, 1) / 255.0
x = ((x - mean) / std)[None].astype(np.float32)

logits = session.run(["logits"], {"input": x})[0][0]
probs = np.exp(logits - logits.max())
probs /= probs.sum()
print(labels[int(probs.argmax())], float(probs.max()))

Preprocessing

Step Value
Resize shorter edge β†’ 256
Crop center crop to 224 Γ— 224
Color RGB
Scale divide by 255
Normalize mean [0.485, 0.456, 0.406]
Normalize std [0.229, 0.224, 0.225]
Layout NCHW
Dtype float32

Standard ImageNet stats β€” also captured in preprocess.json for programmatic loading.

Training

  • Backbone weights: torchvision MobileNet_V3_Small_Weights.IMAGENET1K_V1.
  • Head: replaced with nn.Linear(in, 3).
  • Optimizer: AdamW, weight decay 1 Γ— 10⁻⁴.
  • Schedule: cosine annealing across all epochs.
  • Stage 1: backbone frozen for 2 epochs, only the new head trains (lr = 3 Γ— 10⁻⁴).
  • Stage 2: backbone unfrozen at lr / 10, head stays at base lr (discriminative learning rates).
  • Loss: CrossEntropyLoss with inverse-frequency class weights and label smoothing 0.05.
  • Augmentation: Resize(256) β†’ RandomResizedCrop(224, scale 0.7–1.0) β†’ ColorJitter (brightness/contrast/saturation/hue) β†’ small RandomRotation β†’ occasional grayscale β†’ ImageNet normalize.
  • Best checkpoint: selected by validation accuracy.

Evaluation

Held-out test set: 243 images across the three classes.

fp32

Metric Value
Top-1 accuracy 98.35%
Macro F1 0.9801
Class F1
bank_statement 0.9783
invoice 0.9697
other 0.9924

Confusion matrix (rows = true, cols = predicted):

bank_statement invoice other
bank_statement 45 2 0
invoice 0 64 0
other 0 2 130

int8 (QDQ) β€” not recommended

Metric Value
Top-1 accuracy 58.85%
Macro F1 0.4517
Top-1 disagreement vs fp32 40.74% (99/243)

Best result observed across MinMax / Entropy / Percentile calibration Γ— per-channel / per-tensor weights. All configurations produce a similar collapse (45–58% accuracy).

Quantization notes

Post-training static quantization of MobileNetV3-Small is a known-difficult problem. The architecture's Hardswish activations and Squeeze-and-Excitation blocks produce activation distributions with extreme outliers that don't fit cleanly into INT8 scales. PTQ β€” regardless of QDQ vs QOperator format, calibration method, or per-channel vs per-tensor β€” accumulates enough error across ~140 tensors to collapse one or more classes.

If you need a smaller model, in increasing order of effort:

  1. FP16 β€” usually within rounding error of fp32. Simplest path to ~3 MB.
  2. Quantization-aware training (QAT) β€” torchvision provides models.quantization.mobilenet_v3_small. Requires a retraining run but typically lands within 1–2 points of fp32.
  3. Switch architectures β€” MobileNetV2, EfficientNet-Lite0, or a small ConvNeXt variant all post-train-quantize more reliably than MNV3.

The shipped int8 file is left in this repo only as evidence of the failure mode, not as a deployable artifact.

Why QDQ format anyway? ONNX Runtime Mobile does not include ConvInteger / MatMulInteger operators. A model quantized with QuantFormat.QOperator or quantize_dynamic will load on desktop ORT and then fail at runtime on mobile with code=9 (NOT_IMPLEMENTED). QDQ keeps standard Conv / MatMul nodes surrounded by QuantizeLinear / DequantizeLinear, which is the path ORT Mobile executes. So if you do produce a working int8 build (e.g. via QAT), export it as QDQ.

Limitations and bias

  • Domain bias toward English-language, Western-format documents. Performance on non-Latin scripts, right-to-left layouts, and regional statement / invoice formats has not been systematically measured.
  • Photo conditions matter. Heavy glare, motion blur, extreme skew (>~15Β°), or occlusion shifts predictions toward other.
  • other is an open set. Its decision boundary is determined entirely by the contents of the training data's other/ folder. Receipts, IDs, screenshots, and shipping labels were included; any class not seen in training may be classified inconsistently.
  • No PII handling. Documents are processed as opaque pixels; the model does not redact or filter sensitive fields.

Files

File Purpose
invoice_classifier_fp32.onnx Recommended β€” fp32 ONNX model.
invoice_classifier_int8_qdq.onnx Experimental int8 build (not recommended).
labels.json Class names in model index order.
preprocess.json Input shape + ImageNet mean/std.
sha256.txt SHA-256 hashes + file sizes for pinned downloads.

Pinning hashes

bbe9997671953145206939291c592dc937c6c523202234bba66e8a589cc643db  invoice_classifier_fp32.onnx     6084524
83d4ac8aafa1c8fa36cfdab50217769f060b6f9ba0a364e8771dbfdda791d7c3  invoice_classifier_int8_qdq.onnx 1664530

License

Apache-2.0. The pretrained ImageNet backbone is also Apache-2.0 (torchvision MobileNetV3 weights).

Citation

@software{DocRex,
  title  = {DocRex (MobileNetV3-Small)},
  author = {Vivek Kaushal},
  year   = {2026},
  url    = {https://huggingface.co/vivekkaushal/DocRex}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for vivekkaushal/DocRex

Quantized
(3)
this model