File size: 2,090 Bytes
6027a13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Stage 2b: Structural Head Removal

Unlike Stage 2a which masks the 10 most prunable attention heads by zeroing their output-projection columns, Stage 2b physically shrinks the attention tensors. The `qkv.weight` rows corresponding to pruned heads are deleted, the `proj.weight` columns are deleted, and each block's `num_heads` is reduced. MLPs, LayerNorms, and LayerScales are unchanged.

## Per-block pruning plan

```
Block    Heads removed      Heads kept
 3       [5]                 11
 4       [8]                 11
 6       [9]                 11
 7       [11]                11
 9       [11, 10, 9]         9
 10      [4]                 11
 11      [1, 9]              10
```

Other blocks (0, 1, 2, 5, 8) retain all 12 heads.

## Result

```
backbone params before:  85,641,984  = 85.64 M
backbone params after:   83,675,904  = 83.68 M
saved:                    1,966,080  =  1.97 M (2.30 %)
F1 at K=10 structural:    0.9159
F1 at K=10 Stage 2a mask: 0.9159    (byte-identical forward)
```

## Loading

The pruned backbone is *not* a drop-in replacement for the stock Argus backbone because the attention module shapes differ per-block. Use `load_pruned_backbone.py`:

```python
from load_pruned_backbone import load_stage2b_backbone
backbone = load_stage2b_backbone('pruned_state_dict.safetensors', 'head_config.json')
```

The loader constructs an Argus ViT-B, walks `head_config.json`, and replaces each block's attention with a `PrunedSelfAttention` sized for the kept heads before copying weights.

## Files

- `stage_2b_structural.py` — the conversion script
- `pruned_state_dict.safetensors` — shrunk backbone weights
- `head_config.json` — per-block `num_heads`, kept-head indices, removed-head indices
- `load_pruned_backbone.py` — loader
- `eval.json` — F1 parity + param delta

## What this buys

- 2.3 % backbone param reduction for free (no F1 cost; +0.022 F1 gain over Stage 0 baseline).
- Smaller forward pass: pruned blocks do less attention compute.
- Sets up Stage 3 (depth reduction) and Stage 4 (specialist backbone) on a smaller starting model.