phanerozoic's picture
Stage 2b: structural head removal (83.68M backbone, F1 0.9159 preserved)
6027a13 verified

Stage 2b: Structural Head Removal

Unlike Stage 2a which masks the 10 most prunable attention heads by zeroing their output-projection columns, Stage 2b physically shrinks the attention tensors. The qkv.weight rows corresponding to pruned heads are deleted, the proj.weight columns are deleted, and each block's num_heads is reduced. MLPs, LayerNorms, and LayerScales are unchanged.

Per-block pruning plan

Block    Heads removed      Heads kept
 3       [5]                 11
 4       [8]                 11
 6       [9]                 11
 7       [11]                11
 9       [11, 10, 9]         9
 10      [4]                 11
 11      [1, 9]              10

Other blocks (0, 1, 2, 5, 8) retain all 12 heads.

Result

backbone params before:  85,641,984  = 85.64 M
backbone params after:   83,675,904  = 83.68 M
saved:                    1,966,080  =  1.97 M (2.30 %)
F1 at K=10 structural:    0.9159
F1 at K=10 Stage 2a mask: 0.9159    (byte-identical forward)

Loading

The pruned backbone is not a drop-in replacement for the stock Argus backbone because the attention module shapes differ per-block. Use load_pruned_backbone.py:

from load_pruned_backbone import load_stage2b_backbone
backbone = load_stage2b_backbone('pruned_state_dict.safetensors', 'head_config.json')

The loader constructs an Argus ViT-B, walks head_config.json, and replaces each block's attention with a PrunedSelfAttention sized for the kept heads before copying weights.

Files

  • stage_2b_structural.py — the conversion script
  • pruned_state_dict.safetensors — shrunk backbone weights
  • head_config.json — per-block num_heads, kept-head indices, removed-head indices
  • load_pruned_backbone.py — loader
  • eval.json — F1 parity + param delta

What this buys

  • 2.3 % backbone param reduction for free (no F1 cost; +0.022 F1 gain over Stage 0 baseline).
  • Smaller forward pass: pruned blocks do less attention compute.
  • Sets up Stage 3 (depth reduction) and Stage 4 (specialist backbone) on a smaller starting model.