phanerozoic's picture
Stage 1: output-channel pruning (Path A, exact parity with Stage 0)
d9418d2 verified
"""Stage 1: output-channel pruning.
The full 768-D EUPE-ViT-B output token is produced as before, the final
LayerNorm runs as before, and then only the 100 dimensions the classifier
reads are retained. The classifier is fused into a single Linear(100, 1)
layer with ternary {+1, 0, -1} fixed weights and one free bias (threshold
expressed as negative bias). Inference is identical to Stage 0 by
construction. No compute savings; this stage just cleans the interface
and sets up the weight shapes that later stages will attack.
Usage:
model = Stage1PersonClassifier.from_pretrained_argus('phanerozoic/argus')
score, pred = model(image_tensor)
"""
import json, os, sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
class Stage1PersonClassifier(nn.Module):
"""EUPE-ViT-B -> 100 dim slice -> ternary linear head -> binary decision.
pos_dims and neg_dims are index tensors into the 768-D output. The
classifier weight matrix stored in `retained_weight` has shape (1, 100)
where positive-dim positions are +1 and negative-dim positions are -1.
Bias equals the negated threshold.
"""
def __init__(self, argus_model, pos_dims, neg_dims, threshold):
super().__init__()
self.backbone = argus_model.backbone
retained = list(pos_dims) + list(neg_dims)
self.register_buffer('retained_dims', torch.tensor(retained, dtype=torch.long))
w = torch.zeros(1, len(retained))
w[0, : len(pos_dims)] = 1.0
w[0, len(pos_dims):] = -1.0
self.register_buffer('retained_weight', w)
# Stored as a free parameter so gradient descent could retune.
self.threshold = nn.Parameter(torch.tensor(float(threshold)))
self.D = 768
@torch.inference_mode()
def forward(self, x):
"""x: (B, 3, 768, 768) normalized (ImageNet stats).
Returns (score, pred) where score is (B,) float and pred is (B,) bool.
"""
with torch.autocast('cuda', dtype=torch.bfloat16):
out = self.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float() # (B, 2304, 768)
ln = F.layer_norm(patches, [self.D])
pooled = ln.max(dim=1).values # (B, 768)
retained = pooled.index_select(-1, self.retained_dims) # (B, 100)
score = F.linear(retained, self.retained_weight).squeeze(-1) # (B,)
pred = score > self.threshold
return score, pred
@classmethod
def from_pretrained_argus(cls, repo_or_path='phanerozoic/argus',
classifier_json='classifier.json'):
"""Load Argus, read classifier.json, build the wrapper."""
from transformers import AutoModel
argus = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
with open(classifier_json) as f:
c = json.load(f)
return cls(argus, c['pos_dims'], c['neg_dims'], c['threshold'])
if __name__ == '__main__':
# Smoke test
from transformers import AutoModel
argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True)
c = json.load(open(Path(__file__).parent / '..' / 'stage_0' / 'classifier.json'))
m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold'])
m = m.cuda().eval()
n_all = sum(p.numel() for p in m.parameters())
n_backbone = sum(p.numel() for p in m.backbone.parameters())
n_head = n_all - n_backbone
print(f'total params: {n_all:,}')
print(f'backbone params: {n_backbone:,}')
print(f'head params: {n_head} (one learnable threshold; weights are fixed buffers)')
x = torch.randn(2, 3, 768, 768, device='cuda')
score, pred = m(x)
print(f'forward OK. score={score.tolist()} pred={pred.tolist()}')