phanerozoic commited on
Commit
d9418d2
·
verified ·
1 Parent(s): 2a2e168

Stage 1: output-channel pruning (Path A, exact parity with Stage 0)

Browse files
stage_1/README.md CHANGED
@@ -1,5 +1,32 @@
1
  # Stage 1: Output-Channel Pruning
2
 
3
- Reserved. See repo root README for plan.
4
 
5
- Scope: keep only the 100 feature dimensions the Stage 0 classifier reads, remove the remaining 668 output channels from EUPE-ViT-B's final projection. No retraining. Expected to preserve F1 exactly.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Stage 1: Output-Channel Pruning
2
 
3
+ Path A implementation: runtime slicing with classifier fusion. Zero behavior drift from Stage 0.
4
 
5
+ ## What changed
6
+
7
+ The backbone emits a 768-D vector per token. Stage 1 wraps that output so downstream code sees only the 100 dimensions the classifier reads. The classifier is fused into a single `Linear(100, 1)` layer with ternary {+1, -1} fixed weights and one free parameter expressed as the negated threshold bias.
8
+
9
+ ```python
10
+ from model import Stage1PersonClassifier
11
+ model = Stage1PersonClassifier.from_pretrained_argus('phanerozoic/argus')
12
+ score, pred = model(image_tensor) # (B,) float, (B,) bool
13
+ ```
14
+
15
+ ## What did not change
16
+
17
+ Compute: identical to Stage 0. The backbone still produces all 768 output channels, the final LayerNorm still runs across 768 dimensions, and the 668 unused channels are sliced off at the end. This stage exists to crystallize the interface (single classifier head, single learnable scalar) before later stages actually shrink the backbone.
18
+
19
+ ## Evaluation
20
+
21
+ 5000 COCO val 2017 images, live Argus forward pass at 768 pixel input:
22
+
23
+ ```
24
+ Stage 0 F1 0.8886
25
+ Stage 1 F1 0.8886 (exact parity, match=true)
26
+ ```
27
+
28
+ See `eval.json` for precision and recall.
29
+
30
+ ## What Path B would look like
31
+
32
+ Path B drops the last block's MLP `fc2` output rows for the 668 unused dimensions, calibrates the final LayerNorm using fixed per-channel statistics collected from a corpus, and reduces `fc2` from (3072 → 768) to (3072 → 100). Saves ~2.3M backbone parameters (1.6% of 85M). Expected F1 drift is small (<0.02) but non-zero due to the LayerNorm approximation. Not implemented in this stage; belongs in a later iteration or a Stage 1B branch if pursued.
stage_1/__pycache__/model.cpython-311.pyc ADDED
Binary file (7.26 kB). View file
 
