Stage 2b: Structural Head Removal
Unlike Stage 2a which masks the 10 most prunable attention heads by zeroing their output-projection columns, Stage 2b physically shrinks the attention tensors. The qkv.weight rows corresponding to pruned heads are deleted, the proj.weight columns are deleted, and each block's num_heads is reduced. MLPs, LayerNorms, and LayerScales are unchanged.
Per-block pruning plan
Block Heads removed Heads kept
3 [5] 11
4 [8] 11
6 [9] 11
7 [11] 11
9 [11, 10, 9] 9
10 [4] 11
11 [1, 9] 10
Other blocks (0, 1, 2, 5, 8) retain all 12 heads.
Result
backbone params before: 85,641,984 = 85.64 M
backbone params after: 83,675,904 = 83.68 M
saved: 1,966,080 = 1.97 M (2.30 %)
F1 at K=10 structural: 0.9159
F1 at K=10 Stage 2a mask: 0.9159 (byte-identical forward)
Loading
The pruned backbone is not a drop-in replacement for the stock Argus backbone because the attention module shapes differ per-block. Use load_pruned_backbone.py:
from load_pruned_backbone import load_stage2b_backbone
backbone = load_stage2b_backbone('pruned_state_dict.safetensors', 'head_config.json')
The loader constructs an Argus ViT-B, walks head_config.json, and replaces each block's attention with a PrunedSelfAttention sized for the kept heads before copying weights.
Files
stage_2b_structural.py— the conversion scriptpruned_state_dict.safetensors— shrunk backbone weightshead_config.json— per-blocknum_heads, kept-head indices, removed-head indicesload_pruned_backbone.py— loadereval.json— F1 parity + param delta
What this buys
- 2.3 % backbone param reduction for free (no F1 cost; +0.022 F1 gain over Stage 0 baseline).
- Smaller forward pass: pruned blocks do less attention compute.
- Sets up Stage 3 (depth reduction) and Stage 4 (specialist backbone) on a smaller starting model.