Initial commit
Browse files- README.md +307 -296
- __init__.py +28 -0
- bench.py +176 -0
- coherence_eval.py +834 -0
- config.py +306 -0
- data.py +546 -0
- generate.py +195 -0
- graft_g2lu.py +300 -0
- layers.py +325 -0
- lm_eval_wrapper.py +344 -0
- mirrored.py +532 -0
- model.py +357 -0
- scripts/__init__.py +0 -0
- scripts/representation_analysis.py +1014 -0
- scripts/spectral_analysis.py +969 -0
- scripts/spectral_to_csv.py +202 -0
- train.py +637 -0
README.md
CHANGED
|
@@ -1,296 +1,307 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
datasets:
|
| 4 |
-
- Bingsu/openwebtext_20p
|
| 5 |
-
- HuggingFaceFW/fineweb-edu
|
| 6 |
-
language:
|
| 7 |
-
- en
|
| 8 |
-
pipeline_tag: text-generation
|
| 9 |
-
---
|
| 10 |
-
# Prisma
|
| 11 |
-
|
| 12 |
-
A prototype model that is assembled as a mirrored transformer architecture with nested gating (adds an extra weight to the FFN) and morphological position encoding. It proposes that the model architecture creates different scaffolding, leading to different training regimens and capabilities.
|
| 13 |
-
|
| 14 |
-
Prisma is only viable as it piggybacks on pre-trained tokenizers and their weight-tied embeddings, it decomposes the transformer architecture into symmetric **expand** and **compress** phases that share structural weights, connected by a small number of unique **middle** layers. Information expands from tokens to semantics, then compresses back — like light through a prism.
|
| 15 |
-
|
| 16 |
-
```
|
| 17 |
-
Token Embeddings
|
| 18 |
-
|
|
| 19 |
-
[ Expand ] ─── mirror pair 1 (W1, W2 shared) ── G²LU gate (W3·W4)
|
| 20 |
-
[ Expand ] ─── mirror pair 2
|
| 21 |
-
[ .... ] ─── mirror pair N
|
| 22 |
-
|
|
| 23 |
-
[ Middle ] ─── unique layers (full capacity, not shared)
|
| 24 |
-
|
|
| 25 |
-
[Compress ] ─── mirror pair N (same W1, W2 as expand N)
|
| 26 |
-
[Compress ] ─── mirror pair 2
|
| 27 |
-
[Compress ] ─── mirror pair 1
|
| 28 |
-
|
|
| 29 |
-
LM Head (weight-tied to embeddings)
|
| 30 |
-
```
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
## Key Concepts
|
| 34 |
-
|
| 35 |
-
**Mirrored layers.** Each expand layer shares W1 (projection) and W2 (output) weights with its corresponding compress layer. The architecture gets 2N virtual layers of processing from N unique parameter sets. At 357M parameters, Prisma runs 41 virtual layers from ~20 unique weight sets + 1 middle layer.
|
| 36 |
-
|
| 37 |
-
**G²LU — Gated-Gated Linear Unit.** The gate is itself gated:
|
| 38 |
-
|
| 39 |
-
Where typical gated transformers have `y = W2 @ (W1 @ x * silu(W3 @ x))`, Prisma has:
|
| 40 |
-
```python
|
| 41 |
-
g4 = silu(W4 @ x) # inner gate
|
| 42 |
-
g3 = silu(W3 @ x * g4) # outer gate, modulated by inner
|
| 43 |
-
y = W2 @ (W1 @ x * g3) # gated output
|
| 44 |
-
```
|
| 45 |
-
|
| 46 |
-
One gate in function of the other. Creates quadratic (saddle-surface) decision boundaries instead of linear hyperplanes — each neuron computes a conjunction ("feature A AND feature B") rather than a single threshold. This produces narrow, separated activation channels that resist memorization and tolerate significantly higher learning rates. Part of the parameters saved with mirroring are re-distributed as W4.
|
| 47 |
-
|
| 48 |
-
**WoRPE — Word-position Rotary Position Embedding.** Dedicates a small subspace of each attention head to encode position within a word (0 = prefix, 1 = second subword, ...). The information is already in the BPE tokenizer's word-boundary markers — WoRPE surfaces it geometrically so the model doesn't have to rediscover it. No new tokenizer required.
|
| 49 |
-
|
| 50 |
-
**Auxiliary skip prediction.** An optional second head predicts t+K tokens ahead, providing gradient signal that rewards structural representations over local memorization. At K=1, functions as a dual-supervision regularizer through an untied projection.
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
## Results
|
| 54 |
-
|
| 55 |
-
### ~50M scale prototype (WikiText-103, 4 epochs)
|
| 56 |
-
|
| 57 |
-
| Model | Params | LR | WikiText PPL | LAMBADA |
|
| 58 |
-
|---|---|---|---|---|
|
| 59 |
-
| Standard SwiGLU | 51M | 1e-4 | 4125 | 0.002 |
|
| 60 |
-
| Prisma (G²LU) | 47M | 1e-4 | 2914 | 0.001 |
|
| 61 |
-
| Prisma (G²LU + WoRPE) | 51M | 1e-2 | 921 | 0.082 |
|
| 62 |
-
|
| 63 |
-
*Standard trained 10 epochs; Prisma (G²LU + WoRPE) shown at 1 epoch — the point is LR tolerance, not epoch-matched comparison.*
|
| 64 |
-
|
| 65 |
-
The regularization stack (mirroring + G²LU + WoRPE) enables training at **100x the standard learning rate** without instability.
|
| 66 |
-
|
| 67 |
-
### ~350M scale prototype — comparison with published models
|
| 68 |
-
|
| 69 |
-
Prisma 357M trained on ~30B tokens (OpenWebText 20% + FineWeb-Edu 10BT continued training), compared against published models at similar scale:
|
| 70 |
-
|
| 71 |
-
| Model | Params | Train Data | ARC-C\* | ARC-E\* | BoolQ | HellaSwag\* | LAMBADA | PIQA\* | WikiText\*\* | WinoGrande |
|
| 72 |
-
|---|---|---|---|---|---|---|---|---|---|---|
|
| 73 |
-
| GPT-2 medium | 355M | 40B | 0.250 | 0.436 | 0.586 | 0.394 | **0.430** | 0.664 | **26.75** | **0.531** |
|
| 74 |
-
| Baguettotron | 321M | 200B | 0.302 | 0.506 | 0.589 | 0.354 | 0.294 | 0.618 | 30.93 | 0.530 |
|
| 75 |
-
| SmolLM-360M | 360M | 600B | **0.359** | **0.640** | 0.550 | **0.536** | **0.455** | **0.715** | **19.49** | **0.570** |
|
| 76 |
-
| SmolLM2-360M | 360M | 4000B | **0.381** | **0.681** | 0.617 | 0.431 | **0.532** | **0.718** | **15.67** | **0.586** |
|
| 77 |
-
| LFM2-350M | 350M | 10000B | **0.393** | **0.662** | **0.642** | **0.489** | 0.399 | 0.698 | 25.68 | **0.558** |
|
| 78 |
-
| **Prisma** | **357M** | **30B** | 0.290 | **0.548** | **0.620** | **0.427** | 0.362 | **0.670** | 27.40 | 0.506 |
|
| 79 |
-
|
| 80 |
-
\* *normalized accuracy* · \*\* *word perplexity*
|
| 81 |
-
|
| 82 |
-
**Key findings:**
|
| 83 |
-
- **Beats GPT-2 medium on 5/8 benchmarks** (ARC-C, ARC-E, BoolQ, HellaSwag, PIQA) with 25% less training data.
|
| 84 |
-
- **Beats Baguettotron (200B) on 6/8 benchmarks** — including PPL — with **7x less data.**
|
| 85 |
-
- **BoolQ 0.620** exceeds all models except LFM2 (10000B) and SmolLM2 (4000B). The anti-memorization properties of G²LU force genuine comprehension instead of statistical shortcuts.
|
| 86 |
-
- **ARC-Easy 0.548** — the largest absolute gain over GPT-2 medium (+11.2pp). FineWeb-Edu knowledge absorbed efficiently through G²LU's relational features.
|
| 87 |
-
- Prisma wins on **reasoning benchmarks** (ARC, HellaSwag, PIQA, BoolQ). Models trained on 20-300x more data win on **content prediction** (LAMBADA, PPL). The architecture trades raw memorization for data-efficient knowledge application.
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
### Training progression (~350M)
|
| 91 |
-
|
| 92 |
-
| Stage | LR | ARC-C | ARC-E | BoolQ | HellaSwag | LAMBADA | PIQA | PPL |
|
| 93 |
-
|---|---|---|---|---|---|---|---|---|
|
| 94 |
-
| Standard 336M baseline | 1e-4 | 0.228 | 0.341 | 0.618 | 0.280 | 0.226 | 0.574 | 77.2 |
|
| 95 |
-
| Prisma 41L (OWT 20%) | 5e-4 | 0.238 | 0.394 | 0.585 | 0.317 | 0.313 | 0.614 | 44.8 |
|
| 96 |
-
| + WoRPE (OWT 20%) | 1e-3 | 0.247 | 0.397 | 0.595 | 0.331 | 0.333 | 0.614 | 43.5 |
|
| 97 |
-
| + continued (FineWeb c1) | 1e-3 | 0.249 | 0.434 | 0.601 | 0.333 | 0.312 | 0.626 | 34.7 |
|
| 98 |
-
| + continued (FineWeb c2) | 1e-3 | 0.290 | 0.548 | 0.620 | 0.427 | 0.362 | 0.670 | 27.4 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
## Quick Start
|
| 102 |
-
|
| 103 |
-
### Install
|
| 104 |
-
|
| 105 |
-
```bash
|
| 106 |
-
pip install -r
|
| 107 |
-
```
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
### Train
|
| 111 |
-
|
| 112 |
-
```bash
|
| 113 |
-
# Small Prisma (~47M) on WikiText-103
|
| 114 |
-
python -m
|
| 115 |
-
--arch mirrored --dims 384 --heads 6 --kv-heads 2 --layers 57 --n-middle 1 \
|
| 116 |
-
--tokenizer facebook/MobileLLM-125M \
|
| 117 |
-
--word-rope-dims 8 --word-rope-base 10.0 \
|
| 118 |
-
--data hf:wikitext:wikitext-103-raw-v1:train \
|
| 119 |
-
--epochs 4 --batch-size 32 --context-length 512 \
|
| 120 |
-
--lr 1e-2 --warmup-steps 500 --bf16 --gpu 0
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# 324M Prisma on OpenWebText
|
| 124 |
-
python -m
|
| 125 |
-
--gpu 0 --compile --bf16 --arch mirrored \
|
| 126 |
-
--dims 1024 --heads 16 --kv-heads 4 --layers 41 --n-middle 1 \
|
| 127 |
-
--word-rope-dims 8 --word-rope-base 10.0 \
|
| 128 |
-
--tokenizer facebook/MobileLLM-125M \
|
| 129 |
-
--data hf:Bingsu/openwebtext_20p --text-column text \
|
| 130 |
-
--epochs 4 --batch-size 12 --context-length 1024 --grad-accum 42 \
|
| 131 |
-
--lr 1e-3 --warmup 500 \
|
| 132 |
-
--log-every 5 --val-every 1000 --save-every 1000 \
|
| 133 |
-
--checkpoint-dir path/to/checkpoints/your_model/
|
| 134 |
-
```
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
### Generate
|
| 138 |
-
|
| 139 |
-
```bash
|
| 140 |
-
python -m
|
| 141 |
-
--checkpoint path/to/checkpoints/your_model/best.pt \
|
| 142 |
-
--prompt "A thought observing itself discovers that it" \
|
| 143 |
-
--max-tokens 256 --temperature 0.8 --top-p 0.95
|
| 144 |
-
```
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
### Benchmark
|
| 148 |
-
|
| 149 |
-
```bash
|
| 150 |
-
# Single model
|
| 151 |
-
python -m
|
| 152 |
-
--checkpoint path/to/checkpoints/your_model/best.pt \
|
| 153 |
-
--tasks arc_easy,lambada_openai,piqa,hellaswag,winogrande,wikitext \
|
| 154 |
-
--gpu 0
|
| 155 |
-
```
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
## CLI Reference
|
| 159 |
-
|
| 160 |
-
### Architecture
|
| 161 |
-
|
| 162 |
-
| Flag | Default | Description |
|
| 163 |
-
|---|---|---|
|
| 164 |
-
| `--arch` | `mirrored` | Architecture: `standard`, `mirrored`, `graft_g2lu` (experimental) |
|
| 165 |
-
| `--dims` | 512 | Hidden dimension |
|
| 166 |
-
| `--heads` | 8 | Number of attention heads |
|
| 167 |
-
| `--kv-heads` | — | KV heads for GQA (omit = MHA) |
|
| 168 |
-
| `--layers` | 12 | Total virtual layers (expand + middle + compress) |
|
| 169 |
-
| `--n-middle` | 2 | Unique (non-mirrored) middle layers |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
### Prisma-Specific
|
| 173 |
-
|
| 174 |
-
| Flag | Default | Description |
|
| 175 |
-
|---|---|---|
|
| 176 |
-
| `--word-rope-dims` | 0 | Head dims for WoRPE (0 = disabled, try 8) |
|
| 177 |
-
| `--word-rope-base` | 10.0 | WoRPE frequency base |
|
| 178 |
-
| `--aux-skip` | 0 | Skip-ahead prediction distance (0 = disabled) |
|
| 179 |
-
| `--aux-weight` | 0.1 | Weight for auxiliary loss |
|
| 180 |
-
| `--no-g2lu` | — | Disable G²LU, use standard SwiGLU in mirrored arch |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
### Training
|
| 184 |
-
|
| 185 |
-
| Flag | Default | Description |
|
| 186 |
-
|---|---|---|
|
| 187 |
-
| `--lr` | 3e-4 | Peak learning rate |
|
| 188 |
-
| `--min-lr` | 0.0 | LR floor for cosine schedule |
|
| 189 |
-
| `--warmup-steps` | 100 | LR warmup steps |
|
| 190 |
-
| `--epochs` | 10 | Training epochs |
|
| 191 |
-
| `--batch-size` | 32 | Micro-batch size |
|
| 192 |
-
| `--grad-accum` | 1 | Gradient accumulation steps |
|
| 193 |
-
| `--context-length` | 512 | Sequence length |
|
| 194 |
-
| `--bf16` / `--fp16` | — | Mixed precision |
|
| 195 |
-
| `--compile` | — | `torch.compile` the model |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
### Data
|
| 199 |
-
|
| 200 |
-
| Flag | Default | Description |
|
| 201 |
-
|---|---|---|
|
| 202 |
-
| `--data` | — | Path or `hf:dataset_name` |
|
| 203 |
-
| `--text-column` | `text` | Column name for HF datasets |
|
| 204 |
-
| `--tokenizer` | `gpt2` | Tokenizer name or path |
|
| 205 |
-
| `--num-samples` | — | Limit dataset size |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
## Architecture Details
|
| 209 |
-
|
| 210 |
-
### Why Mirroring Works
|
| 211 |
-
|
| 212 |
-
Mirroring only works due to the additional gate. W3 and W4 specialize to serve different roles despite sharing weights — spectral analysis confirms the gates swap their stable-rank profiles at the architectural midpoint. The order of mirror layers may be rearrangeable, as the gates adapt to whatever representations flow through them.
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
### Why G²LU Works
|
| 216 |
-
|
| 217 |
-
Standard SwiGLU creates hyperplane decision boundaries — broad, overlapping activation regions. G²LU's nested gate creates **saddle surfaces** — narrow activation bands with isolation gaps (like a spectral comb filter). This has three effects:
|
| 218 |
-
|
| 219 |
-
1. **Anti-memorization.** The gate geometry cannot form sharp, input-specific activations. The model is forced toward broad, relational features.
|
| 220 |
-
2. **Higher LR tolerance.** Narrow activation bands leave headroom between features. Large gradient updates shift features within their bands without colliding.
|
| 221 |
-
3. **Compositional detection.** Each neuron natively computes conjunctions (A AND B), not just thresholds. Might be useful for morphology, syntax, and structural reasoning.
|
| 222 |
-
|
| 223 |
-
G²LU can be seen as occupying a point between standard GLU (fixed activation, fixed gate) and KAN (fully learned activations): the activation function is fixed (silu), but its effective shape adapts per-input through the nested gate.
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
### Why WoRPE Works
|
| 227 |
-
|
| 228 |
-
BPE tokenizers already mark word boundaries (`Ġ` for GPT-2, `▁` for SentencePiece). WoRPE surfaces this information geometrically in a dedicated subspace of the rotary embedding, so the model gets word-internal position for free instead of rediscovering it from attention patterns. Requires G²LU to exploit effectively — the saddle surfaces compute morphological conjunctions ("position-0 AND prefix-pattern") that single gates cannot.
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
### Why Everything Works Together
|
| 232 |
-
|
| 233 |
-
The optimization landscape of this architecture is substantially more complex than a standard transformer — shared weights must serve both directions, nested gates must coordinate, and the hourglass bottleneck constrains information flow. This appears to be only tractable when anchored by pre-trained, weight-tied embeddings that provide a stable coordinate system. The frozen embeddings give the model fixed reference geometry, allowing convergence despite the architectural complexity.
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
## File Map
|
| 237 |
-
|
| 238 |
-
```
|
| 239 |
-
|
| 240 |
-
config.py — CLI arguments, presets, CircuitConfig
|
| 241 |
-
layers.py — RMSNorm, RoPE, WoRPE, CausalAttention, SwiGLU
|
| 242 |
-
model.py — CircuitTransformer (standard baseline)
|
| 243 |
-
mirrored.py — MirroredTransformer, G²LU, MirroredBlock
|
| 244 |
-
train.py — Training loop, LR schedule, checkpointing
|
| 245 |
-
data.py — MemmapDataset, parallel tokenization, HF/text loading
|
| 246 |
-
generate.py — Text generation with KV caching
|
| 247 |
-
bench.py — Benchmark runner and comparison tables
|
| 248 |
-
lm_eval_wrapper.py — EleutherAI lm-eval harness integration
|
| 249 |
-
graft_g2lu.py — Surgical G²LU upgrade for pretrained models (experimental/untested)
|
| 250 |
-
scripts/ — Analysis scripts
|
| 251 |
-
```
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
## Origin
|
| 255 |
-
|
| 256 |
-
Prisma grew from interpretability research on _layer grafting_ (writing in progress) in Llama 3.2, which suggests that one of the ways that transformers might self organize to process language can be seen as like a mirrored structure that expands from tokens to semantics, then compressing back — bringing the interpretive analogy of seeing it as a biconvex lens with fractures or polarizing filters within its body. If the two halves are symmetric structurally, they can share weights. The gate (fractures/polarizing filters) becomes the minimum surgical unit for changing behavior. A single weightset becomes insufficient due to shared weights, which brought the question of how to properly make two gates efficiently collaborate.
|
| 257 |
-
|
| 258 |
-
G²LU emerged from the observation that for a pair of gates to be expressive and atomic, _one gate needs to be in function of the other_.
|
| 259 |
-
|
| 260 |
-
WoRPE emerged from noticing, that tokenizers already carry word structure but positional encodings ignore it — providing hints to the model allows faster convergence during training.
|
| 261 |
-
|
| 262 |
-
The architecture is a processing engine that plugs into pretrained tokenizer embeddings. The tokenizer is load-bearing infrastructure — Prisma operates within a pre-existing coordinate system.
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
## Developer Notes
|
| 266 |
-
|
| 267 |
-
This model is the outcome of a POC done by a single individual with limited resources, further investigation, training and tests are being slowly conducted as time and conditions allow.
|
| 268 |
-
|
| 269 |
-
The proposed architecture was only fully trained on top of `facebook/MobileLLM-125M` tokenizer and weight-tied embeddings. It might be the case that it doesn't work as expected on untied embeddings and it is highly likely that it is impossible to train a model with this architecture without a pre-trained tokenizer.
|
| 270 |
-
|
| 271 |
-
Different arrangements of the architecture (varying middle layer count, mirror depth, width) would likely produce different results. Only this setup — with 1 middle layer — was tested, as a validation of whether the architecture works at all. The extreme case was chosen deliberately: if the bottleneck configuration most prone to failure still produces competitive results, less constrained configurations should too.
|
| 272 |
-
|
| 273 |
-
Factorized dimensions for embeddings and an intermediate down proj before the output head were attempted, and nothing useful came out of it.
|
| 274 |
-
|
| 275 |
-
It is completely unknown if the architecture is beneficial for larger models (1B+) — observations suggests it might.
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
## Training
|
| 279 |
-
|
| 280 |
-
- **Architecture**:
|
| 281 |
-
- 41 layers
|
| 282 |
-
- 20 with shared W1 and W2
|
| 283 |
-
- 1 unique
|
| 284 |
-
- 1024 dimms
|
| 285 |
-
- 16 GQA heads, 4 KV heads (4:1)
|
| 286 |
-
- vocab size 32k
|
| 287 |
-
- RoPE + WoRPE + G²LU
|
| 288 |
-
- **Pretraining tokens**: 30B
|
| 289 |
-
- **Precision**: bfloat16
|
| 290 |
-
- **Tokenizer/Embeddings**: facebook/MobileLLM-125M
|
| 291 |
-
- **Hardware**: 1 H100
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
## Disclaimer
|
| 295 |
-
|
| 296 |
-
This model is developed as a research model and it hasn't been tested thoroughly regarding synthesis and coherence quality, as its size is somewhat limiting. Use it at your own risk.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
datasets:
|
| 4 |
+
- Bingsu/openwebtext_20p
|
| 5 |
+
- HuggingFaceFW/fineweb-edu
|
| 6 |
+
language:
|
| 7 |
+
- en
|
| 8 |
+
pipeline_tag: text-generation
|
| 9 |
+
---
|
| 10 |
+
# Prisma
|
| 11 |
+
|
| 12 |
+
A prototype model that is assembled as a mirrored transformer architecture with nested gating (adds an extra weight to the FFN) and morphological position encoding. It proposes that the model architecture creates different scaffolding, leading to different training regimens and capabilities.
|
| 13 |
+
|
| 14 |
+
Prisma is only viable as it piggybacks on pre-trained tokenizers and their weight-tied embeddings, it decomposes the transformer architecture into symmetric **expand** and **compress** phases that share structural weights, connected by a small number of unique **middle** layers. Information expands from tokens to semantics, then compresses back — like light through a prism.
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
Token Embeddings
|
| 18 |
+
|
|
| 19 |
+
[ Expand ] ─── mirror pair 1 (W1, W2 shared) ── G²LU gate (W3·W4)
|
| 20 |
+
[ Expand ] ─── mirror pair 2
|
| 21 |
+
[ .... ] ─── mirror pair N
|
| 22 |
+
|
|
| 23 |
+
[ Middle ] ─── unique layers (full capacity, not shared)
|
| 24 |
+
|
|
| 25 |
+
[Compress ] ─── mirror pair N (same W1, W2 as expand N)
|
| 26 |
+
[Compress ] ─── mirror pair 2
|
| 27 |
+
[Compress ] ─── mirror pair 1
|
| 28 |
+
|
|
| 29 |
+
LM Head (weight-tied to embeddings)
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## Key Concepts
|
| 34 |
+
|
| 35 |
+
**Mirrored layers.** Each expand layer shares W1 (projection) and W2 (output) weights with its corresponding compress layer. The architecture gets 2N virtual layers of processing from N unique parameter sets. At 357M parameters, Prisma runs 41 virtual layers from ~20 unique weight sets + 1 middle layer.
|
| 36 |
+
|
| 37 |
+
**G²LU — Gated-Gated Linear Unit.** The gate is itself gated:
|
| 38 |
+
|
| 39 |
+
Where typical gated transformers have `y = W2 @ (W1 @ x * silu(W3 @ x))`, Prisma has:
|
| 40 |
+
```python
|
| 41 |
+
g4 = silu(W4 @ x) # inner gate
|
| 42 |
+
g3 = silu(W3 @ x * g4) # outer gate, modulated by inner
|
| 43 |
+
y = W2 @ (W1 @ x * g3) # gated output
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
One gate in function of the other. Creates quadratic (saddle-surface) decision boundaries instead of linear hyperplanes — each neuron computes a conjunction ("feature A AND feature B") rather than a single threshold. This produces narrow, separated activation channels that resist memorization and tolerate significantly higher learning rates. Part of the parameters saved with mirroring are re-distributed as W4.
|
| 47 |
+
|
| 48 |
+
**WoRPE — Word-position Rotary Position Embedding.** Dedicates a small subspace of each attention head to encode position within a word (0 = prefix, 1 = second subword, ...). The information is already in the BPE tokenizer's word-boundary markers — WoRPE surfaces it geometrically so the model doesn't have to rediscover it. No new tokenizer required.
|
| 49 |
+
|
| 50 |
+
**Auxiliary skip prediction.** An optional second head predicts t+K tokens ahead, providing gradient signal that rewards structural representations over local memorization. At K=1, functions as a dual-supervision regularizer through an untied projection.
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## Results
|
| 54 |
+
|
| 55 |
+
### ~50M scale prototype (WikiText-103, 4 epochs)
|
| 56 |
+
|
| 57 |
+
| Model | Params | LR | WikiText PPL | LAMBADA |
|
| 58 |
+
|---|---|---|---|---|
|
| 59 |
+
| Standard SwiGLU | 51M | 1e-4 | 4125 | 0.002 |
|
| 60 |
+
| Prisma (G²LU) | 47M | 1e-4 | 2914 | 0.001 |
|
| 61 |
+
| Prisma (G²LU + WoRPE) | 51M | 1e-2 | 921 | 0.082 |
|
| 62 |
+
|
| 63 |
+
*Standard trained 10 epochs; Prisma (G²LU + WoRPE) shown at 1 epoch — the point is LR tolerance, not epoch-matched comparison.*
|
| 64 |
+
|
| 65 |
+
The regularization stack (mirroring + G²LU + WoRPE) enables training at **100x the standard learning rate** without instability.
|
| 66 |
+
|
| 67 |
+
### ~350M scale prototype — comparison with published models
|
| 68 |
+
|
| 69 |
+
Prisma 357M trained on ~30B tokens (OpenWebText 20% + FineWeb-Edu 10BT continued training), compared against published models at similar scale:
|
| 70 |
+
|
| 71 |
+
| Model | Params | Train Data | ARC-C\* | ARC-E\* | BoolQ | HellaSwag\* | LAMBADA | PIQA\* | WikiText\*\* | WinoGrande |
|
| 72 |
+
|---|---|---|---|---|---|---|---|---|---|---|
|
| 73 |
+
| GPT-2 medium | 355M | 40B | 0.250 | 0.436 | 0.586 | 0.394 | **0.430** | 0.664 | **26.75** | **0.531** |
|
| 74 |
+
| Baguettotron | 321M | 200B | 0.302 | 0.506 | 0.589 | 0.354 | 0.294 | 0.618 | 30.93 | 0.530 |
|
| 75 |
+
| SmolLM-360M | 360M | 600B | **0.359** | **0.640** | 0.550 | **0.536** | **0.455** | **0.715** | **19.49** | **0.570** |
|
| 76 |
+
| SmolLM2-360M | 360M | 4000B | **0.381** | **0.681** | 0.617 | 0.431 | **0.532** | **0.718** | **15.67** | **0.586** |
|
| 77 |
+
| LFM2-350M | 350M | 10000B | **0.393** | **0.662** | **0.642** | **0.489** | 0.399 | 0.698 | 25.68 | **0.558** |
|
| 78 |
+
| **Prisma** | **357M** | **30B** | 0.290 | **0.548** | **0.620** | **0.427** | 0.362 | **0.670** | 27.40 | 0.506 |
|
| 79 |
+
|
| 80 |
+
\* *normalized accuracy* · \*\* *word perplexity*
|
| 81 |
+
|
| 82 |
+
**Key findings:**
|
| 83 |
+
- **Beats GPT-2 medium on 5/8 benchmarks** (ARC-C, ARC-E, BoolQ, HellaSwag, PIQA) with 25% less training data.
|
| 84 |
+
- **Beats Baguettotron (200B) on 6/8 benchmarks** — including PPL — with **7x less data.**
|
| 85 |
+
- **BoolQ 0.620** exceeds all models except LFM2 (10000B) and SmolLM2 (4000B). The anti-memorization properties of G²LU force genuine comprehension instead of statistical shortcuts.
|
| 86 |
+
- **ARC-Easy 0.548** — the largest absolute gain over GPT-2 medium (+11.2pp). FineWeb-Edu knowledge absorbed efficiently through G²LU's relational features.
|
| 87 |
+
- Prisma wins on **reasoning benchmarks** (ARC, HellaSwag, PIQA, BoolQ). Models trained on 20-300x more data win on **content prediction** (LAMBADA, PPL). The architecture trades raw memorization for data-efficient knowledge application.
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
### Training progression (~350M)
|
| 91 |
+
|
| 92 |
+
| Stage | LR | ARC-C | ARC-E | BoolQ | HellaSwag | LAMBADA | PIQA | PPL |
|
| 93 |
+
|---|---|---|---|---|---|---|---|---|
|
| 94 |
+
| Standard 336M baseline | 1e-4 | 0.228 | 0.341 | 0.618 | 0.280 | 0.226 | 0.574 | 77.2 |
|
| 95 |
+
| Prisma 41L (OWT 20%) | 5e-4 | 0.238 | 0.394 | 0.585 | 0.317 | 0.313 | 0.614 | 44.8 |
|
| 96 |
+
| + WoRPE (OWT 20%) | 1e-3 | 0.247 | 0.397 | 0.595 | 0.331 | 0.333 | 0.614 | 43.5 |
|
| 97 |
+
| + continued (FineWeb c1) | 1e-3 | 0.249 | 0.434 | 0.601 | 0.333 | 0.312 | 0.626 | 34.7 |
|
| 98 |
+
| + continued (FineWeb c2) | 1e-3 | 0.290 | 0.548 | 0.620 | 0.427 | 0.362 | 0.670 | 27.4 |
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
## Quick Start
|
| 102 |
+
|
| 103 |
+
### Install
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
pip install -r Prisma/requirements.txt
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
### Train
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
# Small Prisma (~47M) on WikiText-103
|
| 114 |
+
python -m Prisma.train \
|
| 115 |
+
--arch mirrored --dims 384 --heads 6 --kv-heads 2 --layers 57 --n-middle 1 \
|
| 116 |
+
--tokenizer facebook/MobileLLM-125M \
|
| 117 |
+
--word-rope-dims 8 --word-rope-base 10.0 \
|
| 118 |
+
--data hf:wikitext:wikitext-103-raw-v1:train \
|
| 119 |
+
--epochs 4 --batch-size 32 --context-length 512 \
|
| 120 |
+
--lr 1e-2 --warmup-steps 500 --bf16 --gpu 0
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# 324M Prisma on OpenWebText
|
| 124 |
+
python -m Prisma.train \
|
| 125 |
+
--gpu 0 --compile --bf16 --arch mirrored \
|
| 126 |
+
--dims 1024 --heads 16 --kv-heads 4 --layers 41 --n-middle 1 \
|
| 127 |
+
--word-rope-dims 8 --word-rope-base 10.0 \
|
| 128 |
+
--tokenizer facebook/MobileLLM-125M \
|
| 129 |
+
--data hf:Bingsu/openwebtext_20p --text-column text \
|
| 130 |
+
--epochs 4 --batch-size 12 --context-length 1024 --grad-accum 42 \
|
| 131 |
+
--lr 1e-3 --warmup 500 \
|
| 132 |
+
--log-every 5 --val-every 1000 --save-every 1000 \
|
| 133 |
+
--checkpoint-dir path/to/checkpoints/your_model/
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
### Generate
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
python -m Prisma.generate \
|
| 141 |
+
--checkpoint path/to/checkpoints/your_model/best.pt \
|
| 142 |
+
--prompt "A thought observing itself discovers that it" \
|
| 143 |
+
--max-tokens 256 --temperature 0.8 --top-p 0.95
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
### Benchmark
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
# Single model
|
| 151 |
+
python -m Prisma.bench \
|
| 152 |
+
--checkpoint path/to/checkpoints/your_model/best.pt \
|
| 153 |
+
--tasks arc_easy,lambada_openai,piqa,hellaswag,winogrande,wikitext \
|
| 154 |
+
--gpu 0
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
## CLI Reference
|
| 159 |
+
|
| 160 |
+
### Architecture
|
| 161 |
+
|
| 162 |
+
| Flag | Default | Description |
|
| 163 |
+
|---|---|---|
|
| 164 |
+
| `--arch` | `mirrored` | Architecture: `standard`, `mirrored`, `graft_g2lu` (experimental) |
|
| 165 |
+
| `--dims` | 512 | Hidden dimension |
|
| 166 |
+
| `--heads` | 8 | Number of attention heads |
|
| 167 |
+
| `--kv-heads` | — | KV heads for GQA (omit = MHA) |
|
| 168 |
+
| `--layers` | 12 | Total virtual layers (expand + middle + compress) |
|
| 169 |
+
| `--n-middle` | 2 | Unique (non-mirrored) middle layers |
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
### Prisma-Specific
|
| 173 |
+
|
| 174 |
+
| Flag | Default | Description |
|
| 175 |
+
|---|---|---|
|
| 176 |
+
| `--word-rope-dims` | 0 | Head dims for WoRPE (0 = disabled, try 8) |
|
| 177 |
+
| `--word-rope-base` | 10.0 | WoRPE frequency base |
|
| 178 |
+
| `--aux-skip` | 0 | Skip-ahead prediction distance (0 = disabled) |
|
| 179 |
+
| `--aux-weight` | 0.1 | Weight for auxiliary loss |
|
| 180 |
+
| `--no-g2lu` | — | Disable G²LU, use standard SwiGLU in mirrored arch |
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
### Training
|
| 184 |
+
|
| 185 |
+
| Flag | Default | Description |
|
| 186 |
+
|---|---|---|
|
| 187 |
+
| `--lr` | 3e-4 | Peak learning rate |
|
| 188 |
+
| `--min-lr` | 0.0 | LR floor for cosine schedule |
|
| 189 |
+
| `--warmup-steps` | 100 | LR warmup steps |
|
| 190 |
+
| `--epochs` | 10 | Training epochs |
|
| 191 |
+
| `--batch-size` | 32 | Micro-batch size |
|
| 192 |
+
| `--grad-accum` | 1 | Gradient accumulation steps |
|
| 193 |
+
| `--context-length` | 512 | Sequence length |
|
| 194 |
+
| `--bf16` / `--fp16` | — | Mixed precision |
|
| 195 |
+
| `--compile` | — | `torch.compile` the model |
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
### Data
|
| 199 |
+
|
| 200 |
+
| Flag | Default | Description |
|
| 201 |
+
|---|---|---|
|
| 202 |
+
| `--data` | — | Path or `hf:dataset_name` |
|
| 203 |
+
| `--text-column` | `text` | Column name for HF datasets |
|
| 204 |
+
| `--tokenizer` | `gpt2` | Tokenizer name or path |
|
| 205 |
+
| `--num-samples` | — | Limit dataset size |
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
## Architecture Details
|
| 209 |
+
|
| 210 |
+
### Why Mirroring Works
|
| 211 |
+
|
| 212 |
+
Mirroring only works due to the additional gate. W3 and W4 specialize to serve different roles despite sharing weights — spectral analysis confirms the gates swap their stable-rank profiles at the architectural midpoint. The order of mirror layers may be rearrangeable, as the gates adapt to whatever representations flow through them.
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
### Why G²LU Works
|
| 216 |
+
|
| 217 |
+
Standard SwiGLU creates hyperplane decision boundaries — broad, overlapping activation regions. G²LU's nested gate creates **saddle surfaces** — narrow activation bands with isolation gaps (like a spectral comb filter). This has three effects:
|
| 218 |
+
|
| 219 |
+
1. **Anti-memorization.** The gate geometry cannot form sharp, input-specific activations. The model is forced toward broad, relational features.
|
| 220 |
+
2. **Higher LR tolerance.** Narrow activation bands leave headroom between features. Large gradient updates shift features within their bands without colliding.
|
| 221 |
+
3. **Compositional detection.** Each neuron natively computes conjunctions (A AND B), not just thresholds. Might be useful for morphology, syntax, and structural reasoning.
|
| 222 |
+
|
| 223 |
+
G²LU can be seen as occupying a point between standard GLU (fixed activation, fixed gate) and KAN (fully learned activations): the activation function is fixed (silu), but its effective shape adapts per-input through the nested gate.
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
### Why WoRPE Works
|
| 227 |
+
|
| 228 |
+
BPE tokenizers already mark word boundaries (`Ġ` for GPT-2, `▁` for SentencePiece). WoRPE surfaces this information geometrically in a dedicated subspace of the rotary embedding, so the model gets word-internal position for free instead of rediscovering it from attention patterns. Requires G²LU to exploit effectively — the saddle surfaces compute morphological conjunctions ("position-0 AND prefix-pattern") that single gates cannot.
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
### Why Everything Works Together
|
| 232 |
+
|
| 233 |
+
The optimization landscape of this architecture is substantially more complex than a standard transformer — shared weights must serve both directions, nested gates must coordinate, and the hourglass bottleneck constrains information flow. This appears to be only tractable when anchored by pre-trained, weight-tied embeddings that provide a stable coordinate system. The frozen embeddings give the model fixed reference geometry, allowing convergence despite the architectural complexity.
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
## File Map
|
| 237 |
+
|
| 238 |
+
```
|
| 239 |
+
Prisma/
|
| 240 |
+
config.py — CLI arguments, presets, CircuitConfig
|
| 241 |
+
layers.py — RMSNorm, RoPE, WoRPE, CausalAttention, SwiGLU
|
| 242 |
+
model.py — CircuitTransformer (standard baseline)
|
| 243 |
+
mirrored.py — MirroredTransformer, G²LU, MirroredBlock
|
| 244 |
+
train.py — Training loop, LR schedule, checkpointing
|
| 245 |
+
data.py — MemmapDataset, parallel tokenization, HF/text loading
|
| 246 |
+
generate.py — Text generation with KV caching
|
| 247 |
+
bench.py — Benchmark runner and comparison tables
|
| 248 |
+
lm_eval_wrapper.py — EleutherAI lm-eval harness integration
|
| 249 |
+
graft_g2lu.py — Surgical G²LU upgrade for pretrained models (experimental/untested)
|
| 250 |
+
scripts/ — Analysis scripts
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
## Origin
|
| 255 |
+
|
| 256 |
+
Prisma grew from interpretability research on _layer grafting_ (writing in progress) in Llama 3.2, which suggests that one of the ways that transformers might self organize to process language can be seen as like a mirrored structure that expands from tokens to semantics, then compressing back — bringing the interpretive analogy of seeing it as a biconvex lens with fractures or polarizing filters within its body. If the two halves are symmetric structurally, they can share weights. The gate (fractures/polarizing filters) becomes the minimum surgical unit for changing behavior. A single weightset becomes insufficient due to shared weights, which brought the question of how to properly make two gates efficiently collaborate.
|
| 257 |
+
|
| 258 |
+
G²LU emerged from the observation that for a pair of gates to be expressive and atomic, _one gate needs to be in function of the other_.
|
| 259 |
+
|
| 260 |
+
WoRPE emerged from noticing, that tokenizers already carry word structure but positional encodings ignore it — providing hints to the model allows faster convergence during training.
|
| 261 |
+
|
| 262 |
+
The architecture is a processing engine that plugs into pretrained tokenizer embeddings. The tokenizer is load-bearing infrastructure — Prisma operates within a pre-existing coordinate system.
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
## Developer Notes
|
| 266 |
+
|
| 267 |
+
This model is the outcome of a POC done by a single individual with limited resources, further investigation, training and tests are being slowly conducted as time and conditions allow.
|
| 268 |
+
|
| 269 |
+
The proposed architecture was only fully trained on top of `facebook/MobileLLM-125M` tokenizer and weight-tied embeddings. It might be the case that it doesn't work as expected on untied embeddings and it is highly likely that it is impossible to train a model with this architecture without a pre-trained tokenizer.
|
| 270 |
+
|
| 271 |
+
Different arrangements of the architecture (varying middle layer count, mirror depth, width) would likely produce different results. Only this setup — with 1 middle layer — was tested, as a validation of whether the architecture works at all. The extreme case was chosen deliberately: if the bottleneck configuration most prone to failure still produces competitive results, less constrained configurations should too.
|
| 272 |
+
|
| 273 |
+
Factorized dimensions for embeddings and an intermediate down proj before the output head were attempted, and nothing useful came out of it.
|
| 274 |
+
|
| 275 |
+
It is completely unknown if the architecture is beneficial for larger models (1B+) — observations suggests it might.
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
## Training
|
| 279 |
+
|
| 280 |
+
- **Architecture**:
|
| 281 |
+
- 41 layers
|
| 282 |
+
- 20 with shared W1 and W2
|
| 283 |
+
- 1 unique
|
| 284 |
+
- 1024 dimms
|
| 285 |
+
- 16 GQA heads, 4 KV heads (4:1)
|
| 286 |
+
- vocab size 32k
|
| 287 |
+
- RoPE + WoRPE + G²LU
|
| 288 |
+
- **Pretraining tokens**: 30B
|
| 289 |
+
- **Precision**: bfloat16
|
| 290 |
+
- **Tokenizer/Embeddings**: facebook/MobileLLM-125M
|
| 291 |
+
- **Hardware**: 1 H100
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
## Disclaimer
|
| 295 |
+
|
| 296 |
+
This model is developed as a research model and it hasn't been tested thoroughly regarding synthesis and coherence quality, as its size is somewhat limiting. Use it at your own risk.
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
## Citation
|
| 300 |
+
```
|
| 301 |
+
@misc{ivatchkovitch2026prisma,
|
| 302 |
+
title={Prisma: Interpretability-Inspired Mirrored Transformer Architecture},
|
| 303 |
+
author={Yuri Ivatchkovitch},
|
| 304 |
+
year={2026},
|
| 305 |
+
howpublished={\url{https://huggingface.co/y3i12/Prisma}},
|
| 306 |
+
}
|
| 307 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Circuits: Minimal Transformer for Semantic Circuitry Experiments.
|
| 3 |
+
|
| 4 |
+
A clean, self-contained transformer implementation designed for
|
| 5 |
+
experimenting with neural networks.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .config import CircuitConfig
|
| 9 |
+
from .model import CircuitTransformer, count_parameters
|
| 10 |
+
from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
|
| 11 |
+
from .data import get_tokenizer, load_data, create_dataloader, TextDataset
|
| 12 |
+
from .graft_g2lu import G2LU_GraftedModel, G2LU_MLP, load_g2lu_model
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"CircuitConfig",
|
| 16 |
+
"CircuitTransformer",
|
| 17 |
+
"count_parameters",
|
| 18 |
+
"MirroredConfig",
|
| 19 |
+
"MirroredTransformer",
|
| 20 |
+
"count_mirrored_parameters",
|
| 21 |
+
"get_tokenizer",
|
| 22 |
+
"load_data",
|
| 23 |
+
"create_dataloader",
|
| 24 |
+
"TextDataset",
|
| 25 |
+
"G2LU_GraftedModel",
|
| 26 |
+
"G2LU_MLP",
|
| 27 |
+
"load_g2lu_model",
|
| 28 |
+
]
|
bench.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Benchmark Circuit transformer family against standard LM tasks.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
# Single model
|
| 7 |
+
python -m circuits.bench --checkpoint circuits/checkpoints/slot_local_mirrored/best.pt --gpu 0
|
| 8 |
+
|
| 9 |
+
# Compare all architectures
|
| 10 |
+
python -m circuits.bench --compare --gpu 0
|
| 11 |
+
|
| 12 |
+
# Quick sanity check (100 samples per task)
|
| 13 |
+
python -m circuits.bench --compare --gpu 0 --limit 100
|
| 14 |
+
|
| 15 |
+
# Specific tasks
|
| 16 |
+
python -m circuits.bench --checkpoint path/to/best.pt --tasks hellaswag,lambada_openai
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import time
|
| 22 |
+
import torch
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import lm_eval
|
| 26 |
+
from lm_eval.api.registry import register_model
|
| 27 |
+
|
| 28 |
+
from .lm_eval_wrapper import CircuitLM
|
| 29 |
+
|
| 30 |
+
# Register so lm_eval can find it
|
| 31 |
+
register_model("circuit")(CircuitLM)
|
| 32 |
+
|
| 33 |
+
DEFAULT_TASKS = "arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,piqa,wikitext,winogrande"
|
| 34 |
+
|
| 35 |
+
# Known checkpoints for --compare mode
|
| 36 |
+
CHECKPOINTS = {
|
| 37 |
+
"standard_12L": "circuits/checkpoints/flat/best.pt",
|
| 38 |
+
"mirrored_9L_wide": "circuits/checkpoints/hier_wide_2/best.pt",
|
| 39 |
+
"mirrored_15L_deep": "circuits/checkpoints/hier_resized/best.pt",
|
| 40 |
+
"slot_local_mirrored": "circuits/checkpoints/slot_local_mirrored/best.pt",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run_benchmark(checkpoint: str, tasks: str, device: str, limit: int = None, batch_size: int = 1, compile: bool = False):
|
| 45 |
+
"""Run lm-eval on a single checkpoint."""
|
| 46 |
+
model_args = f"checkpoint={checkpoint},device={device},batch_size={batch_size},compile={'true' if compile else 'false'}"
|
| 47 |
+
|
| 48 |
+
task_list = tasks.split(",")
|
| 49 |
+
|
| 50 |
+
results = lm_eval.simple_evaluate(
|
| 51 |
+
model="circuit",
|
| 52 |
+
model_args=model_args,
|
| 53 |
+
tasks=task_list,
|
| 54 |
+
limit=limit,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return results
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def extract_scores(results: dict) -> dict:
|
| 61 |
+
"""Pull headline metrics from lm-eval results."""
|
| 62 |
+
scores = {}
|
| 63 |
+
if "results" not in results:
|
| 64 |
+
return scores
|
| 65 |
+
for task_name, task_results in results["results"].items():
|
| 66 |
+
# Get the primary metric (usually acc or acc_norm)
|
| 67 |
+
if "acc_norm,none" in task_results:
|
| 68 |
+
scores[task_name] = task_results["acc_norm,none"]
|
| 69 |
+
elif "acc,none" in task_results:
|
| 70 |
+
scores[task_name] = task_results["acc,none"]
|
| 71 |
+
elif "perplexity,none" in task_results:
|
| 72 |
+
scores[task_name] = task_results["perplexity,none"]
|
| 73 |
+
elif "word_perplexity,none" in task_results:
|
| 74 |
+
scores[task_name] = task_results["word_perplexity,none"]
|
| 75 |
+
return scores
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def print_comparison(all_results: dict, tasks: list):
|
| 79 |
+
"""Pretty-print comparison table."""
|
| 80 |
+
# Header
|
| 81 |
+
col_width = max(len(t) for t in tasks) + 2
|
| 82 |
+
name_width = max(len(n) for n in all_results) + 2
|
| 83 |
+
|
| 84 |
+
header = f"{'Model':<{name_width}}"
|
| 85 |
+
for task in tasks:
|
| 86 |
+
header += f"{task:>{col_width}}"
|
| 87 |
+
header += f"{' avg':>8}"
|
| 88 |
+
print("\n" + "=" * len(header))
|
| 89 |
+
print(header)
|
| 90 |
+
print("-" * len(header))
|
| 91 |
+
|
| 92 |
+
for name, scores in all_results.items():
|
| 93 |
+
row = f"{name:<{name_width}}"
|
| 94 |
+
vals = []
|
| 95 |
+
for task in tasks:
|
| 96 |
+
val = scores.get(task, None)
|
| 97 |
+
if val is not None:
|
| 98 |
+
row += f"{val:>{col_width}.4f}"
|
| 99 |
+
vals.append(val)
|
| 100 |
+
else:
|
| 101 |
+
row += f"{'N/A':>{col_width}}"
|
| 102 |
+
avg = sum(vals) / len(vals) if vals else 0
|
| 103 |
+
row += f"{avg:>8.4f}"
|
| 104 |
+
print(row)
|
| 105 |
+
|
| 106 |
+
print("=" * len(header))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def main():
|
| 110 |
+
parser = argparse.ArgumentParser(description="Benchmark Circuit transformers")
|
| 111 |
+
parser.add_argument("--checkpoint", type=str, help="Path to single checkpoint")
|
| 112 |
+
parser.add_argument("--compare", action="store_true", help="Compare all known architectures")
|
| 113 |
+
parser.add_argument("--tasks", type=str, default=DEFAULT_TASKS, help="Comma-separated task list")
|
| 114 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| 115 |
+
parser.add_argument("--limit", type=int, default=None, help="Limit samples per task (for quick testing)")
|
| 116 |
+
parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
|
| 117 |
+
parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
|
| 118 |
+
parser.add_argument("--compile", action="store_true", help="torch.compile models for faster inference")
|
| 119 |
+
args = parser.parse_args()
|
| 120 |
+
|
| 121 |
+
device = f"cuda:{args.gpu}"
|
| 122 |
+
task_list = args.tasks.split(",")
|
| 123 |
+
|
| 124 |
+
if args.compare:
|
| 125 |
+
all_scores = {}
|
| 126 |
+
all_raw = {}
|
| 127 |
+
|
| 128 |
+
# Filter to existing checkpoints
|
| 129 |
+
available = {k: v for k, v in CHECKPOINTS.items() if Path(v).exists()}
|
| 130 |
+
missing = {k: v for k, v in CHECKPOINTS.items() if not Path(v).exists()}
|
| 131 |
+
if missing:
|
| 132 |
+
print(f"Skipping (not found): {', '.join(missing.keys())}")
|
| 133 |
+
|
| 134 |
+
for name, ckpt_path in available.items():
|
| 135 |
+
print(f"\n{'='*60}")
|
| 136 |
+
print(f"Evaluating: {name}")
|
| 137 |
+
print(f"Checkpoint: {ckpt_path}")
|
| 138 |
+
print(f"{'='*60}")
|
| 139 |
+
|
| 140 |
+
t0 = time.time()
|
| 141 |
+
results = run_benchmark(ckpt_path, args.tasks, device, args.limit, args.batch_size, args.compile)
|
| 142 |
+
elapsed = time.time() - t0
|
| 143 |
+
|
| 144 |
+
scores = extract_scores(results)
|
| 145 |
+
all_scores[name] = scores
|
| 146 |
+
all_raw[name] = results.get("results", {})
|
| 147 |
+
print(f" Completed in {elapsed:.0f}s: {scores}")
|
| 148 |
+
|
| 149 |
+
print_comparison(all_scores, task_list)
|
| 150 |
+
|
| 151 |
+
if args.output:
|
| 152 |
+
with open(args.output, "w") as f:
|
| 153 |
+
json.dump({"scores": all_scores, "raw": all_raw}, f, indent=2, default=str)
|
| 154 |
+
print(f"\nResults saved to {args.output}")
|
| 155 |
+
|
| 156 |
+
elif args.checkpoint:
|
| 157 |
+
print(f"Evaluating: {args.checkpoint}")
|
| 158 |
+
t0 = time.time()
|
| 159 |
+
results = run_benchmark(args.checkpoint, args.tasks, device, args.limit, args.batch_size, args.compile)
|
| 160 |
+
elapsed = time.time() - t0
|
| 161 |
+
|
| 162 |
+
scores = extract_scores(results)
|
| 163 |
+
print(f"\nResults ({elapsed:.0f}s):")
|
| 164 |
+
for task, score in scores.items():
|
| 165 |
+
print(f" {task}: {score:.4f}")
|
| 166 |
+
|
| 167 |
+
if args.output:
|
| 168 |
+
with open(args.output, "w") as f:
|
| 169 |
+
json.dump(results, f, indent=2, default=str)
|
| 170 |
+
print(f"\nResults saved to {args.output}")
|
| 171 |
+
else:
|
| 172 |
+
parser.print_help()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
main()
|
coherence_eval.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Coherence evaluation for language models.
|
| 4 |
+
|
| 5 |
+
Measures what standard benchmarks can't see:
|
| 6 |
+
Tier 1 — Generation diversity (repetition, collapse detection)
|
| 7 |
+
Tier 2 — Multi-distance prediction (context utilization, skip accuracy)
|
| 8 |
+
Tier 3 — Semantic consistency (chunk similarity over long generations)
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
# Custom checkpoint
|
| 12 |
+
python -m circuits.coherence_eval --checkpoint circuits/checkpoints/model/best.pt
|
| 13 |
+
|
| 14 |
+
# HuggingFace model
|
| 15 |
+
python -m circuits.coherence_eval --model gpt2
|
| 16 |
+
|
| 17 |
+
# Compare models
|
| 18 |
+
python -m circuits.coherence_eval --model EleutherAI/pythia-160m --gpu 0
|
| 19 |
+
|
| 20 |
+
# Quick test (fewer prompts, shorter generation)
|
| 21 |
+
python -m circuits.coherence_eval --checkpoint path/to/model.pt --num-prompts 5 --gen-length 256
|
| 22 |
+
|
| 23 |
+
# Run specific tiers
|
| 24 |
+
python -m circuits.coherence_eval --checkpoint path/to/model.pt --tiers 1,3
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import math
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 38 |
+
# Default prompts — diverse domains, 10-20 tokens each
|
| 39 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 40 |
+
|
| 41 |
+
DEFAULT_PROMPTS = [
|
| 42 |
+
"A thought observing itself discovers that it",
|
| 43 |
+
"The history of science shows that",
|
| 44 |
+
"In the middle of the night, the old house",
|
| 45 |
+
"The relationship between language and thought has been",
|
| 46 |
+
"When the first settlers arrived, they found",
|
| 47 |
+
"The mathematical proof begins by assuming",
|
| 48 |
+
"She opened the door to find",
|
| 49 |
+
"The economic implications of this policy",
|
| 50 |
+
"Deep beneath the ocean surface, researchers discovered",
|
| 51 |
+
"The most important lesson from this experiment is",
|
| 52 |
+
"According to recent studies, the human brain",
|
| 53 |
+
"The old library contained books that",
|
| 54 |
+
"As the temperature continued to rise, the effects on",
|
| 55 |
+
"The development of artificial intelligence has raised questions about",
|
| 56 |
+
"In the small village at the foot of the mountain",
|
| 57 |
+
"The fundamental principles of democracy require",
|
| 58 |
+
"Looking through the telescope, the astronomer noticed",
|
| 59 |
+
"The relationship between music and emotion",
|
| 60 |
+
"During the industrial revolution, working conditions",
|
| 61 |
+
"The ancient manuscript revealed secrets about",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 66 |
+
# Model wrapper — unified interface for circuit models and HF models
|
| 67 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
class ModelWrapper:
|
| 70 |
+
"""Unified interface for custom circuit models and HuggingFace models."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, model, tokenizer, device, model_type="hf",
|
| 73 |
+
skip_head=None, skip_k=0, max_seq_len=1024, name="unknown"):
|
| 74 |
+
self.model = model
|
| 75 |
+
self.tokenizer = tokenizer
|
| 76 |
+
self.device = device
|
| 77 |
+
self.model_type = model_type # "circuit" or "hf"
|
| 78 |
+
self.skip_head = skip_head
|
| 79 |
+
self.skip_k = skip_k
|
| 80 |
+
self.max_seq_len = max_seq_len
|
| 81 |
+
self.name = name
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def from_checkpoint(cls, path, device):
|
| 85 |
+
"""Load a custom circuit model from checkpoint."""
|
| 86 |
+
from .config import CircuitConfig
|
| 87 |
+
from .model import CircuitTransformer
|
| 88 |
+
from .mirrored import MirroredConfig, MirroredTransformer
|
| 89 |
+
from .slotted_mirrored import SlotMirroredConfig, SlotMirroredTransformer
|
| 90 |
+
from .data import get_tokenizer
|
| 91 |
+
|
| 92 |
+
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
| 93 |
+
model_type = checkpoint.get("model_type", "standard")
|
| 94 |
+
|
| 95 |
+
if model_type == "slot_mirrored":
|
| 96 |
+
config = SlotMirroredConfig.from_dict(checkpoint["config"])
|
| 97 |
+
model = SlotMirroredTransformer(config).to(device)
|
| 98 |
+
arch_desc = f"SlotMirrored ({config.n_slots} slots)"
|
| 99 |
+
elif model_type == "mirrored":
|
| 100 |
+
config = MirroredConfig.from_dict(checkpoint["config"])
|
| 101 |
+
model = MirroredTransformer(config).to(device)
|
| 102 |
+
arch_desc = "Mirrored"
|
| 103 |
+
else:
|
| 104 |
+
config = CircuitConfig.from_dict(checkpoint["config"])
|
| 105 |
+
model = CircuitTransformer(config).to(device)
|
| 106 |
+
arch_desc = "Standard"
|
| 107 |
+
|
| 108 |
+
# Handle torch.compile prefix
|
| 109 |
+
state_dict = checkpoint["model"]
|
| 110 |
+
if any(k.startswith("_orig_mod.") for k in state_dict):
|
| 111 |
+
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| 112 |
+
model.load_state_dict(state_dict)
|
| 113 |
+
model.eval()
|
| 114 |
+
|
| 115 |
+
tokenizer = get_tokenizer()
|
| 116 |
+
skip_head = model.skip_head if hasattr(model, 'skip_head') else None
|
| 117 |
+
skip_k = getattr(config, 'aux_skip_k', 0)
|
| 118 |
+
max_seq_len = config.max_seq_len
|
| 119 |
+
|
| 120 |
+
params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 121 |
+
name = f"{Path(path).parent.name}/{Path(path).stem} ({arch_desc}, {params:.1f}M)"
|
| 122 |
+
|
| 123 |
+
return cls(model, tokenizer, device, model_type="circuit",
|
| 124 |
+
skip_head=skip_head, skip_k=skip_k,
|
| 125 |
+
max_seq_len=max_seq_len, name=name)
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def from_pretrained(cls, model_name, device):
|
| 129 |
+
"""Load a HuggingFace model."""
|
| 130 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 131 |
+
|
| 132 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 133 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 134 |
+
model_name, trust_remote_code=True,
|
| 135 |
+
torch_dtype=torch.float32,
|
| 136 |
+
).to(device)
|
| 137 |
+
model.eval()
|
| 138 |
+
|
| 139 |
+
max_seq_len = getattr(model.config, 'max_position_embeddings', 1024)
|
| 140 |
+
if tokenizer.pad_token is None:
|
| 141 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 142 |
+
|
| 143 |
+
params = sum(p.numel() for p in model.parameters()) / 1e6
|
| 144 |
+
name = f"{model_name} ({params:.1f}M)"
|
| 145 |
+
|
| 146 |
+
return cls(model, tokenizer, device, model_type="hf",
|
| 147 |
+
max_seq_len=max_seq_len, name=name)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def has_skip_head(self):
|
| 151 |
+
return self.skip_head is not None and self.skip_k > 0
|
| 152 |
+
|
| 153 |
+
def generate(self, prompt_text, max_new_tokens=512):
|
| 154 |
+
"""Generate tokens at temperature 0 (greedy). Returns generated token IDs only."""
|
| 155 |
+
prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
|
| 156 |
+
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
if self.model_type == "hf":
|
| 159 |
+
output_ids = self.model.generate(
|
| 160 |
+
prompt_ids,
|
| 161 |
+
max_new_tokens=max_new_tokens,
|
| 162 |
+
do_sample=True,
|
| 163 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 164 |
+
temperature=0.8,
|
| 165 |
+
top_k=50,
|
| 166 |
+
top_p=0.9,
|
| 167 |
+
repetition_penalty=1.2,
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
output_ids = self.model.generate(
|
| 171 |
+
prompt_ids,
|
| 172 |
+
max_new_tokens=max_new_tokens,
|
| 173 |
+
temperature=0.8,
|
| 174 |
+
top_k=50,
|
| 175 |
+
top_p=0.9,
|
| 176 |
+
repetition_penalty=1.2,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Return only the generated part
|
| 180 |
+
gen_ids = output_ids[0, prompt_ids.shape[1]:]
|
| 181 |
+
return prompt_ids[0], gen_ids
|
| 182 |
+
|
| 183 |
+
def forward_with_hidden(self, input_ids):
|
| 184 |
+
"""Forward pass returning (logits, hidden_states, skip_logits_or_None).
|
| 185 |
+
input_ids: [1, L] tensor.
|
| 186 |
+
"""
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
if self.model_type == "hf":
|
| 189 |
+
outputs = self.model(input_ids, output_hidden_states=True)
|
| 190 |
+
logits = outputs.logits
|
| 191 |
+
hidden = outputs.hidden_states[-1]
|
| 192 |
+
return logits, hidden, None
|
| 193 |
+
else:
|
| 194 |
+
# Hook into norm layer to capture pre-lm_head hidden states
|
| 195 |
+
hidden_capture = {}
|
| 196 |
+
|
| 197 |
+
def hook_fn(module, inp, output):
|
| 198 |
+
hidden_capture['h'] = output.detach()
|
| 199 |
+
|
| 200 |
+
handle = self.model.norm.register_forward_hook(hook_fn)
|
| 201 |
+
output = self.model(input_ids)
|
| 202 |
+
handle.remove()
|
| 203 |
+
|
| 204 |
+
logits = output['logits']
|
| 205 |
+
hidden = hidden_capture['h']
|
| 206 |
+
|
| 207 |
+
skip_logits = None
|
| 208 |
+
if self.has_skip_head:
|
| 209 |
+
skip_logits = self.skip_head(hidden)
|
| 210 |
+
|
| 211 |
+
return logits, hidden, skip_logits
|
| 212 |
+
|
| 213 |
+
def forward(self, input_ids):
|
| 214 |
+
"""Forward pass returning logits only. input_ids: [1, L] tensor."""
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
if self.model_type == "hf":
|
| 217 |
+
return self.model(input_ids).logits
|
| 218 |
+
else:
|
| 219 |
+
return self.model(input_ids)['logits']
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 223 |
+
# Generation (shared between Tier 1 and Tier 3)
|
| 224 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 225 |
+
|
| 226 |
+
def generate_all(wrapper, prompts, gen_length):
|
| 227 |
+
"""Generate from all prompts. Returns list of (prompt_text, prompt_ids, gen_ids)."""
|
| 228 |
+
results = []
|
| 229 |
+
for prompt in prompts:
|
| 230 |
+
prompt_ids, gen_ids = wrapper.generate(prompt, max_new_tokens=gen_length)
|
| 231 |
+
results.append((prompt, prompt_ids, gen_ids))
|
| 232 |
+
print(f" [{len(results)}/{len(prompts)}] {len(gen_ids)} tokens", end="\r")
|
| 233 |
+
print()
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 238 |
+
# Tier 1: Generation Diversity
|
| 239 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 240 |
+
|
| 241 |
+
def ngrams(tokens, n):
|
| 242 |
+
"""Extract n-grams from token list."""
|
| 243 |
+
return [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def compute_diversity(gen_ids):
|
| 247 |
+
"""Compute diversity metrics for a single generation."""
|
| 248 |
+
tokens = gen_ids.tolist()
|
| 249 |
+
n = len(tokens)
|
| 250 |
+
if n < 4:
|
| 251 |
+
return {"unique_1g": 0, "unique_2g": 0, "unique_3g": 0, "unique_4g": 0,
|
| 252 |
+
"max_repeat": n, "collapsed": True}
|
| 253 |
+
|
| 254 |
+
results = {}
|
| 255 |
+
for k in [1, 2, 3, 4]:
|
| 256 |
+
grams = ngrams(tokens, k)
|
| 257 |
+
results[f"unique_{k}g"] = len(set(grams)) / len(grams) if grams else 0.0
|
| 258 |
+
|
| 259 |
+
# Max consecutive identical token span
|
| 260 |
+
max_repeat = 1
|
| 261 |
+
current = 1
|
| 262 |
+
for i in range(1, n):
|
| 263 |
+
if tokens[i] == tokens[i - 1]:
|
| 264 |
+
current += 1
|
| 265 |
+
max_repeat = max(max_repeat, current)
|
| 266 |
+
else:
|
| 267 |
+
current = 1
|
| 268 |
+
results["max_repeat"] = max_repeat
|
| 269 |
+
|
| 270 |
+
# Longest repeated n-gram span (any n-gram repeated consecutively)
|
| 271 |
+
max_ngram_repeat = 1
|
| 272 |
+
for ng_size in [2, 3, 4, 5, 8]:
|
| 273 |
+
grams = ngrams(tokens, ng_size)
|
| 274 |
+
streak = 1
|
| 275 |
+
for i in range(1, len(grams)):
|
| 276 |
+
if grams[i] == grams[i - 1]:
|
| 277 |
+
streak += 1
|
| 278 |
+
max_ngram_repeat = max(max_ngram_repeat, streak * ng_size)
|
| 279 |
+
else:
|
| 280 |
+
streak = 1
|
| 281 |
+
results["max_ngram_repeat_span"] = max_ngram_repeat
|
| 282 |
+
|
| 283 |
+
# Collapse: unique 4-grams < 50% or max repeat span > 25% of generation
|
| 284 |
+
results["collapsed"] = (results["unique_4g"] < 0.5) or (max_ngram_repeat > n * 0.25)
|
| 285 |
+
|
| 286 |
+
return results
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def eval_diversity(generations, tokenizer, show_samples=3):
|
| 290 |
+
"""Tier 1: Compute diversity metrics from pre-generated text."""
|
| 291 |
+
print("\n" + "=" * 60)
|
| 292 |
+
print("TIER 1: Generation Diversity")
|
| 293 |
+
print("=" * 60)
|
| 294 |
+
|
| 295 |
+
all_metrics = []
|
| 296 |
+
sample_texts = []
|
| 297 |
+
|
| 298 |
+
for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
|
| 299 |
+
metrics = compute_diversity(gen_ids)
|
| 300 |
+
metrics["prompt"] = prompt
|
| 301 |
+
metrics["gen_length"] = len(gen_ids)
|
| 302 |
+
all_metrics.append(metrics)
|
| 303 |
+
|
| 304 |
+
if i < show_samples:
|
| 305 |
+
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
| 306 |
+
sample_texts.append((prompt, text))
|
| 307 |
+
|
| 308 |
+
n = len(all_metrics)
|
| 309 |
+
if n == 0:
|
| 310 |
+
print(" No generations to evaluate.")
|
| 311 |
+
return {}
|
| 312 |
+
|
| 313 |
+
# Aggregate
|
| 314 |
+
agg = {}
|
| 315 |
+
for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g",
|
| 316 |
+
"max_repeat", "max_ngram_repeat_span"]:
|
| 317 |
+
values = [m[key] for m in all_metrics]
|
| 318 |
+
agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
|
| 319 |
+
|
| 320 |
+
collapse_count = sum(1 for m in all_metrics if m["collapsed"])
|
| 321 |
+
agg["collapse_rate"] = collapse_count / n
|
| 322 |
+
avg_len = sum(m["gen_length"] for m in all_metrics) / n
|
| 323 |
+
|
| 324 |
+
# Print
|
| 325 |
+
print(f"\n Prompts evaluated: {n}")
|
| 326 |
+
print(f" Avg generation length: {avg_len:.0f} tokens")
|
| 327 |
+
print()
|
| 328 |
+
print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
|
| 329 |
+
print(f" {'-' * 50}")
|
| 330 |
+
for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g"]:
|
| 331 |
+
m = agg[key]
|
| 332 |
+
print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
|
| 333 |
+
for key in ["max_repeat", "max_ngram_repeat_span"]:
|
| 334 |
+
m = agg[key]
|
| 335 |
+
print(f" {key:<24} {m['mean']:>8.1f} {int(m['min']):>8d} {int(m['max']):>8d}")
|
| 336 |
+
print(f"\n Collapse rate: {collapse_count}/{n} ({agg['collapse_rate']:.1%})")
|
| 337 |
+
|
| 338 |
+
# Show samples
|
| 339 |
+
if sample_texts:
|
| 340 |
+
print(f"\n --- Sample generations (first {len(sample_texts)}) ---")
|
| 341 |
+
for prompt, text in sample_texts:
|
| 342 |
+
print(f"\n Prompt: \"{prompt}\"")
|
| 343 |
+
preview = text[:400].replace("\n", " ")
|
| 344 |
+
if len(text) > 400:
|
| 345 |
+
preview += "..."
|
| 346 |
+
print(f" Output: {preview}")
|
| 347 |
+
|
| 348 |
+
return {"per_prompt": all_metrics, "aggregate": agg}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 352 |
+
# Tier 2: Multi-Distance Prediction
|
| 353 |
+
# ────────────────���─────────────────────────────────────────────────────
|
| 354 |
+
|
| 355 |
+
def prepare_eval_sequences(wrapper, num_sequences=50, data_source=None):
|
| 356 |
+
"""Prepare ground truth sequences for Tier 2."""
|
| 357 |
+
max_len = wrapper.max_seq_len
|
| 358 |
+
|
| 359 |
+
if data_source and Path(data_source).exists():
|
| 360 |
+
with open(data_source) as f:
|
| 361 |
+
text = f.read()
|
| 362 |
+
all_ids = wrapper.tokenizer.encode(text)
|
| 363 |
+
else:
|
| 364 |
+
try:
|
| 365 |
+
from datasets import load_dataset
|
| 366 |
+
print(" Loading WikiText-103 validation...")
|
| 367 |
+
ds = load_dataset("wikitext", "wikitext-103-raw-v1",
|
| 368 |
+
split="validation", trust_remote_code=True)
|
| 369 |
+
text = "\n".join(row["text"] for row in ds if row["text"].strip())
|
| 370 |
+
all_ids = wrapper.tokenizer.encode(text)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
print(f" Could not load eval data: {e}")
|
| 373 |
+
print(f" Install 'datasets' or use --eval-data to provide a text file.")
|
| 374 |
+
return None
|
| 375 |
+
|
| 376 |
+
# Chunk into sequences
|
| 377 |
+
sequences = []
|
| 378 |
+
for i in range(0, len(all_ids) - max_len, max_len):
|
| 379 |
+
seq = torch.tensor(all_ids[i:i + max_len], dtype=torch.long)
|
| 380 |
+
sequences.append(seq)
|
| 381 |
+
if len(sequences) >= num_sequences:
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
if len(sequences) < 2:
|
| 385 |
+
print(" Not enough text for evaluation sequences.")
|
| 386 |
+
return None
|
| 387 |
+
|
| 388 |
+
print(f" Prepared {len(sequences)} sequences of {max_len} tokens")
|
| 389 |
+
return sequences
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def eval_context_utilization(wrapper, sequences):
|
| 393 |
+
"""Tier 2a: Per-position perplexity grouped by depth bucket."""
|
| 394 |
+
max_len = wrapper.max_seq_len
|
| 395 |
+
|
| 396 |
+
# Adaptive buckets based on max_seq_len
|
| 397 |
+
bucket_bounds = [0, 64, 128, 256, 512]
|
| 398 |
+
if max_len > 512:
|
| 399 |
+
bucket_bounds.append(max_len)
|
| 400 |
+
else:
|
| 401 |
+
bucket_bounds.append(max_len)
|
| 402 |
+
# Remove duplicates and sort
|
| 403 |
+
bucket_bounds = sorted(set(b for b in bucket_bounds if b <= max_len))
|
| 404 |
+
if bucket_bounds[-1] < max_len:
|
| 405 |
+
bucket_bounds.append(max_len)
|
| 406 |
+
buckets = [(bucket_bounds[i], bucket_bounds[i + 1])
|
| 407 |
+
for i in range(len(bucket_bounds) - 1)]
|
| 408 |
+
|
| 409 |
+
# Accumulate per-position losses
|
| 410 |
+
all_losses = []
|
| 411 |
+
for seq in sequences:
|
| 412 |
+
input_ids = seq.unsqueeze(0).to(wrapper.device)
|
| 413 |
+
logits = wrapper.forward(input_ids)
|
| 414 |
+
|
| 415 |
+
shift_logits = logits[0, :-1]
|
| 416 |
+
shift_labels = input_ids[0, 1:]
|
| 417 |
+
per_token_loss = F.cross_entropy(shift_logits, shift_labels, reduction='none')
|
| 418 |
+
all_losses.append(per_token_loss.cpu())
|
| 419 |
+
print(f" [{len(all_losses)}/{len(sequences)}]", end="\r")
|
| 420 |
+
print()
|
| 421 |
+
|
| 422 |
+
# Compute per-bucket stats
|
| 423 |
+
stacked = torch.stack(all_losses) # [N, L-1]
|
| 424 |
+
bucket_results = {}
|
| 425 |
+
for start, end in buckets:
|
| 426 |
+
s = min(start, stacked.shape[1])
|
| 427 |
+
e = min(end, stacked.shape[1])
|
| 428 |
+
if s >= e:
|
| 429 |
+
continue
|
| 430 |
+
bucket_losses = stacked[:, s:e]
|
| 431 |
+
avg_loss = bucket_losses.mean().item()
|
| 432 |
+
bucket_results[f"{start}-{end}"] = {
|
| 433 |
+
"loss": avg_loss,
|
| 434 |
+
"ppl": math.exp(min(avg_loss, 20)), # cap to avoid overflow
|
| 435 |
+
"n_tokens": bucket_losses.numel(),
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
return bucket_results
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def eval_skip_accuracy(wrapper, sequences, distances):
|
| 442 |
+
"""Tier 2b: Skip head prediction accuracy at various distances."""
|
| 443 |
+
if not wrapper.has_skip_head:
|
| 444 |
+
return None
|
| 445 |
+
|
| 446 |
+
results = {f"t+{K}": {"top1": [], "top5": []} for K in distances}
|
| 447 |
+
|
| 448 |
+
for seq in sequences:
|
| 449 |
+
input_ids = seq.unsqueeze(0).to(wrapper.device)
|
| 450 |
+
_, hidden, _ = wrapper.forward_with_hidden(input_ids)
|
| 451 |
+
|
| 452 |
+
for K in distances:
|
| 453 |
+
if K >= input_ids.shape[1]:
|
| 454 |
+
continue
|
| 455 |
+
|
| 456 |
+
skip_logits = wrapper.skip_head(hidden) # [1, L, V]
|
| 457 |
+
targets = input_ids[0, K:] # tokens at t+K
|
| 458 |
+
preds = skip_logits[0, :-K] # predictions from position t
|
| 459 |
+
|
| 460 |
+
top1 = (preds.argmax(-1) == targets).float().mean().item()
|
| 461 |
+
top5_indices = preds.topk(min(5, preds.shape[-1]), dim=-1).indices
|
| 462 |
+
top5 = (top5_indices == targets.unsqueeze(-1)).any(-1).float().mean().item()
|
| 463 |
+
|
| 464 |
+
results[f"t+{K}"]["top1"].append(top1)
|
| 465 |
+
results[f"t+{K}"]["top5"].append(top5)
|
| 466 |
+
|
| 467 |
+
print(f" [{len(results['t+' + str(distances[0])]['top1'])}/{len(sequences)}]", end="\r")
|
| 468 |
+
print()
|
| 469 |
+
|
| 470 |
+
# Average across sequences
|
| 471 |
+
avg_results = {}
|
| 472 |
+
for key in sorted(results.keys(), key=lambda x: int(x.split("+")[1])):
|
| 473 |
+
vals = results[key]
|
| 474 |
+
if vals["top1"]:
|
| 475 |
+
avg_results[key] = {
|
| 476 |
+
"top1": sum(vals["top1"]) / len(vals["top1"]),
|
| 477 |
+
"top5": sum(vals["top5"]) / len(vals["top5"]),
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
return avg_results
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def eval_structural(wrapper, eval_data, distances, num_sequences):
|
| 484 |
+
"""Run Tier 2 evaluation."""
|
| 485 |
+
print("\n" + "=" * 60)
|
| 486 |
+
print("TIER 2: Structural Prediction")
|
| 487 |
+
print("=" * 60)
|
| 488 |
+
|
| 489 |
+
sequences = prepare_eval_sequences(wrapper, num_sequences, eval_data)
|
| 490 |
+
if sequences is None:
|
| 491 |
+
return {"context_utilization": None, "skip_accuracy": None}
|
| 492 |
+
|
| 493 |
+
# 2a: Context utilization
|
| 494 |
+
print("\n --- 2a: Context Utilization (PPL by position depth) ---")
|
| 495 |
+
ctx_results = eval_context_utilization(wrapper, sequences)
|
| 496 |
+
|
| 497 |
+
if ctx_results:
|
| 498 |
+
print(f"\n {'Depth':<12} {'Loss':>8} {'PPL':>10} {'Tokens':>10}")
|
| 499 |
+
print(f" {'-' * 42}")
|
| 500 |
+
for bucket, vals in ctx_results.items():
|
| 501 |
+
print(f" {bucket:<12} {vals['loss']:>8.3f} {vals['ppl']:>10.2f} {vals['n_tokens']:>10}")
|
| 502 |
+
|
| 503 |
+
buckets_list = list(ctx_results.values())
|
| 504 |
+
if len(buckets_list) >= 2:
|
| 505 |
+
ratio = buckets_list[0]["ppl"] / buckets_list[-1]["ppl"]
|
| 506 |
+
print(f"\n Context utilization ratio (first/last): {ratio:.2f}x")
|
| 507 |
+
print(f" (Higher = model benefits more from additional context)")
|
| 508 |
+
|
| 509 |
+
# 2b: Skip accuracy
|
| 510 |
+
skip_results = None
|
| 511 |
+
if wrapper.has_skip_head:
|
| 512 |
+
print(f"\n --- 2b: Skip Head Accuracy (trained for t+{wrapper.skip_k}) ---")
|
| 513 |
+
skip_results = eval_skip_accuracy(wrapper, sequences, distances)
|
| 514 |
+
|
| 515 |
+
if skip_results:
|
| 516 |
+
print(f"\n {'Distance':<12} {'Top-1':>8} {'Top-5':>8}")
|
| 517 |
+
print(f" {'-' * 30}")
|
| 518 |
+
for key, vals in skip_results.items():
|
| 519 |
+
trained = " *" if int(key.split("+")[1]) == wrapper.skip_k else ""
|
| 520 |
+
print(f" {key:<12} {vals['top1']:>8.4f} {vals['top5']:>8.4f}{trained}")
|
| 521 |
+
print(f"\n * = trained distance")
|
| 522 |
+
else:
|
| 523 |
+
print("\n Skip head: not available")
|
| 524 |
+
|
| 525 |
+
return {"context_utilization": ctx_results, "skip_accuracy": skip_results}
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 529 |
+
# Tier 3: Semantic Consistency
|
| 530 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 531 |
+
|
| 532 |
+
def compute_chunk_similarity(hidden_states, chunk_size=128):
|
| 533 |
+
"""Compute cosine similarity between chunks of hidden states.
|
| 534 |
+
hidden_states: [L, D] tensor.
|
| 535 |
+
"""
|
| 536 |
+
L, D = hidden_states.shape
|
| 537 |
+
n_chunks = L // chunk_size
|
| 538 |
+
|
| 539 |
+
if n_chunks < 2:
|
| 540 |
+
return None
|
| 541 |
+
|
| 542 |
+
# Mean-pool each chunk
|
| 543 |
+
chunks = []
|
| 544 |
+
for i in range(n_chunks):
|
| 545 |
+
chunk = hidden_states[i * chunk_size:(i + 1) * chunk_size]
|
| 546 |
+
chunks.append(chunk.mean(dim=0))
|
| 547 |
+
|
| 548 |
+
chunk_vecs = torch.stack(chunks)
|
| 549 |
+
chunk_vecs = F.normalize(chunk_vecs, dim=-1)
|
| 550 |
+
|
| 551 |
+
# Pairwise cosine similarity
|
| 552 |
+
sim_matrix = chunk_vecs @ chunk_vecs.T
|
| 553 |
+
|
| 554 |
+
# Upper triangle (excluding diagonal)
|
| 555 |
+
mask = torch.triu(torch.ones_like(sim_matrix, dtype=torch.bool), diagonal=1)
|
| 556 |
+
pairwise_sims = sim_matrix[mask]
|
| 557 |
+
|
| 558 |
+
# Adjacent pairs
|
| 559 |
+
adjacent = [sim_matrix[i, i + 1].item() for i in range(n_chunks - 1)]
|
| 560 |
+
|
| 561 |
+
# Distant pairs (first quarter vs last quarter)
|
| 562 |
+
q1 = max(1, n_chunks // 4)
|
| 563 |
+
distant = []
|
| 564 |
+
for i in range(q1):
|
| 565 |
+
for j in range(n_chunks - q1, n_chunks):
|
| 566 |
+
if i < j:
|
| 567 |
+
distant.append(sim_matrix[i, j].item())
|
| 568 |
+
|
| 569 |
+
return {
|
| 570 |
+
"mean_sim": pairwise_sims.mean().item(),
|
| 571 |
+
"min_sim": pairwise_sims.min().item(),
|
| 572 |
+
"adjacent_sim": sum(adjacent) / len(adjacent),
|
| 573 |
+
"distant_sim": sum(distant) / len(distant) if distant else 0.0,
|
| 574 |
+
"n_chunks": n_chunks,
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def eval_consistency(wrapper, generations, chunk_size=128):
|
| 579 |
+
"""Tier 3: Semantic consistency of generated text via hidden state similarity."""
|
| 580 |
+
print("\n" + "=" * 60)
|
| 581 |
+
print("TIER 3: Semantic Consistency")
|
| 582 |
+
print("=" * 60)
|
| 583 |
+
|
| 584 |
+
all_metrics = []
|
| 585 |
+
|
| 586 |
+
for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
|
| 587 |
+
if gen_ids.shape[0] < chunk_size * 2:
|
| 588 |
+
continue
|
| 589 |
+
|
| 590 |
+
# Build full sequence: prompt + generated
|
| 591 |
+
full_ids = torch.cat([prompt_ids, gen_ids]).unsqueeze(0).to(wrapper.device)
|
| 592 |
+
|
| 593 |
+
# Trim to max_seq_len
|
| 594 |
+
if full_ids.shape[1] > wrapper.max_seq_len:
|
| 595 |
+
full_ids = full_ids[:, :wrapper.max_seq_len]
|
| 596 |
+
|
| 597 |
+
_, hidden, _ = wrapper.forward_with_hidden(full_ids)
|
| 598 |
+
|
| 599 |
+
# Use only generated part's hidden states
|
| 600 |
+
gen_start = prompt_ids.shape[0]
|
| 601 |
+
gen_hidden = hidden[0, gen_start:] # [gen_len, D]
|
| 602 |
+
|
| 603 |
+
metrics = compute_chunk_similarity(gen_hidden, chunk_size)
|
| 604 |
+
if metrics is not None:
|
| 605 |
+
metrics["prompt"] = prompt
|
| 606 |
+
all_metrics.append(metrics)
|
| 607 |
+
|
| 608 |
+
print(f" [{len(all_metrics)}/{len(generations)}]", end="\r")
|
| 609 |
+
print()
|
| 610 |
+
|
| 611 |
+
if not all_metrics:
|
| 612 |
+
print(" No valid generations for consistency evaluation.")
|
| 613 |
+
return {}
|
| 614 |
+
|
| 615 |
+
n = len(all_metrics)
|
| 616 |
+
agg = {}
|
| 617 |
+
for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim"]:
|
| 618 |
+
values = [m[key] for m in all_metrics]
|
| 619 |
+
agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
|
| 620 |
+
|
| 621 |
+
# Topic drift: how much similarity drops from adjacent to distant chunks
|
| 622 |
+
drift_vals = [m["adjacent_sim"] - m["distant_sim"] for m in all_metrics]
|
| 623 |
+
agg["topic_drift"] = {"mean": sum(drift_vals) / n,
|
| 624 |
+
"min": min(drift_vals), "max": max(drift_vals)}
|
| 625 |
+
|
| 626 |
+
# Print
|
| 627 |
+
print(f"\n Generations evaluated: {n}")
|
| 628 |
+
print(f" Chunk size: {chunk_size} tokens")
|
| 629 |
+
avg_chunks = sum(m["n_chunks"] for m in all_metrics) / n
|
| 630 |
+
print(f" Avg chunks per generation: {avg_chunks:.1f}")
|
| 631 |
+
print()
|
| 632 |
+
print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
|
| 633 |
+
print(f" {'-' * 50}")
|
| 634 |
+
for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim", "topic_drift"]:
|
| 635 |
+
m = agg[key]
|
| 636 |
+
print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
|
| 637 |
+
|
| 638 |
+
return {"per_prompt": all_metrics, "aggregate": agg}
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 642 |
+
# Summary
|
| 643 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 644 |
+
|
| 645 |
+
def print_summary(results):
|
| 646 |
+
"""Print composite summary scores."""
|
| 647 |
+
print("\n" + "=" * 60)
|
| 648 |
+
print("SUMMARY")
|
| 649 |
+
print("=" * 60)
|
| 650 |
+
|
| 651 |
+
scores = {}
|
| 652 |
+
|
| 653 |
+
# Diversity score: mean unique-4gram
|
| 654 |
+
t1 = results.get("tier1_diversity", {})
|
| 655 |
+
if t1 and "aggregate" in t1:
|
| 656 |
+
div_score = t1["aggregate"].get("unique_4g", {}).get("mean", None)
|
| 657 |
+
collapse = t1["aggregate"].get("collapse_rate", None)
|
| 658 |
+
if div_score is not None:
|
| 659 |
+
scores["diversity"] = div_score
|
| 660 |
+
print(f" Diversity (unique 4-gram): {div_score:.3f}", end="")
|
| 661 |
+
if collapse is not None:
|
| 662 |
+
print(f" (collapse: {collapse:.0%})", end="")
|
| 663 |
+
print()
|
| 664 |
+
|
| 665 |
+
# Context utilization ratio
|
| 666 |
+
t2 = results.get("tier2_structural", {})
|
| 667 |
+
if t2:
|
| 668 |
+
ctx = t2.get("context_utilization")
|
| 669 |
+
if ctx:
|
| 670 |
+
buckets = list(ctx.values())
|
| 671 |
+
if len(buckets) >= 2:
|
| 672 |
+
ratio = buckets[0]["ppl"] / buckets[-1]["ppl"]
|
| 673 |
+
scores["context_util"] = ratio
|
| 674 |
+
print(f" Context utilization: {ratio:.2f}x")
|
| 675 |
+
|
| 676 |
+
skip = t2.get("skip_accuracy")
|
| 677 |
+
if skip:
|
| 678 |
+
# Report accuracy at trained distance
|
| 679 |
+
trained_key = None
|
| 680 |
+
for key in skip:
|
| 681 |
+
trained_key = key # use first available
|
| 682 |
+
break
|
| 683 |
+
if trained_key:
|
| 684 |
+
top5 = skip[trained_key]["top5"]
|
| 685 |
+
scores["skip_top5"] = top5
|
| 686 |
+
print(f" Skip accuracy ({trained_key} top-5): {top5:.4f}")
|
| 687 |
+
|
| 688 |
+
# Coherence score: mean chunk similarity
|
| 689 |
+
t3 = results.get("tier3_consistency", {})
|
| 690 |
+
if t3 and "aggregate" in t3:
|
| 691 |
+
coh_score = t3["aggregate"].get("mean_sim", {}).get("mean", None)
|
| 692 |
+
drift = t3["aggregate"].get("topic_drift", {}).get("mean", None)
|
| 693 |
+
if coh_score is not None:
|
| 694 |
+
scores["coherence"] = coh_score
|
| 695 |
+
print(f" Coherence (chunk sim): {coh_score:.3f}", end="")
|
| 696 |
+
if drift is not None:
|
| 697 |
+
print(f" (drift: {drift:.3f})", end="")
|
| 698 |
+
print()
|
| 699 |
+
|
| 700 |
+
results["summary"] = scores
|
| 701 |
+
return scores
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 705 |
+
# Main
|
| 706 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 707 |
+
|
| 708 |
+
def parse_args():
|
| 709 |
+
parser = argparse.ArgumentParser(
|
| 710 |
+
description="Coherence evaluation for language models",
|
| 711 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Model source (mutually exclusive)
|
| 715 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
| 716 |
+
group.add_argument("--checkpoint", type=str, help="Path to circuit model checkpoint")
|
| 717 |
+
group.add_argument("--model", type=str, help="HuggingFace model name or path")
|
| 718 |
+
|
| 719 |
+
# Evaluation config
|
| 720 |
+
parser.add_argument("--prompts", type=str, help="File with prompts (one per line)")
|
| 721 |
+
parser.add_argument("--num-prompts", type=int, default=20,
|
| 722 |
+
help="Number of prompts to use (default: 20)")
|
| 723 |
+
parser.add_argument("--gen-length", type=int, default=512,
|
| 724 |
+
help="Tokens to generate per prompt (default: 512)")
|
| 725 |
+
parser.add_argument("--eval-data", type=str,
|
| 726 |
+
help="Text file for Tier 2 (default: WikiText-103 validation)")
|
| 727 |
+
parser.add_argument("--num-sequences", type=int, default=50,
|
| 728 |
+
help="Number of sequences for Tier 2 (default: 50)")
|
| 729 |
+
parser.add_argument("--chunk-size", type=int, default=128,
|
| 730 |
+
help="Chunk size for Tier 3 similarity (default: 128)")
|
| 731 |
+
parser.add_argument("--distances", type=str, default="2,5,10,25,50,100",
|
| 732 |
+
help="Skip distances for Tier 2b (default: 2,5,10,25,50,100)")
|
| 733 |
+
parser.add_argument("--tiers", type=str, default="1,2,3",
|
| 734 |
+
help="Which tiers to run (default: 1,2,3)")
|
| 735 |
+
|
| 736 |
+
# Hardware
|
| 737 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU index (default: 0)")
|
| 738 |
+
|
| 739 |
+
# Output
|
| 740 |
+
parser.add_argument("--output", type=str, help="Save results to JSON file")
|
| 741 |
+
parser.add_argument("--samples", type=int, default=3,
|
| 742 |
+
help="Number of sample generations to display (default: 3)")
|
| 743 |
+
|
| 744 |
+
return parser.parse_args()
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def main():
|
| 748 |
+
args = parse_args()
|
| 749 |
+
|
| 750 |
+
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| 751 |
+
tiers = [int(t) for t in args.tiers.split(",")]
|
| 752 |
+
distances = [int(d) for d in args.distances.split(",")]
|
| 753 |
+
|
| 754 |
+
# Load model
|
| 755 |
+
print("=" * 60)
|
| 756 |
+
print("Coherence Evaluation")
|
| 757 |
+
print("=" * 60)
|
| 758 |
+
|
| 759 |
+
if args.checkpoint:
|
| 760 |
+
print(f"Loading: {args.checkpoint}")
|
| 761 |
+
wrapper = ModelWrapper.from_checkpoint(args.checkpoint, device)
|
| 762 |
+
else:
|
| 763 |
+
print(f"Loading: {args.model}")
|
| 764 |
+
wrapper = ModelWrapper.from_pretrained(args.model, device)
|
| 765 |
+
|
| 766 |
+
print(f"Model: {wrapper.name}")
|
| 767 |
+
print(f"Device: {device}")
|
| 768 |
+
print(f"Max seq len: {wrapper.max_seq_len}")
|
| 769 |
+
if wrapper.has_skip_head:
|
| 770 |
+
print(f"Skip head: t+{wrapper.skip_k}")
|
| 771 |
+
print(f"Tiers: {tiers}")
|
| 772 |
+
|
| 773 |
+
# Load prompts
|
| 774 |
+
if args.prompts:
|
| 775 |
+
with open(args.prompts) as f:
|
| 776 |
+
prompts = [line.strip() for line in f if line.strip()]
|
| 777 |
+
else:
|
| 778 |
+
prompts = DEFAULT_PROMPTS
|
| 779 |
+
prompts = prompts[:args.num_prompts]
|
| 780 |
+
print(f"Prompts: {len(prompts)}")
|
| 781 |
+
|
| 782 |
+
results = {"model": wrapper.name}
|
| 783 |
+
t0 = time.time()
|
| 784 |
+
|
| 785 |
+
# Generate once for Tier 1 and Tier 3
|
| 786 |
+
generations = None
|
| 787 |
+
if 1 in tiers or 3 in tiers:
|
| 788 |
+
print(f"\nGenerating {args.gen_length} tokens from {len(prompts)} prompts...")
|
| 789 |
+
generations = generate_all(wrapper, prompts, args.gen_length)
|
| 790 |
+
|
| 791 |
+
# Tier 1
|
| 792 |
+
if 1 in tiers and generations:
|
| 793 |
+
results["tier1_diversity"] = eval_diversity(
|
| 794 |
+
generations, wrapper.tokenizer, show_samples=args.samples)
|
| 795 |
+
|
| 796 |
+
# Tier 2
|
| 797 |
+
if 2 in tiers:
|
| 798 |
+
results["tier2_structural"] = eval_structural(
|
| 799 |
+
wrapper, args.eval_data, distances, args.num_sequences)
|
| 800 |
+
|
| 801 |
+
# Tier 3
|
| 802 |
+
if 3 in tiers and generations:
|
| 803 |
+
results["tier3_consistency"] = eval_consistency(
|
| 804 |
+
wrapper, generations, args.chunk_size)
|
| 805 |
+
|
| 806 |
+
# Summary
|
| 807 |
+
print_summary(results)
|
| 808 |
+
|
| 809 |
+
elapsed = time.time() - t0
|
| 810 |
+
print(f"\nTotal time: {elapsed:.0f}s")
|
| 811 |
+
|
| 812 |
+
# Save
|
| 813 |
+
if args.output:
|
| 814 |
+
def make_serializable(obj):
|
| 815 |
+
if isinstance(obj, dict):
|
| 816 |
+
return {k: make_serializable(v) for k, v in obj.items()}
|
| 817 |
+
elif isinstance(obj, list):
|
| 818 |
+
return [make_serializable(v) for v in obj]
|
| 819 |
+
elif isinstance(obj, torch.Tensor):
|
| 820 |
+
return obj.tolist()
|
| 821 |
+
elif isinstance(obj, float):
|
| 822 |
+
if math.isnan(obj) or math.isinf(obj):
|
| 823 |
+
return str(obj)
|
| 824 |
+
return obj
|
| 825 |
+
|
| 826 |
+
out_path = Path(args.output)
|
| 827 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 828 |
+
with open(out_path, "w") as f:
|
| 829 |
+
json.dump(make_serializable(results), f, indent=2)
|
| 830 |
+
print(f"Results saved to {args.output}")
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
if __name__ == "__main__":
|
| 834 |
+
main()
|
config.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for Circuit Transformer experiments.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class CircuitConfig:
|
| 11 |
+
"""Configuration for CircuitTransformer model and training."""
|
| 12 |
+
|
| 13 |
+
# Model architecture
|
| 14 |
+
vocab_size: int = 50257 # GPT-2 tokenizer
|
| 15 |
+
hidden_size: int = 256
|
| 16 |
+
num_heads: int = 8
|
| 17 |
+
num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA)
|
| 18 |
+
num_layers: int = 6
|
| 19 |
+
max_seq_len: int = 512
|
| 20 |
+
dropout: float = 0.0
|
| 21 |
+
|
| 22 |
+
# Training
|
| 23 |
+
batch_size: int = 32
|
| 24 |
+
learning_rate: float = 3e-4
|
| 25 |
+
min_lr: float = 0.0
|
| 26 |
+
weight_decay: float = 0.1
|
| 27 |
+
warmup_steps: int = 100
|
| 28 |
+
epochs: int = 10
|
| 29 |
+
grad_clip: float = 1.0
|
| 30 |
+
reset: bool = False
|
| 31 |
+
|
| 32 |
+
# Hardware
|
| 33 |
+
gpu: int = 0
|
| 34 |
+
fp16: bool = True
|
| 35 |
+
bf16: bool = False
|
| 36 |
+
compile: bool = False
|
| 37 |
+
|
| 38 |
+
# Logging
|
| 39 |
+
log_every: int = 50
|
| 40 |
+
save_every: int = 5000
|
| 41 |
+
checkpoint_dir: str = "./circuits/checkpoints"
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
assert self.hidden_size % self.num_heads == 0, \
|
| 45 |
+
f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
|
| 46 |
+
if self.num_kv_heads is not None:
|
| 47 |
+
assert self.num_heads % self.num_kv_heads == 0, \
|
| 48 |
+
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
| 49 |
+
|
| 50 |
+
# Presets
|
| 51 |
+
@classmethod
|
| 52 |
+
def tiny(cls) -> "CircuitConfig":
|
| 53 |
+
"""~2M params"""
|
| 54 |
+
return cls(hidden_size=128, num_heads=4, num_layers=4)
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def small(cls) -> "CircuitConfig":
|
| 58 |
+
"""~10M params"""
|
| 59 |
+
return cls(hidden_size=256, num_heads=8, num_layers=6)
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def medium(cls) -> "CircuitConfig":
|
| 63 |
+
"""~50M params"""
|
| 64 |
+
return cls(hidden_size=512, num_heads=8, num_layers=12)
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def medium_plus(cls) -> "CircuitConfig":
|
| 68 |
+
"""~50M params"""
|
| 69 |
+
return cls(hidden_size=512, num_heads=8, num_layers=15)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def medium_wide_9(cls) -> "CircuitConfig":
|
| 74 |
+
"""~50M params"""
|
| 75 |
+
return cls(hidden_size=640, num_heads=10, num_layers=9)
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def medium_wide_11(cls) -> "CircuitConfig":
|
| 79 |
+
"""~50M params"""
|
| 80 |
+
return cls(hidden_size=640, num_heads=10, num_layers=11)
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def medium_large(cls) -> "CircuitConfig":
|
| 84 |
+
"""~90M params"""
|
| 85 |
+
return cls(hidden_size=768, num_heads=12, num_layers=12)
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def large(cls) -> "CircuitConfig":
|
| 89 |
+
return cls(hidden_size=1280, num_heads=20, num_layers=11)
|
| 90 |
+
|
| 91 |
+
# Auxiliary objectives
|
| 92 |
+
aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled)
|
| 93 |
+
aux_skip_weight: float = 0.1 # weight for auxiliary skip loss
|
| 94 |
+
|
| 95 |
+
# Word-position RoPE (SemRoPE)
|
| 96 |
+
word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled)
|
| 97 |
+
word_rope_base: float = 10.0 # frequency base for word-position RoPE
|
| 98 |
+
|
| 99 |
+
# Factorized embedding / MLP head
|
| 100 |
+
embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size)
|
| 101 |
+
head_dim: int = 0 # MLP head intermediate dim (0 = linear head)
|
| 102 |
+
|
| 103 |
+
def to_dict(self) -> dict:
|
| 104 |
+
"""Convert to dictionary for serialization."""
|
| 105 |
+
d = {
|
| 106 |
+
"vocab_size": self.vocab_size,
|
| 107 |
+
"hidden_size": self.hidden_size,
|
| 108 |
+
"num_heads": self.num_heads,
|
| 109 |
+
"num_layers": self.num_layers,
|
| 110 |
+
"max_seq_len": self.max_seq_len,
|
| 111 |
+
"dropout": self.dropout,
|
| 112 |
+
}
|
| 113 |
+
if self.num_kv_heads is not None:
|
| 114 |
+
d["num_kv_heads"] = self.num_kv_heads
|
| 115 |
+
if self.aux_skip_k > 0:
|
| 116 |
+
d["aux_skip_k"] = self.aux_skip_k
|
| 117 |
+
d["aux_skip_weight"] = self.aux_skip_weight
|
| 118 |
+
if self.word_rope_dims > 0:
|
| 119 |
+
d["word_rope_dims"] = self.word_rope_dims
|
| 120 |
+
d["word_rope_base"] = self.word_rope_base
|
| 121 |
+
if self.embed_dim > 0:
|
| 122 |
+
d["embed_dim"] = self.embed_dim
|
| 123 |
+
if self.head_dim > 0:
|
| 124 |
+
d["head_dim"] = self.head_dim
|
| 125 |
+
return d
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def from_dict(cls, d: dict) -> "CircuitConfig":
|
| 129 |
+
"""Create from dictionary."""
|
| 130 |
+
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def parse_args() -> tuple[CircuitConfig, argparse.Namespace]:
|
| 134 |
+
"""Parse CLI arguments and return config + extra args."""
|
| 135 |
+
parser = argparse.ArgumentParser(description="Circuit Transformer Training")
|
| 136 |
+
|
| 137 |
+
# Data
|
| 138 |
+
parser.add_argument("--data", type=str, required=True,
|
| 139 |
+
help="Data source: path/to/file.txt, path/to/dir/, or hf:dataset_name")
|
| 140 |
+
parser.add_argument("--text-column", type=str, default="text",
|
| 141 |
+
help="Column name for HF datasets (default: text)")
|
| 142 |
+
parser.add_argument("--data-format", type=str, choices=["text", "chat"], default="text",
|
| 143 |
+
help="Data format: text (single column) or chat (system + conversations)")
|
| 144 |
+
parser.add_argument("--num-samples", type=int, default=None,
|
| 145 |
+
help="Limit samples from HF dataset")
|
| 146 |
+
parser.add_argument("--cache-dir", type=str, default="./circuits/.cache",
|
| 147 |
+
help="Cache directory for tokenized data")
|
| 148 |
+
parser.add_argument("--no-cache", action="store_true",
|
| 149 |
+
help="Disable data caching")
|
| 150 |
+
parser.add_argument("--val-split", type=float, default=0.05,
|
| 151 |
+
help="Fraction of data for validation (default: 0.05, 0 to disable)")
|
| 152 |
+
|
| 153 |
+
# Model architecture
|
| 154 |
+
# TODO: Remove `slot_mirrored`
|
| 155 |
+
parser.add_argument("--arch", type=str, choices=["standard", "mirrored", "graft_g2lu"], default="standard",
|
| 156 |
+
help="Model architecture (default: standard)")
|
| 157 |
+
parser.add_argument("--preset", type=str, choices=["tiny", "small", "medium", "medium_plus", "medium_large", "medium_wide_9", "medium_wide_11", "large"],
|
| 158 |
+
help="Use preset configuration")
|
| 159 |
+
parser.add_argument("--dims", type=int, default=None, help="Hidden size")
|
| 160 |
+
parser.add_argument("--layers", type=int, default=None, help="Number of layers")
|
| 161 |
+
parser.add_argument("--heads", type=int, default=None, help="Number of attention heads")
|
| 162 |
+
parser.add_argument("--kv-heads", type=int, default=None,
|
| 163 |
+
help="Number of KV heads for GQA (default: same as --heads for MHA)")
|
| 164 |
+
parser.add_argument("--context-length", type=int, default=None, help="Max sequence length")
|
| 165 |
+
parser.add_argument("--dropout", type=float, default=None, help="Dropout rate")
|
| 166 |
+
parser.add_argument("--tokenizer", type=str, default="gpt2",
|
| 167 |
+
help="Tokenizer to use (default: gpt2, e.g. facebook/MobileLLM-125M)")
|
| 168 |
+
|
| 169 |
+
# Mirrored architecture specific
|
| 170 |
+
parser.add_argument("--n-middle", type=int, default=2,
|
| 171 |
+
help="Unique middle layers for mirrored arch (default: 2)")
|
| 172 |
+
parser.add_argument("--share-attention", action="store_true", default=True,
|
| 173 |
+
help="Share attention weights between mirror pairs (default)")
|
| 174 |
+
parser.add_argument("--no-share-attention", dest="share_attention", action="store_false",
|
| 175 |
+
help="Separate attention weights per direction")
|
| 176 |
+
|
| 177 |
+
# G²LU gating
|
| 178 |
+
parser.add_argument("--no-g2lu", action="store_true",
|
| 179 |
+
help="Disable G²LU (use vanilla SwiGLU in mirrored arch)")
|
| 180 |
+
|
| 181 |
+
# Auxiliary objectives
|
| 182 |
+
parser.add_argument("--aux-skip", type=int, default=0,
|
| 183 |
+
help="Skip-ahead prediction distance (0 = disabled, e.g. 5 predicts t+5)")
|
| 184 |
+
parser.add_argument("--aux-weight", type=float, default=0.1,
|
| 185 |
+
help="Weight for auxiliary skip loss (default: 0.1)")
|
| 186 |
+
|
| 187 |
+
# Word-position RoPE (SemRoPE)
|
| 188 |
+
parser.add_argument("--word-rope-dims", type=int, default=0,
|
| 189 |
+
help="Head dims dedicated to word-position RoPE (0=disabled, try 8 or 16)")
|
| 190 |
+
parser.add_argument("--word-rope-base", type=float, default=10.0,
|
| 191 |
+
help="Frequency base for word-position RoPE (default: 10.0)")
|
| 192 |
+
|
| 193 |
+
# Factorized embedding / MLP head
|
| 194 |
+
parser.add_argument("--embed-dim", type=int, default=0,
|
| 195 |
+
help="Factorized embedding dim (0=use hidden_size, e.g. 256)")
|
| 196 |
+
parser.add_argument("--head-dim", type=int, default=0,
|
| 197 |
+
help="MLP head intermediate dim (0=linear head, e.g. 512)")
|
| 198 |
+
|
| 199 |
+
# G²LU gate grafting
|
| 200 |
+
parser.add_argument("--pretrained", type=str, default=None,
|
| 201 |
+
help="HuggingFace model for graft_g2lu (e.g. meta-llama/Llama-3.2-1B)")
|
| 202 |
+
parser.add_argument("--align-weight", type=float, default=1.0,
|
| 203 |
+
help="Alignment loss weight for G²LU grafting (default: 1.0)")
|
| 204 |
+
parser.add_argument("--graft-warmup", type=int, default=500,
|
| 205 |
+
help="Blend warmup steps: SwiGLU→G²LU transition (default: 500)")
|
| 206 |
+
|
| 207 |
+
# Training
|
| 208 |
+
parser.add_argument("--epochs", type=int, default=None)
|
| 209 |
+
parser.add_argument("--batch-size", type=int, default=None)
|
| 210 |
+
parser.add_argument("--lr", type=float, default=None, help="Learning rate")
|
| 211 |
+
parser.add_argument("--min-lr", type=float, default=None,
|
| 212 |
+
help="Minimum learning rate for cosine decay (default: 0)")
|
| 213 |
+
parser.add_argument("--weight-decay", type=float, default=None)
|
| 214 |
+
parser.add_argument("--warmup-steps", type=int, default=None)
|
| 215 |
+
parser.add_argument("--grad-clip", type=float, default=None)
|
| 216 |
+
parser.add_argument("--grad-accum", type=int, default=1,
|
| 217 |
+
help="Gradient accumulation steps (effective batch = batch_size * grad_accum)")
|
| 218 |
+
|
| 219 |
+
# Hardware
|
| 220 |
+
parser.add_argument("--gpu", type=int, default=0)
|
| 221 |
+
parser.add_argument("--fp16", action="store_true", help="Use FP16 mixed precision (with GradScaler)")
|
| 222 |
+
parser.add_argument("--bf16", action="store_true", help="Use BF16 mixed precision (no scaler needed)")
|
| 223 |
+
parser.add_argument("--no-fp16", action="store_true", help="Disable mixed precision (FP32)")
|
| 224 |
+
parser.add_argument("--compile", action="store_true", help="Use torch.compile")
|
| 225 |
+
|
| 226 |
+
# Logging/Checkpointing
|
| 227 |
+
parser.add_argument("--log-every", type=int, default=None)
|
| 228 |
+
parser.add_argument("--save-every", type=int, default=None)
|
| 229 |
+
parser.add_argument("--val-every", type=int, default=0,
|
| 230 |
+
help="Run validation every N steps (0 = only at epoch end)")
|
| 231 |
+
parser.add_argument("--checkpoint-dir", type=str, default=None)
|
| 232 |
+
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
|
| 233 |
+
parser.add_argument("--reset", action="store_true", default=False, help="When resuming the training resets steps and optimizers")
|
| 234 |
+
|
| 235 |
+
args = parser.parse_args()
|
| 236 |
+
|
| 237 |
+
# Build config from preset or defaults
|
| 238 |
+
if args.preset:
|
| 239 |
+
config = getattr(CircuitConfig, args.preset)()
|
| 240 |
+
else:
|
| 241 |
+
config = CircuitConfig()
|
| 242 |
+
|
| 243 |
+
# Override with explicit args
|
| 244 |
+
if args.dims is not None:
|
| 245 |
+
config.hidden_size = args.dims
|
| 246 |
+
if args.layers is not None:
|
| 247 |
+
config.num_layers = args.layers
|
| 248 |
+
if args.heads is not None:
|
| 249 |
+
config.num_heads = args.heads
|
| 250 |
+
if args.kv_heads is not None:
|
| 251 |
+
config.num_kv_heads = args.kv_heads
|
| 252 |
+
if args.context_length is not None:
|
| 253 |
+
config.max_seq_len = args.context_length
|
| 254 |
+
if args.dropout is not None:
|
| 255 |
+
config.dropout = args.dropout
|
| 256 |
+
if args.epochs is not None:
|
| 257 |
+
config.epochs = args.epochs
|
| 258 |
+
if args.batch_size is not None:
|
| 259 |
+
config.batch_size = args.batch_size
|
| 260 |
+
if args.lr is not None:
|
| 261 |
+
config.learning_rate = args.lr
|
| 262 |
+
if args.min_lr is not None:
|
| 263 |
+
config.min_lr = args.min_lr
|
| 264 |
+
if args.weight_decay is not None:
|
| 265 |
+
config.weight_decay = args.weight_decay
|
| 266 |
+
if args.warmup_steps is not None:
|
| 267 |
+
config.warmup_steps = args.warmup_steps
|
| 268 |
+
if args.grad_clip is not None:
|
| 269 |
+
config.grad_clip = args.grad_clip
|
| 270 |
+
if args.log_every is not None:
|
| 271 |
+
config.log_every = args.log_every
|
| 272 |
+
if args.save_every is not None:
|
| 273 |
+
config.save_every = args.save_every
|
| 274 |
+
if args.checkpoint_dir is not None:
|
| 275 |
+
config.checkpoint_dir = args.checkpoint_dir
|
| 276 |
+
|
| 277 |
+
# Auxiliary objectives
|
| 278 |
+
if args.aux_skip > 0:
|
| 279 |
+
config.aux_skip_k = args.aux_skip
|
| 280 |
+
config.aux_skip_weight = args.aux_weight
|
| 281 |
+
|
| 282 |
+
# Word-position RoPE
|
| 283 |
+
if args.word_rope_dims > 0:
|
| 284 |
+
config.word_rope_dims = args.word_rope_dims
|
| 285 |
+
config.word_rope_base = args.word_rope_base
|
| 286 |
+
|
| 287 |
+
# Factorized embedding / MLP head
|
| 288 |
+
if args.embed_dim > 0:
|
| 289 |
+
config.embed_dim = args.embed_dim
|
| 290 |
+
if args.head_dim > 0:
|
| 291 |
+
config.head_dim = args.head_dim
|
| 292 |
+
|
| 293 |
+
config.gpu = args.gpu
|
| 294 |
+
if args.bf16:
|
| 295 |
+
config.bf16 = True
|
| 296 |
+
config.fp16 = False
|
| 297 |
+
elif args.no_fp16:
|
| 298 |
+
config.fp16 = False
|
| 299 |
+
config.bf16 = False
|
| 300 |
+
elif args.fp16:
|
| 301 |
+
config.fp16 = True
|
| 302 |
+
config.bf16 = False
|
| 303 |
+
config.compile = args.compile
|
| 304 |
+
config.reset = args.reset
|
| 305 |
+
|
| 306 |
+
return config, args
|
data.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading utilities for Circuit Transformer.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- Single text file: --data path/to/file.txt
|
| 6 |
+
- Directory of text files: --data path/to/dir/
|
| 7 |
+
- HuggingFace dataset: --data hf:dataset_name
|
| 8 |
+
|
| 9 |
+
Caching:
|
| 10 |
+
- HF datasets: memory-mapped binary files (.bin) — O(1) RAM
|
| 11 |
+
- Text files: torch .pt files (legacy, in-memory)
|
| 12 |
+
- Cache location: ./circuits/.cache/ (or custom via cache_dir)
|
| 13 |
+
|
| 14 |
+
Parallelism:
|
| 15 |
+
- HF datasets tokenized via dataset.map(num_proc=N) — multiprocessing, bypasses GIL
|
| 16 |
+
- Fast tokenizer uses Rust internally — additional parallelism within each worker
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import struct
|
| 21 |
+
import hashlib
|
| 22 |
+
import multiprocessing
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
from torch.utils.data import Dataset, DataLoader
|
| 28 |
+
|
| 29 |
+
DEFAULT_CACHE_DIR = "./circuits/.cache"
|
| 30 |
+
|
| 31 |
+
# Memmap binary format:
|
| 32 |
+
# Header: 8 bytes = [uint32 n_chunks, uint32 max_seq_len]
|
| 33 |
+
# Data: n_chunks * max_seq_len * 4 bytes (int32, row-major)
|
| 34 |
+
HEADER_SIZE = 8
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Cache utilities
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
def _cache_key(data_source: str, max_seq_len: int, num_samples: int | None) -> str:
|
| 42 |
+
"""Generate cache filename from parameters."""
|
| 43 |
+
key_str = f"{data_source}|{max_seq_len}|{num_samples}"
|
| 44 |
+
hash_val = hashlib.md5(key_str.encode()).hexdigest()[:12]
|
| 45 |
+
name = data_source.replace("/", "_").replace(":", "_").replace(".", "_")[-30:]
|
| 46 |
+
return f"{name}_{max_seq_len}_{hash_val}.bin"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# Dataset classes
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class MemmapDataset(Dataset):
|
| 54 |
+
"""Dataset backed by memory-mapped binary file. O(1) RAM regardless of size."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, path, start=None, end=None):
|
| 57 |
+
self.path = str(path)
|
| 58 |
+
with open(self.path, 'rb') as f:
|
| 59 |
+
total, self.max_seq_len = struct.unpack('II', f.read(HEADER_SIZE))
|
| 60 |
+
self._total = total
|
| 61 |
+
self.data = np.memmap(
|
| 62 |
+
self.path, dtype=np.int32, mode='r',
|
| 63 |
+
offset=HEADER_SIZE, shape=(total, self.max_seq_len),
|
| 64 |
+
)
|
| 65 |
+
self.start = start if start is not None else 0
|
| 66 |
+
self.end = end if end is not None else total
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return self.end - self.start
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx):
|
| 72 |
+
tokens = torch.from_numpy(self.data[self.start + idx].copy()).long()
|
| 73 |
+
return {"input_ids": tokens, "labels": tokens.clone()}
|
| 74 |
+
|
| 75 |
+
def split(self, val_fraction=0.1):
|
| 76 |
+
"""Split into (train, val) datasets. Both share the same memmap file."""
|
| 77 |
+
total = self.end - self.start
|
| 78 |
+
n_val = max(1, int(total * val_fraction))
|
| 79 |
+
train = MemmapDataset(self.path, self.start, self.end - n_val)
|
| 80 |
+
val = MemmapDataset(self.path, self.end - n_val, self.end)
|
| 81 |
+
return train, val
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TextDataset(Dataset):
|
| 85 |
+
"""Simple in-memory dataset from tokenized chunks. For small datasets."""
|
| 86 |
+
|
| 87 |
+
def __init__(self, token_chunks: list[list[int]], max_seq_len: int):
|
| 88 |
+
self.chunks = token_chunks
|
| 89 |
+
self.max_seq_len = max_seq_len
|
| 90 |
+
|
| 91 |
+
def __len__(self) -> int:
|
| 92 |
+
return len(self.chunks)
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 95 |
+
tokens = self.chunks[idx]
|
| 96 |
+
if len(tokens) < self.max_seq_len:
|
| 97 |
+
tokens = tokens + [0] * (self.max_seq_len - len(tokens))
|
| 98 |
+
else:
|
| 99 |
+
tokens = tokens[: self.max_seq_len]
|
| 100 |
+
input_ids = torch.tensor(tokens, dtype=torch.long)
|
| 101 |
+
return {"input_ids": input_ids, "labels": input_ids.clone()}
|
| 102 |
+
|
| 103 |
+
def split(self, val_fraction=0.1):
|
| 104 |
+
"""Split into (train, val) datasets with shuffle."""
|
| 105 |
+
import random
|
| 106 |
+
random.shuffle(self.chunks)
|
| 107 |
+
n_val = max(1, int(len(self.chunks) * val_fraction))
|
| 108 |
+
val = TextDataset(self.chunks[:n_val], self.max_seq_len)
|
| 109 |
+
train = TextDataset(self.chunks[n_val:], self.max_seq_len)
|
| 110 |
+
return train, val
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
# Tokenizer
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
class _SentencePieceTokenizer:
|
| 118 |
+
"""Minimal tokenizer wrapper using sentencepiece directly.
|
| 119 |
+
Bypasses transformers tokenizer bugs across versions."""
|
| 120 |
+
|
| 121 |
+
def __init__(self, model_path, name):
|
| 122 |
+
import sentencepiece as spm
|
| 123 |
+
self.sp = spm.SentencePieceProcessor()
|
| 124 |
+
self.sp.Load(model_path)
|
| 125 |
+
self._vocab_size = self.sp.GetPieceSize()
|
| 126 |
+
self.eos_token_id = self.sp.eos_id()
|
| 127 |
+
self.bos_token_id = self.sp.bos_id()
|
| 128 |
+
self.eos_token = self.sp.IdToPiece(self.eos_token_id)
|
| 129 |
+
self.bos_token = self.sp.IdToPiece(self.bos_token_id)
|
| 130 |
+
self.pad_token = None
|
| 131 |
+
self.pad_token_id = None
|
| 132 |
+
self.name_or_path = name
|
| 133 |
+
|
| 134 |
+
def __len__(self):
|
| 135 |
+
return self._vocab_size
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def vocab_size(self):
|
| 139 |
+
return self._vocab_size
|
| 140 |
+
|
| 141 |
+
def encode(self, text, add_special_tokens=False, return_tensors=None):
|
| 142 |
+
ids = self.sp.Encode(text)
|
| 143 |
+
if return_tensors == "pt":
|
| 144 |
+
import torch
|
| 145 |
+
return torch.tensor([ids])
|
| 146 |
+
return ids
|
| 147 |
+
|
| 148 |
+
def decode(self, ids, skip_special_tokens=False):
|
| 149 |
+
if hasattr(ids, 'tolist'):
|
| 150 |
+
ids = ids.tolist()
|
| 151 |
+
return self.sp.Decode(list(ids))
|
| 152 |
+
|
| 153 |
+
def __call__(self, texts, add_special_tokens=False, **kwargs):
|
| 154 |
+
if isinstance(texts, str):
|
| 155 |
+
texts = [texts]
|
| 156 |
+
return {"input_ids": [self.sp.Encode(t) for t in texts]}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_tokenizer(name: str = "gpt2"):
|
| 160 |
+
"""Get tokenizer from HuggingFace, with sentencepiece fallback.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
name: Tokenizer name or path. Default "gpt2" (50257 vocab).
|
| 164 |
+
Use e.g. "facebook/MobileLLM-125M" for 32K vocab.
|
| 165 |
+
"""
|
| 166 |
+
from transformers import AutoTokenizer
|
| 167 |
+
|
| 168 |
+
# Try AutoTokenizer (fast then slow)
|
| 169 |
+
for use_fast in (True, False):
|
| 170 |
+
try:
|
| 171 |
+
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=use_fast,
|
| 172 |
+
trust_remote_code=True)
|
| 173 |
+
if isinstance(tokenizer, bool):
|
| 174 |
+
continue
|
| 175 |
+
if tokenizer.pad_token is None:
|
| 176 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 177 |
+
return tokenizer
|
| 178 |
+
except Exception:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Fallback: load sentencepiece model directly (bypasses transformers bugs)
|
| 182 |
+
print(f"AutoTokenizer failed for {name}, falling back to sentencepiece")
|
| 183 |
+
from huggingface_hub import hf_hub_download
|
| 184 |
+
model_path = hf_hub_download(name, "tokenizer.model")
|
| 185 |
+
tokenizer = _SentencePieceTokenizer(model_path, name)
|
| 186 |
+
if tokenizer.pad_token is None:
|
| 187 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 188 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 189 |
+
return tokenizer
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
# Streaming memmap writer
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
|
| 196 |
+
def _stream_chunks_to_memmap(tokenized, total_examples, max_seq_len, output_path,
|
| 197 |
+
num_workers=1, read_batch=10_000):
|
| 198 |
+
"""Stream tokenized examples into a memory-mapped binary file.
|
| 199 |
+
|
| 200 |
+
Single-process, numpy-batch approach. Reads batches from Arrow dataset,
|
| 201 |
+
flattens to numpy int32, writes complete chunks to disk.
|
| 202 |
+
Memory: O(read_batch * avg_seq_len * 4 bytes).
|
| 203 |
+
No fork, no multiprocessing, no OOM.
|
| 204 |
+
"""
|
| 205 |
+
from itertools import chain
|
| 206 |
+
from tqdm import tqdm
|
| 207 |
+
|
| 208 |
+
temp_path = str(output_path) + ".tmp"
|
| 209 |
+
n_chunks = 0
|
| 210 |
+
total_tokens = 0
|
| 211 |
+
carryover = np.array([], dtype=np.int32)
|
| 212 |
+
|
| 213 |
+
n_batches = (total_examples + read_batch - 1) // read_batch
|
| 214 |
+
|
| 215 |
+
with open(temp_path, 'wb') as f:
|
| 216 |
+
f.write(struct.pack('II', 0, max_seq_len)) # placeholder header
|
| 217 |
+
|
| 218 |
+
for batch_start in tqdm(range(0, total_examples, read_batch),
|
| 219 |
+
total=n_batches, desc="Chunking",
|
| 220 |
+
mininterval=1.0):
|
| 221 |
+
batch_end = min(batch_start + read_batch, total_examples)
|
| 222 |
+
batch_ids = tokenized[batch_start:batch_end]["input_ids"]
|
| 223 |
+
|
| 224 |
+
# Count tokens, flatten Arrow→numpy without intermediate Python list
|
| 225 |
+
n_tok = sum(len(ids) for ids in batch_ids if ids)
|
| 226 |
+
if n_tok == 0:
|
| 227 |
+
del batch_ids
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
flat = np.fromiter(
|
| 231 |
+
chain.from_iterable(ids for ids in batch_ids if ids),
|
| 232 |
+
dtype=np.int32, count=n_tok,
|
| 233 |
+
)
|
| 234 |
+
del batch_ids
|
| 235 |
+
total_tokens += n_tok
|
| 236 |
+
|
| 237 |
+
# Prepend carryover from previous batch
|
| 238 |
+
if len(carryover) > 0:
|
| 239 |
+
flat = np.concatenate([carryover, flat])
|
| 240 |
+
|
| 241 |
+
# Write complete chunks
|
| 242 |
+
n_complete = len(flat) // max_seq_len
|
| 243 |
+
if n_complete > 0:
|
| 244 |
+
f.write(flat[:n_complete * max_seq_len].tobytes())
|
| 245 |
+
n_chunks += n_complete
|
| 246 |
+
|
| 247 |
+
carryover = flat[n_complete * max_seq_len:].copy()
|
| 248 |
+
del flat
|
| 249 |
+
|
| 250 |
+
# Handle remaining tokens
|
| 251 |
+
if len(carryover) >= 32:
|
| 252 |
+
padded = np.zeros(max_seq_len, dtype=np.int32)
|
| 253 |
+
padded[:len(carryover)] = carryover
|
| 254 |
+
f.write(padded.tobytes())
|
| 255 |
+
n_chunks += 1
|
| 256 |
+
|
| 257 |
+
# Write actual count into header
|
| 258 |
+
f.seek(0)
|
| 259 |
+
f.write(struct.pack('II', n_chunks, max_seq_len))
|
| 260 |
+
|
| 261 |
+
os.rename(temp_path, str(output_path))
|
| 262 |
+
size_gb = os.path.getsize(output_path) / 1e9
|
| 263 |
+
print(f"Total tokens: {total_tokens:,} → {n_chunks:,} chunks ({size_gb:.1f} GB)")
|
| 264 |
+
return n_chunks
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ---------------------------------------------------------------------------
|
| 268 |
+
# HuggingFace dataset loader (parallel + memmap)
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
|
| 271 |
+
def _flatten_chat(example):
|
| 272 |
+
"""Convert chat format (system + conversations list) to plain text.
|
| 273 |
+
|
| 274 |
+
Handles datasets like Bespoke-Stratos-17k and OpenThoughts-114k
|
| 275 |
+
which store data as: system (str) + conversations (list of {from, value}).
|
| 276 |
+
"""
|
| 277 |
+
parts = []
|
| 278 |
+
if example.get("system"):
|
| 279 |
+
parts.append(example["system"].strip())
|
| 280 |
+
for msg in example.get("conversations", []):
|
| 281 |
+
value = msg.get("value", "")
|
| 282 |
+
if value:
|
| 283 |
+
parts.append(value.strip())
|
| 284 |
+
return {"text": "\n\n".join(parts)}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _estimate_avg_chars(dataset, text_column: str, n_sample: int = 200) -> float:
|
| 288 |
+
"""Estimate average text length from a sample of the dataset."""
|
| 289 |
+
n = min(n_sample, len(dataset))
|
| 290 |
+
total = sum(len(dataset[i][text_column] or "") for i in range(n))
|
| 291 |
+
return total / max(n, 1)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _adaptive_params(avg_chars: float, n_examples: int):
|
| 295 |
+
"""Scale worker count, batch sizes based on average example length.
|
| 296 |
+
|
| 297 |
+
Long examples (chain-of-thought reasoning) need smaller batches and fewer
|
| 298 |
+
workers to avoid OOM on memory-constrained systems (especially WSL).
|
| 299 |
+
"""
|
| 300 |
+
cpu_count = max(1, multiprocessing.cpu_count() - 1)
|
| 301 |
+
|
| 302 |
+
if avg_chars > 20_000: # very long (OpenThoughts-style, ~7K+ tokens)
|
| 303 |
+
num_proc = min(cpu_count, 4)
|
| 304 |
+
tok_batch = 64
|
| 305 |
+
read_batch = 500
|
| 306 |
+
elif avg_chars > 5_000: # long (detailed SFT, ~1.5K+ tokens)
|
| 307 |
+
num_proc = min(cpu_count, 8)
|
| 308 |
+
tok_batch = 256
|
| 309 |
+
read_batch = 2_000
|
| 310 |
+
elif avg_chars > 1_000: # medium (typical SFT)
|
| 311 |
+
num_proc = min(cpu_count, 16)
|
| 312 |
+
tok_batch = 500
|
| 313 |
+
read_batch = 5_000
|
| 314 |
+
else: # short (web text, wiki)
|
| 315 |
+
num_proc = min(cpu_count, 32)
|
| 316 |
+
tok_batch = 1000
|
| 317 |
+
read_batch = 10_000
|
| 318 |
+
|
| 319 |
+
return num_proc, tok_batch, read_batch
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def load_hf_dataset(
|
| 323 |
+
name: str,
|
| 324 |
+
split: str,
|
| 325 |
+
text_column: str,
|
| 326 |
+
tokenizer,
|
| 327 |
+
max_seq_len: int,
|
| 328 |
+
num_samples: int | None = None,
|
| 329 |
+
hf_config: str | None = None,
|
| 330 |
+
cache_path: Path | None = None,
|
| 331 |
+
data_format: str = "text",
|
| 332 |
+
) -> MemmapDataset:
|
| 333 |
+
"""Load HF dataset with parallel tokenization and streaming to memmap.
|
| 334 |
+
|
| 335 |
+
Parallelism:
|
| 336 |
+
- dataset.map(num_proc=N) uses multiprocessing — bypasses GIL
|
| 337 |
+
- GPT2TokenizerFast runs Rust tokenization — bypasses GIL
|
| 338 |
+
- batched=True enables efficient batch processing
|
| 339 |
+
|
| 340 |
+
Memory:
|
| 341 |
+
- Adaptive batch sizes based on avg example length — prevents OOM on long sequences
|
| 342 |
+
- Tokenized data in Arrow format (memory-mapped by HuggingFace)
|
| 343 |
+
- Chunks streamed to binary memmap file — never in RAM
|
| 344 |
+
"""
|
| 345 |
+
from datasets import load_dataset
|
| 346 |
+
|
| 347 |
+
config_str = f", config={hf_config}" if hf_config else ""
|
| 348 |
+
print(f"Loading HF dataset: {name} (split={split}{config_str})")
|
| 349 |
+
dataset = load_dataset(name, hf_config, split=split)
|
| 350 |
+
|
| 351 |
+
if num_samples is not None:
|
| 352 |
+
dataset = dataset.select(range(min(num_samples, len(dataset))))
|
| 353 |
+
|
| 354 |
+
# Flatten chat format to plain text
|
| 355 |
+
if data_format == "chat":
|
| 356 |
+
# Use conservative parallelism for flattening — light operation
|
| 357 |
+
flat_proc = min(max(1, multiprocessing.cpu_count() - 1), 8)
|
| 358 |
+
print(f"Flattening {len(dataset):,} chat examples to plain text...")
|
| 359 |
+
dataset = dataset.map(
|
| 360 |
+
_flatten_chat,
|
| 361 |
+
num_proc=flat_proc,
|
| 362 |
+
remove_columns=dataset.column_names,
|
| 363 |
+
desc="Flattening chat",
|
| 364 |
+
)
|
| 365 |
+
text_column = "text"
|
| 366 |
+
|
| 367 |
+
# Estimate avg example length and adapt parameters
|
| 368 |
+
avg_chars = _estimate_avg_chars(dataset, text_column)
|
| 369 |
+
num_proc, tok_batch, read_batch = _adaptive_params(avg_chars, len(dataset))
|
| 370 |
+
print(f" Avg example length: ~{avg_chars:,.0f} chars → "
|
| 371 |
+
f"{num_proc} workers, tok_batch={tok_batch}, read_batch={read_batch}")
|
| 372 |
+
|
| 373 |
+
# Filter empty examples
|
| 374 |
+
print(f"Filtering empty examples from {len(dataset):,}...")
|
| 375 |
+
dataset = dataset.filter(
|
| 376 |
+
lambda x: bool(x[text_column] and x[text_column].strip()),
|
| 377 |
+
num_proc=num_proc,
|
| 378 |
+
desc="Filtering",
|
| 379 |
+
)
|
| 380 |
+
print(f" {len(dataset):,} non-empty examples")
|
| 381 |
+
|
| 382 |
+
# Parallel tokenization
|
| 383 |
+
print(f"Tokenizing {len(dataset):,} examples with {num_proc} workers...")
|
| 384 |
+
|
| 385 |
+
def tokenize_batch(examples):
|
| 386 |
+
return tokenizer(examples[text_column], add_special_tokens=False)
|
| 387 |
+
|
| 388 |
+
tokenized = dataset.map(
|
| 389 |
+
tokenize_batch,
|
| 390 |
+
batched=True,
|
| 391 |
+
batch_size=tok_batch,
|
| 392 |
+
num_proc=num_proc,
|
| 393 |
+
remove_columns=dataset.column_names,
|
| 394 |
+
desc="Tokenizing",
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Stream to memmap — use temp path if no cache configured
|
| 398 |
+
if cache_path is None:
|
| 399 |
+
import tempfile
|
| 400 |
+
cache_path = Path(tempfile.mktemp(suffix='.bin'))
|
| 401 |
+
|
| 402 |
+
_stream_chunks_to_memmap(tokenized, len(tokenized), max_seq_len, cache_path,
|
| 403 |
+
read_batch=read_batch)
|
| 404 |
+
return MemmapDataset(cache_path)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
# ---------------------------------------------------------------------------
|
| 408 |
+
# Text file loaders (unchanged — small datasets, in-memory is fine)
|
| 409 |
+
# ---------------------------------------------------------------------------
|
| 410 |
+
|
| 411 |
+
def tokenize_text(text: str, tokenizer, max_seq_len: int) -> list[list[int]]:
|
| 412 |
+
"""Tokenize text into chunks of max_seq_len."""
|
| 413 |
+
tokens = tokenizer.encode(text)
|
| 414 |
+
chunks = []
|
| 415 |
+
for i in range(0, len(tokens), max_seq_len):
|
| 416 |
+
chunk = tokens[i : i + max_seq_len]
|
| 417 |
+
if len(chunk) >= 32:
|
| 418 |
+
chunks.append(chunk)
|
| 419 |
+
return chunks
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def load_text_file(path: str, tokenizer, max_seq_len: int) -> list[list[int]]:
|
| 423 |
+
"""Load and tokenize a single text file."""
|
| 424 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 425 |
+
text = f.read()
|
| 426 |
+
return tokenize_text(text, tokenizer, max_seq_len)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def load_text_directory(path: str, tokenizer, max_seq_len: int) -> list[list[int]]:
|
| 430 |
+
"""Load and tokenize all .txt files from a directory."""
|
| 431 |
+
all_chunks = []
|
| 432 |
+
path = Path(path)
|
| 433 |
+
for txt_file in sorted(path.glob("**/*.txt")):
|
| 434 |
+
chunks = load_text_file(str(txt_file), tokenizer, max_seq_len)
|
| 435 |
+
all_chunks.extend(chunks)
|
| 436 |
+
return all_chunks
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# ---------------------------------------------------------------------------
|
| 440 |
+
# Main entry point
|
| 441 |
+
# ---------------------------------------------------------------------------
|
| 442 |
+
|
| 443 |
+
def load_data(
|
| 444 |
+
data_source: str,
|
| 445 |
+
tokenizer,
|
| 446 |
+
max_seq_len: int,
|
| 447 |
+
text_column: str = "text",
|
| 448 |
+
num_samples: int | None = None,
|
| 449 |
+
cache_dir: str | None = DEFAULT_CACHE_DIR,
|
| 450 |
+
data_format: str = "text",
|
| 451 |
+
) -> Dataset:
|
| 452 |
+
"""
|
| 453 |
+
Load data from various sources. Returns a Dataset with .split() support.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
data_source: Path or HF dataset identifier
|
| 457 |
+
- "path/to/file.txt" — single file
|
| 458 |
+
- "path/to/dir/" — directory of .txt files
|
| 459 |
+
- "hf:dataset_name" — HuggingFace dataset (train split)
|
| 460 |
+
- "hf:dataset:split" — HuggingFace with specific split
|
| 461 |
+
- "hf:dataset:config:split" — with config and split
|
| 462 |
+
tokenizer: Tokenizer to use
|
| 463 |
+
max_seq_len: Maximum sequence length
|
| 464 |
+
text_column: Column name for HF datasets
|
| 465 |
+
num_samples: Limit samples from HF dataset
|
| 466 |
+
cache_dir: Directory for cache files (None to disable)
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
Dataset object supporting len(), __getitem__(), and split(fraction)
|
| 470 |
+
"""
|
| 471 |
+
cache_path = None
|
| 472 |
+
if cache_dir is not None:
|
| 473 |
+
cache_path = Path(cache_dir) / _cache_key(data_source, max_seq_len, num_samples)
|
| 474 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 475 |
+
|
| 476 |
+
# Check for memmap cache (.bin)
|
| 477 |
+
if cache_path.exists():
|
| 478 |
+
print(f"Loading from cache: {cache_path}")
|
| 479 |
+
ds = MemmapDataset(cache_path)
|
| 480 |
+
print(f" Loaded {len(ds):,} chunks")
|
| 481 |
+
return ds
|
| 482 |
+
|
| 483 |
+
# Check for legacy cache (.pt)
|
| 484 |
+
legacy_path = cache_path.with_suffix('.pt')
|
| 485 |
+
if legacy_path.exists():
|
| 486 |
+
print(f"Loading from legacy cache: {legacy_path}")
|
| 487 |
+
data = torch.load(legacy_path, weights_only=False)
|
| 488 |
+
chunks = data["chunks"]
|
| 489 |
+
print(f" Loaded {len(chunks):,} chunks")
|
| 490 |
+
return TextDataset(chunks, max_seq_len)
|
| 491 |
+
|
| 492 |
+
# Load and tokenize
|
| 493 |
+
if data_source.startswith("hf:"):
|
| 494 |
+
parts = data_source[3:].split(":")
|
| 495 |
+
name = parts[0]
|
| 496 |
+
hf_config = None
|
| 497 |
+
split = "train"
|
| 498 |
+
if len(parts) == 2:
|
| 499 |
+
split = parts[1]
|
| 500 |
+
elif len(parts) == 3:
|
| 501 |
+
hf_config = parts[1]
|
| 502 |
+
split = parts[2]
|
| 503 |
+
return load_hf_dataset(
|
| 504 |
+
name, split, text_column, tokenizer, max_seq_len,
|
| 505 |
+
num_samples, hf_config=hf_config, cache_path=cache_path,
|
| 506 |
+
data_format=data_format,
|
| 507 |
+
)
|
| 508 |
+
elif os.path.isfile(data_source):
|
| 509 |
+
chunks = load_text_file(data_source, tokenizer, max_seq_len)
|
| 510 |
+
elif os.path.isdir(data_source):
|
| 511 |
+
chunks = load_text_directory(data_source, tokenizer, max_seq_len)
|
| 512 |
+
else:
|
| 513 |
+
raise ValueError(f"Unknown data source: {data_source}")
|
| 514 |
+
|
| 515 |
+
# For text files: save legacy cache
|
| 516 |
+
if cache_dir is not None:
|
| 517 |
+
legacy_path = cache_path.with_suffix('.pt')
|
| 518 |
+
torch.save({"chunks": chunks, "data_source": data_source,
|
| 519 |
+
"max_seq_len": max_seq_len, "num_samples": num_samples}, legacy_path)
|
| 520 |
+
print(f"Saved to cache: {legacy_path}")
|
| 521 |
+
|
| 522 |
+
return TextDataset(chunks, max_seq_len)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# ---------------------------------------------------------------------------
|
| 526 |
+
# DataLoader factory
|
| 527 |
+
# ---------------------------------------------------------------------------
|
| 528 |
+
|
| 529 |
+
def create_dataloader(
|
| 530 |
+
dataset,
|
| 531 |
+
batch_size: int,
|
| 532 |
+
max_seq_len: int = None,
|
| 533 |
+
shuffle: bool = True,
|
| 534 |
+
num_workers: int = 0,
|
| 535 |
+
) -> DataLoader:
|
| 536 |
+
"""Create a DataLoader from a Dataset or list of chunks."""
|
| 537 |
+
if not isinstance(dataset, Dataset):
|
| 538 |
+
# Legacy compatibility: list of token chunks
|
| 539 |
+
dataset = TextDataset(dataset, max_seq_len)
|
| 540 |
+
return DataLoader(
|
| 541 |
+
dataset,
|
| 542 |
+
batch_size=batch_size,
|
| 543 |
+
shuffle=shuffle,
|
| 544 |
+
num_workers=num_workers,
|
| 545 |
+
pin_memory=True,
|
| 546 |
+
)
|
generate.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generation script for Circuit Transformer.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time"
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
|
| 16 |
+
from .config import CircuitConfig
|
| 17 |
+
from .model import CircuitTransformer
|
| 18 |
+
from .mirrored import MirroredConfig, MirroredTransformer
|
| 19 |
+
from .graft_g2lu import load_g2lu_model
|
| 20 |
+
from .layers import build_word_start_table
|
| 21 |
+
from .data import get_tokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer")
|
| 26 |
+
|
| 27 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint")
|
| 28 |
+
parser.add_argument("--prompt", type=str, default="", help="Prompt text")
|
| 29 |
+
parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate")
|
| 30 |
+
parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
|
| 31 |
+
parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering")
|
| 32 |
+
parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold")
|
| 33 |
+
parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)")
|
| 34 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| 35 |
+
parser.add_argument("--no-cache", action="store_true", help="Disable KV cache")
|
| 36 |
+
|
| 37 |
+
return parser.parse_args()
|
| 38 |
+
|
| 39 |
+
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| 40 |
+
"""Migrate checkpoint state_dict to match current model architecture.
|
| 41 |
+
|
| 42 |
+
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| 43 |
+
"""
|
| 44 |
+
if any(k.startswith("_orig_mod.") for k in state_dict):
|
| 45 |
+
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| 46 |
+
|
| 47 |
+
model_keys = set(model.state_dict().keys())
|
| 48 |
+
ckpt_keys = set(state_dict.keys())
|
| 49 |
+
|
| 50 |
+
missing = model_keys - ckpt_keys
|
| 51 |
+
unexpected = ckpt_keys - model_keys
|
| 52 |
+
|
| 53 |
+
print(unexpected)
|
| 54 |
+
|
| 55 |
+
if not missing and not unexpected:
|
| 56 |
+
return state_dict # perfect match, no migration needed
|
| 57 |
+
|
| 58 |
+
migrated = dict(state_dict)
|
| 59 |
+
migrations = []
|
| 60 |
+
|
| 61 |
+
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
|
| 62 |
+
for key in list(unexpected):
|
| 63 |
+
if ".ffn.gate_expand.weight" in key:
|
| 64 |
+
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| 65 |
+
if new_key in missing:
|
| 66 |
+
migrated[new_key] = migrated.pop(key)
|
| 67 |
+
missing.discard(new_key)
|
| 68 |
+
unexpected.discard(key)
|
| 69 |
+
migrations.append(f" {key} → {new_key}")
|
| 70 |
+
if ".ffn.gate_compress.weight" in key:
|
| 71 |
+
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| 72 |
+
if new_key in missing:
|
| 73 |
+
migrated[new_key] = migrated.pop(key)
|
| 74 |
+
missing.discard(new_key)
|
| 75 |
+
unexpected.discard(key)
|
| 76 |
+
migrations.append(f" {key} → {new_key}")
|
| 77 |
+
|
| 78 |
+
if migrations:
|
| 79 |
+
print(f"State dict migration ({len(migrations)} keys renamed):")
|
| 80 |
+
for m in migrations:
|
| 81 |
+
print(m)
|
| 82 |
+
# Report remaining missing keys (freshly initialized)
|
| 83 |
+
still_missing = model_keys - set(migrated.keys())
|
| 84 |
+
if still_missing:
|
| 85 |
+
print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| 86 |
+
for k in sorted(still_missing):
|
| 87 |
+
print(f" {k}")
|
| 88 |
+
|
| 89 |
+
return migrated
|
| 90 |
+
|
| 91 |
+
def generate():
|
| 92 |
+
args = parse_args()
|
| 93 |
+
|
| 94 |
+
# Setup device
|
| 95 |
+
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| 96 |
+
print(f"Device: {device}")
|
| 97 |
+
|
| 98 |
+
# Load checkpoint
|
| 99 |
+
print(f"Loading checkpoint: {args.checkpoint}")
|
| 100 |
+
checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 101 |
+
|
| 102 |
+
# Reconstruct config and model based on architecture type
|
| 103 |
+
model_type = checkpoint.get("model_type", "standard")
|
| 104 |
+
is_folded = model_type == "folded"
|
| 105 |
+
|
| 106 |
+
if model_type == "graft_g2lu":
|
| 107 |
+
model = load_g2lu_model(args.checkpoint, device=device)
|
| 108 |
+
model.eval()
|
| 109 |
+
pretrained_name = checkpoint.get("pretrained_name", "unknown")
|
| 110 |
+
print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)")
|
| 111 |
+
tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name)
|
| 112 |
+
tokenizer = get_tokenizer(tokenizer_name)
|
| 113 |
+
elif is_folded:
|
| 114 |
+
from grafting.fold_llama import FoldedLlama
|
| 115 |
+
model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device)
|
| 116 |
+
model.eval()
|
| 117 |
+
fold_cfg = model.config
|
| 118 |
+
print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, "
|
| 119 |
+
f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)")
|
| 120 |
+
tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True)
|
| 121 |
+
else:
|
| 122 |
+
if model_type == "mirrored":
|
| 123 |
+
if checkpoint["config"].get("dual_gate_middle"):
|
| 124 |
+
checkpoint["config"].pop("dual_gate_middle")
|
| 125 |
+
config = MirroredConfig.from_dict(checkpoint["config"])
|
| 126 |
+
model = MirroredTransformer(config).to(device)
|
| 127 |
+
print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)")
|
| 128 |
+
else:
|
| 129 |
+
config = CircuitConfig.from_dict(checkpoint["config"])
|
| 130 |
+
model = CircuitTransformer(config).to(device)
|
| 131 |
+
print(f"Architecture: CircuitTransformer ({config.num_layers} layers)")
|
| 132 |
+
# Strip _orig_mod. prefix from torch.compile'd checkpoints
|
| 133 |
+
|
| 134 |
+
state_dict = _migrate_state_dict(checkpoint["model"], model)
|
| 135 |
+
|
| 136 |
+
model.load_state_dict(state_dict)
|
| 137 |
+
model.eval()
|
| 138 |
+
tokenizer_name = checkpoint.get("tokenizer_name", "gpt2")
|
| 139 |
+
tokenizer = get_tokenizer(tokenizer_name)
|
| 140 |
+
|
| 141 |
+
# Build word-position table if model uses SemRoPE
|
| 142 |
+
word_start_table_device = None
|
| 143 |
+
if model_type not in ("graft_g2lu", "folded"):
|
| 144 |
+
ckpt_config = checkpoint.get("config", {})
|
| 145 |
+
word_rope_dims = ckpt_config.get("word_rope_dims", 0)
|
| 146 |
+
if word_rope_dims > 0:
|
| 147 |
+
word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| 148 |
+
print(f"Word-position RoPE: {word_rope_dims} dims")
|
| 149 |
+
|
| 150 |
+
# Tokenize prompt
|
| 151 |
+
if args.prompt:
|
| 152 |
+
prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)
|
| 153 |
+
else:
|
| 154 |
+
# Start with BOS/EOS token
|
| 155 |
+
prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
|
| 156 |
+
|
| 157 |
+
print(f"\nPrompt: {args.prompt or '<empty>'}")
|
| 158 |
+
print(f"Prompt tokens: {prompt_ids.shape[1]}")
|
| 159 |
+
print(f"Generating {args.max_tokens} tokens...")
|
| 160 |
+
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}")
|
| 161 |
+
print("-" * 50)
|
| 162 |
+
|
| 163 |
+
# Generate
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
gen_kwargs = dict(
|
| 166 |
+
max_new_tokens=args.max_tokens,
|
| 167 |
+
temperature=args.temperature,
|
| 168 |
+
top_k=args.top_k,
|
| 169 |
+
top_p=args.top_p,
|
| 170 |
+
use_cache=not args.no_cache,
|
| 171 |
+
)
|
| 172 |
+
if args.repetition_penalty != 1.0:
|
| 173 |
+
gen_kwargs["repetition_penalty"] = args.repetition_penalty
|
| 174 |
+
|
| 175 |
+
# HF models need do_sample=True for temperature/top_k/top_p
|
| 176 |
+
if model_type == "graft_g2lu":
|
| 177 |
+
if args.temperature > 0 and args.temperature != 1.0:
|
| 178 |
+
gen_kwargs["do_sample"] = True
|
| 179 |
+
elif args.top_p < 1.0 or args.top_k > 0:
|
| 180 |
+
gen_kwargs["do_sample"] = True
|
| 181 |
+
|
| 182 |
+
if word_start_table_device is not None:
|
| 183 |
+
gen_kwargs["word_start_table"] = word_start_table_device
|
| 184 |
+
|
| 185 |
+
output_ids = model.generate(prompt_ids, **gen_kwargs)
|
| 186 |
+
|
| 187 |
+
# Decode and print
|
| 188 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 189 |
+
print(generated_text)
|
| 190 |
+
print("-" * 50)
|
| 191 |
+
print(f"Total tokens: {output_ids.shape[1]}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
generate()
|
graft_g2lu.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
G²LU Gate Grafting: Surgically upgrade pretrained SwiGLU models to G²LU.
|
| 3 |
+
|
| 4 |
+
Takes any HuggingFace model with SwiGLU (gate_proj + up_proj), freezes everything
|
| 5 |
+
except gate weights, adds W4 for nested gating, and trains with alignment + LM loss.
|
| 6 |
+
|
| 7 |
+
This is grafting applied to the gate mechanism — the same methodology validated for
|
| 8 |
+
full layer replacement, now targeting the minimum surgical unit.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python -m circuits.train --arch graft_g2lu --pretrained meta-llama/Llama-3.2-1B \
|
| 12 |
+
--align-weight 1.0 --graft-warmup 500 --data hf:Bingsu/openwebtext_20p ...
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class G2LU_MLP(nn.Module):
|
| 22 |
+
"""Per-layer MLP wrapper that upgrades SwiGLU to G²LU.
|
| 23 |
+
|
| 24 |
+
Holds references to the original gate_proj (W3, frozen), up_proj (W1, frozen),
|
| 25 |
+
down_proj (W2, frozen), plus a new w4 (zero-initialized, trainable).
|
| 26 |
+
|
| 27 |
+
Gate ordering: silu(W4@x * silu(W3@x)) — the pretrained gate (W3) acts as
|
| 28 |
+
structural prior, constraining W4 to operate within the feature subspace the
|
| 29 |
+
pretrained model already deems relevant. W4's gradients are scaled by silu(W3@x),
|
| 30 |
+
inheriting the pretrained model's feature selection hierarchy.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, original_mlp: nn.Module):
|
| 34 |
+
super().__init__()
|
| 35 |
+
# References to original weights (all frozen)
|
| 36 |
+
self.gate_proj = original_mlp.gate_proj # W3 — frozen
|
| 37 |
+
self.up_proj = original_mlp.up_proj # W1 — frozen
|
| 38 |
+
self.down_proj = original_mlp.down_proj # W2 — frozen
|
| 39 |
+
|
| 40 |
+
# New W4: same shape as gate_proj, zero-initialized, matched dtype
|
| 41 |
+
self.w4 = nn.Linear(
|
| 42 |
+
self.gate_proj.in_features,
|
| 43 |
+
self.gate_proj.out_features,
|
| 44 |
+
bias=self.gate_proj.bias is not None,
|
| 45 |
+
dtype=self.gate_proj.weight.dtype,
|
| 46 |
+
device=self.gate_proj.weight.device,
|
| 47 |
+
)
|
| 48 |
+
nn.init.zeros_(self.w4.weight)
|
| 49 |
+
if self.w4.bias is not None:
|
| 50 |
+
nn.init.zeros_(self.w4.bias)
|
| 51 |
+
|
| 52 |
+
# Blend alpha: 0 = pure SwiGLU, 1 = full G²LU
|
| 53 |
+
self._alpha = 0.0
|
| 54 |
+
# Per-layer alignment loss (collected by parent)
|
| 55 |
+
self._align_loss = None
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
# Pretrained gate (frozen W3) — structural prior
|
| 59 |
+
w3_gate = F.silu(self.gate_proj(x))
|
| 60 |
+
|
| 61 |
+
# G²LU gate: silu(W4@x * silu(W3@x))
|
| 62 |
+
# W4 modulated BY pretrained knowledge, not the reverse
|
| 63 |
+
g2lu_gate = F.silu(self.w4(x) * w3_gate)
|
| 64 |
+
|
| 65 |
+
# Blend warmup: smooth transition from SwiGLU → G²LU
|
| 66 |
+
if self._alpha < 1.0:
|
| 67 |
+
gate = (1.0 - self._alpha) * w3_gate + self._alpha * g2lu_gate
|
| 68 |
+
else:
|
| 69 |
+
gate = g2lu_gate
|
| 70 |
+
|
| 71 |
+
# Per-layer alignment loss (compare against original SwiGLU gate)
|
| 72 |
+
self._align_loss = F.mse_loss(gate, w3_gate.detach())
|
| 73 |
+
|
| 74 |
+
return self.down_proj(gate * self.up_proj(x))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class G2LU_GraftedModel(nn.Module):
|
| 78 |
+
"""Full model wrapper that upgrades a pretrained HF model's MLPs to G²LU.
|
| 79 |
+
|
| 80 |
+
Interface matches CircuitTransformer: forward(input_ids, labels=labels) returns
|
| 81 |
+
{"loss", "logits", "align_loss"}.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
pretrained_name: str,
|
| 87 |
+
align_weight: float = 1.0,
|
| 88 |
+
warmup_steps: int = 500,
|
| 89 |
+
device: str = "cuda",
|
| 90 |
+
dtype=torch.bfloat16,
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.pretrained_name = pretrained_name
|
| 94 |
+
self.align_weight = align_weight
|
| 95 |
+
self.warmup_steps = warmup_steps
|
| 96 |
+
self._current_step = 0
|
| 97 |
+
|
| 98 |
+
# Load pretrained HF model
|
| 99 |
+
from transformers import AutoModelForCausalLM
|
| 100 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 101 |
+
pretrained_name,
|
| 102 |
+
dtype=dtype,
|
| 103 |
+
trust_remote_code=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Discover and replace MLPs
|
| 107 |
+
self.g2lu_mlps = []
|
| 108 |
+
self._replace_mlps()
|
| 109 |
+
|
| 110 |
+
# Freeze everything, then selectively unfreeze W4 only
|
| 111 |
+
for param in self.model.parameters():
|
| 112 |
+
param.requires_grad = False
|
| 113 |
+
|
| 114 |
+
for g2lu in self.g2lu_mlps:
|
| 115 |
+
for param in g2lu.w4.parameters():
|
| 116 |
+
param.requires_grad = True
|
| 117 |
+
|
| 118 |
+
self.model.to(device)
|
| 119 |
+
|
| 120 |
+
# Print summary
|
| 121 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 122 |
+
trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 123 |
+
print(f"G²LU Graft: {pretrained_name}")
|
| 124 |
+
print(f" Layers upgraded: {len(self.g2lu_mlps)}")
|
| 125 |
+
print(f" Total params: {total_params:,} ({total_params/1e6:.1f}M)")
|
| 126 |
+
print(f" Trainable params: {trainable:,} ({trainable/1e6:.1f}M, {100*trainable/total_params:.1f}%)")
|
| 127 |
+
print(f" Align weight: {align_weight}, Warmup: {warmup_steps} steps")
|
| 128 |
+
|
| 129 |
+
def _replace_mlps(self):
|
| 130 |
+
"""Walk the model tree and replace SwiGLU MLPs with G²LU wrappers."""
|
| 131 |
+
# Try common decoder layer paths
|
| 132 |
+
layers = None
|
| 133 |
+
for attr_path in ["model.layers", "gpt_neox.layers", "transformer.h"]:
|
| 134 |
+
obj = self.model
|
| 135 |
+
try:
|
| 136 |
+
for attr in attr_path.split("."):
|
| 137 |
+
obj = getattr(obj, attr)
|
| 138 |
+
layers = obj
|
| 139 |
+
break
|
| 140 |
+
except AttributeError:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
if layers is None:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"Could not find decoder layers in {type(self.model).__name__}. "
|
| 146 |
+
f"Tried: model.layers, gpt_neox.layers, transformer.h"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
for i, layer in enumerate(layers):
|
| 150 |
+
# Try common MLP attribute names
|
| 151 |
+
mlp = None
|
| 152 |
+
mlp_attr = None
|
| 153 |
+
for attr in ["mlp", "feed_forward"]:
|
| 154 |
+
if hasattr(layer, attr):
|
| 155 |
+
mlp = getattr(layer, attr)
|
| 156 |
+
mlp_attr = attr
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
if mlp is None:
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# Check for SwiGLU signature (gate_proj + up_proj)
|
| 163 |
+
if hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj"):
|
| 164 |
+
g2lu = G2LU_MLP(mlp)
|
| 165 |
+
setattr(layer, mlp_attr, g2lu)
|
| 166 |
+
self.g2lu_mlps.append(g2lu)
|
| 167 |
+
|
| 168 |
+
if not self.g2lu_mlps:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"No SwiGLU MLPs found (need gate_proj + up_proj attributes). "
|
| 171 |
+
"This model may not use gated linear units."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def set_step(self, step: int):
|
| 175 |
+
"""Update blend alpha across all G²LU MLPs."""
|
| 176 |
+
self._current_step = step
|
| 177 |
+
alpha = min(step / max(self.warmup_steps, 1), 1.0)
|
| 178 |
+
for g2lu in self.g2lu_mlps:
|
| 179 |
+
g2lu._alpha = alpha
|
| 180 |
+
|
| 181 |
+
def trainable_parameters(self):
|
| 182 |
+
"""Yield only unfrozen parameters (for optimizer and grad clipping)."""
|
| 183 |
+
for param in self.model.parameters():
|
| 184 |
+
if param.requires_grad:
|
| 185 |
+
yield param
|
| 186 |
+
|
| 187 |
+
def collect_align_loss(self):
|
| 188 |
+
"""Average per-layer alignment losses."""
|
| 189 |
+
losses = [g2lu._align_loss for g2lu in self.g2lu_mlps if g2lu._align_loss is not None]
|
| 190 |
+
if not losses:
|
| 191 |
+
return torch.tensor(0.0)
|
| 192 |
+
return torch.stack(losses).mean()
|
| 193 |
+
|
| 194 |
+
def forward(self, input_ids, labels=None, **kwargs):
|
| 195 |
+
outputs = self.model(input_ids=input_ids, labels=labels, **kwargs)
|
| 196 |
+
|
| 197 |
+
result = {"logits": outputs.logits}
|
| 198 |
+
|
| 199 |
+
align_loss = self.collect_align_loss()
|
| 200 |
+
result["align_loss"] = align_loss
|
| 201 |
+
|
| 202 |
+
if labels is not None:
|
| 203 |
+
# Combine LM loss + alignment loss
|
| 204 |
+
result["loss"] = outputs.loss + self.align_weight * align_loss
|
| 205 |
+
result["lm_loss"] = outputs.loss
|
| 206 |
+
else:
|
| 207 |
+
result["loss"] = align_loss
|
| 208 |
+
|
| 209 |
+
return result
|
| 210 |
+
|
| 211 |
+
def generate(self, input_ids, **kwargs):
|
| 212 |
+
"""Delegate to HF model's .generate()."""
|
| 213 |
+
return self.model.generate(input_ids=input_ids, **kwargs)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def save_g2lu_checkpoint(
|
| 217 |
+
model: G2LU_GraftedModel,
|
| 218 |
+
optimizer: torch.optim.Optimizer,
|
| 219 |
+
step: int,
|
| 220 |
+
epoch: int,
|
| 221 |
+
loss: float,
|
| 222 |
+
path: str,
|
| 223 |
+
epoch_step: int = 0,
|
| 224 |
+
best_val_loss: float | None = None,
|
| 225 |
+
scaler=None,
|
| 226 |
+
tokenizer_name: str = None,
|
| 227 |
+
):
|
| 228 |
+
"""Delta save: only trainable params + metadata."""
|
| 229 |
+
# Extract only requires_grad params
|
| 230 |
+
raw = model.model if not hasattr(model, '_orig_mod') else model._orig_mod.model
|
| 231 |
+
# Handle torch.compile wrapper
|
| 232 |
+
if hasattr(model, '_orig_mod'):
|
| 233 |
+
g2lu_model = model._orig_mod
|
| 234 |
+
else:
|
| 235 |
+
g2lu_model = model
|
| 236 |
+
|
| 237 |
+
delta_sd = {}
|
| 238 |
+
full_sd = g2lu_model.model.state_dict()
|
| 239 |
+
for name, param in g2lu_model.model.named_parameters():
|
| 240 |
+
if param.requires_grad:
|
| 241 |
+
# Strip _orig_mod. prefix if present
|
| 242 |
+
clean_name = name.removeprefix("_orig_mod.")
|
| 243 |
+
delta_sd[clean_name] = full_sd.get(name, param.data).clone()
|
| 244 |
+
|
| 245 |
+
# Also save the w4 weights explicitly (they're part of the replaced modules)
|
| 246 |
+
for name, val in full_sd.items():
|
| 247 |
+
clean_name = name.removeprefix("_orig_mod.")
|
| 248 |
+
if ".w4." in clean_name and clean_name not in delta_sd:
|
| 249 |
+
delta_sd[clean_name] = val.clone()
|
| 250 |
+
|
| 251 |
+
checkpoint = {
|
| 252 |
+
"model": delta_sd,
|
| 253 |
+
"optimizer": optimizer.state_dict(),
|
| 254 |
+
"step": step,
|
| 255 |
+
"epoch": epoch,
|
| 256 |
+
"epoch_step": epoch_step,
|
| 257 |
+
"loss": loss,
|
| 258 |
+
"model_type": "graft_g2lu",
|
| 259 |
+
"pretrained_name": g2lu_model.pretrained_name,
|
| 260 |
+
"align_weight": g2lu_model.align_weight,
|
| 261 |
+
"warmup_steps": g2lu_model.warmup_steps,
|
| 262 |
+
"tokenizer_name": tokenizer_name or g2lu_model.pretrained_name,
|
| 263 |
+
}
|
| 264 |
+
if best_val_loss is not None:
|
| 265 |
+
checkpoint["best_val_loss"] = best_val_loss
|
| 266 |
+
if scaler is not None:
|
| 267 |
+
checkpoint["scaler"] = scaler.state_dict()
|
| 268 |
+
torch.save(checkpoint, path)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def load_g2lu_model(checkpoint_path: str, device: str = "cuda", dtype=torch.bfloat16):
|
| 272 |
+
"""Delta load: recreate model from pretrained + apply delta weights."""
|
| 273 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 274 |
+
|
| 275 |
+
pretrained_name = checkpoint["pretrained_name"]
|
| 276 |
+
align_weight = checkpoint.get("align_weight", 1.0)
|
| 277 |
+
warmup_steps = checkpoint.get("warmup_steps", 500)
|
| 278 |
+
|
| 279 |
+
model = G2LU_GraftedModel(
|
| 280 |
+
pretrained_name=pretrained_name,
|
| 281 |
+
align_weight=align_weight,
|
| 282 |
+
warmup_steps=warmup_steps,
|
| 283 |
+
device=device,
|
| 284 |
+
dtype=dtype,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Load delta weights
|
| 288 |
+
delta_sd = checkpoint["model"]
|
| 289 |
+
# Strip _orig_mod. prefix if present
|
| 290 |
+
delta_sd = {k.removeprefix("_orig_mod."): v for k, v in delta_sd.items()}
|
| 291 |
+
|
| 292 |
+
# Apply delta weights to the model
|
| 293 |
+
missing, unexpected = model.model.load_state_dict(delta_sd, strict=False)
|
| 294 |
+
if unexpected:
|
| 295 |
+
print(f" Warning: unexpected keys in delta checkpoint: {unexpected[:5]}...")
|
| 296 |
+
|
| 297 |
+
# Set alpha to 1.0 for inference (full G²LU)
|
| 298 |
+
model.set_step(warmup_steps + 1)
|
| 299 |
+
|
| 300 |
+
return model
|
layers.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared building blocks for Circuit Transformer architectures.
|
| 3 |
+
|
| 4 |
+
Components:
|
| 5 |
+
- RMSNorm: Root Mean Square Layer Normalization
|
| 6 |
+
- RotaryEmbedding: Rotary Position Embedding (RoPE)
|
| 7 |
+
- CausalAttention: Multi-head causal attention with RoPE + KV cache
|
| 8 |
+
- SwiGLU: Gated feed-forward network
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import math
|
| 15 |
+
from functools import lru_cache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RMSNorm(nn.Module):
|
| 19 |
+
"""Root Mean Square Layer Normalization."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.eps = eps
|
| 24 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
| 28 |
+
return (x.float() * norm).type_as(x) * self.weight
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor:
|
| 32 |
+
"""Build a boolean table marking which token IDs start a new word.
|
| 33 |
+
|
| 34 |
+
Detects word boundaries from tokenizer's token representations:
|
| 35 |
+
- Ġ prefix (GPT-2/BPE style)
|
| 36 |
+
- ▁ prefix (SentencePiece style)
|
| 37 |
+
- Special tokens (starting with <)
|
| 38 |
+
"""
|
| 39 |
+
table = torch.zeros(vocab_size, dtype=torch.bool)
|
| 40 |
+
|
| 41 |
+
# Get all token strings — handle both HF and SentencePiece tokenizers
|
| 42 |
+
if hasattr(tokenizer, 'convert_ids_to_tokens'):
|
| 43 |
+
tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
|
| 44 |
+
elif hasattr(tokenizer, 'sp'):
|
| 45 |
+
tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)]
|
| 46 |
+
else:
|
| 47 |
+
tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
|
| 48 |
+
|
| 49 |
+
for idx, tok in enumerate(tokens):
|
| 50 |
+
if tok is None:
|
| 51 |
+
continue
|
| 52 |
+
if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
|
| 53 |
+
table[idx] = True
|
| 54 |
+
# Punctuation and newlines that start new "words"
|
| 55 |
+
elif len(tok) > 0 and tok[0] in '\n\r\t':
|
| 56 |
+
table[idx] = True
|
| 57 |
+
|
| 58 |
+
# Token 0 is always a word starter (BOS/padding)
|
| 59 |
+
table[0] = True
|
| 60 |
+
|
| 61 |
+
return table
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""Compute position-within-word for each token. Vectorized, no loops.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
input_ids: [B, L] token IDs
|
| 69 |
+
word_start_table: [vocab_size] bool tensor from build_word_start_table
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
[B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary)
|
| 73 |
+
"""
|
| 74 |
+
is_word_start = word_start_table[input_ids] # [B, L]
|
| 75 |
+
is_word_start[:, 0] = True # First token always starts a word
|
| 76 |
+
|
| 77 |
+
B, L = input_ids.shape
|
| 78 |
+
positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1)
|
| 79 |
+
|
| 80 |
+
# Fill non-word-start positions with -1, word-start positions with their index
|
| 81 |
+
fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device))
|
| 82 |
+
|
| 83 |
+
# cummax propagates the most recent word-start position forward
|
| 84 |
+
running_start, _ = fill.cummax(dim=1)
|
| 85 |
+
|
| 86 |
+
# Position within word = distance from the most recent word start
|
| 87 |
+
word_pos = positions - running_start # [B, L] float: 0, 1, 2, 0, 1, 0, ...
|
| 88 |
+
|
| 89 |
+
return word_pos
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class WordPositionRoPE(nn.Module):
|
| 93 |
+
"""RoPE encoding for position-within-word.
|
| 94 |
+
|
| 95 |
+
Dedicates a small subspace of head dimensions to word-internal position,
|
| 96 |
+
using separate (lower) frequency bases. Overrides the last `word_dims`
|
| 97 |
+
of the standard RoPE cos/sin tensors.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, word_dims: int, word_base: float = 10.0):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.word_dims = word_dims
|
| 103 |
+
word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims))
|
| 104 |
+
self.register_buffer("word_inv_freq", word_inv_freq)
|
| 105 |
+
|
| 106 |
+
def forward(
|
| 107 |
+
self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor
|
| 108 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 109 |
+
"""Override last word_dims of cos/sin with word-position-derived values.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
cos, sin: [L, head_dim] from standard RotaryEmbedding
|
| 113 |
+
word_positions: [B, L] float tensor (position within word)
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
cos, sin: [B, L, head_dim] with word dims overridden
|
| 117 |
+
"""
|
| 118 |
+
B, L = word_positions.shape
|
| 119 |
+
|
| 120 |
+
# Compute word angles: [B, L, word_dims/2]
|
| 121 |
+
angles = word_positions.unsqueeze(-1) * self.word_inv_freq
|
| 122 |
+
# Duplicate for rotate_half pattern: [B, L, word_dims]
|
| 123 |
+
word_emb = torch.cat([angles, angles], dim=-1)
|
| 124 |
+
word_cos = word_emb.cos()
|
| 125 |
+
word_sin = word_emb.sin()
|
| 126 |
+
|
| 127 |
+
# Expand standard cos/sin to batch dimension: [L, D] -> [B, L, D]
|
| 128 |
+
cos = cos.unsqueeze(0).expand(B, -1, -1).clone()
|
| 129 |
+
sin = sin.unsqueeze(0).expand(B, -1, -1).clone()
|
| 130 |
+
|
| 131 |
+
# Override last word_dims with word-position values
|
| 132 |
+
cos[:, :, -self.word_dims:] = word_cos
|
| 133 |
+
sin[:, :, -self.word_dims:] = word_sin
|
| 134 |
+
|
| 135 |
+
return cos, sin
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class RotaryEmbedding(nn.Module):
|
| 139 |
+
"""Rotary Position Embedding (RoPE)."""
|
| 140 |
+
|
| 141 |
+
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.dim = dim
|
| 144 |
+
self.max_seq_len = max_seq_len
|
| 145 |
+
|
| 146 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 147 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 148 |
+
self._build_cache(max_seq_len)
|
| 149 |
+
|
| 150 |
+
def _build_cache(self, seq_len: int):
|
| 151 |
+
t = torch.arange(seq_len, device=self.inv_freq.device)
|
| 152 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 153 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 154 |
+
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| 155 |
+
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| 156 |
+
|
| 157 |
+
def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 158 |
+
if seq_len > self.cos_cached.size(0):
|
| 159 |
+
self._build_cache(seq_len)
|
| 160 |
+
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
"""Rotate half the hidden dims."""
|
| 165 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 166 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def apply_rotary_pos_emb(
|
| 170 |
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 171 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 172 |
+
"""Apply rotary position embedding to queries and keys.
|
| 173 |
+
|
| 174 |
+
Handles both standard [L, D] and batched [B, L, D] cos/sin.
|
| 175 |
+
Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast.
|
| 176 |
+
"""
|
| 177 |
+
if cos.dim() == 3: # [B, L, D] from WordPositionRoPE
|
| 178 |
+
cos = cos.unsqueeze(1) # [B, 1, L, D] — broadcast over heads
|
| 179 |
+
sin = sin.unsqueeze(1)
|
| 180 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 181 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 182 |
+
return q_embed, k_embed
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class CausalAttention(nn.Module):
|
| 186 |
+
"""Multi-head attention with causal mask, RoPE, and optional GQA.
|
| 187 |
+
|
| 188 |
+
Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads.
|
| 189 |
+
Each KV head serves (num_heads // num_kv_heads) query heads.
|
| 190 |
+
KV cache stored at kv_heads granularity for memory efficiency.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
hidden_size: int,
|
| 196 |
+
num_heads: int,
|
| 197 |
+
num_kv_heads: int | None = None,
|
| 198 |
+
max_seq_len: int = 2048,
|
| 199 |
+
dropout: float = 0.0,
|
| 200 |
+
window_size: int | None = None,
|
| 201 |
+
word_rope_dims: int = 0,
|
| 202 |
+
word_rope_base: float = 10.0,
|
| 203 |
+
):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.hidden_size = hidden_size
|
| 206 |
+
self.num_heads = num_heads
|
| 207 |
+
self.num_kv_heads = num_kv_heads or num_heads
|
| 208 |
+
self.head_dim = hidden_size // num_heads
|
| 209 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
| 210 |
+
self.dropout = dropout
|
| 211 |
+
self.window_size = window_size
|
| 212 |
+
|
| 213 |
+
assert self.num_heads % self.num_kv_heads == 0, \
|
| 214 |
+
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
| 215 |
+
if word_rope_dims > 0:
|
| 216 |
+
assert word_rope_dims <= self.head_dim, \
|
| 217 |
+
f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})"
|
| 218 |
+
assert word_rope_dims % 2 == 0, \
|
| 219 |
+
f"word_rope_dims ({word_rope_dims}) must be even"
|
| 220 |
+
|
| 221 |
+
self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 222 |
+
self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 223 |
+
self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 224 |
+
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 225 |
+
|
| 226 |
+
self.rotary = RotaryEmbedding(self.head_dim, max_seq_len)
|
| 227 |
+
|
| 228 |
+
# Word-position RoPE (optional)
|
| 229 |
+
self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None
|
| 230 |
+
|
| 231 |
+
# Build causal mask (optionally windowed)
|
| 232 |
+
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
| 233 |
+
if window_size is not None:
|
| 234 |
+
# Band mask: position i attends to [max(0, i-window+1), i]
|
| 235 |
+
band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1))
|
| 236 |
+
mask = mask * band
|
| 237 |
+
self.register_buffer(
|
| 238 |
+
"causal_mask",
|
| 239 |
+
mask.view(1, 1, max_seq_len, max_seq_len),
|
| 240 |
+
persistent=False,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
"""Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads."""
|
| 245 |
+
if self.num_kv_groups == 1:
|
| 246 |
+
return kv
|
| 247 |
+
B, H_kv, L, D = kv.shape
|
| 248 |
+
return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D)
|
| 249 |
+
|
| 250 |
+
def forward(
|
| 251 |
+
self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
|
| 252 |
+
word_positions: torch.Tensor | None = None,
|
| 253 |
+
) -> tuple[torch.Tensor, tuple | None]:
|
| 254 |
+
B, L, _ = x.shape
|
| 255 |
+
|
| 256 |
+
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 257 |
+
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 258 |
+
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 259 |
+
|
| 260 |
+
# RoPE: use correct position offset for KV-cached generation
|
| 261 |
+
offset = past_kv[0].size(2) if past_kv is not None else 0
|
| 262 |
+
cos, sin = self.rotary(x, offset + L)
|
| 263 |
+
cos = cos[offset:offset + L]
|
| 264 |
+
sin = sin[offset:offset + L]
|
| 265 |
+
|
| 266 |
+
# Override word-position dims if enabled
|
| 267 |
+
if self.word_rope is not None and word_positions is not None:
|
| 268 |
+
cos, sin = self.word_rope(cos, sin, word_positions)
|
| 269 |
+
|
| 270 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 271 |
+
|
| 272 |
+
# KV cache at kv_heads granularity (memory efficient for GQA)
|
| 273 |
+
if past_kv is not None:
|
| 274 |
+
past_k, past_v = past_kv
|
| 275 |
+
k = torch.cat([past_k, k], dim=2)
|
| 276 |
+
v = torch.cat([past_v, v], dim=2)
|
| 277 |
+
|
| 278 |
+
new_kv = (k, v) if use_cache else None
|
| 279 |
+
|
| 280 |
+
dropout_p = self.dropout if self.training else 0.0
|
| 281 |
+
use_gqa = self.num_kv_groups > 1
|
| 282 |
+
|
| 283 |
+
if self.window_size is not None:
|
| 284 |
+
# Windowed attention: manual path (SDPA FlashAttention doesn't support arbitrary masks)
|
| 285 |
+
k_expanded = self._expand_kv(k)
|
| 286 |
+
v_expanded = self._expand_kv(v)
|
| 287 |
+
seq_len = k.size(2)
|
| 288 |
+
attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 289 |
+
if seq_len <= self.causal_mask.size(-1):
|
| 290 |
+
mask = self.causal_mask[:, :, offset:offset + L, :seq_len]
|
| 291 |
+
attn = attn.masked_fill(mask == 0, float("-inf"))
|
| 292 |
+
attn = F.softmax(attn, dim=-1)
|
| 293 |
+
if dropout_p > 0:
|
| 294 |
+
attn = F.dropout(attn, p=dropout_p)
|
| 295 |
+
out = torch.matmul(attn, v_expanded)
|
| 296 |
+
else:
|
| 297 |
+
# SDPA: auto-dispatches to FlashAttention2 / memory-efficient / math backend
|
| 298 |
+
# Native GQA support avoids expanding KV heads (saves memory + enables FlashAttention GQA kernel)
|
| 299 |
+
is_causal = past_kv is None and L > 1
|
| 300 |
+
out = F.scaled_dot_product_attention(
|
| 301 |
+
q, k, v,
|
| 302 |
+
dropout_p=dropout_p,
|
| 303 |
+
is_causal=is_causal,
|
| 304 |
+
enable_gqa=use_gqa,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
|
| 308 |
+
|
| 309 |
+
return self.o_proj(out), new_kv
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class SwiGLU(nn.Module):
|
| 313 |
+
"""SwiGLU feed-forward network."""
|
| 314 |
+
|
| 315 |
+
def __init__(self, hidden_size: int, intermediate_size: int | None = None):
|
| 316 |
+
super().__init__()
|
| 317 |
+
intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
| 318 |
+
intermediate_size = ((intermediate_size + 63) // 64) * 64
|
| 319 |
+
|
| 320 |
+
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 321 |
+
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 322 |
+
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 323 |
+
|
| 324 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 325 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
lm_eval_wrapper.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LM-eval harness wrapper for Circuit/Mirrored transformers.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
# Single model
|
| 6 |
+
python -m circuits.bench --checkpoint circuits/checkpoints/mirrored/best.pt --gpu 0
|
| 7 |
+
|
| 8 |
+
# Compare all architectures
|
| 9 |
+
python -m circuits.bench --compare --gpu 0
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from typing import List
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from lm_eval.api.model import LM
|
| 18 |
+
from lm_eval.api.instance import Instance
|
| 19 |
+
|
| 20 |
+
from .config import CircuitConfig
|
| 21 |
+
from .model import CircuitTransformer
|
| 22 |
+
from .mirrored import MirroredConfig, MirroredTransformer
|
| 23 |
+
from .graft_g2lu import load_g2lu_model
|
| 24 |
+
from .layers import build_word_start_table, compute_word_positions
|
| 25 |
+
from .data import get_tokenizer
|
| 26 |
+
|
| 27 |
+
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| 28 |
+
"""Migrate checkpoint state_dict to match current model architecture.
|
| 29 |
+
|
| 30 |
+
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| 31 |
+
"""
|
| 32 |
+
if any(k.startswith("_orig_mod.") for k in state_dict):
|
| 33 |
+
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| 34 |
+
|
| 35 |
+
model_keys = set(model.state_dict().keys())
|
| 36 |
+
ckpt_keys = set(state_dict.keys())
|
| 37 |
+
|
| 38 |
+
missing = model_keys - ckpt_keys
|
| 39 |
+
unexpected = ckpt_keys - model_keys
|
| 40 |
+
|
| 41 |
+
print(unexpected)
|
| 42 |
+
|
| 43 |
+
if not missing and not unexpected:
|
| 44 |
+
return state_dict # perfect match, no migration needed
|
| 45 |
+
|
| 46 |
+
migrated = dict(state_dict)
|
| 47 |
+
migrations = []
|
| 48 |
+
|
| 49 |
+
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
|
| 50 |
+
for key in list(unexpected):
|
| 51 |
+
if ".ffn.gate_expand.weight" in key:
|
| 52 |
+
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| 53 |
+
if new_key in missing:
|
| 54 |
+
migrated[new_key] = migrated.pop(key)
|
| 55 |
+
missing.discard(new_key)
|
| 56 |
+
unexpected.discard(key)
|
| 57 |
+
migrations.append(f" {key} → {new_key}")
|
| 58 |
+
if ".ffn.gate_compress.weight" in key:
|
| 59 |
+
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| 60 |
+
if new_key in missing:
|
| 61 |
+
migrated[new_key] = migrated.pop(key)
|
| 62 |
+
missing.discard(new_key)
|
| 63 |
+
unexpected.discard(key)
|
| 64 |
+
migrations.append(f" {key} → {new_key}")
|
| 65 |
+
|
| 66 |
+
if migrations:
|
| 67 |
+
print(f"State dict migration ({len(migrations)} keys renamed):")
|
| 68 |
+
for m in migrations:
|
| 69 |
+
print(m)
|
| 70 |
+
# Report remaining missing keys (freshly initialized)
|
| 71 |
+
still_missing = model_keys - set(migrated.keys())
|
| 72 |
+
if still_missing:
|
| 73 |
+
print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| 74 |
+
for k in sorted(still_missing):
|
| 75 |
+
print(f" {k}")
|
| 76 |
+
|
| 77 |
+
return migrated
|
| 78 |
+
|
| 79 |
+
def load_model(checkpoint_path: str, device: str = "cuda"):
|
| 80 |
+
"""Load any circuit model from checkpoint with auto-detection."""
|
| 81 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 82 |
+
|
| 83 |
+
model_type = checkpoint.get("model_type", "standard")
|
| 84 |
+
if model_type == "graft_g2lu":
|
| 85 |
+
model = load_g2lu_model(checkpoint_path, device=device)
|
| 86 |
+
model.eval()
|
| 87 |
+
n_layers = len(model.g2lu_mlps)
|
| 88 |
+
arch_name = f"G²LU Graft ({checkpoint['pretrained_name']}, {n_layers}L)"
|
| 89 |
+
config = model.model.config # HF config
|
| 90 |
+
return model, config, arch_name, model_type
|
| 91 |
+
elif model_type == "mirrored":
|
| 92 |
+
if checkpoint["config"].get("dual_gate_middle"):
|
| 93 |
+
checkpoint["config"].pop("dual_gate_middle")
|
| 94 |
+
config = MirroredConfig.from_dict(checkpoint["config"])
|
| 95 |
+
model = MirroredTransformer(config)
|
| 96 |
+
arch_name = f"Mirrored ({model.total_virtual_layers}L)"
|
| 97 |
+
else:
|
| 98 |
+
config = CircuitConfig.from_dict(checkpoint["config"])
|
| 99 |
+
model = CircuitTransformer(config)
|
| 100 |
+
arch_name = f"Standard ({config.num_layers}L)"
|
| 101 |
+
|
| 102 |
+
# Strip _orig_mod. prefix from torch.compile'd checkpoints
|
| 103 |
+
state_dict = checkpoint["model"]
|
| 104 |
+
state_dict = _migrate_state_dict(state_dict, model)
|
| 105 |
+
model.load_state_dict(state_dict)
|
| 106 |
+
|
| 107 |
+
model = model.to(device).eval()
|
| 108 |
+
return model, config, arch_name, model_type
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class CircuitLM(LM):
|
| 112 |
+
"""LM-eval wrapper for Circuit transformer family."""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
checkpoint: str,
|
| 117 |
+
device: str = "cuda",
|
| 118 |
+
batch_size: int = 1,
|
| 119 |
+
compile: bool = False,
|
| 120 |
+
):
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
self.model, self.config, self.arch_name, self.model_type = load_model(
|
| 124 |
+
checkpoint, device
|
| 125 |
+
)
|
| 126 |
+
# Keep raw reference for .generate() — torch.compile only wraps forward()
|
| 127 |
+
self._raw_model = self.model
|
| 128 |
+
if compile == True:
|
| 129 |
+
self.model = torch.compile(self.model)
|
| 130 |
+
print(" torch.compile: enabled")
|
| 131 |
+
_ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
|
| 132 |
+
_tok_name = _ckpt.get("tokenizer_name", "gpt2")
|
| 133 |
+
del _ckpt
|
| 134 |
+
self.tokenizer = get_tokenizer(_tok_name)
|
| 135 |
+
if self.tokenizer.pad_token is None:
|
| 136 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 137 |
+
|
| 138 |
+
self._device = device
|
| 139 |
+
self._batch_size = batch_size
|
| 140 |
+
|
| 141 |
+
# Build word-position table if model uses SemRoPE
|
| 142 |
+
self._word_start_table = None
|
| 143 |
+
word_rope_dims = getattr(self.config, 'word_rope_dims', 0)
|
| 144 |
+
if word_rope_dims == 0 and isinstance(self.config, dict):
|
| 145 |
+
word_rope_dims = self.config.get('word_rope_dims', 0)
|
| 146 |
+
if word_rope_dims > 0:
|
| 147 |
+
self._word_start_table = build_word_start_table(
|
| 148 |
+
self.tokenizer, len(self.tokenizer)
|
| 149 |
+
).to(device)
|
| 150 |
+
print(f" Word-position RoPE: {word_rope_dims} dims")
|
| 151 |
+
|
| 152 |
+
# Count parameters
|
| 153 |
+
n_params = sum(p.numel() for p in self.model.parameters())
|
| 154 |
+
print(f" Architecture: {self.arch_name}")
|
| 155 |
+
print(f" Parameters: {n_params / 1e6:.1f}M")
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def eot_token_id(self):
|
| 159 |
+
return self.tokenizer.eos_token_id
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def max_length(self):
|
| 163 |
+
return getattr(self.config, "max_seq_len", None) or getattr(self.config, "max_position_embeddings", 512)
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def max_gen_toks(self):
|
| 167 |
+
return 256
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def batch_size(self):
|
| 171 |
+
return self._batch_size
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def device(self):
|
| 175 |
+
return self._device
|
| 176 |
+
|
| 177 |
+
def tok_encode(self, string: str) -> List[int]:
|
| 178 |
+
return self.tokenizer.encode(string, add_special_tokens=False)
|
| 179 |
+
|
| 180 |
+
def tok_decode(self, tokens: List[int]) -> str:
|
| 181 |
+
return self.tokenizer.decode(tokens)
|
| 182 |
+
|
| 183 |
+
def _model_call(self, input_ids: torch.Tensor):
|
| 184 |
+
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16, enabled=self._device != "cpu"):
|
| 185 |
+
word_positions = None
|
| 186 |
+
if self._word_start_table is not None:
|
| 187 |
+
word_positions = compute_word_positions(input_ids, self._word_start_table)
|
| 188 |
+
output = self.model(input_ids, use_cache=False, word_positions=word_positions)
|
| 189 |
+
return output["logits"]
|
| 190 |
+
|
| 191 |
+
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
|
| 192 |
+
results = []
|
| 193 |
+
for context_enc, continuation_enc in requests:
|
| 194 |
+
# Truncate from the left if too long
|
| 195 |
+
full_enc = context_enc + continuation_enc
|
| 196 |
+
if len(full_enc) > self.max_length:
|
| 197 |
+
excess = len(full_enc) - self.max_length
|
| 198 |
+
context_enc = context_enc[excess:]
|
| 199 |
+
full_enc = context_enc + continuation_enc
|
| 200 |
+
|
| 201 |
+
input_ids = torch.tensor(
|
| 202 |
+
[full_enc], dtype=torch.long, device=self._device
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
logits = self._model_call(input_ids)
|
| 206 |
+
|
| 207 |
+
ctx_len = len(context_enc)
|
| 208 |
+
cont_logits = logits[:, ctx_len - 1 : -1, :]
|
| 209 |
+
cont_tokens = input_ids[:, ctx_len:]
|
| 210 |
+
|
| 211 |
+
log_probs = F.log_softmax(cont_logits, dim=-1)
|
| 212 |
+
token_log_probs = log_probs.gather(
|
| 213 |
+
2, cont_tokens.unsqueeze(-1)
|
| 214 |
+
).squeeze(-1)
|
| 215 |
+
|
| 216 |
+
total_log_prob = token_log_probs.sum().item()
|
| 217 |
+
is_greedy = (cont_logits.argmax(dim=-1) == cont_tokens).all().item()
|
| 218 |
+
|
| 219 |
+
results.append((total_log_prob, is_greedy))
|
| 220 |
+
|
| 221 |
+
return results
|
| 222 |
+
|
| 223 |
+
def loglikelihood(
|
| 224 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 225 |
+
) -> List[tuple]:
|
| 226 |
+
results = []
|
| 227 |
+
for request in tqdm(
|
| 228 |
+
requests, desc="loglikelihood", disable=disable_tqdm
|
| 229 |
+
):
|
| 230 |
+
context, continuation = request.args
|
| 231 |
+
# Encode full text together to get correct tokenization,
|
| 232 |
+
# then split — sentencepiece tokenizes differently at string
|
| 233 |
+
# boundaries vs mid-sequence (the leading ▁ problem)
|
| 234 |
+
context_enc = self.tok_encode(context)
|
| 235 |
+
full_enc = self.tok_encode(context + continuation)
|
| 236 |
+
continuation_enc = full_enc[len(context_enc):]
|
| 237 |
+
if not continuation_enc:
|
| 238 |
+
# Edge case: continuation was absorbed into context tokens
|
| 239 |
+
# Fall back to encoding continuation separately
|
| 240 |
+
continuation_enc = self.tok_encode(continuation)
|
| 241 |
+
result = self._loglikelihood_tokens([(context_enc, continuation_enc)])
|
| 242 |
+
results.append(result[0])
|
| 243 |
+
return results
|
| 244 |
+
|
| 245 |
+
def loglikelihood_rolling(
|
| 246 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 247 |
+
) -> List[float]:
|
| 248 |
+
results = []
|
| 249 |
+
for request in tqdm(
|
| 250 |
+
requests, desc="loglikelihood_rolling", disable=disable_tqdm
|
| 251 |
+
):
|
| 252 |
+
text = request.args[0]
|
| 253 |
+
encoding = self.tok_encode(text)
|
| 254 |
+
|
| 255 |
+
total_log_prob = 0.0
|
| 256 |
+
max_len = self.max_length
|
| 257 |
+
|
| 258 |
+
for i in range(0, len(encoding), max_len):
|
| 259 |
+
chunk = encoding[i : i + max_len]
|
| 260 |
+
input_ids = torch.tensor(
|
| 261 |
+
[chunk], dtype=torch.long, device=self._device
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
logits = self._model_call(input_ids)
|
| 265 |
+
shift_logits = logits[:, :-1, :]
|
| 266 |
+
shift_labels = input_ids[:, 1:]
|
| 267 |
+
|
| 268 |
+
log_probs = F.log_softmax(shift_logits, dim=-1)
|
| 269 |
+
token_log_probs = log_probs.gather(
|
| 270 |
+
2, shift_labels.unsqueeze(-1)
|
| 271 |
+
).squeeze(-1)
|
| 272 |
+
|
| 273 |
+
total_log_prob += token_log_probs.sum().item()
|
| 274 |
+
|
| 275 |
+
results.append(total_log_prob)
|
| 276 |
+
return results
|
| 277 |
+
|
| 278 |
+
def generate_until(
|
| 279 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
| 280 |
+
) -> List[str]:
|
| 281 |
+
results = []
|
| 282 |
+
for request in tqdm(
|
| 283 |
+
requests, desc="generate_until", disable=disable_tqdm
|
| 284 |
+
):
|
| 285 |
+
context = request.args[0]
|
| 286 |
+
gen_kwargs = getattr(request, "kwargs", {}) or {}
|
| 287 |
+
|
| 288 |
+
until = gen_kwargs.get("until", [self.tokenizer.eos_token])
|
| 289 |
+
max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
|
| 290 |
+
|
| 291 |
+
context_enc = self.tok_encode(context)
|
| 292 |
+
# Truncate context from left if needed
|
| 293 |
+
if len(context_enc) > self.max_length - max_gen:
|
| 294 |
+
context_enc = context_enc[-(self.max_length - max_gen) :]
|
| 295 |
+
input_ids = torch.tensor(
|
| 296 |
+
[context_enc], dtype=torch.long, device=self._device
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if self.model_type == "graft_g2lu":
|
| 300 |
+
# Use HF's native generate with KV caching — much faster than
|
| 301 |
+
# manual token-by-token without cache (O(n) vs O(n²))
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
output_ids = self._raw_model.generate(
|
| 304 |
+
input_ids,
|
| 305 |
+
max_new_tokens=max_gen,
|
| 306 |
+
do_sample=False,
|
| 307 |
+
use_cache=True,
|
| 308 |
+
)
|
| 309 |
+
generated_text = self.tok_decode(
|
| 310 |
+
output_ids[0, input_ids.shape[1] :].tolist()
|
| 311 |
+
)
|
| 312 |
+
else:
|
| 313 |
+
generated_ids = input_ids.clone()
|
| 314 |
+
with torch.no_grad():
|
| 315 |
+
for _ in range(max_gen):
|
| 316 |
+
# Truncate if we exceed max_length
|
| 317 |
+
if generated_ids.shape[1] > self.max_length:
|
| 318 |
+
generated_ids = generated_ids[:, -self.max_length :]
|
| 319 |
+
|
| 320 |
+
logits = self._model_call(generated_ids)
|
| 321 |
+
next_logits = logits[:, -1, :]
|
| 322 |
+
next_token = next_logits.argmax(dim=-1, keepdim=True)
|
| 323 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 324 |
+
|
| 325 |
+
if next_token.item() == self.eot_token_id:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
current_text = self.tok_decode(
|
| 329 |
+
generated_ids[0, len(context_enc) :].tolist()
|
| 330 |
+
)
|
| 331 |
+
if any(s in current_text for s in until):
|
| 332 |
+
break
|
| 333 |
+
|
| 334 |
+
generated_text = self.tok_decode(
|
| 335 |
+
generated_ids[0, len(context_enc) :].tolist()
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
for stop in until:
|
| 339 |
+
if stop in generated_text:
|
| 340 |
+
generated_text = generated_text[: generated_text.index(stop)]
|
| 341 |
+
|
| 342 |
+
results.append(generated_text)
|
| 343 |
+
|
| 344 |
+
return results
|
mirrored.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mirrored Transformer: Weight-sharing between expand and compress phases.
|
| 3 |
+
|
| 4 |
+
Based on the biconcave lens hypothesis from grafting research:
|
| 5 |
+
- Early layers expand from tokens to semantic space
|
| 6 |
+
- Late layers compress from semantic space back to tokens
|
| 7 |
+
- These phases share structural computation (W₁, W₂)
|
| 8 |
+
- Only the gate (semiotic filter) differs by direction
|
| 9 |
+
|
| 10 |
+
Architecture:
|
| 11 |
+
y = W₂ @ (W₁ @ x ⊙ swish(W₃ @ swish(W₄ @ x)))
|
| 12 |
+
|
| 13 |
+
Both gates fire every pass (additive, OR-logic). W₁ computed once.
|
| 14 |
+
W₁, W₂ shared between mirror pairs. W₃, W₄ are dual gates.
|
| 15 |
+
~33% FFN parameter savings per mirrored pair vs standard SwiGLU.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import math
|
| 22 |
+
from dataclasses import dataclass, fields
|
| 23 |
+
|
| 24 |
+
from .layers import RMSNorm, CausalAttention, SwiGLU
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class MirroredConfig:
|
| 29 |
+
"""Configuration for Mirrored Transformer."""
|
| 30 |
+
vocab_size: int = 50257
|
| 31 |
+
hidden_size: int = 768
|
| 32 |
+
num_heads: int = 12
|
| 33 |
+
num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA)
|
| 34 |
+
num_layers: int = 12 # effective depth (expand + middle + compress)
|
| 35 |
+
n_middle: int = 2 # unique middle layers (standard SwiGLU)
|
| 36 |
+
max_seq_len: int = 512
|
| 37 |
+
dropout: float = 0.0
|
| 38 |
+
aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled)
|
| 39 |
+
aux_skip_weight: float = 0.1 # weight for auxiliary skip loss
|
| 40 |
+
use_g2lu: bool = True # G²LU nested gates (False = vanilla SwiGLU)
|
| 41 |
+
word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled)
|
| 42 |
+
word_rope_base: float = 10.0 # frequency base for word-position RoPE
|
| 43 |
+
embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size)
|
| 44 |
+
head_dim: int = 0 # MLP head intermediate dim (0 = linear head)
|
| 45 |
+
|
| 46 |
+
def __post_init__(self):
|
| 47 |
+
assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
|
| 48 |
+
if self.num_kv_heads is not None:
|
| 49 |
+
assert self.num_heads % self.num_kv_heads == 0, \
|
| 50 |
+
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
| 51 |
+
n_mirror_layers = self.num_layers - self.n_middle
|
| 52 |
+
assert n_mirror_layers > 0, "num_layers must be greater than n_middle"
|
| 53 |
+
assert n_mirror_layers % 2 == 0, "num_layers - n_middle must be even"
|
| 54 |
+
self.n_mirror = n_mirror_layers // 2
|
| 55 |
+
|
| 56 |
+
def to_dict(self) -> dict:
|
| 57 |
+
"""Convert to dictionary for serialization."""
|
| 58 |
+
return {f.name: getattr(self, f.name) for f in fields(self) if f.name != "n_mirror"}
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def from_dict(cls, d: dict) -> "MirroredConfig":
|
| 62 |
+
"""Create from dictionary."""
|
| 63 |
+
valid = {f.name for f in fields(cls)}
|
| 64 |
+
filtered = {k: v for k, v in d.items() if k in valid}
|
| 65 |
+
return cls(**filtered)
|
| 66 |
+
|
| 67 |
+
class MLP(nn.Module):
|
| 68 |
+
"""Feed-forward network with SiLU activation."""
|
| 69 |
+
|
| 70 |
+
def __init__(self, dim, intermediate_size, dropout):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.up_proj = nn.Linear(dim, intermediate_size, bias=False)
|
| 73 |
+
self.gate_proj = nn.Linear(dim, intermediate_size, bias=False)
|
| 74 |
+
self.down_proj = nn.Linear(intermediate_size, dim, bias=False)
|
| 75 |
+
self.dropout = nn.Dropout(dropout)
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
|
| 79 |
+
|
| 80 |
+
class MirroredSwiGLU(nn.Module):
|
| 81 |
+
"""SwiGLU with shared base weights and dual gates.
|
| 82 |
+
|
| 83 |
+
Standard SwiGLU: y = W₂(silu(W₁x) ⊙ W₃x) — 3 matrices
|
| 84 |
+
Mirrored SwiGLU: y = W₂(W₁x ⊙ (silu(W₃ ⊙ silu(W₄x)))) — 2 shared + 2 gates
|
| 85 |
+
|
| 86 |
+
W₁ computed once, reused for both branches.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, hidden_size: int, intermediate_size: int | None = None,
|
| 90 |
+
gate_mode: str = 'additive', use_g2lu: bool = True):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.gate_mode = gate_mode
|
| 93 |
+
self.use_g2lu = use_g2lu
|
| 94 |
+
self._current_step = 0
|
| 95 |
+
intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
| 96 |
+
intermediate_size = ((intermediate_size + 63) // 64) * 64
|
| 97 |
+
|
| 98 |
+
# Shared structural transform
|
| 99 |
+
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 100 |
+
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 101 |
+
|
| 102 |
+
# Gate(s)
|
| 103 |
+
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 104 |
+
if use_g2lu:
|
| 105 |
+
self.w4 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 106 |
+
|
| 107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
hidden = self.w1(x)
|
| 109 |
+
if self.use_g2lu:
|
| 110 |
+
g4 = F.silu(self.w4(x))
|
| 111 |
+
g3 = F.silu(self.w3(x) * g4)
|
| 112 |
+
else:
|
| 113 |
+
g3 = F.silu(self.w3(x))
|
| 114 |
+
return self.w2(hidden * g3)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class MirroredBlock(nn.Module):
|
| 118 |
+
"""Transformer block with shared weights for expand/compress phases.
|
| 119 |
+
|
| 120 |
+
Each MirroredBlock is used TWICE in the forward pass:
|
| 121 |
+
once during expand (building semantics) and once during compress (encoding output).
|
| 122 |
+
|
| 123 |
+
Shared: attention weights (optional), FFN W₁/W₂
|
| 124 |
+
Separate: norms (different residual stream statistics), FFN gate
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
|
| 128 |
+
max_seq_len: int = 2048,
|
| 129 |
+
dropout: float = 0.0,
|
| 130 |
+
window_size: int | None = None, gate_mode: str = 'additive',
|
| 131 |
+
word_rope_dims: int = 0, word_rope_base: float = 10.0,
|
| 132 |
+
use_g2lu: bool = True):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size=window_size,
|
| 136 |
+
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
| 137 |
+
|
| 138 |
+
# FFN with shared base + direction-specific gates
|
| 139 |
+
self.ffn = MirroredSwiGLU(hidden_size, gate_mode=gate_mode, use_g2lu=use_g2lu)
|
| 140 |
+
|
| 141 |
+
# Separate norms per direction (residual stream statistics differ)
|
| 142 |
+
self.expand_attn_norm = RMSNorm(hidden_size)
|
| 143 |
+
self.expand_ffn_norm = RMSNorm(hidden_size)
|
| 144 |
+
self.compress_attn_norm = RMSNorm(hidden_size)
|
| 145 |
+
self.compress_ffn_norm = RMSNorm(hidden_size)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
|
| 148 |
+
word_positions: torch.Tensor | None = None) -> tuple:
|
| 149 |
+
attn_norm = self.compress_attn_norm
|
| 150 |
+
ffn_norm = self.compress_ffn_norm
|
| 151 |
+
attn = self.attn
|
| 152 |
+
|
| 153 |
+
attn_out, new_kv = attn(attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| 154 |
+
x = x + attn_out
|
| 155 |
+
x = x + self.ffn(ffn_norm(x))
|
| 156 |
+
|
| 157 |
+
return x, new_kv
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MiddleBlock(nn.Module):
|
| 161 |
+
"""Standard transformer block for unique middle layers.
|
| 162 |
+
|
| 163 |
+
When gate_mode is provided, uses MirroredSwiGLU (dual-gate) instead of
|
| 164 |
+
single-gate SwiGLU — giving the middle the same rich gating geometry
|
| 165 |
+
as the mirror pairs.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
|
| 169 |
+
max_seq_len: int = 2048,
|
| 170 |
+
dropout: float = 0.0,
|
| 171 |
+
word_rope_dims: int = 0, word_rope_base: float = 10.0,
|
| 172 |
+
use_g2lu: bool = True):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.attn_norm = RMSNorm(hidden_size)
|
| 175 |
+
self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout,
|
| 176 |
+
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
| 177 |
+
self.ffn_norm = RMSNorm(hidden_size)
|
| 178 |
+
self.ffn = MirroredSwiGLU(hidden_size, use_g2lu=use_g2lu)
|
| 179 |
+
|
| 180 |
+
def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
|
| 181 |
+
word_positions: torch.Tensor | None = None) -> tuple:
|
| 182 |
+
attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| 183 |
+
x = x + attn_out
|
| 184 |
+
x = x + self.ffn(self.ffn_norm(x))
|
| 185 |
+
return x, new_kv
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class MirroredTransformer(nn.Module):
|
| 189 |
+
"""Transformer with mirrored expand/compress architecture.
|
| 190 |
+
|
| 191 |
+
Forward pass:
|
| 192 |
+
1. Embed tokens
|
| 193 |
+
2. Expand phase: mirror_blocks[0..N] with w3
|
| 194 |
+
3. Middle: unique standard blocks
|
| 195 |
+
4. Compress phase: mirror_blocks[N..0] (reversed) with w4
|
| 196 |
+
5. Norm + LM head
|
| 197 |
+
|
| 198 |
+
For a 12-layer model with n_middle=2:
|
| 199 |
+
- 5 mirror pairs (10 virtual layers) + 2 middle = 12 effective layers
|
| 200 |
+
- Expand: blocks[0] → blocks[4]
|
| 201 |
+
- Middle: middle[0] → middle[1]
|
| 202 |
+
- Compress: blocks[4] → blocks[0]
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, config: MirroredConfig):
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.config = config
|
| 208 |
+
|
| 209 |
+
# Token embeddings (optionally factorized)
|
| 210 |
+
embed_dim = getattr(config, 'embed_dim', 0)
|
| 211 |
+
head_dim = getattr(config, 'head_dim', 0)
|
| 212 |
+
# Auto-mirror factorization: head uses embed_dim for weight tying
|
| 213 |
+
if embed_dim > 0 and head_dim == 0:
|
| 214 |
+
head_dim = embed_dim
|
| 215 |
+
|
| 216 |
+
# G²LU config (needed before projection setup)
|
| 217 |
+
use_g2lu = getattr(config, 'use_g2lu', True)
|
| 218 |
+
|
| 219 |
+
if embed_dim > 0:
|
| 220 |
+
self.embed = nn.Embedding(config.vocab_size, embed_dim)
|
| 221 |
+
self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| 222 |
+
# G²LU gates for up-projection (consistent with mirror blocks)
|
| 223 |
+
if use_g2lu:
|
| 224 |
+
self.embed_g3 = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| 225 |
+
self.embed_g4 = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| 226 |
+
else:
|
| 227 |
+
self.embed_g3 = None
|
| 228 |
+
self.embed_g4 = None
|
| 229 |
+
else:
|
| 230 |
+
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 231 |
+
self.embed_proj = None
|
| 232 |
+
self.embed_g3 = None
|
| 233 |
+
self.embed_g4 = None
|
| 234 |
+
self.embed_scale = math.sqrt(config.hidden_size)
|
| 235 |
+
self.window_sizes = [None] * config.n_mirror
|
| 236 |
+
|
| 237 |
+
# Word-position RoPE config
|
| 238 |
+
word_rope_dims = getattr(config, 'word_rope_dims', 0)
|
| 239 |
+
word_rope_base = getattr(config, 'word_rope_base', 10.0)
|
| 240 |
+
|
| 241 |
+
# Mirrored blocks (used in both expand and compress phases)
|
| 242 |
+
self.mirror_blocks = nn.ModuleList([
|
| 243 |
+
MirroredBlock(
|
| 244 |
+
config.hidden_size, config.num_heads, config.num_kv_heads,
|
| 245 |
+
config.max_seq_len,
|
| 246 |
+
config.dropout,
|
| 247 |
+
window_size=self.window_sizes[i],
|
| 248 |
+
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
|
| 249 |
+
use_g2lu=use_g2lu,
|
| 250 |
+
)
|
| 251 |
+
for i in range(config.n_mirror)
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
# Unique middle blocks (standard transformer, optionally dual-gated)
|
| 255 |
+
self.middle_blocks = nn.ModuleList([
|
| 256 |
+
MiddleBlock(config.hidden_size, config.num_heads, config.num_kv_heads,
|
| 257 |
+
config.max_seq_len, config.dropout,
|
| 258 |
+
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
|
| 259 |
+
use_g2lu=use_g2lu)
|
| 260 |
+
for _ in range(config.n_middle)
|
| 261 |
+
])
|
| 262 |
+
|
| 263 |
+
# Output (optionally MLP head)
|
| 264 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 265 |
+
if head_dim > 0:
|
| 266 |
+
self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 267 |
+
self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| 268 |
+
# G²LU gates for down-projection
|
| 269 |
+
if use_g2lu:
|
| 270 |
+
self.head_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 271 |
+
self.head_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 272 |
+
else:
|
| 273 |
+
self.head_g3 = None
|
| 274 |
+
self.head_g4 = None
|
| 275 |
+
else:
|
| 276 |
+
self.head_down = None
|
| 277 |
+
self.head_g3 = None
|
| 278 |
+
self.head_g4 = None
|
| 279 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 280 |
+
|
| 281 |
+
# Weight tying (when embed and lm_head dimensions match)
|
| 282 |
+
_e = embed_dim if embed_dim > 0 else config.hidden_size
|
| 283 |
+
_h = head_dim if head_dim > 0 else config.hidden_size
|
| 284 |
+
if _e == _h:
|
| 285 |
+
self.lm_head.weight = self.embed.weight
|
| 286 |
+
|
| 287 |
+
# Auxiliary skip-ahead prediction head
|
| 288 |
+
self.skip_head = None
|
| 289 |
+
self.skip_head_down = None
|
| 290 |
+
self.skip_g3 = None
|
| 291 |
+
self.skip_g4 = None
|
| 292 |
+
if config.aux_skip_k > 0:
|
| 293 |
+
if head_dim > 0:
|
| 294 |
+
self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 295 |
+
self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| 296 |
+
if use_g2lu:
|
| 297 |
+
self.skip_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 298 |
+
self.skip_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 299 |
+
else:
|
| 300 |
+
self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 301 |
+
|
| 302 |
+
# Initialize weights
|
| 303 |
+
self.apply(self._init_weights)
|
| 304 |
+
|
| 305 |
+
def _init_weights(self, module):
|
| 306 |
+
if isinstance(module, nn.Linear):
|
| 307 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 308 |
+
if module.bias is not None:
|
| 309 |
+
torch.nn.init.zeros_(module.bias)
|
| 310 |
+
elif isinstance(module, nn.Embedding):
|
| 311 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def total_virtual_layers(self) -> int:
|
| 315 |
+
"""Total number of virtual layers in the forward pass."""
|
| 316 |
+
return self.config.n_mirror * 2 + self.config.n_middle
|
| 317 |
+
|
| 318 |
+
def forward(
|
| 319 |
+
self,
|
| 320 |
+
input_ids: torch.Tensor,
|
| 321 |
+
labels: torch.Tensor = None,
|
| 322 |
+
use_cache: bool = False,
|
| 323 |
+
past_kv: list = None,
|
| 324 |
+
word_positions: torch.Tensor | None = None,
|
| 325 |
+
) -> dict:
|
| 326 |
+
B, L = input_ids.shape
|
| 327 |
+
|
| 328 |
+
# Embed tokens (optionally factorized, with G²LU gating)
|
| 329 |
+
x = self.embed(input_ids)
|
| 330 |
+
if self.embed_proj is not None:
|
| 331 |
+
if self.embed_g3 is not None:
|
| 332 |
+
g4 = F.silu(self.embed_g4(x))
|
| 333 |
+
g3 = F.silu(self.embed_g3(x) * g4)
|
| 334 |
+
x = self.embed_proj(x) * g3
|
| 335 |
+
else:
|
| 336 |
+
x = F.silu(self.embed_proj(x))
|
| 337 |
+
x = x * self.embed_scale
|
| 338 |
+
|
| 339 |
+
new_kv = [] if use_cache else None
|
| 340 |
+
kv_idx = 0
|
| 341 |
+
|
| 342 |
+
# === Expand phase ===
|
| 343 |
+
for block in self.mirror_blocks:
|
| 344 |
+
layer_past = past_kv[kv_idx] if past_kv is not None else None
|
| 345 |
+
x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
|
| 346 |
+
if use_cache:
|
| 347 |
+
new_kv.append(kv)
|
| 348 |
+
kv_idx += 1
|
| 349 |
+
|
| 350 |
+
# === Dual-path: save pre-middle state for alignment loss ===
|
| 351 |
+
for block in self.middle_blocks:
|
| 352 |
+
layer_past = past_kv[kv_idx] if past_kv is not None else None
|
| 353 |
+
x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
|
| 354 |
+
if use_cache:
|
| 355 |
+
new_kv.append(kv)
|
| 356 |
+
kv_idx += 1
|
| 357 |
+
|
| 358 |
+
# === Compress phase (reversed order) ===
|
| 359 |
+
for i in reversed(range(len(self.mirror_blocks))):
|
| 360 |
+
layer_past = past_kv[kv_idx] if past_kv is not None else None
|
| 361 |
+
x, kv = self.mirror_blocks[i](x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
|
| 362 |
+
if use_cache:
|
| 363 |
+
new_kv.append(kv)
|
| 364 |
+
kv_idx += 1
|
| 365 |
+
|
| 366 |
+
# === Output (optionally MLP head with G²LU gating) ===
|
| 367 |
+
x = self.norm(x)
|
| 368 |
+
if self.head_down is not None:
|
| 369 |
+
if self.head_g3 is not None:
|
| 370 |
+
g4 = F.silu(self.head_g4(x))
|
| 371 |
+
g3 = F.silu(self.head_g3(x) * g4)
|
| 372 |
+
logits = self.lm_head(self.head_down(x) * g3)
|
| 373 |
+
else:
|
| 374 |
+
logits = self.lm_head(F.silu(self.head_down(x)))
|
| 375 |
+
else:
|
| 376 |
+
logits = self.lm_head(x)
|
| 377 |
+
|
| 378 |
+
result = {"logits": logits}
|
| 379 |
+
|
| 380 |
+
if use_cache:
|
| 381 |
+
result["past_kv"] = new_kv
|
| 382 |
+
|
| 383 |
+
if labels is not None:
|
| 384 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 385 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 386 |
+
loss = F.cross_entropy(
|
| 387 |
+
shift_logits.view(-1, self.config.vocab_size),
|
| 388 |
+
shift_labels.view(-1),
|
| 389 |
+
ignore_index=-100
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
if self.skip_head is not None:
|
| 393 |
+
skip_k = self.config.aux_skip_k
|
| 394 |
+
if self.skip_head_down is not None:
|
| 395 |
+
if self.skip_g3 is not None:
|
| 396 |
+
g4 = F.silu(self.skip_g4(x))
|
| 397 |
+
g3 = F.silu(self.skip_g3(x) * g4)
|
| 398 |
+
skip_logits = self.skip_head(self.skip_head_down(x) * g3)[:, :-skip_k, :].contiguous()
|
| 399 |
+
else:
|
| 400 |
+
skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
|
| 401 |
+
else:
|
| 402 |
+
skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
|
| 403 |
+
skip_labels = labels[:, skip_k:].contiguous()
|
| 404 |
+
aux_loss = F.cross_entropy(
|
| 405 |
+
skip_logits.view(-1, self.config.vocab_size),
|
| 406 |
+
skip_labels.view(-1),
|
| 407 |
+
ignore_index=-100
|
| 408 |
+
)
|
| 409 |
+
result["aux_loss"] = aux_loss
|
| 410 |
+
loss = loss + self.config.aux_skip_weight * aux_loss
|
| 411 |
+
|
| 412 |
+
result["loss"] = loss
|
| 413 |
+
|
| 414 |
+
return result
|
| 415 |
+
|
| 416 |
+
@torch.no_grad()
|
| 417 |
+
def generate(
|
| 418 |
+
self,
|
| 419 |
+
prompt_ids: torch.Tensor,
|
| 420 |
+
max_new_tokens: int = 50,
|
| 421 |
+
temperature: float = 0.8,
|
| 422 |
+
top_k: int = 50,
|
| 423 |
+
top_p: float = 0.9,
|
| 424 |
+
use_cache: bool = True,
|
| 425 |
+
word_start_table: torch.Tensor | None = None,
|
| 426 |
+
) -> torch.Tensor:
|
| 427 |
+
"""Autoregressive generation with KV caching."""
|
| 428 |
+
from .layers import compute_word_positions
|
| 429 |
+
|
| 430 |
+
self.eval()
|
| 431 |
+
generated = prompt_ids.clone()
|
| 432 |
+
past_kv = None
|
| 433 |
+
word_pos_counter = 0
|
| 434 |
+
|
| 435 |
+
for _ in range(max_new_tokens):
|
| 436 |
+
if use_cache and past_kv is not None:
|
| 437 |
+
input_ids = generated[:, -1:]
|
| 438 |
+
if word_start_table is not None:
|
| 439 |
+
last_token = generated[0, -1].item()
|
| 440 |
+
if word_start_table[last_token]:
|
| 441 |
+
word_pos_counter = 0
|
| 442 |
+
else:
|
| 443 |
+
word_pos_counter += 1
|
| 444 |
+
word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
|
| 445 |
+
else:
|
| 446 |
+
word_positions = None
|
| 447 |
+
else:
|
| 448 |
+
input_ids = generated
|
| 449 |
+
if word_start_table is not None:
|
| 450 |
+
word_positions = compute_word_positions(input_ids, word_start_table)
|
| 451 |
+
else:
|
| 452 |
+
word_positions = None
|
| 453 |
+
|
| 454 |
+
output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
|
| 455 |
+
logits = output["logits"][:, -1, :]
|
| 456 |
+
|
| 457 |
+
if use_cache:
|
| 458 |
+
past_kv = output["past_kv"]
|
| 459 |
+
|
| 460 |
+
if temperature > 0:
|
| 461 |
+
logits = logits / temperature
|
| 462 |
+
|
| 463 |
+
if top_k > 0:
|
| 464 |
+
top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 465 |
+
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
|
| 466 |
+
logits = torch.where(logits < min_top_k, float("-inf"), logits)
|
| 467 |
+
|
| 468 |
+
if top_p < 1.0:
|
| 469 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 470 |
+
cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 471 |
+
sorted_indices_to_remove = cumsum_probs > top_p
|
| 472 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 473 |
+
sorted_indices_to_remove[:, 0] = False
|
| 474 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 475 |
+
logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
| 476 |
+
|
| 477 |
+
probs = F.softmax(logits, dim=-1)
|
| 478 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 479 |
+
else:
|
| 480 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
| 481 |
+
|
| 482 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 483 |
+
|
| 484 |
+
if generated.size(1) >= self.config.max_seq_len:
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
return generated
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def count_mirrored_parameters(model: MirroredTransformer) -> dict:
|
| 491 |
+
"""Count parameters with breakdown by component."""
|
| 492 |
+
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 493 |
+
|
| 494 |
+
# Unique params (not double-counted from weight tying)
|
| 495 |
+
unique = sum(p.numel() for p in set(p for p in model.parameters() if p.requires_grad))
|
| 496 |
+
|
| 497 |
+
mirror_params = sum(p.numel() for p in model.mirror_blocks.parameters())
|
| 498 |
+
middle_params = sum(p.numel() for p in model.middle_blocks.parameters())
|
| 499 |
+
embed_params = model.embed.weight.numel()
|
| 500 |
+
if model.embed_proj is not None:
|
| 501 |
+
embed_params += model.embed_proj.weight.numel()
|
| 502 |
+
head_params = 0
|
| 503 |
+
if model.head_down is not None:
|
| 504 |
+
head_params += model.head_down.weight.numel()
|
| 505 |
+
head_params += model.lm_head.weight.numel()
|
| 506 |
+
|
| 507 |
+
# Break down mirror block into shared vs direction-specific
|
| 508 |
+
shared_attn = 0
|
| 509 |
+
shared_ffn_base = 0
|
| 510 |
+
gate_params = 0
|
| 511 |
+
norm_params = 0
|
| 512 |
+
|
| 513 |
+
for block in model.mirror_blocks:
|
| 514 |
+
shared_attn += sum(p.numel() for p in block.attn.parameters())
|
| 515 |
+
shared_ffn_base += block.ffn.w1.weight.numel() + block.ffn.w2.weight.numel()
|
| 516 |
+
gate_params += block.ffn.w3.weight.numel()
|
| 517 |
+
if hasattr(block.ffn, 'w4'):
|
| 518 |
+
gate_params += block.ffn.w4.weight.numel()
|
| 519 |
+
norm_params += sum(p.numel() for n, p in block.named_parameters() if 'norm' in n)
|
| 520 |
+
|
| 521 |
+
return {
|
| 522 |
+
"total": total,
|
| 523 |
+
"unique": unique,
|
| 524 |
+
"mirror_blocks": mirror_params,
|
| 525 |
+
"middle_blocks": middle_params,
|
| 526 |
+
"embedding": embed_params,
|
| 527 |
+
"head": head_params,
|
| 528 |
+
"shared_attention": shared_attn,
|
| 529 |
+
"shared_ffn_base": shared_ffn_base,
|
| 530 |
+
"direction_gates": gate_params,
|
| 531 |
+
"norms": norm_params,
|
| 532 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Circuit Transformer: Minimal transformer for semantic circuitry experiments.
|
| 3 |
+
|
| 4 |
+
Follows patterns from shimmer/lira/gpt.py with extension hooks for future work.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
from .config import CircuitConfig
|
| 13 |
+
from .layers import RMSNorm, RotaryEmbedding, CausalAttention, SwiGLU, WordPositionRoPE
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TransformerBlock(nn.Module):
|
| 17 |
+
"""Pre-norm transformer block with causal attention."""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
hidden_size: int,
|
| 22 |
+
num_heads: int,
|
| 23 |
+
num_kv_heads: int | None = None,
|
| 24 |
+
max_seq_len: int = 2048,
|
| 25 |
+
dropout: float = 0.0,
|
| 26 |
+
window_size: int | None = None,
|
| 27 |
+
word_rope_dims: int = 0,
|
| 28 |
+
word_rope_base: float = 10.0,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.attn_norm = RMSNorm(hidden_size)
|
| 32 |
+
self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size,
|
| 33 |
+
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
| 34 |
+
self.ffn_norm = RMSNorm(hidden_size)
|
| 35 |
+
self.ffn = SwiGLU(hidden_size)
|
| 36 |
+
|
| 37 |
+
def forward(
|
| 38 |
+
self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
|
| 39 |
+
word_positions: torch.Tensor | None = None,
|
| 40 |
+
) -> tuple[torch.Tensor, tuple | None]:
|
| 41 |
+
# Attention with residual
|
| 42 |
+
attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| 43 |
+
x = x + attn_out
|
| 44 |
+
|
| 45 |
+
# FFN with residual
|
| 46 |
+
x = x + self.ffn(self.ffn_norm(x))
|
| 47 |
+
|
| 48 |
+
return x, new_kv
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CircuitTransformer(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Minimal transformer for semantic circuitry experiments.
|
| 54 |
+
|
| 55 |
+
Features:
|
| 56 |
+
- Standard GPT-style architecture (RMSNorm, RoPE, SwiGLU, causal attention)
|
| 57 |
+
- Weight tying (embed = lm_head)
|
| 58 |
+
- Extension hooks for future work:
|
| 59 |
+
- freeze_layers() / unfreeze_layers() for progressive training
|
| 60 |
+
- get_layer_outputs() for interpretability
|
| 61 |
+
- window_size param for sliding window attention
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, config: CircuitConfig):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.config = config
|
| 67 |
+
|
| 68 |
+
# Token embeddings (optionally factorized)
|
| 69 |
+
embed_dim = getattr(config, 'embed_dim', 0)
|
| 70 |
+
head_dim = getattr(config, 'head_dim', 0)
|
| 71 |
+
# Auto-mirror factorization: head uses embed_dim for weight tying
|
| 72 |
+
if embed_dim > 0 and head_dim == 0:
|
| 73 |
+
head_dim = embed_dim
|
| 74 |
+
|
| 75 |
+
if embed_dim > 0:
|
| 76 |
+
self.embed = nn.Embedding(config.vocab_size, embed_dim)
|
| 77 |
+
self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| 78 |
+
else:
|
| 79 |
+
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 80 |
+
self.embed_proj = None
|
| 81 |
+
self.embed_scale = math.sqrt(config.hidden_size)
|
| 82 |
+
|
| 83 |
+
# Transformer blocks
|
| 84 |
+
self.layers = nn.ModuleList([
|
| 85 |
+
TransformerBlock(
|
| 86 |
+
config.hidden_size,
|
| 87 |
+
config.num_heads,
|
| 88 |
+
getattr(config, 'num_kv_heads', None),
|
| 89 |
+
config.max_seq_len,
|
| 90 |
+
config.dropout,
|
| 91 |
+
word_rope_dims=getattr(config, 'word_rope_dims', 0),
|
| 92 |
+
word_rope_base=getattr(config, 'word_rope_base', 10.0),
|
| 93 |
+
)
|
| 94 |
+
for _ in range(config.num_layers)
|
| 95 |
+
])
|
| 96 |
+
|
| 97 |
+
# Output (optionally MLP head)
|
| 98 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 99 |
+
if head_dim > 0:
|
| 100 |
+
self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 101 |
+
self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| 102 |
+
else:
|
| 103 |
+
self.head_down = None
|
| 104 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 105 |
+
|
| 106 |
+
# Weight tying (when embed and lm_head dimensions match)
|
| 107 |
+
_e = embed_dim if embed_dim > 0 else config.hidden_size
|
| 108 |
+
_h = head_dim if head_dim > 0 else config.hidden_size
|
| 109 |
+
if _e == _h:
|
| 110 |
+
self.lm_head.weight = self.embed.weight
|
| 111 |
+
|
| 112 |
+
# Auxiliary skip-ahead prediction head
|
| 113 |
+
self.skip_head = None
|
| 114 |
+
self.skip_head_down = None
|
| 115 |
+
aux_skip_k = getattr(config, 'aux_skip_k', 0)
|
| 116 |
+
if aux_skip_k > 0:
|
| 117 |
+
if head_dim > 0:
|
| 118 |
+
self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| 119 |
+
self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| 120 |
+
else:
|
| 121 |
+
self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 122 |
+
|
| 123 |
+
# Track frozen layers
|
| 124 |
+
self._frozen_layers: set[int] = set()
|
| 125 |
+
|
| 126 |
+
# Initialize weights
|
| 127 |
+
self.apply(self._init_weights)
|
| 128 |
+
|
| 129 |
+
def _init_weights(self, module):
|
| 130 |
+
if isinstance(module, nn.Linear):
|
| 131 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 132 |
+
if module.bias is not None:
|
| 133 |
+
torch.nn.init.zeros_(module.bias)
|
| 134 |
+
elif isinstance(module, nn.Embedding):
|
| 135 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 136 |
+
|
| 137 |
+
def forward(
|
| 138 |
+
self,
|
| 139 |
+
input_ids: torch.Tensor,
|
| 140 |
+
labels: torch.Tensor | None = None,
|
| 141 |
+
use_cache: bool = False,
|
| 142 |
+
past_kv: list | None = None,
|
| 143 |
+
word_positions: torch.Tensor | None = None,
|
| 144 |
+
) -> dict:
|
| 145 |
+
"""
|
| 146 |
+
Forward pass.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
input_ids: [B, L] token IDs
|
| 150 |
+
labels: [B, L] target token IDs (for loss computation)
|
| 151 |
+
use_cache: Whether to return KV cache for generation
|
| 152 |
+
past_kv: Previous KV cache
|
| 153 |
+
word_positions: [B, L] position within word (from compute_word_positions)
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
dict with 'logits', optionally 'loss' and 'past_kv'
|
| 157 |
+
"""
|
| 158 |
+
B, L = input_ids.shape
|
| 159 |
+
|
| 160 |
+
# Embed tokens (optionally factorized)
|
| 161 |
+
x = self.embed(input_ids)
|
| 162 |
+
if self.embed_proj is not None:
|
| 163 |
+
x = F.silu(self.embed_proj(x))
|
| 164 |
+
x = x * self.embed_scale
|
| 165 |
+
|
| 166 |
+
# Process through layers
|
| 167 |
+
new_kv = [] if use_cache else None
|
| 168 |
+
for i, layer in enumerate(self.layers):
|
| 169 |
+
layer_past = past_kv[i] if past_kv is not None else None
|
| 170 |
+
x, kv = layer(x, use_cache, layer_past, word_positions=word_positions)
|
| 171 |
+
if use_cache:
|
| 172 |
+
new_kv.append(kv)
|
| 173 |
+
|
| 174 |
+
# Output (optionally MLP head)
|
| 175 |
+
x = self.norm(x)
|
| 176 |
+
if self.head_down is not None:
|
| 177 |
+
logits = self.lm_head(F.silu(self.head_down(x)))
|
| 178 |
+
else:
|
| 179 |
+
logits = self.lm_head(x)
|
| 180 |
+
|
| 181 |
+
result = {"logits": logits}
|
| 182 |
+
|
| 183 |
+
if use_cache:
|
| 184 |
+
result["past_kv"] = new_kv
|
| 185 |
+
|
| 186 |
+
# Compute loss if labels provided
|
| 187 |
+
if labels is not None:
|
| 188 |
+
# Shift for next-token prediction
|
| 189 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 190 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 191 |
+
loss = F.cross_entropy(
|
| 192 |
+
shift_logits.view(-1, self.config.vocab_size),
|
| 193 |
+
shift_labels.view(-1),
|
| 194 |
+
ignore_index=-100,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Auxiliary skip-ahead prediction
|
| 198 |
+
if self.skip_head is not None:
|
| 199 |
+
skip_k = getattr(self.config, 'aux_skip_k', 0)
|
| 200 |
+
skip_weight = getattr(self.config, 'aux_skip_weight', 0.1)
|
| 201 |
+
if self.skip_head_down is not None:
|
| 202 |
+
skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
|
| 203 |
+
else:
|
| 204 |
+
skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
|
| 205 |
+
skip_labels = labels[:, skip_k:].contiguous()
|
| 206 |
+
aux_loss = F.cross_entropy(
|
| 207 |
+
skip_logits.view(-1, self.config.vocab_size),
|
| 208 |
+
skip_labels.view(-1),
|
| 209 |
+
ignore_index=-100,
|
| 210 |
+
)
|
| 211 |
+
result["aux_loss"] = aux_loss
|
| 212 |
+
loss = loss + skip_weight * aux_loss
|
| 213 |
+
|
| 214 |
+
result["loss"] = loss
|
| 215 |
+
|
| 216 |
+
return result
|
| 217 |
+
|
| 218 |
+
# === Extension hooks for future experiments ===
|
| 219 |
+
|
| 220 |
+
def freeze_layers(self, indices: list[int]) -> None:
|
| 221 |
+
"""Freeze specific layers (stop gradients)."""
|
| 222 |
+
for idx in indices:
|
| 223 |
+
if 0 <= idx < len(self.layers):
|
| 224 |
+
for param in self.layers[idx].parameters():
|
| 225 |
+
param.requires_grad = False
|
| 226 |
+
self._frozen_layers.add(idx)
|
| 227 |
+
|
| 228 |
+
def unfreeze_layers(self, indices: list[int] | None = None) -> None:
|
| 229 |
+
"""Unfreeze specific layers (or all if indices=None)."""
|
| 230 |
+
if indices is None:
|
| 231 |
+
indices = list(self._frozen_layers)
|
| 232 |
+
for idx in indices:
|
| 233 |
+
if 0 <= idx < len(self.layers):
|
| 234 |
+
for param in self.layers[idx].parameters():
|
| 235 |
+
param.requires_grad = True
|
| 236 |
+
self._frozen_layers.discard(idx)
|
| 237 |
+
|
| 238 |
+
def get_layer_outputs(self, input_ids: torch.Tensor) -> list[torch.Tensor]:
|
| 239 |
+
"""Get intermediate outputs from each layer for interpretability."""
|
| 240 |
+
outputs = []
|
| 241 |
+
x = self.embed(input_ids)
|
| 242 |
+
if self.embed_proj is not None:
|
| 243 |
+
x = F.silu(self.embed_proj(x))
|
| 244 |
+
x = x * self.embed_scale
|
| 245 |
+
|
| 246 |
+
for layer in self.layers:
|
| 247 |
+
x, _ = layer(x, use_cache=False, past_kv=None)
|
| 248 |
+
outputs.append(x.clone())
|
| 249 |
+
|
| 250 |
+
return outputs
|
| 251 |
+
|
| 252 |
+
@torch.no_grad()
|
| 253 |
+
def generate(
|
| 254 |
+
self,
|
| 255 |
+
prompt_ids: torch.Tensor,
|
| 256 |
+
max_new_tokens: int = 50,
|
| 257 |
+
temperature: float = 0.8,
|
| 258 |
+
top_k: int = 50,
|
| 259 |
+
top_p: float = 0.9,
|
| 260 |
+
use_cache: bool = True,
|
| 261 |
+
word_start_table: torch.Tensor | None = None,
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
"""
|
| 264 |
+
Autoregressive generation with KV caching.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
prompt_ids: [B, L] prompt token IDs
|
| 268 |
+
max_new_tokens: Maximum tokens to generate
|
| 269 |
+
temperature: Sampling temperature
|
| 270 |
+
top_k: Top-k filtering
|
| 271 |
+
top_p: Nucleus sampling threshold
|
| 272 |
+
use_cache: Use KV cache for faster generation
|
| 273 |
+
word_start_table: [vocab_size] bool tensor for word-position RoPE
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
[B, L + max_new_tokens] generated token IDs
|
| 277 |
+
"""
|
| 278 |
+
from .layers import compute_word_positions
|
| 279 |
+
|
| 280 |
+
self.eval()
|
| 281 |
+
generated = prompt_ids.clone()
|
| 282 |
+
past_kv = None
|
| 283 |
+
word_pos_counter = 0 # Track word position during cached generation
|
| 284 |
+
|
| 285 |
+
for _ in range(max_new_tokens):
|
| 286 |
+
# Get input (full sequence or just last token with cache)
|
| 287 |
+
if use_cache and past_kv is not None:
|
| 288 |
+
input_ids = generated[:, -1:]
|
| 289 |
+
# Compute word position for the single new token
|
| 290 |
+
if word_start_table is not None:
|
| 291 |
+
last_token = generated[0, -1].item()
|
| 292 |
+
if word_start_table[last_token]:
|
| 293 |
+
word_pos_counter = 0
|
| 294 |
+
else:
|
| 295 |
+
word_pos_counter += 1
|
| 296 |
+
word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
|
| 297 |
+
else:
|
| 298 |
+
word_positions = None
|
| 299 |
+
else:
|
| 300 |
+
input_ids = generated
|
| 301 |
+
# Compute word positions for full sequence
|
| 302 |
+
if word_start_table is not None:
|
| 303 |
+
word_positions = compute_word_positions(input_ids, word_start_table)
|
| 304 |
+
else:
|
| 305 |
+
word_positions = None
|
| 306 |
+
|
| 307 |
+
# Forward pass
|
| 308 |
+
output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
|
| 309 |
+
logits = output["logits"][:, -1, :] # Last position
|
| 310 |
+
|
| 311 |
+
if use_cache:
|
| 312 |
+
past_kv = output["past_kv"]
|
| 313 |
+
|
| 314 |
+
# Apply temperature
|
| 315 |
+
if temperature > 0:
|
| 316 |
+
logits = logits / temperature
|
| 317 |
+
|
| 318 |
+
# Top-k filtering
|
| 319 |
+
if top_k > 0:
|
| 320 |
+
top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 321 |
+
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
|
| 322 |
+
logits = torch.where(logits < min_top_k, float("-inf"), logits)
|
| 323 |
+
|
| 324 |
+
# Top-p (nucleus) filtering
|
| 325 |
+
if top_p < 1.0:
|
| 326 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 327 |
+
cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 328 |
+
|
| 329 |
+
# Remove tokens with cumulative prob above threshold
|
| 330 |
+
sorted_indices_to_remove = cumsum_probs > top_p
|
| 331 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 332 |
+
sorted_indices_to_remove[:, 0] = False
|
| 333 |
+
|
| 334 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 335 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 336 |
+
)
|
| 337 |
+
logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
| 338 |
+
|
| 339 |
+
# Sample
|
| 340 |
+
probs = F.softmax(logits, dim=-1)
|
| 341 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 342 |
+
else:
|
| 343 |
+
# Greedy
|
| 344 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
| 345 |
+
|
| 346 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 347 |
+
|
| 348 |
+
# Stop if max length reached
|
| 349 |
+
if generated.size(1) >= self.config.max_seq_len:
|
| 350 |
+
break
|
| 351 |
+
|
| 352 |
+
return generated
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def count_parameters(model: CircuitTransformer) -> int:
|
| 356 |
+
"""Count trainable parameters."""
|
| 357 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
scripts/__init__.py
ADDED
|
File without changes
|
scripts/representation_analysis.py
ADDED
|
@@ -0,0 +1,1014 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Representation analysis: CKA and Logit Lens for Prisma / Circuit Transformer.
|
| 4 |
+
|
| 5 |
+
CKA (Centered Kernel Alignment):
|
| 6 |
+
Measures representational similarity between all layer pairs.
|
| 7 |
+
Produces a heatmap revealing mirror symmetry, phase transitions,
|
| 8 |
+
and cross-model alignment.
|
| 9 |
+
|
| 10 |
+
Logit Lens:
|
| 11 |
+
Projects intermediate representations to vocabulary space at every layer.
|
| 12 |
+
Reveals what the model "thinks" at each processing stage -- from raw
|
| 13 |
+
tokens through the semantic bottleneck back to specific predictions.
|
| 14 |
+
|
| 15 |
+
Also computes representation drift (cosine similarity between consecutive layers).
|
| 16 |
+
|
| 17 |
+
Usage:
|
| 18 |
+
# Full analysis (CKA + logit lens)
|
| 19 |
+
python -m circuits.scripts.representation_analysis \\
|
| 20 |
+
--checkpoint path/to/checkpoint.pt \\
|
| 21 |
+
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
|
| 22 |
+
|
| 23 |
+
# Cross-model CKA
|
| 24 |
+
python -m circuits.scripts.representation_analysis \\
|
| 25 |
+
--checkpoint path/to/prisma.pt --hf-model gpt2-medium \\
|
| 26 |
+
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
|
| 27 |
+
|
| 28 |
+
# CKA only (skip logit lens)
|
| 29 |
+
python -m circuits.scripts.representation_analysis \\
|
| 30 |
+
--checkpoint path/to/checkpoint.pt \\
|
| 31 |
+
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train \\
|
| 32 |
+
--no-logit-lens
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import json
|
| 37 |
+
import sys
|
| 38 |
+
import os
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
from collections import OrderedDict
|
| 41 |
+
|
| 42 |
+
import numpy as np
|
| 43 |
+
import torch
|
| 44 |
+
import torch.nn as nn
|
| 45 |
+
import torch.nn.functional as F
|
| 46 |
+
|
| 47 |
+
import matplotlib
|
| 48 |
+
matplotlib.use("Agg")
|
| 49 |
+
import matplotlib.pyplot as plt
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Model loading
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
|
| 57 |
+
"""Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
|
| 58 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 59 |
+
from circuits.config import CircuitConfig
|
| 60 |
+
from circuits.model import CircuitTransformer
|
| 61 |
+
from circuits.mirrored import MirroredConfig, MirroredTransformer
|
| 62 |
+
|
| 63 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 64 |
+
model_type = ckpt.get("model_type", "standard")
|
| 65 |
+
config_dict = ckpt.get("config", {})
|
| 66 |
+
|
| 67 |
+
if model_type == "mirrored":
|
| 68 |
+
if config_dict.get("dual_gate_middle"):
|
| 69 |
+
config_dict.pop("dual_gate_middle")
|
| 70 |
+
config = MirroredConfig.from_dict(config_dict)
|
| 71 |
+
model = MirroredTransformer(config)
|
| 72 |
+
else:
|
| 73 |
+
config = CircuitConfig.from_dict(config_dict)
|
| 74 |
+
model = CircuitTransformer(config)
|
| 75 |
+
|
| 76 |
+
state_dict = ckpt["model"]
|
| 77 |
+
if any(k.startswith("_orig_mod.") for k in state_dict):
|
| 78 |
+
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| 79 |
+
model.load_state_dict(state_dict, strict=False)
|
| 80 |
+
model.to(device).eval()
|
| 81 |
+
|
| 82 |
+
return model, config_dict, model_type
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_hf_model(model_name: str, device: str = "cpu"):
|
| 86 |
+
"""Load a HuggingFace causal LM."""
|
| 87 |
+
from transformers import AutoModelForCausalLM
|
| 88 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True)
|
| 89 |
+
model.to(device).eval()
|
| 90 |
+
return model
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
# Data loading
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
def load_data(data_source: str, tokenizer_name: str, num_samples: int = 32,
|
| 98 |
+
context_length: int = 512, device: str = "cpu"):
|
| 99 |
+
"""Load tokenized data. Returns (input_ids, tokenizer).
|
| 100 |
+
|
| 101 |
+
Supports:
|
| 102 |
+
- Memmap .bin files (from circuits training cache)
|
| 103 |
+
- hf:dataset:config:split (streaming from HuggingFace)
|
| 104 |
+
- Plain text files
|
| 105 |
+
"""
|
| 106 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 107 |
+
from circuits.data import get_tokenizer
|
| 108 |
+
|
| 109 |
+
tokenizer = get_tokenizer(tokenizer_name)
|
| 110 |
+
|
| 111 |
+
# Memmap binary file (already tokenized)
|
| 112 |
+
if data_source.endswith(".bin"):
|
| 113 |
+
import struct
|
| 114 |
+
with open(data_source, 'rb') as f:
|
| 115 |
+
n_chunks, seq_len = struct.unpack('II', f.read(8))
|
| 116 |
+
data = np.memmap(data_source, dtype=np.int32, mode='r',
|
| 117 |
+
offset=8, shape=(n_chunks, seq_len))
|
| 118 |
+
n = min(num_samples, n_chunks)
|
| 119 |
+
# Slice to requested context length
|
| 120 |
+
cl = min(context_length, seq_len)
|
| 121 |
+
input_ids = torch.from_numpy(data[:n, :cl].copy()).long().to(device)
|
| 122 |
+
return input_ids, tokenizer
|
| 123 |
+
|
| 124 |
+
# HuggingFace dataset
|
| 125 |
+
if data_source.startswith("hf:"):
|
| 126 |
+
from datasets import load_dataset
|
| 127 |
+
parts = data_source[3:].split(":")
|
| 128 |
+
ds_name = parts[0]
|
| 129 |
+
ds_config = parts[1] if len(parts) > 1 else None
|
| 130 |
+
ds_split = parts[2] if len(parts) > 2 else "train"
|
| 131 |
+
dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
|
| 132 |
+
all_ids = []
|
| 133 |
+
for item in dataset:
|
| 134 |
+
text = item.get("text", "")
|
| 135 |
+
if len(text) < 100:
|
| 136 |
+
continue
|
| 137 |
+
ids = tokenizer.encode(text)
|
| 138 |
+
if len(ids) >= context_length:
|
| 139 |
+
all_ids.append(ids[:context_length])
|
| 140 |
+
if len(all_ids) >= num_samples:
|
| 141 |
+
break
|
| 142 |
+
if not all_ids:
|
| 143 |
+
return None, tokenizer
|
| 144 |
+
return torch.tensor(all_ids, device=device), tokenizer
|
| 145 |
+
|
| 146 |
+
# Plain text file
|
| 147 |
+
with open(data_source) as f:
|
| 148 |
+
texts = [line.strip() for line in f if len(line.strip()) > 100]
|
| 149 |
+
all_ids = []
|
| 150 |
+
for text in texts:
|
| 151 |
+
ids = tokenizer.encode(text)
|
| 152 |
+
if len(ids) >= context_length:
|
| 153 |
+
all_ids.append(ids[:context_length])
|
| 154 |
+
if len(all_ids) >= num_samples:
|
| 155 |
+
break
|
| 156 |
+
if not all_ids:
|
| 157 |
+
return None, tokenizer
|
| 158 |
+
return torch.tensor(all_ids, device=device), tokenizer
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def tokenize_for_hf(texts: list, model_name: str, context_length: int = 512,
|
| 162 |
+
device: str = "cpu"):
|
| 163 |
+
"""Tokenize texts for an HF model. Returns (input_ids, tokenizer)."""
|
| 164 |
+
from transformers import AutoTokenizer
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
| 166 |
+
use_fast=False,
|
| 167 |
+
trust_remote_code=True)
|
| 168 |
+
if tokenizer.pad_token is None:
|
| 169 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 170 |
+
|
| 171 |
+
all_ids = []
|
| 172 |
+
for text in texts:
|
| 173 |
+
ids = tokenizer.encode(text, max_length=context_length, truncation=True)
|
| 174 |
+
if len(ids) >= context_length:
|
| 175 |
+
all_ids.append(ids[:context_length])
|
| 176 |
+
elif len(ids) > 32:
|
| 177 |
+
all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
|
| 178 |
+
|
| 179 |
+
if not all_ids:
|
| 180 |
+
return None, tokenizer
|
| 181 |
+
|
| 182 |
+
return torch.tensor(all_ids, device=device), tokenizer
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ---------------------------------------------------------------------------
|
| 186 |
+
# Activation collection
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
def collect_mirrored_activations(model, input_ids, word_positions=None):
|
| 190 |
+
"""Collect activations from MirroredTransformer at every processing stage."""
|
| 191 |
+
activations = OrderedDict()
|
| 192 |
+
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
x = model.embed(input_ids)
|
| 195 |
+
if model.embed_proj is not None:
|
| 196 |
+
if model.embed_g3 is not None:
|
| 197 |
+
g4 = F.silu(model.embed_g4(x))
|
| 198 |
+
g3 = F.silu(model.embed_g3(x) * g4)
|
| 199 |
+
x = model.embed_proj(x) * g3
|
| 200 |
+
else:
|
| 201 |
+
x = F.silu(model.embed_proj(x))
|
| 202 |
+
x = x * model.embed_scale
|
| 203 |
+
activations["embedding"] = x.detach().cpu()
|
| 204 |
+
|
| 205 |
+
for i, block in enumerate(model.mirror_blocks):
|
| 206 |
+
x, _ = block(x, word_positions=word_positions)
|
| 207 |
+
activations[f"expand_{i}"] = x.detach().cpu()
|
| 208 |
+
|
| 209 |
+
for i, block in enumerate(model.middle_blocks):
|
| 210 |
+
x, _ = block(x, word_positions=word_positions)
|
| 211 |
+
activations[f"middle_{i}"] = x.detach().cpu()
|
| 212 |
+
|
| 213 |
+
for i in reversed(range(len(model.mirror_blocks))):
|
| 214 |
+
x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
|
| 215 |
+
compress_idx = len(model.mirror_blocks) - 1 - i
|
| 216 |
+
activations[f"compress_{compress_idx}"] = x.detach().cpu()
|
| 217 |
+
|
| 218 |
+
x = model.norm(x)
|
| 219 |
+
activations["final_norm"] = x.detach().cpu()
|
| 220 |
+
|
| 221 |
+
return activations
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def collect_standard_activations(model, input_ids, word_positions=None):
|
| 225 |
+
"""Collect activations from standard CircuitTransformer."""
|
| 226 |
+
activations = OrderedDict()
|
| 227 |
+
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
x = model.embed(input_ids)
|
| 230 |
+
if model.embed_proj is not None:
|
| 231 |
+
x = F.silu(model.embed_proj(x))
|
| 232 |
+
x = x * model.embed_scale
|
| 233 |
+
activations["embedding"] = x.detach().cpu()
|
| 234 |
+
|
| 235 |
+
for i, layer in enumerate(model.layers):
|
| 236 |
+
x, _ = layer(x, word_positions=word_positions)
|
| 237 |
+
activations[f"layer_{i}"] = x.detach().cpu()
|
| 238 |
+
|
| 239 |
+
x = model.norm(x)
|
| 240 |
+
activations["final_norm"] = x.detach().cpu()
|
| 241 |
+
|
| 242 |
+
return activations
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def collect_hf_activations(model, input_ids):
|
| 246 |
+
"""Hook-based activation collection for HuggingFace models."""
|
| 247 |
+
activations = OrderedDict()
|
| 248 |
+
hooks = []
|
| 249 |
+
|
| 250 |
+
if hasattr(model, 'transformer'):
|
| 251 |
+
# GPT-2 style
|
| 252 |
+
blocks = model.transformer.h
|
| 253 |
+
embed = model.transformer.wte
|
| 254 |
+
final_norm = model.transformer.ln_f
|
| 255 |
+
elif hasattr(model, 'model'):
|
| 256 |
+
# Llama / Mistral style
|
| 257 |
+
blocks = model.model.layers
|
| 258 |
+
embed = model.model.embed_tokens
|
| 259 |
+
final_norm = model.model.norm
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(f"Unsupported HF model: {type(model)}")
|
| 262 |
+
|
| 263 |
+
def make_hook(name):
|
| 264 |
+
def hook_fn(module, input, output):
|
| 265 |
+
out = output[0] if isinstance(output, tuple) else output
|
| 266 |
+
activations[name] = out.detach().cpu()
|
| 267 |
+
return hook_fn
|
| 268 |
+
|
| 269 |
+
hooks.append(embed.register_forward_hook(make_hook("embedding")))
|
| 270 |
+
for i, block in enumerate(blocks):
|
| 271 |
+
hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
|
| 272 |
+
hooks.append(final_norm.register_forward_hook(make_hook("final_norm")))
|
| 273 |
+
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
model(input_ids)
|
| 276 |
+
|
| 277 |
+
for h in hooks:
|
| 278 |
+
h.remove()
|
| 279 |
+
|
| 280 |
+
return activations
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def collect_activations(model, model_type, config_dict, input_ids, device):
|
| 284 |
+
"""Dispatch to the right collector based on model type."""
|
| 285 |
+
word_positions = None
|
| 286 |
+
word_rope_dims = config_dict.get("word_rope_dims", 0) if config_dict else 0
|
| 287 |
+
|
| 288 |
+
if word_rope_dims > 0 and model_type in ("standard", "mirrored"):
|
| 289 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 290 |
+
from circuits.data import get_tokenizer
|
| 291 |
+
from circuits.layers import build_word_start_table, compute_word_positions
|
| 292 |
+
tokenizer_name = config_dict.get("tokenizer_name", "gpt2")
|
| 293 |
+
# Try to get tokenizer from the model's config
|
| 294 |
+
tokenizer = get_tokenizer(tokenizer_name)
|
| 295 |
+
word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| 296 |
+
word_positions = compute_word_positions(input_ids, word_start_table)
|
| 297 |
+
|
| 298 |
+
if model_type == "mirrored":
|
| 299 |
+
return collect_mirrored_activations(model, input_ids, word_positions)
|
| 300 |
+
elif model_type == "standard":
|
| 301 |
+
return collect_standard_activations(model, input_ids, word_positions)
|
| 302 |
+
else:
|
| 303 |
+
return collect_hf_activations(model, input_ids)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# ---------------------------------------------------------------------------
|
| 307 |
+
# Linear CKA
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
def linear_cka(X: torch.Tensor, Y: torch.Tensor) -> float:
|
| 311 |
+
"""Compute linear CKA between two [N, D] representation matrices.
|
| 312 |
+
|
| 313 |
+
CKA(X, Y) = ||Yc^T Xc||_F^2 / (||Xc^T Xc||_F * ||Yc^T Yc||_F)
|
| 314 |
+
"""
|
| 315 |
+
X = X.float()
|
| 316 |
+
Y = Y.float()
|
| 317 |
+
|
| 318 |
+
# Center
|
| 319 |
+
X = X - X.mean(0, keepdim=True)
|
| 320 |
+
Y = Y - Y.mean(0, keepdim=True)
|
| 321 |
+
|
| 322 |
+
N = X.shape[0]
|
| 323 |
+
|
| 324 |
+
if N < min(X.shape[1], Y.shape[1]):
|
| 325 |
+
# Kernel formulation (N < D): K=XX^T, L=YY^T — [N,N] matrices
|
| 326 |
+
K = X @ X.T
|
| 327 |
+
L = Y @ Y.T
|
| 328 |
+
numerator = (K * L).sum()
|
| 329 |
+
denominator = torch.sqrt((K * K).sum() * (L * L).sum())
|
| 330 |
+
else:
|
| 331 |
+
# Feature formulation (D <= N)
|
| 332 |
+
XtY = X.T @ Y
|
| 333 |
+
XtX = X.T @ X
|
| 334 |
+
YtY = Y.T @ Y
|
| 335 |
+
numerator = (XtY * XtY).sum()
|
| 336 |
+
denominator = torch.sqrt((XtX * XtX).sum() * (YtY * YtY).sum())
|
| 337 |
+
|
| 338 |
+
if denominator < 1e-10:
|
| 339 |
+
return 0.0
|
| 340 |
+
|
| 341 |
+
return (numerator / denominator).item()
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def compute_cka_matrix(activations: dict, subsample: int = 4) -> tuple:
|
| 345 |
+
"""Compute CKA between all layer pairs. Returns (cka_matrix, layer_names)."""
|
| 346 |
+
names = list(activations.keys())
|
| 347 |
+
n_layers = len(names)
|
| 348 |
+
|
| 349 |
+
# Flatten and subsample: [B, L, D] -> [N, D]
|
| 350 |
+
flat_acts = {}
|
| 351 |
+
for name, act in activations.items():
|
| 352 |
+
act_sub = act[:, ::subsample, :]
|
| 353 |
+
flat_acts[name] = act_sub.reshape(-1, act_sub.shape[-1])
|
| 354 |
+
|
| 355 |
+
cka_matrix = np.zeros((n_layers, n_layers))
|
| 356 |
+
|
| 357 |
+
for i in range(n_layers):
|
| 358 |
+
cka_matrix[i, i] = 1.0
|
| 359 |
+
for j in range(i + 1, n_layers):
|
| 360 |
+
cka_val = linear_cka(flat_acts[names[i]], flat_acts[names[j]])
|
| 361 |
+
cka_matrix[i, j] = cka_val
|
| 362 |
+
cka_matrix[j, i] = cka_val
|
| 363 |
+
if (i + 1) % 5 == 0 or i == n_layers - 1:
|
| 364 |
+
print(f" CKA: {i+1}/{n_layers} rows computed")
|
| 365 |
+
|
| 366 |
+
return cka_matrix, names
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def compute_cross_model_cka(acts_a: dict, acts_b: dict) -> tuple:
|
| 370 |
+
"""Cross-model CKA using sample-level (avg-pooled) representations."""
|
| 371 |
+
names_a = list(acts_a.keys())
|
| 372 |
+
names_b = list(acts_b.keys())
|
| 373 |
+
|
| 374 |
+
def pool(activations):
|
| 375 |
+
return {name: act.mean(dim=1) for name, act in activations.items()}
|
| 376 |
+
|
| 377 |
+
pooled_a = pool(acts_a)
|
| 378 |
+
pooled_b = pool(acts_b)
|
| 379 |
+
|
| 380 |
+
# Ensure same number of samples
|
| 381 |
+
n_samples = min(
|
| 382 |
+
next(iter(pooled_a.values())).shape[0],
|
| 383 |
+
next(iter(pooled_b.values())).shape[0]
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
cka_matrix = np.zeros((len(names_a), len(names_b)))
|
| 387 |
+
|
| 388 |
+
for i, na in enumerate(names_a):
|
| 389 |
+
for j, nb in enumerate(names_b):
|
| 390 |
+
cka_matrix[i, j] = linear_cka(pooled_a[na][:n_samples], pooled_b[nb][:n_samples])
|
| 391 |
+
if (i + 1) % 5 == 0 or i == len(names_a) - 1:
|
| 392 |
+
print(f" Cross-CKA: {i+1}/{len(names_a)} rows computed")
|
| 393 |
+
|
| 394 |
+
return cka_matrix, names_a, names_b
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# ---------------------------------------------------------------------------
|
| 398 |
+
# Logit Lens
|
| 399 |
+
# ---------------------------------------------------------------------------
|
| 400 |
+
|
| 401 |
+
def get_unembed_components(model, model_type):
|
| 402 |
+
"""Extract (norm_module, unembed_weight) for logit lens projection."""
|
| 403 |
+
if model_type in ("standard", "mirrored"):
|
| 404 |
+
return model.norm, model.embed.weight
|
| 405 |
+
elif hasattr(model, 'transformer'):
|
| 406 |
+
return model.transformer.ln_f, model.transformer.wte.weight
|
| 407 |
+
elif hasattr(model, 'model'):
|
| 408 |
+
return model.model.norm, model.model.embed_tokens.weight
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError(f"Unsupported model: {type(model)}")
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def compute_logit_lens(activations: dict, norm: nn.Module, unembed_weight: torch.Tensor,
|
| 414 |
+
labels: torch.Tensor, device: str = "cpu",
|
| 415 |
+
chunk_size: int = 2048) -> OrderedDict:
|
| 416 |
+
"""Compute logit lens statistics at every layer.
|
| 417 |
+
|
| 418 |
+
Projects intermediate hidden states through final norm + unembedding.
|
| 419 |
+
Computes entropy, top-1 probability, correct token rank, and
|
| 420 |
+
agreement with the final layer's predictions.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
activations: OrderedDict[name] = [B, L, D]
|
| 424 |
+
norm: final layer norm module
|
| 425 |
+
unembed_weight: [V, D] unembedding matrix
|
| 426 |
+
labels: [B, L-1] next-token labels (input_ids[:, 1:])
|
| 427 |
+
device: computation device
|
| 428 |
+
chunk_size: number of positions per batch for projection
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
OrderedDict[name] = {entropy, top1_prob, correct_rank, ...}
|
| 432 |
+
"""
|
| 433 |
+
names = list(activations.keys())
|
| 434 |
+
final_name = names[-1] # "final_norm"
|
| 435 |
+
results = OrderedDict()
|
| 436 |
+
|
| 437 |
+
unembed = unembed_weight.to(device)
|
| 438 |
+
norm_mod = norm.to(device)
|
| 439 |
+
labels_flat = labels.reshape(-1).to(device)
|
| 440 |
+
|
| 441 |
+
def process_layer(name, act, apply_norm=True):
|
| 442 |
+
"""Project one layer's activations and compute all metrics."""
|
| 443 |
+
B, L, D = act.shape
|
| 444 |
+
flat = act[:, :-1, :].reshape(-1, D) # [B*(L-1), D]
|
| 445 |
+
N = flat.shape[0]
|
| 446 |
+
|
| 447 |
+
all_entropy = []
|
| 448 |
+
all_top1_prob = []
|
| 449 |
+
all_correct_rank = []
|
| 450 |
+
all_top1_idx = []
|
| 451 |
+
|
| 452 |
+
for start in range(0, N, chunk_size):
|
| 453 |
+
end = min(start + chunk_size, N)
|
| 454 |
+
chunk = flat[start:end].to(device)
|
| 455 |
+
chunk_labels = labels_flat[start:end]
|
| 456 |
+
|
| 457 |
+
if apply_norm:
|
| 458 |
+
chunk = norm_mod(chunk)
|
| 459 |
+
|
| 460 |
+
logits = chunk @ unembed.T # [cs, V]
|
| 461 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 462 |
+
probs = log_probs.exp()
|
| 463 |
+
|
| 464 |
+
# Entropy
|
| 465 |
+
entropy = -(probs * log_probs).sum(dim=-1)
|
| 466 |
+
all_entropy.append(entropy.cpu())
|
| 467 |
+
|
| 468 |
+
# Top-1 probability
|
| 469 |
+
top1_prob = probs.max(dim=-1).values
|
| 470 |
+
all_top1_prob.append(top1_prob.cpu())
|
| 471 |
+
|
| 472 |
+
# Correct token rank
|
| 473 |
+
correct_logits = logits.gather(1, chunk_labels.unsqueeze(1))
|
| 474 |
+
rank = (logits > correct_logits).sum(dim=-1) + 1
|
| 475 |
+
all_correct_rank.append(rank.cpu())
|
| 476 |
+
|
| 477 |
+
# Top-1 index
|
| 478 |
+
all_top1_idx.append(logits.argmax(dim=-1).cpu())
|
| 479 |
+
|
| 480 |
+
entropy_t = torch.cat(all_entropy)
|
| 481 |
+
top1_t = torch.cat(all_top1_prob)
|
| 482 |
+
rank_t = torch.cat(all_correct_rank).float()
|
| 483 |
+
top1_idx = torch.cat(all_top1_idx)
|
| 484 |
+
|
| 485 |
+
return {
|
| 486 |
+
"entropy": entropy_t.mean().item(),
|
| 487 |
+
"entropy_std": entropy_t.std().item(),
|
| 488 |
+
"top1_prob": top1_t.mean().item(),
|
| 489 |
+
"correct_rank_mean": rank_t.mean().item(),
|
| 490 |
+
"correct_rank_median": rank_t.median().item(),
|
| 491 |
+
"log_rank_mean": rank_t.log().mean().item(),
|
| 492 |
+
"_top1_idx": top1_idx,
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
# Process all layers
|
| 496 |
+
for name in names:
|
| 497 |
+
is_final = (name == final_name)
|
| 498 |
+
act = activations[name]
|
| 499 |
+
stats = process_layer(name, act, apply_norm=not is_final)
|
| 500 |
+
results[name] = stats
|
| 501 |
+
print(f" Logit lens: {name:20s} entropy={stats['entropy']:.2f} "
|
| 502 |
+
f"top1={stats['top1_prob']:.4f} rank={stats['correct_rank_median']:.0f}")
|
| 503 |
+
|
| 504 |
+
# Compute agreement with final layer
|
| 505 |
+
final_top1 = results[final_name]["_top1_idx"]
|
| 506 |
+
for name in names:
|
| 507 |
+
layer_top1 = results[name]["_top1_idx"]
|
| 508 |
+
agreement = (layer_top1 == final_top1).float().mean().item()
|
| 509 |
+
results[name]["agreement_with_final"] = agreement
|
| 510 |
+
|
| 511 |
+
# Clean up internal tensors
|
| 512 |
+
for name in names:
|
| 513 |
+
del results[name]["_top1_idx"]
|
| 514 |
+
|
| 515 |
+
return results
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
# ---------------------------------------------------------------------------
|
| 519 |
+
# Representation drift
|
| 520 |
+
# ---------------------------------------------------------------------------
|
| 521 |
+
|
| 522 |
+
def compute_drift(activations: dict) -> OrderedDict:
|
| 523 |
+
"""Cosine similarity between consecutive layers' representations."""
|
| 524 |
+
names = list(activations.keys())
|
| 525 |
+
drift = OrderedDict()
|
| 526 |
+
|
| 527 |
+
for i in range(1, len(names)):
|
| 528 |
+
prev = activations[names[i - 1]]
|
| 529 |
+
curr = activations[names[i]]
|
| 530 |
+
|
| 531 |
+
# Flatten to [N, D]
|
| 532 |
+
prev_flat = prev.reshape(-1, prev.shape[-1])
|
| 533 |
+
curr_flat = curr.reshape(-1, curr.shape[-1])
|
| 534 |
+
|
| 535 |
+
# Mean cosine similarity
|
| 536 |
+
cos = F.cosine_similarity(prev_flat, curr_flat, dim=-1)
|
| 537 |
+
drift[names[i]] = {
|
| 538 |
+
"cos_sim_mean": cos.mean().item(),
|
| 539 |
+
"cos_sim_std": cos.std().item(),
|
| 540 |
+
"l2_distance": (curr_flat - prev_flat).norm(dim=-1).mean().item(),
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
return drift
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# ---------------------------------------------------------------------------
|
| 547 |
+
# Plotting
|
| 548 |
+
# ---------------------------------------------------------------------------
|
| 549 |
+
|
| 550 |
+
def _phase_color(name):
|
| 551 |
+
"""Return color based on layer phase."""
|
| 552 |
+
if "expand" in name:
|
| 553 |
+
return "steelblue"
|
| 554 |
+
elif "middle" in name:
|
| 555 |
+
return "goldenrod"
|
| 556 |
+
elif "compress" in name:
|
| 557 |
+
return "coral"
|
| 558 |
+
elif "embedding" in name:
|
| 559 |
+
return "gray"
|
| 560 |
+
elif "final" in name:
|
| 561 |
+
return "gray"
|
| 562 |
+
else:
|
| 563 |
+
return "mediumpurple"
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def _layer_sort_key(name):
|
| 567 |
+
"""Sort key for processing order."""
|
| 568 |
+
order = {"embedding": -1, "final_norm": 9999}
|
| 569 |
+
if name in order:
|
| 570 |
+
return order[name]
|
| 571 |
+
parts = name.split("_")
|
| 572 |
+
phase = parts[0]
|
| 573 |
+
idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
|
| 574 |
+
phase_offset = {"expand": 0, "middle": 1000, "compress": 2000, "layer": 0}
|
| 575 |
+
return phase_offset.get(phase, 3000) + idx
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _short_name(name):
|
| 579 |
+
"""Shorten layer name for plot labels."""
|
| 580 |
+
if name == "embedding":
|
| 581 |
+
return "emb"
|
| 582 |
+
if name == "final_norm":
|
| 583 |
+
return "out"
|
| 584 |
+
parts = name.split("_")
|
| 585 |
+
if parts[0] == "expand":
|
| 586 |
+
return f"E{parts[1]}"
|
| 587 |
+
elif parts[0] == "middle":
|
| 588 |
+
return f"M{parts[1]}"
|
| 589 |
+
elif parts[0] == "compress":
|
| 590 |
+
return f"C{parts[1]}"
|
| 591 |
+
elif parts[0] == "layer":
|
| 592 |
+
return f"L{parts[1]}"
|
| 593 |
+
return name[:6]
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def plot_cka_self(cka_matrix: np.ndarray, names: list, output_dir: Path,
|
| 597 |
+
model_label: str):
|
| 598 |
+
"""Plot self-CKA heatmap."""
|
| 599 |
+
n = len(names)
|
| 600 |
+
short = [_short_name(n) for n in names]
|
| 601 |
+
|
| 602 |
+
fig, ax = plt.subplots(figsize=(max(10, n * 0.35), max(8, n * 0.3)))
|
| 603 |
+
fig.suptitle(f"{model_label} -- CKA Self-Similarity", fontsize=14)
|
| 604 |
+
|
| 605 |
+
im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="equal")
|
| 606 |
+
|
| 607 |
+
# Phase separators
|
| 608 |
+
for i, name in enumerate(names):
|
| 609 |
+
if i > 0:
|
| 610 |
+
prev = names[i - 1].split("_")[0]
|
| 611 |
+
curr = name.split("_")[0]
|
| 612 |
+
if prev != curr:
|
| 613 |
+
ax.axhline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
|
| 614 |
+
ax.axvline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
|
| 615 |
+
|
| 616 |
+
ax.set_xticks(range(n))
|
| 617 |
+
ax.set_xticklabels(short, rotation=90, fontsize=7)
|
| 618 |
+
ax.set_yticks(range(n))
|
| 619 |
+
ax.set_yticklabels(short, fontsize=7)
|
| 620 |
+
|
| 621 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
|
| 622 |
+
plt.tight_layout()
|
| 623 |
+
fig.savefig(output_dir / "cka_self.png", dpi=150)
|
| 624 |
+
plt.close(fig)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def plot_cka_cross(cka_matrix: np.ndarray, names_a: list, names_b: list,
|
| 628 |
+
output_dir: Path, label_a: str, label_b: str):
|
| 629 |
+
"""Plot cross-model CKA heatmap."""
|
| 630 |
+
short_a = [_short_name(n) for n in names_a]
|
| 631 |
+
short_b = [_short_name(n) for n in names_b]
|
| 632 |
+
|
| 633 |
+
na, nb = len(names_a), len(names_b)
|
| 634 |
+
fig, ax = plt.subplots(figsize=(max(10, nb * 0.35), max(8, na * 0.3)))
|
| 635 |
+
fig.suptitle(f"Cross-CKA: {label_a} vs {label_b}", fontsize=14)
|
| 636 |
+
|
| 637 |
+
im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="auto")
|
| 638 |
+
|
| 639 |
+
ax.set_xticks(range(nb))
|
| 640 |
+
ax.set_xticklabels(short_b, rotation=90, fontsize=7)
|
| 641 |
+
ax.set_xlabel(label_b)
|
| 642 |
+
ax.set_yticks(range(na))
|
| 643 |
+
ax.set_yticklabels(short_a, fontsize=7)
|
| 644 |
+
ax.set_ylabel(label_a)
|
| 645 |
+
|
| 646 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
|
| 647 |
+
plt.tight_layout()
|
| 648 |
+
fig.savefig(output_dir / "cka_cross.png", dpi=150)
|
| 649 |
+
plt.close(fig)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def plot_logit_lens(lens_results: OrderedDict, output_dir: Path,
|
| 653 |
+
model_label: str):
|
| 654 |
+
"""Plot logit lens summary: entropy, confidence, rank, agreement."""
|
| 655 |
+
names = list(lens_results.keys())
|
| 656 |
+
sorted_names = sorted(names, key=_layer_sort_key)
|
| 657 |
+
short = [_short_name(n) for n in sorted_names]
|
| 658 |
+
colors = [_phase_color(n) for n in sorted_names]
|
| 659 |
+
x = range(len(sorted_names))
|
| 660 |
+
|
| 661 |
+
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
| 662 |
+
fig.suptitle(f"{model_label} -- Logit Lens", fontsize=14)
|
| 663 |
+
|
| 664 |
+
# Entropy
|
| 665 |
+
vals = [lens_results[n]["entropy"] for n in sorted_names]
|
| 666 |
+
axes[0, 0].bar(x, vals, color=colors, alpha=0.85)
|
| 667 |
+
axes[0, 0].set_ylabel("Entropy (nats)")
|
| 668 |
+
axes[0, 0].set_title("Prediction entropy per layer")
|
| 669 |
+
axes[0, 0].set_xticks(x)
|
| 670 |
+
axes[0, 0].set_xticklabels(short, rotation=90, fontsize=7)
|
| 671 |
+
|
| 672 |
+
# Top-1 probability
|
| 673 |
+
vals = [lens_results[n]["top1_prob"] for n in sorted_names]
|
| 674 |
+
axes[0, 1].bar(x, vals, color=colors, alpha=0.85)
|
| 675 |
+
axes[0, 1].set_ylabel("Top-1 probability")
|
| 676 |
+
axes[0, 1].set_title("Prediction confidence per layer")
|
| 677 |
+
axes[0, 1].set_xticks(x)
|
| 678 |
+
axes[0, 1].set_xticklabels(short, rotation=90, fontsize=7)
|
| 679 |
+
|
| 680 |
+
# Correct rank (log scale)
|
| 681 |
+
vals = [lens_results[n]["correct_rank_median"] for n in sorted_names]
|
| 682 |
+
axes[1, 0].bar(x, vals, color=colors, alpha=0.85)
|
| 683 |
+
axes[1, 0].set_ylabel("Median rank of correct token")
|
| 684 |
+
axes[1, 0].set_yscale("log")
|
| 685 |
+
axes[1, 0].set_title("When does the model find the answer?")
|
| 686 |
+
axes[1, 0].set_xticks(x)
|
| 687 |
+
axes[1, 0].set_xticklabels(short, rotation=90, fontsize=7)
|
| 688 |
+
|
| 689 |
+
# Agreement with final layer
|
| 690 |
+
vals = [lens_results[n]["agreement_with_final"] for n in sorted_names]
|
| 691 |
+
axes[1, 1].bar(x, vals, color=colors, alpha=0.85)
|
| 692 |
+
axes[1, 1].set_ylabel("Agreement with final layer")
|
| 693 |
+
axes[1, 1].set_title("Convergence toward final prediction")
|
| 694 |
+
axes[1, 1].set_ylim(0, 1.05)
|
| 695 |
+
axes[1, 1].set_xticks(x)
|
| 696 |
+
axes[1, 1].set_xticklabels(short, rotation=90, fontsize=7)
|
| 697 |
+
|
| 698 |
+
plt.tight_layout()
|
| 699 |
+
fig.savefig(output_dir / "logit_lens_summary.png", dpi=150)
|
| 700 |
+
plt.close(fig)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def plot_logit_lens_trajectory(activations: dict, norm: nn.Module,
|
| 704 |
+
unembed_weight: torch.Tensor, input_ids: torch.Tensor,
|
| 705 |
+
tokenizer, output_dir: Path, model_label: str,
|
| 706 |
+
device: str = "cpu",
|
| 707 |
+
n_positions: int = 6, n_layers: int = 10):
|
| 708 |
+
"""Show top-5 predicted tokens at selected layers for a few positions.
|
| 709 |
+
|
| 710 |
+
Picks positions spread across the first sample and shows how the
|
| 711 |
+
model's prediction evolves through the network.
|
| 712 |
+
"""
|
| 713 |
+
names = sorted(activations.keys(), key=_layer_sort_key)
|
| 714 |
+
|
| 715 |
+
# Select layers evenly spread across the network
|
| 716 |
+
if len(names) > n_layers:
|
| 717 |
+
indices = np.linspace(0, len(names) - 1, n_layers, dtype=int)
|
| 718 |
+
selected_layers = [names[i] for i in indices]
|
| 719 |
+
else:
|
| 720 |
+
selected_layers = names
|
| 721 |
+
|
| 722 |
+
# Select positions from the first sample
|
| 723 |
+
seq_len = input_ids.shape[1]
|
| 724 |
+
pos_indices = np.linspace(10, seq_len - 2, n_positions, dtype=int)
|
| 725 |
+
|
| 726 |
+
unembed = unembed_weight.to(device)
|
| 727 |
+
norm_mod = norm.to(device)
|
| 728 |
+
final_name = names[-1]
|
| 729 |
+
|
| 730 |
+
fig, axes = plt.subplots(n_positions, 1, figsize=(14, 3 * n_positions))
|
| 731 |
+
if n_positions == 1:
|
| 732 |
+
axes = [axes]
|
| 733 |
+
fig.suptitle(f"{model_label} -- Token prediction trajectory", fontsize=14, y=1.02)
|
| 734 |
+
|
| 735 |
+
for pos_idx, pos in enumerate(pos_indices):
|
| 736 |
+
ax = axes[pos_idx]
|
| 737 |
+
actual_token = tokenizer.decode([input_ids[0, pos + 1].item()])
|
| 738 |
+
context = tokenizer.decode(input_ids[0, max(0, pos - 5):pos + 1].tolist())
|
| 739 |
+
|
| 740 |
+
layer_labels = []
|
| 741 |
+
top_tokens_per_layer = []
|
| 742 |
+
|
| 743 |
+
for name in selected_layers:
|
| 744 |
+
is_final = (name == final_name)
|
| 745 |
+
hidden = activations[name][0, pos:pos + 1, :].to(device) # [1, D]
|
| 746 |
+
if not is_final:
|
| 747 |
+
hidden = norm_mod(hidden)
|
| 748 |
+
logits = (hidden @ unembed.T).squeeze(0) # [V]
|
| 749 |
+
probs = F.softmax(logits, dim=-1)
|
| 750 |
+
top5_vals, top5_idx = probs.topk(5)
|
| 751 |
+
|
| 752 |
+
tokens_str = []
|
| 753 |
+
for val, idx in zip(top5_vals, top5_idx):
|
| 754 |
+
tok = tokenizer.decode([idx.item()]).replace("\n", "\\n")
|
| 755 |
+
tokens_str.append(f"{tok}({val:.2f})")
|
| 756 |
+
|
| 757 |
+
layer_labels.append(_short_name(name))
|
| 758 |
+
top_tokens_per_layer.append("\n".join(tokens_str))
|
| 759 |
+
|
| 760 |
+
# Create a text table
|
| 761 |
+
ax.set_xlim(-0.5, len(layer_labels) - 0.5)
|
| 762 |
+
ax.set_ylim(-0.5, 5.5)
|
| 763 |
+
ax.set_xticks(range(len(layer_labels)))
|
| 764 |
+
ax.set_xticklabels(layer_labels, fontsize=8)
|
| 765 |
+
ax.set_yticks([])
|
| 766 |
+
|
| 767 |
+
for li, tokens_str in enumerate(top_tokens_per_layer):
|
| 768 |
+
lines = tokens_str.split("\n")
|
| 769 |
+
for rank, line in enumerate(lines):
|
| 770 |
+
color = "darkgreen" if actual_token.strip() in line else "black"
|
| 771 |
+
fontweight = "bold" if actual_token.strip() in line else "normal"
|
| 772 |
+
ax.text(li, rank, line, ha="center", va="center", fontsize=7,
|
| 773 |
+
color=color, fontweight=fontweight)
|
| 774 |
+
|
| 775 |
+
ax.set_title(f'pos {pos}: "...{context}" -> [{actual_token.strip()}]',
|
| 776 |
+
fontsize=9, loc="left")
|
| 777 |
+
ax.invert_yaxis()
|
| 778 |
+
ax.spines["top"].set_visible(False)
|
| 779 |
+
ax.spines["right"].set_visible(False)
|
| 780 |
+
ax.spines["left"].set_visible(False)
|
| 781 |
+
|
| 782 |
+
plt.tight_layout()
|
| 783 |
+
fig.savefig(output_dir / "logit_lens_trajectory.png", dpi=150, bbox_inches="tight")
|
| 784 |
+
plt.close(fig)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def plot_drift(drift: OrderedDict, output_dir: Path, model_label: str):
|
| 788 |
+
"""Plot representation drift between consecutive layers."""
|
| 789 |
+
names = list(drift.keys())
|
| 790 |
+
sorted_names = sorted(names, key=_layer_sort_key)
|
| 791 |
+
short = [_short_name(n) for n in sorted_names]
|
| 792 |
+
colors = [_phase_color(n) for n in sorted_names]
|
| 793 |
+
x = range(len(sorted_names))
|
| 794 |
+
|
| 795 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 796 |
+
fig.suptitle(f"{model_label} -- Representation drift", fontsize=14)
|
| 797 |
+
|
| 798 |
+
# Cosine similarity with previous layer
|
| 799 |
+
vals = [drift[n]["cos_sim_mean"] for n in sorted_names]
|
| 800 |
+
axes[0].bar(x, vals, color=colors, alpha=0.85)
|
| 801 |
+
axes[0].set_ylabel("Cosine similarity with previous layer")
|
| 802 |
+
axes[0].set_title("How much each layer preserves direction")
|
| 803 |
+
axes[0].set_xticks(x)
|
| 804 |
+
axes[0].set_xticklabels(short, rotation=90, fontsize=7)
|
| 805 |
+
|
| 806 |
+
# L2 distance
|
| 807 |
+
vals = [drift[n]["l2_distance"] for n in sorted_names]
|
| 808 |
+
axes[1].bar(x, vals, color=colors, alpha=0.85)
|
| 809 |
+
axes[1].set_ylabel("L2 distance from previous layer")
|
| 810 |
+
axes[1].set_title("How much each layer changes magnitude")
|
| 811 |
+
axes[1].set_xticks(x)
|
| 812 |
+
axes[1].set_xticklabels(short, rotation=90, fontsize=7)
|
| 813 |
+
|
| 814 |
+
plt.tight_layout()
|
| 815 |
+
fig.savefig(output_dir / "representation_drift.png", dpi=150)
|
| 816 |
+
plt.close(fig)
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
# ---------------------------------------------------------------------------
|
| 820 |
+
# Results saving
|
| 821 |
+
# ---------------------------------------------------------------------------
|
| 822 |
+
|
| 823 |
+
def save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir):
|
| 824 |
+
"""Save all numerical results to JSON."""
|
| 825 |
+
out = {}
|
| 826 |
+
|
| 827 |
+
if cka_matrix is not None:
|
| 828 |
+
out["cka_self"] = {
|
| 829 |
+
"names": cka_names,
|
| 830 |
+
"matrix": cka_matrix.tolist(),
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
if lens_results:
|
| 834 |
+
out["logit_lens"] = {name: data for name, data in lens_results.items()}
|
| 835 |
+
|
| 836 |
+
if drift:
|
| 837 |
+
out["drift"] = {name: data for name, data in drift.items()}
|
| 838 |
+
|
| 839 |
+
if cross_cka is not None:
|
| 840 |
+
matrix, names_a, names_b = cross_cka
|
| 841 |
+
out["cka_cross"] = {
|
| 842 |
+
"names_a": names_a,
|
| 843 |
+
"names_b": names_b,
|
| 844 |
+
"matrix": matrix.tolist(),
|
| 845 |
+
}
|
| 846 |
+
|
| 847 |
+
with open(output_dir / "results.json", "w") as f:
|
| 848 |
+
json.dump(out, f, indent=2, default=str)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
# ---------------------------------------------------------------------------
|
| 852 |
+
# Main
|
| 853 |
+
# ---------------------------------------------------------------------------
|
| 854 |
+
|
| 855 |
+
def main():
|
| 856 |
+
parser = argparse.ArgumentParser(
|
| 857 |
+
description="CKA and Logit Lens analysis for Prisma / Circuit Transformer")
|
| 858 |
+
parser.add_argument("--checkpoint", type=str, required=True,
|
| 859 |
+
help="Path to Prisma/Circuit checkpoint")
|
| 860 |
+
parser.add_argument("--checkpoint-b", type=str, default=None,
|
| 861 |
+
help="Second Prisma checkpoint for cross-model CKA")
|
| 862 |
+
parser.add_argument("--hf-model", type=str, default=None,
|
| 863 |
+
help="HuggingFace model for cross-model CKA (e.g. gpt2-medium)")
|
| 864 |
+
parser.add_argument("--data", type=str, required=True,
|
| 865 |
+
help="Data source (hf:dataset:config:split or file path)")
|
| 866 |
+
parser.add_argument("--num-samples", type=int, default=32,
|
| 867 |
+
help="Number of text samples (default: 32)")
|
| 868 |
+
parser.add_argument("--context-length", type=int, default=512,
|
| 869 |
+
help="Sequence length (default: 512)")
|
| 870 |
+
parser.add_argument("--cka-subsample", type=int, default=4,
|
| 871 |
+
help="Position subsampling for CKA (default: 4)")
|
| 872 |
+
parser.add_argument("--no-logit-lens", action="store_true",
|
| 873 |
+
help="Skip logit lens analysis")
|
| 874 |
+
parser.add_argument("--no-cka", action="store_true",
|
| 875 |
+
help="Skip CKA analysis")
|
| 876 |
+
parser.add_argument("--output-dir", type=str, default=None,
|
| 877 |
+
help="Output directory (default: auto)")
|
| 878 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| 879 |
+
args = parser.parse_args()
|
| 880 |
+
|
| 881 |
+
device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 882 |
+
print(f"Device: {device}")
|
| 883 |
+
|
| 884 |
+
# Output directory
|
| 885 |
+
if args.output_dir:
|
| 886 |
+
output_dir = Path(args.output_dir)
|
| 887 |
+
else:
|
| 888 |
+
ckpt_name = Path(args.checkpoint).parent.name
|
| 889 |
+
output_dir = Path("circuits/scripts/representation_output") / ckpt_name
|
| 890 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 891 |
+
print(f"Output: {output_dir}")
|
| 892 |
+
|
| 893 |
+
# === Load model A ===
|
| 894 |
+
print(f"\nLoading: {args.checkpoint}")
|
| 895 |
+
model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
|
| 896 |
+
label_a = Path(args.checkpoint).parent.name
|
| 897 |
+
n_params = sum(p.numel() for p in model_a.parameters())
|
| 898 |
+
print(f" Type: {model_type_a}, params: {n_params:,}")
|
| 899 |
+
|
| 900 |
+
# === Load data ===
|
| 901 |
+
ckpt_data = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 902 |
+
tokenizer_name = ckpt_data.get("tokenizer_name", config_a.get("tokenizer_name", "gpt2"))
|
| 903 |
+
del ckpt_data
|
| 904 |
+
|
| 905 |
+
print(f"\nLoading data ({args.num_samples} samples, ctx={args.context_length})...")
|
| 906 |
+
result = load_data(
|
| 907 |
+
args.data, tokenizer_name, args.num_samples, args.context_length, device
|
| 908 |
+
)
|
| 909 |
+
if result[0] is None:
|
| 910 |
+
print("ERROR: No valid samples loaded.")
|
| 911 |
+
return
|
| 912 |
+
input_ids, tokenizer = result
|
| 913 |
+
print(f" Data shape: {input_ids.shape}")
|
| 914 |
+
|
| 915 |
+
# === Collect activations (model A) ===
|
| 916 |
+
print(f"\nCollecting activations ({model_type_a})...")
|
| 917 |
+
acts_a = collect_activations(model_a, model_type_a, config_a, input_ids, device)
|
| 918 |
+
print(f" Collected {len(acts_a)} layers")
|
| 919 |
+
|
| 920 |
+
# Free GPU memory
|
| 921 |
+
del model_a
|
| 922 |
+
if device.startswith("cuda"):
|
| 923 |
+
torch.cuda.empty_cache()
|
| 924 |
+
|
| 925 |
+
# === CKA (self) ===
|
| 926 |
+
cka_matrix = None
|
| 927 |
+
cka_names = None
|
| 928 |
+
if not args.no_cka:
|
| 929 |
+
print(f"\nComputing self-CKA (subsample={args.cka_subsample})...")
|
| 930 |
+
cka_matrix, cka_names = compute_cka_matrix(acts_a, subsample=args.cka_subsample)
|
| 931 |
+
plot_cka_self(cka_matrix, cka_names, output_dir, label_a)
|
| 932 |
+
print(f" Saved: cka_self.png")
|
| 933 |
+
|
| 934 |
+
# === Cross-model CKA ===
|
| 935 |
+
cross_cka = None
|
| 936 |
+
if not args.no_cka and (args.checkpoint_b or args.hf_model):
|
| 937 |
+
if args.checkpoint_b:
|
| 938 |
+
print(f"\nLoading comparison: {args.checkpoint_b}")
|
| 939 |
+
model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
|
| 940 |
+
label_b = Path(args.checkpoint_b).parent.name
|
| 941 |
+
acts_b = collect_activations(model_b, model_type_b, config_b, input_ids, device)
|
| 942 |
+
del model_b
|
| 943 |
+
else:
|
| 944 |
+
print(f"\nLoading HF model: {args.hf_model}")
|
| 945 |
+
model_b = load_hf_model(args.hf_model, device)
|
| 946 |
+
label_b = args.hf_model
|
| 947 |
+
# Decode texts from our tokens and re-tokenize for HF model
|
| 948 |
+
print(f" Re-tokenizing for {args.hf_model}...")
|
| 949 |
+
raw_texts = [tokenizer.decode(input_ids[i].tolist()) for i in range(input_ids.shape[0])]
|
| 950 |
+
input_ids_b, _ = tokenize_for_hf(
|
| 951 |
+
raw_texts, args.hf_model, args.context_length, device
|
| 952 |
+
)
|
| 953 |
+
if input_ids_b is not None:
|
| 954 |
+
print(f" HF data shape: {input_ids_b.shape}")
|
| 955 |
+
acts_b = collect_hf_activations(model_b, input_ids_b)
|
| 956 |
+
else:
|
| 957 |
+
acts_b = None
|
| 958 |
+
del model_b
|
| 959 |
+
|
| 960 |
+
if device.startswith("cuda"):
|
| 961 |
+
torch.cuda.empty_cache()
|
| 962 |
+
|
| 963 |
+
if acts_b:
|
| 964 |
+
print(f"\nComputing cross-model CKA...")
|
| 965 |
+
cross_matrix, cross_names_a, cross_names_b = compute_cross_model_cka(acts_a, acts_b)
|
| 966 |
+
cross_cka = (cross_matrix, cross_names_a, cross_names_b)
|
| 967 |
+
plot_cka_cross(cross_matrix, cross_names_a, cross_names_b,
|
| 968 |
+
output_dir, label_a, label_b)
|
| 969 |
+
print(f" Saved: cka_cross.png")
|
| 970 |
+
del acts_b
|
| 971 |
+
|
| 972 |
+
# === Logit lens ===
|
| 973 |
+
lens_results = None
|
| 974 |
+
if not args.no_logit_lens:
|
| 975 |
+
# Reload model for unembedding components (we deleted it for memory)
|
| 976 |
+
print(f"\nReloading model for logit lens...")
|
| 977 |
+
model_a, _, _ = load_prisma_model(args.checkpoint, device)
|
| 978 |
+
norm, unembed_weight = get_unembed_components(model_a, model_type_a)
|
| 979 |
+
|
| 980 |
+
labels = input_ids[:, 1:].cpu() # next-token labels
|
| 981 |
+
|
| 982 |
+
print(f"Computing logit lens...")
|
| 983 |
+
lens_results = compute_logit_lens(acts_a, norm, unembed_weight, labels, device)
|
| 984 |
+
plot_logit_lens(lens_results, output_dir, label_a)
|
| 985 |
+
print(f" Saved: logit_lens_summary.png")
|
| 986 |
+
|
| 987 |
+
# Token trajectory visualization
|
| 988 |
+
print(f" Generating token trajectories...")
|
| 989 |
+
plot_logit_lens_trajectory(
|
| 990 |
+
acts_a, norm, unembed_weight, input_ids.cpu(), tokenizer,
|
| 991 |
+
output_dir, label_a, device
|
| 992 |
+
)
|
| 993 |
+
print(f" Saved: logit_lens_trajectory.png")
|
| 994 |
+
|
| 995 |
+
del model_a
|
| 996 |
+
if device.startswith("cuda"):
|
| 997 |
+
torch.cuda.empty_cache()
|
| 998 |
+
|
| 999 |
+
# === Representation drift ===
|
| 1000 |
+
print(f"\nComputing representation drift...")
|
| 1001 |
+
drift = compute_drift(acts_a)
|
| 1002 |
+
plot_drift(drift, output_dir, label_a)
|
| 1003 |
+
print(f" Saved: representation_drift.png")
|
| 1004 |
+
|
| 1005 |
+
# === Save results ===
|
| 1006 |
+
save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir)
|
| 1007 |
+
print(f"\nAll outputs saved to: {output_dir}")
|
| 1008 |
+
n_plots = len(list(output_dir.glob("*.png")))
|
| 1009 |
+
print(f" Plots: {n_plots} PNG files")
|
| 1010 |
+
print(f" Data: results.json")
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
if __name__ == "__main__":
|
| 1014 |
+
main()
|
scripts/spectral_analysis.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Spectral analysis of Prisma / Circuit Transformer checkpoints.
|
| 4 |
+
|
| 5 |
+
Computes SVD spectra of weight matrices and (optionally) activation covariances,
|
| 6 |
+
revealing how the model organizes information geometrically.
|
| 7 |
+
|
| 8 |
+
Analyses:
|
| 9 |
+
1. Weight spectra — singular value distributions per matrix
|
| 10 |
+
2. Effective rank — how many dimensions carry real signal
|
| 11 |
+
3. Power-law fit — Martin & Mahoney alpha exponent (training quality)
|
| 12 |
+
4. MP bound — Marchenko-Pastur separation of signal vs noise
|
| 13 |
+
5. Mirror comparison — expand vs compress activation spectra (Prisma-specific)
|
| 14 |
+
6. Embedding alignment— spectral similarity between embed and final hidden states
|
| 15 |
+
7. Layer-wise summary — effective rank progression through the network (the lens)
|
| 16 |
+
|
| 17 |
+
Usage:
|
| 18 |
+
# Weight-only analysis (no data needed)
|
| 19 |
+
python -m circuits.scripts.spectral_analysis --checkpoint path/to/checkpoint.pt
|
| 20 |
+
|
| 21 |
+
# Full analysis with activation spectra (needs data)
|
| 22 |
+
python -m circuits.scripts.spectral_analysis --checkpoint path/to/checkpoint.pt \
|
| 23 |
+
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train --num-samples 512
|
| 24 |
+
|
| 25 |
+
# Compare two checkpoints
|
| 26 |
+
python -m circuits.scripts.spectral_analysis \
|
| 27 |
+
--checkpoint path/to/prisma.pt --checkpoint-b path/to/standard.pt
|
| 28 |
+
|
| 29 |
+
# Compare against HuggingFace model
|
| 30 |
+
python -m circuits.scripts.spectral_analysis \
|
| 31 |
+
--checkpoint path/to/prisma.pt --hf-model gpt2-medium
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import json
|
| 36 |
+
import sys
|
| 37 |
+
import os
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from collections import defaultdict
|
| 40 |
+
|
| 41 |
+
import numpy as np
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn as nn
|
| 44 |
+
import matplotlib
|
| 45 |
+
matplotlib.use("Agg")
|
| 46 |
+
import matplotlib.pyplot as plt
|
| 47 |
+
from matplotlib.gridspec import GridSpec
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# Model loading
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
|
| 55 |
+
"""Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
|
| 56 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 57 |
+
from circuits.config import CircuitConfig
|
| 58 |
+
from circuits.model import CircuitTransformer
|
| 59 |
+
from circuits.mirrored import MirroredConfig, MirroredTransformer
|
| 60 |
+
|
| 61 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 62 |
+
model_type = ckpt.get("model_type", "standard")
|
| 63 |
+
config_dict = ckpt.get("config", {})
|
| 64 |
+
|
| 65 |
+
if model_type == "mirrored":
|
| 66 |
+
if config_dict.get("dual_gate_middle"):
|
| 67 |
+
config_dict.pop("dual_gate_middle")
|
| 68 |
+
config = MirroredConfig.from_dict(config_dict)
|
| 69 |
+
model = MirroredTransformer(config)
|
| 70 |
+
else:
|
| 71 |
+
config = CircuitConfig.from_dict(config_dict)
|
| 72 |
+
model = CircuitTransformer(config)
|
| 73 |
+
|
| 74 |
+
state_dict = ckpt["model"]
|
| 75 |
+
if any(k.startswith("_orig_mod.") for k in state_dict):
|
| 76 |
+
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| 77 |
+
model.load_state_dict(state_dict, strict=False)
|
| 78 |
+
model.to(device).eval()
|
| 79 |
+
|
| 80 |
+
return model, config_dict, model_type
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_hf_model(model_name: str, device: str = "cpu"):
|
| 84 |
+
"""Load a HuggingFace causal LM."""
|
| 85 |
+
from transformers import AutoModelForCausalLM
|
| 86 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
|
| 87 |
+
model.to(device).eval()
|
| 88 |
+
return model
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
# SVD utilities
|
| 93 |
+
# ---------------------------------------------------------------------------
|
| 94 |
+
|
| 95 |
+
def compute_singular_values(weight: torch.Tensor) -> np.ndarray:
|
| 96 |
+
"""Compute singular values of a 2D weight matrix."""
|
| 97 |
+
w = weight.detach().float().cpu()
|
| 98 |
+
if w.ndim != 2:
|
| 99 |
+
return None
|
| 100 |
+
sv = torch.linalg.svdvals(w).numpy()
|
| 101 |
+
return sv
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def effective_rank(sv: np.ndarray) -> float:
|
| 105 |
+
"""Entropy-based effective rank (Roy & Vetterli, 2007).
|
| 106 |
+
|
| 107 |
+
erank = exp(H(p)) where p_i = sigma_i / sum(sigma)
|
| 108 |
+
and H is Shannon entropy. Ranges from 1 (rank-1) to min(m,n) (full rank).
|
| 109 |
+
"""
|
| 110 |
+
sv = sv[sv > 1e-10]
|
| 111 |
+
if len(sv) == 0:
|
| 112 |
+
return 0.0
|
| 113 |
+
p = sv / sv.sum()
|
| 114 |
+
entropy = -(p * np.log(p)).sum()
|
| 115 |
+
return float(np.exp(entropy))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def stable_rank(sv: np.ndarray) -> float:
|
| 119 |
+
"""Stable rank = ||W||_F^2 / ||W||_2^2 = sum(sigma^2) / max(sigma)^2."""
|
| 120 |
+
if len(sv) == 0 or sv[0] < 1e-10:
|
| 121 |
+
return 0.0
|
| 122 |
+
return float((sv ** 2).sum() / (sv[0] ** 2))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def marchenko_pastur_bound(m: int, n: int, sv: np.ndarray) -> float:
|
| 126 |
+
"""Estimate Marchenko-Pastur upper edge.
|
| 127 |
+
|
| 128 |
+
For a random matrix with variance sigma^2, the MP upper bound is
|
| 129 |
+
sigma * (1 + sqrt(m/n))^2 (assuming m >= n).
|
| 130 |
+
We estimate sigma from the bulk of singular values.
|
| 131 |
+
"""
|
| 132 |
+
gamma = max(m, n) / min(m, n)
|
| 133 |
+
# Estimate noise level from bottom half of spectrum
|
| 134 |
+
bottom_half = sv[len(sv) // 2:]
|
| 135 |
+
if len(bottom_half) == 0:
|
| 136 |
+
return sv[-1] if len(sv) > 0 else 0.0
|
| 137 |
+
sigma_est = float(np.median(bottom_half)) / np.sqrt(max(m, n))
|
| 138 |
+
mp_upper = sigma_est * (1.0 + np.sqrt(gamma)) ** 2 * np.sqrt(min(m, n))
|
| 139 |
+
return mp_upper
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def fit_power_law(sv: np.ndarray, fit_fraction: float = 0.8) -> tuple[float, float]:
|
| 143 |
+
"""Fit power law to singular value distribution tail.
|
| 144 |
+
|
| 145 |
+
Returns (alpha, r_squared). alpha < 2 = heavy-tailed (well-trained).
|
| 146 |
+
"""
|
| 147 |
+
sv = sv[sv > 1e-10]
|
| 148 |
+
if len(sv) < 10:
|
| 149 |
+
return 0.0, 0.0
|
| 150 |
+
# Fit to the top `fit_fraction` of the spectrum (exclude noise floor)
|
| 151 |
+
n_fit = max(10, int(len(sv) * fit_fraction))
|
| 152 |
+
sv_fit = sv[:n_fit]
|
| 153 |
+
|
| 154 |
+
log_rank = np.log(np.arange(1, n_fit + 1))
|
| 155 |
+
log_sv = np.log(sv_fit)
|
| 156 |
+
|
| 157 |
+
# Linear regression in log-log space: log(sv) = -alpha * log(rank) + c
|
| 158 |
+
coeffs = np.polyfit(log_rank, log_sv, 1)
|
| 159 |
+
alpha = -coeffs[0]
|
| 160 |
+
|
| 161 |
+
# R-squared
|
| 162 |
+
predicted = np.polyval(coeffs, log_rank)
|
| 163 |
+
ss_res = ((log_sv - predicted) ** 2).sum()
|
| 164 |
+
ss_tot = ((log_sv - log_sv.mean()) ** 2).sum()
|
| 165 |
+
r_sq = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
|
| 166 |
+
|
| 167 |
+
return float(alpha), float(r_sq)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# Weight spectrum analysis
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
def analyze_weight_spectra(model: nn.Module, model_label: str = "model") -> dict:
|
| 175 |
+
"""Compute SVD spectra for all 2D weight matrices."""
|
| 176 |
+
results = {}
|
| 177 |
+
for name, param in model.named_parameters():
|
| 178 |
+
if param.ndim != 2:
|
| 179 |
+
continue
|
| 180 |
+
sv = compute_singular_values(param)
|
| 181 |
+
if sv is None:
|
| 182 |
+
continue
|
| 183 |
+
m, n = param.shape
|
| 184 |
+
mp_bound = marchenko_pastur_bound(m, n, sv)
|
| 185 |
+
n_above_mp = int((sv > mp_bound).sum())
|
| 186 |
+
alpha, r_sq = fit_power_law(sv)
|
| 187 |
+
|
| 188 |
+
results[name] = {
|
| 189 |
+
"shape": (m, n),
|
| 190 |
+
"singular_values": sv,
|
| 191 |
+
"effective_rank": effective_rank(sv),
|
| 192 |
+
"stable_rank": stable_rank(sv),
|
| 193 |
+
"spectral_norm": float(sv[0]),
|
| 194 |
+
"frobenius_norm": float(np.sqrt((sv ** 2).sum())),
|
| 195 |
+
"mp_bound": mp_bound,
|
| 196 |
+
"n_above_mp": n_above_mp,
|
| 197 |
+
"n_total": len(sv),
|
| 198 |
+
"signal_ratio": n_above_mp / len(sv) if len(sv) > 0 else 0,
|
| 199 |
+
"alpha": alpha,
|
| 200 |
+
"alpha_r2": r_sq,
|
| 201 |
+
"condition_number": float(sv[0] / sv[-1]) if sv[-1] > 1e-10 else float("inf"),
|
| 202 |
+
}
|
| 203 |
+
return results
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
# Activation spectrum analysis
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
|
| 210 |
+
def collect_activations(model, input_ids: torch.Tensor,
|
| 211 |
+
word_positions: torch.Tensor = None,
|
| 212 |
+
model_type: str = "standard") -> dict[str, torch.Tensor]:
|
| 213 |
+
"""Run a forward pass and collect intermediate activations via hooks."""
|
| 214 |
+
activations = {}
|
| 215 |
+
hooks = []
|
| 216 |
+
|
| 217 |
+
def make_hook(name):
|
| 218 |
+
def hook_fn(module, input, output):
|
| 219 |
+
if isinstance(output, tuple):
|
| 220 |
+
out = output[0]
|
| 221 |
+
else:
|
| 222 |
+
out = output
|
| 223 |
+
# Store mean over batch and sequence for covariance
|
| 224 |
+
activations[name] = out.detach().float().cpu()
|
| 225 |
+
return hook_fn
|
| 226 |
+
|
| 227 |
+
# Register hooks based on model type
|
| 228 |
+
if model_type == "mirrored":
|
| 229 |
+
# Expand phase
|
| 230 |
+
for i, block in enumerate(model.mirror_blocks):
|
| 231 |
+
hooks.append(block.register_forward_hook(make_hook(f"expand_{i}")))
|
| 232 |
+
# Middle
|
| 233 |
+
for i, block in enumerate(model.middle_blocks):
|
| 234 |
+
hooks.append(block.register_forward_hook(make_hook(f"middle_{i}")))
|
| 235 |
+
# Compress — mirror blocks are reused in reverse, so we hook the FFN output
|
| 236 |
+
# We'll collect compress activations differently via a custom forward
|
| 237 |
+
else:
|
| 238 |
+
for i, block in enumerate(model.layers):
|
| 239 |
+
hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
|
| 240 |
+
|
| 241 |
+
# Also hook the embedding output
|
| 242 |
+
hooks.append(model.embed.register_forward_hook(make_hook("embedding")))
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
kwargs = {}
|
| 246 |
+
if word_positions is not None:
|
| 247 |
+
kwargs["word_positions"] = word_positions
|
| 248 |
+
model(input_ids, **kwargs)
|
| 249 |
+
|
| 250 |
+
for h in hooks:
|
| 251 |
+
h.remove()
|
| 252 |
+
|
| 253 |
+
return activations
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def collect_mirrored_activations(model, input_ids: torch.Tensor,
|
| 257 |
+
word_positions: torch.Tensor = None) -> dict[str, torch.Tensor]:
|
| 258 |
+
"""Collect activations from a MirroredTransformer, separating expand and compress phases.
|
| 259 |
+
|
| 260 |
+
This manually runs the forward pass to capture compress-phase activations
|
| 261 |
+
from the reversed mirror blocks.
|
| 262 |
+
"""
|
| 263 |
+
import math
|
| 264 |
+
|
| 265 |
+
activations = {}
|
| 266 |
+
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
# Embed
|
| 269 |
+
x = model.embed(input_ids)
|
| 270 |
+
if model.embed_proj is not None:
|
| 271 |
+
import torch.nn.functional as F
|
| 272 |
+
if model.embed_g3 is not None:
|
| 273 |
+
g4 = F.silu(model.embed_g4(x))
|
| 274 |
+
g3 = F.silu(model.embed_g3(x) * g4)
|
| 275 |
+
x = model.embed_proj(x) * g3
|
| 276 |
+
else:
|
| 277 |
+
x = F.silu(model.embed_proj(x))
|
| 278 |
+
x = x * model.embed_scale
|
| 279 |
+
activations["embedding"] = x.detach().float().cpu()
|
| 280 |
+
|
| 281 |
+
# Expand phase
|
| 282 |
+
for i, block in enumerate(model.mirror_blocks):
|
| 283 |
+
x, _ = block(x, word_positions=word_positions)
|
| 284 |
+
activations[f"expand_{i}"] = x.detach().float().cpu()
|
| 285 |
+
|
| 286 |
+
# Middle phase
|
| 287 |
+
for i, block in enumerate(model.middle_blocks):
|
| 288 |
+
x, _ = block(x, word_positions=word_positions)
|
| 289 |
+
activations[f"middle_{i}"] = x.detach().float().cpu()
|
| 290 |
+
|
| 291 |
+
# Compress phase (reversed)
|
| 292 |
+
for i in reversed(range(len(model.mirror_blocks))):
|
| 293 |
+
x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
|
| 294 |
+
compress_idx = len(model.mirror_blocks) - 1 - i
|
| 295 |
+
activations[f"compress_{compress_idx}"] = x.detach().float().cpu()
|
| 296 |
+
|
| 297 |
+
# Final norm
|
| 298 |
+
x = model.norm(x)
|
| 299 |
+
activations["final_norm"] = x.detach().float().cpu()
|
| 300 |
+
|
| 301 |
+
return activations
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def activation_spectrum(act: torch.Tensor, max_components: int = 256) -> dict:
|
| 305 |
+
"""Compute eigenspectrum of activation covariance.
|
| 306 |
+
|
| 307 |
+
act: [B, T, D] — reshape to [B*T, D], compute covariance, eigendecompose.
|
| 308 |
+
"""
|
| 309 |
+
# Flatten batch and sequence
|
| 310 |
+
flat = act.reshape(-1, act.shape[-1]) # [N, D]
|
| 311 |
+
N, D = flat.shape
|
| 312 |
+
|
| 313 |
+
if N < 2:
|
| 314 |
+
return None
|
| 315 |
+
|
| 316 |
+
# Center
|
| 317 |
+
flat = flat - flat.mean(dim=0, keepdim=True)
|
| 318 |
+
|
| 319 |
+
# Compute covariance via SVD of the data matrix (more stable than cov matrix)
|
| 320 |
+
n_components = min(max_components, D, N)
|
| 321 |
+
try:
|
| 322 |
+
U, S, Vh = torch.pca_lowrank(flat, q=n_components)
|
| 323 |
+
eigenvalues = (S ** 2 / (N - 1)).numpy()
|
| 324 |
+
except Exception:
|
| 325 |
+
# Fallback: full covariance
|
| 326 |
+
cov = (flat.T @ flat) / (N - 1)
|
| 327 |
+
eigenvalues = torch.linalg.eigvalsh(cov).flip(0).numpy()
|
| 328 |
+
eigenvalues = eigenvalues[:max_components]
|
| 329 |
+
|
| 330 |
+
eigenvalues = eigenvalues[eigenvalues > 1e-10]
|
| 331 |
+
|
| 332 |
+
return {
|
| 333 |
+
"eigenvalues": eigenvalues,
|
| 334 |
+
"effective_rank": effective_rank(np.sqrt(np.maximum(eigenvalues, 0))),
|
| 335 |
+
"total_variance": float(eigenvalues.sum()),
|
| 336 |
+
"top1_variance_ratio": float(eigenvalues[0] / eigenvalues.sum()) if len(eigenvalues) > 0 else 0,
|
| 337 |
+
"top10_variance_ratio": float(eigenvalues[:10].sum() / eigenvalues.sum()) if len(eigenvalues) >= 10 else 0,
|
| 338 |
+
"n_components": len(eigenvalues),
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# ---------------------------------------------------------------------------
|
| 343 |
+
# Plotting
|
| 344 |
+
# ---------------------------------------------------------------------------
|
| 345 |
+
|
| 346 |
+
def plot_weight_spectra(results: dict, output_dir: Path, model_label: str = "model",
|
| 347 |
+
results_b: dict = None, model_b_label: str = "model_b"):
|
| 348 |
+
"""Plot singular value distributions for all weight matrices."""
|
| 349 |
+
# Group by layer/component type
|
| 350 |
+
groups = defaultdict(list)
|
| 351 |
+
for name, data in results.items():
|
| 352 |
+
# Identify the component type
|
| 353 |
+
if "attn" in name and ("q_proj" in name or "wq" in name):
|
| 354 |
+
groups["attention_Q"].append((name, data))
|
| 355 |
+
elif "attn" in name and ("k_proj" in name or "wk" in name):
|
| 356 |
+
groups["attention_K"].append((name, data))
|
| 357 |
+
elif "attn" in name and ("v_proj" in name or "wv" in name):
|
| 358 |
+
groups["attention_V"].append((name, data))
|
| 359 |
+
elif "attn" in name and ("o_proj" in name or "wo" in name):
|
| 360 |
+
groups["attention_O"].append((name, data))
|
| 361 |
+
elif "w1" in name or "up_proj" in name:
|
| 362 |
+
groups["ffn_W1"].append((name, data))
|
| 363 |
+
elif "w2" in name or "down_proj" in name:
|
| 364 |
+
groups["ffn_W2"].append((name, data))
|
| 365 |
+
elif "w3" in name or "gate_proj" in name:
|
| 366 |
+
groups["ffn_gate_W3"].append((name, data))
|
| 367 |
+
elif "w4" in name:
|
| 368 |
+
groups["ffn_gate_W4"].append((name, data))
|
| 369 |
+
elif "embed" in name or "wte" in name:
|
| 370 |
+
groups["embedding"].append((name, data))
|
| 371 |
+
elif "lm_head" in name:
|
| 372 |
+
groups["lm_head"].append((name, data))
|
| 373 |
+
else:
|
| 374 |
+
groups["other"].append((name, data))
|
| 375 |
+
|
| 376 |
+
# Plot each group
|
| 377 |
+
for group_name, items in groups.items():
|
| 378 |
+
if not items:
|
| 379 |
+
continue
|
| 380 |
+
|
| 381 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 382 |
+
fig.suptitle(f"{model_label} — {group_name} weight spectra", fontsize=13)
|
| 383 |
+
|
| 384 |
+
ax_linear, ax_log = axes
|
| 385 |
+
|
| 386 |
+
cmap = plt.cm.viridis(np.linspace(0.1, 0.9, len(items)))
|
| 387 |
+
for idx, (name, data) in enumerate(items):
|
| 388 |
+
sv = data["singular_values"]
|
| 389 |
+
short_name = name.split(".")[-2] + "." + name.split(".")[-1] if "." in name else name
|
| 390 |
+
ax_linear.plot(sv, color=cmap[idx], alpha=0.7, linewidth=0.8, label=short_name)
|
| 391 |
+
ax_log.loglog(np.arange(1, len(sv) + 1), sv, color=cmap[idx], alpha=0.7,
|
| 392 |
+
linewidth=0.8, label=short_name)
|
| 393 |
+
# MP bound
|
| 394 |
+
ax_linear.axhline(data["mp_bound"], color=cmap[idx], linestyle=":", alpha=0.3)
|
| 395 |
+
|
| 396 |
+
ax_linear.set_xlabel("Rank")
|
| 397 |
+
ax_linear.set_ylabel("Singular value")
|
| 398 |
+
ax_linear.set_title("Linear scale")
|
| 399 |
+
ax_linear.legend(fontsize=6, ncol=2)
|
| 400 |
+
|
| 401 |
+
ax_log.set_xlabel("Rank")
|
| 402 |
+
ax_log.set_ylabel("Singular value")
|
| 403 |
+
ax_log.set_title("Log-log scale (power law)")
|
| 404 |
+
ax_log.legend(fontsize=6, ncol=2)
|
| 405 |
+
|
| 406 |
+
plt.tight_layout()
|
| 407 |
+
fig.savefig(output_dir / f"weight_spectra_{group_name}.png", dpi=150)
|
| 408 |
+
plt.close(fig)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def plot_effective_rank_progression(results: dict, output_dir: Path,
|
| 412 |
+
model_label: str = "model",
|
| 413 |
+
results_b: dict = None,
|
| 414 |
+
model_b_label: str = "model_b"):
|
| 415 |
+
"""Plot effective rank per layer — the biconcave lens in eigenvalues."""
|
| 416 |
+
# Extract layer-ordered FFN W1 effective ranks (the main signal path)
|
| 417 |
+
layer_data = []
|
| 418 |
+
for name, data in sorted(results.items()):
|
| 419 |
+
if "w1" in name or "up_proj" in name:
|
| 420 |
+
# Extract layer index
|
| 421 |
+
parts = name.split(".")
|
| 422 |
+
layer_label = name
|
| 423 |
+
for p in parts:
|
| 424 |
+
if p.isdigit():
|
| 425 |
+
layer_label = p
|
| 426 |
+
break
|
| 427 |
+
layer_data.append((name, data["effective_rank"], data["stable_rank"],
|
| 428 |
+
data["alpha"], data["signal_ratio"], layer_label))
|
| 429 |
+
|
| 430 |
+
if not layer_data:
|
| 431 |
+
return
|
| 432 |
+
|
| 433 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 434 |
+
fig.suptitle(f"{model_label} — Layer-wise spectral properties (FFN W1)", fontsize=13)
|
| 435 |
+
|
| 436 |
+
names = [d[0] for d in layer_data]
|
| 437 |
+
x = range(len(layer_data))
|
| 438 |
+
short_labels = [d[5] for d in layer_data]
|
| 439 |
+
|
| 440 |
+
# Effective rank
|
| 441 |
+
axes[0, 0].bar(x, [d[1] for d in layer_data], color="steelblue", alpha=0.8)
|
| 442 |
+
axes[0, 0].set_ylabel("Effective rank")
|
| 443 |
+
axes[0, 0].set_title("Effective rank (entropy-based)")
|
| 444 |
+
axes[0, 0].set_xticks(x)
|
| 445 |
+
axes[0, 0].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| 446 |
+
|
| 447 |
+
# Stable rank
|
| 448 |
+
axes[0, 1].bar(x, [d[2] for d in layer_data], color="coral", alpha=0.8)
|
| 449 |
+
axes[0, 1].set_ylabel("Stable rank")
|
| 450 |
+
axes[0, 1].set_title("Stable rank (Frobenius/spectral)")
|
| 451 |
+
axes[0, 1].set_xticks(x)
|
| 452 |
+
axes[0, 1].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| 453 |
+
|
| 454 |
+
# Power-law alpha
|
| 455 |
+
axes[1, 0].bar(x, [d[3] for d in layer_data], color="mediumpurple", alpha=0.8)
|
| 456 |
+
axes[1, 0].set_ylabel("Alpha")
|
| 457 |
+
axes[1, 0].set_title("Power-law exponent (lower = heavier tail = more structure)")
|
| 458 |
+
axes[1, 0].axhline(2.0, color="red", linestyle="--", alpha=0.5, label="alpha=2 boundary")
|
| 459 |
+
axes[1, 0].legend(fontsize=8)
|
| 460 |
+
axes[1, 0].set_xticks(x)
|
| 461 |
+
axes[1, 0].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| 462 |
+
|
| 463 |
+
# Signal ratio (above MP)
|
| 464 |
+
axes[1, 1].bar(x, [d[4] for d in layer_data], color="seagreen", alpha=0.8)
|
| 465 |
+
axes[1, 1].set_ylabel("Signal ratio")
|
| 466 |
+
axes[1, 1].set_title("Fraction of singular values above MP bound")
|
| 467 |
+
axes[1, 1].set_xticks(x)
|
| 468 |
+
axes[1, 1].set_xticklabels(short_labels, rotation=45, fontsize=7)
|
| 469 |
+
|
| 470 |
+
plt.tight_layout()
|
| 471 |
+
fig.savefig(output_dir / "layer_progression.png", dpi=150)
|
| 472 |
+
plt.close(fig)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def plot_activation_spectra(act_spectra: dict, output_dir: Path,
|
| 476 |
+
model_label: str = "model"):
|
| 477 |
+
"""Plot activation eigenspectra across layers."""
|
| 478 |
+
if not act_spectra:
|
| 479 |
+
return
|
| 480 |
+
|
| 481 |
+
# Sort layers in processing order
|
| 482 |
+
order_keys = {"embedding": -1, "final_norm": 999}
|
| 483 |
+
def sort_key(name):
|
| 484 |
+
if name in order_keys:
|
| 485 |
+
return order_keys[name]
|
| 486 |
+
parts = name.split("_")
|
| 487 |
+
phase = parts[0]
|
| 488 |
+
idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
|
| 489 |
+
phase_offset = {"expand": 0, "middle": 100, "compress": 200, "layer": 0}
|
| 490 |
+
return phase_offset.get(phase, 300) + idx
|
| 491 |
+
|
| 492 |
+
sorted_names = sorted(act_spectra.keys(), key=sort_key)
|
| 493 |
+
|
| 494 |
+
# -- Eigenvalue distributions --
|
| 495 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
| 496 |
+
fig.suptitle(f"{model_label} — Activation eigenspectra", fontsize=13)
|
| 497 |
+
|
| 498 |
+
cmap = plt.cm.coolwarm(np.linspace(0, 1, len(sorted_names)))
|
| 499 |
+
for idx, name in enumerate(sorted_names):
|
| 500 |
+
data = act_spectra[name]
|
| 501 |
+
ev = data["eigenvalues"]
|
| 502 |
+
axes[0].semilogy(ev / ev.sum(), color=cmap[idx], alpha=0.7, linewidth=1.0, label=name)
|
| 503 |
+
axes[1].plot(np.cumsum(ev) / ev.sum(), color=cmap[idx], alpha=0.7, linewidth=1.0, label=name)
|
| 504 |
+
|
| 505 |
+
axes[0].set_xlabel("Component")
|
| 506 |
+
axes[0].set_ylabel("Normalized eigenvalue (log)")
|
| 507 |
+
axes[0].set_title("Eigenvalue distribution")
|
| 508 |
+
axes[0].legend(fontsize=6, ncol=2)
|
| 509 |
+
|
| 510 |
+
axes[1].set_xlabel("Component")
|
| 511 |
+
axes[1].set_ylabel("Cumulative variance explained")
|
| 512 |
+
axes[1].set_title("Variance concentration")
|
| 513 |
+
axes[1].axhline(0.9, color="gray", linestyle="--", alpha=0.4, label="90%")
|
| 514 |
+
axes[1].legend(fontsize=6, ncol=2)
|
| 515 |
+
|
| 516 |
+
plt.tight_layout()
|
| 517 |
+
fig.savefig(output_dir / "activation_spectra.png", dpi=150)
|
| 518 |
+
plt.close(fig)
|
| 519 |
+
|
| 520 |
+
# -- Effective rank progression (the lens shape) --
|
| 521 |
+
fig, ax = plt.subplots(figsize=(12, 5))
|
| 522 |
+
fig.suptitle(f"{model_label} — Activation effective rank progression", fontsize=13)
|
| 523 |
+
|
| 524 |
+
eranks = [act_spectra[n]["effective_rank"] for n in sorted_names]
|
| 525 |
+
colors = []
|
| 526 |
+
for name in sorted_names:
|
| 527 |
+
if "expand" in name:
|
| 528 |
+
colors.append("steelblue")
|
| 529 |
+
elif "middle" in name:
|
| 530 |
+
colors.append("goldenrod")
|
| 531 |
+
elif "compress" in name:
|
| 532 |
+
colors.append("coral")
|
| 533 |
+
else:
|
| 534 |
+
colors.append("gray")
|
| 535 |
+
|
| 536 |
+
ax.bar(range(len(sorted_names)), eranks, color=colors, alpha=0.8)
|
| 537 |
+
ax.set_xticks(range(len(sorted_names)))
|
| 538 |
+
ax.set_xticklabels(sorted_names, rotation=45, ha="right", fontsize=8)
|
| 539 |
+
ax.set_ylabel("Effective rank")
|
| 540 |
+
ax.set_title("Expand (blue) → Middle (gold) → Compress (coral)")
|
| 541 |
+
|
| 542 |
+
plt.tight_layout()
|
| 543 |
+
fig.savefig(output_dir / "activation_rank_progression.png", dpi=150)
|
| 544 |
+
plt.close(fig)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def plot_mirror_comparison(act_spectra: dict, output_dir: Path,
|
| 548 |
+
model_label: str = "model"):
|
| 549 |
+
"""Compare expand vs compress activation spectra for each mirror pair."""
|
| 550 |
+
expand_layers = sorted([n for n in act_spectra if n.startswith("expand_")])
|
| 551 |
+
compress_layers = sorted([n for n in act_spectra if n.startswith("compress_")])
|
| 552 |
+
|
| 553 |
+
if not expand_layers or not compress_layers:
|
| 554 |
+
return
|
| 555 |
+
|
| 556 |
+
n_pairs = min(len(expand_layers), len(compress_layers))
|
| 557 |
+
fig, axes = plt.subplots(1, n_pairs, figsize=(4 * n_pairs, 4), squeeze=False)
|
| 558 |
+
fig.suptitle(f"{model_label} — Mirror pair activation spectra (expand vs compress)", fontsize=13)
|
| 559 |
+
|
| 560 |
+
for i in range(n_pairs):
|
| 561 |
+
ax = axes[0, i]
|
| 562 |
+
exp_ev = act_spectra[expand_layers[i]]["eigenvalues"]
|
| 563 |
+
comp_ev = act_spectra[compress_layers[i]]["eigenvalues"]
|
| 564 |
+
|
| 565 |
+
n_plot = min(len(exp_ev), len(comp_ev), 100)
|
| 566 |
+
ax.semilogy(exp_ev[:n_plot] / exp_ev.sum(), color="steelblue", alpha=0.8,
|
| 567 |
+
linewidth=1.5, label="expand")
|
| 568 |
+
ax.semilogy(comp_ev[:n_plot] / comp_ev.sum(), color="coral", alpha=0.8,
|
| 569 |
+
linewidth=1.5, label="compress")
|
| 570 |
+
|
| 571 |
+
exp_er = act_spectra[expand_layers[i]]["effective_rank"]
|
| 572 |
+
comp_er = act_spectra[compress_layers[i]]["effective_rank"]
|
| 573 |
+
ax.set_title(f"Pair {i}\nerank: {exp_er:.0f} / {comp_er:.0f}", fontsize=10)
|
| 574 |
+
ax.set_xlabel("Component")
|
| 575 |
+
if i == 0:
|
| 576 |
+
ax.set_ylabel("Normalized eigenvalue")
|
| 577 |
+
ax.legend(fontsize=8)
|
| 578 |
+
|
| 579 |
+
plt.tight_layout()
|
| 580 |
+
fig.savefig(output_dir / "mirror_pair_comparison.png", dpi=150)
|
| 581 |
+
plt.close(fig)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def plot_gate_spectra(results: dict, output_dir: Path, model_label: str = "model"):
|
| 585 |
+
"""Compare W3 vs W4 gate weight spectra (G2LU inner vs outer gate)."""
|
| 586 |
+
w3_items = [(n, d) for n, d in sorted(results.items()) if "w3" in n and "ffn" in n]
|
| 587 |
+
w4_items = [(n, d) for n, d in sorted(results.items()) if "w4" in n and "ffn" in n]
|
| 588 |
+
|
| 589 |
+
if not w3_items or not w4_items:
|
| 590 |
+
return
|
| 591 |
+
|
| 592 |
+
n_pairs = min(len(w3_items), len(w4_items))
|
| 593 |
+
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
|
| 594 |
+
fig.suptitle(f"{model_label} — G2LU gate spectra (W3 outer vs W4 inner)", fontsize=13)
|
| 595 |
+
|
| 596 |
+
# Overlay all W3 vs W4
|
| 597 |
+
cmap_w3 = plt.cm.Blues(np.linspace(0.3, 0.9, n_pairs))
|
| 598 |
+
cmap_w4 = plt.cm.Reds(np.linspace(0.3, 0.9, n_pairs))
|
| 599 |
+
|
| 600 |
+
for i in range(n_pairs):
|
| 601 |
+
sv3 = w3_items[i][1]["singular_values"]
|
| 602 |
+
sv4 = w4_items[i][1]["singular_values"]
|
| 603 |
+
axes[0].semilogy(sv3, color=cmap_w3[i], alpha=0.6, linewidth=0.8, label=f"W3 pair {i}")
|
| 604 |
+
axes[0].semilogy(sv4, color=cmap_w4[i], alpha=0.6, linewidth=0.8, label=f"W4 pair {i}")
|
| 605 |
+
|
| 606 |
+
axes[0].set_xlabel("Rank")
|
| 607 |
+
axes[0].set_ylabel("Singular value (log)")
|
| 608 |
+
axes[0].set_title("Gate weight spectra")
|
| 609 |
+
axes[0].legend(fontsize=6, ncol=4)
|
| 610 |
+
|
| 611 |
+
# Effective rank comparison
|
| 612 |
+
er_w3 = [w3_items[i][1]["effective_rank"] for i in range(n_pairs)]
|
| 613 |
+
er_w4 = [w4_items[i][1]["effective_rank"] for i in range(n_pairs)]
|
| 614 |
+
x = np.arange(n_pairs)
|
| 615 |
+
axes[1].bar(x - 0.15, er_w3, 0.3, color="steelblue", alpha=0.8, label="W3 (outer gate)")
|
| 616 |
+
axes[1].bar(x + 0.15, er_w4, 0.3, color="coral", alpha=0.8, label="W4 (inner gate)")
|
| 617 |
+
axes[1].set_xlabel("Mirror pair")
|
| 618 |
+
axes[1].set_ylabel("Effective rank")
|
| 619 |
+
axes[1].set_title("Gate effective rank by pair")
|
| 620 |
+
axes[1].set_xticks(x)
|
| 621 |
+
axes[1].legend()
|
| 622 |
+
|
| 623 |
+
plt.tight_layout()
|
| 624 |
+
fig.savefig(output_dir / "gate_spectra.png", dpi=150)
|
| 625 |
+
plt.close(fig)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def plot_embedding_alignment(results: dict, act_spectra: dict, output_dir: Path,
|
| 629 |
+
model_label: str = "model"):
|
| 630 |
+
"""Compare embedding weight spectrum with final layer activation spectrum."""
|
| 631 |
+
embed_data = None
|
| 632 |
+
for name, data in results.items():
|
| 633 |
+
if "embed" in name.lower() and "proj" not in name.lower() and "g3" not in name.lower() and "g4" not in name.lower():
|
| 634 |
+
embed_data = data
|
| 635 |
+
break
|
| 636 |
+
|
| 637 |
+
final_act = act_spectra.get("final_norm") or act_spectra.get("compress_0")
|
| 638 |
+
if embed_data is None or final_act is None:
|
| 639 |
+
return
|
| 640 |
+
|
| 641 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
| 642 |
+
fig.suptitle(f"{model_label} — Embedding vs final activation spectra", fontsize=13)
|
| 643 |
+
|
| 644 |
+
# Normalized comparison
|
| 645 |
+
sv_embed = embed_data["singular_values"]
|
| 646 |
+
ev_final = final_act["eigenvalues"]
|
| 647 |
+
sv_embed_norm = sv_embed / sv_embed.sum()
|
| 648 |
+
ev_final_norm = ev_final / ev_final.sum()
|
| 649 |
+
|
| 650 |
+
n_plot = min(len(sv_embed_norm), len(ev_final_norm), 200)
|
| 651 |
+
axes[0].semilogy(sv_embed_norm[:n_plot], color="steelblue", linewidth=1.5,
|
| 652 |
+
label=f"Embedding (erank={embed_data['effective_rank']:.0f})")
|
| 653 |
+
axes[0].semilogy(ev_final_norm[:n_plot], color="coral", linewidth=1.5,
|
| 654 |
+
label=f"Final act (erank={final_act['effective_rank']:.0f})")
|
| 655 |
+
axes[0].set_xlabel("Component")
|
| 656 |
+
axes[0].set_ylabel("Normalized value (log)")
|
| 657 |
+
axes[0].set_title("Spectral shape comparison")
|
| 658 |
+
axes[0].legend()
|
| 659 |
+
|
| 660 |
+
# Cumulative variance
|
| 661 |
+
axes[1].plot(np.cumsum(sv_embed_norm[:n_plot]), color="steelblue", linewidth=1.5, label="Embedding")
|
| 662 |
+
axes[1].plot(np.cumsum(ev_final_norm[:n_plot]), color="coral", linewidth=1.5, label="Final activation")
|
| 663 |
+
axes[1].set_xlabel("Component")
|
| 664 |
+
axes[1].set_ylabel("Cumulative fraction")
|
| 665 |
+
axes[1].set_title("Variance concentration")
|
| 666 |
+
axes[1].axhline(0.9, color="gray", linestyle="--", alpha=0.4)
|
| 667 |
+
axes[1].legend()
|
| 668 |
+
|
| 669 |
+
plt.tight_layout()
|
| 670 |
+
fig.savefig(output_dir / "embedding_alignment.png", dpi=150)
|
| 671 |
+
plt.close(fig)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def plot_comparison(results_a: dict, results_b: dict,
|
| 675 |
+
label_a: str, label_b: str,
|
| 676 |
+
output_dir: Path):
|
| 677 |
+
"""Side-by-side comparison of two models' spectral properties."""
|
| 678 |
+
# Collect effective ranks for FFN W1 / up_proj
|
| 679 |
+
def extract_ffn_ranks(results):
|
| 680 |
+
ranks = []
|
| 681 |
+
for name, data in sorted(results.items()):
|
| 682 |
+
if ("w1" in name or "up_proj" in name or "c_fc" in name
|
| 683 |
+
or "dense_h_to_4h" in name) and "embed" not in name:
|
| 684 |
+
ranks.append((name, data["effective_rank"], data["stable_rank"], data["alpha"]))
|
| 685 |
+
return ranks
|
| 686 |
+
|
| 687 |
+
ranks_a = extract_ffn_ranks(results_a)
|
| 688 |
+
ranks_b = extract_ffn_ranks(results_b)
|
| 689 |
+
|
| 690 |
+
if not ranks_a or not ranks_b:
|
| 691 |
+
return
|
| 692 |
+
|
| 693 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 694 |
+
fig.suptitle(f"Comparison: {label_a} vs {label_b}", fontsize=13)
|
| 695 |
+
|
| 696 |
+
n = min(len(ranks_a), len(ranks_b))
|
| 697 |
+
x = np.arange(n)
|
| 698 |
+
|
| 699 |
+
for ax_idx, (metric_idx, ylabel, title) in enumerate([
|
| 700 |
+
(1, "Effective rank", "Effective rank per layer"),
|
| 701 |
+
(2, "Stable rank", "Stable rank per layer"),
|
| 702 |
+
(3, "Alpha", "Power-law alpha per layer"),
|
| 703 |
+
]):
|
| 704 |
+
vals_a = [ranks_a[i][metric_idx] for i in range(n)]
|
| 705 |
+
vals_b = [ranks_b[i][metric_idx] for i in range(n)]
|
| 706 |
+
axes[ax_idx].bar(x - 0.15, vals_a, 0.3, color="steelblue", alpha=0.8, label=label_a)
|
| 707 |
+
axes[ax_idx].bar(x + 0.15, vals_b, 0.3, color="coral", alpha=0.8, label=label_b)
|
| 708 |
+
axes[ax_idx].set_xlabel("Layer")
|
| 709 |
+
axes[ax_idx].set_ylabel(ylabel)
|
| 710 |
+
axes[ax_idx].set_title(title)
|
| 711 |
+
axes[ax_idx].legend(fontsize=8)
|
| 712 |
+
|
| 713 |
+
plt.tight_layout()
|
| 714 |
+
fig.savefig(output_dir / "comparison.png", dpi=150)
|
| 715 |
+
plt.close(fig)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# ---------------------------------------------------------------------------
|
| 719 |
+
# Summary report
|
| 720 |
+
# ---------------------------------------------------------------------------
|
| 721 |
+
|
| 722 |
+
def print_summary(results: dict, model_label: str, act_spectra: dict = None):
|
| 723 |
+
"""Print a concise text summary of spectral analysis."""
|
| 724 |
+
print(f"\n{'='*70}")
|
| 725 |
+
print(f" Spectral Analysis: {model_label}")
|
| 726 |
+
print(f"{'='*70}")
|
| 727 |
+
|
| 728 |
+
# Group by component type
|
| 729 |
+
components = defaultdict(list)
|
| 730 |
+
for name, data in sorted(results.items()):
|
| 731 |
+
if "w1" in name or "up_proj" in name:
|
| 732 |
+
components["FFN W1 (up)"].append(data)
|
| 733 |
+
elif "w2" in name or "down_proj" in name:
|
| 734 |
+
components["FFN W2 (down)"].append(data)
|
| 735 |
+
elif "w3" in name:
|
| 736 |
+
components["FFN W3 (outer gate)"].append(data)
|
| 737 |
+
elif "w4" in name:
|
| 738 |
+
components["FFN W4 (inner gate)"].append(data)
|
| 739 |
+
elif "embed" in name.lower() and "proj" not in name and "g3" not in name and "g4" not in name:
|
| 740 |
+
components["Embedding"].append(data)
|
| 741 |
+
|
| 742 |
+
print(f"\n{'Component':<25} {'Shape':>12} {'eRank':>8} {'sRank':>8} {'Alpha':>8} {'Sig%':>8} {'Cond#':>10}")
|
| 743 |
+
print("-" * 85)
|
| 744 |
+
for comp_name, items in components.items():
|
| 745 |
+
for i, data in enumerate(items):
|
| 746 |
+
label = f"{comp_name}" if len(items) == 1 else f"{comp_name}[{i}]"
|
| 747 |
+
shape_str = f"{data['shape'][0]}x{data['shape'][1]}"
|
| 748 |
+
cond = f"{data['condition_number']:.0f}" if data['condition_number'] < 1e6 else "inf"
|
| 749 |
+
print(f"{label:<25} {shape_str:>12} {data['effective_rank']:>8.1f} "
|
| 750 |
+
f"{data['stable_rank']:>8.1f} {data['alpha']:>8.3f} "
|
| 751 |
+
f"{data['signal_ratio']*100:>7.1f}% {cond:>10}")
|
| 752 |
+
|
| 753 |
+
# Aggregate stats
|
| 754 |
+
all_alphas = [d["alpha"] for d in results.values() if d["alpha"] > 0]
|
| 755 |
+
all_eranks = [d["effective_rank"] for d in results.values()]
|
| 756 |
+
if all_alphas:
|
| 757 |
+
print(f"\n Mean alpha: {np.mean(all_alphas):.3f} (< 2.0 = heavy-tailed = well-structured)")
|
| 758 |
+
print(f" Mean effective rank: {np.mean(all_eranks):.1f}")
|
| 759 |
+
|
| 760 |
+
# Activation summary
|
| 761 |
+
if act_spectra:
|
| 762 |
+
print(f"\n Activation spectra:")
|
| 763 |
+
print(f" {'Layer':<25} {'eRank':>8} {'Top1%':>8} {'Top10%':>8}")
|
| 764 |
+
print(" " + "-" * 55)
|
| 765 |
+
|
| 766 |
+
order_keys = {"embedding": -1, "final_norm": 999}
|
| 767 |
+
def sort_key(name):
|
| 768 |
+
if name in order_keys:
|
| 769 |
+
return order_keys[name]
|
| 770 |
+
parts = name.split("_")
|
| 771 |
+
phase = parts[0]
|
| 772 |
+
idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
|
| 773 |
+
phase_offset = {"expand": 0, "middle": 100, "compress": 200, "layer": 0}
|
| 774 |
+
return phase_offset.get(phase, 300) + idx
|
| 775 |
+
|
| 776 |
+
for name in sorted(act_spectra.keys(), key=sort_key):
|
| 777 |
+
data = act_spectra[name]
|
| 778 |
+
print(f" {name:<25} {data['effective_rank']:>8.1f} "
|
| 779 |
+
f"{data['top1_variance_ratio']*100:>7.1f}% "
|
| 780 |
+
f"{data['top10_variance_ratio']*100:>7.1f}%")
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def save_results_json(results: dict, act_spectra: dict, output_path: Path):
|
| 784 |
+
"""Save numerical results (no numpy arrays) to JSON."""
|
| 785 |
+
out = {}
|
| 786 |
+
for name, data in results.items():
|
| 787 |
+
out[name] = {k: v for k, v in data.items() if k != "singular_values"}
|
| 788 |
+
out[name]["top_10_sv"] = data["singular_values"][:10].tolist()
|
| 789 |
+
|
| 790 |
+
if act_spectra:
|
| 791 |
+
out["_activations"] = {}
|
| 792 |
+
for name, data in act_spectra.items():
|
| 793 |
+
out["_activations"][name] = {k: v for k, v in data.items() if k != "eigenvalues"}
|
| 794 |
+
out["_activations"][name]["top_10_ev"] = data["eigenvalues"][:10].tolist()
|
| 795 |
+
|
| 796 |
+
with open(output_path, "w") as f:
|
| 797 |
+
json.dump(out, f, indent=2, default=str)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
# ---------------------------------------------------------------------------
|
| 801 |
+
# Data loading (minimal — just enough tokens for activation analysis)
|
| 802 |
+
# ---------------------------------------------------------------------------
|
| 803 |
+
|
| 804 |
+
def load_sample_data(data_source: str, tokenizer_name: str, num_samples: int = 256,
|
| 805 |
+
context_length: int = 512, device: str = "cpu"):
|
| 806 |
+
"""Load a small batch of tokenized data for activation analysis."""
|
| 807 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 808 |
+
from circuits.data import get_tokenizer
|
| 809 |
+
|
| 810 |
+
tokenizer = get_tokenizer(tokenizer_name)
|
| 811 |
+
|
| 812 |
+
if data_source.startswith("hf:"):
|
| 813 |
+
from datasets import load_dataset
|
| 814 |
+
parts = data_source[3:].split(":")
|
| 815 |
+
ds_name = parts[0]
|
| 816 |
+
ds_config = parts[1] if len(parts) > 1 else None
|
| 817 |
+
ds_split = parts[2] if len(parts) > 2 else "train"
|
| 818 |
+
dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
|
| 819 |
+
texts = []
|
| 820 |
+
for item in dataset:
|
| 821 |
+
texts.append(item.get("text", ""))
|
| 822 |
+
if len(texts) >= num_samples:
|
| 823 |
+
break
|
| 824 |
+
else:
|
| 825 |
+
with open(data_source) as f:
|
| 826 |
+
texts = [line.strip() for line in f if line.strip()][:num_samples]
|
| 827 |
+
|
| 828 |
+
# Tokenize and create batches
|
| 829 |
+
all_ids = []
|
| 830 |
+
for text in texts:
|
| 831 |
+
ids = tokenizer.encode(text)
|
| 832 |
+
if len(ids) >= context_length:
|
| 833 |
+
all_ids.append(ids[:context_length])
|
| 834 |
+
elif len(ids) > 32:
|
| 835 |
+
all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
|
| 836 |
+
|
| 837 |
+
if not all_ids:
|
| 838 |
+
return None, tokenizer
|
| 839 |
+
|
| 840 |
+
input_ids = torch.tensor(all_ids[:num_samples], device=device)
|
| 841 |
+
return input_ids, tokenizer
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
# ---------------------------------------------------------------------------
|
| 845 |
+
# Main
|
| 846 |
+
# ---------------------------------------------------------------------------
|
| 847 |
+
|
| 848 |
+
def main():
|
| 849 |
+
parser = argparse.ArgumentParser(description="Spectral analysis of Prisma checkpoints")
|
| 850 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to Prisma/Circuit checkpoint")
|
| 851 |
+
parser.add_argument("--checkpoint-b", type=str, default=None, help="Second checkpoint for comparison")
|
| 852 |
+
parser.add_argument("--hf-model", type=str, default=None, help="HuggingFace model name for comparison")
|
| 853 |
+
parser.add_argument("--data", type=str, default=None,
|
| 854 |
+
help="Data source for activation analysis (hf:dataset:config:split or path)")
|
| 855 |
+
parser.add_argument("--num-samples", type=int, default=256, help="Number of samples for activation analysis")
|
| 856 |
+
parser.add_argument("--context-length", type=int, default=512, help="Context length for activation analysis")
|
| 857 |
+
parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: auto)")
|
| 858 |
+
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| 859 |
+
parser.add_argument("--no-activations", action="store_true", help="Skip activation analysis even if data provided")
|
| 860 |
+
args = parser.parse_args()
|
| 861 |
+
|
| 862 |
+
device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| 863 |
+
print(f"Device: {device}")
|
| 864 |
+
|
| 865 |
+
# Output directory
|
| 866 |
+
if args.output_dir:
|
| 867 |
+
output_dir = Path(args.output_dir)
|
| 868 |
+
else:
|
| 869 |
+
ckpt_name = Path(args.checkpoint).parent.name
|
| 870 |
+
output_dir = Path("circuits/scripts/spectral_output") / ckpt_name
|
| 871 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 872 |
+
print(f"Output: {output_dir}")
|
| 873 |
+
|
| 874 |
+
# ── Load model A ──
|
| 875 |
+
print(f"\nLoading: {args.checkpoint}")
|
| 876 |
+
model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
|
| 877 |
+
label_a = Path(args.checkpoint).parent.name
|
| 878 |
+
print(f" Type: {model_type_a}")
|
| 879 |
+
n_params = sum(p.numel() for p in model_a.parameters())
|
| 880 |
+
print(f" Parameters: {n_params:,}")
|
| 881 |
+
|
| 882 |
+
# ── Weight spectra (A) ──
|
| 883 |
+
print("\nAnalyzing weight spectra...")
|
| 884 |
+
weight_results_a = analyze_weight_spectra(model_a, label_a)
|
| 885 |
+
print(f" Analyzed {len(weight_results_a)} weight matrices")
|
| 886 |
+
|
| 887 |
+
# ── Activation spectra (A) ──
|
| 888 |
+
act_spectra_a = None
|
| 889 |
+
if args.data and not args.no_activations:
|
| 890 |
+
tokenizer_name = torch.load(args.checkpoint, map_location="cpu",
|
| 891 |
+
weights_only=False).get("tokenizer_name", "gpt2")
|
| 892 |
+
print(f"\nLoading data for activation analysis ({args.num_samples} samples)...")
|
| 893 |
+
input_ids, tokenizer = load_sample_data(
|
| 894 |
+
args.data, tokenizer_name, args.num_samples, args.context_length, device
|
| 895 |
+
)
|
| 896 |
+
if input_ids is not None:
|
| 897 |
+
print(f" Data shape: {input_ids.shape}")
|
| 898 |
+
|
| 899 |
+
# Compute word positions if needed
|
| 900 |
+
word_positions = None
|
| 901 |
+
word_rope_dims = config_a.get("word_rope_dims", 0)
|
| 902 |
+
if word_rope_dims > 0:
|
| 903 |
+
from circuits.layers import build_word_start_table, compute_word_positions
|
| 904 |
+
word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| 905 |
+
word_positions = compute_word_positions(input_ids, word_start_table)
|
| 906 |
+
|
| 907 |
+
print(" Collecting activations...")
|
| 908 |
+
if model_type_a == "mirrored":
|
| 909 |
+
raw_acts = collect_mirrored_activations(model_a, input_ids, word_positions)
|
| 910 |
+
else:
|
| 911 |
+
raw_acts = collect_activations(model_a, input_ids, word_positions, model_type_a)
|
| 912 |
+
|
| 913 |
+
print(f" Computing activation spectra ({len(raw_acts)} layers)...")
|
| 914 |
+
act_spectra_a = {}
|
| 915 |
+
for name, act in raw_acts.items():
|
| 916 |
+
spec = activation_spectrum(act)
|
| 917 |
+
if spec is not None:
|
| 918 |
+
act_spectra_a[name] = spec
|
| 919 |
+
|
| 920 |
+
# ── Model B (optional comparison) ──
|
| 921 |
+
weight_results_b = None
|
| 922 |
+
label_b = None
|
| 923 |
+
if args.checkpoint_b:
|
| 924 |
+
print(f"\nLoading comparison: {args.checkpoint_b}")
|
| 925 |
+
model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
|
| 926 |
+
label_b = Path(args.checkpoint_b).parent.name
|
| 927 |
+
weight_results_b = analyze_weight_spectra(model_b, label_b)
|
| 928 |
+
del model_b
|
| 929 |
+
elif args.hf_model:
|
| 930 |
+
print(f"\nLoading HF model: {args.hf_model}")
|
| 931 |
+
model_b = load_hf_model(args.hf_model, device)
|
| 932 |
+
label_b = args.hf_model
|
| 933 |
+
weight_results_b = analyze_weight_spectra(model_b, label_b)
|
| 934 |
+
del model_b
|
| 935 |
+
|
| 936 |
+
if device.startswith("cuda"):
|
| 937 |
+
torch.cuda.empty_cache()
|
| 938 |
+
|
| 939 |
+
# ── Plots ──
|
| 940 |
+
print("\nGenerating plots...")
|
| 941 |
+
plot_weight_spectra(weight_results_a, output_dir, label_a)
|
| 942 |
+
plot_effective_rank_progression(weight_results_a, output_dir, label_a)
|
| 943 |
+
plot_gate_spectra(weight_results_a, output_dir, label_a)
|
| 944 |
+
|
| 945 |
+
if act_spectra_a:
|
| 946 |
+
plot_activation_spectra(act_spectra_a, output_dir, label_a)
|
| 947 |
+
plot_mirror_comparison(act_spectra_a, output_dir, label_a)
|
| 948 |
+
plot_embedding_alignment(weight_results_a, act_spectra_a, output_dir, label_a)
|
| 949 |
+
|
| 950 |
+
if weight_results_b and label_b:
|
| 951 |
+
plot_comparison(weight_results_a, weight_results_b, label_a, label_b, output_dir)
|
| 952 |
+
# Also print summary for B
|
| 953 |
+
print_summary(weight_results_b, label_b)
|
| 954 |
+
|
| 955 |
+
# ── Summary ──
|
| 956 |
+
print_summary(weight_results_a, label_a, act_spectra_a)
|
| 957 |
+
|
| 958 |
+
# ── Save ──
|
| 959 |
+
save_results_json(weight_results_a, act_spectra_a, output_dir / "results.json")
|
| 960 |
+
if weight_results_b:
|
| 961 |
+
save_results_json(weight_results_b, None, output_dir / "results_b.json")
|
| 962 |
+
|
| 963 |
+
print(f"\nAll outputs saved to: {output_dir}")
|
| 964 |
+
print(f" Plots: {len(list(output_dir.glob('*.png')))} PNG files")
|
| 965 |
+
print(f" Data: results.json")
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
if __name__ == "__main__":
|
| 969 |
+
main()
|
scripts/spectral_to_csv.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert spectral analysis JSON results to CSV tables for analysis."""
|
| 2 |
+
import json
|
| 3 |
+
import csv
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def classify_layer(name, model_type):
|
| 11 |
+
"""Classify a weight matrix by layer index, component type, and phase."""
|
| 12 |
+
if model_type == "prisma":
|
| 13 |
+
# mirror_blocks.N.component
|
| 14 |
+
m = re.match(r'mirror_blocks\.(\d+)\.', name)
|
| 15 |
+
if m:
|
| 16 |
+
layer_idx = int(m.group(1))
|
| 17 |
+
phase = "mirror"
|
| 18 |
+
if 'attn' in name:
|
| 19 |
+
comp = 'Q' if 'q_proj' in name else 'K' if 'k_proj' in name else 'V' if 'v_proj' in name else 'O' if 'o_proj' in name else 'attn'
|
| 20 |
+
elif 'ffn.w3' in name or 'gate_expand' in name:
|
| 21 |
+
comp = 'W3'
|
| 22 |
+
elif 'ffn.w4' in name or 'gate_compress' in name:
|
| 23 |
+
comp = 'W4'
|
| 24 |
+
elif 'ffn.w1' in name:
|
| 25 |
+
comp = 'W1'
|
| 26 |
+
elif 'w2' in name:
|
| 27 |
+
comp = 'W2'
|
| 28 |
+
else:
|
| 29 |
+
comp = 'other'
|
| 30 |
+
return layer_idx, comp, phase
|
| 31 |
+
|
| 32 |
+
m = re.match(r'middle_blocks\.(\d+)\.', name)
|
| 33 |
+
if m:
|
| 34 |
+
layer_idx = int(m.group(1))
|
| 35 |
+
phase = "middle"
|
| 36 |
+
if 'attn' in name:
|
| 37 |
+
comp = 'Q' if 'q_proj' in name else 'K' if 'k_proj' in name else 'V' if 'v_proj' in name else 'O' if 'o_proj' in name else 'attn'
|
| 38 |
+
elif 'gate' in name:
|
| 39 |
+
comp = 'W3'
|
| 40 |
+
elif 'ffn.w1' in name:
|
| 41 |
+
comp = 'W1'
|
| 42 |
+
elif 'ffn.w2' in name:
|
| 43 |
+
comp = 'W2'
|
| 44 |
+
else:
|
| 45 |
+
comp = 'other'
|
| 46 |
+
return layer_idx, comp, phase
|
| 47 |
+
|
| 48 |
+
m = re.match(r'(first|last)_block\.', name)
|
| 49 |
+
if m:
|
| 50 |
+
phase = m.group(1)
|
| 51 |
+
if 'attn' in name:
|
| 52 |
+
comp = 'Q' if 'q_proj' in name else 'K' if 'k_proj' in name else 'V' if 'v_proj' in name else 'O' if 'o_proj' in name else 'attn'
|
| 53 |
+
elif 'ffn.w3' in name or 'gate' in name:
|
| 54 |
+
comp = 'W3'
|
| 55 |
+
elif 'ffn.w4' in name:
|
| 56 |
+
comp = 'W4'
|
| 57 |
+
elif 'ffn.w1' in name:
|
| 58 |
+
comp = 'W1'
|
| 59 |
+
elif 'ffn.w2' in name:
|
| 60 |
+
comp = 'W2'
|
| 61 |
+
else:
|
| 62 |
+
comp = 'other'
|
| 63 |
+
return 0, comp, phase
|
| 64 |
+
|
| 65 |
+
if 'embed' in name:
|
| 66 |
+
return -1, 'embed', 'embed'
|
| 67 |
+
if 'head' in name or 'lm_head' in name:
|
| 68 |
+
return 99, 'head', 'head'
|
| 69 |
+
return -1, 'other', 'other'
|
| 70 |
+
|
| 71 |
+
else: # GPT-2 style
|
| 72 |
+
m = re.match(r'transformer\.h\.(\d+)\.', name)
|
| 73 |
+
if m:
|
| 74 |
+
layer_idx = int(m.group(1))
|
| 75 |
+
if 'c_attn' in name:
|
| 76 |
+
comp = 'QKV'
|
| 77 |
+
elif 'c_proj' in name and 'mlp' not in name:
|
| 78 |
+
comp = 'O'
|
| 79 |
+
elif 'c_fc' in name:
|
| 80 |
+
comp = 'W1'
|
| 81 |
+
elif 'mlp.c_proj' in name:
|
| 82 |
+
comp = 'W2'
|
| 83 |
+
else:
|
| 84 |
+
comp = 'other'
|
| 85 |
+
return layer_idx, comp, "layer"
|
| 86 |
+
|
| 87 |
+
if 'wte' in name:
|
| 88 |
+
return -1, 'embed', 'embed'
|
| 89 |
+
if 'wpe' in name:
|
| 90 |
+
return -1, 'pos_embed', 'embed'
|
| 91 |
+
return -1, 'other', 'other'
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def json_to_csvs(json_path, output_dir, model_type="prisma"):
|
| 95 |
+
with open(json_path) as f:
|
| 96 |
+
data = json.load(f)
|
| 97 |
+
|
| 98 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
# 1. Full weight matrix summary
|
| 101 |
+
rows = []
|
| 102 |
+
for name, info in data.items():
|
| 103 |
+
if 'activation' in name or name.startswith('_'):
|
| 104 |
+
continue
|
| 105 |
+
layer_idx, comp, phase = classify_layer(name, model_type)
|
| 106 |
+
rows.append({
|
| 107 |
+
'name': name,
|
| 108 |
+
'layer_idx': layer_idx,
|
| 109 |
+
'component': comp,
|
| 110 |
+
'phase': phase,
|
| 111 |
+
'shape': 'x'.join(str(s) for s in info['shape']),
|
| 112 |
+
'effective_rank': round(info['effective_rank'], 2),
|
| 113 |
+
'stable_rank': round(info['stable_rank'], 3),
|
| 114 |
+
'spectral_norm': round(info['spectral_norm'], 4),
|
| 115 |
+
'frobenius_norm': round(info['frobenius_norm'], 4),
|
| 116 |
+
'alpha': round(info['alpha'], 4),
|
| 117 |
+
'alpha_r2': round(info['alpha_r2'], 4),
|
| 118 |
+
'signal_ratio': round(info['signal_ratio'], 4),
|
| 119 |
+
'condition_number': round(info['condition_number'], 2),
|
| 120 |
+
'mp_bound': round(info['mp_bound'], 4),
|
| 121 |
+
'n_above_mp': info['n_above_mp'],
|
| 122 |
+
'n_total': info['n_total'],
|
| 123 |
+
'sv_1': round(info['top_10_sv'][0], 4) if info['top_10_sv'] else 0,
|
| 124 |
+
'sv_2': round(info['top_10_sv'][1], 4) if len(info['top_10_sv']) > 1 else 0,
|
| 125 |
+
'sv_10': round(info['top_10_sv'][9], 4) if len(info['top_10_sv']) > 9 else 0,
|
| 126 |
+
'sv1_sv2_ratio': round(info['top_10_sv'][0] / info['top_10_sv'][1], 4) if len(info['top_10_sv']) > 1 and info['top_10_sv'][1] > 0 else 0,
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
with open(os.path.join(output_dir, 'weights_full.csv'), 'w', newline='') as f:
|
| 130 |
+
w = csv.DictWriter(f, fieldnames=rows[0].keys())
|
| 131 |
+
w.writeheader()
|
| 132 |
+
w.writerows(sorted(rows, key=lambda r: (r['phase'], r['layer_idx'], r['component'])))
|
| 133 |
+
|
| 134 |
+
# 2. Layer-level FFN summary (W1 progression = the lens)
|
| 135 |
+
ffn_rows = [r for r in rows if r['component'] == 'W1']
|
| 136 |
+
with open(os.path.join(output_dir, 'ffn_w1_progression.csv'), 'w', newline='') as f:
|
| 137 |
+
w = csv.DictWriter(f, fieldnames=['layer_idx', 'phase', 'effective_rank', 'stable_rank', 'alpha', 'alpha_r2', 'signal_ratio', 'condition_number', 'sv1_sv2_ratio'])
|
| 138 |
+
w.writeheader()
|
| 139 |
+
for r in sorted(ffn_rows, key=lambda r: (r['phase'], r['layer_idx'])):
|
| 140 |
+
w.writerow({k: r[k] for k in w.fieldnames})
|
| 141 |
+
|
| 142 |
+
# 3. Gate comparison (W3 vs W4)
|
| 143 |
+
gate_rows = [r for r in rows if r['component'] in ('W3', 'W4') and r['phase'] == 'mirror']
|
| 144 |
+
with open(os.path.join(output_dir, 'gate_comparison.csv'), 'w', newline='') as f:
|
| 145 |
+
w = csv.DictWriter(f, fieldnames=['layer_idx', 'component', 'effective_rank', 'stable_rank', 'alpha', 'alpha_r2', 'signal_ratio', 'sv1_sv2_ratio'])
|
| 146 |
+
w.writeheader()
|
| 147 |
+
for r in sorted(gate_rows, key=lambda r: (r['layer_idx'], r['component'])):
|
| 148 |
+
w.writerow({k: r[k] for k in w.fieldnames})
|
| 149 |
+
|
| 150 |
+
# 4. Attention head comparison (Q, K, V, O per layer)
|
| 151 |
+
attn_rows = [r for r in rows if r['component'] in ('Q', 'K', 'V', 'O', 'QKV')]
|
| 152 |
+
with open(os.path.join(output_dir, 'attention_progression.csv'), 'w', newline='') as f:
|
| 153 |
+
w = csv.DictWriter(f, fieldnames=['layer_idx', 'phase', 'component', 'effective_rank', 'stable_rank', 'alpha', 'signal_ratio', 'condition_number'])
|
| 154 |
+
w.writeheader()
|
| 155 |
+
for r in sorted(attn_rows, key=lambda r: (r['phase'], r['layer_idx'], r['component'])):
|
| 156 |
+
w.writerow({k: r[k] for k in w.fieldnames})
|
| 157 |
+
|
| 158 |
+
# 5. Summary statistics
|
| 159 |
+
alphas = [r['alpha'] for r in rows if r['alpha'] > 0]
|
| 160 |
+
eff_ranks = [r['effective_rank'] for r in rows if r['layer_idx'] >= 0]
|
| 161 |
+
signal_ratios = [r['signal_ratio'] for r in rows if r['layer_idx'] >= 0]
|
| 162 |
+
|
| 163 |
+
summary = {
|
| 164 |
+
'n_matrices': len(rows),
|
| 165 |
+
'mean_alpha': round(sum(alphas) / len(alphas), 4) if alphas else 0,
|
| 166 |
+
'min_alpha': round(min(alphas), 4) if alphas else 0,
|
| 167 |
+
'max_alpha': round(max(alphas), 4) if alphas else 0,
|
| 168 |
+
'mean_effective_rank': round(sum(eff_ranks) / len(eff_ranks), 2) if eff_ranks else 0,
|
| 169 |
+
'mean_signal_ratio': round(sum(signal_ratios) / len(signal_ratios), 4) if signal_ratios else 0,
|
| 170 |
+
'n_well_trained (alpha<2)': sum(1 for a in alphas if a < 2.0),
|
| 171 |
+
'n_total_alpha': len(alphas),
|
| 172 |
+
}
|
| 173 |
+
with open(os.path.join(output_dir, 'summary.csv'), 'w', newline='') as f:
|
| 174 |
+
w = csv.DictWriter(f, fieldnames=summary.keys())
|
| 175 |
+
w.writeheader()
|
| 176 |
+
w.writerow(summary)
|
| 177 |
+
|
| 178 |
+
print(f"Wrote CSVs to {output_dir}/")
|
| 179 |
+
print(f" weights_full.csv ({len(rows)} matrices)")
|
| 180 |
+
print(f" ffn_w1_progression.csv ({len(ffn_rows)} layers)")
|
| 181 |
+
print(f" gate_comparison.csv ({len(gate_rows)} entries)")
|
| 182 |
+
print(f" attention_progression.csv ({len(attn_rows)} entries)")
|
| 183 |
+
print(f" summary.csv")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
base = "circuits/scripts/spectral_output/mirrored_300M_mk4_cont"
|
| 188 |
+
|
| 189 |
+
# Prisma
|
| 190 |
+
json_to_csvs(
|
| 191 |
+
f"{base}/results.json",
|
| 192 |
+
f"{base}/csv_prisma",
|
| 193 |
+
model_type="prisma"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# GPT-2 medium
|
| 197 |
+
if os.path.exists(f"{base}/results_b.json"):
|
| 198 |
+
json_to_csvs(
|
| 199 |
+
f"{base}/results_b.json",
|
| 200 |
+
f"{base}/csv_gpt2",
|
| 201 |
+
model_type="gpt2"
|
| 202 |
+
)
|
train.py
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Training script for Circuit Transformer.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python circuits/train.py --data hf:roneneldan/TinyStories --preset tiny --epochs 1 --gpu 0
|
| 7 |
+
python circuits/train.py --data path/to/corpus.txt --dims 256 --layers 6 --fp16
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import gc
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
import math
|
| 14 |
+
import random
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch.cuda.amp import GradScaler
|
| 21 |
+
from torch.amp import autocast
|
| 22 |
+
|
| 23 |
+
from .config import CircuitConfig, parse_args
|
| 24 |
+
from .model import CircuitTransformer, count_parameters
|
| 25 |
+
from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
|
| 26 |
+
from .graft_g2lu import G2LU_GraftedModel, save_g2lu_checkpoint
|
| 27 |
+
from .layers import build_word_start_table, compute_word_positions
|
| 28 |
+
from .data import get_tokenizer, load_data, create_dataloader
|
| 29 |
+
|
| 30 |
+
def corrupt_tokens(input_ids, ratio, vocab_size):
|
| 31 |
+
"""Replace random tokens with random vocab tokens for denoising autoencoder.
|
| 32 |
+
|
| 33 |
+
Returns (corrupted_ids, mask) where mask is True at corrupted positions.
|
| 34 |
+
"""
|
| 35 |
+
mask = torch.rand(input_ids.shape, device=input_ids.device) < ratio
|
| 36 |
+
mask[:, 0] = False # never corrupt first token (BOS/start)
|
| 37 |
+
random_tokens = torch.randint(0, vocab_size, input_ids.shape, device=input_ids.device)
|
| 38 |
+
corrupted = input_ids.clone()
|
| 39 |
+
corrupted[mask] = random_tokens[mask]
|
| 40 |
+
return corrupted, mask
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def evaluate(config, model, dataloader, device, use_amp=False, amp_dtype=torch.float16, mid_run_eval=False,
|
| 45 |
+
word_start_table=None):
|
| 46 |
+
"""Run validation and return avg loss + perplexity."""
|
| 47 |
+
model.eval()
|
| 48 |
+
total_loss = 0.0
|
| 49 |
+
n_batches = 0
|
| 50 |
+
|
| 51 |
+
for batch in dataloader:
|
| 52 |
+
input_ids = batch["input_ids"].to(device)
|
| 53 |
+
labels = batch["labels"].to(device)
|
| 54 |
+
word_positions = None
|
| 55 |
+
if word_start_table is not None:
|
| 56 |
+
word_positions = compute_word_positions(input_ids, word_start_table)
|
| 57 |
+
|
| 58 |
+
if use_amp:
|
| 59 |
+
with autocast('cuda', dtype=amp_dtype):
|
| 60 |
+
output = model(input_ids, labels=labels, word_positions=word_positions)
|
| 61 |
+
else:
|
| 62 |
+
output = model(input_ids, labels=labels, word_positions=word_positions)
|
| 63 |
+
|
| 64 |
+
total_loss += output["loss"].item()
|
| 65 |
+
n_batches += 1
|
| 66 |
+
|
| 67 |
+
if n_batches % (config.log_every * 10) == 0:
|
| 68 |
+
avg_loss = total_loss / max(n_batches, 1)
|
| 69 |
+
ppl = math.exp(min(avg_loss, 20))
|
| 70 |
+
print(
|
| 71 |
+
f"batch {n_batches:6d}/{len(dataloader):6d} | "
|
| 72 |
+
f"Loss {total_loss / n_batches:.4f} | "
|
| 73 |
+
f"PPL {ppl:8.2f}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if mid_run_eval and n_batches >= 1500 :
|
| 77 |
+
break
|
| 78 |
+
|
| 79 |
+
if not mid_run_eval:
|
| 80 |
+
model.train()
|
| 81 |
+
|
| 82 |
+
avg_loss = total_loss / max(n_batches, 1)
|
| 83 |
+
ppl = math.exp(min(avg_loss, 20)) # cap to avoid overflow
|
| 84 |
+
|
| 85 |
+
return avg_loss, ppl
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_lr(step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float = 0.0, delay: int = 0) -> float:
|
| 89 |
+
"""Cosine learning rate schedule with warmup and optional delay.
|
| 90 |
+
|
| 91 |
+
With delay > 0, the schedule is shifted:
|
| 92 |
+
Steps 0..delay: LR = 0 (frozen)
|
| 93 |
+
Steps delay..delay+warmup: linear ramp 0 → max_lr
|
| 94 |
+
Steps delay+warmup..max_steps: cosine decay max_lr → min_lr
|
| 95 |
+
"""
|
| 96 |
+
if step < delay:
|
| 97 |
+
return 0.0
|
| 98 |
+
effective_step = step - delay
|
| 99 |
+
effective_max = max(1, max_steps - delay)
|
| 100 |
+
if effective_step < warmup_steps:
|
| 101 |
+
return max_lr * effective_step / warmup_steps
|
| 102 |
+
if effective_step >= effective_max:
|
| 103 |
+
return min_lr
|
| 104 |
+
progress = (effective_step - warmup_steps) / (effective_max - warmup_steps)
|
| 105 |
+
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_checkpoint(
|
| 109 |
+
model: nn.Module,
|
| 110 |
+
optimizer: torch.optim.Optimizer,
|
| 111 |
+
step: int,
|
| 112 |
+
epoch: int,
|
| 113 |
+
loss: float,
|
| 114 |
+
config,
|
| 115 |
+
path: str,
|
| 116 |
+
model_type: str = "standard",
|
| 117 |
+
epoch_step: int = 0,
|
| 118 |
+
best_val_loss: float | None = None,
|
| 119 |
+
scaler=None,
|
| 120 |
+
tokenizer_name: str = "gpt2",
|
| 121 |
+
):
|
| 122 |
+
"""Save training checkpoint.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
epoch: Next epoch to start on resume (completed epoch count).
|
| 126 |
+
epoch_step: Batches already processed in `epoch` (0 if epoch is complete).
|
| 127 |
+
optimizer_mid: Middle optimizer for dual-path training (optional).
|
| 128 |
+
"""
|
| 129 |
+
checkpoint = {
|
| 130 |
+
"model": model.state_dict(),
|
| 131 |
+
"optimizer": optimizer.state_dict(),
|
| 132 |
+
"step": step,
|
| 133 |
+
"epoch": epoch,
|
| 134 |
+
"epoch_step": epoch_step,
|
| 135 |
+
"loss": loss,
|
| 136 |
+
"config": config.to_dict(),
|
| 137 |
+
"model_type": model_type,
|
| 138 |
+
"tokenizer_name": tokenizer_name,
|
| 139 |
+
}
|
| 140 |
+
if best_val_loss is not None:
|
| 141 |
+
checkpoint["best_val_loss"] = best_val_loss
|
| 142 |
+
if scaler is not None:
|
| 143 |
+
checkpoint["scaler"] = scaler.state_dict()
|
| 144 |
+
torch.save(checkpoint, path)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
|
| 148 |
+
"""Migrate checkpoint state_dict to match current model architecture.
|
| 149 |
+
|
| 150 |
+
Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
|
| 151 |
+
"""
|
| 152 |
+
model_keys = set(model.state_dict().keys())
|
| 153 |
+
ckpt_keys = set(state_dict.keys())
|
| 154 |
+
|
| 155 |
+
missing = model_keys - ckpt_keys
|
| 156 |
+
unexpected = ckpt_keys - model_keys
|
| 157 |
+
|
| 158 |
+
if not missing and not unexpected:
|
| 159 |
+
return state_dict # perfect match, no migration needed
|
| 160 |
+
|
| 161 |
+
migrated = dict(state_dict)
|
| 162 |
+
migrations = []
|
| 163 |
+
|
| 164 |
+
# SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
|
| 165 |
+
for key in list(unexpected):
|
| 166 |
+
if ".ffn.gate_expand.weight" in key:
|
| 167 |
+
new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
|
| 168 |
+
if new_key in missing:
|
| 169 |
+
migrated[new_key] = migrated.pop(key)
|
| 170 |
+
missing.discard(new_key)
|
| 171 |
+
unexpected.discard(key)
|
| 172 |
+
migrations.append(f" {key} → {new_key}")
|
| 173 |
+
if ".ffn.gate_compress.weight" in key:
|
| 174 |
+
new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
|
| 175 |
+
if new_key in missing:
|
| 176 |
+
migrated[new_key] = migrated.pop(key)
|
| 177 |
+
missing.discard(new_key)
|
| 178 |
+
unexpected.discard(key)
|
| 179 |
+
migrations.append(f" {key} → {new_key}")
|
| 180 |
+
|
| 181 |
+
if migrations:
|
| 182 |
+
print(f"State dict migration ({len(migrations)} keys renamed):")
|
| 183 |
+
for m in migrations:
|
| 184 |
+
print(m)
|
| 185 |
+
# Report remaining missing keys (freshly initialized)
|
| 186 |
+
still_missing = model_keys - set(migrated.keys())
|
| 187 |
+
if still_missing:
|
| 188 |
+
print(f" New parameters (freshly initialized): {len(still_missing)}")
|
| 189 |
+
for k in sorted(still_missing):
|
| 190 |
+
print(f" {k}")
|
| 191 |
+
|
| 192 |
+
return migrated
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer = None,
|
| 196 |
+
scaler=None, reset:bool = False):
|
| 197 |
+
"""Load training checkpoint. Returns dict with resume info."""
|
| 198 |
+
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
| 199 |
+
state_dict = _migrate_state_dict(checkpoint["model"], model)
|
| 200 |
+
model.load_state_dict(state_dict, strict=False)
|
| 201 |
+
if not reset:
|
| 202 |
+
if optimizer is not None and "optimizer" in checkpoint:
|
| 203 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 204 |
+
if scaler is not None and "scaler" in checkpoint:
|
| 205 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 206 |
+
return {
|
| 207 |
+
"step": checkpoint.get("step", 0),
|
| 208 |
+
"epoch": checkpoint.get("epoch", 0),
|
| 209 |
+
"epoch_step": checkpoint.get("epoch_step", 0),
|
| 210 |
+
"best_val_loss": checkpoint.get("best_val_loss", float("inf")),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def train():
|
| 215 |
+
config, args = parse_args()
|
| 216 |
+
|
| 217 |
+
# Setup device
|
| 218 |
+
device = torch.device(f"cuda:{config.gpu}" if torch.cuda.is_available() else "cpu")
|
| 219 |
+
print(f"Device: {device}")
|
| 220 |
+
|
| 221 |
+
# Load tokenizer and data
|
| 222 |
+
print(f"Loading data from: {args.data}")
|
| 223 |
+
model_type = args.arch
|
| 224 |
+
tokenizer_name = getattr(args, 'tokenizer', 'gpt2')
|
| 225 |
+
if model_type == "graft_g2lu":
|
| 226 |
+
tokenizer_name = args.pretrained
|
| 227 |
+
tokenizer = get_tokenizer(tokenizer_name)
|
| 228 |
+
config.vocab_size = len(tokenizer)
|
| 229 |
+
print(f"Tokenizer: {tokenizer_name} (vocab_size={config.vocab_size})")
|
| 230 |
+
cache_dir = None if args.no_cache else args.cache_dir
|
| 231 |
+
dataset = load_data(
|
| 232 |
+
args.data,
|
| 233 |
+
tokenizer,
|
| 234 |
+
config.max_seq_len,
|
| 235 |
+
text_column=args.text_column,
|
| 236 |
+
num_samples=args.num_samples,
|
| 237 |
+
cache_dir=cache_dir,
|
| 238 |
+
data_format=args.data_format,
|
| 239 |
+
)
|
| 240 |
+
print(f"Loaded {len(dataset):,} chunks")
|
| 241 |
+
|
| 242 |
+
# Train/val split
|
| 243 |
+
val_split = args.val_split
|
| 244 |
+
if val_split > 0 and len(dataset) > 20:
|
| 245 |
+
train_dataset, val_dataset = dataset.split(val_split)
|
| 246 |
+
print(f"Split: {len(train_dataset):,} train / {len(val_dataset):,} val ({val_split:.0%})")
|
| 247 |
+
else:
|
| 248 |
+
train_dataset = dataset
|
| 249 |
+
val_dataset = None
|
| 250 |
+
|
| 251 |
+
# Create dataloaders
|
| 252 |
+
dataloader = create_dataloader(
|
| 253 |
+
train_dataset,
|
| 254 |
+
config.batch_size,
|
| 255 |
+
shuffle=True,
|
| 256 |
+
)
|
| 257 |
+
val_dataloader = None
|
| 258 |
+
if val_dataset is not None:
|
| 259 |
+
val_dataloader = create_dataloader(
|
| 260 |
+
val_dataset,
|
| 261 |
+
config.batch_size,
|
| 262 |
+
shuffle=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Create model
|
| 266 |
+
if model_type == "mirrored":
|
| 267 |
+
model_config = MirroredConfig(
|
| 268 |
+
vocab_size=config.vocab_size,
|
| 269 |
+
hidden_size=config.hidden_size,
|
| 270 |
+
num_heads=config.num_heads,
|
| 271 |
+
num_kv_heads=config.num_kv_heads,
|
| 272 |
+
num_layers=config.num_layers,
|
| 273 |
+
n_middle=args.n_middle,
|
| 274 |
+
max_seq_len=config.max_seq_len,
|
| 275 |
+
dropout=config.dropout,
|
| 276 |
+
use_g2lu=not getattr(args, 'no_g2lu', False),
|
| 277 |
+
aux_skip_k=getattr(args, 'aux_skip', 0),
|
| 278 |
+
aux_skip_weight=getattr(args, 'aux_weight', 0.1),
|
| 279 |
+
word_rope_dims=getattr(config, 'word_rope_dims', 0),
|
| 280 |
+
word_rope_base=getattr(config, 'word_rope_base', 10.0),
|
| 281 |
+
embed_dim=getattr(config, 'embed_dim', 0),
|
| 282 |
+
head_dim=getattr(config, 'head_dim', 0),
|
| 283 |
+
)
|
| 284 |
+
model = MirroredTransformer(model_config).to(device)
|
| 285 |
+
param_info = count_mirrored_parameters(model)
|
| 286 |
+
num_params = param_info["unique"]
|
| 287 |
+
print(f"Model: MirroredTransformer")
|
| 288 |
+
print(f" Virtual layers: {model.total_virtual_layers} ({model_config.n_mirror} mirror pairs + {model_config.n_middle} middle)")
|
| 289 |
+
print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M unique)")
|
| 290 |
+
print(f" Shared FFN base: {param_info['shared_ffn_base']:,}")
|
| 291 |
+
print(f" Direction gates: {param_info['direction_gates']:,}")
|
| 292 |
+
print(f" FFN gating: {'G²LU (nested dual gate)' if model_config.use_g2lu else 'SwiGLU (vanilla)'}")
|
| 293 |
+
if model_config.num_kv_heads is not None:
|
| 294 |
+
print(f" GQA: {model_config.num_heads}Q / {model_config.num_kv_heads}KV ({model_config.num_heads // model_config.num_kv_heads}:1 ratio)")
|
| 295 |
+
if model_config.aux_skip_k > 0:
|
| 296 |
+
print(f" Aux skip prediction: t+{model_config.aux_skip_k} (weight={model_config.aux_skip_weight})")
|
| 297 |
+
if getattr(model_config, 'embed_dim', 0) > 0:
|
| 298 |
+
std_embed = config.vocab_size * config.hidden_size
|
| 299 |
+
fact_embed = config.vocab_size * model_config.embed_dim + model_config.embed_dim * config.hidden_size
|
| 300 |
+
print(f" Factorized embedding: {model_config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
|
| 301 |
+
if getattr(model_config, 'head_dim', 0) > 0:
|
| 302 |
+
std_head = config.hidden_size * config.vocab_size
|
| 303 |
+
mlp_head = config.hidden_size * model_config.head_dim + model_config.head_dim * config.vocab_size
|
| 304 |
+
print(f" MLP head: {config.hidden_size} → {model_config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
|
| 305 |
+
elif model_type == "graft_g2lu":
|
| 306 |
+
assert args.pretrained, "--pretrained is required for graft_g2lu architecture"
|
| 307 |
+
amp_dtype = torch.bfloat16 if config.bf16 else (torch.float16 if config.fp16 else torch.float32)
|
| 308 |
+
model = G2LU_GraftedModel(
|
| 309 |
+
pretrained_name=args.pretrained,
|
| 310 |
+
align_weight=args.align_weight,
|
| 311 |
+
warmup_steps=args.graft_warmup,
|
| 312 |
+
device=device,
|
| 313 |
+
dtype=amp_dtype,
|
| 314 |
+
)
|
| 315 |
+
model_config = None # No CircuitConfig for HF models
|
| 316 |
+
num_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
|
| 317 |
+
else:
|
| 318 |
+
model_config = config
|
| 319 |
+
model = CircuitTransformer(config).to(device)
|
| 320 |
+
num_params = count_parameters(model)
|
| 321 |
+
print(f"Model: CircuitTransformer")
|
| 322 |
+
print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
| 323 |
+
if getattr(config, 'aux_skip_k', 0) > 0:
|
| 324 |
+
print(f" Aux skip prediction: t+{config.aux_skip_k} (weight={config.aux_skip_weight})")
|
| 325 |
+
if getattr(config, 'embed_dim', 0) > 0:
|
| 326 |
+
std_embed = config.vocab_size * config.hidden_size
|
| 327 |
+
fact_embed = config.vocab_size * config.embed_dim + config.embed_dim * config.hidden_size
|
| 328 |
+
print(f" Factorized embedding: {config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
|
| 329 |
+
if getattr(config, 'head_dim', 0) > 0:
|
| 330 |
+
std_head = config.hidden_size * config.vocab_size
|
| 331 |
+
mlp_head = config.hidden_size * config.head_dim + config.head_dim * config.vocab_size
|
| 332 |
+
print(f" MLP head: {config.hidden_size} → {config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
|
| 333 |
+
|
| 334 |
+
# Build word-position table if enabled
|
| 335 |
+
word_rope_dims = getattr(config, 'word_rope_dims', 0)
|
| 336 |
+
if word_rope_dims > 0:
|
| 337 |
+
word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| 338 |
+
print(f" Word-position RoPE: {word_rope_dims} dims, base={getattr(config, 'word_rope_base', 10.0)}")
|
| 339 |
+
print(f" Word starters in vocab: {word_start_table.sum().item():,} / {len(tokenizer):,}")
|
| 340 |
+
else:
|
| 341 |
+
word_start_table = None
|
| 342 |
+
|
| 343 |
+
# Keep raw reference for set_gate_step (torch.compile wraps the model)
|
| 344 |
+
raw_model = model
|
| 345 |
+
|
| 346 |
+
# Optionally compile
|
| 347 |
+
if config.compile and hasattr(torch, "compile"):
|
| 348 |
+
print("Compiling model with torch.compile...")
|
| 349 |
+
model = torch.compile(raw_model)
|
| 350 |
+
|
| 351 |
+
# Optimizer — with optional staggered warmup and dual-path training
|
| 352 |
+
grad_accum = getattr(args, 'grad_accum', 1)
|
| 353 |
+
|
| 354 |
+
opt_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
|
| 355 |
+
optimizer = torch.optim.AdamW(
|
| 356 |
+
opt_params,
|
| 357 |
+
lr=config.learning_rate,
|
| 358 |
+
weight_decay=config.weight_decay,
|
| 359 |
+
betas=(0.9, 0.95),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Mixed precision
|
| 363 |
+
use_amp = (config.fp16 or config.bf16) and device.type == "cuda"
|
| 364 |
+
amp_dtype = torch.bfloat16 if config.bf16 else torch.float16
|
| 365 |
+
scaler = GradScaler() if (config.fp16 and use_amp) else None
|
| 366 |
+
if use_amp:
|
| 367 |
+
print(f" Mixed precision: {'BF16' if config.bf16 else 'FP16'}" +
|
| 368 |
+
(" (no scaler)" if scaler is None else " (with GradScaler)"))
|
| 369 |
+
|
| 370 |
+
# Resume from checkpoint
|
| 371 |
+
start_step = 0
|
| 372 |
+
start_epoch = 0
|
| 373 |
+
skip_batches = 0
|
| 374 |
+
best_val_loss = float("inf")
|
| 375 |
+
if args.resume:
|
| 376 |
+
print(f"Resuming from: {args.resume}")
|
| 377 |
+
resume_info = load_checkpoint(args.resume, model, optimizer, scaler, args.reset)
|
| 378 |
+
if not args.reset:
|
| 379 |
+
start_step = resume_info["step"]
|
| 380 |
+
start_epoch = resume_info["epoch"]
|
| 381 |
+
skip_batches = resume_info["epoch_step"]
|
| 382 |
+
best_val_loss = resume_info["best_val_loss"]
|
| 383 |
+
print(f"Resumed at step {start_step}, epoch {start_epoch}" +
|
| 384 |
+
(f", skipping {skip_batches} batches" if skip_batches > 0 else ""))
|
| 385 |
+
if best_val_loss < float("inf"):
|
| 386 |
+
print(f" Best val loss so far: {best_val_loss:.4f} (PPL {math.exp(min(best_val_loss, 20)):.2f})")
|
| 387 |
+
|
| 388 |
+
# Setup checkpoint directory
|
| 389 |
+
checkpoint_dir = Path(config.checkpoint_dir)
|
| 390 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 391 |
+
|
| 392 |
+
# Training loop
|
| 393 |
+
steps_per_epoch = math.ceil(len(dataloader) / grad_accum)
|
| 394 |
+
max_steps = config.epochs * steps_per_epoch
|
| 395 |
+
tokens_per_step = config.batch_size * grad_accum * config.max_seq_len
|
| 396 |
+
total_train_tokens = config.epochs * len(dataloader) * config.batch_size * config.max_seq_len
|
| 397 |
+
step = start_step
|
| 398 |
+
model.train()
|
| 399 |
+
|
| 400 |
+
print(f"\nStarting training:")
|
| 401 |
+
print(f" Epochs: {config.epochs}")
|
| 402 |
+
print(f" Batch size: {config.batch_size}" +
|
| 403 |
+
(f" x {grad_accum} accum = {config.batch_size * grad_accum} effective" if grad_accum > 1 else ""))
|
| 404 |
+
print(f" Steps per epoch: {steps_per_epoch}" +
|
| 405 |
+
(f" ({len(dataloader)} micro-batches)" if grad_accum > 1 else ""))
|
| 406 |
+
print(f" Total steps: {max_steps}")
|
| 407 |
+
print(f" Total tokens: {total_train_tokens:,} ({total_train_tokens/1e6:.1f}M)")
|
| 408 |
+
if num_params > 0:
|
| 409 |
+
print(f" Tokens/param ratio: {total_train_tokens/num_params:.1f}x (Chinchilla=20x)")
|
| 410 |
+
print(f" Learning rate: {config.learning_rate}" +
|
| 411 |
+
(f" → {config.min_lr}" if config.min_lr > 0 else ""))
|
| 412 |
+
print(f" Mixed precision: {use_amp}")
|
| 413 |
+
print(f" Validation: {'enabled' if val_dataloader else 'disabled'}")
|
| 414 |
+
print()
|
| 415 |
+
|
| 416 |
+
total_loss = 0.0
|
| 417 |
+
log_steps = 0
|
| 418 |
+
total_tokens_seen = step * tokens_per_step
|
| 419 |
+
# best_val_loss already set in resume section above
|
| 420 |
+
h_mid_buffer = None
|
| 421 |
+
last_align_val = float("inf")
|
| 422 |
+
start_time = time.time()
|
| 423 |
+
|
| 424 |
+
for epoch in range(start_epoch, config.epochs):
|
| 425 |
+
epoch_start = time.time()
|
| 426 |
+
epoch_loss = 0.0
|
| 427 |
+
epoch_steps = 0
|
| 428 |
+
|
| 429 |
+
micro_batches = []
|
| 430 |
+
epoch_micro_batches = skip_batches if epoch == start_epoch else 0
|
| 431 |
+
|
| 432 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 433 |
+
# Skip already-processed batches on resume
|
| 434 |
+
if epoch == start_epoch and batch_idx < skip_batches:
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
micro_batches.append(batch)
|
| 438 |
+
epoch_micro_batches += 1
|
| 439 |
+
|
| 440 |
+
# Accumulate micro-batches (flush at accum boundary or epoch end)
|
| 441 |
+
if len(micro_batches) < grad_accum and batch_idx < len(dataloader) - 1:
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
n_micro = len(micro_batches)
|
| 445 |
+
actual_tokens = n_micro * config.batch_size * config.max_seq_len
|
| 446 |
+
|
| 447 |
+
# Update learning rate (per-group delays for staggered warmup)
|
| 448 |
+
for param_group in optimizer.param_groups:
|
| 449 |
+
delay = param_group.get("delay", 0)
|
| 450 |
+
param_group["lr"] = get_lr(step, config.warmup_steps, max_steps, config.learning_rate, min_lr=config.min_lr, delay=delay)
|
| 451 |
+
lr = optimizer.param_groups[0]["lr"] # for logging
|
| 452 |
+
|
| 453 |
+
loss_ed_val = None
|
| 454 |
+
loss_align_val = None
|
| 455 |
+
grad_norm_mid = None
|
| 456 |
+
absorb_loss_val = None
|
| 457 |
+
|
| 458 |
+
# Update blend alpha for G²LU grafting
|
| 459 |
+
if model_type == "graft_g2lu":
|
| 460 |
+
raw_model.set_step(step)
|
| 461 |
+
|
| 462 |
+
# === Standard single-path training with accumulation ===
|
| 463 |
+
optimizer.zero_grad()
|
| 464 |
+
accum_loss = 0.0
|
| 465 |
+
accum_aux = 0.0
|
| 466 |
+
accum_align = 0.0
|
| 467 |
+
|
| 468 |
+
for mb in micro_batches:
|
| 469 |
+
mb_ids = mb["input_ids"].to(device)
|
| 470 |
+
mb_labels = mb["labels"].to(device)
|
| 471 |
+
word_positions = None
|
| 472 |
+
if word_start_table is not None:
|
| 473 |
+
word_positions = compute_word_positions(mb_ids, word_start_table)
|
| 474 |
+
if use_amp:
|
| 475 |
+
with autocast('cuda', dtype=amp_dtype):
|
| 476 |
+
output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
|
| 477 |
+
else:
|
| 478 |
+
output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
|
| 479 |
+
if scaler:
|
| 480 |
+
scaler.scale(output["loss"] / n_micro).backward()
|
| 481 |
+
else:
|
| 482 |
+
(output["loss"] / n_micro).backward()
|
| 483 |
+
accum_loss += output["loss"].item()
|
| 484 |
+
if "aux_loss" in output:
|
| 485 |
+
accum_aux += output["aux_loss"].item()
|
| 486 |
+
if "align_loss" in output:
|
| 487 |
+
accum_align += output["align_loss"].item()
|
| 488 |
+
|
| 489 |
+
if scaler:
|
| 490 |
+
scaler.unscale_(optimizer)
|
| 491 |
+
clip_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
|
| 492 |
+
grad_norm = nn.utils.clip_grad_norm_(clip_params, config.grad_clip).item()
|
| 493 |
+
if scaler:
|
| 494 |
+
scaler.step(optimizer)
|
| 495 |
+
scaler.update()
|
| 496 |
+
else:
|
| 497 |
+
optimizer.step()
|
| 498 |
+
optimizer.zero_grad()
|
| 499 |
+
|
| 500 |
+
loss_val = accum_loss / n_micro
|
| 501 |
+
aux_loss_val = accum_aux / n_micro if accum_aux > 0 else None
|
| 502 |
+
align_loss_val = accum_align / n_micro if accum_align > 0 else None
|
| 503 |
+
|
| 504 |
+
total_loss += loss_val
|
| 505 |
+
epoch_loss += loss_val
|
| 506 |
+
epoch_steps += 1
|
| 507 |
+
log_steps += 1
|
| 508 |
+
total_tokens_seen += actual_tokens
|
| 509 |
+
step += 1
|
| 510 |
+
|
| 511 |
+
# Logging
|
| 512 |
+
if step % config.log_every == 0:
|
| 513 |
+
avg_loss = total_loss / max(log_steps, 1)
|
| 514 |
+
ppl = math.exp(min(avg_loss, 20))
|
| 515 |
+
elapsed = time.time() - start_time
|
| 516 |
+
tok_s = (log_steps * tokens_per_step) / max(elapsed, 1e-6)
|
| 517 |
+
|
| 518 |
+
extra = ""
|
| 519 |
+
if aux_loss_val is not None:
|
| 520 |
+
extra += f" | Aux {aux_loss_val:.3f}"
|
| 521 |
+
if align_loss_val is not None:
|
| 522 |
+
extra += f" | Align {align_loss_val:.4f}"
|
| 523 |
+
|
| 524 |
+
print(
|
| 525 |
+
f"Step {step:6d} | "
|
| 526 |
+
f"Epoch {epoch+1}/{config.epochs} | "
|
| 527 |
+
f"Loss {avg_loss:.4f} | "
|
| 528 |
+
f"PPL {ppl:8.2f} | "
|
| 529 |
+
f"GradN {grad_norm:.3f} | "
|
| 530 |
+
f"LR {lr:.2e} | "
|
| 531 |
+
f"Tok/s {tok_s:.0f}"
|
| 532 |
+
f"{extra}"
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
total_loss = 0.0
|
| 536 |
+
log_steps = 0
|
| 537 |
+
start_time = time.time()
|
| 538 |
+
|
| 539 |
+
# Checkpointing
|
| 540 |
+
if step % config.save_every == 0:
|
| 541 |
+
ckpt_path = checkpoint_dir / f"step_{step:06d}.pt"
|
| 542 |
+
if model_type == "graft_g2lu":
|
| 543 |
+
save_g2lu_checkpoint(raw_model, optimizer, step, epoch, loss_val, str(ckpt_path),
|
| 544 |
+
epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 545 |
+
else:
|
| 546 |
+
save_checkpoint(model, optimizer, step, epoch, loss_val, model_config, str(ckpt_path), model_type,
|
| 547 |
+
epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 548 |
+
print(f" Saved checkpoint: {ckpt_path}")
|
| 549 |
+
gc.collect()
|
| 550 |
+
torch.cuda.empty_cache()
|
| 551 |
+
|
| 552 |
+
# Mid-training validation
|
| 553 |
+
val_every = getattr(args, 'val_every', 0)
|
| 554 |
+
if val_every > 0 and step % val_every == 0 and val_dataloader:
|
| 555 |
+
val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, mid_run_eval=True, word_start_table=word_start_table)
|
| 556 |
+
avg_train = epoch_loss / max(epoch_steps, 1)
|
| 557 |
+
gap = val_loss - avg_train
|
| 558 |
+
print(f" [Val @ step {step}] Loss: {val_loss:.4f} | PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
|
| 559 |
+
if val_loss < best_val_loss:
|
| 560 |
+
best_val_loss = val_loss
|
| 561 |
+
best_path = checkpoint_dir / "best.pt"
|
| 562 |
+
if model_type == "graft_g2lu":
|
| 563 |
+
save_g2lu_checkpoint(raw_model, optimizer, step, epoch, val_loss, str(best_path),
|
| 564 |
+
epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 565 |
+
else:
|
| 566 |
+
save_checkpoint(model, optimizer, step, epoch, val_loss, model_config, str(best_path), model_type,
|
| 567 |
+
epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 568 |
+
print(f" New best! Saved: {best_path}")
|
| 569 |
+
gc.collect()
|
| 570 |
+
torch.cuda.empty_cache()
|
| 571 |
+
|
| 572 |
+
micro_batches = []
|
| 573 |
+
|
| 574 |
+
# --- Epoch summary ---
|
| 575 |
+
epoch_elapsed = time.time() - epoch_start
|
| 576 |
+
avg_epoch_loss = epoch_loss / max(epoch_steps, 1)
|
| 577 |
+
epoch_ppl = math.exp(min(avg_epoch_loss, 20))
|
| 578 |
+
|
| 579 |
+
print(f"\n{'='*70}")
|
| 580 |
+
print(f"Epoch {epoch+1}/{config.epochs} complete in {epoch_elapsed:.0f}s")
|
| 581 |
+
print(f" Train loss: {avg_epoch_loss:.4f} | Train PPL: {epoch_ppl:.2f}")
|
| 582 |
+
print(f" Tokens seen: {total_tokens_seen:,} ({total_tokens_seen/1e6:.1f}M)")
|
| 583 |
+
|
| 584 |
+
# Validation
|
| 585 |
+
if val_dataloader:
|
| 586 |
+
val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, word_start_table=word_start_table)
|
| 587 |
+
gap = val_loss - avg_epoch_loss
|
| 588 |
+
print(f" Val loss: {val_loss:.4f} | Val PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
|
| 589 |
+
|
| 590 |
+
if val_loss < best_val_loss:
|
| 591 |
+
best_val_loss = val_loss
|
| 592 |
+
best_path = checkpoint_dir / "best.pt"
|
| 593 |
+
if model_type == "graft_g2lu":
|
| 594 |
+
save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, val_loss, str(best_path),
|
| 595 |
+
epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 596 |
+
else:
|
| 597 |
+
save_checkpoint(model, optimizer, step, epoch + 1, val_loss, model_config, str(best_path), model_type,
|
| 598 |
+
epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 599 |
+
print(f" New best! Saved: {best_path}")
|
| 600 |
+
# Free validation tensors
|
| 601 |
+
gc.collect()
|
| 602 |
+
torch.cuda.empty_cache()
|
| 603 |
+
print(f"{'='*70}\n")
|
| 604 |
+
|
| 605 |
+
# Save epoch checkpoint
|
| 606 |
+
ckpt_path = checkpoint_dir / f"epoch_{epoch+1:02d}.pt"
|
| 607 |
+
if model_type == "graft_g2lu":
|
| 608 |
+
save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, avg_epoch_loss, str(ckpt_path),
|
| 609 |
+
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 610 |
+
else:
|
| 611 |
+
save_checkpoint(model, optimizer, step, epoch + 1, avg_epoch_loss, model_config, str(ckpt_path), model_type,
|
| 612 |
+
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 613 |
+
gc.collect()
|
| 614 |
+
torch.cuda.empty_cache()
|
| 615 |
+
|
| 616 |
+
# Save final checkpoint
|
| 617 |
+
if step == start_step:
|
| 618 |
+
print(f"\nNo training performed (already at step {step}/{max_steps}).")
|
| 619 |
+
print(f" To train more epochs, increase --epochs beyond {config.epochs}.")
|
| 620 |
+
else:
|
| 621 |
+
final_path = checkpoint_dir / "latest.pt"
|
| 622 |
+
if model_type == "graft_g2lu":
|
| 623 |
+
save_g2lu_checkpoint(raw_model, optimizer, step, config.epochs, avg_epoch_loss, str(final_path),
|
| 624 |
+
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 625 |
+
else:
|
| 626 |
+
save_checkpoint(model, optimizer, step, config.epochs, avg_epoch_loss, model_config, str(final_path), model_type,
|
| 627 |
+
epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
|
| 628 |
+
print(f"\nTraining complete.")
|
| 629 |
+
print(f" Final train loss: {avg_epoch_loss:.4f} | PPL: {epoch_ppl:.2f}")
|
| 630 |
+
if val_dataloader:
|
| 631 |
+
print(f" Best val loss: {best_val_loss:.4f} | PPL: {math.exp(min(best_val_loss, 20)):.2f}")
|
| 632 |
+
print(f" Total tokens: {total_tokens_seen:,}")
|
| 633 |
+
print(f" Checkpoints: {final_path}")
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
if __name__ == "__main__":
|
| 637 |
+
train()
|