YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Dissecting BERT Layers: FFN Dual Role, Separability-Guided Layer Skip, and Interpretable Classification

Artifacts for the paper by Yeonseong Cynn (River Lab, May 2026).

Summary

Layer-level analysis framework for BERT across five GLUE tasks (SST-2, CoLA, MRPC, QNLI, RTE).

Key findings:

  • Separability-guided layer skip: identifies removable layers via separability delta analysis, validated by actual BERT forward-pass experiments
  • FFN dual role: 92% structural (norm normalization) vs. 8% classification-relevant, explaining why FFN removal collapses models while individual layers appear "harmful"
  • Error analysis: 60-93% of misclassifications are high-confidence errors (margin > 0.3), indicating BERT's CLS representation is the bottleneck

Files

Weights

  • bert_sst2_prune_masks.npz β€” Per-layer FFN neuron pruning masks (0/1) for BERT SST-2. Keys: mask_L0 through mask_L11, each shape (3072,).

Results (JSON)

  • results/{task}_layer_analysis.json β€” Layer separability metrics, delta changes, and FFN structural/classification ratio per task
  • results/{task}_skip_results.json β€” Single and multi-layer skip experiment results per task

Tasks: sst2, cola, mrpc, qnli, rte

Figures

  • figures/fig1_separability.png β€” Layer separability curves across 5 GLUE tasks
  • figures/fig2_ffn_ratio.png β€” FFN structural/classification ratio heatmap (log scale)
  • figures/fig3_errors.png β€” Error direction and confidence analysis
  • figures/fig4_skip_prediction.png β€” Separability prediction vs actual skip accuracy

Usage

Loading pruning masks

import numpy as np

masks = np.load("bert_sst2_prune_masks.npz")
for layer in range(12):
    mask = masks[f"mask_L{layer}"]  # (3072,) binary mask
    kept = mask.sum()
    print(f"L{layer+1}: {int(kept)}/3072 neurons kept ({kept/3072*100:.0f}%)")

Applying masks to BERT

import torch
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2")
masks = np.load("bert_sst2_prune_masks.npz")

# Zero out pruned neurons in FFN intermediate layer
for layer_idx in range(12):
    mask = torch.tensor(masks[f"mask_L{layer_idx}"], dtype=torch.float32)
    ffn = model.bert.encoder.layer[layer_idx].intermediate.dense
    ffn.weight.data *= mask.unsqueeze(1)
    ffn.bias.data *= mask

Base Model

All experiments use textattack/bert-base-uncased-SST-2 and corresponding task-specific fine-tuned models from textattack.

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support