phanerozoic commited on
Commit
e477540
·
verified ·
1 Parent(s): 3729ac4

Stage 3: add Stage2+Stage3 compound grid + README update

Browse files
stage_3/README.md CHANGED
@@ -48,6 +48,22 @@ The takeaway for backbone compression: **naive block skipping on a frozen pretra
48
  - `block_importance.json` — per-block F1 + L2 deviation
49
  - `block_pruning_curve.json` — cumulative F1 at K=1, 2, 3, …, 12
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ## Parameter accounting
52
 
53
  Each block is ~7.08M params (1.77M qkv + 589K proj + 4.72M MLP + LN + LayerScale). At K=1, ~7.1M params are effectively zeroed (8.3% of the 85.6M backbone). At K=2 with a small F1 cost, ~14.2M (16.6%) — the 0.13 F1 drop makes this generally not worth it for a person detector where 0.87 is the current baseline. Further compression should come from Stages 2 + 4 + 5 combined, not depth alone.
 
48
  - `block_importance.json` — per-block F1 + L2 deviation
49
  - `block_pruning_curve.json` — cumulative F1 at K=1, 2, 3, …, 12
50
 
51
+ ## Compound with Stage 2
52
+
53
+ `compound_stage2_stage3.py` sweeps the Stage 2 head-pruning × Stage 3 block-pruning grid. Best points:
54
+
55
+ ```
56
+ K_heads K_blocks F1 params saved
57
+ 0 0 0.894 0 (baseline)
58
+ 10 0 0.916 1.97M (Stage 2 peak, +0.022 F1)
59
+ 10 1 0.882 9.05M (stack block 11, -0.012 F1 from baseline)
60
+ 5 1 0.880 8.06M (same tier, fewer heads pruned)
61
+ 0 1 0.876 7.08M (Stage 3 alone)
62
+ 15 2 0.243 17.11M (collapses — block 6 too important)
63
+ ```
64
+
65
+ Heads and blocks do compose but with a penalty. Removing the 10 prunable heads while also dropping block 11 gives a clean F1 ≈ 0.88 at 9M params saved, which is the best head+depth combined offer without training anything new. Beyond that, Stage 4 (specialist backbone) is needed for further compression.
66
+
67
  ## Parameter accounting
68
 
69
  Each block is ~7.08M params (1.77M qkv + 589K proj + 4.72M MLP + LN + LayerScale). At K=1, ~7.1M params are effectively zeroed (8.3% of the 85.6M backbone). At K=2 with a small F1 cost, ~14.2M (16.6%) — the 0.13 F1 drop makes this generally not worth it for a person detector where 0.87 is the current baseline. Further compression should come from Stages 2 + 4 + 5 combined, not depth alone.
