"""Stage 1: output-channel pruning. The full 768-D EUPE-ViT-B output token is produced as before, the final LayerNorm runs as before, and then only the 100 dimensions the classifier reads are retained. The classifier is fused into a single Linear(100, 1) layer with ternary {+1, 0, -1} fixed weights and one free bias (threshold expressed as negative bias). Inference is identical to Stage 0 by construction. No compute savings; this stage just cleans the interface and sets up the weight shapes that later stages will attack. Usage: model = Stage1PersonClassifier.from_pretrained_argus('phanerozoic/argus') score, pred = model(image_tensor) """ import json, os, sys from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F class Stage1PersonClassifier(nn.Module): """EUPE-ViT-B -> 100 dim slice -> ternary linear head -> binary decision. pos_dims and neg_dims are index tensors into the 768-D output. The classifier weight matrix stored in `retained_weight` has shape (1, 100) where positive-dim positions are +1 and negative-dim positions are -1. Bias equals the negated threshold. """ def __init__(self, argus_model, pos_dims, neg_dims, threshold): super().__init__() self.backbone = argus_model.backbone retained = list(pos_dims) + list(neg_dims) self.register_buffer('retained_dims', torch.tensor(retained, dtype=torch.long)) w = torch.zeros(1, len(retained)) w[0, : len(pos_dims)] = 1.0 w[0, len(pos_dims):] = -1.0 self.register_buffer('retained_weight', w) # Stored as a free parameter so gradient descent could retune. self.threshold = nn.Parameter(torch.tensor(float(threshold))) self.D = 768 @torch.inference_mode() def forward(self, x): """x: (B, 3, 768, 768) normalized (ImageNet stats). Returns (score, pred) where score is (B,) float and pred is (B,) bool. """ with torch.autocast('cuda', dtype=torch.bfloat16): out = self.backbone.forward_features(x) patches = out['x_norm_patchtokens'].float() # (B, 2304, 768) ln = F.layer_norm(patches, [self.D]) pooled = ln.max(dim=1).values # (B, 768) retained = pooled.index_select(-1, self.retained_dims) # (B, 100) score = F.linear(retained, self.retained_weight).squeeze(-1) # (B,) pred = score > self.threshold return score, pred @classmethod def from_pretrained_argus(cls, repo_or_path='phanerozoic/argus', classifier_json='classifier.json'): """Load Argus, read classifier.json, build the wrapper.""" from transformers import AutoModel argus = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True) with open(classifier_json) as f: c = json.load(f) return cls(argus, c['pos_dims'], c['neg_dims'], c['threshold']) if __name__ == '__main__': # Smoke test from transformers import AutoModel argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True) c = json.load(open(Path(__file__).parent / '..' / 'stage_0' / 'classifier.json')) m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold']) m = m.cuda().eval() n_all = sum(p.numel() for p in m.parameters()) n_backbone = sum(p.numel() for p in m.backbone.parameters()) n_head = n_all - n_backbone print(f'total params: {n_all:,}') print(f'backbone params: {n_backbone:,}') print(f'head params: {n_head} (one learnable threshold; weights are fixed buffers)') x = torch.randn(2, 3, 768, 768, device='cuda') score, pred = m(x) print(f'forward OK. score={score.tolist()} pred={pred.tolist()}')