Add trained linear softmax classifier alongside the kNN head
Browse filesAdds a second classification method to Argus: a learned linear softmax
classifier that lives next to the existing kNN-over-prototypes head.
Both share the frozen EUPE-ViT-B backbone, both run from the same CLS
token in a single forward pass, and the caller picks per-invocation
via a new `method` keyword argument on `classify()`.
Why
---
The kNN protocol is decisive on top-1 but produces flatter top-k
distributions because nearby class-mean prototypes look similar in
feature space. A learned linear layer can sharpen distinctions between
visually confusable classes and produces calibrated softmax
probabilities that downstream code can actually threshold against.
Both heads are useful; this commit gives callers the choice without
forcing one over the other.
Architecture
------------
- Single Linear(768 -> 1000) layer with bias, applied to the
L2-normalized EUPE-ViT-B CLS token.
- 769K trainable parameters added to the previously frozen pipeline
(86M backbone, 117K seg head, 200K depth head).
- Two new persistent buffers on the Argus class: class_logit_weight
[1000, 768] and class_logit_bias [1000].
- Checkpoint grows from 332 MB to 334 MB.
Training recipe
---------------
- ImageNet-1k train (1,281,167 images) at 224 px center crop with
standard ImageNet normalization.
- Two-pass: extract per-image CLS features once via the frozen
backbone (cached to disk as a [N, 768] tensor + labels), then train
the linear layer on the cached features. Re-runs of the training
step skip extraction entirely.
- SGD with momentum=0.9, weight_decay=0, batch=4096, 100 epochs with
cosine schedule, no augmentation (cached features are fixed).
- LR swept across {0.5, 1.0, 3.0, 10.0, 30.0} x weight_decay {0, 1e-6}.
Best: lr=30.0, wd=0.0. The large LR reflects the L2-normalized
features; with \|\|x\|\|=1 the logits are bounded by \|\|W_i\|\|, so the
weights need aggressive growth from zero init to produce sharp
softmax distributions.
- Best checkpoint by val top-1, restored at the end.
Results on ImageNet-1k val (50,000 images)
-------------------------------------------
| method | top-1 | top-5 |
|-----------------|----------|----------|
| kNN (k=10) | 84.07 % | 93.99 % |
| linear softmax | 85.53 % | 97.69 % |
| delta | +1.46 pt | +3.70 pt |
The linear probe beats kNN on both metrics. The top-5 improvement is
particularly meaningful: softmax is decisive where the nearest-mean
kNN is flat on visually similar classes.
EUPE paper references for context:
- IN1k-ZS (PE text-tower zero-shot): 79.7
- IN1k-KNN: 84.1 (reproduced at 84.07)
API
---
`model.classify(image, top_k=5, method="knn")` is the new signature.
The `method` argument accepts "knn" (default, unchanged behavior)
or "softmax" (new path). The return shape is identical in both modes:
list of {class_id, class_name, score}, with `margin` (top-1 minus
top-2) attached to the first entry. The only behavioral difference
is that `score` is in [0, 1] for softmax (probability) and in
[-1, 1] for kNN (cosine similarity). Batch inputs work for both
methods. Calling with an unknown method raises ValueError.
Inference cost
--------------
- Adds one matmul on the CLS token: under 1 ms latency, negligible
memory.
- The backbone forward pass is shared with all other tasks, so
neither perceive() nor any other method gains any new compute
unless the caller explicitly passes method="softmax".
Backward compatibility
----------------------
- Default method="knn" keeps every existing call site identical to
the previous release.
- The new buffers default to zero when absent from a loaded
checkpoint, and the Argus._init_weights override now zeros any
buffer that came back NaN from HF's torch.empty() reallocation of
missing keys. Before the fix, a checkpoint without trained linear
weights loaded with NaN in class_logit_bias and the softmax path
returned all-NaN outputs.
Files changed
-------------
- argus.py: two new buffers in Argus.__init__, method parameter on
classify(), NaN-safe _init_weights override.
- model.safetensors: gains class_logit_weight and class_logit_bias
flat keys; class_prototypes and all other weights unchanged.
205 tensors total (was 203).
- config.json: unchanged.
|
@@ -929,6 +929,16 @@ class Argus(PreTrainedModel):
|
|
| 929 |
torch.zeros(config.num_imagenet_classes, config.embed_dim),
|
| 930 |
persistent=True,
|
| 931 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
|
| 933 |
for p in self.backbone.parameters():
|
| 934 |
p.requires_grad = False
|
|
@@ -937,7 +947,14 @@ class Argus(PreTrainedModel):
|
|
| 937 |
self.depth_head.eval()
|
| 938 |
|
| 939 |
def _init_weights(self, module):
|
| 940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
|
| 942 |
@property
|
| 943 |
def class_ids(self):
|
|
@@ -959,15 +976,28 @@ class Argus(PreTrainedModel):
|
|
| 959 |
return cls, spatial
|
| 960 |
|
| 961 |
@torch.inference_mode()
|
| 962 |
-
def classify(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 963 |
single, images = _normalize_image_input(image_or_images)
|
| 964 |
transform = make_eupe_transform(224)
|
| 965 |
batch = torch.stack([transform(img) for img in images]).to(self.device)
|
| 966 |
cls, _ = self._extract(batch)
|
| 967 |
cls = F.normalize(cls, dim=-1)
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
margins = (top2.values[:, 0] - top2.values[:, 1]).tolist()
|
| 972 |
|
| 973 |
results = []
|
|
|
|
| 929 |
torch.zeros(config.num_imagenet_classes, config.embed_dim),
|
| 930 |
persistent=True,
|
| 931 |
)
|
| 932 |
+
self.register_buffer(
|
| 933 |
+
"class_logit_weight",
|
| 934 |
+
torch.zeros(config.num_imagenet_classes, config.embed_dim),
|
| 935 |
+
persistent=True,
|
| 936 |
+
)
|
| 937 |
+
self.register_buffer(
|
| 938 |
+
"class_logit_bias",
|
| 939 |
+
torch.zeros(config.num_imagenet_classes),
|
| 940 |
+
persistent=True,
|
| 941 |
+
)
|
| 942 |
|
| 943 |
for p in self.backbone.parameters():
|
| 944 |
p.requires_grad = False
|
|
|
|
| 947 |
self.depth_head.eval()
|
| 948 |
|
| 949 |
def _init_weights(self, module):
|
| 950 |
+
# HF reallocates missing buffers with torch.empty() (uninitialized memory).
|
| 951 |
+
# Zero any buffer that came back NaN; leave loaded buffers untouched.
|
| 952 |
+
if module is self:
|
| 953 |
+
for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"):
|
| 954 |
+
if hasattr(self, name):
|
| 955 |
+
buf = getattr(self, name)
|
| 956 |
+
if torch.isnan(buf).any() or torch.isinf(buf).any():
|
| 957 |
+
buf.data.zero_()
|
| 958 |
|
| 959 |
@property
|
| 960 |
def class_ids(self):
|
|
|
|
| 976 |
return cls, spatial
|
| 977 |
|
| 978 |
@torch.inference_mode()
|
| 979 |
+
def classify(
|
| 980 |
+
self,
|
| 981 |
+
image_or_images,
|
| 982 |
+
top_k: int = 5,
|
| 983 |
+
method: Literal["knn", "softmax"] = "knn",
|
| 984 |
+
):
|
| 985 |
single, images = _normalize_image_input(image_or_images)
|
| 986 |
transform = make_eupe_transform(224)
|
| 987 |
batch = torch.stack([transform(img) for img in images]).to(self.device)
|
| 988 |
cls, _ = self._extract(batch)
|
| 989 |
cls = F.normalize(cls, dim=-1)
|
| 990 |
+
|
| 991 |
+
if method == "knn":
|
| 992 |
+
scores_full = cls @ self.class_prototypes.T # cosine similarity in [-1, 1]
|
| 993 |
+
elif method == "softmax":
|
| 994 |
+
logits = F.linear(cls, self.class_logit_weight, self.class_logit_bias)
|
| 995 |
+
scores_full = F.softmax(logits, dim=-1) # in [0, 1]
|
| 996 |
+
else:
|
| 997 |
+
raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')")
|
| 998 |
+
|
| 999 |
+
topk = scores_full.topk(top_k, dim=-1)
|
| 1000 |
+
top2 = scores_full.topk(2, dim=-1)
|
| 1001 |
margins = (top2.values[:, 0] - top2.values[:, 1]).tolist()
|
| 1002 |
|
| 1003 |
results = []
|