stage_3/compound_results.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "baseline_F1": 0.8939393758773804,
3
+ "grid": [
4
+ {
5
+ "K_heads": 0,
6
+ "K_blocks": 0,
7
+ "F1": 0.8939393758773804,
8
+ "precision": 0.9254902005195618,
9
+ "recall": 0.8644688725471497,
10
+ "approx_params_saved": 0
11
+ },
12
+ {
13
+ "K_heads": 0,
14
+ "K_blocks": 1,
15
+ "F1": 0.8757709264755249,
16
+ "precision": 0.8438030481338501,
17
+ "recall": 0.9102563858032227,
18
+ "approx_params_saved": 7079424
19
+ },
20
+ {
21
+ "K_heads": 0,
22
+ "K_blocks": 2,
23
+ "F1": 0.7702127695083618,
24
+ "precision": 0.9187816977500916,
25
+ "recall": 0.66300368309021,
26
+ "approx_params_saved": 14158848
27
+ },
28
+ {
29
+ "K_heads": 5,
30
+ "K_blocks": 0,
31
+ "F1": 0.9085714221000671,
32
+ "precision": 0.9464285969734192,
33
+ "recall": 0.8736263513565063,
34
+ "approx_params_saved": 983040
35
+ },
36
+ {
37
+ "K_heads": 5,
38
+ "K_blocks": 1,
39
+ "F1": 0.8795811533927917,
40
+ "precision": 0.8399999737739563,
41
+ "recall": 0.9230769276618958,
42
+ "approx_params_saved": 8062464
43
+ },
44
+ {
45
+ "K_heads": 5,
46
+ "K_blocks": 2,
47
+ "F1": 0.8004115223884583,
48
+ "precision": 0.9131455421447754,
49
+ "recall": 0.7124541997909546,
50
+ "approx_params_saved": 15141888
51
+ },
52
+ {
53
+ "K_heads": 10,
54
+ "K_blocks": 0,
55
+ "F1": 0.9158878326416016,
56
+ "precision": 0.9351145029067993,
57
+ "recall": 0.8974359035491943,
58
+ "approx_params_saved": 1966080
59
+ },
60
+ {
61
+ "K_heads": 10,
62
+ "K_blocks": 1,
63
+ "F1": 0.8819875717163086,
64
+ "precision": 0.8554216623306274,
65
+ "recall": 0.9102563858032227,
66
+ "approx_params_saved": 9045504
67
+ },
68
+ {
69
+ "K_heads": 10,
70
+ "K_blocks": 2,
71
+ "F1": 0.7060185074806213,
72
+ "precision": 0.9591194987297058,
73
+ "recall": 0.5586080551147461,
74
+ "approx_params_saved": 16124928
75
+ },
76
+ {
77
+ "K_heads": 15,
78
+ "K_blocks": 0,
79
+ "F1": 0.8949342966079712,
80
+ "precision": 0.9173076748847961,
81
+ "recall": 0.8736263513565063,
82
+ "approx_params_saved": 2949120
83
+ },
84
+ {
85
+ "K_heads": 15,
86
+ "K_blocks": 1,
87
+ "F1": 0.8675373196601868,
88
+ "precision": 0.8840304017066956,
89
+ "recall": 0.8516483306884766,
90
+ "approx_params_saved": 10028544
91
+ },
92
+ {
93
+ "K_heads": 15,
94
+ "K_blocks": 2,
95
+ "F1": 0.24320000410079956,
96
+ "precision": 0.9620253443717957,
97
+ "recall": 0.13919414579868317,
98
+ "approx_params_saved": 17107968
99
+ },
100
+ {
101
+ "K_heads": 20,
102
+ "K_blocks": 0,
103
+ "F1": 0.8971269726753235,
104
+ "precision": 0.908067524433136,
105
+ "recall": 0.8864468932151794,
106
+ "approx_params_saved": 3932160
107
+ },
108
+ {
109
+ "K_heads": 20,
110
+ "K_blocks": 1,
111
+ "F1": 0.8467432856559753,
112
+ "precision": 0.8875501751899719,
113
+ "recall": 0.8095238208770752,
114
+ "approx_params_saved": 11011584
115
+ },
116
+ {
117
+ "K_heads": 20,
118
+ "K_blocks": 2,
119
+ "F1": 0.16415409743785858,
120
+ "precision": 0.9607843160629272,
121
+ "recall": 0.08974359184503555,
122
+ "approx_params_saved": 18091008
123
+ }
124
+ ]
125
+ }
stage_3/compound_stage2_stage3.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compound ablation: Stage 2 head mask + Stage 3 block ablation.
2
+
3
+ Verifies that the two stages compose. Measures F1 at combinations of
4
+ K_heads ∈ {0, 5, 10, 15, 20} and K_blocks ∈ {0, 1, 2} using the
5
+ already-computed importance rankings.
6
+
7
+ Output: compound_results.json
8
+ """
9
+ import json, os
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from PIL import Image
14
+ from pycocotools.coco import COCO
15
+ from transformers import AutoModel
16
+
17
+ COCO_ROOT = '/home/zootest/datasets/coco'
18
+ CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
19
+ STAGE2_IMPORT = '/mnt/d/_tmp/1pc_repo/stage_2/head_importance.json'
20
+ STAGE3_IMPORT = '/mnt/d/_tmp/1pc_repo/stage_3/block_importance.json'
21
+ OUT = '/mnt/d/_tmp/1pc_repo/stage_3/compound_results.json'
22
+ DEVICE = 'cuda'
23
+ HEAD_DIM = 64
24
+ RES = 768
25
+ D = 768
26
+ N = 1000
27
+
28
+
29
+ def f1_of(scores, labels, thr):
30
+ pred = scores > thr
31
+ tp = (pred & labels).sum().float()
32
+ fp = (pred & ~labels).sum().float()
33
+ fn = (~pred & labels).sum().float()
34
+ prec = tp / (tp + fp).clamp(min=1)
35
+ rec = tp / (tp + fn).clamp(min=1)
36
+ f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
37
+ return float(f1), float(prec), float(rec)
38
+
39
+
40
+ @torch.inference_mode()
41
+ def score_all(model, imgs, pos, neg):
42
+ scores = []
43
+ for x in imgs:
44
+ with torch.autocast('cuda', dtype=torch.bfloat16):
45
+ out = model.backbone.forward_features(x)
46
+ patches = out['x_norm_patchtokens'].float().squeeze(0)
47
+ ln = F.layer_norm(patches, [D])
48
+ pooled = ln.max(dim=0).values
49
+ scores.append((pooled[pos].sum() - pooled[neg].sum()).item())
50
+ return torch.tensor(scores, device=DEVICE)
51
+
52
+
53
+ def main():
54
+ with open(CLASSIFIER) as f:
55
+ c = json.load(f)
56
+ pos = torch.tensor(c['pos_dims'], dtype=torch.long, device=DEVICE)
57
+ neg = torch.tensor(c['neg_dims'], dtype=torch.long, device=DEVICE)
58
+ thr = float(c['threshold'])
59
+
60
+ with open(STAGE2_IMPORT) as f:
61
+ s2 = json.load(f)
62
+ head_rank = s2['ranked_most_prunable_first'] # list of (block, head, drop)
63
+
64
+ with open(STAGE3_IMPORT) as f:
65
+ s3 = json.load(f)
66
+ block_rank = s3['ranked_most_prunable_first'] # list of (block, drop)
67
+
68
+ print('[load] Argus + COCO', flush=True)
69
+ model = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True).to(DEVICE).eval()
70
+ MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
71
+ STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
72
+ coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
73
+ img_ids = sorted(coco.getImgIds())[:N]
74
+ imgs, labels = [], []
75
+ for img_id in img_ids:
76
+ info = coco.loadImgs(img_id)[0]
77
+ p = f"{COCO_ROOT}/val2017/{info['file_name']}"
78
+ img = Image.open(p).convert('RGB').resize((RES, RES), Image.BILINEAR)
79
+ arr = np.asarray(img, dtype=np.uint8).copy()
80
+ x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
81
+ imgs.append((x - MEAN) / STD)
82
+ labels.append(any(a['category_id'] == 1
83
+ for a in coco.loadAnns(coco.getAnnIds(imgIds=img_id, iscrowd=False))))
84
+ labels = torch.tensor(labels, dtype=torch.bool, device=DEVICE)
85
+
86
+ # Backup original weights
87
+ orig_proj = {b: model.backbone.blocks[b].attn.proj.weight.detach().clone() for b in range(12)}
88
+ orig_fc2 = {b: model.backbone.blocks[b].mlp.fc2.weight.detach().clone() for b in range(12)}
89
+
90
+ def restore():
91
+ for b in range(12):
92
+ model.backbone.blocks[b].attn.proj.weight.data.copy_(orig_proj[b])
93
+ model.backbone.blocks[b].mlp.fc2.weight.data.copy_(orig_fc2[b])
94
+
95
+ def apply_head_mask(K_heads):
96
+ for (bl, hd, _) in head_rank[:K_heads]:
97
+ model.backbone.blocks[bl].attn.proj.weight.data[:, hd*HEAD_DIM:(hd+1)*HEAD_DIM] = 0.0
98
+
99
+ def apply_block_drop(K_blocks):
100
+ for (bl, _) in block_rank[:K_blocks]:
101
+ model.backbone.blocks[bl].attn.proj.weight.data.zero_()
102
+ model.backbone.blocks[bl].mlp.fc2.weight.data.zero_()
103
+
104
+ results = []
105
+ for kh in [0, 5, 10, 15, 20]:
106
+ for kb in [0, 1, 2]:
107
+ restore()
108
+ apply_head_mask(kh)
109
+ apply_block_drop(kb)
110
+ s = score_all(model, imgs, pos, neg)
111
+ f1, p, r = f1_of(s, labels, thr)
112
+ # Approximate param savings
113
+ heads_params = kh * (147456 + 49152) # per-head qkv+proj cost
114
+ blocks_params = kb * (147456*12 + 49152*12 + 2*768*3072 + 1536) # rough per-block
115
+ saved = heads_params + blocks_params
116
+ results.append({'K_heads': kh, 'K_blocks': kb,
117
+ 'F1': f1, 'precision': p, 'recall': r,
118
+ 'approx_params_saved': saved})
119
+ print(f' K_heads={kh:>2} K_blocks={kb} F1={f1:.4f} P={p:.4f} R={r:.4f} '
120
+ f'saved={saved/1e6:.2f}M', flush=True)
121
+
122
+ restore()
123
+ with open(OUT, 'w') as f:
124
+ json.dump({'baseline_F1': s3['baseline_F1'], 'grid': results}, f, indent=2)
125
+ print(f'[done] -> {OUT}', flush=True)
126
+
127
+
128
+ if __name__ == '__main__':
129
+ main()