Add trained linear softmax classifier alongside the kNN head
Browse filesAdds a second classification method to Argus: a learned linear softmax
classifier that lives next to the existing kNN-over-prototypes head.
Both share the frozen EUPE-ViT-B backbone, both run from the same CLS
token in a single forward pass, and the caller picks per-invocation
via a new `method` keyword argument on `classify()`.
Why
---
The kNN protocol is decisive on top-1 but produces flatter top-k
distributions because nearby class-mean prototypes look similar in
feature space. A learned linear layer can sharpen distinctions between
visually confusable classes and produces calibrated softmax
probabilities that downstream code can actually threshold against.
Both heads are useful; this commit gives callers the choice without
forcing one over the other.
Architecture
------------
- Single Linear(768 -> 1000) layer with bias, applied to the
L2-normalized EUPE-ViT-B CLS token.
- 769K trainable parameters added to the previously frozen pipeline
(86M backbone, 117K seg head, 200K depth head).
- Two new persistent buffers on the Argus class: class_logit_weight
[1000, 768] and class_logit_bias [1000].
- Checkpoint grows from 332 MB to 334 MB.
Training recipe
---------------
- ImageNet-1k train (1,281,167 images) at 224 px center crop with
standard ImageNet normalization.
- Two-pass: extract per-image CLS features once via the frozen
backbone (cached to disk as a [N, 768] tensor + labels), then train
the linear layer on the cached features. Re-runs of the training
step skip extraction entirely.
- SGD with momentum=0.9, weight_decay=0, batch=4096, 100 epochs with
cosine schedule, no augmentation (cached features are fixed).
- LR swept across {0.5, 1.0, 3.0, 10.0, 30.0} x weight_decay {0, 1e-6}.
Best: lr=30.0, wd=0.0. The large LR reflects the L2-normalized
features; with \|\|x\|\|=1 the logits are bounded by \|\|W_i\|\|, so the
weights need aggressive growth from zero init to produce sharp
softmax distributions.
- Best checkpoint by val top-1, restored at the end.
Results on ImageNet-1k val (50,000 images)
-------------------------------------------
| method | top-1 | top-5 |
|-----------------|----------|----------|
| kNN (k=10) | 84.07 % | 93.99 % |
| linear softmax | 85.53 % | 97.69 % |
| delta | +1.46 pt | +3.70 pt |
The linear probe beats kNN on both metrics. The top-5 improvement is
particularly meaningful: softmax is decisive where the nearest-mean
kNN is flat on visually similar classes.
EUPE paper references for context:
- IN1k-ZS (PE text-tower zero-shot): 79.7
- IN1k-KNN: 84.1 (reproduced at 84.07)
API
---
`model.classify(image, top_k=5, method="knn")` is the new signature.
The `method` argument accepts "knn" (default, unchanged behavior)
or "softmax" (new path). The return shape is identical in both modes:
list of {class_id, class_name, score}, with `margin` (top-1 minus
top-2) attached to the first entry. The only behavioral difference
is that `score` is in [0, 1] for softmax (probability) and in
[-1, 1] for kNN (cosine similarity). Batch inputs work for both
methods. Calling with an unknown method raises ValueError.
Inference cost
--------------
- Adds one matmul on the CLS token: under 1 ms latency, negligible
memory.
- The backbone forward pass is shared with all other tasks, so
neither perceive() nor any other method gains any new compute
unless the caller explicitly passes method="softmax".
Backward compatibility
----------------------
- Default method="knn" keeps every existing call site identical to
the previous release.
- The new buffers default to zero when absent from a loaded
checkpoint, and the Argus._init_weights override now zeros any
buffer that came back NaN from HF's torch.empty() reallocation of
missing keys. Before the fix, a checkpoint without trained linear
weights loaded with NaN in class_logit_bias and the softmax path
returned all-NaN outputs.
Files changed
-------------
- argus.py: two new buffers in Argus.__init__, method parameter on
classify(), NaN-safe _init_weights override.
- model.safetensors: gains class_logit_weight and class_logit_bias
flat keys; class_prototypes and all other weights unchanged.
205 tensors total (was 203).
- config.json: unchanged.
- model.safetensors +2 -2
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8107701b7b36e436eda4488d0eaa1297769b374025d67177cf4eeef0240ae053
|
| 3 |
+
size 350231376
|