File size: 2,673 Bytes
2c6b8f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Dissecting BERT Layers: FFN Dual Role, Separability-Guided Layer Skip, and Interpretable Classification

Artifacts for the paper by Yeonseong Cynn (River Lab, May 2026).

## Summary

Layer-level analysis framework for BERT across five GLUE tasks (SST-2, CoLA, MRPC, QNLI, RTE).

Key findings:
- **Separability-guided layer skip**: identifies removable layers via separability delta analysis, validated by actual BERT forward-pass experiments
- **FFN dual role**: 92% structural (norm normalization) vs. 8% classification-relevant, explaining why FFN removal collapses models while individual layers appear "harmful"
- **Error analysis**: 60-93% of misclassifications are high-confidence errors (margin > 0.3), indicating BERT's CLS representation is the bottleneck

## Files

### Weights
- `bert_sst2_prune_masks.npz` — Per-layer FFN neuron pruning masks (0/1) for BERT SST-2. Keys: `mask_L0` through `mask_L11`, each shape `(3072,)`.

### Results (JSON)
- `results/{task}_layer_analysis.json` — Layer separability metrics, delta changes, and FFN structural/classification ratio per task
- `results/{task}_skip_results.json` — Single and multi-layer skip experiment results per task

Tasks: `sst2`, `cola`, `mrpc`, `qnli`, `rte`

### Figures
- `figures/fig1_separability.png` — Layer separability curves across 5 GLUE tasks
- `figures/fig2_ffn_ratio.png` — FFN structural/classification ratio heatmap (log scale)
- `figures/fig3_errors.png` — Error direction and confidence analysis
- `figures/fig4_skip_prediction.png` — Separability prediction vs actual skip accuracy

## Usage

### Loading pruning masks
```python
import numpy as np

masks = np.load("bert_sst2_prune_masks.npz")
for layer in range(12):
    mask = masks[f"mask_L{layer}"]  # (3072,) binary mask
    kept = mask.sum()
    print(f"L{layer+1}: {int(kept)}/3072 neurons kept ({kept/3072*100:.0f}%)")
```

### Applying masks to BERT
```python
import torch
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
masks = np.load("bert_sst2_prune_masks.npz")

# Zero out pruned neurons in FFN intermediate layer
for layer_idx in range(12):
    mask = torch.tensor(masks[f"mask_L{layer_idx}"], dtype=torch.float32)
    ffn = model.bert.encoder.layer[layer_idx].intermediate.dense
    ffn.weight.data *= mask.unsqueeze(1)
    ffn.bias.data *= mask
```

## Base Model

All experiments use [textattack/bert-base-uncased-SST-2](https://huggingface.co/textattack/bert-base-uncased-SST-2) and corresponding task-specific fine-tuned models from [textattack](https://huggingface.co/textattack).

## License

MIT