s1ghhh's picture
Upload folder using huggingface_hub
d73500e verified
# 🧠 Reproducing “Super Weights” in Large Language Models
**Paper:** *The Super Weight in Large Language Models*
**Authors:** Mengxia Yu, De Wang, Qi Shan, Colorado J. Reed, Alvin Wan
**Affiliation:** Apple & University of Notre Dame
**arXiv:** [2411.07191v2 (July 2025)](https://arxiv.org/abs/2411.07191)
---
## 🧩 1. Background
Large Language Models (LLMs) often exhibit *outlier weights and activations* — values with extremely large magnitudes that strongly influence model quality.
This paper identifies a **single scalar parameter**, termed a **Super Weight (SW)**, whose removal alone can **destroy a model’s ability to generate text**.
### Key findings
- Pruning **one scalar** in Llama-7B causes zero-shot accuracy to drop → random guessing.
- The same weight induces a **Super Activation (SA)** — a huge activation spike that persists across layers.
- Both SW and SA can be found **data-free**, with a single forward pass.
- Preserving them dramatically improves **quantization quality**.
---
## 🧠 2. Conceptual Overview
| Term | Description |
|------|--------------|
| **Super Weight (SW)** | A single extremely important weight (scalar) in `mlp.down_proj` of an early transformer block. |
| **Super Activation (SA)** | The corresponding massive activation value generated by SW; propagates via skip connections. |
| **Effect of Pruning SW** | Model generates gibberish output, perplexity ↑ ×1000, zero-shot accuracy ↓ ≈ 35 points. |
| **Effect of Restoring SA** | Restores ≈ 40 % of performance loss → shows SW works partly through SA. |
---
## ⚙️ 3. How to Find Super Weights (Data-Free Method)
### Step 1 — Locate MLP Layers
In each Transformer block, focus on the **MLP down-projection** (`mlp.down_proj`) module.
### Step 2 — Forward Pass
Run one forward pass with any prompt ( no dataset required ):
```python
prompt = "My favorite food is"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model(**inputs)
```
### Step 3 — Record Activations
Hook the input and output of each `down_proj` layer to capture activations:
```python
activations = {}
def hook_fn(module, inp, out):
activations[module] = (inp[0].detach(), out.detach())
for i, layer in enumerate(model.model.layers):
layer.mlp.down_proj.register_forward_hook(hook_fn)
model(**inputs)
```
### Step 4 — Find Activation Spikes
For each layer, compute maximum absolute values per channel:
```python
max_in = torch.max(torch.abs(inp), dim=0).values
max_out = torch.max(torch.abs(out), dim=0).values
```
Plot or inspect their peaks across layers (Figure 3 of the paper).
A layer with a **sharp activation spike** indicates presence of a Super Weight.
### Step 5 — Determine Coordinates
- Row index = channel of max output (`out.argmax()` → row)
- Column index = channel of max input (`inp.argmax()` → col)
- The Super Weight is:
```
model.layers[layer_id].mlp.down_proj.weight[row, col]
```
Example (Llama-7B): `layer[2].mlp.down_proj.weight[3968, 7003]`.
---
## 🧮 4. Mathematical Explanation
For down-projection layer:
\[
Y = X W^T
\]
If a super activation \( Y_{ij} \) is dominant,
then it is mainly produced by one large input–weight pair \((X_{ik}, W_{jk})\).
Detecting the indices of extreme \( X_{ik} \) and \( Y_{ij} \) reveals the Super Weight \( W_{jk} \).
---
## 📋 5. Known Super Weight Coordinates (Table 2)
| Model | Layer | Type | Coordinates [row, col] |
|:------|:------|:------|:------|
| **Llama-7B** | 2 | mlp.down_proj | [3968, 7003] |
| **Llama-13B** | 2 | mlp.down_proj | [2231, 2278], [2231, 6939] |
| **Llama-30B** | 3 / 10 | mlp.down_proj | [5633, 12817], [5633, 17439], [5633, 14386] |
| **Llama-2 7B** | 1 | mlp.down_proj | [2533, 7890] |
| **Mistral-7B** | 1 | mlp.down_proj | [2070, 7310] |
| **OLMo-7B** | 1 / 2 / 7 / 24 | mlp.down_proj | [269, 7467], [269, 8275], [269, 453], [269, 2300] |
| **Phi-3 mini-4k-instruct** | 2 / 4 | mlp.down_proj | [525, 808], [1113, 2723], … |
---
## 🧪 6. Verification Procedure
### ✅ Step A — Pruning Test
```python
row, col = 3968, 7003
model.model.layers[2].mlp.down_proj.weight[row, col] = 0
```
Then generate text:
```python
print(model.generate(**tokenizer("My favorite condiment is", return_tensors="pt")))
```
→ If output becomes gibberish → found SW successfully.
### ✅ Step B — Super Activation Restoration
Record that activation value before pruning, restore it manually after pruning to verify partial recovery.
---
## ⚡ 7. Practical PyTorch Snippet
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "Hello world"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
def find_super_weight(model, inputs, threshold=3.0):
for i, layer in enumerate(model.model.layers):
x = layer.mlp.gate_proj(inputs['input_ids'].float())
y = layer.mlp.down_proj(x)
max_in, idx_in = torch.max(torch.abs(x), dim=1)
max_out, idx_out = torch.max(torch.abs(y), dim=1)
if max_in.max() > threshold and max_out.max() > threshold:
print(f"[Layer {i}] Super weight candidate at ({idx_out.item()}, {idx_in.item()})")
find_super_weight(model, inputs)
```
---
## 📈 8. Interpretation & Use Cases
| Use Case | Effect of Preserving SW/SA |
|:--|:--|
| **Quantization** | Enhances simple round-to-nearest (INT4/INT8) to ≈ 70–80 % of SmoothQuant quality. |
| **Model Compression** | Allows larger block sizes (e.g., 512×512) with less degradation. |
| **Explainability** | Reveals that a few weights govern semantic token probabilities (stopword suppression). |
---
## 🧭 9. Summary
- 🧩 Super Weights exist — a few scalars dominate LLM behavior.
- ⚙️ They can be found with a single forward pass.
- ⚡ Preserving them is vital for model compression and quantization.
- 📊 Author released a directory of SW coordinates for open LLMs.
---
## 📚 10. References
Yu et al., **“The Super Weight in Large Language Models,”** arXiv:2411.07191v2, 2025.
Sun et al., *Massive Activations in Large Language Models*, ICLR Workshop 2024.
Dettmers et al., *GPTQ / AWQ / SmoothQuant* (2022–2024).
---