stage_1/eval.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "stage": 1,
3
+ "F1": 0.8885542750358582,
4
+ "precision": 0.9011073112487793,
5
+ "recall": 0.8763461112976074,
6
+ "threshold": 25.284494400024414,
7
+ "stage_0_baseline_F1": 0.8886,
8
+ "match": true
9
+ }
stage_1/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stage 1: output-channel pruning.
2
+
3
+ The full 768-D EUPE-ViT-B output token is produced as before, the final
4
+ LayerNorm runs as before, and then only the 100 dimensions the classifier
5
+ reads are retained. The classifier is fused into a single Linear(100, 1)
6
+ layer with ternary {+1, 0, -1} fixed weights and one free bias (threshold
7
+ expressed as negative bias). Inference is identical to Stage 0 by
8
+ construction. No compute savings; this stage just cleans the interface
9
+ and sets up the weight shapes that later stages will attack.
10
+
11
+ Usage:
12
+ model = Stage1PersonClassifier.from_pretrained_argus('phanerozoic/argus')
13
+ score, pred = model(image_tensor)
14
+ """
15
+ import json, os, sys
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+
23
+ class Stage1PersonClassifier(nn.Module):
24
+ """EUPE-ViT-B -> 100 dim slice -> ternary linear head -> binary decision.
25
+
26
+ pos_dims and neg_dims are index tensors into the 768-D output. The
27
+ classifier weight matrix stored in `retained_weight` has shape (1, 100)
28
+ where positive-dim positions are +1 and negative-dim positions are -1.
29
+ Bias equals the negated threshold.
30
+ """
31
+
32
+ def __init__(self, argus_model, pos_dims, neg_dims, threshold):
33
+ super().__init__()
34
+ self.backbone = argus_model.backbone
35
+ retained = list(pos_dims) + list(neg_dims)
36
+ self.register_buffer('retained_dims', torch.tensor(retained, dtype=torch.long))
37
+ w = torch.zeros(1, len(retained))
38
+ w[0, : len(pos_dims)] = 1.0
39
+ w[0, len(pos_dims):] = -1.0
40
+ self.register_buffer('retained_weight', w)
41
+ # Stored as a free parameter so gradient descent could retune.
42
+ self.threshold = nn.Parameter(torch.tensor(float(threshold)))
43
+ self.D = 768
44
+
45
+ @torch.inference_mode()
46
+ def forward(self, x):
47
+ """x: (B, 3, 768, 768) normalized (ImageNet stats).
48
+
49
+ Returns (score, pred) where score is (B,) float and pred is (B,) bool.
50
+ """
51
+ with torch.autocast('cuda', dtype=torch.bfloat16):
52
+ out = self.backbone.forward_features(x)
53
+ patches = out['x_norm_patchtokens'].float() # (B, 2304, 768)
54
+ ln = F.layer_norm(patches, [self.D])
55
+ pooled = ln.max(dim=1).values # (B, 768)
56
+ retained = pooled.index_select(-1, self.retained_dims) # (B, 100)
57
+ score = F.linear(retained, self.retained_weight).squeeze(-1) # (B,)
58
+ pred = score > self.threshold
59
+ return score, pred
60
+
61
+ @classmethod
62
+ def from_pretrained_argus(cls, repo_or_path='phanerozoic/argus',
63
+ classifier_json='classifier.json'):
64
+ """Load Argus, read classifier.json, build the wrapper."""
65
+ from transformers import AutoModel
66
+ argus = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
67
+ with open(classifier_json) as f:
68
+ c = json.load(f)
69
+ return cls(argus, c['pos_dims'], c['neg_dims'], c['threshold'])
70
+
71
+
72
+ if __name__ == '__main__':
73
+ # Smoke test
74
+ from transformers import AutoModel
75
+ argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True)
76
+ c = json.load(open(Path(__file__).parent / '..' / 'stage_0' / 'classifier.json'))
77
+ m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold'])
78
+ m = m.cuda().eval()
79
+ n_all = sum(p.numel() for p in m.parameters())
80
+ n_backbone = sum(p.numel() for p in m.backbone.parameters())
81
+ n_head = n_all - n_backbone
82
+ print(f'total params: {n_all:,}')
83
+ print(f'backbone params: {n_backbone:,}')
84
+ print(f'head params: {n_head} (one learnable threshold; weights are fixed buffers)')
85
+
86
+ x = torch.randn(2, 3, 768, 768, device='cuda')
87
+ score, pred = m(x)
88
+ print(f'forward OK. score={score.tolist()} pred={pred.tolist()}')
stage_1/verify.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verify Stage 1 matches Stage 0 F1 exactly on COCO val 2017."""
2
+ import os, sys, json, time
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pycocotools.coco import COCO
7
+
8
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
9
+ from model import Stage1PersonClassifier
10
+ from transformers import AutoModel
11
+
12
+ COCO_ROOT = '/home/zootest/datasets/coco'
13
+ CLASSIFIER_JSON = os.path.join(os.path.dirname(os.path.abspath(__file__)),
14
+ '..', 'stage_0', 'classifier.json')
15
+ RES = 768
16
+
17
+
18
+ def main():
19
+ print('[init] loading Argus + classifier', flush=True)
20
+ argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True)
21
+ c = json.load(open(CLASSIFIER_JSON))
22
+ m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold']).cuda().eval()
23
+
24
+ coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
25
+ img_ids = sorted(coco.getImgIds())
26
+ labels = [any(a['category_id'] == 1 for a in coco.loadAnns(coco.getAnnIds(imgIds=i, iscrowd=False)))
27
+ for i in img_ids]
28
+
29
+ MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
30
+ STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
31
+
32
+ scores = torch.zeros(len(img_ids), device='cuda')
33
+ t0 = time.time()
34
+ for i, img_id in enumerate(img_ids):
35
+ info = coco.loadImgs(img_id)[0]
36
+ path = f"{COCO_ROOT}/val2017/{info['file_name']}"
37
+ img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
38
+ arr = np.asarray(img, dtype=np.uint8).copy()
39
+ x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
40
+ x = (x - MEAN) / STD
41
+ s, _ = m(x)
42
+ scores[i] = s[0]
43
+ if (i + 1) % 500 == 0:
44
+ print(f' {i+1}/{len(img_ids)} {(i+1)/(time.time()-t0):.1f} img/s', flush=True)
45
+
46
+ y = torch.tensor(labels, dtype=torch.bool, device='cuda')
47
+ pred = scores > m.threshold
48
+ tp = (pred & y).sum().float()
49
+ fp = (pred & ~y).sum().float()
50
+ fn = (~pred & y).sum().float()
51
+ prec = tp / (tp + fp).clamp(min=1)
52
+ rec = tp / (tp + fn).clamp(min=1)
53
+ f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
54
+ print(f'\n[verify] F1={f1:.4f} P={prec:.4f} R={rec:.4f}', flush=True)
55
+ print(f'[parity] Stage 0 baseline = 0.8886; Stage 1 should be identical', flush=True)
56
+
57
+ out = {
58
+ 'stage': 1,
59
+ 'F1': float(f1),
60
+ 'precision': float(prec),
61
+ 'recall': float(rec),
62
+ 'threshold': float(m.threshold.item()),
63
+ 'stage_0_baseline_F1': 0.8886,
64
+ 'match': bool(abs(float(f1) - 0.8886) < 1e-3),
65
+ }
66
+ with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval.json'), 'w') as f:
67
+ json.dump(out, f, indent=2)
68
+ print(f'[done] wrote eval.json match={out["match"]}', flush=True)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()