phanerozoic commited on
Commit
6027a13
·
verified ·
1 Parent(s): caedda8

Stage 2b: structural head removal (83.68M backbone, F1 0.9159 preserved)

Browse files
stage_2b/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 2b: Structural Head Removal
2
+
3
+ Unlike Stage 2a which masks the 10 most prunable attention heads by zeroing their output-projection columns, Stage 2b physically shrinks the attention tensors. The `qkv.weight` rows corresponding to pruned heads are deleted, the `proj.weight` columns are deleted, and each block's `num_heads` is reduced. MLPs, LayerNorms, and LayerScales are unchanged.
4
+
5
+ ## Per-block pruning plan
6
+
7
+ ```
8
+ Block Heads removed Heads kept
9
+ 3 [5] 11
10
+ 4 [8] 11
11
+ 6 [9] 11
12
+ 7 [11] 11
13
+ 9 [11, 10, 9] 9
14
+ 10 [4] 11
15
+ 11 [1, 9] 10
16
+ ```
17
+
18
+ Other blocks (0, 1, 2, 5, 8) retain all 12 heads.
19
+
20
+ ## Result
21
+
22
+ ```
23
+ backbone params before: 85,641,984 = 85.64 M
24
+ backbone params after: 83,675,904 = 83.68 M
25
+ saved: 1,966,080 = 1.97 M (2.30 %)
26
+ F1 at K=10 structural: 0.9159
27
+ F1 at K=10 Stage 2a mask: 0.9159 (byte-identical forward)
28
+ ```
29
+
30
+ ## Loading
31
+
32
+ The pruned backbone is *not* a drop-in replacement for the stock Argus backbone because the attention module shapes differ per-block. Use `load_pruned_backbone.py`:
33
+
34
+ ```python
35
+ from load_pruned_backbone import load_stage2b_backbone
36
+ backbone = load_stage2b_backbone('pruned_state_dict.safetensors', 'head_config.json')
37
+ ```
38
+
39
+ The loader constructs an Argus ViT-B, walks `head_config.json`, and replaces each block's attention with a `PrunedSelfAttention` sized for the kept heads before copying weights.
40
+
41
+ ## Files
42
+
43
+ - `stage_2b_structural.py` — the conversion script
44
+ - `pruned_state_dict.safetensors` — shrunk backbone weights
45
+ - `head_config.json` — per-block `num_heads`, kept-head indices, removed-head indices
46
+ - `load_pruned_backbone.py` — loader
47
+ - `eval.json` — F1 parity + param delta
48
+
49
+ ## What this buys
50
+
51
+ - 2.3 % backbone param reduction for free (no F1 cost; +0.022 F1 gain over Stage 0 baseline).
52
+ - Smaller forward pass: pruned blocks do less attention compute.
53
+ - Sets up Stage 3 (depth reduction) and Stage 4 (specialist backbone) on a smaller starting model.
stage_2b/eval.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "baseline_F1_stage2a_mask_K10": 0.9159,
3
+ "stage2b_structural_F1": 0.9158878326416016,
4
+ "precision": 0.9351145029067993,
5
+ "recall": 0.8974359035491943,
6
+ "backbone_params_before": 85641984,
7
+ "backbone_params_after": 83675904,
8
+ "backbone_params_saved": 1966080,
9
+ "n_calibration_images": 1000
10
+ }
stage_2b/head_config.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "per_block_num_heads": [
3
+ 12,
4
+ 12,
5
+ 12,
6
+ 11,
7
+ 11,
8
+ 12,
9
+ 11,
10
+ 11,
11
+ 12,
12
+ 9,
13
+ 11,
14
+ 10
15
+ ],
16
+ "per_block_kept_heads": {
17
+ "0": [
18
+ 0,
19
+ 1,
20
+ 2,
21
+ 3,
22
+ 4,
23
+ 5,
24
+ 6,
25
+ 7,
26
+ 8,
27
+ 9,
28
+ 10,
29
+ 11
30
+ ],
31
+ "1": [
32
+ 0,
33
+ 1,
34
+ 2,
35
+ 3,
36
+ 4,
37
+ 5,
38
+ 6,
39
+ 7,
40
+ 8,
41
+ 9,
42
+ 10,
43
+ 11
44
+ ],
45
+ "2": [
46
+ 0,
47
+ 1,
48
+ 2,
49
+ 3,
50
+ 4,
51
+ 5,
52
+ 6,
53
+ 7,
54
+ 8,
55
+ 9,
56
+ 10,
57
+ 11
58
+ ],
59
+ "3": [
60
+ 0,
61
+ 1,
62
+ 2,
63
+ 3,
64
+ 4,
65
+ 6,
66
+ 7,
67
+ 8,
68
+ 9,
69
+ 10,
70
+ 11
71
+ ],
72
+ "4": [
73
+ 0,
74
+ 1,
75
+ 2,
76
+ 3,
77
+ 4,
78
+ 5,
79
+ 6,
80
+ 7,
81
+ 9,
82
+ 10,
83
+ 11
84
+ ],
85
+ "5": [
86
+ 0,
87
+ 1,
88
+ 2,
89
+ 3,
90
+ 4,
91
+ 5,
92
+ 6,
93
+ 7,
94
+ 8,
95
+ 9,
96
+ 10,
97
+ 11
98
+ ],
99
+ "6": [
100
+ 0,
101
+ 1,
102
+ 2,
103
+ 3,
104
+ 4,
105
+ 5,
106
+ 6,
107
+ 7,
108
+ 8,
109
+ 10,
110
+ 11
111
+ ],
112
+ "7": [
113
+ 0,
114
+ 1,
115
+ 2,
116
+ 3,
117
+ 4,
118
+ 5,
119
+ 6,
120
+ 7,
121
+ 8,
122
+ 9,
123
+ 10
124
+ ],
125
+ "8": [
126
+ 0,
127
+ 1,
128
+ 2,
129
+ 3,
130
+ 4,
131
+ 5,
132
+ 6,
133
+ 7,
134
+ 8,
135
+ 9,
136
+ 10,
137
+ 11
138
+ ],
139
+ "9": [
140
+ 0,
141
+ 1,
142
+ 2,
143
+ 3,
144
+ 4,
145
+ 5,
146
+ 6,
147
+ 7,
148
+ 8
149
+ ],
150
+ "10": [
151
+ 0,
152
+ 1,
153
+ 2,
154
+ 3,
155
+ 5,
156
+ 6,
157
+ 7,
158
+ 8,
159
+ 9,
160
+ 10,
161
+ 11
162
+ ],
163
+ "11": [
164
+ 0,
165
+ 2,
166
+ 3,
167
+ 4,
168
+ 5,
169
+ 6,
170
+ 7,
171
+ 8,
172
+ 10,
173
+ 11
174
+ ]
175
+ },
176
+ "per_block_removed_heads": {
177
+ "0": [],
178
+ "1": [],
179
+ "2": [],
180
+ "3": [
181
+ 5
182
+ ],
183
+ "4": [
184
+ 8
185
+ ],
186
+ "5": [],
187
+ "6": [
188
+ 9
189
+ ],
190
+ "7": [
191
+ 11
192
+ ],
193
+ "8": [],
194
+ "9": [
195
+ 11,
196
+ 10,
197
+ 9
198
+ ],
199
+ "10": [
200
+ 4
201
+ ],
202
+ "11": [
203
+ 1,
204
+ 9
205
+ ]
206
+ },
207
+ "head_dim": 64,
208
+ "dim": 768
209
+ }
stage_2b/load_pruned_backbone.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load the Stage 2b pruned backbone.
2
+
3
+ Reconstructs an argus.DinoVisionTransformer, replaces each block's attention
4
+ with a PrunedSelfAttention sized per head_config.json, and copies weights
5
+ from pruned_state_dict.safetensors.
6
+ """
7
+ import json, sys, os
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ sys.path.insert(0, '/mnt/d/Argus')
13
+ import argus
14
+
15
+
16
+ class PrunedSelfAttention(nn.Module):
17
+ def __init__(self, dim=768, num_heads=12, head_dim=64,
18
+ qkv_bias=False, proj_bias=True, mask_k_bias=False):
19
+ super().__init__()
20
+ self.num_heads = num_heads
21
+ self.head_dim = head_dim
22
+ self.inner_dim = num_heads * head_dim
23
+ self.scale = head_dim ** -0.5
24
+ linear_class = argus.LinearKMaskedBias if mask_k_bias else nn.Linear
25
+ self.qkv = linear_class(dim, 3 * self.inner_dim, bias=qkv_bias)
26
+ self.proj = nn.Linear(self.inner_dim, dim, bias=proj_bias)
27
+
28
+ def forward(self, x, attn_bias=None, rope=None):
29
+ B, N, _ = x.shape
30
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
31
+ q, k, v = torch.unbind(qkv, 2)
32
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
33
+ if rope is not None:
34
+ sin, cos = rope
35
+ prefix = N - sin.shape[-2]
36
+ q_pre, q_suf = q[:, :, :prefix, :], q[:, :, prefix:, :]
37
+ k_pre, k_suf = k[:, :, :prefix, :], k[:, :, prefix:, :]
38
+ q = torch.cat([q_pre, argus.rope_apply(q_suf, sin, cos)], dim=-2)
39
+ k = torch.cat([k_pre, argus.rope_apply(k_suf, sin, cos)], dim=-2)
40
+ attn = F.scaled_dot_product_attention(q, k, v)
41
+ attn = attn.transpose(1, 2).reshape(B, N, self.inner_dim)
42
+ return self.proj(attn)
43
+
44
+
45
+ def load_stage2b_backbone(state_dict_path, head_config_path):
46
+ from safetensors.torch import load_file
47
+ with open(head_config_path) as f:
48
+ cfg = json.load(f)
49
+ backbone = argus.build_eupe_vitb16()
50
+ # Resize each block's attention module
51
+ for b, new_heads in enumerate(cfg['per_block_num_heads']):
52
+ if new_heads != 12:
53
+ block = backbone.blocks[b]
54
+ block.attn = PrunedSelfAttention(
55
+ dim=cfg['dim'], num_heads=new_heads, head_dim=cfg['head_dim'],
56
+ qkv_bias=False, proj_bias=True, mask_k_bias=False,
57
+ )
58
+ state = load_file(state_dict_path)
59
+ backbone.load_state_dict(state, strict=False)
60
+ return backbone
61
+
62
+
63
+ if __name__ == '__main__':
64
+ here = os.path.dirname(os.path.abspath(__file__))
65
+ backbone = load_stage2b_backbone(
66
+ os.path.join(here, 'pruned_state_dict.safetensors'),
67
+ os.path.join(here, 'head_config.json'),
68
+ )
69
+ total = sum(p.numel() for p in backbone.parameters())
70
+ print(f'Stage 2b backbone loaded: {total:,} params = {total/1e6:.2f}M')
71
+ x = torch.randn(1, 3, 768, 768)
72
+ backbone.eval()
73
+ with torch.inference_mode():
74
+ out = backbone.forward_features(x)
75
+ print(f'forward OK patch tokens: {tuple(out["x_norm_patchtokens"].shape)}')
stage_2b/pruned_state_dict.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:133aae4f1e7b7e232b517c71aec50628d6d4475e41d19c2023a04e5b260962d6
3
+ size 334718768