| # Dissecting BERT Layers: FFN Dual Role, Separability-Guided Layer Skip, and Interpretable Classification |
|
|
| Artifacts for the paper by Yeonseong Cynn (River Lab, May 2026). |
|
|
| ## Summary |
|
|
| Layer-level analysis framework for BERT across five GLUE tasks (SST-2, CoLA, MRPC, QNLI, RTE). |
|
|
| Key findings: |
| - **Separability-guided layer skip**: identifies removable layers via separability delta analysis, validated by actual BERT forward-pass experiments |
| - **FFN dual role**: 92% structural (norm normalization) vs. 8% classification-relevant, explaining why FFN removal collapses models while individual layers appear "harmful" |
| - **Error analysis**: 60-93% of misclassifications are high-confidence errors (margin > 0.3), indicating BERT's CLS representation is the bottleneck |
|
|
| ## Files |
|
|
| ### Weights |
| - `bert_sst2_prune_masks.npz` β Per-layer FFN neuron pruning masks (0/1) for BERT SST-2. Keys: `mask_L0` through `mask_L11`, each shape `(3072,)`. |
|
|
| ### Results (JSON) |
| - `results/{task}_layer_analysis.json` β Layer separability metrics, delta changes, and FFN structural/classification ratio per task |
| - `results/{task}_skip_results.json` β Single and multi-layer skip experiment results per task |
|
|
| Tasks: `sst2`, `cola`, `mrpc`, `qnli`, `rte` |
|
|
| ### Figures |
| - `figures/fig1_separability.png` β Layer separability curves across 5 GLUE tasks |
| - `figures/fig2_ffn_ratio.png` β FFN structural/classification ratio heatmap (log scale) |
| - `figures/fig3_errors.png` β Error direction and confidence analysis |
| - `figures/fig4_skip_prediction.png` β Separability prediction vs actual skip accuracy |
|
|
| ## Usage |
|
|
| ### Loading pruning masks |
| ```python |
| import numpy as np |
| |
| masks = np.load("bert_sst2_prune_masks.npz") |
| for layer in range(12): |
| mask = masks[f"mask_L{layer}"] # (3072,) binary mask |
| kept = mask.sum() |
| print(f"L{layer+1}: {int(kept)}/3072 neurons kept ({kept/3072*100:.0f}%)") |
| ``` |
|
|
| ### Applying masks to BERT |
| ```python |
| import torch |
| from transformers import BertForSequenceClassification |
| |
| model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2") |
| masks = np.load("bert_sst2_prune_masks.npz") |
| |
| # Zero out pruned neurons in FFN intermediate layer |
| for layer_idx in range(12): |
| mask = torch.tensor(masks[f"mask_L{layer_idx}"], dtype=torch.float32) |
| ffn = model.bert.encoder.layer[layer_idx].intermediate.dense |
| ffn.weight.data *= mask.unsqueeze(1) |
| ffn.bias.data *= mask |
| ``` |
|
|
| ## Base Model |
|
|
| All experiments use [textattack/bert-base-uncased-SST-2](https://huggingface.co/textattack/bert-base-uncased-SST-2) and corresponding task-specific fine-tuned models from [textattack](https://huggingface.co/textattack). |
|
|
| ## License |
|
|
| MIT |
|
|