| # 🧠 Reproducing “Super Weights” in Large Language Models | |
| **Paper:** *The Super Weight in Large Language Models* | |
| **Authors:** Mengxia Yu, De Wang, Qi Shan, Colorado J. Reed, Alvin Wan | |
| **Affiliation:** Apple & University of Notre Dame | |
| **arXiv:** [2411.07191v2 (July 2025)](https://arxiv.org/abs/2411.07191) | |
| --- | |
| ## 🧩 1. Background | |
| Large Language Models (LLMs) often exhibit *outlier weights and activations* — values with extremely large magnitudes that strongly influence model quality. | |
| This paper identifies a **single scalar parameter**, termed a **Super Weight (SW)**, whose removal alone can **destroy a model’s ability to generate text**. | |
| ### Key findings | |
| - Pruning **one scalar** in Llama-7B causes zero-shot accuracy to drop → random guessing. | |
| - The same weight induces a **Super Activation (SA)** — a huge activation spike that persists across layers. | |
| - Both SW and SA can be found **data-free**, with a single forward pass. | |
| - Preserving them dramatically improves **quantization quality**. | |
| --- | |
| ## 🧠 2. Conceptual Overview | |
| | Term | Description | | |
| |------|--------------| | |
| | **Super Weight (SW)** | A single extremely important weight (scalar) in `mlp.down_proj` of an early transformer block. | | |
| | **Super Activation (SA)** | The corresponding massive activation value generated by SW; propagates via skip connections. | | |
| | **Effect of Pruning SW** | Model generates gibberish output, perplexity ↑ ×1000, zero-shot accuracy ↓ ≈ 35 points. | | |
| | **Effect of Restoring SA** | Restores ≈ 40 % of performance loss → shows SW works partly through SA. | | |
| --- | |
| ## ⚙️ 3. How to Find Super Weights (Data-Free Method) | |
| ### Step 1 — Locate MLP Layers | |
| In each Transformer block, focus on the **MLP down-projection** (`mlp.down_proj`) module. | |
| ### Step 2 — Forward Pass | |
| Run one forward pass with any prompt ( no dataset required ): | |
| ```python | |
| prompt = "My favorite food is" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model(**inputs) | |
| ``` | |
| ### Step 3 — Record Activations | |
| Hook the input and output of each `down_proj` layer to capture activations: | |
| ```python | |
| activations = {} | |
| def hook_fn(module, inp, out): | |
| activations[module] = (inp[0].detach(), out.detach()) | |
| for i, layer in enumerate(model.model.layers): | |
| layer.mlp.down_proj.register_forward_hook(hook_fn) | |
| model(**inputs) | |
| ``` | |
| ### Step 4 — Find Activation Spikes | |
| For each layer, compute maximum absolute values per channel: | |
| ```python | |
| max_in = torch.max(torch.abs(inp), dim=0).values | |
| max_out = torch.max(torch.abs(out), dim=0).values | |
| ``` | |
| Plot or inspect their peaks across layers (Figure 3 of the paper). | |
| A layer with a **sharp activation spike** indicates presence of a Super Weight. | |
| ### Step 5 — Determine Coordinates | |
| - Row index = channel of max output (`out.argmax()` → row) | |
| - Column index = channel of max input (`inp.argmax()` → col) | |
| - The Super Weight is: | |
| ``` | |
| model.layers[layer_id].mlp.down_proj.weight[row, col] | |
| ``` | |
| Example (Llama-7B): `layer[2].mlp.down_proj.weight[3968, 7003]`. | |
| --- | |
| ## 🧮 4. Mathematical Explanation | |
| For down-projection layer: | |
| \[ | |
| Y = X W^T | |
| \] | |
| If a super activation \( Y_{ij} \) is dominant, | |
| then it is mainly produced by one large input–weight pair \((X_{ik}, W_{jk})\). | |
| Detecting the indices of extreme \( X_{ik} \) and \( Y_{ij} \) reveals the Super Weight \( W_{jk} \). | |
| --- | |
| ## 📋 5. Known Super Weight Coordinates (Table 2) | |
| | Model | Layer | Type | Coordinates [row, col] | | |
| |:------|:------|:------|:------| | |
| | **Llama-7B** | 2 | mlp.down_proj | [3968, 7003] | | |
| | **Llama-13B** | 2 | mlp.down_proj | [2231, 2278], [2231, 6939] | | |
| | **Llama-30B** | 3 / 10 | mlp.down_proj | [5633, 12817], [5633, 17439], [5633, 14386] | | |
| | **Llama-2 7B** | 1 | mlp.down_proj | [2533, 7890] | | |
| | **Mistral-7B** | 1 | mlp.down_proj | [2070, 7310] | | |
| | **OLMo-7B** | 1 / 2 / 7 / 24 | mlp.down_proj | [269, 7467], [269, 8275], [269, 453], [269, 2300] | | |
| | **Phi-3 mini-4k-instruct** | 2 / 4 | mlp.down_proj | [525, 808], [1113, 2723], … | | |
| --- | |
| ## 🧪 6. Verification Procedure | |
| ### ✅ Step A — Pruning Test | |
| ```python | |
| row, col = 3968, 7003 | |
| model.model.layers[2].mlp.down_proj.weight[row, col] = 0 | |
| ``` | |
| Then generate text: | |
| ```python | |
| print(model.generate(**tokenizer("My favorite condiment is", return_tensors="pt"))) | |
| ``` | |
| → If output becomes gibberish → found SW successfully. | |
| ### ✅ Step B — Super Activation Restoration | |
| Record that activation value before pruning, restore it manually after pruning to verify partial recovery. | |
| --- | |
| ## ⚡ 7. Practical PyTorch Snippet | |
| ```python | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_name = "meta-llama/Llama-2-7b-hf" | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| prompt = "Hello world" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| def find_super_weight(model, inputs, threshold=3.0): | |
| for i, layer in enumerate(model.model.layers): | |
| x = layer.mlp.gate_proj(inputs['input_ids'].float()) | |
| y = layer.mlp.down_proj(x) | |
| max_in, idx_in = torch.max(torch.abs(x), dim=1) | |
| max_out, idx_out = torch.max(torch.abs(y), dim=1) | |
| if max_in.max() > threshold and max_out.max() > threshold: | |
| print(f"[Layer {i}] Super weight candidate at ({idx_out.item()}, {idx_in.item()})") | |
| find_super_weight(model, inputs) | |
| ``` | |
| --- | |
| ## 📈 8. Interpretation & Use Cases | |
| | Use Case | Effect of Preserving SW/SA | | |
| |:--|:--| | |
| | **Quantization** | Enhances simple round-to-nearest (INT4/INT8) to ≈ 70–80 % of SmoothQuant quality. | | |
| | **Model Compression** | Allows larger block sizes (e.g., 512×512) with less degradation. | | |
| | **Explainability** | Reveals that a few weights govern semantic token probabilities (stopword suppression). | | |
| --- | |
| ## 🧭 9. Summary | |
| - 🧩 Super Weights exist — a few scalars dominate LLM behavior. | |
| - ⚙️ They can be found with a single forward pass. | |
| - ⚡ Preserving them is vital for model compression and quantization. | |
| - 📊 Author released a directory of SW coordinates for open LLMs. | |
| --- | |
| ## 📚 10. References | |
| Yu et al., **“The Super Weight in Large Language Models,”** arXiv:2411.07191v2, 2025. | |
| Sun et al., *Massive Activations in Large Language Models*, ICLR Workshop 2024. | |
| Dettmers et al., *GPTQ / AWQ / SmoothQuant* (2022–2024). | |
| --- | |