phanerozoic commited on
Commit
3729ac4
·
verified ·
1 Parent(s): a7e09b2

Stage 3: depth reduction results (only 1 block cleanly prunable)

Browse files
stage_3/README.md CHANGED
@@ -1,5 +1,53 @@
1
  # Stage 3: Depth Reduction
2
 
3
- Reserved. See repo root README for plan.
4
 
5
- Scope: drop transformer blocks that do not route signal to the 100 Stage 0 dims.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Stage 3: Depth Reduction
2
 
3
+ Attempted block-level pruning analogous to Stage 2 but at the block granularity. For each of EUPE-ViT-B's 12 transformer blocks, zeroed both `block.attn.proj.weight` and `block.mlp.fc2.weight`, which because of the residual structure degenerates the block to a pass-through identity. Measured F1 on 1000 COCO val images with the Stage 0 classifier.
4
 
5
+ ## Headline result
6
+
7
+ Only one block is cleanly prunable. Block 11 (the final block) can be removed with F1 dropping from 0.894 to 0.876. Block 6 is borderline (drop 0.030). All other blocks are structurally critical: ablation collapses the classifier to near-zero F1. Cumulative pruning past K=1 drops fast: K=2 loses 12 F1 points, K=3 destroys the classifier.
8
+
9
+ ## Per-block importance
10
+
11
+ ```
12
+ Block F1 ΔF1 vs baseline
13
+ 0 0.000 +0.89 (critical)
14
+ 1 0.011 +0.88 (critical)
15
+ 2 0.000 +0.89 (critical)
16
+ 3 0.783 +0.11
17
+ 4 0.765 +0.13
18
+ 5 0.599 +0.29 (important)
19
+ 6 0.864 +0.03 (borderline)
20
+ 7 0.152 +0.74 (critical)
21
+ 8 0.430 +0.46
22
+ 9 0.674 +0.22
23
+ 10 0.743 +0.15
24
+ 11 0.876 +0.02 (most prunable)
25
+ ```
26
+
27
+ Baseline F1 = 0.8939 (1000-image calibration pool).
28
+
29
+ ## Cumulative pruning
30
+
31
+ ```
32
+ K pruned F1
33
+ 1 0.876 [block 11]
34
+ 2 0.770 [11, 6]
35
+ 3 0.000 [11, 6, 3]
36
+ 4+ 0.000
37
+ ```
38
+
39
+ ## Interpretation
40
+
41
+ Transformer blocks cascade information through residual updates. Unlike individual attention heads (which can be redundant within a single block), blocks build the representation incrementally; removing any middle or early block breaks the chain that produces the person-discriminative dims by the final layer. Block 11 is post-hoc refinement that the classifier can survive without. Everything else is load-bearing.
42
+
43
+ The takeaway for backbone compression: **naive block skipping on a frozen pretrained ViT-B reaches a hard ceiling at one block**. To get a shallower model, we need Stage 4 — train a new shallower student that learns a compact representation directly, rather than trying to strip layers from the existing one.
44
+
45
+ ## What this stage ships
46
+
47
+ - `block_ablation.py` — the sweep script
48
+ - `block_importance.json` — per-block F1 + L2 deviation
49
+ - `block_pruning_curve.json` — cumulative F1 at K=1, 2, 3, …, 12
50
+
51
+ ## Parameter accounting
52
+
53
+ Each block is ~7.08M params (1.77M qkv + 589K proj + 4.72M MLP + LN + LayerScale). At K=1, ~7.1M params are effectively zeroed (8.3% of the 85.6M backbone). At K=2 with a small F1 cost, ~14.2M (16.6%) — the 0.13 F1 drop makes this generally not worth it for a person detector where 0.87 is the current baseline. Further compression should come from Stages 2 + 4 + 5 combined, not depth alone.
stage_3/block_ablation.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stage 3: depth reduction via block ablation.
2
+
3
+ For each of the 12 transformer blocks, zero both the attention proj and the
4
+ MLP fc2 output projections. Because each block is x + attn(x) + mlp(x), this
5
+ degenerates the block to an identity (residual pass-through). Measure F1 on
6
+ the Stage 0 classifier. Rank blocks by smallest F1 drop, sweep cumulative
7
+ skipping, identify how many blocks can be dropped without collapsing.
8
+
9
+ Output:
10
+ block_importance.json
11
+ block_pruning_curve.json
12
+ """
13
+ import os, sys, json, time
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import numpy as np
17
+ from PIL import Image
18
+ from pycocotools.coco import COCO
19
+ from transformers import AutoModel
20
+
21
+ COCO_ROOT = '/home/zootest/datasets/coco'
22
+ STAGE0_CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
23
+ N_CALIBRATION = 1000
24
+ N_BLOCKS = 12
25
+ RES = 768
26
+ D = 768
27
+ OUT_DIR = '/mnt/d/_tmp/1pc_repo/stage_3'
28
+ DEVICE = 'cuda'
29
+
30
+
31
+ def load_classifier():
32
+ with open(STAGE0_CLASSIFIER) as f:
33
+ c = json.load(f)
34
+ pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
35
+ neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
36
+ return pos, neg, float(c['threshold'])
37
+
38
+
39
+ @torch.inference_mode()
40
+ def score_images(argus, imgs, pos, neg):
41
+ scores = []
42
+ for x in imgs:
43
+ with torch.autocast('cuda', dtype=torch.bfloat16):
44
+ out = argus.backbone.forward_features(x)
45
+ patches = out['x_norm_patchtokens'].float().squeeze(0)
46
+ ln = F.layer_norm(patches, [D])
47
+ pooled = ln.max(dim=0).values
48
+ scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
49
+ return torch.tensor(scores, device=DEVICE)
50
+
51
+
52
+ def f1_of(scores, labels, thr):
53
+ pred = scores > thr
54
+ tp = (pred & labels).sum().float()
55
+ fp = (pred & ~labels).sum().float()
56
+ fn = (~pred & labels).sum().float()
57
+ prec = tp / (tp + fp).clamp(min=1)
58
+ rec = tp / (tp + fn).clamp(min=1)
59
+ f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
60
+ return float(f1), float(prec), float(rec)
61
+
62
+
63
+ def ablate_block(model, block_idx, zero=True):
64
+ """Zero attn.proj and mlp.fc2 of the given block so the block degenerates
65
+ to an identity via residual. Returns (orig_proj, orig_fc2) for restoring."""
66
+ block = model.backbone.blocks[block_idx]
67
+ orig_proj = block.attn.proj.weight.detach().clone()
68
+ orig_fc2 = block.mlp.fc2.weight.detach().clone()
69
+ if zero:
70
+ with torch.no_grad():
71
+ block.attn.proj.weight.data.zero_()
72
+ block.mlp.fc2.weight.data.zero_()
73
+ return orig_proj, orig_fc2
74
+
75
+
76
+ def restore_block(model, block_idx, orig_proj, orig_fc2):
77
+ block = model.backbone.blocks[block_idx]
78
+ block.attn.proj.weight.data.copy_(orig_proj)
79
+ block.mlp.fc2.weight.data.copy_(orig_fc2)
80
+
81
+
82
+ def load_calibration(coco, n, MEAN, STD):
83
+ img_ids = sorted(coco.getImgIds())[:n]
84
+ tensors, labels = [], []
85
+ for img_id in img_ids:
86
+ info = coco.loadImgs(img_id)[0]
87
+ path = f"{COCO_ROOT}/val2017/{info['file_name']}"
88
+ img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
89
+ arr = np.asarray(img, dtype=np.uint8).copy()
90
+ x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
91
+ tensors.append((x - MEAN) / STD)
92
+ labels.append(any(a['category_id'] == 1
93
+ for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
94
+ return tensors, torch.tensor(labels, dtype=torch.bool, device=DEVICE)
95
+
96
+
97
+ def main():
98
+ os.makedirs(OUT_DIR, exist_ok=True)
99
+ print('[init] loading Argus', flush=True)
100
+ model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
101
+ pos, neg, thr = load_classifier()
102
+
103
+ MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
104
+ STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
105
+
106
+ print(f'[calib] loading {N_CALIBRATION} COCO val images', flush=True)
107
+ coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
108
+ imgs, labels = load_calibration(coco, N_CALIBRATION, MEAN, STD)
109
+
110
+ print('[baseline]', flush=True)
111
+ base_scores = score_images(model, imgs, pos, neg)
112
+ base_f1, base_p, base_r = f1_of(base_scores, labels, thr)
113
+ print(f' baseline F1={base_f1:.4f} P={base_p:.4f} R={base_r:.4f}', flush=True)
114
+
115
+ # Per-block individual ablation
116
+ per_block = []
117
+ t0 = time.time()
118
+ for b in range(N_BLOCKS):
119
+ op, of = ablate_block(model, b)
120
+ scores = score_images(model, imgs, pos, neg)
121
+ restore_block(model, b, op, of)
122
+ f1, p, r = f1_of(scores, labels, thr)
123
+ drop = base_f1 - f1
124
+ per_block.append({'block': b, 'F1': f1, 'precision': p, 'recall': r,
125
+ 'F1_drop': drop})
126
+ print(f' block {b:>2} F1={f1:.4f} drop={drop:+.4f} '
127
+ f'{(time.time()-t0):.1f}s', flush=True)
128
+
129
+ ranked = sorted(per_block, key=lambda x: x['F1_drop'])
130
+
131
+ # Cumulative ablation curve
132
+ print('[curve] cumulative block ablation', flush=True)
133
+ curve = []
134
+ backups = {b: ablate_block(model, b) for b in range(N_BLOCKS)}
135
+ for b, (op, of) in backups.items():
136
+ restore_block(model, b, op, of) # ensure clean start
137
+ for K in [1, 2, 3, 4, 5, 6, 8, 10, 12]:
138
+ # Restore all
139
+ for b in range(N_BLOCKS):
140
+ op, of = backups[b]
141
+ restore_block(model, b, op, of)
142
+ # Ablate top-K most-prunable
143
+ for r in ranked[:K]:
144
+ b = r['block']
145
+ with torch.no_grad():
146
+ model.backbone.blocks[b].attn.proj.weight.data.zero_()
147
+ model.backbone.blocks[b].mlp.fc2.weight.data.zero_()
148
+ scores = score_images(model, imgs, pos, neg)
149
+ f1, p, rr = f1_of(scores, labels, thr)
150
+ curve.append({'blocks_pruned': K, 'F1': f1, 'F1_drop': base_f1 - f1,
151
+ 'precision': p, 'recall': rr,
152
+ 'pruned_list': [r['block'] for r in ranked[:K]]})
153
+ print(f' K={K:>2} F1={f1:.4f} drop={base_f1-f1:+.4f} '
154
+ f'blocks pruned={[r["block"] for r in ranked[:K]]}', flush=True)
155
+ # Restore
156
+ for b in range(N_BLOCKS):
157
+ op, of = backups[b]
158
+ restore_block(model, b, op, of)
159
+
160
+ with open(f'{OUT_DIR}/block_importance.json', 'w') as f:
161
+ json.dump({'baseline_F1': base_f1, 'per_block': per_block,
162
+ 'ranked_most_prunable_first': [(r['block'], r['F1_drop'])
163
+ for r in ranked]},
164
+ f, indent=2)
165
+ with open(f'{OUT_DIR}/block_pruning_curve.json', 'w') as f:
166
+ json.dump({'baseline_F1': base_f1, 'curve': curve}, f, indent=2)
167
+ print(f'[done] -> {OUT_DIR}', flush=True)
168
+
169
+
170
+ if __name__ == '__main__':
171
+ main()
stage_3/block_importance.json ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "baseline_F1": 0.8939393758773804,
3
+ "per_block": [
4
+ {
5
+ "block": 0,
6
+ "F1": 0.0,
7
+ "precision": 0.0,
8
+ "recall": 0.0,
9
+ "F1_drop": 0.8939393758773804
10
+ },
11
+ {
12
+ "block": 1,
13
+ "F1": 0.010928962379693985,
14
+ "precision": 1.0,
15
+ "recall": 0.005494505632668734,
16
+ "F1_drop": 0.8830104134976864
17
+ },
18
+ {
19
+ "block": 2,
20
+ "F1": 0.0,
21
+ "precision": 0.0,
22
+ "recall": 0.0,
23
+ "F1_drop": 0.8939393758773804
24
+ },
25
+ {
26
+ "block": 3,
27
+ "F1": 0.7833163142204285,
28
+ "precision": 0.8810068368911743,
29
+ "recall": 0.7051281929016113,
30
+ "F1_drop": 0.1106230616569519
31
+ },
32
+ {
33
+ "block": 4,
34
+ "F1": 0.76458340883255,
35
+ "precision": 0.8864734172821045,
36
+ "recall": 0.6721611618995667,
37
+ "F1_drop": 0.12935596704483032
38
+ },
39
+ {
40
+ "block": 5,
41
+ "F1": 0.5989974737167358,
42
+ "precision": 0.9484127163887024,
43
+ "recall": 0.4377289414405823,
44
+ "F1_drop": 0.29494190216064453
45
+ },
46
+ {
47
+ "block": 6,
48
+ "F1": 0.864454984664917,
49
+ "precision": 0.8958742618560791,
50
+ "recall": 0.8351648449897766,
51
+ "F1_drop": 0.02948439121246338
52
+ },
53
+ {
54
+ "block": 7,
55
+ "F1": 0.15228426456451416,
56
+ "precision": 1.0,
57
+ "recall": 0.08241758495569229,
58
+ "F1_drop": 0.7416551113128662
59
+ },
60
+ {
61
+ "block": 8,
62
+ "F1": 0.430379718542099,
63
+ "precision": 0.9272727370262146,
64
+ "recall": 0.28021979331970215,
65
+ "F1_drop": 0.46355965733528137
66
+ },
67
+ {
68
+ "block": 9,
69
+ "F1": 0.674500584602356,
70
+ "precision": 0.9409835934638977,
71
+ "recall": 0.5256410241127014,
72
+ "F1_drop": 0.21943879127502441
73
+ },
74
+ {
75
+ "block": 10,
76
+ "F1": 0.7431092262268066,
77
+ "precision": 0.9335179924964905,
78
+ "recall": 0.6172161102294922,
79
+ "F1_drop": 0.15083014965057373
80
+ },
81
+ {
82
+ "block": 11,
83
+ "F1": 0.8757709264755249,
84
+ "precision": 0.8438030481338501,
85
+ "recall": 0.9102563858032227,
86
+ "F1_drop": 0.01816844940185547
87
+ }
88
+ ],
89
+ "ranked_most_prunable_first": [
90
+ [
91
+ 11,
92
+ 0.01816844940185547
93
+ ],
94
+ [
95
+ 6,
96
+ 0.02948439121246338
97
+ ],
98
+ [
99
+ 3,
100
+ 0.1106230616569519
101
+ ],
102
+ [
103
+ 4,
104
+ 0.12935596704483032
105
+ ],
106
+ [
107
+ 10,
108
+ 0.15083014965057373
109
+ ],
110
+ [
111
+ 9,
112
+ 0.21943879127502441
113
+ ],
114
+ [
115
+ 5,
116
+ 0.29494190216064453
117
+ ],
118
+ [
119
+ 8,
120
+ 0.46355965733528137
121
+ ],
122
+ [
123
+ 7,
124
+ 0.7416551113128662
125
+ ],
126
+ [
127
+ 1,
128
+ 0.8830104134976864
129
+ ],
130
+ [
131
+ 0,
132
+ 0.8939393758773804
133
+ ],
134
+ [
135
+ 2,
136
+ 0.8939393758773804
137
+ ]
138
+ ]
139
+ }
stage_3/block_pruning_curve.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "baseline_F1": 0.8939393758773804,
3
+ "curve": [
4
+ {
5
+ "blocks_pruned": 1,
6
+ "F1": 0.8757709264755249,
7
+ "F1_drop": 0.01816844940185547,
8
+ "precision": 0.8438030481338501,
9
+ "recall": 0.9102563858032227,
10
+ "pruned_list": [
11
+ 11
12
+ ]
13
+ },
14
+ {
15
+ "blocks_pruned": 2,
16
+ "F1": 0.7702127695083618,
17
+ "F1_drop": 0.12372660636901855,
18
+ "precision": 0.9187816977500916,
19
+ "recall": 0.66300368309021,
20
+ "pruned_list": [
21
+ 11,
22
+ 6
23
+ ]
24
+ },
25
+ {
26
+ "blocks_pruned": 3,
27
+ "F1": 0.0,
28
+ "F1_drop": 0.8939393758773804,
29
+ "precision": 0.0,
30
+ "recall": 0.0,
31
+ "pruned_list": [
32
+ 11,
33
+ 6,
34
+ 3
35
+ ]
36
+ },
37
+ {
38
+ "blocks_pruned": 4,
39
+ "F1": 0.0,
40
+ "F1_drop": 0.8939393758773804,
41
+ "precision": 0.0,
42
+ "recall": 0.0,
43
+ "pruned_list": [
44
+ 11,
45
+ 6,
46
+ 3,
47
+ 4
48
+ ]
49
+ },
50
+ {
51
+ "blocks_pruned": 5,
52
+ "F1": 0.0,
53
+ "F1_drop": 0.8939393758773804,
54
+ "precision": 0.0,
55
+ "recall": 0.0,
56
+ "pruned_list": [
57
+ 11,
58
+ 6,
59
+ 3,
60
+ 4,
61
+ 10
62
+ ]
63
+ },
64
+ {
65
+ "blocks_pruned": 6,
66
+ "F1": 0.0,
67
+ "F1_drop": 0.8939393758773804,
68
+ "precision": 0.0,
69
+ "recall": 0.0,
70
+ "pruned_list": [
71
+ 11,
72
+ 6,
73
+ 3,
74
+ 4,
75
+ 10,
76
+ 9
77
+ ]
78
+ },
79
+ {
80
+ "blocks_pruned": 8,
81
+ "F1": 0.0,
82
+ "F1_drop": 0.8939393758773804,
83
+ "precision": 0.0,
84
+ "recall": 0.0,
85
+ "pruned_list": [
86
+ 11,
87
+ 6,
88
+ 3,
89
+ 4,
90
+ 10,
91
+ 9,
92
+ 5,
93
+ 8
94
+ ]
95
+ },
96
+ {
97
+ "blocks_pruned": 10,
98
+ "F1": 0.0,
99
+ "F1_drop": 0.8939393758773804,
100
+ "precision": 0.0,
101
+ "recall": 0.0,
102
+ "pruned_list": [
103
+ 11,
104
+ 6,
105
+ 3,
106
+ 4,
107
+ 10,
108
+ 9,
109
+ 5,
110
+ 8,
111
+ 7,
112
+ 1
113
+ ]
114
+ },
115
+ {
116
+ "blocks_pruned": 12,
117
+ "F1": 0.0,
118
+ "F1_drop": 0.8939393758773804,
119
+ "precision": 0.0,
120
+ "recall": 0.0,
121
+ "pruned_list": [
122
+ 11,
123
+ 6,
124
+ 3,
125
+ 4,
126
+ 10,
127
+ 9,
128
+ 5,
129
+ 8,
130
+ 7,
131
+ 1,
132
+ 0,
133
+ 2
134
+ ]
135
+ }
136
+ ]
137
+ }