y3i12 commited on
Commit
56e82ec
·
1 Parent(s): 97022c0

Initial commit

Browse files
Files changed (17) hide show
  1. README.md +307 -296
  2. __init__.py +28 -0
  3. bench.py +176 -0
  4. coherence_eval.py +834 -0
  5. config.py +306 -0
  6. data.py +546 -0
  7. generate.py +195 -0
  8. graft_g2lu.py +300 -0
  9. layers.py +325 -0
  10. lm_eval_wrapper.py +344 -0
  11. mirrored.py +532 -0
  12. model.py +357 -0
  13. scripts/__init__.py +0 -0
  14. scripts/representation_analysis.py +1014 -0
  15. scripts/spectral_analysis.py +969 -0
  16. scripts/spectral_to_csv.py +202 -0
  17. train.py +637 -0
README.md CHANGED
@@ -1,296 +1,307 @@
1
- ---
2
- license: mit
3
- datasets:
4
- - Bingsu/openwebtext_20p
5
- - HuggingFaceFW/fineweb-edu
6
- language:
7
- - en
8
- pipeline_tag: text-generation
9
- ---
10
- # Prisma
11
-
12
- A prototype model that is assembled as a mirrored transformer architecture with nested gating (adds an extra weight to the FFN) and morphological position encoding. It proposes that the model architecture creates different scaffolding, leading to different training regimens and capabilities.
13
-
14
- Prisma is only viable as it piggybacks on pre-trained tokenizers and their weight-tied embeddings, it decomposes the transformer architecture into symmetric **expand** and **compress** phases that share structural weights, connected by a small number of unique **middle** layers. Information expands from tokens to semantics, then compresses back — like light through a prism.
15
-
16
- ```
17
- Token Embeddings
18
- |
19
- [ Expand ] ─── mirror pair 1 (W1, W2 shared) ── G²LU gate (W3·W4)
20
- [ Expand ] ─── mirror pair 2
21
- [ .... ] ─── mirror pair N
22
- |
23
- [ Middle ] ─── unique layers (full capacity, not shared)
24
- |
25
- [Compress ] ─── mirror pair N (same W1, W2 as expand N)
26
- [Compress ] ─── mirror pair 2
27
- [Compress ] ─── mirror pair 1
28
- |
29
- LM Head (weight-tied to embeddings)
30
- ```
31
-
32
-
33
- ## Key Concepts
34
-
35
- **Mirrored layers.** Each expand layer shares W1 (projection) and W2 (output) weights with its corresponding compress layer. The architecture gets 2N virtual layers of processing from N unique parameter sets. At 357M parameters, Prisma runs 41 virtual layers from ~20 unique weight sets + 1 middle layer.
36
-
37
- **G²LU — Gated-Gated Linear Unit.** The gate is itself gated:
38
-
39
- Where typical gated transformers have `y = W2 @ (W1 @ x * silu(W3 @ x))`, Prisma has:
40
- ```python
41
- g4 = silu(W4 @ x) # inner gate
42
- g3 = silu(W3 @ x * g4) # outer gate, modulated by inner
43
- y = W2 @ (W1 @ x * g3) # gated output
44
- ```
45
-
46
- One gate in function of the other. Creates quadratic (saddle-surface) decision boundaries instead of linear hyperplanes — each neuron computes a conjunction ("feature A AND feature B") rather than a single threshold. This produces narrow, separated activation channels that resist memorization and tolerate significantly higher learning rates. Part of the parameters saved with mirroring are re-distributed as W4.
47
-
48
- **WoRPE — Word-position Rotary Position Embedding.** Dedicates a small subspace of each attention head to encode position within a word (0 = prefix, 1 = second subword, ...). The information is already in the BPE tokenizer's word-boundary markers — WoRPE surfaces it geometrically so the model doesn't have to rediscover it. No new tokenizer required.
49
-
50
- **Auxiliary skip prediction.** An optional second head predicts t+K tokens ahead, providing gradient signal that rewards structural representations over local memorization. At K=1, functions as a dual-supervision regularizer through an untied projection.
51
-
52
-
53
- ## Results
54
-
55
- ### ~50M scale prototype (WikiText-103, 4 epochs)
56
-
57
- | Model | Params | LR | WikiText PPL | LAMBADA |
58
- |---|---|---|---|---|
59
- | Standard SwiGLU | 51M | 1e-4 | 4125 | 0.002 |
60
- | Prisma (G²LU) | 47M | 1e-4 | 2914 | 0.001 |
61
- | Prisma (G²LU + WoRPE) | 51M | 1e-2 | 921 | 0.082 |
62
-
63
- *Standard trained 10 epochs; Prisma (G²LU + WoRPE) shown at 1 epoch — the point is LR tolerance, not epoch-matched comparison.*
64
-
65
- The regularization stack (mirroring + G²LU + WoRPE) enables training at **100x the standard learning rate** without instability.
66
-
67
- ### ~350M scale prototype — comparison with published models
68
-
69
- Prisma 357M trained on ~30B tokens (OpenWebText 20% + FineWeb-Edu 10BT continued training), compared against published models at similar scale:
70
-
71
- | Model | Params | Train Data | ARC-C\* | ARC-E\* | BoolQ | HellaSwag\* | LAMBADA | PIQA\* | WikiText\*\* | WinoGrande |
72
- |---|---|---|---|---|---|---|---|---|---|---|
73
- | GPT-2 medium | 355M | 40B | 0.250 | 0.436 | 0.586 | 0.394 | **0.430** | 0.664 | **26.75** | **0.531** |
74
- | Baguettotron | 321M | 200B | 0.302 | 0.506 | 0.589 | 0.354 | 0.294 | 0.618 | 30.93 | 0.530 |
75
- | SmolLM-360M | 360M | 600B | **0.359** | **0.640** | 0.550 | **0.536** | **0.455** | **0.715** | **19.49** | **0.570** |
76
- | SmolLM2-360M | 360M | 4000B | **0.381** | **0.681** | 0.617 | 0.431 | **0.532** | **0.718** | **15.67** | **0.586** |
77
- | LFM2-350M | 350M | 10000B | **0.393** | **0.662** | **0.642** | **0.489** | 0.399 | 0.698 | 25.68 | **0.558** |
78
- | **Prisma** | **357M** | **30B** | 0.290 | **0.548** | **0.620** | **0.427** | 0.362 | **0.670** | 27.40 | 0.506 |
79
-
80
- \* *normalized accuracy* · \*\* *word perplexity*
81
-
82
- **Key findings:**
83
- - **Beats GPT-2 medium on 5/8 benchmarks** (ARC-C, ARC-E, BoolQ, HellaSwag, PIQA) with 25% less training data.
84
- - **Beats Baguettotron (200B) on 6/8 benchmarks** — including PPL — with **7x less data.**
85
- - **BoolQ 0.620** exceeds all models except LFM2 (10000B) and SmolLM2 (4000B). The anti-memorization properties of G²LU force genuine comprehension instead of statistical shortcuts.
86
- - **ARC-Easy 0.548** — the largest absolute gain over GPT-2 medium (+11.2pp). FineWeb-Edu knowledge absorbed efficiently through G²LU's relational features.
87
- - Prisma wins on **reasoning benchmarks** (ARC, HellaSwag, PIQA, BoolQ). Models trained on 20-300x more data win on **content prediction** (LAMBADA, PPL). The architecture trades raw memorization for data-efficient knowledge application.
88
-
89
-
90
- ### Training progression (~350M)
91
-
92
- | Stage | LR | ARC-C | ARC-E | BoolQ | HellaSwag | LAMBADA | PIQA | PPL |
93
- |---|---|---|---|---|---|---|---|---|
94
- | Standard 336M baseline | 1e-4 | 0.228 | 0.341 | 0.618 | 0.280 | 0.226 | 0.574 | 77.2 |
95
- | Prisma 41L (OWT 20%) | 5e-4 | 0.238 | 0.394 | 0.585 | 0.317 | 0.313 | 0.614 | 44.8 |
96
- | + WoRPE (OWT 20%) | 1e-3 | 0.247 | 0.397 | 0.595 | 0.331 | 0.333 | 0.614 | 43.5 |
97
- | + continued (FineWeb c1) | 1e-3 | 0.249 | 0.434 | 0.601 | 0.333 | 0.312 | 0.626 | 34.7 |
98
- | + continued (FineWeb c2) | 1e-3 | 0.290 | 0.548 | 0.620 | 0.427 | 0.362 | 0.670 | 27.4 |
99
-
100
-
101
- ## Quick Start
102
-
103
- ### Install
104
-
105
- ```bash
106
- pip install -r circuits/requirements.txt
107
- ```
108
-
109
-
110
- ### Train
111
-
112
- ```bash
113
- # Small Prisma (~47M) on WikiText-103
114
- python -m circuits.train \
115
- --arch mirrored --dims 384 --heads 6 --kv-heads 2 --layers 57 --n-middle 1 \
116
- --tokenizer facebook/MobileLLM-125M \
117
- --word-rope-dims 8 --word-rope-base 10.0 \
118
- --data hf:wikitext:wikitext-103-raw-v1:train \
119
- --epochs 4 --batch-size 32 --context-length 512 \
120
- --lr 1e-2 --warmup-steps 500 --bf16 --gpu 0
121
-
122
-
123
- # 324M Prisma on OpenWebText
124
- python -m circuits.train \
125
- --gpu 0 --compile --bf16 --arch mirrored \
126
- --dims 1024 --heads 16 --kv-heads 4 --layers 41 --n-middle 1 \
127
- --word-rope-dims 8 --word-rope-base 10.0 \
128
- --tokenizer facebook/MobileLLM-125M \
129
- --data hf:Bingsu/openwebtext_20p --text-column text \
130
- --epochs 4 --batch-size 12 --context-length 1024 --grad-accum 42 \
131
- --lr 1e-3 --warmup 500 \
132
- --log-every 5 --val-every 1000 --save-every 1000 \
133
- --checkpoint-dir path/to/checkpoints/your_model/
134
- ```
135
-
136
-
137
- ### Generate
138
-
139
- ```bash
140
- python -m circuits.generate \
141
- --checkpoint path/to/checkpoints/your_model/best.pt \
142
- --prompt "A thought observing itself discovers that it" \
143
- --max-tokens 256 --temperature 0.8 --top-p 0.95
144
- ```
145
-
146
-
147
- ### Benchmark
148
-
149
- ```bash
150
- # Single model
151
- python -m circuits.bench \
152
- --checkpoint path/to/checkpoints/your_model/best.pt \
153
- --tasks arc_easy,lambada_openai,piqa,hellaswag,winogrande,wikitext \
154
- --gpu 0
155
- ```
156
-
157
-
158
- ## CLI Reference
159
-
160
- ### Architecture
161
-
162
- | Flag | Default | Description |
163
- |---|---|---|
164
- | `--arch` | `mirrored` | Architecture: `standard`, `mirrored`, `graft_g2lu` (experimental) |
165
- | `--dims` | 512 | Hidden dimension |
166
- | `--heads` | 8 | Number of attention heads |
167
- | `--kv-heads` | — | KV heads for GQA (omit = MHA) |
168
- | `--layers` | 12 | Total virtual layers (expand + middle + compress) |
169
- | `--n-middle` | 2 | Unique (non-mirrored) middle layers |
170
-
171
-
172
- ### Prisma-Specific
173
-
174
- | Flag | Default | Description |
175
- |---|---|---|
176
- | `--word-rope-dims` | 0 | Head dims for WoRPE (0 = disabled, try 8) |
177
- | `--word-rope-base` | 10.0 | WoRPE frequency base |
178
- | `--aux-skip` | 0 | Skip-ahead prediction distance (0 = disabled) |
179
- | `--aux-weight` | 0.1 | Weight for auxiliary loss |
180
- | `--no-g2lu` | — | Disable G²LU, use standard SwiGLU in mirrored arch |
181
-
182
-
183
- ### Training
184
-
185
- | Flag | Default | Description |
186
- |---|---|---|
187
- | `--lr` | 3e-4 | Peak learning rate |
188
- | `--min-lr` | 0.0 | LR floor for cosine schedule |
189
- | `--warmup-steps` | 100 | LR warmup steps |
190
- | `--epochs` | 10 | Training epochs |
191
- | `--batch-size` | 32 | Micro-batch size |
192
- | `--grad-accum` | 1 | Gradient accumulation steps |
193
- | `--context-length` | 512 | Sequence length |
194
- | `--bf16` / `--fp16` | — | Mixed precision |
195
- | `--compile` | — | `torch.compile` the model |
196
-
197
-
198
- ### Data
199
-
200
- | Flag | Default | Description |
201
- |---|---|---|
202
- | `--data` | — | Path or `hf:dataset_name` |
203
- | `--text-column` | `text` | Column name for HF datasets |
204
- | `--tokenizer` | `gpt2` | Tokenizer name or path |
205
- | `--num-samples` | — | Limit dataset size |
206
-
207
-
208
- ## Architecture Details
209
-
210
- ### Why Mirroring Works
211
-
212
- Mirroring only works due to the additional gate. W3 and W4 specialize to serve different roles despite sharing weights — spectral analysis confirms the gates swap their stable-rank profiles at the architectural midpoint. The order of mirror layers may be rearrangeable, as the gates adapt to whatever representations flow through them.
213
-
214
-
215
- ### Why G²LU Works
216
-
217
- Standard SwiGLU creates hyperplane decision boundaries — broad, overlapping activation regions. G²LU's nested gate creates **saddle surfaces** — narrow activation bands with isolation gaps (like a spectral comb filter). This has three effects:
218
-
219
- 1. **Anti-memorization.** The gate geometry cannot form sharp, input-specific activations. The model is forced toward broad, relational features.
220
- 2. **Higher LR tolerance.** Narrow activation bands leave headroom between features. Large gradient updates shift features within their bands without colliding.
221
- 3. **Compositional detection.** Each neuron natively computes conjunctions (A AND B), not just thresholds. Might be useful for morphology, syntax, and structural reasoning.
222
-
223
- G²LU can be seen as occupying a point between standard GLU (fixed activation, fixed gate) and KAN (fully learned activations): the activation function is fixed (silu), but its effective shape adapts per-input through the nested gate.
224
-
225
-
226
- ### Why WoRPE Works
227
-
228
- BPE tokenizers already mark word boundaries (`Ġ` for GPT-2, `▁` for SentencePiece). WoRPE surfaces this information geometrically in a dedicated subspace of the rotary embedding, so the model gets word-internal position for free instead of rediscovering it from attention patterns. Requires G²LU to exploit effectively — the saddle surfaces compute morphological conjunctions ("position-0 AND prefix-pattern") that single gates cannot.
229
-
230
-
231
- ### Why Everything Works Together
232
-
233
- The optimization landscape of this architecture is substantially more complex than a standard transformer — shared weights must serve both directions, nested gates must coordinate, and the hourglass bottleneck constrains information flow. This appears to be only tractable when anchored by pre-trained, weight-tied embeddings that provide a stable coordinate system. The frozen embeddings give the model fixed reference geometry, allowing convergence despite the architectural complexity.
234
-
235
-
236
- ## File Map
237
-
238
- ```
239
- circuits/
240
- config.py — CLI arguments, presets, CircuitConfig
241
- layers.py — RMSNorm, RoPE, WoRPE, CausalAttention, SwiGLU
242
- model.py — CircuitTransformer (standard baseline)
243
- mirrored.py — MirroredTransformer, G²LU, MirroredBlock
244
- train.py — Training loop, LR schedule, checkpointing
245
- data.py — MemmapDataset, parallel tokenization, HF/text loading
246
- generate.py — Text generation with KV caching
247
- bench.py — Benchmark runner and comparison tables
248
- lm_eval_wrapper.py — EleutherAI lm-eval harness integration
249
- graft_g2lu.py — Surgical G²LU upgrade for pretrained models (experimental/untested)
250
- scripts/ — Analysis scripts
251
- ```
252
-
253
-
254
- ## Origin
255
-
256
- Prisma grew from interpretability research on _layer grafting_ (writing in progress) in Llama 3.2, which suggests that one of the ways that transformers might self organize to process language can be seen as like a mirrored structure that expands from tokens to semantics, then compressing back — bringing the interpretive analogy of seeing it as a biconvex lens with fractures or polarizing filters within its body. If the two halves are symmetric structurally, they can share weights. The gate (fractures/polarizing filters) becomes the minimum surgical unit for changing behavior. A single weightset becomes insufficient due to shared weights, which brought the question of how to properly make two gates efficiently collaborate.
257
-
258
- G²LU emerged from the observation that for a pair of gates to be expressive and atomic, _one gate needs to be in function of the other_.
259
-
260
- WoRPE emerged from noticing, that tokenizers already carry word structure but positional encodings ignore it — providing hints to the model allows faster convergence during training.
261
-
262
- The architecture is a processing engine that plugs into pretrained tokenizer embeddings. The tokenizer is load-bearing infrastructure — Prisma operates within a pre-existing coordinate system.
263
-
264
-
265
- ## Developer Notes
266
-
267
- This model is the outcome of a POC done by a single individual with limited resources, further investigation, training and tests are being slowly conducted as time and conditions allow.
268
-
269
- The proposed architecture was only fully trained on top of `facebook/MobileLLM-125M` tokenizer and weight-tied embeddings. It might be the case that it doesn't work as expected on untied embeddings and it is highly likely that it is impossible to train a model with this architecture without a pre-trained tokenizer.
270
-
271
- Different arrangements of the architecture (varying middle layer count, mirror depth, width) would likely produce different results. Only this setup — with 1 middle layer — was tested, as a validation of whether the architecture works at all. The extreme case was chosen deliberately: if the bottleneck configuration most prone to failure still produces competitive results, less constrained configurations should too.
272
-
273
- Factorized dimensions for embeddings and an intermediate down proj before the output head were attempted, and nothing useful came out of it.
274
-
275
- It is completely unknown if the architecture is beneficial for larger models (1B+) — observations suggests it might.
276
-
277
-
278
- ## Training
279
-
280
- - **Architecture**:
281
- - 41 layers
282
- - 20 with shared W1 and W2
283
- - 1 unique
284
- - 1024 dimms
285
- - 16 GQA heads, 4 KV heads (4:1)
286
- - vocab size 32k
287
- - RoPE + WoRPE + G²LU
288
- - **Pretraining tokens**: 30B
289
- - **Precision**: bfloat16
290
- - **Tokenizer/Embeddings**: facebook/MobileLLM-125M
291
- - **Hardware**: 1 H100
292
-
293
-
294
- ## Disclaimer
295
-
296
- This model is developed as a research model and it hasn't been tested thoroughly regarding synthesis and coherence quality, as its size is somewhat limiting. Use it at your own risk.
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - Bingsu/openwebtext_20p
5
+ - HuggingFaceFW/fineweb-edu
6
+ language:
7
+ - en
8
+ pipeline_tag: text-generation
9
+ ---
10
+ # Prisma
11
+
12
+ A prototype model that is assembled as a mirrored transformer architecture with nested gating (adds an extra weight to the FFN) and morphological position encoding. It proposes that the model architecture creates different scaffolding, leading to different training regimens and capabilities.
13
+
14
+ Prisma is only viable as it piggybacks on pre-trained tokenizers and their weight-tied embeddings, it decomposes the transformer architecture into symmetric **expand** and **compress** phases that share structural weights, connected by a small number of unique **middle** layers. Information expands from tokens to semantics, then compresses back — like light through a prism.
15
+
16
+ ```
17
+ Token Embeddings
18
+ |
19
+ [ Expand ] ─── mirror pair 1 (W1, W2 shared) ── G²LU gate (W3·W4)
20
+ [ Expand ] ─── mirror pair 2
21
+ [ .... ] ─── mirror pair N
22
+ |
23
+ [ Middle ] ─── unique layers (full capacity, not shared)
24
+ |
25
+ [Compress ] ─── mirror pair N (same W1, W2 as expand N)
26
+ [Compress ] ─── mirror pair 2
27
+ [Compress ] ─── mirror pair 1
28
+ |
29
+ LM Head (weight-tied to embeddings)
30
+ ```
31
+
32
+
33
+ ## Key Concepts
34
+
35
+ **Mirrored layers.** Each expand layer shares W1 (projection) and W2 (output) weights with its corresponding compress layer. The architecture gets 2N virtual layers of processing from N unique parameter sets. At 357M parameters, Prisma runs 41 virtual layers from ~20 unique weight sets + 1 middle layer.
36
+
37
+ **G²LU — Gated-Gated Linear Unit.** The gate is itself gated:
38
+
39
+ Where typical gated transformers have `y = W2 @ (W1 @ x * silu(W3 @ x))`, Prisma has:
40
+ ```python
41
+ g4 = silu(W4 @ x) # inner gate
42
+ g3 = silu(W3 @ x * g4) # outer gate, modulated by inner
43
+ y = W2 @ (W1 @ x * g3) # gated output
44
+ ```
45
+
46
+ One gate in function of the other. Creates quadratic (saddle-surface) decision boundaries instead of linear hyperplanes — each neuron computes a conjunction ("feature A AND feature B") rather than a single threshold. This produces narrow, separated activation channels that resist memorization and tolerate significantly higher learning rates. Part of the parameters saved with mirroring are re-distributed as W4.
47
+
48
+ **WoRPE — Word-position Rotary Position Embedding.** Dedicates a small subspace of each attention head to encode position within a word (0 = prefix, 1 = second subword, ...). The information is already in the BPE tokenizer's word-boundary markers — WoRPE surfaces it geometrically so the model doesn't have to rediscover it. No new tokenizer required.
49
+
50
+ **Auxiliary skip prediction.** An optional second head predicts t+K tokens ahead, providing gradient signal that rewards structural representations over local memorization. At K=1, functions as a dual-supervision regularizer through an untied projection.
51
+
52
+
53
+ ## Results
54
+
55
+ ### ~50M scale prototype (WikiText-103, 4 epochs)
56
+
57
+ | Model | Params | LR | WikiText PPL | LAMBADA |
58
+ |---|---|---|---|---|
59
+ | Standard SwiGLU | 51M | 1e-4 | 4125 | 0.002 |
60
+ | Prisma (G²LU) | 47M | 1e-4 | 2914 | 0.001 |
61
+ | Prisma (G²LU + WoRPE) | 51M | 1e-2 | 921 | 0.082 |
62
+
63
+ *Standard trained 10 epochs; Prisma (G²LU + WoRPE) shown at 1 epoch — the point is LR tolerance, not epoch-matched comparison.*
64
+
65
+ The regularization stack (mirroring + G²LU + WoRPE) enables training at **100x the standard learning rate** without instability.
66
+
67
+ ### ~350M scale prototype — comparison with published models
68
+
69
+ Prisma 357M trained on ~30B tokens (OpenWebText 20% + FineWeb-Edu 10BT continued training), compared against published models at similar scale:
70
+
71
+ | Model | Params | Train Data | ARC-C\* | ARC-E\* | BoolQ | HellaSwag\* | LAMBADA | PIQA\* | WikiText\*\* | WinoGrande |
72
+ |---|---|---|---|---|---|---|---|---|---|---|
73
+ | GPT-2 medium | 355M | 40B | 0.250 | 0.436 | 0.586 | 0.394 | **0.430** | 0.664 | **26.75** | **0.531** |
74
+ | Baguettotron | 321M | 200B | 0.302 | 0.506 | 0.589 | 0.354 | 0.294 | 0.618 | 30.93 | 0.530 |
75
+ | SmolLM-360M | 360M | 600B | **0.359** | **0.640** | 0.550 | **0.536** | **0.455** | **0.715** | **19.49** | **0.570** |
76
+ | SmolLM2-360M | 360M | 4000B | **0.381** | **0.681** | 0.617 | 0.431 | **0.532** | **0.718** | **15.67** | **0.586** |
77
+ | LFM2-350M | 350M | 10000B | **0.393** | **0.662** | **0.642** | **0.489** | 0.399 | 0.698 | 25.68 | **0.558** |
78
+ | **Prisma** | **357M** | **30B** | 0.290 | **0.548** | **0.620** | **0.427** | 0.362 | **0.670** | 27.40 | 0.506 |
79
+
80
+ \* *normalized accuracy* · \*\* *word perplexity*
81
+
82
+ **Key findings:**
83
+ - **Beats GPT-2 medium on 5/8 benchmarks** (ARC-C, ARC-E, BoolQ, HellaSwag, PIQA) with 25% less training data.
84
+ - **Beats Baguettotron (200B) on 6/8 benchmarks** — including PPL — with **7x less data.**
85
+ - **BoolQ 0.620** exceeds all models except LFM2 (10000B) and SmolLM2 (4000B). The anti-memorization properties of G²LU force genuine comprehension instead of statistical shortcuts.
86
+ - **ARC-Easy 0.548** — the largest absolute gain over GPT-2 medium (+11.2pp). FineWeb-Edu knowledge absorbed efficiently through G²LU's relational features.
87
+ - Prisma wins on **reasoning benchmarks** (ARC, HellaSwag, PIQA, BoolQ). Models trained on 20-300x more data win on **content prediction** (LAMBADA, PPL). The architecture trades raw memorization for data-efficient knowledge application.
88
+
89
+
90
+ ### Training progression (~350M)
91
+
92
+ | Stage | LR | ARC-C | ARC-E | BoolQ | HellaSwag | LAMBADA | PIQA | PPL |
93
+ |---|---|---|---|---|---|---|---|---|
94
+ | Standard 336M baseline | 1e-4 | 0.228 | 0.341 | 0.618 | 0.280 | 0.226 | 0.574 | 77.2 |
95
+ | Prisma 41L (OWT 20%) | 5e-4 | 0.238 | 0.394 | 0.585 | 0.317 | 0.313 | 0.614 | 44.8 |
96
+ | + WoRPE (OWT 20%) | 1e-3 | 0.247 | 0.397 | 0.595 | 0.331 | 0.333 | 0.614 | 43.5 |
97
+ | + continued (FineWeb c1) | 1e-3 | 0.249 | 0.434 | 0.601 | 0.333 | 0.312 | 0.626 | 34.7 |
98
+ | + continued (FineWeb c2) | 1e-3 | 0.290 | 0.548 | 0.620 | 0.427 | 0.362 | 0.670 | 27.4 |
99
+
100
+
101
+ ## Quick Start
102
+
103
+ ### Install
104
+
105
+ ```bash
106
+ pip install -r Prisma/requirements.txt
107
+ ```
108
+
109
+
110
+ ### Train
111
+
112
+ ```bash
113
+ # Small Prisma (~47M) on WikiText-103
114
+ python -m Prisma.train \
115
+ --arch mirrored --dims 384 --heads 6 --kv-heads 2 --layers 57 --n-middle 1 \
116
+ --tokenizer facebook/MobileLLM-125M \
117
+ --word-rope-dims 8 --word-rope-base 10.0 \
118
+ --data hf:wikitext:wikitext-103-raw-v1:train \
119
+ --epochs 4 --batch-size 32 --context-length 512 \
120
+ --lr 1e-2 --warmup-steps 500 --bf16 --gpu 0
121
+
122
+
123
+ # 324M Prisma on OpenWebText
124
+ python -m Prisma.train \
125
+ --gpu 0 --compile --bf16 --arch mirrored \
126
+ --dims 1024 --heads 16 --kv-heads 4 --layers 41 --n-middle 1 \
127
+ --word-rope-dims 8 --word-rope-base 10.0 \
128
+ --tokenizer facebook/MobileLLM-125M \
129
+ --data hf:Bingsu/openwebtext_20p --text-column text \
130
+ --epochs 4 --batch-size 12 --context-length 1024 --grad-accum 42 \
131
+ --lr 1e-3 --warmup 500 \
132
+ --log-every 5 --val-every 1000 --save-every 1000 \
133
+ --checkpoint-dir path/to/checkpoints/your_model/
134
+ ```
135
+
136
+
137
+ ### Generate
138
+
139
+ ```bash
140
+ python -m Prisma.generate \
141
+ --checkpoint path/to/checkpoints/your_model/best.pt \
142
+ --prompt "A thought observing itself discovers that it" \
143
+ --max-tokens 256 --temperature 0.8 --top-p 0.95
144
+ ```
145
+
146
+
147
+ ### Benchmark
148
+
149
+ ```bash
150
+ # Single model
151
+ python -m Prisma.bench \
152
+ --checkpoint path/to/checkpoints/your_model/best.pt \
153
+ --tasks arc_easy,lambada_openai,piqa,hellaswag,winogrande,wikitext \
154
+ --gpu 0
155
+ ```
156
+
157
+
158
+ ## CLI Reference
159
+
160
+ ### Architecture
161
+
162
+ | Flag | Default | Description |
163
+ |---|---|---|
164
+ | `--arch` | `mirrored` | Architecture: `standard`, `mirrored`, `graft_g2lu` (experimental) |
165
+ | `--dims` | 512 | Hidden dimension |
166
+ | `--heads` | 8 | Number of attention heads |
167
+ | `--kv-heads` | — | KV heads for GQA (omit = MHA) |
168
+ | `--layers` | 12 | Total virtual layers (expand + middle + compress) |
169
+ | `--n-middle` | 2 | Unique (non-mirrored) middle layers |
170
+
171
+
172
+ ### Prisma-Specific
173
+
174
+ | Flag | Default | Description |
175
+ |---|---|---|
176
+ | `--word-rope-dims` | 0 | Head dims for WoRPE (0 = disabled, try 8) |
177
+ | `--word-rope-base` | 10.0 | WoRPE frequency base |
178
+ | `--aux-skip` | 0 | Skip-ahead prediction distance (0 = disabled) |
179
+ | `--aux-weight` | 0.1 | Weight for auxiliary loss |
180
+ | `--no-g2lu` | — | Disable G²LU, use standard SwiGLU in mirrored arch |
181
+
182
+
183
+ ### Training
184
+
185
+ | Flag | Default | Description |
186
+ |---|---|---|
187
+ | `--lr` | 3e-4 | Peak learning rate |
188
+ | `--min-lr` | 0.0 | LR floor for cosine schedule |
189
+ | `--warmup-steps` | 100 | LR warmup steps |
190
+ | `--epochs` | 10 | Training epochs |
191
+ | `--batch-size` | 32 | Micro-batch size |
192
+ | `--grad-accum` | 1 | Gradient accumulation steps |
193
+ | `--context-length` | 512 | Sequence length |
194
+ | `--bf16` / `--fp16` | — | Mixed precision |
195
+ | `--compile` | — | `torch.compile` the model |
196
+
197
+
198
+ ### Data
199
+
200
+ | Flag | Default | Description |
201
+ |---|---|---|
202
+ | `--data` | — | Path or `hf:dataset_name` |
203
+ | `--text-column` | `text` | Column name for HF datasets |
204
+ | `--tokenizer` | `gpt2` | Tokenizer name or path |
205
+ | `--num-samples` | — | Limit dataset size |
206
+
207
+
208
+ ## Architecture Details
209
+
210
+ ### Why Mirroring Works
211
+
212
+ Mirroring only works due to the additional gate. W3 and W4 specialize to serve different roles despite sharing weights — spectral analysis confirms the gates swap their stable-rank profiles at the architectural midpoint. The order of mirror layers may be rearrangeable, as the gates adapt to whatever representations flow through them.
213
+
214
+
215
+ ### Why G²LU Works
216
+
217
+ Standard SwiGLU creates hyperplane decision boundaries — broad, overlapping activation regions. G²LU's nested gate creates **saddle surfaces** — narrow activation bands with isolation gaps (like a spectral comb filter). This has three effects:
218
+
219
+ 1. **Anti-memorization.** The gate geometry cannot form sharp, input-specific activations. The model is forced toward broad, relational features.
220
+ 2. **Higher LR tolerance.** Narrow activation bands leave headroom between features. Large gradient updates shift features within their bands without colliding.
221
+ 3. **Compositional detection.** Each neuron natively computes conjunctions (A AND B), not just thresholds. Might be useful for morphology, syntax, and structural reasoning.
222
+
223
+ G²LU can be seen as occupying a point between standard GLU (fixed activation, fixed gate) and KAN (fully learned activations): the activation function is fixed (silu), but its effective shape adapts per-input through the nested gate.
224
+
225
+
226
+ ### Why WoRPE Works
227
+
228
+ BPE tokenizers already mark word boundaries (`Ġ` for GPT-2, `▁` for SentencePiece). WoRPE surfaces this information geometrically in a dedicated subspace of the rotary embedding, so the model gets word-internal position for free instead of rediscovering it from attention patterns. Requires G²LU to exploit effectively — the saddle surfaces compute morphological conjunctions ("position-0 AND prefix-pattern") that single gates cannot.
229
+
230
+
231
+ ### Why Everything Works Together
232
+
233
+ The optimization landscape of this architecture is substantially more complex than a standard transformer — shared weights must serve both directions, nested gates must coordinate, and the hourglass bottleneck constrains information flow. This appears to be only tractable when anchored by pre-trained, weight-tied embeddings that provide a stable coordinate system. The frozen embeddings give the model fixed reference geometry, allowing convergence despite the architectural complexity.
234
+
235
+
236
+ ## File Map
237
+
238
+ ```
239
+ Prisma/
240
+ config.py — CLI arguments, presets, CircuitConfig
241
+ layers.py — RMSNorm, RoPE, WoRPE, CausalAttention, SwiGLU
242
+ model.py — CircuitTransformer (standard baseline)
243
+ mirrored.py — MirroredTransformer, G²LU, MirroredBlock
244
+ train.py — Training loop, LR schedule, checkpointing
245
+ data.py — MemmapDataset, parallel tokenization, HF/text loading
246
+ generate.py — Text generation with KV caching
247
+ bench.py — Benchmark runner and comparison tables
248
+ lm_eval_wrapper.py — EleutherAI lm-eval harness integration
249
+ graft_g2lu.py — Surgical G²LU upgrade for pretrained models (experimental/untested)
250
+ scripts/ — Analysis scripts
251
+ ```
252
+
253
+
254
+ ## Origin
255
+
256
+ Prisma grew from interpretability research on _layer grafting_ (writing in progress) in Llama 3.2, which suggests that one of the ways that transformers might self organize to process language can be seen as like a mirrored structure that expands from tokens to semantics, then compressing back — bringing the interpretive analogy of seeing it as a biconvex lens with fractures or polarizing filters within its body. If the two halves are symmetric structurally, they can share weights. The gate (fractures/polarizing filters) becomes the minimum surgical unit for changing behavior. A single weightset becomes insufficient due to shared weights, which brought the question of how to properly make two gates efficiently collaborate.
257
+
258
+ G²LU emerged from the observation that for a pair of gates to be expressive and atomic, _one gate needs to be in function of the other_.
259
+
260
+ WoRPE emerged from noticing, that tokenizers already carry word structure but positional encodings ignore it — providing hints to the model allows faster convergence during training.
261
+
262
+ The architecture is a processing engine that plugs into pretrained tokenizer embeddings. The tokenizer is load-bearing infrastructure — Prisma operates within a pre-existing coordinate system.
263
+
264
+
265
+ ## Developer Notes
266
+
267
+ This model is the outcome of a POC done by a single individual with limited resources, further investigation, training and tests are being slowly conducted as time and conditions allow.
268
+
269
+ The proposed architecture was only fully trained on top of `facebook/MobileLLM-125M` tokenizer and weight-tied embeddings. It might be the case that it doesn't work as expected on untied embeddings and it is highly likely that it is impossible to train a model with this architecture without a pre-trained tokenizer.
270
+
271
+ Different arrangements of the architecture (varying middle layer count, mirror depth, width) would likely produce different results. Only this setup — with 1 middle layer — was tested, as a validation of whether the architecture works at all. The extreme case was chosen deliberately: if the bottleneck configuration most prone to failure still produces competitive results, less constrained configurations should too.
272
+
273
+ Factorized dimensions for embeddings and an intermediate down proj before the output head were attempted, and nothing useful came out of it.
274
+
275
+ It is completely unknown if the architecture is beneficial for larger models (1B+) — observations suggests it might.
276
+
277
+
278
+ ## Training
279
+
280
+ - **Architecture**:
281
+ - 41 layers
282
+ - 20 with shared W1 and W2
283
+ - 1 unique
284
+ - 1024 dimms
285
+ - 16 GQA heads, 4 KV heads (4:1)
286
+ - vocab size 32k
287
+ - RoPE + WoRPE + G²LU
288
+ - **Pretraining tokens**: 30B
289
+ - **Precision**: bfloat16
290
+ - **Tokenizer/Embeddings**: facebook/MobileLLM-125M
291
+ - **Hardware**: 1 H100
292
+
293
+
294
+ ## Disclaimer
295
+
296
+ This model is developed as a research model and it hasn't been tested thoroughly regarding synthesis and coherence quality, as its size is somewhat limiting. Use it at your own risk.
297
+
298
+
299
+ ## Citation
300
+ ```
301
+ @misc{ivatchkovitch2026prisma,
302
+ title={Prisma: Interpretability-Inspired Mirrored Transformer Architecture},
303
+ author={Yuri Ivatchkovitch},
304
+ year={2026},
305
+ howpublished={\url{https://huggingface.co/y3i12/Prisma}},
306
+ }
307
+ ```
__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Circuits: Minimal Transformer for Semantic Circuitry Experiments.
3
+
4
+ A clean, self-contained transformer implementation designed for
5
+ experimenting with neural networks.
6
+ """
7
+
8
+ from .config import CircuitConfig
9
+ from .model import CircuitTransformer, count_parameters
10
+ from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
11
+ from .data import get_tokenizer, load_data, create_dataloader, TextDataset
12
+ from .graft_g2lu import G2LU_GraftedModel, G2LU_MLP, load_g2lu_model
13
+
14
+ __all__ = [
15
+ "CircuitConfig",
16
+ "CircuitTransformer",
17
+ "count_parameters",
18
+ "MirroredConfig",
19
+ "MirroredTransformer",
20
+ "count_mirrored_parameters",
21
+ "get_tokenizer",
22
+ "load_data",
23
+ "create_dataloader",
24
+ "TextDataset",
25
+ "G2LU_GraftedModel",
26
+ "G2LU_MLP",
27
+ "load_g2lu_model",
28
+ ]
bench.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Benchmark Circuit transformer family against standard LM tasks.
4
+
5
+ Usage:
6
+ # Single model
7
+ python -m circuits.bench --checkpoint circuits/checkpoints/slot_local_mirrored/best.pt --gpu 0
8
+
9
+ # Compare all architectures
10
+ python -m circuits.bench --compare --gpu 0
11
+
12
+ # Quick sanity check (100 samples per task)
13
+ python -m circuits.bench --compare --gpu 0 --limit 100
14
+
15
+ # Specific tasks
16
+ python -m circuits.bench --checkpoint path/to/best.pt --tasks hellaswag,lambada_openai
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import time
22
+ import torch
23
+ from pathlib import Path
24
+
25
+ import lm_eval
26
+ from lm_eval.api.registry import register_model
27
+
28
+ from .lm_eval_wrapper import CircuitLM
29
+
30
+ # Register so lm_eval can find it
31
+ register_model("circuit")(CircuitLM)
32
+
33
+ DEFAULT_TASKS = "arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,piqa,wikitext,winogrande"
34
+
35
+ # Known checkpoints for --compare mode
36
+ CHECKPOINTS = {
37
+ "standard_12L": "circuits/checkpoints/flat/best.pt",
38
+ "mirrored_9L_wide": "circuits/checkpoints/hier_wide_2/best.pt",
39
+ "mirrored_15L_deep": "circuits/checkpoints/hier_resized/best.pt",
40
+ "slot_local_mirrored": "circuits/checkpoints/slot_local_mirrored/best.pt",
41
+ }
42
+
43
+
44
+ def run_benchmark(checkpoint: str, tasks: str, device: str, limit: int = None, batch_size: int = 1, compile: bool = False):
45
+ """Run lm-eval on a single checkpoint."""
46
+ model_args = f"checkpoint={checkpoint},device={device},batch_size={batch_size},compile={'true' if compile else 'false'}"
47
+
48
+ task_list = tasks.split(",")
49
+
50
+ results = lm_eval.simple_evaluate(
51
+ model="circuit",
52
+ model_args=model_args,
53
+ tasks=task_list,
54
+ limit=limit,
55
+ )
56
+
57
+ return results
58
+
59
+
60
+ def extract_scores(results: dict) -> dict:
61
+ """Pull headline metrics from lm-eval results."""
62
+ scores = {}
63
+ if "results" not in results:
64
+ return scores
65
+ for task_name, task_results in results["results"].items():
66
+ # Get the primary metric (usually acc or acc_norm)
67
+ if "acc_norm,none" in task_results:
68
+ scores[task_name] = task_results["acc_norm,none"]
69
+ elif "acc,none" in task_results:
70
+ scores[task_name] = task_results["acc,none"]
71
+ elif "perplexity,none" in task_results:
72
+ scores[task_name] = task_results["perplexity,none"]
73
+ elif "word_perplexity,none" in task_results:
74
+ scores[task_name] = task_results["word_perplexity,none"]
75
+ return scores
76
+
77
+
78
+ def print_comparison(all_results: dict, tasks: list):
79
+ """Pretty-print comparison table."""
80
+ # Header
81
+ col_width = max(len(t) for t in tasks) + 2
82
+ name_width = max(len(n) for n in all_results) + 2
83
+
84
+ header = f"{'Model':<{name_width}}"
85
+ for task in tasks:
86
+ header += f"{task:>{col_width}}"
87
+ header += f"{' avg':>8}"
88
+ print("\n" + "=" * len(header))
89
+ print(header)
90
+ print("-" * len(header))
91
+
92
+ for name, scores in all_results.items():
93
+ row = f"{name:<{name_width}}"
94
+ vals = []
95
+ for task in tasks:
96
+ val = scores.get(task, None)
97
+ if val is not None:
98
+ row += f"{val:>{col_width}.4f}"
99
+ vals.append(val)
100
+ else:
101
+ row += f"{'N/A':>{col_width}}"
102
+ avg = sum(vals) / len(vals) if vals else 0
103
+ row += f"{avg:>8.4f}"
104
+ print(row)
105
+
106
+ print("=" * len(header))
107
+
108
+
109
+ def main():
110
+ parser = argparse.ArgumentParser(description="Benchmark Circuit transformers")
111
+ parser.add_argument("--checkpoint", type=str, help="Path to single checkpoint")
112
+ parser.add_argument("--compare", action="store_true", help="Compare all known architectures")
113
+ parser.add_argument("--tasks", type=str, default=DEFAULT_TASKS, help="Comma-separated task list")
114
+ parser.add_argument("--gpu", type=int, default=0, help="GPU index")
115
+ parser.add_argument("--limit", type=int, default=None, help="Limit samples per task (for quick testing)")
116
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
117
+ parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
118
+ parser.add_argument("--compile", action="store_true", help="torch.compile models for faster inference")
119
+ args = parser.parse_args()
120
+
121
+ device = f"cuda:{args.gpu}"
122
+ task_list = args.tasks.split(",")
123
+
124
+ if args.compare:
125
+ all_scores = {}
126
+ all_raw = {}
127
+
128
+ # Filter to existing checkpoints
129
+ available = {k: v for k, v in CHECKPOINTS.items() if Path(v).exists()}
130
+ missing = {k: v for k, v in CHECKPOINTS.items() if not Path(v).exists()}
131
+ if missing:
132
+ print(f"Skipping (not found): {', '.join(missing.keys())}")
133
+
134
+ for name, ckpt_path in available.items():
135
+ print(f"\n{'='*60}")
136
+ print(f"Evaluating: {name}")
137
+ print(f"Checkpoint: {ckpt_path}")
138
+ print(f"{'='*60}")
139
+
140
+ t0 = time.time()
141
+ results = run_benchmark(ckpt_path, args.tasks, device, args.limit, args.batch_size, args.compile)
142
+ elapsed = time.time() - t0
143
+
144
+ scores = extract_scores(results)
145
+ all_scores[name] = scores
146
+ all_raw[name] = results.get("results", {})
147
+ print(f" Completed in {elapsed:.0f}s: {scores}")
148
+
149
+ print_comparison(all_scores, task_list)
150
+
151
+ if args.output:
152
+ with open(args.output, "w") as f:
153
+ json.dump({"scores": all_scores, "raw": all_raw}, f, indent=2, default=str)
154
+ print(f"\nResults saved to {args.output}")
155
+
156
+ elif args.checkpoint:
157
+ print(f"Evaluating: {args.checkpoint}")
158
+ t0 = time.time()
159
+ results = run_benchmark(args.checkpoint, args.tasks, device, args.limit, args.batch_size, args.compile)
160
+ elapsed = time.time() - t0
161
+
162
+ scores = extract_scores(results)
163
+ print(f"\nResults ({elapsed:.0f}s):")
164
+ for task, score in scores.items():
165
+ print(f" {task}: {score:.4f}")
166
+
167
+ if args.output:
168
+ with open(args.output, "w") as f:
169
+ json.dump(results, f, indent=2, default=str)
170
+ print(f"\nResults saved to {args.output}")
171
+ else:
172
+ parser.print_help()
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
coherence_eval.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Coherence evaluation for language models.
4
+
5
+ Measures what standard benchmarks can't see:
6
+ Tier 1 — Generation diversity (repetition, collapse detection)
7
+ Tier 2 — Multi-distance prediction (context utilization, skip accuracy)
8
+ Tier 3 — Semantic consistency (chunk similarity over long generations)
9
+
10
+ Usage:
11
+ # Custom checkpoint
12
+ python -m circuits.coherence_eval --checkpoint circuits/checkpoints/model/best.pt
13
+
14
+ # HuggingFace model
15
+ python -m circuits.coherence_eval --model gpt2
16
+
17
+ # Compare models
18
+ python -m circuits.coherence_eval --model EleutherAI/pythia-160m --gpu 0
19
+
20
+ # Quick test (fewer prompts, shorter generation)
21
+ python -m circuits.coherence_eval --checkpoint path/to/model.pt --num-prompts 5 --gen-length 256
22
+
23
+ # Run specific tiers
24
+ python -m circuits.coherence_eval --checkpoint path/to/model.pt --tiers 1,3
25
+ """
26
+
27
+ import argparse
28
+ import json
29
+ import math
30
+ import sys
31
+ import time
32
+ from pathlib import Path
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ # ──────────────────────────────────────────────────────────────────────
38
+ # Default prompts — diverse domains, 10-20 tokens each
39
+ # ──────────────────────────────────────────────────────────────────────
40
+
41
+ DEFAULT_PROMPTS = [
42
+ "A thought observing itself discovers that it",
43
+ "The history of science shows that",
44
+ "In the middle of the night, the old house",
45
+ "The relationship between language and thought has been",
46
+ "When the first settlers arrived, they found",
47
+ "The mathematical proof begins by assuming",
48
+ "She opened the door to find",
49
+ "The economic implications of this policy",
50
+ "Deep beneath the ocean surface, researchers discovered",
51
+ "The most important lesson from this experiment is",
52
+ "According to recent studies, the human brain",
53
+ "The old library contained books that",
54
+ "As the temperature continued to rise, the effects on",
55
+ "The development of artificial intelligence has raised questions about",
56
+ "In the small village at the foot of the mountain",
57
+ "The fundamental principles of democracy require",
58
+ "Looking through the telescope, the astronomer noticed",
59
+ "The relationship between music and emotion",
60
+ "During the industrial revolution, working conditions",
61
+ "The ancient manuscript revealed secrets about",
62
+ ]
63
+
64
+
65
+ # ──────────────────────────────────────────────────────────────────────
66
+ # Model wrapper — unified interface for circuit models and HF models
67
+ # ──────────────────────────────────────────────────────────────────────
68
+
69
+ class ModelWrapper:
70
+ """Unified interface for custom circuit models and HuggingFace models."""
71
+
72
+ def __init__(self, model, tokenizer, device, model_type="hf",
73
+ skip_head=None, skip_k=0, max_seq_len=1024, name="unknown"):
74
+ self.model = model
75
+ self.tokenizer = tokenizer
76
+ self.device = device
77
+ self.model_type = model_type # "circuit" or "hf"
78
+ self.skip_head = skip_head
79
+ self.skip_k = skip_k
80
+ self.max_seq_len = max_seq_len
81
+ self.name = name
82
+
83
+ @classmethod
84
+ def from_checkpoint(cls, path, device):
85
+ """Load a custom circuit model from checkpoint."""
86
+ from .config import CircuitConfig
87
+ from .model import CircuitTransformer
88
+ from .mirrored import MirroredConfig, MirroredTransformer
89
+ from .slotted_mirrored import SlotMirroredConfig, SlotMirroredTransformer
90
+ from .data import get_tokenizer
91
+
92
+ checkpoint = torch.load(path, map_location="cpu", weights_only=False)
93
+ model_type = checkpoint.get("model_type", "standard")
94
+
95
+ if model_type == "slot_mirrored":
96
+ config = SlotMirroredConfig.from_dict(checkpoint["config"])
97
+ model = SlotMirroredTransformer(config).to(device)
98
+ arch_desc = f"SlotMirrored ({config.n_slots} slots)"
99
+ elif model_type == "mirrored":
100
+ config = MirroredConfig.from_dict(checkpoint["config"])
101
+ model = MirroredTransformer(config).to(device)
102
+ arch_desc = "Mirrored"
103
+ else:
104
+ config = CircuitConfig.from_dict(checkpoint["config"])
105
+ model = CircuitTransformer(config).to(device)
106
+ arch_desc = "Standard"
107
+
108
+ # Handle torch.compile prefix
109
+ state_dict = checkpoint["model"]
110
+ if any(k.startswith("_orig_mod.") for k in state_dict):
111
+ state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
112
+ model.load_state_dict(state_dict)
113
+ model.eval()
114
+
115
+ tokenizer = get_tokenizer()
116
+ skip_head = model.skip_head if hasattr(model, 'skip_head') else None
117
+ skip_k = getattr(config, 'aux_skip_k', 0)
118
+ max_seq_len = config.max_seq_len
119
+
120
+ params = sum(p.numel() for p in model.parameters()) / 1e6
121
+ name = f"{Path(path).parent.name}/{Path(path).stem} ({arch_desc}, {params:.1f}M)"
122
+
123
+ return cls(model, tokenizer, device, model_type="circuit",
124
+ skip_head=skip_head, skip_k=skip_k,
125
+ max_seq_len=max_seq_len, name=name)
126
+
127
+ @classmethod
128
+ def from_pretrained(cls, model_name, device):
129
+ """Load a HuggingFace model."""
130
+ from transformers import AutoModelForCausalLM, AutoTokenizer
131
+
132
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
133
+ model = AutoModelForCausalLM.from_pretrained(
134
+ model_name, trust_remote_code=True,
135
+ torch_dtype=torch.float32,
136
+ ).to(device)
137
+ model.eval()
138
+
139
+ max_seq_len = getattr(model.config, 'max_position_embeddings', 1024)
140
+ if tokenizer.pad_token is None:
141
+ tokenizer.pad_token = tokenizer.eos_token
142
+
143
+ params = sum(p.numel() for p in model.parameters()) / 1e6
144
+ name = f"{model_name} ({params:.1f}M)"
145
+
146
+ return cls(model, tokenizer, device, model_type="hf",
147
+ max_seq_len=max_seq_len, name=name)
148
+
149
+ @property
150
+ def has_skip_head(self):
151
+ return self.skip_head is not None and self.skip_k > 0
152
+
153
+ def generate(self, prompt_text, max_new_tokens=512):
154
+ """Generate tokens at temperature 0 (greedy). Returns generated token IDs only."""
155
+ prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
156
+
157
+ with torch.no_grad():
158
+ if self.model_type == "hf":
159
+ output_ids = self.model.generate(
160
+ prompt_ids,
161
+ max_new_tokens=max_new_tokens,
162
+ do_sample=True,
163
+ pad_token_id=self.tokenizer.pad_token_id,
164
+ temperature=0.8,
165
+ top_k=50,
166
+ top_p=0.9,
167
+ repetition_penalty=1.2,
168
+ )
169
+ else:
170
+ output_ids = self.model.generate(
171
+ prompt_ids,
172
+ max_new_tokens=max_new_tokens,
173
+ temperature=0.8,
174
+ top_k=50,
175
+ top_p=0.9,
176
+ repetition_penalty=1.2,
177
+ )
178
+
179
+ # Return only the generated part
180
+ gen_ids = output_ids[0, prompt_ids.shape[1]:]
181
+ return prompt_ids[0], gen_ids
182
+
183
+ def forward_with_hidden(self, input_ids):
184
+ """Forward pass returning (logits, hidden_states, skip_logits_or_None).
185
+ input_ids: [1, L] tensor.
186
+ """
187
+ with torch.no_grad():
188
+ if self.model_type == "hf":
189
+ outputs = self.model(input_ids, output_hidden_states=True)
190
+ logits = outputs.logits
191
+ hidden = outputs.hidden_states[-1]
192
+ return logits, hidden, None
193
+ else:
194
+ # Hook into norm layer to capture pre-lm_head hidden states
195
+ hidden_capture = {}
196
+
197
+ def hook_fn(module, inp, output):
198
+ hidden_capture['h'] = output.detach()
199
+
200
+ handle = self.model.norm.register_forward_hook(hook_fn)
201
+ output = self.model(input_ids)
202
+ handle.remove()
203
+
204
+ logits = output['logits']
205
+ hidden = hidden_capture['h']
206
+
207
+ skip_logits = None
208
+ if self.has_skip_head:
209
+ skip_logits = self.skip_head(hidden)
210
+
211
+ return logits, hidden, skip_logits
212
+
213
+ def forward(self, input_ids):
214
+ """Forward pass returning logits only. input_ids: [1, L] tensor."""
215
+ with torch.no_grad():
216
+ if self.model_type == "hf":
217
+ return self.model(input_ids).logits
218
+ else:
219
+ return self.model(input_ids)['logits']
220
+
221
+
222
+ # ──────────────────────────────────────────────────────────────────────
223
+ # Generation (shared between Tier 1 and Tier 3)
224
+ # ──────────────────────────────────────────────────────────────────────
225
+
226
+ def generate_all(wrapper, prompts, gen_length):
227
+ """Generate from all prompts. Returns list of (prompt_text, prompt_ids, gen_ids)."""
228
+ results = []
229
+ for prompt in prompts:
230
+ prompt_ids, gen_ids = wrapper.generate(prompt, max_new_tokens=gen_length)
231
+ results.append((prompt, prompt_ids, gen_ids))
232
+ print(f" [{len(results)}/{len(prompts)}] {len(gen_ids)} tokens", end="\r")
233
+ print()
234
+ return results
235
+
236
+
237
+ # ──────────────────────────────────────────────────────────────────────
238
+ # Tier 1: Generation Diversity
239
+ # ──────────────────────────────────────────────────────────────────────
240
+
241
+ def ngrams(tokens, n):
242
+ """Extract n-grams from token list."""
243
+ return [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
244
+
245
+
246
+ def compute_diversity(gen_ids):
247
+ """Compute diversity metrics for a single generation."""
248
+ tokens = gen_ids.tolist()
249
+ n = len(tokens)
250
+ if n < 4:
251
+ return {"unique_1g": 0, "unique_2g": 0, "unique_3g": 0, "unique_4g": 0,
252
+ "max_repeat": n, "collapsed": True}
253
+
254
+ results = {}
255
+ for k in [1, 2, 3, 4]:
256
+ grams = ngrams(tokens, k)
257
+ results[f"unique_{k}g"] = len(set(grams)) / len(grams) if grams else 0.0
258
+
259
+ # Max consecutive identical token span
260
+ max_repeat = 1
261
+ current = 1
262
+ for i in range(1, n):
263
+ if tokens[i] == tokens[i - 1]:
264
+ current += 1
265
+ max_repeat = max(max_repeat, current)
266
+ else:
267
+ current = 1
268
+ results["max_repeat"] = max_repeat
269
+
270
+ # Longest repeated n-gram span (any n-gram repeated consecutively)
271
+ max_ngram_repeat = 1
272
+ for ng_size in [2, 3, 4, 5, 8]:
273
+ grams = ngrams(tokens, ng_size)
274
+ streak = 1
275
+ for i in range(1, len(grams)):
276
+ if grams[i] == grams[i - 1]:
277
+ streak += 1
278
+ max_ngram_repeat = max(max_ngram_repeat, streak * ng_size)
279
+ else:
280
+ streak = 1
281
+ results["max_ngram_repeat_span"] = max_ngram_repeat
282
+
283
+ # Collapse: unique 4-grams < 50% or max repeat span > 25% of generation
284
+ results["collapsed"] = (results["unique_4g"] < 0.5) or (max_ngram_repeat > n * 0.25)
285
+
286
+ return results
287
+
288
+
289
+ def eval_diversity(generations, tokenizer, show_samples=3):
290
+ """Tier 1: Compute diversity metrics from pre-generated text."""
291
+ print("\n" + "=" * 60)
292
+ print("TIER 1: Generation Diversity")
293
+ print("=" * 60)
294
+
295
+ all_metrics = []
296
+ sample_texts = []
297
+
298
+ for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
299
+ metrics = compute_diversity(gen_ids)
300
+ metrics["prompt"] = prompt
301
+ metrics["gen_length"] = len(gen_ids)
302
+ all_metrics.append(metrics)
303
+
304
+ if i < show_samples:
305
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True)
306
+ sample_texts.append((prompt, text))
307
+
308
+ n = len(all_metrics)
309
+ if n == 0:
310
+ print(" No generations to evaluate.")
311
+ return {}
312
+
313
+ # Aggregate
314
+ agg = {}
315
+ for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g",
316
+ "max_repeat", "max_ngram_repeat_span"]:
317
+ values = [m[key] for m in all_metrics]
318
+ agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
319
+
320
+ collapse_count = sum(1 for m in all_metrics if m["collapsed"])
321
+ agg["collapse_rate"] = collapse_count / n
322
+ avg_len = sum(m["gen_length"] for m in all_metrics) / n
323
+
324
+ # Print
325
+ print(f"\n Prompts evaluated: {n}")
326
+ print(f" Avg generation length: {avg_len:.0f} tokens")
327
+ print()
328
+ print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
329
+ print(f" {'-' * 50}")
330
+ for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g"]:
331
+ m = agg[key]
332
+ print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
333
+ for key in ["max_repeat", "max_ngram_repeat_span"]:
334
+ m = agg[key]
335
+ print(f" {key:<24} {m['mean']:>8.1f} {int(m['min']):>8d} {int(m['max']):>8d}")
336
+ print(f"\n Collapse rate: {collapse_count}/{n} ({agg['collapse_rate']:.1%})")
337
+
338
+ # Show samples
339
+ if sample_texts:
340
+ print(f"\n --- Sample generations (first {len(sample_texts)}) ---")
341
+ for prompt, text in sample_texts:
342
+ print(f"\n Prompt: \"{prompt}\"")
343
+ preview = text[:400].replace("\n", " ")
344
+ if len(text) > 400:
345
+ preview += "..."
346
+ print(f" Output: {preview}")
347
+
348
+ return {"per_prompt": all_metrics, "aggregate": agg}
349
+
350
+
351
+ # ──────────────────────────────────────────────────────────────────────
352
+ # Tier 2: Multi-Distance Prediction
353
+ # ────────────────���─────────────────────────────────────────────────────
354
+
355
+ def prepare_eval_sequences(wrapper, num_sequences=50, data_source=None):
356
+ """Prepare ground truth sequences for Tier 2."""
357
+ max_len = wrapper.max_seq_len
358
+
359
+ if data_source and Path(data_source).exists():
360
+ with open(data_source) as f:
361
+ text = f.read()
362
+ all_ids = wrapper.tokenizer.encode(text)
363
+ else:
364
+ try:
365
+ from datasets import load_dataset
366
+ print(" Loading WikiText-103 validation...")
367
+ ds = load_dataset("wikitext", "wikitext-103-raw-v1",
368
+ split="validation", trust_remote_code=True)
369
+ text = "\n".join(row["text"] for row in ds if row["text"].strip())
370
+ all_ids = wrapper.tokenizer.encode(text)
371
+ except Exception as e:
372
+ print(f" Could not load eval data: {e}")
373
+ print(f" Install 'datasets' or use --eval-data to provide a text file.")
374
+ return None
375
+
376
+ # Chunk into sequences
377
+ sequences = []
378
+ for i in range(0, len(all_ids) - max_len, max_len):
379
+ seq = torch.tensor(all_ids[i:i + max_len], dtype=torch.long)
380
+ sequences.append(seq)
381
+ if len(sequences) >= num_sequences:
382
+ break
383
+
384
+ if len(sequences) < 2:
385
+ print(" Not enough text for evaluation sequences.")
386
+ return None
387
+
388
+ print(f" Prepared {len(sequences)} sequences of {max_len} tokens")
389
+ return sequences
390
+
391
+
392
+ def eval_context_utilization(wrapper, sequences):
393
+ """Tier 2a: Per-position perplexity grouped by depth bucket."""
394
+ max_len = wrapper.max_seq_len
395
+
396
+ # Adaptive buckets based on max_seq_len
397
+ bucket_bounds = [0, 64, 128, 256, 512]
398
+ if max_len > 512:
399
+ bucket_bounds.append(max_len)
400
+ else:
401
+ bucket_bounds.append(max_len)
402
+ # Remove duplicates and sort
403
+ bucket_bounds = sorted(set(b for b in bucket_bounds if b <= max_len))
404
+ if bucket_bounds[-1] < max_len:
405
+ bucket_bounds.append(max_len)
406
+ buckets = [(bucket_bounds[i], bucket_bounds[i + 1])
407
+ for i in range(len(bucket_bounds) - 1)]
408
+
409
+ # Accumulate per-position losses
410
+ all_losses = []
411
+ for seq in sequences:
412
+ input_ids = seq.unsqueeze(0).to(wrapper.device)
413
+ logits = wrapper.forward(input_ids)
414
+
415
+ shift_logits = logits[0, :-1]
416
+ shift_labels = input_ids[0, 1:]
417
+ per_token_loss = F.cross_entropy(shift_logits, shift_labels, reduction='none')
418
+ all_losses.append(per_token_loss.cpu())
419
+ print(f" [{len(all_losses)}/{len(sequences)}]", end="\r")
420
+ print()
421
+
422
+ # Compute per-bucket stats
423
+ stacked = torch.stack(all_losses) # [N, L-1]
424
+ bucket_results = {}
425
+ for start, end in buckets:
426
+ s = min(start, stacked.shape[1])
427
+ e = min(end, stacked.shape[1])
428
+ if s >= e:
429
+ continue
430
+ bucket_losses = stacked[:, s:e]
431
+ avg_loss = bucket_losses.mean().item()
432
+ bucket_results[f"{start}-{end}"] = {
433
+ "loss": avg_loss,
434
+ "ppl": math.exp(min(avg_loss, 20)), # cap to avoid overflow
435
+ "n_tokens": bucket_losses.numel(),
436
+ }
437
+
438
+ return bucket_results
439
+
440
+
441
+ def eval_skip_accuracy(wrapper, sequences, distances):
442
+ """Tier 2b: Skip head prediction accuracy at various distances."""
443
+ if not wrapper.has_skip_head:
444
+ return None
445
+
446
+ results = {f"t+{K}": {"top1": [], "top5": []} for K in distances}
447
+
448
+ for seq in sequences:
449
+ input_ids = seq.unsqueeze(0).to(wrapper.device)
450
+ _, hidden, _ = wrapper.forward_with_hidden(input_ids)
451
+
452
+ for K in distances:
453
+ if K >= input_ids.shape[1]:
454
+ continue
455
+
456
+ skip_logits = wrapper.skip_head(hidden) # [1, L, V]
457
+ targets = input_ids[0, K:] # tokens at t+K
458
+ preds = skip_logits[0, :-K] # predictions from position t
459
+
460
+ top1 = (preds.argmax(-1) == targets).float().mean().item()
461
+ top5_indices = preds.topk(min(5, preds.shape[-1]), dim=-1).indices
462
+ top5 = (top5_indices == targets.unsqueeze(-1)).any(-1).float().mean().item()
463
+
464
+ results[f"t+{K}"]["top1"].append(top1)
465
+ results[f"t+{K}"]["top5"].append(top5)
466
+
467
+ print(f" [{len(results['t+' + str(distances[0])]['top1'])}/{len(sequences)}]", end="\r")
468
+ print()
469
+
470
+ # Average across sequences
471
+ avg_results = {}
472
+ for key in sorted(results.keys(), key=lambda x: int(x.split("+")[1])):
473
+ vals = results[key]
474
+ if vals["top1"]:
475
+ avg_results[key] = {
476
+ "top1": sum(vals["top1"]) / len(vals["top1"]),
477
+ "top5": sum(vals["top5"]) / len(vals["top5"]),
478
+ }
479
+
480
+ return avg_results
481
+
482
+
483
+ def eval_structural(wrapper, eval_data, distances, num_sequences):
484
+ """Run Tier 2 evaluation."""
485
+ print("\n" + "=" * 60)
486
+ print("TIER 2: Structural Prediction")
487
+ print("=" * 60)
488
+
489
+ sequences = prepare_eval_sequences(wrapper, num_sequences, eval_data)
490
+ if sequences is None:
491
+ return {"context_utilization": None, "skip_accuracy": None}
492
+
493
+ # 2a: Context utilization
494
+ print("\n --- 2a: Context Utilization (PPL by position depth) ---")
495
+ ctx_results = eval_context_utilization(wrapper, sequences)
496
+
497
+ if ctx_results:
498
+ print(f"\n {'Depth':<12} {'Loss':>8} {'PPL':>10} {'Tokens':>10}")
499
+ print(f" {'-' * 42}")
500
+ for bucket, vals in ctx_results.items():
501
+ print(f" {bucket:<12} {vals['loss']:>8.3f} {vals['ppl']:>10.2f} {vals['n_tokens']:>10}")
502
+
503
+ buckets_list = list(ctx_results.values())
504
+ if len(buckets_list) >= 2:
505
+ ratio = buckets_list[0]["ppl"] / buckets_list[-1]["ppl"]
506
+ print(f"\n Context utilization ratio (first/last): {ratio:.2f}x")
507
+ print(f" (Higher = model benefits more from additional context)")
508
+
509
+ # 2b: Skip accuracy
510
+ skip_results = None
511
+ if wrapper.has_skip_head:
512
+ print(f"\n --- 2b: Skip Head Accuracy (trained for t+{wrapper.skip_k}) ---")
513
+ skip_results = eval_skip_accuracy(wrapper, sequences, distances)
514
+
515
+ if skip_results:
516
+ print(f"\n {'Distance':<12} {'Top-1':>8} {'Top-5':>8}")
517
+ print(f" {'-' * 30}")
518
+ for key, vals in skip_results.items():
519
+ trained = " *" if int(key.split("+")[1]) == wrapper.skip_k else ""
520
+ print(f" {key:<12} {vals['top1']:>8.4f} {vals['top5']:>8.4f}{trained}")
521
+ print(f"\n * = trained distance")
522
+ else:
523
+ print("\n Skip head: not available")
524
+
525
+ return {"context_utilization": ctx_results, "skip_accuracy": skip_results}
526
+
527
+
528
+ # ──────────────────────────────────────────────────────────────────────
529
+ # Tier 3: Semantic Consistency
530
+ # ──────────────────────────────────────────────────────────────────────
531
+
532
+ def compute_chunk_similarity(hidden_states, chunk_size=128):
533
+ """Compute cosine similarity between chunks of hidden states.
534
+ hidden_states: [L, D] tensor.
535
+ """
536
+ L, D = hidden_states.shape
537
+ n_chunks = L // chunk_size
538
+
539
+ if n_chunks < 2:
540
+ return None
541
+
542
+ # Mean-pool each chunk
543
+ chunks = []
544
+ for i in range(n_chunks):
545
+ chunk = hidden_states[i * chunk_size:(i + 1) * chunk_size]
546
+ chunks.append(chunk.mean(dim=0))
547
+
548
+ chunk_vecs = torch.stack(chunks)
549
+ chunk_vecs = F.normalize(chunk_vecs, dim=-1)
550
+
551
+ # Pairwise cosine similarity
552
+ sim_matrix = chunk_vecs @ chunk_vecs.T
553
+
554
+ # Upper triangle (excluding diagonal)
555
+ mask = torch.triu(torch.ones_like(sim_matrix, dtype=torch.bool), diagonal=1)
556
+ pairwise_sims = sim_matrix[mask]
557
+
558
+ # Adjacent pairs
559
+ adjacent = [sim_matrix[i, i + 1].item() for i in range(n_chunks - 1)]
560
+
561
+ # Distant pairs (first quarter vs last quarter)
562
+ q1 = max(1, n_chunks // 4)
563
+ distant = []
564
+ for i in range(q1):
565
+ for j in range(n_chunks - q1, n_chunks):
566
+ if i < j:
567
+ distant.append(sim_matrix[i, j].item())
568
+
569
+ return {
570
+ "mean_sim": pairwise_sims.mean().item(),
571
+ "min_sim": pairwise_sims.min().item(),
572
+ "adjacent_sim": sum(adjacent) / len(adjacent),
573
+ "distant_sim": sum(distant) / len(distant) if distant else 0.0,
574
+ "n_chunks": n_chunks,
575
+ }
576
+
577
+
578
+ def eval_consistency(wrapper, generations, chunk_size=128):
579
+ """Tier 3: Semantic consistency of generated text via hidden state similarity."""
580
+ print("\n" + "=" * 60)
581
+ print("TIER 3: Semantic Consistency")
582
+ print("=" * 60)
583
+
584
+ all_metrics = []
585
+
586
+ for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
587
+ if gen_ids.shape[0] < chunk_size * 2:
588
+ continue
589
+
590
+ # Build full sequence: prompt + generated
591
+ full_ids = torch.cat([prompt_ids, gen_ids]).unsqueeze(0).to(wrapper.device)
592
+
593
+ # Trim to max_seq_len
594
+ if full_ids.shape[1] > wrapper.max_seq_len:
595
+ full_ids = full_ids[:, :wrapper.max_seq_len]
596
+
597
+ _, hidden, _ = wrapper.forward_with_hidden(full_ids)
598
+
599
+ # Use only generated part's hidden states
600
+ gen_start = prompt_ids.shape[0]
601
+ gen_hidden = hidden[0, gen_start:] # [gen_len, D]
602
+
603
+ metrics = compute_chunk_similarity(gen_hidden, chunk_size)
604
+ if metrics is not None:
605
+ metrics["prompt"] = prompt
606
+ all_metrics.append(metrics)
607
+
608
+ print(f" [{len(all_metrics)}/{len(generations)}]", end="\r")
609
+ print()
610
+
611
+ if not all_metrics:
612
+ print(" No valid generations for consistency evaluation.")
613
+ return {}
614
+
615
+ n = len(all_metrics)
616
+ agg = {}
617
+ for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim"]:
618
+ values = [m[key] for m in all_metrics]
619
+ agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
620
+
621
+ # Topic drift: how much similarity drops from adjacent to distant chunks
622
+ drift_vals = [m["adjacent_sim"] - m["distant_sim"] for m in all_metrics]
623
+ agg["topic_drift"] = {"mean": sum(drift_vals) / n,
624
+ "min": min(drift_vals), "max": max(drift_vals)}
625
+
626
+ # Print
627
+ print(f"\n Generations evaluated: {n}")
628
+ print(f" Chunk size: {chunk_size} tokens")
629
+ avg_chunks = sum(m["n_chunks"] for m in all_metrics) / n
630
+ print(f" Avg chunks per generation: {avg_chunks:.1f}")
631
+ print()
632
+ print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
633
+ print(f" {'-' * 50}")
634
+ for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim", "topic_drift"]:
635
+ m = agg[key]
636
+ print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
637
+
638
+ return {"per_prompt": all_metrics, "aggregate": agg}
639
+
640
+
641
+ # ──────────────────────────────────────────────────────────────────────
642
+ # Summary
643
+ # ──────────────────────────────────────────────────────────────────────
644
+
645
+ def print_summary(results):
646
+ """Print composite summary scores."""
647
+ print("\n" + "=" * 60)
648
+ print("SUMMARY")
649
+ print("=" * 60)
650
+
651
+ scores = {}
652
+
653
+ # Diversity score: mean unique-4gram
654
+ t1 = results.get("tier1_diversity", {})
655
+ if t1 and "aggregate" in t1:
656
+ div_score = t1["aggregate"].get("unique_4g", {}).get("mean", None)
657
+ collapse = t1["aggregate"].get("collapse_rate", None)
658
+ if div_score is not None:
659
+ scores["diversity"] = div_score
660
+ print(f" Diversity (unique 4-gram): {div_score:.3f}", end="")
661
+ if collapse is not None:
662
+ print(f" (collapse: {collapse:.0%})", end="")
663
+ print()
664
+
665
+ # Context utilization ratio
666
+ t2 = results.get("tier2_structural", {})
667
+ if t2:
668
+ ctx = t2.get("context_utilization")
669
+ if ctx:
670
+ buckets = list(ctx.values())
671
+ if len(buckets) >= 2:
672
+ ratio = buckets[0]["ppl"] / buckets[-1]["ppl"]
673
+ scores["context_util"] = ratio
674
+ print(f" Context utilization: {ratio:.2f}x")
675
+
676
+ skip = t2.get("skip_accuracy")
677
+ if skip:
678
+ # Report accuracy at trained distance
679
+ trained_key = None
680
+ for key in skip:
681
+ trained_key = key # use first available
682
+ break
683
+ if trained_key:
684
+ top5 = skip[trained_key]["top5"]
685
+ scores["skip_top5"] = top5
686
+ print(f" Skip accuracy ({trained_key} top-5): {top5:.4f}")
687
+
688
+ # Coherence score: mean chunk similarity
689
+ t3 = results.get("tier3_consistency", {})
690
+ if t3 and "aggregate" in t3:
691
+ coh_score = t3["aggregate"].get("mean_sim", {}).get("mean", None)
692
+ drift = t3["aggregate"].get("topic_drift", {}).get("mean", None)
693
+ if coh_score is not None:
694
+ scores["coherence"] = coh_score
695
+ print(f" Coherence (chunk sim): {coh_score:.3f}", end="")
696
+ if drift is not None:
697
+ print(f" (drift: {drift:.3f})", end="")
698
+ print()
699
+
700
+ results["summary"] = scores
701
+ return scores
702
+
703
+
704
+ # ──────────────────────────────────────────────────────────────────────
705
+ # Main
706
+ # ──────────────────────────────────────────────────────────────────────
707
+
708
+ def parse_args():
709
+ parser = argparse.ArgumentParser(
710
+ description="Coherence evaluation for language models",
711
+ formatter_class=argparse.RawDescriptionHelpFormatter,
712
+ )
713
+
714
+ # Model source (mutually exclusive)
715
+ group = parser.add_mutually_exclusive_group(required=True)
716
+ group.add_argument("--checkpoint", type=str, help="Path to circuit model checkpoint")
717
+ group.add_argument("--model", type=str, help="HuggingFace model name or path")
718
+
719
+ # Evaluation config
720
+ parser.add_argument("--prompts", type=str, help="File with prompts (one per line)")
721
+ parser.add_argument("--num-prompts", type=int, default=20,
722
+ help="Number of prompts to use (default: 20)")
723
+ parser.add_argument("--gen-length", type=int, default=512,
724
+ help="Tokens to generate per prompt (default: 512)")
725
+ parser.add_argument("--eval-data", type=str,
726
+ help="Text file for Tier 2 (default: WikiText-103 validation)")
727
+ parser.add_argument("--num-sequences", type=int, default=50,
728
+ help="Number of sequences for Tier 2 (default: 50)")
729
+ parser.add_argument("--chunk-size", type=int, default=128,
730
+ help="Chunk size for Tier 3 similarity (default: 128)")
731
+ parser.add_argument("--distances", type=str, default="2,5,10,25,50,100",
732
+ help="Skip distances for Tier 2b (default: 2,5,10,25,50,100)")
733
+ parser.add_argument("--tiers", type=str, default="1,2,3",
734
+ help="Which tiers to run (default: 1,2,3)")
735
+
736
+ # Hardware
737
+ parser.add_argument("--gpu", type=int, default=0, help="GPU index (default: 0)")
738
+
739
+ # Output
740
+ parser.add_argument("--output", type=str, help="Save results to JSON file")
741
+ parser.add_argument("--samples", type=int, default=3,
742
+ help="Number of sample generations to display (default: 3)")
743
+
744
+ return parser.parse_args()
745
+
746
+
747
+ def main():
748
+ args = parse_args()
749
+
750
+ device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
751
+ tiers = [int(t) for t in args.tiers.split(",")]
752
+ distances = [int(d) for d in args.distances.split(",")]
753
+
754
+ # Load model
755
+ print("=" * 60)
756
+ print("Coherence Evaluation")
757
+ print("=" * 60)
758
+
759
+ if args.checkpoint:
760
+ print(f"Loading: {args.checkpoint}")
761
+ wrapper = ModelWrapper.from_checkpoint(args.checkpoint, device)
762
+ else:
763
+ print(f"Loading: {args.model}")
764
+ wrapper = ModelWrapper.from_pretrained(args.model, device)
765
+
766
+ print(f"Model: {wrapper.name}")
767
+ print(f"Device: {device}")
768
+ print(f"Max seq len: {wrapper.max_seq_len}")
769
+ if wrapper.has_skip_head:
770
+ print(f"Skip head: t+{wrapper.skip_k}")
771
+ print(f"Tiers: {tiers}")
772
+
773
+ # Load prompts
774
+ if args.prompts:
775
+ with open(args.prompts) as f:
776
+ prompts = [line.strip() for line in f if line.strip()]
777
+ else:
778
+ prompts = DEFAULT_PROMPTS
779
+ prompts = prompts[:args.num_prompts]
780
+ print(f"Prompts: {len(prompts)}")
781
+
782
+ results = {"model": wrapper.name}
783
+ t0 = time.time()
784
+
785
+ # Generate once for Tier 1 and Tier 3
786
+ generations = None
787
+ if 1 in tiers or 3 in tiers:
788
+ print(f"\nGenerating {args.gen_length} tokens from {len(prompts)} prompts...")
789
+ generations = generate_all(wrapper, prompts, args.gen_length)
790
+
791
+ # Tier 1
792
+ if 1 in tiers and generations:
793
+ results["tier1_diversity"] = eval_diversity(
794
+ generations, wrapper.tokenizer, show_samples=args.samples)
795
+
796
+ # Tier 2
797
+ if 2 in tiers:
798
+ results["tier2_structural"] = eval_structural(
799
+ wrapper, args.eval_data, distances, args.num_sequences)
800
+
801
+ # Tier 3
802
+ if 3 in tiers and generations:
803
+ results["tier3_consistency"] = eval_consistency(
804
+ wrapper, generations, args.chunk_size)
805
+
806
+ # Summary
807
+ print_summary(results)
808
+
809
+ elapsed = time.time() - t0
810
+ print(f"\nTotal time: {elapsed:.0f}s")
811
+
812
+ # Save
813
+ if args.output:
814
+ def make_serializable(obj):
815
+ if isinstance(obj, dict):
816
+ return {k: make_serializable(v) for k, v in obj.items()}
817
+ elif isinstance(obj, list):
818
+ return [make_serializable(v) for v in obj]
819
+ elif isinstance(obj, torch.Tensor):
820
+ return obj.tolist()
821
+ elif isinstance(obj, float):
822
+ if math.isnan(obj) or math.isinf(obj):
823
+ return str(obj)
824
+ return obj
825
+
826
+ out_path = Path(args.output)
827
+ out_path.parent.mkdir(parents=True, exist_ok=True)
828
+ with open(out_path, "w") as f:
829
+ json.dump(make_serializable(results), f, indent=2)
830
+ print(f"Results saved to {args.output}")
831
+
832
+
833
+ if __name__ == "__main__":
834
+ main()
config.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for Circuit Transformer experiments.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ import argparse
7
+
8
+
9
+ @dataclass
10
+ class CircuitConfig:
11
+ """Configuration for CircuitTransformer model and training."""
12
+
13
+ # Model architecture
14
+ vocab_size: int = 50257 # GPT-2 tokenizer
15
+ hidden_size: int = 256
16
+ num_heads: int = 8
17
+ num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA)
18
+ num_layers: int = 6
19
+ max_seq_len: int = 512
20
+ dropout: float = 0.0
21
+
22
+ # Training
23
+ batch_size: int = 32
24
+ learning_rate: float = 3e-4
25
+ min_lr: float = 0.0
26
+ weight_decay: float = 0.1
27
+ warmup_steps: int = 100
28
+ epochs: int = 10
29
+ grad_clip: float = 1.0
30
+ reset: bool = False
31
+
32
+ # Hardware
33
+ gpu: int = 0
34
+ fp16: bool = True
35
+ bf16: bool = False
36
+ compile: bool = False
37
+
38
+ # Logging
39
+ log_every: int = 50
40
+ save_every: int = 5000
41
+ checkpoint_dir: str = "./circuits/checkpoints"
42
+
43
+ def __post_init__(self):
44
+ assert self.hidden_size % self.num_heads == 0, \
45
+ f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
46
+ if self.num_kv_heads is not None:
47
+ assert self.num_heads % self.num_kv_heads == 0, \
48
+ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
49
+
50
+ # Presets
51
+ @classmethod
52
+ def tiny(cls) -> "CircuitConfig":
53
+ """~2M params"""
54
+ return cls(hidden_size=128, num_heads=4, num_layers=4)
55
+
56
+ @classmethod
57
+ def small(cls) -> "CircuitConfig":
58
+ """~10M params"""
59
+ return cls(hidden_size=256, num_heads=8, num_layers=6)
60
+
61
+ @classmethod
62
+ def medium(cls) -> "CircuitConfig":
63
+ """~50M params"""
64
+ return cls(hidden_size=512, num_heads=8, num_layers=12)
65
+
66
+ @classmethod
67
+ def medium_plus(cls) -> "CircuitConfig":
68
+ """~50M params"""
69
+ return cls(hidden_size=512, num_heads=8, num_layers=15)
70
+
71
+
72
+ @classmethod
73
+ def medium_wide_9(cls) -> "CircuitConfig":
74
+ """~50M params"""
75
+ return cls(hidden_size=640, num_heads=10, num_layers=9)
76
+
77
+ @classmethod
78
+ def medium_wide_11(cls) -> "CircuitConfig":
79
+ """~50M params"""
80
+ return cls(hidden_size=640, num_heads=10, num_layers=11)
81
+
82
+ @classmethod
83
+ def medium_large(cls) -> "CircuitConfig":
84
+ """~90M params"""
85
+ return cls(hidden_size=768, num_heads=12, num_layers=12)
86
+
87
+ @classmethod
88
+ def large(cls) -> "CircuitConfig":
89
+ return cls(hidden_size=1280, num_heads=20, num_layers=11)
90
+
91
+ # Auxiliary objectives
92
+ aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled)
93
+ aux_skip_weight: float = 0.1 # weight for auxiliary skip loss
94
+
95
+ # Word-position RoPE (SemRoPE)
96
+ word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled)
97
+ word_rope_base: float = 10.0 # frequency base for word-position RoPE
98
+
99
+ # Factorized embedding / MLP head
100
+ embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size)
101
+ head_dim: int = 0 # MLP head intermediate dim (0 = linear head)
102
+
103
+ def to_dict(self) -> dict:
104
+ """Convert to dictionary for serialization."""
105
+ d = {
106
+ "vocab_size": self.vocab_size,
107
+ "hidden_size": self.hidden_size,
108
+ "num_heads": self.num_heads,
109
+ "num_layers": self.num_layers,
110
+ "max_seq_len": self.max_seq_len,
111
+ "dropout": self.dropout,
112
+ }
113
+ if self.num_kv_heads is not None:
114
+ d["num_kv_heads"] = self.num_kv_heads
115
+ if self.aux_skip_k > 0:
116
+ d["aux_skip_k"] = self.aux_skip_k
117
+ d["aux_skip_weight"] = self.aux_skip_weight
118
+ if self.word_rope_dims > 0:
119
+ d["word_rope_dims"] = self.word_rope_dims
120
+ d["word_rope_base"] = self.word_rope_base
121
+ if self.embed_dim > 0:
122
+ d["embed_dim"] = self.embed_dim
123
+ if self.head_dim > 0:
124
+ d["head_dim"] = self.head_dim
125
+ return d
126
+
127
+ @classmethod
128
+ def from_dict(cls, d: dict) -> "CircuitConfig":
129
+ """Create from dictionary."""
130
+ return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
131
+
132
+
133
+ def parse_args() -> tuple[CircuitConfig, argparse.Namespace]:
134
+ """Parse CLI arguments and return config + extra args."""
135
+ parser = argparse.ArgumentParser(description="Circuit Transformer Training")
136
+
137
+ # Data
138
+ parser.add_argument("--data", type=str, required=True,
139
+ help="Data source: path/to/file.txt, path/to/dir/, or hf:dataset_name")
140
+ parser.add_argument("--text-column", type=str, default="text",
141
+ help="Column name for HF datasets (default: text)")
142
+ parser.add_argument("--data-format", type=str, choices=["text", "chat"], default="text",
143
+ help="Data format: text (single column) or chat (system + conversations)")
144
+ parser.add_argument("--num-samples", type=int, default=None,
145
+ help="Limit samples from HF dataset")
146
+ parser.add_argument("--cache-dir", type=str, default="./circuits/.cache",
147
+ help="Cache directory for tokenized data")
148
+ parser.add_argument("--no-cache", action="store_true",
149
+ help="Disable data caching")
150
+ parser.add_argument("--val-split", type=float, default=0.05,
151
+ help="Fraction of data for validation (default: 0.05, 0 to disable)")
152
+
153
+ # Model architecture
154
+ # TODO: Remove `slot_mirrored`
155
+ parser.add_argument("--arch", type=str, choices=["standard", "mirrored", "graft_g2lu"], default="standard",
156
+ help="Model architecture (default: standard)")
157
+ parser.add_argument("--preset", type=str, choices=["tiny", "small", "medium", "medium_plus", "medium_large", "medium_wide_9", "medium_wide_11", "large"],
158
+ help="Use preset configuration")
159
+ parser.add_argument("--dims", type=int, default=None, help="Hidden size")
160
+ parser.add_argument("--layers", type=int, default=None, help="Number of layers")
161
+ parser.add_argument("--heads", type=int, default=None, help="Number of attention heads")
162
+ parser.add_argument("--kv-heads", type=int, default=None,
163
+ help="Number of KV heads for GQA (default: same as --heads for MHA)")
164
+ parser.add_argument("--context-length", type=int, default=None, help="Max sequence length")
165
+ parser.add_argument("--dropout", type=float, default=None, help="Dropout rate")
166
+ parser.add_argument("--tokenizer", type=str, default="gpt2",
167
+ help="Tokenizer to use (default: gpt2, e.g. facebook/MobileLLM-125M)")
168
+
169
+ # Mirrored architecture specific
170
+ parser.add_argument("--n-middle", type=int, default=2,
171
+ help="Unique middle layers for mirrored arch (default: 2)")
172
+ parser.add_argument("--share-attention", action="store_true", default=True,
173
+ help="Share attention weights between mirror pairs (default)")
174
+ parser.add_argument("--no-share-attention", dest="share_attention", action="store_false",
175
+ help="Separate attention weights per direction")
176
+
177
+ # G²LU gating
178
+ parser.add_argument("--no-g2lu", action="store_true",
179
+ help="Disable G²LU (use vanilla SwiGLU in mirrored arch)")
180
+
181
+ # Auxiliary objectives
182
+ parser.add_argument("--aux-skip", type=int, default=0,
183
+ help="Skip-ahead prediction distance (0 = disabled, e.g. 5 predicts t+5)")
184
+ parser.add_argument("--aux-weight", type=float, default=0.1,
185
+ help="Weight for auxiliary skip loss (default: 0.1)")
186
+
187
+ # Word-position RoPE (SemRoPE)
188
+ parser.add_argument("--word-rope-dims", type=int, default=0,
189
+ help="Head dims dedicated to word-position RoPE (0=disabled, try 8 or 16)")
190
+ parser.add_argument("--word-rope-base", type=float, default=10.0,
191
+ help="Frequency base for word-position RoPE (default: 10.0)")
192
+
193
+ # Factorized embedding / MLP head
194
+ parser.add_argument("--embed-dim", type=int, default=0,
195
+ help="Factorized embedding dim (0=use hidden_size, e.g. 256)")
196
+ parser.add_argument("--head-dim", type=int, default=0,
197
+ help="MLP head intermediate dim (0=linear head, e.g. 512)")
198
+
199
+ # G²LU gate grafting
200
+ parser.add_argument("--pretrained", type=str, default=None,
201
+ help="HuggingFace model for graft_g2lu (e.g. meta-llama/Llama-3.2-1B)")
202
+ parser.add_argument("--align-weight", type=float, default=1.0,
203
+ help="Alignment loss weight for G²LU grafting (default: 1.0)")
204
+ parser.add_argument("--graft-warmup", type=int, default=500,
205
+ help="Blend warmup steps: SwiGLU→G²LU transition (default: 500)")
206
+
207
+ # Training
208
+ parser.add_argument("--epochs", type=int, default=None)
209
+ parser.add_argument("--batch-size", type=int, default=None)
210
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate")
211
+ parser.add_argument("--min-lr", type=float, default=None,
212
+ help="Minimum learning rate for cosine decay (default: 0)")
213
+ parser.add_argument("--weight-decay", type=float, default=None)
214
+ parser.add_argument("--warmup-steps", type=int, default=None)
215
+ parser.add_argument("--grad-clip", type=float, default=None)
216
+ parser.add_argument("--grad-accum", type=int, default=1,
217
+ help="Gradient accumulation steps (effective batch = batch_size * grad_accum)")
218
+
219
+ # Hardware
220
+ parser.add_argument("--gpu", type=int, default=0)
221
+ parser.add_argument("--fp16", action="store_true", help="Use FP16 mixed precision (with GradScaler)")
222
+ parser.add_argument("--bf16", action="store_true", help="Use BF16 mixed precision (no scaler needed)")
223
+ parser.add_argument("--no-fp16", action="store_true", help="Disable mixed precision (FP32)")
224
+ parser.add_argument("--compile", action="store_true", help="Use torch.compile")
225
+
226
+ # Logging/Checkpointing
227
+ parser.add_argument("--log-every", type=int, default=None)
228
+ parser.add_argument("--save-every", type=int, default=None)
229
+ parser.add_argument("--val-every", type=int, default=0,
230
+ help="Run validation every N steps (0 = only at epoch end)")
231
+ parser.add_argument("--checkpoint-dir", type=str, default=None)
232
+ parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
233
+ parser.add_argument("--reset", action="store_true", default=False, help="When resuming the training resets steps and optimizers")
234
+
235
+ args = parser.parse_args()
236
+
237
+ # Build config from preset or defaults
238
+ if args.preset:
239
+ config = getattr(CircuitConfig, args.preset)()
240
+ else:
241
+ config = CircuitConfig()
242
+
243
+ # Override with explicit args
244
+ if args.dims is not None:
245
+ config.hidden_size = args.dims
246
+ if args.layers is not None:
247
+ config.num_layers = args.layers
248
+ if args.heads is not None:
249
+ config.num_heads = args.heads
250
+ if args.kv_heads is not None:
251
+ config.num_kv_heads = args.kv_heads
252
+ if args.context_length is not None:
253
+ config.max_seq_len = args.context_length
254
+ if args.dropout is not None:
255
+ config.dropout = args.dropout
256
+ if args.epochs is not None:
257
+ config.epochs = args.epochs
258
+ if args.batch_size is not None:
259
+ config.batch_size = args.batch_size
260
+ if args.lr is not None:
261
+ config.learning_rate = args.lr
262
+ if args.min_lr is not None:
263
+ config.min_lr = args.min_lr
264
+ if args.weight_decay is not None:
265
+ config.weight_decay = args.weight_decay
266
+ if args.warmup_steps is not None:
267
+ config.warmup_steps = args.warmup_steps
268
+ if args.grad_clip is not None:
269
+ config.grad_clip = args.grad_clip
270
+ if args.log_every is not None:
271
+ config.log_every = args.log_every
272
+ if args.save_every is not None:
273
+ config.save_every = args.save_every
274
+ if args.checkpoint_dir is not None:
275
+ config.checkpoint_dir = args.checkpoint_dir
276
+
277
+ # Auxiliary objectives
278
+ if args.aux_skip > 0:
279
+ config.aux_skip_k = args.aux_skip
280
+ config.aux_skip_weight = args.aux_weight
281
+
282
+ # Word-position RoPE
283
+ if args.word_rope_dims > 0:
284
+ config.word_rope_dims = args.word_rope_dims
285
+ config.word_rope_base = args.word_rope_base
286
+
287
+ # Factorized embedding / MLP head
288
+ if args.embed_dim > 0:
289
+ config.embed_dim = args.embed_dim
290
+ if args.head_dim > 0:
291
+ config.head_dim = args.head_dim
292
+
293
+ config.gpu = args.gpu
294
+ if args.bf16:
295
+ config.bf16 = True
296
+ config.fp16 = False
297
+ elif args.no_fp16:
298
+ config.fp16 = False
299
+ config.bf16 = False
300
+ elif args.fp16:
301
+ config.fp16 = True
302
+ config.bf16 = False
303
+ config.compile = args.compile
304
+ config.reset = args.reset
305
+
306
+ return config, args
data.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading utilities for Circuit Transformer.
3
+
4
+ Supports:
5
+ - Single text file: --data path/to/file.txt
6
+ - Directory of text files: --data path/to/dir/
7
+ - HuggingFace dataset: --data hf:dataset_name
8
+
9
+ Caching:
10
+ - HF datasets: memory-mapped binary files (.bin) — O(1) RAM
11
+ - Text files: torch .pt files (legacy, in-memory)
12
+ - Cache location: ./circuits/.cache/ (or custom via cache_dir)
13
+
14
+ Parallelism:
15
+ - HF datasets tokenized via dataset.map(num_proc=N) — multiprocessing, bypasses GIL
16
+ - Fast tokenizer uses Rust internally — additional parallelism within each worker
17
+ """
18
+
19
+ import os
20
+ import struct
21
+ import hashlib
22
+ import multiprocessing
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch.utils.data import Dataset, DataLoader
28
+
29
+ DEFAULT_CACHE_DIR = "./circuits/.cache"
30
+
31
+ # Memmap binary format:
32
+ # Header: 8 bytes = [uint32 n_chunks, uint32 max_seq_len]
33
+ # Data: n_chunks * max_seq_len * 4 bytes (int32, row-major)
34
+ HEADER_SIZE = 8
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Cache utilities
39
+ # ---------------------------------------------------------------------------
40
+
41
+ def _cache_key(data_source: str, max_seq_len: int, num_samples: int | None) -> str:
42
+ """Generate cache filename from parameters."""
43
+ key_str = f"{data_source}|{max_seq_len}|{num_samples}"
44
+ hash_val = hashlib.md5(key_str.encode()).hexdigest()[:12]
45
+ name = data_source.replace("/", "_").replace(":", "_").replace(".", "_")[-30:]
46
+ return f"{name}_{max_seq_len}_{hash_val}.bin"
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Dataset classes
51
+ # ---------------------------------------------------------------------------
52
+
53
+ class MemmapDataset(Dataset):
54
+ """Dataset backed by memory-mapped binary file. O(1) RAM regardless of size."""
55
+
56
+ def __init__(self, path, start=None, end=None):
57
+ self.path = str(path)
58
+ with open(self.path, 'rb') as f:
59
+ total, self.max_seq_len = struct.unpack('II', f.read(HEADER_SIZE))
60
+ self._total = total
61
+ self.data = np.memmap(
62
+ self.path, dtype=np.int32, mode='r',
63
+ offset=HEADER_SIZE, shape=(total, self.max_seq_len),
64
+ )
65
+ self.start = start if start is not None else 0
66
+ self.end = end if end is not None else total
67
+
68
+ def __len__(self):
69
+ return self.end - self.start
70
+
71
+ def __getitem__(self, idx):
72
+ tokens = torch.from_numpy(self.data[self.start + idx].copy()).long()
73
+ return {"input_ids": tokens, "labels": tokens.clone()}
74
+
75
+ def split(self, val_fraction=0.1):
76
+ """Split into (train, val) datasets. Both share the same memmap file."""
77
+ total = self.end - self.start
78
+ n_val = max(1, int(total * val_fraction))
79
+ train = MemmapDataset(self.path, self.start, self.end - n_val)
80
+ val = MemmapDataset(self.path, self.end - n_val, self.end)
81
+ return train, val
82
+
83
+
84
+ class TextDataset(Dataset):
85
+ """Simple in-memory dataset from tokenized chunks. For small datasets."""
86
+
87
+ def __init__(self, token_chunks: list[list[int]], max_seq_len: int):
88
+ self.chunks = token_chunks
89
+ self.max_seq_len = max_seq_len
90
+
91
+ def __len__(self) -> int:
92
+ return len(self.chunks)
93
+
94
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
95
+ tokens = self.chunks[idx]
96
+ if len(tokens) < self.max_seq_len:
97
+ tokens = tokens + [0] * (self.max_seq_len - len(tokens))
98
+ else:
99
+ tokens = tokens[: self.max_seq_len]
100
+ input_ids = torch.tensor(tokens, dtype=torch.long)
101
+ return {"input_ids": input_ids, "labels": input_ids.clone()}
102
+
103
+ def split(self, val_fraction=0.1):
104
+ """Split into (train, val) datasets with shuffle."""
105
+ import random
106
+ random.shuffle(self.chunks)
107
+ n_val = max(1, int(len(self.chunks) * val_fraction))
108
+ val = TextDataset(self.chunks[:n_val], self.max_seq_len)
109
+ train = TextDataset(self.chunks[n_val:], self.max_seq_len)
110
+ return train, val
111
+
112
+
113
+ # ---------------------------------------------------------------------------
114
+ # Tokenizer
115
+ # ---------------------------------------------------------------------------
116
+
117
+ class _SentencePieceTokenizer:
118
+ """Minimal tokenizer wrapper using sentencepiece directly.
119
+ Bypasses transformers tokenizer bugs across versions."""
120
+
121
+ def __init__(self, model_path, name):
122
+ import sentencepiece as spm
123
+ self.sp = spm.SentencePieceProcessor()
124
+ self.sp.Load(model_path)
125
+ self._vocab_size = self.sp.GetPieceSize()
126
+ self.eos_token_id = self.sp.eos_id()
127
+ self.bos_token_id = self.sp.bos_id()
128
+ self.eos_token = self.sp.IdToPiece(self.eos_token_id)
129
+ self.bos_token = self.sp.IdToPiece(self.bos_token_id)
130
+ self.pad_token = None
131
+ self.pad_token_id = None
132
+ self.name_or_path = name
133
+
134
+ def __len__(self):
135
+ return self._vocab_size
136
+
137
+ @property
138
+ def vocab_size(self):
139
+ return self._vocab_size
140
+
141
+ def encode(self, text, add_special_tokens=False, return_tensors=None):
142
+ ids = self.sp.Encode(text)
143
+ if return_tensors == "pt":
144
+ import torch
145
+ return torch.tensor([ids])
146
+ return ids
147
+
148
+ def decode(self, ids, skip_special_tokens=False):
149
+ if hasattr(ids, 'tolist'):
150
+ ids = ids.tolist()
151
+ return self.sp.Decode(list(ids))
152
+
153
+ def __call__(self, texts, add_special_tokens=False, **kwargs):
154
+ if isinstance(texts, str):
155
+ texts = [texts]
156
+ return {"input_ids": [self.sp.Encode(t) for t in texts]}
157
+
158
+
159
+ def get_tokenizer(name: str = "gpt2"):
160
+ """Get tokenizer from HuggingFace, with sentencepiece fallback.
161
+
162
+ Args:
163
+ name: Tokenizer name or path. Default "gpt2" (50257 vocab).
164
+ Use e.g. "facebook/MobileLLM-125M" for 32K vocab.
165
+ """
166
+ from transformers import AutoTokenizer
167
+
168
+ # Try AutoTokenizer (fast then slow)
169
+ for use_fast in (True, False):
170
+ try:
171
+ tokenizer = AutoTokenizer.from_pretrained(name, use_fast=use_fast,
172
+ trust_remote_code=True)
173
+ if isinstance(tokenizer, bool):
174
+ continue
175
+ if tokenizer.pad_token is None:
176
+ tokenizer.pad_token = tokenizer.eos_token
177
+ return tokenizer
178
+ except Exception:
179
+ continue
180
+
181
+ # Fallback: load sentencepiece model directly (bypasses transformers bugs)
182
+ print(f"AutoTokenizer failed for {name}, falling back to sentencepiece")
183
+ from huggingface_hub import hf_hub_download
184
+ model_path = hf_hub_download(name, "tokenizer.model")
185
+ tokenizer = _SentencePieceTokenizer(model_path, name)
186
+ if tokenizer.pad_token is None:
187
+ tokenizer.pad_token = tokenizer.eos_token
188
+ tokenizer.pad_token_id = tokenizer.eos_token_id
189
+ return tokenizer
190
+
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # Streaming memmap writer
194
+ # ---------------------------------------------------------------------------
195
+
196
+ def _stream_chunks_to_memmap(tokenized, total_examples, max_seq_len, output_path,
197
+ num_workers=1, read_batch=10_000):
198
+ """Stream tokenized examples into a memory-mapped binary file.
199
+
200
+ Single-process, numpy-batch approach. Reads batches from Arrow dataset,
201
+ flattens to numpy int32, writes complete chunks to disk.
202
+ Memory: O(read_batch * avg_seq_len * 4 bytes).
203
+ No fork, no multiprocessing, no OOM.
204
+ """
205
+ from itertools import chain
206
+ from tqdm import tqdm
207
+
208
+ temp_path = str(output_path) + ".tmp"
209
+ n_chunks = 0
210
+ total_tokens = 0
211
+ carryover = np.array([], dtype=np.int32)
212
+
213
+ n_batches = (total_examples + read_batch - 1) // read_batch
214
+
215
+ with open(temp_path, 'wb') as f:
216
+ f.write(struct.pack('II', 0, max_seq_len)) # placeholder header
217
+
218
+ for batch_start in tqdm(range(0, total_examples, read_batch),
219
+ total=n_batches, desc="Chunking",
220
+ mininterval=1.0):
221
+ batch_end = min(batch_start + read_batch, total_examples)
222
+ batch_ids = tokenized[batch_start:batch_end]["input_ids"]
223
+
224
+ # Count tokens, flatten Arrow→numpy without intermediate Python list
225
+ n_tok = sum(len(ids) for ids in batch_ids if ids)
226
+ if n_tok == 0:
227
+ del batch_ids
228
+ continue
229
+
230
+ flat = np.fromiter(
231
+ chain.from_iterable(ids for ids in batch_ids if ids),
232
+ dtype=np.int32, count=n_tok,
233
+ )
234
+ del batch_ids
235
+ total_tokens += n_tok
236
+
237
+ # Prepend carryover from previous batch
238
+ if len(carryover) > 0:
239
+ flat = np.concatenate([carryover, flat])
240
+
241
+ # Write complete chunks
242
+ n_complete = len(flat) // max_seq_len
243
+ if n_complete > 0:
244
+ f.write(flat[:n_complete * max_seq_len].tobytes())
245
+ n_chunks += n_complete
246
+
247
+ carryover = flat[n_complete * max_seq_len:].copy()
248
+ del flat
249
+
250
+ # Handle remaining tokens
251
+ if len(carryover) >= 32:
252
+ padded = np.zeros(max_seq_len, dtype=np.int32)
253
+ padded[:len(carryover)] = carryover
254
+ f.write(padded.tobytes())
255
+ n_chunks += 1
256
+
257
+ # Write actual count into header
258
+ f.seek(0)
259
+ f.write(struct.pack('II', n_chunks, max_seq_len))
260
+
261
+ os.rename(temp_path, str(output_path))
262
+ size_gb = os.path.getsize(output_path) / 1e9
263
+ print(f"Total tokens: {total_tokens:,} → {n_chunks:,} chunks ({size_gb:.1f} GB)")
264
+ return n_chunks
265
+
266
+
267
+ # ---------------------------------------------------------------------------
268
+ # HuggingFace dataset loader (parallel + memmap)
269
+ # ---------------------------------------------------------------------------
270
+
271
+ def _flatten_chat(example):
272
+ """Convert chat format (system + conversations list) to plain text.
273
+
274
+ Handles datasets like Bespoke-Stratos-17k and OpenThoughts-114k
275
+ which store data as: system (str) + conversations (list of {from, value}).
276
+ """
277
+ parts = []
278
+ if example.get("system"):
279
+ parts.append(example["system"].strip())
280
+ for msg in example.get("conversations", []):
281
+ value = msg.get("value", "")
282
+ if value:
283
+ parts.append(value.strip())
284
+ return {"text": "\n\n".join(parts)}
285
+
286
+
287
+ def _estimate_avg_chars(dataset, text_column: str, n_sample: int = 200) -> float:
288
+ """Estimate average text length from a sample of the dataset."""
289
+ n = min(n_sample, len(dataset))
290
+ total = sum(len(dataset[i][text_column] or "") for i in range(n))
291
+ return total / max(n, 1)
292
+
293
+
294
+ def _adaptive_params(avg_chars: float, n_examples: int):
295
+ """Scale worker count, batch sizes based on average example length.
296
+
297
+ Long examples (chain-of-thought reasoning) need smaller batches and fewer
298
+ workers to avoid OOM on memory-constrained systems (especially WSL).
299
+ """
300
+ cpu_count = max(1, multiprocessing.cpu_count() - 1)
301
+
302
+ if avg_chars > 20_000: # very long (OpenThoughts-style, ~7K+ tokens)
303
+ num_proc = min(cpu_count, 4)
304
+ tok_batch = 64
305
+ read_batch = 500
306
+ elif avg_chars > 5_000: # long (detailed SFT, ~1.5K+ tokens)
307
+ num_proc = min(cpu_count, 8)
308
+ tok_batch = 256
309
+ read_batch = 2_000
310
+ elif avg_chars > 1_000: # medium (typical SFT)
311
+ num_proc = min(cpu_count, 16)
312
+ tok_batch = 500
313
+ read_batch = 5_000
314
+ else: # short (web text, wiki)
315
+ num_proc = min(cpu_count, 32)
316
+ tok_batch = 1000
317
+ read_batch = 10_000
318
+
319
+ return num_proc, tok_batch, read_batch
320
+
321
+
322
+ def load_hf_dataset(
323
+ name: str,
324
+ split: str,
325
+ text_column: str,
326
+ tokenizer,
327
+ max_seq_len: int,
328
+ num_samples: int | None = None,
329
+ hf_config: str | None = None,
330
+ cache_path: Path | None = None,
331
+ data_format: str = "text",
332
+ ) -> MemmapDataset:
333
+ """Load HF dataset with parallel tokenization and streaming to memmap.
334
+
335
+ Parallelism:
336
+ - dataset.map(num_proc=N) uses multiprocessing — bypasses GIL
337
+ - GPT2TokenizerFast runs Rust tokenization — bypasses GIL
338
+ - batched=True enables efficient batch processing
339
+
340
+ Memory:
341
+ - Adaptive batch sizes based on avg example length — prevents OOM on long sequences
342
+ - Tokenized data in Arrow format (memory-mapped by HuggingFace)
343
+ - Chunks streamed to binary memmap file — never in RAM
344
+ """
345
+ from datasets import load_dataset
346
+
347
+ config_str = f", config={hf_config}" if hf_config else ""
348
+ print(f"Loading HF dataset: {name} (split={split}{config_str})")
349
+ dataset = load_dataset(name, hf_config, split=split)
350
+
351
+ if num_samples is not None:
352
+ dataset = dataset.select(range(min(num_samples, len(dataset))))
353
+
354
+ # Flatten chat format to plain text
355
+ if data_format == "chat":
356
+ # Use conservative parallelism for flattening — light operation
357
+ flat_proc = min(max(1, multiprocessing.cpu_count() - 1), 8)
358
+ print(f"Flattening {len(dataset):,} chat examples to plain text...")
359
+ dataset = dataset.map(
360
+ _flatten_chat,
361
+ num_proc=flat_proc,
362
+ remove_columns=dataset.column_names,
363
+ desc="Flattening chat",
364
+ )
365
+ text_column = "text"
366
+
367
+ # Estimate avg example length and adapt parameters
368
+ avg_chars = _estimate_avg_chars(dataset, text_column)
369
+ num_proc, tok_batch, read_batch = _adaptive_params(avg_chars, len(dataset))
370
+ print(f" Avg example length: ~{avg_chars:,.0f} chars → "
371
+ f"{num_proc} workers, tok_batch={tok_batch}, read_batch={read_batch}")
372
+
373
+ # Filter empty examples
374
+ print(f"Filtering empty examples from {len(dataset):,}...")
375
+ dataset = dataset.filter(
376
+ lambda x: bool(x[text_column] and x[text_column].strip()),
377
+ num_proc=num_proc,
378
+ desc="Filtering",
379
+ )
380
+ print(f" {len(dataset):,} non-empty examples")
381
+
382
+ # Parallel tokenization
383
+ print(f"Tokenizing {len(dataset):,} examples with {num_proc} workers...")
384
+
385
+ def tokenize_batch(examples):
386
+ return tokenizer(examples[text_column], add_special_tokens=False)
387
+
388
+ tokenized = dataset.map(
389
+ tokenize_batch,
390
+ batched=True,
391
+ batch_size=tok_batch,
392
+ num_proc=num_proc,
393
+ remove_columns=dataset.column_names,
394
+ desc="Tokenizing",
395
+ )
396
+
397
+ # Stream to memmap — use temp path if no cache configured
398
+ if cache_path is None:
399
+ import tempfile
400
+ cache_path = Path(tempfile.mktemp(suffix='.bin'))
401
+
402
+ _stream_chunks_to_memmap(tokenized, len(tokenized), max_seq_len, cache_path,
403
+ read_batch=read_batch)
404
+ return MemmapDataset(cache_path)
405
+
406
+
407
+ # ---------------------------------------------------------------------------
408
+ # Text file loaders (unchanged — small datasets, in-memory is fine)
409
+ # ---------------------------------------------------------------------------
410
+
411
+ def tokenize_text(text: str, tokenizer, max_seq_len: int) -> list[list[int]]:
412
+ """Tokenize text into chunks of max_seq_len."""
413
+ tokens = tokenizer.encode(text)
414
+ chunks = []
415
+ for i in range(0, len(tokens), max_seq_len):
416
+ chunk = tokens[i : i + max_seq_len]
417
+ if len(chunk) >= 32:
418
+ chunks.append(chunk)
419
+ return chunks
420
+
421
+
422
+ def load_text_file(path: str, tokenizer, max_seq_len: int) -> list[list[int]]:
423
+ """Load and tokenize a single text file."""
424
+ with open(path, "r", encoding="utf-8") as f:
425
+ text = f.read()
426
+ return tokenize_text(text, tokenizer, max_seq_len)
427
+
428
+
429
+ def load_text_directory(path: str, tokenizer, max_seq_len: int) -> list[list[int]]:
430
+ """Load and tokenize all .txt files from a directory."""
431
+ all_chunks = []
432
+ path = Path(path)
433
+ for txt_file in sorted(path.glob("**/*.txt")):
434
+ chunks = load_text_file(str(txt_file), tokenizer, max_seq_len)
435
+ all_chunks.extend(chunks)
436
+ return all_chunks
437
+
438
+
439
+ # ---------------------------------------------------------------------------
440
+ # Main entry point
441
+ # ---------------------------------------------------------------------------
442
+
443
+ def load_data(
444
+ data_source: str,
445
+ tokenizer,
446
+ max_seq_len: int,
447
+ text_column: str = "text",
448
+ num_samples: int | None = None,
449
+ cache_dir: str | None = DEFAULT_CACHE_DIR,
450
+ data_format: str = "text",
451
+ ) -> Dataset:
452
+ """
453
+ Load data from various sources. Returns a Dataset with .split() support.
454
+
455
+ Args:
456
+ data_source: Path or HF dataset identifier
457
+ - "path/to/file.txt" — single file
458
+ - "path/to/dir/" — directory of .txt files
459
+ - "hf:dataset_name" — HuggingFace dataset (train split)
460
+ - "hf:dataset:split" — HuggingFace with specific split
461
+ - "hf:dataset:config:split" — with config and split
462
+ tokenizer: Tokenizer to use
463
+ max_seq_len: Maximum sequence length
464
+ text_column: Column name for HF datasets
465
+ num_samples: Limit samples from HF dataset
466
+ cache_dir: Directory for cache files (None to disable)
467
+
468
+ Returns:
469
+ Dataset object supporting len(), __getitem__(), and split(fraction)
470
+ """
471
+ cache_path = None
472
+ if cache_dir is not None:
473
+ cache_path = Path(cache_dir) / _cache_key(data_source, max_seq_len, num_samples)
474
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
475
+
476
+ # Check for memmap cache (.bin)
477
+ if cache_path.exists():
478
+ print(f"Loading from cache: {cache_path}")
479
+ ds = MemmapDataset(cache_path)
480
+ print(f" Loaded {len(ds):,} chunks")
481
+ return ds
482
+
483
+ # Check for legacy cache (.pt)
484
+ legacy_path = cache_path.with_suffix('.pt')
485
+ if legacy_path.exists():
486
+ print(f"Loading from legacy cache: {legacy_path}")
487
+ data = torch.load(legacy_path, weights_only=False)
488
+ chunks = data["chunks"]
489
+ print(f" Loaded {len(chunks):,} chunks")
490
+ return TextDataset(chunks, max_seq_len)
491
+
492
+ # Load and tokenize
493
+ if data_source.startswith("hf:"):
494
+ parts = data_source[3:].split(":")
495
+ name = parts[0]
496
+ hf_config = None
497
+ split = "train"
498
+ if len(parts) == 2:
499
+ split = parts[1]
500
+ elif len(parts) == 3:
501
+ hf_config = parts[1]
502
+ split = parts[2]
503
+ return load_hf_dataset(
504
+ name, split, text_column, tokenizer, max_seq_len,
505
+ num_samples, hf_config=hf_config, cache_path=cache_path,
506
+ data_format=data_format,
507
+ )
508
+ elif os.path.isfile(data_source):
509
+ chunks = load_text_file(data_source, tokenizer, max_seq_len)
510
+ elif os.path.isdir(data_source):
511
+ chunks = load_text_directory(data_source, tokenizer, max_seq_len)
512
+ else:
513
+ raise ValueError(f"Unknown data source: {data_source}")
514
+
515
+ # For text files: save legacy cache
516
+ if cache_dir is not None:
517
+ legacy_path = cache_path.with_suffix('.pt')
518
+ torch.save({"chunks": chunks, "data_source": data_source,
519
+ "max_seq_len": max_seq_len, "num_samples": num_samples}, legacy_path)
520
+ print(f"Saved to cache: {legacy_path}")
521
+
522
+ return TextDataset(chunks, max_seq_len)
523
+
524
+
525
+ # ---------------------------------------------------------------------------
526
+ # DataLoader factory
527
+ # ---------------------------------------------------------------------------
528
+
529
+ def create_dataloader(
530
+ dataset,
531
+ batch_size: int,
532
+ max_seq_len: int = None,
533
+ shuffle: bool = True,
534
+ num_workers: int = 0,
535
+ ) -> DataLoader:
536
+ """Create a DataLoader from a Dataset or list of chunks."""
537
+ if not isinstance(dataset, Dataset):
538
+ # Legacy compatibility: list of token chunks
539
+ dataset = TextDataset(dataset, max_seq_len)
540
+ return DataLoader(
541
+ dataset,
542
+ batch_size=batch_size,
543
+ shuffle=shuffle,
544
+ num_workers=num_workers,
545
+ pin_memory=True,
546
+ )
generate.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generation script for Circuit Transformer.
4
+
5
+ Usage:
6
+ python circuits/generate.py --checkpoint circuits/checkpoints/latest.pt --prompt "Once upon a time"
7
+ """
8
+
9
+ import argparse
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from transformers import AutoTokenizer
15
+
16
+ from .config import CircuitConfig
17
+ from .model import CircuitTransformer
18
+ from .mirrored import MirroredConfig, MirroredTransformer
19
+ from .graft_g2lu import load_g2lu_model
20
+ from .layers import build_word_start_table
21
+ from .data import get_tokenizer
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(description="Generate text with Circuit Transformer")
26
+
27
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint")
28
+ parser.add_argument("--prompt", type=str, default="", help="Prompt text")
29
+ parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens to generate")
30
+ parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature")
31
+ parser.add_argument("--top-k", type=int, default=50, help="Top-k filtering")
32
+ parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold")
33
+ parser.add_argument("--repetition-penalty", type=float, default=1.0, help="Repetition penalty (1.0=off, 1.3=default for slot models)")
34
+ parser.add_argument("--gpu", type=int, default=0, help="GPU index")
35
+ parser.add_argument("--no-cache", action="store_true", help="Disable KV cache")
36
+
37
+ return parser.parse_args()
38
+
39
+ def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
40
+ """Migrate checkpoint state_dict to match current model architecture.
41
+
42
+ Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
43
+ """
44
+ if any(k.startswith("_orig_mod.") for k in state_dict):
45
+ state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
46
+
47
+ model_keys = set(model.state_dict().keys())
48
+ ckpt_keys = set(state_dict.keys())
49
+
50
+ missing = model_keys - ckpt_keys
51
+ unexpected = ckpt_keys - model_keys
52
+
53
+ print(unexpected)
54
+
55
+ if not missing and not unexpected:
56
+ return state_dict # perfect match, no migration needed
57
+
58
+ migrated = dict(state_dict)
59
+ migrations = []
60
+
61
+ # SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
62
+ for key in list(unexpected):
63
+ if ".ffn.gate_expand.weight" in key:
64
+ new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
65
+ if new_key in missing:
66
+ migrated[new_key] = migrated.pop(key)
67
+ missing.discard(new_key)
68
+ unexpected.discard(key)
69
+ migrations.append(f" {key} → {new_key}")
70
+ if ".ffn.gate_compress.weight" in key:
71
+ new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
72
+ if new_key in missing:
73
+ migrated[new_key] = migrated.pop(key)
74
+ missing.discard(new_key)
75
+ unexpected.discard(key)
76
+ migrations.append(f" {key} → {new_key}")
77
+
78
+ if migrations:
79
+ print(f"State dict migration ({len(migrations)} keys renamed):")
80
+ for m in migrations:
81
+ print(m)
82
+ # Report remaining missing keys (freshly initialized)
83
+ still_missing = model_keys - set(migrated.keys())
84
+ if still_missing:
85
+ print(f" New parameters (freshly initialized): {len(still_missing)}")
86
+ for k in sorted(still_missing):
87
+ print(f" {k}")
88
+
89
+ return migrated
90
+
91
+ def generate():
92
+ args = parse_args()
93
+
94
+ # Setup device
95
+ device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
96
+ print(f"Device: {device}")
97
+
98
+ # Load checkpoint
99
+ print(f"Loading checkpoint: {args.checkpoint}")
100
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
101
+
102
+ # Reconstruct config and model based on architecture type
103
+ model_type = checkpoint.get("model_type", "standard")
104
+ is_folded = model_type == "folded"
105
+
106
+ if model_type == "graft_g2lu":
107
+ model = load_g2lu_model(args.checkpoint, device=device)
108
+ model.eval()
109
+ pretrained_name = checkpoint.get("pretrained_name", "unknown")
110
+ print(f"Architecture: G²LU Graft ({pretrained_name}, {len(model.g2lu_mlps)}L)")
111
+ tokenizer_name = checkpoint.get("tokenizer_name", pretrained_name)
112
+ tokenizer = get_tokenizer(tokenizer_name)
113
+ elif is_folded:
114
+ from grafting.fold_llama import FoldedLlama
115
+ model = FoldedLlama.load_from_checkpoint(args.checkpoint, device=device)
116
+ model.eval()
117
+ fold_cfg = model.config
118
+ print(f"Architecture: FoldedLlama ({fold_cfg.model_name}, "
119
+ f"{fold_cfg.n_expand}E+{fold_cfg.n_middle}M+{fold_cfg.n_compress}C)")
120
+ tokenizer = AutoTokenizer.from_pretrained(fold_cfg.model_name, trust_remote_code=True)
121
+ else:
122
+ if model_type == "mirrored":
123
+ if checkpoint["config"].get("dual_gate_middle"):
124
+ checkpoint["config"].pop("dual_gate_middle")
125
+ config = MirroredConfig.from_dict(checkpoint["config"])
126
+ model = MirroredTransformer(config).to(device)
127
+ print(f"Architecture: MirroredTransformer ({model.total_virtual_layers} virtual layers)")
128
+ else:
129
+ config = CircuitConfig.from_dict(checkpoint["config"])
130
+ model = CircuitTransformer(config).to(device)
131
+ print(f"Architecture: CircuitTransformer ({config.num_layers} layers)")
132
+ # Strip _orig_mod. prefix from torch.compile'd checkpoints
133
+
134
+ state_dict = _migrate_state_dict(checkpoint["model"], model)
135
+
136
+ model.load_state_dict(state_dict)
137
+ model.eval()
138
+ tokenizer_name = checkpoint.get("tokenizer_name", "gpt2")
139
+ tokenizer = get_tokenizer(tokenizer_name)
140
+
141
+ # Build word-position table if model uses SemRoPE
142
+ word_start_table_device = None
143
+ if model_type not in ("graft_g2lu", "folded"):
144
+ ckpt_config = checkpoint.get("config", {})
145
+ word_rope_dims = ckpt_config.get("word_rope_dims", 0)
146
+ if word_rope_dims > 0:
147
+ word_start_table_device = build_word_start_table(tokenizer, len(tokenizer)).to(device)
148
+ print(f"Word-position RoPE: {word_rope_dims} dims")
149
+
150
+ # Tokenize prompt
151
+ if args.prompt:
152
+ prompt_ids = tokenizer.encode(args.prompt, return_tensors="pt").to(device)
153
+ else:
154
+ # Start with BOS/EOS token
155
+ prompt_ids = torch.tensor([[tokenizer.eos_token_id]], device=device)
156
+
157
+ print(f"\nPrompt: {args.prompt or '<empty>'}")
158
+ print(f"Prompt tokens: {prompt_ids.shape[1]}")
159
+ print(f"Generating {args.max_tokens} tokens...")
160
+ print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Top-p: {args.top_p}")
161
+ print("-" * 50)
162
+
163
+ # Generate
164
+ with torch.no_grad():
165
+ gen_kwargs = dict(
166
+ max_new_tokens=args.max_tokens,
167
+ temperature=args.temperature,
168
+ top_k=args.top_k,
169
+ top_p=args.top_p,
170
+ use_cache=not args.no_cache,
171
+ )
172
+ if args.repetition_penalty != 1.0:
173
+ gen_kwargs["repetition_penalty"] = args.repetition_penalty
174
+
175
+ # HF models need do_sample=True for temperature/top_k/top_p
176
+ if model_type == "graft_g2lu":
177
+ if args.temperature > 0 and args.temperature != 1.0:
178
+ gen_kwargs["do_sample"] = True
179
+ elif args.top_p < 1.0 or args.top_k > 0:
180
+ gen_kwargs["do_sample"] = True
181
+
182
+ if word_start_table_device is not None:
183
+ gen_kwargs["word_start_table"] = word_start_table_device
184
+
185
+ output_ids = model.generate(prompt_ids, **gen_kwargs)
186
+
187
+ # Decode and print
188
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
189
+ print(generated_text)
190
+ print("-" * 50)
191
+ print(f"Total tokens: {output_ids.shape[1]}")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ generate()
graft_g2lu.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ G²LU Gate Grafting: Surgically upgrade pretrained SwiGLU models to G²LU.
3
+
4
+ Takes any HuggingFace model with SwiGLU (gate_proj + up_proj), freezes everything
5
+ except gate weights, adds W4 for nested gating, and trains with alignment + LM loss.
6
+
7
+ This is grafting applied to the gate mechanism — the same methodology validated for
8
+ full layer replacement, now targeting the minimum surgical unit.
9
+
10
+ Usage:
11
+ python -m circuits.train --arch graft_g2lu --pretrained meta-llama/Llama-3.2-1B \
12
+ --align-weight 1.0 --graft-warmup 500 --data hf:Bingsu/openwebtext_20p ...
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from pathlib import Path
19
+
20
+
21
+ class G2LU_MLP(nn.Module):
22
+ """Per-layer MLP wrapper that upgrades SwiGLU to G²LU.
23
+
24
+ Holds references to the original gate_proj (W3, frozen), up_proj (W1, frozen),
25
+ down_proj (W2, frozen), plus a new w4 (zero-initialized, trainable).
26
+
27
+ Gate ordering: silu(W4@x * silu(W3@x)) — the pretrained gate (W3) acts as
28
+ structural prior, constraining W4 to operate within the feature subspace the
29
+ pretrained model already deems relevant. W4's gradients are scaled by silu(W3@x),
30
+ inheriting the pretrained model's feature selection hierarchy.
31
+ """
32
+
33
+ def __init__(self, original_mlp: nn.Module):
34
+ super().__init__()
35
+ # References to original weights (all frozen)
36
+ self.gate_proj = original_mlp.gate_proj # W3 — frozen
37
+ self.up_proj = original_mlp.up_proj # W1 — frozen
38
+ self.down_proj = original_mlp.down_proj # W2 — frozen
39
+
40
+ # New W4: same shape as gate_proj, zero-initialized, matched dtype
41
+ self.w4 = nn.Linear(
42
+ self.gate_proj.in_features,
43
+ self.gate_proj.out_features,
44
+ bias=self.gate_proj.bias is not None,
45
+ dtype=self.gate_proj.weight.dtype,
46
+ device=self.gate_proj.weight.device,
47
+ )
48
+ nn.init.zeros_(self.w4.weight)
49
+ if self.w4.bias is not None:
50
+ nn.init.zeros_(self.w4.bias)
51
+
52
+ # Blend alpha: 0 = pure SwiGLU, 1 = full G²LU
53
+ self._alpha = 0.0
54
+ # Per-layer alignment loss (collected by parent)
55
+ self._align_loss = None
56
+
57
+ def forward(self, x):
58
+ # Pretrained gate (frozen W3) — structural prior
59
+ w3_gate = F.silu(self.gate_proj(x))
60
+
61
+ # G²LU gate: silu(W4@x * silu(W3@x))
62
+ # W4 modulated BY pretrained knowledge, not the reverse
63
+ g2lu_gate = F.silu(self.w4(x) * w3_gate)
64
+
65
+ # Blend warmup: smooth transition from SwiGLU → G²LU
66
+ if self._alpha < 1.0:
67
+ gate = (1.0 - self._alpha) * w3_gate + self._alpha * g2lu_gate
68
+ else:
69
+ gate = g2lu_gate
70
+
71
+ # Per-layer alignment loss (compare against original SwiGLU gate)
72
+ self._align_loss = F.mse_loss(gate, w3_gate.detach())
73
+
74
+ return self.down_proj(gate * self.up_proj(x))
75
+
76
+
77
+ class G2LU_GraftedModel(nn.Module):
78
+ """Full model wrapper that upgrades a pretrained HF model's MLPs to G²LU.
79
+
80
+ Interface matches CircuitTransformer: forward(input_ids, labels=labels) returns
81
+ {"loss", "logits", "align_loss"}.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ pretrained_name: str,
87
+ align_weight: float = 1.0,
88
+ warmup_steps: int = 500,
89
+ device: str = "cuda",
90
+ dtype=torch.bfloat16,
91
+ ):
92
+ super().__init__()
93
+ self.pretrained_name = pretrained_name
94
+ self.align_weight = align_weight
95
+ self.warmup_steps = warmup_steps
96
+ self._current_step = 0
97
+
98
+ # Load pretrained HF model
99
+ from transformers import AutoModelForCausalLM
100
+ self.model = AutoModelForCausalLM.from_pretrained(
101
+ pretrained_name,
102
+ dtype=dtype,
103
+ trust_remote_code=True,
104
+ )
105
+
106
+ # Discover and replace MLPs
107
+ self.g2lu_mlps = []
108
+ self._replace_mlps()
109
+
110
+ # Freeze everything, then selectively unfreeze W4 only
111
+ for param in self.model.parameters():
112
+ param.requires_grad = False
113
+
114
+ for g2lu in self.g2lu_mlps:
115
+ for param in g2lu.w4.parameters():
116
+ param.requires_grad = True
117
+
118
+ self.model.to(device)
119
+
120
+ # Print summary
121
+ total_params = sum(p.numel() for p in self.model.parameters())
122
+ trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
123
+ print(f"G²LU Graft: {pretrained_name}")
124
+ print(f" Layers upgraded: {len(self.g2lu_mlps)}")
125
+ print(f" Total params: {total_params:,} ({total_params/1e6:.1f}M)")
126
+ print(f" Trainable params: {trainable:,} ({trainable/1e6:.1f}M, {100*trainable/total_params:.1f}%)")
127
+ print(f" Align weight: {align_weight}, Warmup: {warmup_steps} steps")
128
+
129
+ def _replace_mlps(self):
130
+ """Walk the model tree and replace SwiGLU MLPs with G²LU wrappers."""
131
+ # Try common decoder layer paths
132
+ layers = None
133
+ for attr_path in ["model.layers", "gpt_neox.layers", "transformer.h"]:
134
+ obj = self.model
135
+ try:
136
+ for attr in attr_path.split("."):
137
+ obj = getattr(obj, attr)
138
+ layers = obj
139
+ break
140
+ except AttributeError:
141
+ continue
142
+
143
+ if layers is None:
144
+ raise ValueError(
145
+ f"Could not find decoder layers in {type(self.model).__name__}. "
146
+ f"Tried: model.layers, gpt_neox.layers, transformer.h"
147
+ )
148
+
149
+ for i, layer in enumerate(layers):
150
+ # Try common MLP attribute names
151
+ mlp = None
152
+ mlp_attr = None
153
+ for attr in ["mlp", "feed_forward"]:
154
+ if hasattr(layer, attr):
155
+ mlp = getattr(layer, attr)
156
+ mlp_attr = attr
157
+ break
158
+
159
+ if mlp is None:
160
+ continue
161
+
162
+ # Check for SwiGLU signature (gate_proj + up_proj)
163
+ if hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj"):
164
+ g2lu = G2LU_MLP(mlp)
165
+ setattr(layer, mlp_attr, g2lu)
166
+ self.g2lu_mlps.append(g2lu)
167
+
168
+ if not self.g2lu_mlps:
169
+ raise ValueError(
170
+ "No SwiGLU MLPs found (need gate_proj + up_proj attributes). "
171
+ "This model may not use gated linear units."
172
+ )
173
+
174
+ def set_step(self, step: int):
175
+ """Update blend alpha across all G²LU MLPs."""
176
+ self._current_step = step
177
+ alpha = min(step / max(self.warmup_steps, 1), 1.0)
178
+ for g2lu in self.g2lu_mlps:
179
+ g2lu._alpha = alpha
180
+
181
+ def trainable_parameters(self):
182
+ """Yield only unfrozen parameters (for optimizer and grad clipping)."""
183
+ for param in self.model.parameters():
184
+ if param.requires_grad:
185
+ yield param
186
+
187
+ def collect_align_loss(self):
188
+ """Average per-layer alignment losses."""
189
+ losses = [g2lu._align_loss for g2lu in self.g2lu_mlps if g2lu._align_loss is not None]
190
+ if not losses:
191
+ return torch.tensor(0.0)
192
+ return torch.stack(losses).mean()
193
+
194
+ def forward(self, input_ids, labels=None, **kwargs):
195
+ outputs = self.model(input_ids=input_ids, labels=labels, **kwargs)
196
+
197
+ result = {"logits": outputs.logits}
198
+
199
+ align_loss = self.collect_align_loss()
200
+ result["align_loss"] = align_loss
201
+
202
+ if labels is not None:
203
+ # Combine LM loss + alignment loss
204
+ result["loss"] = outputs.loss + self.align_weight * align_loss
205
+ result["lm_loss"] = outputs.loss
206
+ else:
207
+ result["loss"] = align_loss
208
+
209
+ return result
210
+
211
+ def generate(self, input_ids, **kwargs):
212
+ """Delegate to HF model's .generate()."""
213
+ return self.model.generate(input_ids=input_ids, **kwargs)
214
+
215
+
216
+ def save_g2lu_checkpoint(
217
+ model: G2LU_GraftedModel,
218
+ optimizer: torch.optim.Optimizer,
219
+ step: int,
220
+ epoch: int,
221
+ loss: float,
222
+ path: str,
223
+ epoch_step: int = 0,
224
+ best_val_loss: float | None = None,
225
+ scaler=None,
226
+ tokenizer_name: str = None,
227
+ ):
228
+ """Delta save: only trainable params + metadata."""
229
+ # Extract only requires_grad params
230
+ raw = model.model if not hasattr(model, '_orig_mod') else model._orig_mod.model
231
+ # Handle torch.compile wrapper
232
+ if hasattr(model, '_orig_mod'):
233
+ g2lu_model = model._orig_mod
234
+ else:
235
+ g2lu_model = model
236
+
237
+ delta_sd = {}
238
+ full_sd = g2lu_model.model.state_dict()
239
+ for name, param in g2lu_model.model.named_parameters():
240
+ if param.requires_grad:
241
+ # Strip _orig_mod. prefix if present
242
+ clean_name = name.removeprefix("_orig_mod.")
243
+ delta_sd[clean_name] = full_sd.get(name, param.data).clone()
244
+
245
+ # Also save the w4 weights explicitly (they're part of the replaced modules)
246
+ for name, val in full_sd.items():
247
+ clean_name = name.removeprefix("_orig_mod.")
248
+ if ".w4." in clean_name and clean_name not in delta_sd:
249
+ delta_sd[clean_name] = val.clone()
250
+
251
+ checkpoint = {
252
+ "model": delta_sd,
253
+ "optimizer": optimizer.state_dict(),
254
+ "step": step,
255
+ "epoch": epoch,
256
+ "epoch_step": epoch_step,
257
+ "loss": loss,
258
+ "model_type": "graft_g2lu",
259
+ "pretrained_name": g2lu_model.pretrained_name,
260
+ "align_weight": g2lu_model.align_weight,
261
+ "warmup_steps": g2lu_model.warmup_steps,
262
+ "tokenizer_name": tokenizer_name or g2lu_model.pretrained_name,
263
+ }
264
+ if best_val_loss is not None:
265
+ checkpoint["best_val_loss"] = best_val_loss
266
+ if scaler is not None:
267
+ checkpoint["scaler"] = scaler.state_dict()
268
+ torch.save(checkpoint, path)
269
+
270
+
271
+ def load_g2lu_model(checkpoint_path: str, device: str = "cuda", dtype=torch.bfloat16):
272
+ """Delta load: recreate model from pretrained + apply delta weights."""
273
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
274
+
275
+ pretrained_name = checkpoint["pretrained_name"]
276
+ align_weight = checkpoint.get("align_weight", 1.0)
277
+ warmup_steps = checkpoint.get("warmup_steps", 500)
278
+
279
+ model = G2LU_GraftedModel(
280
+ pretrained_name=pretrained_name,
281
+ align_weight=align_weight,
282
+ warmup_steps=warmup_steps,
283
+ device=device,
284
+ dtype=dtype,
285
+ )
286
+
287
+ # Load delta weights
288
+ delta_sd = checkpoint["model"]
289
+ # Strip _orig_mod. prefix if present
290
+ delta_sd = {k.removeprefix("_orig_mod."): v for k, v in delta_sd.items()}
291
+
292
+ # Apply delta weights to the model
293
+ missing, unexpected = model.model.load_state_dict(delta_sd, strict=False)
294
+ if unexpected:
295
+ print(f" Warning: unexpected keys in delta checkpoint: {unexpected[:5]}...")
296
+
297
+ # Set alpha to 1.0 for inference (full G²LU)
298
+ model.set_step(warmup_steps + 1)
299
+
300
+ return model
layers.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared building blocks for Circuit Transformer architectures.
3
+
4
+ Components:
5
+ - RMSNorm: Root Mean Square Layer Normalization
6
+ - RotaryEmbedding: Rotary Position Embedding (RoPE)
7
+ - CausalAttention: Multi-head causal attention with RoPE + KV cache
8
+ - SwiGLU: Gated feed-forward network
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+ from functools import lru_cache
16
+
17
+
18
+ class RMSNorm(nn.Module):
19
+ """Root Mean Square Layer Normalization."""
20
+
21
+ def __init__(self, dim: int, eps: float = 1e-6):
22
+ super().__init__()
23
+ self.eps = eps
24
+ self.weight = nn.Parameter(torch.ones(dim))
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
28
+ return (x.float() * norm).type_as(x) * self.weight
29
+
30
+
31
+ def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor:
32
+ """Build a boolean table marking which token IDs start a new word.
33
+
34
+ Detects word boundaries from tokenizer's token representations:
35
+ - Ġ prefix (GPT-2/BPE style)
36
+ - ▁ prefix (SentencePiece style)
37
+ - Special tokens (starting with <)
38
+ """
39
+ table = torch.zeros(vocab_size, dtype=torch.bool)
40
+
41
+ # Get all token strings — handle both HF and SentencePiece tokenizers
42
+ if hasattr(tokenizer, 'convert_ids_to_tokens'):
43
+ tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
44
+ elif hasattr(tokenizer, 'sp'):
45
+ tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)]
46
+ else:
47
+ tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
48
+
49
+ for idx, tok in enumerate(tokens):
50
+ if tok is None:
51
+ continue
52
+ if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
53
+ table[idx] = True
54
+ # Punctuation and newlines that start new "words"
55
+ elif len(tok) > 0 and tok[0] in '\n\r\t':
56
+ table[idx] = True
57
+
58
+ # Token 0 is always a word starter (BOS/padding)
59
+ table[0] = True
60
+
61
+ return table
62
+
63
+
64
+ def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor:
65
+ """Compute position-within-word for each token. Vectorized, no loops.
66
+
67
+ Args:
68
+ input_ids: [B, L] token IDs
69
+ word_start_table: [vocab_size] bool tensor from build_word_start_table
70
+
71
+ Returns:
72
+ [B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary)
73
+ """
74
+ is_word_start = word_start_table[input_ids] # [B, L]
75
+ is_word_start[:, 0] = True # First token always starts a word
76
+
77
+ B, L = input_ids.shape
78
+ positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1)
79
+
80
+ # Fill non-word-start positions with -1, word-start positions with their index
81
+ fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device))
82
+
83
+ # cummax propagates the most recent word-start position forward
84
+ running_start, _ = fill.cummax(dim=1)
85
+
86
+ # Position within word = distance from the most recent word start
87
+ word_pos = positions - running_start # [B, L] float: 0, 1, 2, 0, 1, 0, ...
88
+
89
+ return word_pos
90
+
91
+
92
+ class WordPositionRoPE(nn.Module):
93
+ """RoPE encoding for position-within-word.
94
+
95
+ Dedicates a small subspace of head dimensions to word-internal position,
96
+ using separate (lower) frequency bases. Overrides the last `word_dims`
97
+ of the standard RoPE cos/sin tensors.
98
+ """
99
+
100
+ def __init__(self, word_dims: int, word_base: float = 10.0):
101
+ super().__init__()
102
+ self.word_dims = word_dims
103
+ word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims))
104
+ self.register_buffer("word_inv_freq", word_inv_freq)
105
+
106
+ def forward(
107
+ self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor
108
+ ) -> tuple[torch.Tensor, torch.Tensor]:
109
+ """Override last word_dims of cos/sin with word-position-derived values.
110
+
111
+ Args:
112
+ cos, sin: [L, head_dim] from standard RotaryEmbedding
113
+ word_positions: [B, L] float tensor (position within word)
114
+
115
+ Returns:
116
+ cos, sin: [B, L, head_dim] with word dims overridden
117
+ """
118
+ B, L = word_positions.shape
119
+
120
+ # Compute word angles: [B, L, word_dims/2]
121
+ angles = word_positions.unsqueeze(-1) * self.word_inv_freq
122
+ # Duplicate for rotate_half pattern: [B, L, word_dims]
123
+ word_emb = torch.cat([angles, angles], dim=-1)
124
+ word_cos = word_emb.cos()
125
+ word_sin = word_emb.sin()
126
+
127
+ # Expand standard cos/sin to batch dimension: [L, D] -> [B, L, D]
128
+ cos = cos.unsqueeze(0).expand(B, -1, -1).clone()
129
+ sin = sin.unsqueeze(0).expand(B, -1, -1).clone()
130
+
131
+ # Override last word_dims with word-position values
132
+ cos[:, :, -self.word_dims:] = word_cos
133
+ sin[:, :, -self.word_dims:] = word_sin
134
+
135
+ return cos, sin
136
+
137
+
138
+ class RotaryEmbedding(nn.Module):
139
+ """Rotary Position Embedding (RoPE)."""
140
+
141
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
142
+ super().__init__()
143
+ self.dim = dim
144
+ self.max_seq_len = max_seq_len
145
+
146
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
147
+ self.register_buffer("inv_freq", inv_freq)
148
+ self._build_cache(max_seq_len)
149
+
150
+ def _build_cache(self, seq_len: int):
151
+ t = torch.arange(seq_len, device=self.inv_freq.device)
152
+ freqs = torch.outer(t, self.inv_freq)
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
155
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
156
+
157
+ def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
158
+ if seq_len > self.cos_cached.size(0):
159
+ self._build_cache(seq_len)
160
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
161
+
162
+
163
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
164
+ """Rotate half the hidden dims."""
165
+ x1, x2 = x.chunk(2, dim=-1)
166
+ return torch.cat((-x2, x1), dim=-1)
167
+
168
+
169
+ def apply_rotary_pos_emb(
170
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
172
+ """Apply rotary position embedding to queries and keys.
173
+
174
+ Handles both standard [L, D] and batched [B, L, D] cos/sin.
175
+ Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast.
176
+ """
177
+ if cos.dim() == 3: # [B, L, D] from WordPositionRoPE
178
+ cos = cos.unsqueeze(1) # [B, 1, L, D] — broadcast over heads
179
+ sin = sin.unsqueeze(1)
180
+ q_embed = (q * cos) + (rotate_half(q) * sin)
181
+ k_embed = (k * cos) + (rotate_half(k) * sin)
182
+ return q_embed, k_embed
183
+
184
+
185
+ class CausalAttention(nn.Module):
186
+ """Multi-head attention with causal mask, RoPE, and optional GQA.
187
+
188
+ Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads.
189
+ Each KV head serves (num_heads // num_kv_heads) query heads.
190
+ KV cache stored at kv_heads granularity for memory efficiency.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ hidden_size: int,
196
+ num_heads: int,
197
+ num_kv_heads: int | None = None,
198
+ max_seq_len: int = 2048,
199
+ dropout: float = 0.0,
200
+ window_size: int | None = None,
201
+ word_rope_dims: int = 0,
202
+ word_rope_base: float = 10.0,
203
+ ):
204
+ super().__init__()
205
+ self.hidden_size = hidden_size
206
+ self.num_heads = num_heads
207
+ self.num_kv_heads = num_kv_heads or num_heads
208
+ self.head_dim = hidden_size // num_heads
209
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
210
+ self.dropout = dropout
211
+ self.window_size = window_size
212
+
213
+ assert self.num_heads % self.num_kv_heads == 0, \
214
+ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
215
+ if word_rope_dims > 0:
216
+ assert word_rope_dims <= self.head_dim, \
217
+ f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})"
218
+ assert word_rope_dims % 2 == 0, \
219
+ f"word_rope_dims ({word_rope_dims}) must be even"
220
+
221
+ self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
222
+ self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
223
+ self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
224
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
225
+
226
+ self.rotary = RotaryEmbedding(self.head_dim, max_seq_len)
227
+
228
+ # Word-position RoPE (optional)
229
+ self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None
230
+
231
+ # Build causal mask (optionally windowed)
232
+ mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
233
+ if window_size is not None:
234
+ # Band mask: position i attends to [max(0, i-window+1), i]
235
+ band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1))
236
+ mask = mask * band
237
+ self.register_buffer(
238
+ "causal_mask",
239
+ mask.view(1, 1, max_seq_len, max_seq_len),
240
+ persistent=False,
241
+ )
242
+
243
+ def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
244
+ """Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads."""
245
+ if self.num_kv_groups == 1:
246
+ return kv
247
+ B, H_kv, L, D = kv.shape
248
+ return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D)
249
+
250
+ def forward(
251
+ self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
252
+ word_positions: torch.Tensor | None = None,
253
+ ) -> tuple[torch.Tensor, tuple | None]:
254
+ B, L, _ = x.shape
255
+
256
+ q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
257
+ k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
258
+ v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
259
+
260
+ # RoPE: use correct position offset for KV-cached generation
261
+ offset = past_kv[0].size(2) if past_kv is not None else 0
262
+ cos, sin = self.rotary(x, offset + L)
263
+ cos = cos[offset:offset + L]
264
+ sin = sin[offset:offset + L]
265
+
266
+ # Override word-position dims if enabled
267
+ if self.word_rope is not None and word_positions is not None:
268
+ cos, sin = self.word_rope(cos, sin, word_positions)
269
+
270
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
271
+
272
+ # KV cache at kv_heads granularity (memory efficient for GQA)
273
+ if past_kv is not None:
274
+ past_k, past_v = past_kv
275
+ k = torch.cat([past_k, k], dim=2)
276
+ v = torch.cat([past_v, v], dim=2)
277
+
278
+ new_kv = (k, v) if use_cache else None
279
+
280
+ dropout_p = self.dropout if self.training else 0.0
281
+ use_gqa = self.num_kv_groups > 1
282
+
283
+ if self.window_size is not None:
284
+ # Windowed attention: manual path (SDPA FlashAttention doesn't support arbitrary masks)
285
+ k_expanded = self._expand_kv(k)
286
+ v_expanded = self._expand_kv(v)
287
+ seq_len = k.size(2)
288
+ attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim)
289
+ if seq_len <= self.causal_mask.size(-1):
290
+ mask = self.causal_mask[:, :, offset:offset + L, :seq_len]
291
+ attn = attn.masked_fill(mask == 0, float("-inf"))
292
+ attn = F.softmax(attn, dim=-1)
293
+ if dropout_p > 0:
294
+ attn = F.dropout(attn, p=dropout_p)
295
+ out = torch.matmul(attn, v_expanded)
296
+ else:
297
+ # SDPA: auto-dispatches to FlashAttention2 / memory-efficient / math backend
298
+ # Native GQA support avoids expanding KV heads (saves memory + enables FlashAttention GQA kernel)
299
+ is_causal = past_kv is None and L > 1
300
+ out = F.scaled_dot_product_attention(
301
+ q, k, v,
302
+ dropout_p=dropout_p,
303
+ is_causal=is_causal,
304
+ enable_gqa=use_gqa,
305
+ )
306
+
307
+ out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
308
+
309
+ return self.o_proj(out), new_kv
310
+
311
+
312
+ class SwiGLU(nn.Module):
313
+ """SwiGLU feed-forward network."""
314
+
315
+ def __init__(self, hidden_size: int, intermediate_size: int | None = None):
316
+ super().__init__()
317
+ intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
318
+ intermediate_size = ((intermediate_size + 63) // 64) * 64
319
+
320
+ self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
321
+ self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
322
+ self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
lm_eval_wrapper.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LM-eval harness wrapper for Circuit/Mirrored transformers.
3
+
4
+ Usage:
5
+ # Single model
6
+ python -m circuits.bench --checkpoint circuits/checkpoints/mirrored/best.pt --gpu 0
7
+
8
+ # Compare all architectures
9
+ python -m circuits.bench --compare --gpu 0
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import List
16
+ from tqdm import tqdm
17
+ from lm_eval.api.model import LM
18
+ from lm_eval.api.instance import Instance
19
+
20
+ from .config import CircuitConfig
21
+ from .model import CircuitTransformer
22
+ from .mirrored import MirroredConfig, MirroredTransformer
23
+ from .graft_g2lu import load_g2lu_model
24
+ from .layers import build_word_start_table, compute_word_positions
25
+ from .data import get_tokenizer
26
+
27
+ def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
28
+ """Migrate checkpoint state_dict to match current model architecture.
29
+
30
+ Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
31
+ """
32
+ if any(k.startswith("_orig_mod.") for k in state_dict):
33
+ state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
34
+
35
+ model_keys = set(model.state_dict().keys())
36
+ ckpt_keys = set(state_dict.keys())
37
+
38
+ missing = model_keys - ckpt_keys
39
+ unexpected = ckpt_keys - model_keys
40
+
41
+ print(unexpected)
42
+
43
+ if not missing and not unexpected:
44
+ return state_dict # perfect match, no migration needed
45
+
46
+ migrated = dict(state_dict)
47
+ migrations = []
48
+
49
+ # SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
50
+ for key in list(unexpected):
51
+ if ".ffn.gate_expand.weight" in key:
52
+ new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
53
+ if new_key in missing:
54
+ migrated[new_key] = migrated.pop(key)
55
+ missing.discard(new_key)
56
+ unexpected.discard(key)
57
+ migrations.append(f" {key} → {new_key}")
58
+ if ".ffn.gate_compress.weight" in key:
59
+ new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
60
+ if new_key in missing:
61
+ migrated[new_key] = migrated.pop(key)
62
+ missing.discard(new_key)
63
+ unexpected.discard(key)
64
+ migrations.append(f" {key} → {new_key}")
65
+
66
+ if migrations:
67
+ print(f"State dict migration ({len(migrations)} keys renamed):")
68
+ for m in migrations:
69
+ print(m)
70
+ # Report remaining missing keys (freshly initialized)
71
+ still_missing = model_keys - set(migrated.keys())
72
+ if still_missing:
73
+ print(f" New parameters (freshly initialized): {len(still_missing)}")
74
+ for k in sorted(still_missing):
75
+ print(f" {k}")
76
+
77
+ return migrated
78
+
79
+ def load_model(checkpoint_path: str, device: str = "cuda"):
80
+ """Load any circuit model from checkpoint with auto-detection."""
81
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
82
+
83
+ model_type = checkpoint.get("model_type", "standard")
84
+ if model_type == "graft_g2lu":
85
+ model = load_g2lu_model(checkpoint_path, device=device)
86
+ model.eval()
87
+ n_layers = len(model.g2lu_mlps)
88
+ arch_name = f"G²LU Graft ({checkpoint['pretrained_name']}, {n_layers}L)"
89
+ config = model.model.config # HF config
90
+ return model, config, arch_name, model_type
91
+ elif model_type == "mirrored":
92
+ if checkpoint["config"].get("dual_gate_middle"):
93
+ checkpoint["config"].pop("dual_gate_middle")
94
+ config = MirroredConfig.from_dict(checkpoint["config"])
95
+ model = MirroredTransformer(config)
96
+ arch_name = f"Mirrored ({model.total_virtual_layers}L)"
97
+ else:
98
+ config = CircuitConfig.from_dict(checkpoint["config"])
99
+ model = CircuitTransformer(config)
100
+ arch_name = f"Standard ({config.num_layers}L)"
101
+
102
+ # Strip _orig_mod. prefix from torch.compile'd checkpoints
103
+ state_dict = checkpoint["model"]
104
+ state_dict = _migrate_state_dict(state_dict, model)
105
+ model.load_state_dict(state_dict)
106
+
107
+ model = model.to(device).eval()
108
+ return model, config, arch_name, model_type
109
+
110
+
111
+ class CircuitLM(LM):
112
+ """LM-eval wrapper for Circuit transformer family."""
113
+
114
+ def __init__(
115
+ self,
116
+ checkpoint: str,
117
+ device: str = "cuda",
118
+ batch_size: int = 1,
119
+ compile: bool = False,
120
+ ):
121
+ super().__init__()
122
+
123
+ self.model, self.config, self.arch_name, self.model_type = load_model(
124
+ checkpoint, device
125
+ )
126
+ # Keep raw reference for .generate() — torch.compile only wraps forward()
127
+ self._raw_model = self.model
128
+ if compile == True:
129
+ self.model = torch.compile(self.model)
130
+ print(" torch.compile: enabled")
131
+ _ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
132
+ _tok_name = _ckpt.get("tokenizer_name", "gpt2")
133
+ del _ckpt
134
+ self.tokenizer = get_tokenizer(_tok_name)
135
+ if self.tokenizer.pad_token is None:
136
+ self.tokenizer.pad_token = self.tokenizer.eos_token
137
+
138
+ self._device = device
139
+ self._batch_size = batch_size
140
+
141
+ # Build word-position table if model uses SemRoPE
142
+ self._word_start_table = None
143
+ word_rope_dims = getattr(self.config, 'word_rope_dims', 0)
144
+ if word_rope_dims == 0 and isinstance(self.config, dict):
145
+ word_rope_dims = self.config.get('word_rope_dims', 0)
146
+ if word_rope_dims > 0:
147
+ self._word_start_table = build_word_start_table(
148
+ self.tokenizer, len(self.tokenizer)
149
+ ).to(device)
150
+ print(f" Word-position RoPE: {word_rope_dims} dims")
151
+
152
+ # Count parameters
153
+ n_params = sum(p.numel() for p in self.model.parameters())
154
+ print(f" Architecture: {self.arch_name}")
155
+ print(f" Parameters: {n_params / 1e6:.1f}M")
156
+
157
+ @property
158
+ def eot_token_id(self):
159
+ return self.tokenizer.eos_token_id
160
+
161
+ @property
162
+ def max_length(self):
163
+ return getattr(self.config, "max_seq_len", None) or getattr(self.config, "max_position_embeddings", 512)
164
+
165
+ @property
166
+ def max_gen_toks(self):
167
+ return 256
168
+
169
+ @property
170
+ def batch_size(self):
171
+ return self._batch_size
172
+
173
+ @property
174
+ def device(self):
175
+ return self._device
176
+
177
+ def tok_encode(self, string: str) -> List[int]:
178
+ return self.tokenizer.encode(string, add_special_tokens=False)
179
+
180
+ def tok_decode(self, tokens: List[int]) -> str:
181
+ return self.tokenizer.decode(tokens)
182
+
183
+ def _model_call(self, input_ids: torch.Tensor):
184
+ with torch.inference_mode(), torch.autocast('cuda', dtype=torch.bfloat16, enabled=self._device != "cpu"):
185
+ word_positions = None
186
+ if self._word_start_table is not None:
187
+ word_positions = compute_word_positions(input_ids, self._word_start_table)
188
+ output = self.model(input_ids, use_cache=False, word_positions=word_positions)
189
+ return output["logits"]
190
+
191
+ def _loglikelihood_tokens(self, requests, disable_tqdm=False):
192
+ results = []
193
+ for context_enc, continuation_enc in requests:
194
+ # Truncate from the left if too long
195
+ full_enc = context_enc + continuation_enc
196
+ if len(full_enc) > self.max_length:
197
+ excess = len(full_enc) - self.max_length
198
+ context_enc = context_enc[excess:]
199
+ full_enc = context_enc + continuation_enc
200
+
201
+ input_ids = torch.tensor(
202
+ [full_enc], dtype=torch.long, device=self._device
203
+ )
204
+
205
+ logits = self._model_call(input_ids)
206
+
207
+ ctx_len = len(context_enc)
208
+ cont_logits = logits[:, ctx_len - 1 : -1, :]
209
+ cont_tokens = input_ids[:, ctx_len:]
210
+
211
+ log_probs = F.log_softmax(cont_logits, dim=-1)
212
+ token_log_probs = log_probs.gather(
213
+ 2, cont_tokens.unsqueeze(-1)
214
+ ).squeeze(-1)
215
+
216
+ total_log_prob = token_log_probs.sum().item()
217
+ is_greedy = (cont_logits.argmax(dim=-1) == cont_tokens).all().item()
218
+
219
+ results.append((total_log_prob, is_greedy))
220
+
221
+ return results
222
+
223
+ def loglikelihood(
224
+ self, requests: List[Instance], disable_tqdm: bool = False
225
+ ) -> List[tuple]:
226
+ results = []
227
+ for request in tqdm(
228
+ requests, desc="loglikelihood", disable=disable_tqdm
229
+ ):
230
+ context, continuation = request.args
231
+ # Encode full text together to get correct tokenization,
232
+ # then split — sentencepiece tokenizes differently at string
233
+ # boundaries vs mid-sequence (the leading ▁ problem)
234
+ context_enc = self.tok_encode(context)
235
+ full_enc = self.tok_encode(context + continuation)
236
+ continuation_enc = full_enc[len(context_enc):]
237
+ if not continuation_enc:
238
+ # Edge case: continuation was absorbed into context tokens
239
+ # Fall back to encoding continuation separately
240
+ continuation_enc = self.tok_encode(continuation)
241
+ result = self._loglikelihood_tokens([(context_enc, continuation_enc)])
242
+ results.append(result[0])
243
+ return results
244
+
245
+ def loglikelihood_rolling(
246
+ self, requests: List[Instance], disable_tqdm: bool = False
247
+ ) -> List[float]:
248
+ results = []
249
+ for request in tqdm(
250
+ requests, desc="loglikelihood_rolling", disable=disable_tqdm
251
+ ):
252
+ text = request.args[0]
253
+ encoding = self.tok_encode(text)
254
+
255
+ total_log_prob = 0.0
256
+ max_len = self.max_length
257
+
258
+ for i in range(0, len(encoding), max_len):
259
+ chunk = encoding[i : i + max_len]
260
+ input_ids = torch.tensor(
261
+ [chunk], dtype=torch.long, device=self._device
262
+ )
263
+
264
+ logits = self._model_call(input_ids)
265
+ shift_logits = logits[:, :-1, :]
266
+ shift_labels = input_ids[:, 1:]
267
+
268
+ log_probs = F.log_softmax(shift_logits, dim=-1)
269
+ token_log_probs = log_probs.gather(
270
+ 2, shift_labels.unsqueeze(-1)
271
+ ).squeeze(-1)
272
+
273
+ total_log_prob += token_log_probs.sum().item()
274
+
275
+ results.append(total_log_prob)
276
+ return results
277
+
278
+ def generate_until(
279
+ self, requests: List[Instance], disable_tqdm: bool = False
280
+ ) -> List[str]:
281
+ results = []
282
+ for request in tqdm(
283
+ requests, desc="generate_until", disable=disable_tqdm
284
+ ):
285
+ context = request.args[0]
286
+ gen_kwargs = getattr(request, "kwargs", {}) or {}
287
+
288
+ until = gen_kwargs.get("until", [self.tokenizer.eos_token])
289
+ max_gen = gen_kwargs.get("max_gen_toks", self.max_gen_toks)
290
+
291
+ context_enc = self.tok_encode(context)
292
+ # Truncate context from left if needed
293
+ if len(context_enc) > self.max_length - max_gen:
294
+ context_enc = context_enc[-(self.max_length - max_gen) :]
295
+ input_ids = torch.tensor(
296
+ [context_enc], dtype=torch.long, device=self._device
297
+ )
298
+
299
+ if self.model_type == "graft_g2lu":
300
+ # Use HF's native generate with KV caching — much faster than
301
+ # manual token-by-token without cache (O(n) vs O(n²))
302
+ with torch.no_grad():
303
+ output_ids = self._raw_model.generate(
304
+ input_ids,
305
+ max_new_tokens=max_gen,
306
+ do_sample=False,
307
+ use_cache=True,
308
+ )
309
+ generated_text = self.tok_decode(
310
+ output_ids[0, input_ids.shape[1] :].tolist()
311
+ )
312
+ else:
313
+ generated_ids = input_ids.clone()
314
+ with torch.no_grad():
315
+ for _ in range(max_gen):
316
+ # Truncate if we exceed max_length
317
+ if generated_ids.shape[1] > self.max_length:
318
+ generated_ids = generated_ids[:, -self.max_length :]
319
+
320
+ logits = self._model_call(generated_ids)
321
+ next_logits = logits[:, -1, :]
322
+ next_token = next_logits.argmax(dim=-1, keepdim=True)
323
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
324
+
325
+ if next_token.item() == self.eot_token_id:
326
+ break
327
+
328
+ current_text = self.tok_decode(
329
+ generated_ids[0, len(context_enc) :].tolist()
330
+ )
331
+ if any(s in current_text for s in until):
332
+ break
333
+
334
+ generated_text = self.tok_decode(
335
+ generated_ids[0, len(context_enc) :].tolist()
336
+ )
337
+
338
+ for stop in until:
339
+ if stop in generated_text:
340
+ generated_text = generated_text[: generated_text.index(stop)]
341
+
342
+ results.append(generated_text)
343
+
344
+ return results
mirrored.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mirrored Transformer: Weight-sharing between expand and compress phases.
3
+
4
+ Based on the biconcave lens hypothesis from grafting research:
5
+ - Early layers expand from tokens to semantic space
6
+ - Late layers compress from semantic space back to tokens
7
+ - These phases share structural computation (W₁, W₂)
8
+ - Only the gate (semiotic filter) differs by direction
9
+
10
+ Architecture:
11
+ y = W₂ @ (W₁ @ x ⊙ swish(W₃ @ swish(W₄ @ x)))
12
+
13
+ Both gates fire every pass (additive, OR-logic). W₁ computed once.
14
+ W₁, W₂ shared between mirror pairs. W₃, W₄ are dual gates.
15
+ ~33% FFN parameter savings per mirrored pair vs standard SwiGLU.
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import math
22
+ from dataclasses import dataclass, fields
23
+
24
+ from .layers import RMSNorm, CausalAttention, SwiGLU
25
+
26
+
27
+ @dataclass
28
+ class MirroredConfig:
29
+ """Configuration for Mirrored Transformer."""
30
+ vocab_size: int = 50257
31
+ hidden_size: int = 768
32
+ num_heads: int = 12
33
+ num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA)
34
+ num_layers: int = 12 # effective depth (expand + middle + compress)
35
+ n_middle: int = 2 # unique middle layers (standard SwiGLU)
36
+ max_seq_len: int = 512
37
+ dropout: float = 0.0
38
+ aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled)
39
+ aux_skip_weight: float = 0.1 # weight for auxiliary skip loss
40
+ use_g2lu: bool = True # G²LU nested gates (False = vanilla SwiGLU)
41
+ word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled)
42
+ word_rope_base: float = 10.0 # frequency base for word-position RoPE
43
+ embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size)
44
+ head_dim: int = 0 # MLP head intermediate dim (0 = linear head)
45
+
46
+ def __post_init__(self):
47
+ assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
48
+ if self.num_kv_heads is not None:
49
+ assert self.num_heads % self.num_kv_heads == 0, \
50
+ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
51
+ n_mirror_layers = self.num_layers - self.n_middle
52
+ assert n_mirror_layers > 0, "num_layers must be greater than n_middle"
53
+ assert n_mirror_layers % 2 == 0, "num_layers - n_middle must be even"
54
+ self.n_mirror = n_mirror_layers // 2
55
+
56
+ def to_dict(self) -> dict:
57
+ """Convert to dictionary for serialization."""
58
+ return {f.name: getattr(self, f.name) for f in fields(self) if f.name != "n_mirror"}
59
+
60
+ @classmethod
61
+ def from_dict(cls, d: dict) -> "MirroredConfig":
62
+ """Create from dictionary."""
63
+ valid = {f.name for f in fields(cls)}
64
+ filtered = {k: v for k, v in d.items() if k in valid}
65
+ return cls(**filtered)
66
+
67
+ class MLP(nn.Module):
68
+ """Feed-forward network with SiLU activation."""
69
+
70
+ def __init__(self, dim, intermediate_size, dropout):
71
+ super().__init__()
72
+ self.up_proj = nn.Linear(dim, intermediate_size, bias=False)
73
+ self.gate_proj = nn.Linear(dim, intermediate_size, bias=False)
74
+ self.down_proj = nn.Linear(intermediate_size, dim, bias=False)
75
+ self.dropout = nn.Dropout(dropout)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
79
+
80
+ class MirroredSwiGLU(nn.Module):
81
+ """SwiGLU with shared base weights and dual gates.
82
+
83
+ Standard SwiGLU: y = W₂(silu(W₁x) ⊙ W₃x) — 3 matrices
84
+ Mirrored SwiGLU: y = W₂(W₁x ⊙ (silu(W₃ ⊙ silu(W₄x)))) — 2 shared + 2 gates
85
+
86
+ W₁ computed once, reused for both branches.
87
+ """
88
+
89
+ def __init__(self, hidden_size: int, intermediate_size: int | None = None,
90
+ gate_mode: str = 'additive', use_g2lu: bool = True):
91
+ super().__init__()
92
+ self.gate_mode = gate_mode
93
+ self.use_g2lu = use_g2lu
94
+ self._current_step = 0
95
+ intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
96
+ intermediate_size = ((intermediate_size + 63) // 64) * 64
97
+
98
+ # Shared structural transform
99
+ self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
100
+ self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
101
+
102
+ # Gate(s)
103
+ self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
104
+ if use_g2lu:
105
+ self.w4 = nn.Linear(hidden_size, intermediate_size, bias=False)
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ hidden = self.w1(x)
109
+ if self.use_g2lu:
110
+ g4 = F.silu(self.w4(x))
111
+ g3 = F.silu(self.w3(x) * g4)
112
+ else:
113
+ g3 = F.silu(self.w3(x))
114
+ return self.w2(hidden * g3)
115
+
116
+
117
+ class MirroredBlock(nn.Module):
118
+ """Transformer block with shared weights for expand/compress phases.
119
+
120
+ Each MirroredBlock is used TWICE in the forward pass:
121
+ once during expand (building semantics) and once during compress (encoding output).
122
+
123
+ Shared: attention weights (optional), FFN W₁/W₂
124
+ Separate: norms (different residual stream statistics), FFN gate
125
+ """
126
+
127
+ def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
128
+ max_seq_len: int = 2048,
129
+ dropout: float = 0.0,
130
+ window_size: int | None = None, gate_mode: str = 'additive',
131
+ word_rope_dims: int = 0, word_rope_base: float = 10.0,
132
+ use_g2lu: bool = True):
133
+ super().__init__()
134
+
135
+ self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size=window_size,
136
+ word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
137
+
138
+ # FFN with shared base + direction-specific gates
139
+ self.ffn = MirroredSwiGLU(hidden_size, gate_mode=gate_mode, use_g2lu=use_g2lu)
140
+
141
+ # Separate norms per direction (residual stream statistics differ)
142
+ self.expand_attn_norm = RMSNorm(hidden_size)
143
+ self.expand_ffn_norm = RMSNorm(hidden_size)
144
+ self.compress_attn_norm = RMSNorm(hidden_size)
145
+ self.compress_ffn_norm = RMSNorm(hidden_size)
146
+
147
+ def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
148
+ word_positions: torch.Tensor | None = None) -> tuple:
149
+ attn_norm = self.compress_attn_norm
150
+ ffn_norm = self.compress_ffn_norm
151
+ attn = self.attn
152
+
153
+ attn_out, new_kv = attn(attn_norm(x), use_cache, past_kv, word_positions=word_positions)
154
+ x = x + attn_out
155
+ x = x + self.ffn(ffn_norm(x))
156
+
157
+ return x, new_kv
158
+
159
+
160
+ class MiddleBlock(nn.Module):
161
+ """Standard transformer block for unique middle layers.
162
+
163
+ When gate_mode is provided, uses MirroredSwiGLU (dual-gate) instead of
164
+ single-gate SwiGLU — giving the middle the same rich gating geometry
165
+ as the mirror pairs.
166
+ """
167
+
168
+ def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
169
+ max_seq_len: int = 2048,
170
+ dropout: float = 0.0,
171
+ word_rope_dims: int = 0, word_rope_base: float = 10.0,
172
+ use_g2lu: bool = True):
173
+ super().__init__()
174
+ self.attn_norm = RMSNorm(hidden_size)
175
+ self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout,
176
+ word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
177
+ self.ffn_norm = RMSNorm(hidden_size)
178
+ self.ffn = MirroredSwiGLU(hidden_size, use_g2lu=use_g2lu)
179
+
180
+ def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
181
+ word_positions: torch.Tensor | None = None) -> tuple:
182
+ attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
183
+ x = x + attn_out
184
+ x = x + self.ffn(self.ffn_norm(x))
185
+ return x, new_kv
186
+
187
+
188
+ class MirroredTransformer(nn.Module):
189
+ """Transformer with mirrored expand/compress architecture.
190
+
191
+ Forward pass:
192
+ 1. Embed tokens
193
+ 2. Expand phase: mirror_blocks[0..N] with w3
194
+ 3. Middle: unique standard blocks
195
+ 4. Compress phase: mirror_blocks[N..0] (reversed) with w4
196
+ 5. Norm + LM head
197
+
198
+ For a 12-layer model with n_middle=2:
199
+ - 5 mirror pairs (10 virtual layers) + 2 middle = 12 effective layers
200
+ - Expand: blocks[0] → blocks[4]
201
+ - Middle: middle[0] → middle[1]
202
+ - Compress: blocks[4] → blocks[0]
203
+ """
204
+
205
+ def __init__(self, config: MirroredConfig):
206
+ super().__init__()
207
+ self.config = config
208
+
209
+ # Token embeddings (optionally factorized)
210
+ embed_dim = getattr(config, 'embed_dim', 0)
211
+ head_dim = getattr(config, 'head_dim', 0)
212
+ # Auto-mirror factorization: head uses embed_dim for weight tying
213
+ if embed_dim > 0 and head_dim == 0:
214
+ head_dim = embed_dim
215
+
216
+ # G²LU config (needed before projection setup)
217
+ use_g2lu = getattr(config, 'use_g2lu', True)
218
+
219
+ if embed_dim > 0:
220
+ self.embed = nn.Embedding(config.vocab_size, embed_dim)
221
+ self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
222
+ # G²LU gates for up-projection (consistent with mirror blocks)
223
+ if use_g2lu:
224
+ self.embed_g3 = nn.Linear(embed_dim, config.hidden_size, bias=False)
225
+ self.embed_g4 = nn.Linear(embed_dim, config.hidden_size, bias=False)
226
+ else:
227
+ self.embed_g3 = None
228
+ self.embed_g4 = None
229
+ else:
230
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
231
+ self.embed_proj = None
232
+ self.embed_g3 = None
233
+ self.embed_g4 = None
234
+ self.embed_scale = math.sqrt(config.hidden_size)
235
+ self.window_sizes = [None] * config.n_mirror
236
+
237
+ # Word-position RoPE config
238
+ word_rope_dims = getattr(config, 'word_rope_dims', 0)
239
+ word_rope_base = getattr(config, 'word_rope_base', 10.0)
240
+
241
+ # Mirrored blocks (used in both expand and compress phases)
242
+ self.mirror_blocks = nn.ModuleList([
243
+ MirroredBlock(
244
+ config.hidden_size, config.num_heads, config.num_kv_heads,
245
+ config.max_seq_len,
246
+ config.dropout,
247
+ window_size=self.window_sizes[i],
248
+ word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
249
+ use_g2lu=use_g2lu,
250
+ )
251
+ for i in range(config.n_mirror)
252
+ ])
253
+
254
+ # Unique middle blocks (standard transformer, optionally dual-gated)
255
+ self.middle_blocks = nn.ModuleList([
256
+ MiddleBlock(config.hidden_size, config.num_heads, config.num_kv_heads,
257
+ config.max_seq_len, config.dropout,
258
+ word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
259
+ use_g2lu=use_g2lu)
260
+ for _ in range(config.n_middle)
261
+ ])
262
+
263
+ # Output (optionally MLP head)
264
+ self.norm = RMSNorm(config.hidden_size)
265
+ if head_dim > 0:
266
+ self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
267
+ self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
268
+ # G²LU gates for down-projection
269
+ if use_g2lu:
270
+ self.head_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
271
+ self.head_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
272
+ else:
273
+ self.head_g3 = None
274
+ self.head_g4 = None
275
+ else:
276
+ self.head_down = None
277
+ self.head_g3 = None
278
+ self.head_g4 = None
279
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
280
+
281
+ # Weight tying (when embed and lm_head dimensions match)
282
+ _e = embed_dim if embed_dim > 0 else config.hidden_size
283
+ _h = head_dim if head_dim > 0 else config.hidden_size
284
+ if _e == _h:
285
+ self.lm_head.weight = self.embed.weight
286
+
287
+ # Auxiliary skip-ahead prediction head
288
+ self.skip_head = None
289
+ self.skip_head_down = None
290
+ self.skip_g3 = None
291
+ self.skip_g4 = None
292
+ if config.aux_skip_k > 0:
293
+ if head_dim > 0:
294
+ self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
295
+ self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
296
+ if use_g2lu:
297
+ self.skip_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
298
+ self.skip_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
299
+ else:
300
+ self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
301
+
302
+ # Initialize weights
303
+ self.apply(self._init_weights)
304
+
305
+ def _init_weights(self, module):
306
+ if isinstance(module, nn.Linear):
307
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
308
+ if module.bias is not None:
309
+ torch.nn.init.zeros_(module.bias)
310
+ elif isinstance(module, nn.Embedding):
311
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
312
+
313
+ @property
314
+ def total_virtual_layers(self) -> int:
315
+ """Total number of virtual layers in the forward pass."""
316
+ return self.config.n_mirror * 2 + self.config.n_middle
317
+
318
+ def forward(
319
+ self,
320
+ input_ids: torch.Tensor,
321
+ labels: torch.Tensor = None,
322
+ use_cache: bool = False,
323
+ past_kv: list = None,
324
+ word_positions: torch.Tensor | None = None,
325
+ ) -> dict:
326
+ B, L = input_ids.shape
327
+
328
+ # Embed tokens (optionally factorized, with G²LU gating)
329
+ x = self.embed(input_ids)
330
+ if self.embed_proj is not None:
331
+ if self.embed_g3 is not None:
332
+ g4 = F.silu(self.embed_g4(x))
333
+ g3 = F.silu(self.embed_g3(x) * g4)
334
+ x = self.embed_proj(x) * g3
335
+ else:
336
+ x = F.silu(self.embed_proj(x))
337
+ x = x * self.embed_scale
338
+
339
+ new_kv = [] if use_cache else None
340
+ kv_idx = 0
341
+
342
+ # === Expand phase ===
343
+ for block in self.mirror_blocks:
344
+ layer_past = past_kv[kv_idx] if past_kv is not None else None
345
+ x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
346
+ if use_cache:
347
+ new_kv.append(kv)
348
+ kv_idx += 1
349
+
350
+ # === Dual-path: save pre-middle state for alignment loss ===
351
+ for block in self.middle_blocks:
352
+ layer_past = past_kv[kv_idx] if past_kv is not None else None
353
+ x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
354
+ if use_cache:
355
+ new_kv.append(kv)
356
+ kv_idx += 1
357
+
358
+ # === Compress phase (reversed order) ===
359
+ for i in reversed(range(len(self.mirror_blocks))):
360
+ layer_past = past_kv[kv_idx] if past_kv is not None else None
361
+ x, kv = self.mirror_blocks[i](x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
362
+ if use_cache:
363
+ new_kv.append(kv)
364
+ kv_idx += 1
365
+
366
+ # === Output (optionally MLP head with G²LU gating) ===
367
+ x = self.norm(x)
368
+ if self.head_down is not None:
369
+ if self.head_g3 is not None:
370
+ g4 = F.silu(self.head_g4(x))
371
+ g3 = F.silu(self.head_g3(x) * g4)
372
+ logits = self.lm_head(self.head_down(x) * g3)
373
+ else:
374
+ logits = self.lm_head(F.silu(self.head_down(x)))
375
+ else:
376
+ logits = self.lm_head(x)
377
+
378
+ result = {"logits": logits}
379
+
380
+ if use_cache:
381
+ result["past_kv"] = new_kv
382
+
383
+ if labels is not None:
384
+ shift_logits = logits[:, :-1, :].contiguous()
385
+ shift_labels = labels[:, 1:].contiguous()
386
+ loss = F.cross_entropy(
387
+ shift_logits.view(-1, self.config.vocab_size),
388
+ shift_labels.view(-1),
389
+ ignore_index=-100
390
+ )
391
+
392
+ if self.skip_head is not None:
393
+ skip_k = self.config.aux_skip_k
394
+ if self.skip_head_down is not None:
395
+ if self.skip_g3 is not None:
396
+ g4 = F.silu(self.skip_g4(x))
397
+ g3 = F.silu(self.skip_g3(x) * g4)
398
+ skip_logits = self.skip_head(self.skip_head_down(x) * g3)[:, :-skip_k, :].contiguous()
399
+ else:
400
+ skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
401
+ else:
402
+ skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
403
+ skip_labels = labels[:, skip_k:].contiguous()
404
+ aux_loss = F.cross_entropy(
405
+ skip_logits.view(-1, self.config.vocab_size),
406
+ skip_labels.view(-1),
407
+ ignore_index=-100
408
+ )
409
+ result["aux_loss"] = aux_loss
410
+ loss = loss + self.config.aux_skip_weight * aux_loss
411
+
412
+ result["loss"] = loss
413
+
414
+ return result
415
+
416
+ @torch.no_grad()
417
+ def generate(
418
+ self,
419
+ prompt_ids: torch.Tensor,
420
+ max_new_tokens: int = 50,
421
+ temperature: float = 0.8,
422
+ top_k: int = 50,
423
+ top_p: float = 0.9,
424
+ use_cache: bool = True,
425
+ word_start_table: torch.Tensor | None = None,
426
+ ) -> torch.Tensor:
427
+ """Autoregressive generation with KV caching."""
428
+ from .layers import compute_word_positions
429
+
430
+ self.eval()
431
+ generated = prompt_ids.clone()
432
+ past_kv = None
433
+ word_pos_counter = 0
434
+
435
+ for _ in range(max_new_tokens):
436
+ if use_cache and past_kv is not None:
437
+ input_ids = generated[:, -1:]
438
+ if word_start_table is not None:
439
+ last_token = generated[0, -1].item()
440
+ if word_start_table[last_token]:
441
+ word_pos_counter = 0
442
+ else:
443
+ word_pos_counter += 1
444
+ word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
445
+ else:
446
+ word_positions = None
447
+ else:
448
+ input_ids = generated
449
+ if word_start_table is not None:
450
+ word_positions = compute_word_positions(input_ids, word_start_table)
451
+ else:
452
+ word_positions = None
453
+
454
+ output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
455
+ logits = output["logits"][:, -1, :]
456
+
457
+ if use_cache:
458
+ past_kv = output["past_kv"]
459
+
460
+ if temperature > 0:
461
+ logits = logits / temperature
462
+
463
+ if top_k > 0:
464
+ top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
465
+ min_top_k = top_k_vals[:, -1].unsqueeze(-1)
466
+ logits = torch.where(logits < min_top_k, float("-inf"), logits)
467
+
468
+ if top_p < 1.0:
469
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
470
+ cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
471
+ sorted_indices_to_remove = cumsum_probs > top_p
472
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
473
+ sorted_indices_to_remove[:, 0] = False
474
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
475
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
476
+
477
+ probs = F.softmax(logits, dim=-1)
478
+ next_token = torch.multinomial(probs, num_samples=1)
479
+ else:
480
+ next_token = logits.argmax(dim=-1, keepdim=True)
481
+
482
+ generated = torch.cat([generated, next_token], dim=1)
483
+
484
+ if generated.size(1) >= self.config.max_seq_len:
485
+ break
486
+
487
+ return generated
488
+
489
+
490
+ def count_mirrored_parameters(model: MirroredTransformer) -> dict:
491
+ """Count parameters with breakdown by component."""
492
+ total = sum(p.numel() for p in model.parameters() if p.requires_grad)
493
+
494
+ # Unique params (not double-counted from weight tying)
495
+ unique = sum(p.numel() for p in set(p for p in model.parameters() if p.requires_grad))
496
+
497
+ mirror_params = sum(p.numel() for p in model.mirror_blocks.parameters())
498
+ middle_params = sum(p.numel() for p in model.middle_blocks.parameters())
499
+ embed_params = model.embed.weight.numel()
500
+ if model.embed_proj is not None:
501
+ embed_params += model.embed_proj.weight.numel()
502
+ head_params = 0
503
+ if model.head_down is not None:
504
+ head_params += model.head_down.weight.numel()
505
+ head_params += model.lm_head.weight.numel()
506
+
507
+ # Break down mirror block into shared vs direction-specific
508
+ shared_attn = 0
509
+ shared_ffn_base = 0
510
+ gate_params = 0
511
+ norm_params = 0
512
+
513
+ for block in model.mirror_blocks:
514
+ shared_attn += sum(p.numel() for p in block.attn.parameters())
515
+ shared_ffn_base += block.ffn.w1.weight.numel() + block.ffn.w2.weight.numel()
516
+ gate_params += block.ffn.w3.weight.numel()
517
+ if hasattr(block.ffn, 'w4'):
518
+ gate_params += block.ffn.w4.weight.numel()
519
+ norm_params += sum(p.numel() for n, p in block.named_parameters() if 'norm' in n)
520
+
521
+ return {
522
+ "total": total,
523
+ "unique": unique,
524
+ "mirror_blocks": mirror_params,
525
+ "middle_blocks": middle_params,
526
+ "embedding": embed_params,
527
+ "head": head_params,
528
+ "shared_attention": shared_attn,
529
+ "shared_ffn_base": shared_ffn_base,
530
+ "direction_gates": gate_params,
531
+ "norms": norm_params,
532
+ }
model.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Circuit Transformer: Minimal transformer for semantic circuitry experiments.
3
+
4
+ Follows patterns from shimmer/lira/gpt.py with extension hooks for future work.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+ from .config import CircuitConfig
13
+ from .layers import RMSNorm, RotaryEmbedding, CausalAttention, SwiGLU, WordPositionRoPE
14
+
15
+
16
+ class TransformerBlock(nn.Module):
17
+ """Pre-norm transformer block with causal attention."""
18
+
19
+ def __init__(
20
+ self,
21
+ hidden_size: int,
22
+ num_heads: int,
23
+ num_kv_heads: int | None = None,
24
+ max_seq_len: int = 2048,
25
+ dropout: float = 0.0,
26
+ window_size: int | None = None,
27
+ word_rope_dims: int = 0,
28
+ word_rope_base: float = 10.0,
29
+ ):
30
+ super().__init__()
31
+ self.attn_norm = RMSNorm(hidden_size)
32
+ self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size,
33
+ word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
34
+ self.ffn_norm = RMSNorm(hidden_size)
35
+ self.ffn = SwiGLU(hidden_size)
36
+
37
+ def forward(
38
+ self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
39
+ word_positions: torch.Tensor | None = None,
40
+ ) -> tuple[torch.Tensor, tuple | None]:
41
+ # Attention with residual
42
+ attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
43
+ x = x + attn_out
44
+
45
+ # FFN with residual
46
+ x = x + self.ffn(self.ffn_norm(x))
47
+
48
+ return x, new_kv
49
+
50
+
51
+ class CircuitTransformer(nn.Module):
52
+ """
53
+ Minimal transformer for semantic circuitry experiments.
54
+
55
+ Features:
56
+ - Standard GPT-style architecture (RMSNorm, RoPE, SwiGLU, causal attention)
57
+ - Weight tying (embed = lm_head)
58
+ - Extension hooks for future work:
59
+ - freeze_layers() / unfreeze_layers() for progressive training
60
+ - get_layer_outputs() for interpretability
61
+ - window_size param for sliding window attention
62
+ """
63
+
64
+ def __init__(self, config: CircuitConfig):
65
+ super().__init__()
66
+ self.config = config
67
+
68
+ # Token embeddings (optionally factorized)
69
+ embed_dim = getattr(config, 'embed_dim', 0)
70
+ head_dim = getattr(config, 'head_dim', 0)
71
+ # Auto-mirror factorization: head uses embed_dim for weight tying
72
+ if embed_dim > 0 and head_dim == 0:
73
+ head_dim = embed_dim
74
+
75
+ if embed_dim > 0:
76
+ self.embed = nn.Embedding(config.vocab_size, embed_dim)
77
+ self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
78
+ else:
79
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
80
+ self.embed_proj = None
81
+ self.embed_scale = math.sqrt(config.hidden_size)
82
+
83
+ # Transformer blocks
84
+ self.layers = nn.ModuleList([
85
+ TransformerBlock(
86
+ config.hidden_size,
87
+ config.num_heads,
88
+ getattr(config, 'num_kv_heads', None),
89
+ config.max_seq_len,
90
+ config.dropout,
91
+ word_rope_dims=getattr(config, 'word_rope_dims', 0),
92
+ word_rope_base=getattr(config, 'word_rope_base', 10.0),
93
+ )
94
+ for _ in range(config.num_layers)
95
+ ])
96
+
97
+ # Output (optionally MLP head)
98
+ self.norm = RMSNorm(config.hidden_size)
99
+ if head_dim > 0:
100
+ self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
101
+ self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
102
+ else:
103
+ self.head_down = None
104
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
105
+
106
+ # Weight tying (when embed and lm_head dimensions match)
107
+ _e = embed_dim if embed_dim > 0 else config.hidden_size
108
+ _h = head_dim if head_dim > 0 else config.hidden_size
109
+ if _e == _h:
110
+ self.lm_head.weight = self.embed.weight
111
+
112
+ # Auxiliary skip-ahead prediction head
113
+ self.skip_head = None
114
+ self.skip_head_down = None
115
+ aux_skip_k = getattr(config, 'aux_skip_k', 0)
116
+ if aux_skip_k > 0:
117
+ if head_dim > 0:
118
+ self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
119
+ self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
120
+ else:
121
+ self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
122
+
123
+ # Track frozen layers
124
+ self._frozen_layers: set[int] = set()
125
+
126
+ # Initialize weights
127
+ self.apply(self._init_weights)
128
+
129
+ def _init_weights(self, module):
130
+ if isinstance(module, nn.Linear):
131
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
132
+ if module.bias is not None:
133
+ torch.nn.init.zeros_(module.bias)
134
+ elif isinstance(module, nn.Embedding):
135
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
136
+
137
+ def forward(
138
+ self,
139
+ input_ids: torch.Tensor,
140
+ labels: torch.Tensor | None = None,
141
+ use_cache: bool = False,
142
+ past_kv: list | None = None,
143
+ word_positions: torch.Tensor | None = None,
144
+ ) -> dict:
145
+ """
146
+ Forward pass.
147
+
148
+ Args:
149
+ input_ids: [B, L] token IDs
150
+ labels: [B, L] target token IDs (for loss computation)
151
+ use_cache: Whether to return KV cache for generation
152
+ past_kv: Previous KV cache
153
+ word_positions: [B, L] position within word (from compute_word_positions)
154
+
155
+ Returns:
156
+ dict with 'logits', optionally 'loss' and 'past_kv'
157
+ """
158
+ B, L = input_ids.shape
159
+
160
+ # Embed tokens (optionally factorized)
161
+ x = self.embed(input_ids)
162
+ if self.embed_proj is not None:
163
+ x = F.silu(self.embed_proj(x))
164
+ x = x * self.embed_scale
165
+
166
+ # Process through layers
167
+ new_kv = [] if use_cache else None
168
+ for i, layer in enumerate(self.layers):
169
+ layer_past = past_kv[i] if past_kv is not None else None
170
+ x, kv = layer(x, use_cache, layer_past, word_positions=word_positions)
171
+ if use_cache:
172
+ new_kv.append(kv)
173
+
174
+ # Output (optionally MLP head)
175
+ x = self.norm(x)
176
+ if self.head_down is not None:
177
+ logits = self.lm_head(F.silu(self.head_down(x)))
178
+ else:
179
+ logits = self.lm_head(x)
180
+
181
+ result = {"logits": logits}
182
+
183
+ if use_cache:
184
+ result["past_kv"] = new_kv
185
+
186
+ # Compute loss if labels provided
187
+ if labels is not None:
188
+ # Shift for next-token prediction
189
+ shift_logits = logits[:, :-1, :].contiguous()
190
+ shift_labels = labels[:, 1:].contiguous()
191
+ loss = F.cross_entropy(
192
+ shift_logits.view(-1, self.config.vocab_size),
193
+ shift_labels.view(-1),
194
+ ignore_index=-100,
195
+ )
196
+
197
+ # Auxiliary skip-ahead prediction
198
+ if self.skip_head is not None:
199
+ skip_k = getattr(self.config, 'aux_skip_k', 0)
200
+ skip_weight = getattr(self.config, 'aux_skip_weight', 0.1)
201
+ if self.skip_head_down is not None:
202
+ skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
203
+ else:
204
+ skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
205
+ skip_labels = labels[:, skip_k:].contiguous()
206
+ aux_loss = F.cross_entropy(
207
+ skip_logits.view(-1, self.config.vocab_size),
208
+ skip_labels.view(-1),
209
+ ignore_index=-100,
210
+ )
211
+ result["aux_loss"] = aux_loss
212
+ loss = loss + skip_weight * aux_loss
213
+
214
+ result["loss"] = loss
215
+
216
+ return result
217
+
218
+ # === Extension hooks for future experiments ===
219
+
220
+ def freeze_layers(self, indices: list[int]) -> None:
221
+ """Freeze specific layers (stop gradients)."""
222
+ for idx in indices:
223
+ if 0 <= idx < len(self.layers):
224
+ for param in self.layers[idx].parameters():
225
+ param.requires_grad = False
226
+ self._frozen_layers.add(idx)
227
+
228
+ def unfreeze_layers(self, indices: list[int] | None = None) -> None:
229
+ """Unfreeze specific layers (or all if indices=None)."""
230
+ if indices is None:
231
+ indices = list(self._frozen_layers)
232
+ for idx in indices:
233
+ if 0 <= idx < len(self.layers):
234
+ for param in self.layers[idx].parameters():
235
+ param.requires_grad = True
236
+ self._frozen_layers.discard(idx)
237
+
238
+ def get_layer_outputs(self, input_ids: torch.Tensor) -> list[torch.Tensor]:
239
+ """Get intermediate outputs from each layer for interpretability."""
240
+ outputs = []
241
+ x = self.embed(input_ids)
242
+ if self.embed_proj is not None:
243
+ x = F.silu(self.embed_proj(x))
244
+ x = x * self.embed_scale
245
+
246
+ for layer in self.layers:
247
+ x, _ = layer(x, use_cache=False, past_kv=None)
248
+ outputs.append(x.clone())
249
+
250
+ return outputs
251
+
252
+ @torch.no_grad()
253
+ def generate(
254
+ self,
255
+ prompt_ids: torch.Tensor,
256
+ max_new_tokens: int = 50,
257
+ temperature: float = 0.8,
258
+ top_k: int = 50,
259
+ top_p: float = 0.9,
260
+ use_cache: bool = True,
261
+ word_start_table: torch.Tensor | None = None,
262
+ ) -> torch.Tensor:
263
+ """
264
+ Autoregressive generation with KV caching.
265
+
266
+ Args:
267
+ prompt_ids: [B, L] prompt token IDs
268
+ max_new_tokens: Maximum tokens to generate
269
+ temperature: Sampling temperature
270
+ top_k: Top-k filtering
271
+ top_p: Nucleus sampling threshold
272
+ use_cache: Use KV cache for faster generation
273
+ word_start_table: [vocab_size] bool tensor for word-position RoPE
274
+
275
+ Returns:
276
+ [B, L + max_new_tokens] generated token IDs
277
+ """
278
+ from .layers import compute_word_positions
279
+
280
+ self.eval()
281
+ generated = prompt_ids.clone()
282
+ past_kv = None
283
+ word_pos_counter = 0 # Track word position during cached generation
284
+
285
+ for _ in range(max_new_tokens):
286
+ # Get input (full sequence or just last token with cache)
287
+ if use_cache and past_kv is not None:
288
+ input_ids = generated[:, -1:]
289
+ # Compute word position for the single new token
290
+ if word_start_table is not None:
291
+ last_token = generated[0, -1].item()
292
+ if word_start_table[last_token]:
293
+ word_pos_counter = 0
294
+ else:
295
+ word_pos_counter += 1
296
+ word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
297
+ else:
298
+ word_positions = None
299
+ else:
300
+ input_ids = generated
301
+ # Compute word positions for full sequence
302
+ if word_start_table is not None:
303
+ word_positions = compute_word_positions(input_ids, word_start_table)
304
+ else:
305
+ word_positions = None
306
+
307
+ # Forward pass
308
+ output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
309
+ logits = output["logits"][:, -1, :] # Last position
310
+
311
+ if use_cache:
312
+ past_kv = output["past_kv"]
313
+
314
+ # Apply temperature
315
+ if temperature > 0:
316
+ logits = logits / temperature
317
+
318
+ # Top-k filtering
319
+ if top_k > 0:
320
+ top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
321
+ min_top_k = top_k_vals[:, -1].unsqueeze(-1)
322
+ logits = torch.where(logits < min_top_k, float("-inf"), logits)
323
+
324
+ # Top-p (nucleus) filtering
325
+ if top_p < 1.0:
326
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
327
+ cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
328
+
329
+ # Remove tokens with cumulative prob above threshold
330
+ sorted_indices_to_remove = cumsum_probs > top_p
331
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
332
+ sorted_indices_to_remove[:, 0] = False
333
+
334
+ indices_to_remove = sorted_indices_to_remove.scatter(
335
+ 1, sorted_indices, sorted_indices_to_remove
336
+ )
337
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
338
+
339
+ # Sample
340
+ probs = F.softmax(logits, dim=-1)
341
+ next_token = torch.multinomial(probs, num_samples=1)
342
+ else:
343
+ # Greedy
344
+ next_token = logits.argmax(dim=-1, keepdim=True)
345
+
346
+ generated = torch.cat([generated, next_token], dim=1)
347
+
348
+ # Stop if max length reached
349
+ if generated.size(1) >= self.config.max_seq_len:
350
+ break
351
+
352
+ return generated
353
+
354
+
355
+ def count_parameters(model: CircuitTransformer) -> int:
356
+ """Count trainable parameters."""
357
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
scripts/__init__.py ADDED
File without changes
scripts/representation_analysis.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Representation analysis: CKA and Logit Lens for Prisma / Circuit Transformer.
4
+
5
+ CKA (Centered Kernel Alignment):
6
+ Measures representational similarity between all layer pairs.
7
+ Produces a heatmap revealing mirror symmetry, phase transitions,
8
+ and cross-model alignment.
9
+
10
+ Logit Lens:
11
+ Projects intermediate representations to vocabulary space at every layer.
12
+ Reveals what the model "thinks" at each processing stage -- from raw
13
+ tokens through the semantic bottleneck back to specific predictions.
14
+
15
+ Also computes representation drift (cosine similarity between consecutive layers).
16
+
17
+ Usage:
18
+ # Full analysis (CKA + logit lens)
19
+ python -m circuits.scripts.representation_analysis \\
20
+ --checkpoint path/to/checkpoint.pt \\
21
+ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
22
+
23
+ # Cross-model CKA
24
+ python -m circuits.scripts.representation_analysis \\
25
+ --checkpoint path/to/prisma.pt --hf-model gpt2-medium \\
26
+ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
27
+
28
+ # CKA only (skip logit lens)
29
+ python -m circuits.scripts.representation_analysis \\
30
+ --checkpoint path/to/checkpoint.pt \\
31
+ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train \\
32
+ --no-logit-lens
33
+ """
34
+
35
+ import argparse
36
+ import json
37
+ import sys
38
+ import os
39
+ from pathlib import Path
40
+ from collections import OrderedDict
41
+
42
+ import numpy as np
43
+ import torch
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+
47
+ import matplotlib
48
+ matplotlib.use("Agg")
49
+ import matplotlib.pyplot as plt
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Model loading
54
+ # ---------------------------------------------------------------------------
55
+
56
+ def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
57
+ """Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
58
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
59
+ from circuits.config import CircuitConfig
60
+ from circuits.model import CircuitTransformer
61
+ from circuits.mirrored import MirroredConfig, MirroredTransformer
62
+
63
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
64
+ model_type = ckpt.get("model_type", "standard")
65
+ config_dict = ckpt.get("config", {})
66
+
67
+ if model_type == "mirrored":
68
+ if config_dict.get("dual_gate_middle"):
69
+ config_dict.pop("dual_gate_middle")
70
+ config = MirroredConfig.from_dict(config_dict)
71
+ model = MirroredTransformer(config)
72
+ else:
73
+ config = CircuitConfig.from_dict(config_dict)
74
+ model = CircuitTransformer(config)
75
+
76
+ state_dict = ckpt["model"]
77
+ if any(k.startswith("_orig_mod.") for k in state_dict):
78
+ state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
79
+ model.load_state_dict(state_dict, strict=False)
80
+ model.to(device).eval()
81
+
82
+ return model, config_dict, model_type
83
+
84
+
85
+ def load_hf_model(model_name: str, device: str = "cpu"):
86
+ """Load a HuggingFace causal LM."""
87
+ from transformers import AutoModelForCausalLM
88
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True)
89
+ model.to(device).eval()
90
+ return model
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Data loading
95
+ # ---------------------------------------------------------------------------
96
+
97
+ def load_data(data_source: str, tokenizer_name: str, num_samples: int = 32,
98
+ context_length: int = 512, device: str = "cpu"):
99
+ """Load tokenized data. Returns (input_ids, tokenizer).
100
+
101
+ Supports:
102
+ - Memmap .bin files (from circuits training cache)
103
+ - hf:dataset:config:split (streaming from HuggingFace)
104
+ - Plain text files
105
+ """
106
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
107
+ from circuits.data import get_tokenizer
108
+
109
+ tokenizer = get_tokenizer(tokenizer_name)
110
+
111
+ # Memmap binary file (already tokenized)
112
+ if data_source.endswith(".bin"):
113
+ import struct
114
+ with open(data_source, 'rb') as f:
115
+ n_chunks, seq_len = struct.unpack('II', f.read(8))
116
+ data = np.memmap(data_source, dtype=np.int32, mode='r',
117
+ offset=8, shape=(n_chunks, seq_len))
118
+ n = min(num_samples, n_chunks)
119
+ # Slice to requested context length
120
+ cl = min(context_length, seq_len)
121
+ input_ids = torch.from_numpy(data[:n, :cl].copy()).long().to(device)
122
+ return input_ids, tokenizer
123
+
124
+ # HuggingFace dataset
125
+ if data_source.startswith("hf:"):
126
+ from datasets import load_dataset
127
+ parts = data_source[3:].split(":")
128
+ ds_name = parts[0]
129
+ ds_config = parts[1] if len(parts) > 1 else None
130
+ ds_split = parts[2] if len(parts) > 2 else "train"
131
+ dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
132
+ all_ids = []
133
+ for item in dataset:
134
+ text = item.get("text", "")
135
+ if len(text) < 100:
136
+ continue
137
+ ids = tokenizer.encode(text)
138
+ if len(ids) >= context_length:
139
+ all_ids.append(ids[:context_length])
140
+ if len(all_ids) >= num_samples:
141
+ break
142
+ if not all_ids:
143
+ return None, tokenizer
144
+ return torch.tensor(all_ids, device=device), tokenizer
145
+
146
+ # Plain text file
147
+ with open(data_source) as f:
148
+ texts = [line.strip() for line in f if len(line.strip()) > 100]
149
+ all_ids = []
150
+ for text in texts:
151
+ ids = tokenizer.encode(text)
152
+ if len(ids) >= context_length:
153
+ all_ids.append(ids[:context_length])
154
+ if len(all_ids) >= num_samples:
155
+ break
156
+ if not all_ids:
157
+ return None, tokenizer
158
+ return torch.tensor(all_ids, device=device), tokenizer
159
+
160
+
161
+ def tokenize_for_hf(texts: list, model_name: str, context_length: int = 512,
162
+ device: str = "cpu"):
163
+ """Tokenize texts for an HF model. Returns (input_ids, tokenizer)."""
164
+ from transformers import AutoTokenizer
165
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
166
+ use_fast=False,
167
+ trust_remote_code=True)
168
+ if tokenizer.pad_token is None:
169
+ tokenizer.pad_token = tokenizer.eos_token
170
+
171
+ all_ids = []
172
+ for text in texts:
173
+ ids = tokenizer.encode(text, max_length=context_length, truncation=True)
174
+ if len(ids) >= context_length:
175
+ all_ids.append(ids[:context_length])
176
+ elif len(ids) > 32:
177
+ all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
178
+
179
+ if not all_ids:
180
+ return None, tokenizer
181
+
182
+ return torch.tensor(all_ids, device=device), tokenizer
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Activation collection
187
+ # ---------------------------------------------------------------------------
188
+
189
+ def collect_mirrored_activations(model, input_ids, word_positions=None):
190
+ """Collect activations from MirroredTransformer at every processing stage."""
191
+ activations = OrderedDict()
192
+
193
+ with torch.no_grad():
194
+ x = model.embed(input_ids)
195
+ if model.embed_proj is not None:
196
+ if model.embed_g3 is not None:
197
+ g4 = F.silu(model.embed_g4(x))
198
+ g3 = F.silu(model.embed_g3(x) * g4)
199
+ x = model.embed_proj(x) * g3
200
+ else:
201
+ x = F.silu(model.embed_proj(x))
202
+ x = x * model.embed_scale
203
+ activations["embedding"] = x.detach().cpu()
204
+
205
+ for i, block in enumerate(model.mirror_blocks):
206
+ x, _ = block(x, word_positions=word_positions)
207
+ activations[f"expand_{i}"] = x.detach().cpu()
208
+
209
+ for i, block in enumerate(model.middle_blocks):
210
+ x, _ = block(x, word_positions=word_positions)
211
+ activations[f"middle_{i}"] = x.detach().cpu()
212
+
213
+ for i in reversed(range(len(model.mirror_blocks))):
214
+ x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
215
+ compress_idx = len(model.mirror_blocks) - 1 - i
216
+ activations[f"compress_{compress_idx}"] = x.detach().cpu()
217
+
218
+ x = model.norm(x)
219
+ activations["final_norm"] = x.detach().cpu()
220
+
221
+ return activations
222
+
223
+
224
+ def collect_standard_activations(model, input_ids, word_positions=None):
225
+ """Collect activations from standard CircuitTransformer."""
226
+ activations = OrderedDict()
227
+
228
+ with torch.no_grad():
229
+ x = model.embed(input_ids)
230
+ if model.embed_proj is not None:
231
+ x = F.silu(model.embed_proj(x))
232
+ x = x * model.embed_scale
233
+ activations["embedding"] = x.detach().cpu()
234
+
235
+ for i, layer in enumerate(model.layers):
236
+ x, _ = layer(x, word_positions=word_positions)
237
+ activations[f"layer_{i}"] = x.detach().cpu()
238
+
239
+ x = model.norm(x)
240
+ activations["final_norm"] = x.detach().cpu()
241
+
242
+ return activations
243
+
244
+
245
+ def collect_hf_activations(model, input_ids):
246
+ """Hook-based activation collection for HuggingFace models."""
247
+ activations = OrderedDict()
248
+ hooks = []
249
+
250
+ if hasattr(model, 'transformer'):
251
+ # GPT-2 style
252
+ blocks = model.transformer.h
253
+ embed = model.transformer.wte
254
+ final_norm = model.transformer.ln_f
255
+ elif hasattr(model, 'model'):
256
+ # Llama / Mistral style
257
+ blocks = model.model.layers
258
+ embed = model.model.embed_tokens
259
+ final_norm = model.model.norm
260
+ else:
261
+ raise ValueError(f"Unsupported HF model: {type(model)}")
262
+
263
+ def make_hook(name):
264
+ def hook_fn(module, input, output):
265
+ out = output[0] if isinstance(output, tuple) else output
266
+ activations[name] = out.detach().cpu()
267
+ return hook_fn
268
+
269
+ hooks.append(embed.register_forward_hook(make_hook("embedding")))
270
+ for i, block in enumerate(blocks):
271
+ hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
272
+ hooks.append(final_norm.register_forward_hook(make_hook("final_norm")))
273
+
274
+ with torch.no_grad():
275
+ model(input_ids)
276
+
277
+ for h in hooks:
278
+ h.remove()
279
+
280
+ return activations
281
+
282
+
283
+ def collect_activations(model, model_type, config_dict, input_ids, device):
284
+ """Dispatch to the right collector based on model type."""
285
+ word_positions = None
286
+ word_rope_dims = config_dict.get("word_rope_dims", 0) if config_dict else 0
287
+
288
+ if word_rope_dims > 0 and model_type in ("standard", "mirrored"):
289
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
290
+ from circuits.data import get_tokenizer
291
+ from circuits.layers import build_word_start_table, compute_word_positions
292
+ tokenizer_name = config_dict.get("tokenizer_name", "gpt2")
293
+ # Try to get tokenizer from the model's config
294
+ tokenizer = get_tokenizer(tokenizer_name)
295
+ word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
296
+ word_positions = compute_word_positions(input_ids, word_start_table)
297
+
298
+ if model_type == "mirrored":
299
+ return collect_mirrored_activations(model, input_ids, word_positions)
300
+ elif model_type == "standard":
301
+ return collect_standard_activations(model, input_ids, word_positions)
302
+ else:
303
+ return collect_hf_activations(model, input_ids)
304
+
305
+
306
+ # ---------------------------------------------------------------------------
307
+ # Linear CKA
308
+ # ---------------------------------------------------------------------------
309
+
310
+ def linear_cka(X: torch.Tensor, Y: torch.Tensor) -> float:
311
+ """Compute linear CKA between two [N, D] representation matrices.
312
+
313
+ CKA(X, Y) = ||Yc^T Xc||_F^2 / (||Xc^T Xc||_F * ||Yc^T Yc||_F)
314
+ """
315
+ X = X.float()
316
+ Y = Y.float()
317
+
318
+ # Center
319
+ X = X - X.mean(0, keepdim=True)
320
+ Y = Y - Y.mean(0, keepdim=True)
321
+
322
+ N = X.shape[0]
323
+
324
+ if N < min(X.shape[1], Y.shape[1]):
325
+ # Kernel formulation (N < D): K=XX^T, L=YY^T — [N,N] matrices
326
+ K = X @ X.T
327
+ L = Y @ Y.T
328
+ numerator = (K * L).sum()
329
+ denominator = torch.sqrt((K * K).sum() * (L * L).sum())
330
+ else:
331
+ # Feature formulation (D <= N)
332
+ XtY = X.T @ Y
333
+ XtX = X.T @ X
334
+ YtY = Y.T @ Y
335
+ numerator = (XtY * XtY).sum()
336
+ denominator = torch.sqrt((XtX * XtX).sum() * (YtY * YtY).sum())
337
+
338
+ if denominator < 1e-10:
339
+ return 0.0
340
+
341
+ return (numerator / denominator).item()
342
+
343
+
344
+ def compute_cka_matrix(activations: dict, subsample: int = 4) -> tuple:
345
+ """Compute CKA between all layer pairs. Returns (cka_matrix, layer_names)."""
346
+ names = list(activations.keys())
347
+ n_layers = len(names)
348
+
349
+ # Flatten and subsample: [B, L, D] -> [N, D]
350
+ flat_acts = {}
351
+ for name, act in activations.items():
352
+ act_sub = act[:, ::subsample, :]
353
+ flat_acts[name] = act_sub.reshape(-1, act_sub.shape[-1])
354
+
355
+ cka_matrix = np.zeros((n_layers, n_layers))
356
+
357
+ for i in range(n_layers):
358
+ cka_matrix[i, i] = 1.0
359
+ for j in range(i + 1, n_layers):
360
+ cka_val = linear_cka(flat_acts[names[i]], flat_acts[names[j]])
361
+ cka_matrix[i, j] = cka_val
362
+ cka_matrix[j, i] = cka_val
363
+ if (i + 1) % 5 == 0 or i == n_layers - 1:
364
+ print(f" CKA: {i+1}/{n_layers} rows computed")
365
+
366
+ return cka_matrix, names
367
+
368
+
369
+ def compute_cross_model_cka(acts_a: dict, acts_b: dict) -> tuple:
370
+ """Cross-model CKA using sample-level (avg-pooled) representations."""
371
+ names_a = list(acts_a.keys())
372
+ names_b = list(acts_b.keys())
373
+
374
+ def pool(activations):
375
+ return {name: act.mean(dim=1) for name, act in activations.items()}
376
+
377
+ pooled_a = pool(acts_a)
378
+ pooled_b = pool(acts_b)
379
+
380
+ # Ensure same number of samples
381
+ n_samples = min(
382
+ next(iter(pooled_a.values())).shape[0],
383
+ next(iter(pooled_b.values())).shape[0]
384
+ )
385
+
386
+ cka_matrix = np.zeros((len(names_a), len(names_b)))
387
+
388
+ for i, na in enumerate(names_a):
389
+ for j, nb in enumerate(names_b):
390
+ cka_matrix[i, j] = linear_cka(pooled_a[na][:n_samples], pooled_b[nb][:n_samples])
391
+ if (i + 1) % 5 == 0 or i == len(names_a) - 1:
392
+ print(f" Cross-CKA: {i+1}/{len(names_a)} rows computed")
393
+
394
+ return cka_matrix, names_a, names_b
395
+
396
+
397
+ # ---------------------------------------------------------------------------
398
+ # Logit Lens
399
+ # ---------------------------------------------------------------------------
400
+
401
+ def get_unembed_components(model, model_type):
402
+ """Extract (norm_module, unembed_weight) for logit lens projection."""
403
+ if model_type in ("standard", "mirrored"):
404
+ return model.norm, model.embed.weight
405
+ elif hasattr(model, 'transformer'):
406
+ return model.transformer.ln_f, model.transformer.wte.weight
407
+ elif hasattr(model, 'model'):
408
+ return model.model.norm, model.model.embed_tokens.weight
409
+ else:
410
+ raise ValueError(f"Unsupported model: {type(model)}")
411
+
412
+
413
+ def compute_logit_lens(activations: dict, norm: nn.Module, unembed_weight: torch.Tensor,
414
+ labels: torch.Tensor, device: str = "cpu",
415
+ chunk_size: int = 2048) -> OrderedDict:
416
+ """Compute logit lens statistics at every layer.
417
+
418
+ Projects intermediate hidden states through final norm + unembedding.
419
+ Computes entropy, top-1 probability, correct token rank, and
420
+ agreement with the final layer's predictions.
421
+
422
+ Args:
423
+ activations: OrderedDict[name] = [B, L, D]
424
+ norm: final layer norm module
425
+ unembed_weight: [V, D] unembedding matrix
426
+ labels: [B, L-1] next-token labels (input_ids[:, 1:])
427
+ device: computation device
428
+ chunk_size: number of positions per batch for projection
429
+
430
+ Returns:
431
+ OrderedDict[name] = {entropy, top1_prob, correct_rank, ...}
432
+ """
433
+ names = list(activations.keys())
434
+ final_name = names[-1] # "final_norm"
435
+ results = OrderedDict()
436
+
437
+ unembed = unembed_weight.to(device)
438
+ norm_mod = norm.to(device)
439
+ labels_flat = labels.reshape(-1).to(device)
440
+
441
+ def process_layer(name, act, apply_norm=True):
442
+ """Project one layer's activations and compute all metrics."""
443
+ B, L, D = act.shape
444
+ flat = act[:, :-1, :].reshape(-1, D) # [B*(L-1), D]
445
+ N = flat.shape[0]
446
+
447
+ all_entropy = []
448
+ all_top1_prob = []
449
+ all_correct_rank = []
450
+ all_top1_idx = []
451
+
452
+ for start in range(0, N, chunk_size):
453
+ end = min(start + chunk_size, N)
454
+ chunk = flat[start:end].to(device)
455
+ chunk_labels = labels_flat[start:end]
456
+
457
+ if apply_norm:
458
+ chunk = norm_mod(chunk)
459
+
460
+ logits = chunk @ unembed.T # [cs, V]
461
+ log_probs = F.log_softmax(logits, dim=-1)
462
+ probs = log_probs.exp()
463
+
464
+ # Entropy
465
+ entropy = -(probs * log_probs).sum(dim=-1)
466
+ all_entropy.append(entropy.cpu())
467
+
468
+ # Top-1 probability
469
+ top1_prob = probs.max(dim=-1).values
470
+ all_top1_prob.append(top1_prob.cpu())
471
+
472
+ # Correct token rank
473
+ correct_logits = logits.gather(1, chunk_labels.unsqueeze(1))
474
+ rank = (logits > correct_logits).sum(dim=-1) + 1
475
+ all_correct_rank.append(rank.cpu())
476
+
477
+ # Top-1 index
478
+ all_top1_idx.append(logits.argmax(dim=-1).cpu())
479
+
480
+ entropy_t = torch.cat(all_entropy)
481
+ top1_t = torch.cat(all_top1_prob)
482
+ rank_t = torch.cat(all_correct_rank).float()
483
+ top1_idx = torch.cat(all_top1_idx)
484
+
485
+ return {
486
+ "entropy": entropy_t.mean().item(),
487
+ "entropy_std": entropy_t.std().item(),
488
+ "top1_prob": top1_t.mean().item(),
489
+ "correct_rank_mean": rank_t.mean().item(),
490
+ "correct_rank_median": rank_t.median().item(),
491
+ "log_rank_mean": rank_t.log().mean().item(),
492
+ "_top1_idx": top1_idx,
493
+ }
494
+
495
+ # Process all layers
496
+ for name in names:
497
+ is_final = (name == final_name)
498
+ act = activations[name]
499
+ stats = process_layer(name, act, apply_norm=not is_final)
500
+ results[name] = stats
501
+ print(f" Logit lens: {name:20s} entropy={stats['entropy']:.2f} "
502
+ f"top1={stats['top1_prob']:.4f} rank={stats['correct_rank_median']:.0f}")
503
+
504
+ # Compute agreement with final layer
505
+ final_top1 = results[final_name]["_top1_idx"]
506
+ for name in names:
507
+ layer_top1 = results[name]["_top1_idx"]
508
+ agreement = (layer_top1 == final_top1).float().mean().item()
509
+ results[name]["agreement_with_final"] = agreement
510
+
511
+ # Clean up internal tensors
512
+ for name in names:
513
+ del results[name]["_top1_idx"]
514
+
515
+ return results
516
+
517
+
518
+ # ---------------------------------------------------------------------------
519
+ # Representation drift
520
+ # ---------------------------------------------------------------------------
521
+
522
+ def compute_drift(activations: dict) -> OrderedDict:
523
+ """Cosine similarity between consecutive layers' representations."""
524
+ names = list(activations.keys())
525
+ drift = OrderedDict()
526
+
527
+ for i in range(1, len(names)):
528
+ prev = activations[names[i - 1]]
529
+ curr = activations[names[i]]
530
+
531
+ # Flatten to [N, D]
532
+ prev_flat = prev.reshape(-1, prev.shape[-1])
533
+ curr_flat = curr.reshape(-1, curr.shape[-1])
534
+
535
+ # Mean cosine similarity
536
+ cos = F.cosine_similarity(prev_flat, curr_flat, dim=-1)
537
+ drift[names[i]] = {
538
+ "cos_sim_mean": cos.mean().item(),
539
+ "cos_sim_std": cos.std().item(),
540
+ "l2_distance": (curr_flat - prev_flat).norm(dim=-1).mean().item(),
541
+ }
542
+
543
+ return drift
544
+
545
+
546
+ # ---------------------------------------------------------------------------
547
+ # Plotting
548
+ # ---------------------------------------------------------------------------
549
+
550
+ def _phase_color(name):
551
+ """Return color based on layer phase."""
552
+ if "expand" in name:
553
+ return "steelblue"
554
+ elif "middle" in name:
555
+ return "goldenrod"
556
+ elif "compress" in name:
557
+ return "coral"
558
+ elif "embedding" in name:
559
+ return "gray"
560
+ elif "final" in name:
561
+ return "gray"
562
+ else:
563
+ return "mediumpurple"
564
+
565
+
566
+ def _layer_sort_key(name):
567
+ """Sort key for processing order."""
568
+ order = {"embedding": -1, "final_norm": 9999}
569
+ if name in order:
570
+ return order[name]
571
+ parts = name.split("_")
572
+ phase = parts[0]
573
+ idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
574
+ phase_offset = {"expand": 0, "middle": 1000, "compress": 2000, "layer": 0}
575
+ return phase_offset.get(phase, 3000) + idx
576
+
577
+
578
+ def _short_name(name):
579
+ """Shorten layer name for plot labels."""
580
+ if name == "embedding":
581
+ return "emb"
582
+ if name == "final_norm":
583
+ return "out"
584
+ parts = name.split("_")
585
+ if parts[0] == "expand":
586
+ return f"E{parts[1]}"
587
+ elif parts[0] == "middle":
588
+ return f"M{parts[1]}"
589
+ elif parts[0] == "compress":
590
+ return f"C{parts[1]}"
591
+ elif parts[0] == "layer":
592
+ return f"L{parts[1]}"
593
+ return name[:6]
594
+
595
+
596
+ def plot_cka_self(cka_matrix: np.ndarray, names: list, output_dir: Path,
597
+ model_label: str):
598
+ """Plot self-CKA heatmap."""
599
+ n = len(names)
600
+ short = [_short_name(n) for n in names]
601
+
602
+ fig, ax = plt.subplots(figsize=(max(10, n * 0.35), max(8, n * 0.3)))
603
+ fig.suptitle(f"{model_label} -- CKA Self-Similarity", fontsize=14)
604
+
605
+ im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="equal")
606
+
607
+ # Phase separators
608
+ for i, name in enumerate(names):
609
+ if i > 0:
610
+ prev = names[i - 1].split("_")[0]
611
+ curr = name.split("_")[0]
612
+ if prev != curr:
613
+ ax.axhline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
614
+ ax.axvline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
615
+
616
+ ax.set_xticks(range(n))
617
+ ax.set_xticklabels(short, rotation=90, fontsize=7)
618
+ ax.set_yticks(range(n))
619
+ ax.set_yticklabels(short, fontsize=7)
620
+
621
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
622
+ plt.tight_layout()
623
+ fig.savefig(output_dir / "cka_self.png", dpi=150)
624
+ plt.close(fig)
625
+
626
+
627
+ def plot_cka_cross(cka_matrix: np.ndarray, names_a: list, names_b: list,
628
+ output_dir: Path, label_a: str, label_b: str):
629
+ """Plot cross-model CKA heatmap."""
630
+ short_a = [_short_name(n) for n in names_a]
631
+ short_b = [_short_name(n) for n in names_b]
632
+
633
+ na, nb = len(names_a), len(names_b)
634
+ fig, ax = plt.subplots(figsize=(max(10, nb * 0.35), max(8, na * 0.3)))
635
+ fig.suptitle(f"Cross-CKA: {label_a} vs {label_b}", fontsize=14)
636
+
637
+ im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="auto")
638
+
639
+ ax.set_xticks(range(nb))
640
+ ax.set_xticklabels(short_b, rotation=90, fontsize=7)
641
+ ax.set_xlabel(label_b)
642
+ ax.set_yticks(range(na))
643
+ ax.set_yticklabels(short_a, fontsize=7)
644
+ ax.set_ylabel(label_a)
645
+
646
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
647
+ plt.tight_layout()
648
+ fig.savefig(output_dir / "cka_cross.png", dpi=150)
649
+ plt.close(fig)
650
+
651
+
652
+ def plot_logit_lens(lens_results: OrderedDict, output_dir: Path,
653
+ model_label: str):
654
+ """Plot logit lens summary: entropy, confidence, rank, agreement."""
655
+ names = list(lens_results.keys())
656
+ sorted_names = sorted(names, key=_layer_sort_key)
657
+ short = [_short_name(n) for n in sorted_names]
658
+ colors = [_phase_color(n) for n in sorted_names]
659
+ x = range(len(sorted_names))
660
+
661
+ fig, axes = plt.subplots(2, 2, figsize=(16, 10))
662
+ fig.suptitle(f"{model_label} -- Logit Lens", fontsize=14)
663
+
664
+ # Entropy
665
+ vals = [lens_results[n]["entropy"] for n in sorted_names]
666
+ axes[0, 0].bar(x, vals, color=colors, alpha=0.85)
667
+ axes[0, 0].set_ylabel("Entropy (nats)")
668
+ axes[0, 0].set_title("Prediction entropy per layer")
669
+ axes[0, 0].set_xticks(x)
670
+ axes[0, 0].set_xticklabels(short, rotation=90, fontsize=7)
671
+
672
+ # Top-1 probability
673
+ vals = [lens_results[n]["top1_prob"] for n in sorted_names]
674
+ axes[0, 1].bar(x, vals, color=colors, alpha=0.85)
675
+ axes[0, 1].set_ylabel("Top-1 probability")
676
+ axes[0, 1].set_title("Prediction confidence per layer")
677
+ axes[0, 1].set_xticks(x)
678
+ axes[0, 1].set_xticklabels(short, rotation=90, fontsize=7)
679
+
680
+ # Correct rank (log scale)
681
+ vals = [lens_results[n]["correct_rank_median"] for n in sorted_names]
682
+ axes[1, 0].bar(x, vals, color=colors, alpha=0.85)
683
+ axes[1, 0].set_ylabel("Median rank of correct token")
684
+ axes[1, 0].set_yscale("log")
685
+ axes[1, 0].set_title("When does the model find the answer?")
686
+ axes[1, 0].set_xticks(x)
687
+ axes[1, 0].set_xticklabels(short, rotation=90, fontsize=7)
688
+
689
+ # Agreement with final layer
690
+ vals = [lens_results[n]["agreement_with_final"] for n in sorted_names]
691
+ axes[1, 1].bar(x, vals, color=colors, alpha=0.85)
692
+ axes[1, 1].set_ylabel("Agreement with final layer")
693
+ axes[1, 1].set_title("Convergence toward final prediction")
694
+ axes[1, 1].set_ylim(0, 1.05)
695
+ axes[1, 1].set_xticks(x)
696
+ axes[1, 1].set_xticklabels(short, rotation=90, fontsize=7)
697
+
698
+ plt.tight_layout()
699
+ fig.savefig(output_dir / "logit_lens_summary.png", dpi=150)
700
+ plt.close(fig)
701
+
702
+
703
+ def plot_logit_lens_trajectory(activations: dict, norm: nn.Module,
704
+ unembed_weight: torch.Tensor, input_ids: torch.Tensor,
705
+ tokenizer, output_dir: Path, model_label: str,
706
+ device: str = "cpu",
707
+ n_positions: int = 6, n_layers: int = 10):
708
+ """Show top-5 predicted tokens at selected layers for a few positions.
709
+
710
+ Picks positions spread across the first sample and shows how the
711
+ model's prediction evolves through the network.
712
+ """
713
+ names = sorted(activations.keys(), key=_layer_sort_key)
714
+
715
+ # Select layers evenly spread across the network
716
+ if len(names) > n_layers:
717
+ indices = np.linspace(0, len(names) - 1, n_layers, dtype=int)
718
+ selected_layers = [names[i] for i in indices]
719
+ else:
720
+ selected_layers = names
721
+
722
+ # Select positions from the first sample
723
+ seq_len = input_ids.shape[1]
724
+ pos_indices = np.linspace(10, seq_len - 2, n_positions, dtype=int)
725
+
726
+ unembed = unembed_weight.to(device)
727
+ norm_mod = norm.to(device)
728
+ final_name = names[-1]
729
+
730
+ fig, axes = plt.subplots(n_positions, 1, figsize=(14, 3 * n_positions))
731
+ if n_positions == 1:
732
+ axes = [axes]
733
+ fig.suptitle(f"{model_label} -- Token prediction trajectory", fontsize=14, y=1.02)
734
+
735
+ for pos_idx, pos in enumerate(pos_indices):
736
+ ax = axes[pos_idx]
737
+ actual_token = tokenizer.decode([input_ids[0, pos + 1].item()])
738
+ context = tokenizer.decode(input_ids[0, max(0, pos - 5):pos + 1].tolist())
739
+
740
+ layer_labels = []
741
+ top_tokens_per_layer = []
742
+
743
+ for name in selected_layers:
744
+ is_final = (name == final_name)
745
+ hidden = activations[name][0, pos:pos + 1, :].to(device) # [1, D]
746
+ if not is_final:
747
+ hidden = norm_mod(hidden)
748
+ logits = (hidden @ unembed.T).squeeze(0) # [V]
749
+ probs = F.softmax(logits, dim=-1)
750
+ top5_vals, top5_idx = probs.topk(5)
751
+
752
+ tokens_str = []
753
+ for val, idx in zip(top5_vals, top5_idx):
754
+ tok = tokenizer.decode([idx.item()]).replace("\n", "\\n")
755
+ tokens_str.append(f"{tok}({val:.2f})")
756
+
757
+ layer_labels.append(_short_name(name))
758
+ top_tokens_per_layer.append("\n".join(tokens_str))
759
+
760
+ # Create a text table
761
+ ax.set_xlim(-0.5, len(layer_labels) - 0.5)
762
+ ax.set_ylim(-0.5, 5.5)
763
+ ax.set_xticks(range(len(layer_labels)))
764
+ ax.set_xticklabels(layer_labels, fontsize=8)
765
+ ax.set_yticks([])
766
+
767
+ for li, tokens_str in enumerate(top_tokens_per_layer):
768
+ lines = tokens_str.split("\n")
769
+ for rank, line in enumerate(lines):
770
+ color = "darkgreen" if actual_token.strip() in line else "black"
771
+ fontweight = "bold" if actual_token.strip() in line else "normal"
772
+ ax.text(li, rank, line, ha="center", va="center", fontsize=7,
773
+ color=color, fontweight=fontweight)
774
+
775
+ ax.set_title(f'pos {pos}: "...{context}" -> [{actual_token.strip()}]',
776
+ fontsize=9, loc="left")
777
+ ax.invert_yaxis()
778
+ ax.spines["top"].set_visible(False)
779
+ ax.spines["right"].set_visible(False)
780
+ ax.spines["left"].set_visible(False)
781
+
782
+ plt.tight_layout()
783
+ fig.savefig(output_dir / "logit_lens_trajectory.png", dpi=150, bbox_inches="tight")
784
+ plt.close(fig)
785
+
786
+
787
+ def plot_drift(drift: OrderedDict, output_dir: Path, model_label: str):
788
+ """Plot representation drift between consecutive layers."""
789
+ names = list(drift.keys())
790
+ sorted_names = sorted(names, key=_layer_sort_key)
791
+ short = [_short_name(n) for n in sorted_names]
792
+ colors = [_phase_color(n) for n in sorted_names]
793
+ x = range(len(sorted_names))
794
+
795
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
796
+ fig.suptitle(f"{model_label} -- Representation drift", fontsize=14)
797
+
798
+ # Cosine similarity with previous layer
799
+ vals = [drift[n]["cos_sim_mean"] for n in sorted_names]
800
+ axes[0].bar(x, vals, color=colors, alpha=0.85)
801
+ axes[0].set_ylabel("Cosine similarity with previous layer")
802
+ axes[0].set_title("How much each layer preserves direction")
803
+ axes[0].set_xticks(x)
804
+ axes[0].set_xticklabels(short, rotation=90, fontsize=7)
805
+
806
+ # L2 distance
807
+ vals = [drift[n]["l2_distance"] for n in sorted_names]
808
+ axes[1].bar(x, vals, color=colors, alpha=0.85)
809
+ axes[1].set_ylabel("L2 distance from previous layer")
810
+ axes[1].set_title("How much each layer changes magnitude")
811
+ axes[1].set_xticks(x)
812
+ axes[1].set_xticklabels(short, rotation=90, fontsize=7)
813
+
814
+ plt.tight_layout()
815
+ fig.savefig(output_dir / "representation_drift.png", dpi=150)
816
+ plt.close(fig)
817
+
818
+
819
+ # ---------------------------------------------------------------------------
820
+ # Results saving
821
+ # ---------------------------------------------------------------------------
822
+
823
+ def save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir):
824
+ """Save all numerical results to JSON."""
825
+ out = {}
826
+
827
+ if cka_matrix is not None:
828
+ out["cka_self"] = {
829
+ "names": cka_names,
830
+ "matrix": cka_matrix.tolist(),
831
+ }
832
+
833
+ if lens_results:
834
+ out["logit_lens"] = {name: data for name, data in lens_results.items()}
835
+
836
+ if drift:
837
+ out["drift"] = {name: data for name, data in drift.items()}
838
+
839
+ if cross_cka is not None:
840
+ matrix, names_a, names_b = cross_cka
841
+ out["cka_cross"] = {
842
+ "names_a": names_a,
843
+ "names_b": names_b,
844
+ "matrix": matrix.tolist(),
845
+ }
846
+
847
+ with open(output_dir / "results.json", "w") as f:
848
+ json.dump(out, f, indent=2, default=str)
849
+
850
+
851
+ # ---------------------------------------------------------------------------
852
+ # Main
853
+ # ---------------------------------------------------------------------------
854
+
855
+ def main():
856
+ parser = argparse.ArgumentParser(
857
+ description="CKA and Logit Lens analysis for Prisma / Circuit Transformer")
858
+ parser.add_argument("--checkpoint", type=str, required=True,
859
+ help="Path to Prisma/Circuit checkpoint")
860
+ parser.add_argument("--checkpoint-b", type=str, default=None,
861
+ help="Second Prisma checkpoint for cross-model CKA")
862
+ parser.add_argument("--hf-model", type=str, default=None,
863
+ help="HuggingFace model for cross-model CKA (e.g. gpt2-medium)")
864
+ parser.add_argument("--data", type=str, required=True,
865
+ help="Data source (hf:dataset:config:split or file path)")
866
+ parser.add_argument("--num-samples", type=int, default=32,
867
+ help="Number of text samples (default: 32)")
868
+ parser.add_argument("--context-length", type=int, default=512,
869
+ help="Sequence length (default: 512)")
870
+ parser.add_argument("--cka-subsample", type=int, default=4,
871
+ help="Position subsampling for CKA (default: 4)")
872
+ parser.add_argument("--no-logit-lens", action="store_true",
873
+ help="Skip logit lens analysis")
874
+ parser.add_argument("--no-cka", action="store_true",
875
+ help="Skip CKA analysis")
876
+ parser.add_argument("--output-dir", type=str, default=None,
877
+ help="Output directory (default: auto)")
878
+ parser.add_argument("--gpu", type=int, default=0, help="GPU index")
879
+ args = parser.parse_args()
880
+
881
+ device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
882
+ print(f"Device: {device}")
883
+
884
+ # Output directory
885
+ if args.output_dir:
886
+ output_dir = Path(args.output_dir)
887
+ else:
888
+ ckpt_name = Path(args.checkpoint).parent.name
889
+ output_dir = Path("circuits/scripts/representation_output") / ckpt_name
890
+ output_dir.mkdir(parents=True, exist_ok=True)
891
+ print(f"Output: {output_dir}")
892
+
893
+ # === Load model A ===
894
+ print(f"\nLoading: {args.checkpoint}")
895
+ model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
896
+ label_a = Path(args.checkpoint).parent.name
897
+ n_params = sum(p.numel() for p in model_a.parameters())
898
+ print(f" Type: {model_type_a}, params: {n_params:,}")
899
+
900
+ # === Load data ===
901
+ ckpt_data = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
902
+ tokenizer_name = ckpt_data.get("tokenizer_name", config_a.get("tokenizer_name", "gpt2"))
903
+ del ckpt_data
904
+
905
+ print(f"\nLoading data ({args.num_samples} samples, ctx={args.context_length})...")
906
+ result = load_data(
907
+ args.data, tokenizer_name, args.num_samples, args.context_length, device
908
+ )
909
+ if result[0] is None:
910
+ print("ERROR: No valid samples loaded.")
911
+ return
912
+ input_ids, tokenizer = result
913
+ print(f" Data shape: {input_ids.shape}")
914
+
915
+ # === Collect activations (model A) ===
916
+ print(f"\nCollecting activations ({model_type_a})...")
917
+ acts_a = collect_activations(model_a, model_type_a, config_a, input_ids, device)
918
+ print(f" Collected {len(acts_a)} layers")
919
+
920
+ # Free GPU memory
921
+ del model_a
922
+ if device.startswith("cuda"):
923
+ torch.cuda.empty_cache()
924
+
925
+ # === CKA (self) ===
926
+ cka_matrix = None
927
+ cka_names = None
928
+ if not args.no_cka:
929
+ print(f"\nComputing self-CKA (subsample={args.cka_subsample})...")
930
+ cka_matrix, cka_names = compute_cka_matrix(acts_a, subsample=args.cka_subsample)
931
+ plot_cka_self(cka_matrix, cka_names, output_dir, label_a)
932
+ print(f" Saved: cka_self.png")
933
+
934
+ # === Cross-model CKA ===
935
+ cross_cka = None
936
+ if not args.no_cka and (args.checkpoint_b or args.hf_model):
937
+ if args.checkpoint_b:
938
+ print(f"\nLoading comparison: {args.checkpoint_b}")
939
+ model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
940
+ label_b = Path(args.checkpoint_b).parent.name
941
+ acts_b = collect_activations(model_b, model_type_b, config_b, input_ids, device)
942
+ del model_b
943
+ else:
944
+ print(f"\nLoading HF model: {args.hf_model}")
945
+ model_b = load_hf_model(args.hf_model, device)
946
+ label_b = args.hf_model
947
+ # Decode texts from our tokens and re-tokenize for HF model
948
+ print(f" Re-tokenizing for {args.hf_model}...")
949
+ raw_texts = [tokenizer.decode(input_ids[i].tolist()) for i in range(input_ids.shape[0])]
950
+ input_ids_b, _ = tokenize_for_hf(
951
+ raw_texts, args.hf_model, args.context_length, device
952
+ )
953
+ if input_ids_b is not None:
954
+ print(f" HF data shape: {input_ids_b.shape}")
955
+ acts_b = collect_hf_activations(model_b, input_ids_b)
956
+ else:
957
+ acts_b = None
958
+ del model_b
959
+
960
+ if device.startswith("cuda"):
961
+ torch.cuda.empty_cache()
962
+
963
+ if acts_b:
964
+ print(f"\nComputing cross-model CKA...")
965
+ cross_matrix, cross_names_a, cross_names_b = compute_cross_model_cka(acts_a, acts_b)
966
+ cross_cka = (cross_matrix, cross_names_a, cross_names_b)
967
+ plot_cka_cross(cross_matrix, cross_names_a, cross_names_b,
968
+ output_dir, label_a, label_b)
969
+ print(f" Saved: cka_cross.png")
970
+ del acts_b
971
+
972
+ # === Logit lens ===
973
+ lens_results = None
974
+ if not args.no_logit_lens:
975
+ # Reload model for unembedding components (we deleted it for memory)
976
+ print(f"\nReloading model for logit lens...")
977
+ model_a, _, _ = load_prisma_model(args.checkpoint, device)
978
+ norm, unembed_weight = get_unembed_components(model_a, model_type_a)
979
+
980
+ labels = input_ids[:, 1:].cpu() # next-token labels
981
+
982
+ print(f"Computing logit lens...")
983
+ lens_results = compute_logit_lens(acts_a, norm, unembed_weight, labels, device)
984
+ plot_logit_lens(lens_results, output_dir, label_a)
985
+ print(f" Saved: logit_lens_summary.png")
986
+
987
+ # Token trajectory visualization
988
+ print(f" Generating token trajectories...")
989
+ plot_logit_lens_trajectory(
990
+ acts_a, norm, unembed_weight, input_ids.cpu(), tokenizer,
991
+ output_dir, label_a, device
992
+ )
993
+ print(f" Saved: logit_lens_trajectory.png")
994
+
995
+ del model_a
996
+ if device.startswith("cuda"):
997
+ torch.cuda.empty_cache()
998
+
999
+ # === Representation drift ===
1000
+ print(f"\nComputing representation drift...")
1001
+ drift = compute_drift(acts_a)
1002
+ plot_drift(drift, output_dir, label_a)
1003
+ print(f" Saved: representation_drift.png")
1004
+
1005
+ # === Save results ===
1006
+ save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir)
1007
+ print(f"\nAll outputs saved to: {output_dir}")
1008
+ n_plots = len(list(output_dir.glob("*.png")))
1009
+ print(f" Plots: {n_plots} PNG files")
1010
+ print(f" Data: results.json")
1011
+
1012
+
1013
+ if __name__ == "__main__":
1014
+ main()
scripts/spectral_analysis.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Spectral analysis of Prisma / Circuit Transformer checkpoints.
4
+
5
+ Computes SVD spectra of weight matrices and (optionally) activation covariances,
6
+ revealing how the model organizes information geometrically.
7
+
8
+ Analyses:
9
+ 1. Weight spectra — singular value distributions per matrix
10
+ 2. Effective rank — how many dimensions carry real signal
11
+ 3. Power-law fit — Martin & Mahoney alpha exponent (training quality)
12
+ 4. MP bound — Marchenko-Pastur separation of signal vs noise
13
+ 5. Mirror comparison — expand vs compress activation spectra (Prisma-specific)
14
+ 6. Embedding alignment— spectral similarity between embed and final hidden states
15
+ 7. Layer-wise summary — effective rank progression through the network (the lens)
16
+
17
+ Usage:
18
+ # Weight-only analysis (no data needed)
19
+ python -m circuits.scripts.spectral_analysis --checkpoint path/to/checkpoint.pt
20
+
21
+ # Full analysis with activation spectra (needs data)
22
+ python -m circuits.scripts.spectral_analysis --checkpoint path/to/checkpoint.pt \
23
+ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train --num-samples 512
24
+
25
+ # Compare two checkpoints
26
+ python -m circuits.scripts.spectral_analysis \
27
+ --checkpoint path/to/prisma.pt --checkpoint-b path/to/standard.pt
28
+
29
+ # Compare against HuggingFace model
30
+ python -m circuits.scripts.spectral_analysis \
31
+ --checkpoint path/to/prisma.pt --hf-model gpt2-medium
32
+ """
33
+
34
+ import argparse
35
+ import json
36
+ import sys
37
+ import os
38
+ from pathlib import Path
39
+ from collections import defaultdict
40
+
41
+ import numpy as np
42
+ import torch
43
+ import torch.nn as nn
44
+ import matplotlib
45
+ matplotlib.use("Agg")
46
+ import matplotlib.pyplot as plt
47
+ from matplotlib.gridspec import GridSpec
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Model loading
52
+ # ---------------------------------------------------------------------------
53
+
54
+ def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
55
+ """Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
56
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
57
+ from circuits.config import CircuitConfig
58
+ from circuits.model import CircuitTransformer
59
+ from circuits.mirrored import MirroredConfig, MirroredTransformer
60
+
61
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
62
+ model_type = ckpt.get("model_type", "standard")
63
+ config_dict = ckpt.get("config", {})
64
+
65
+ if model_type == "mirrored":
66
+ if config_dict.get("dual_gate_middle"):
67
+ config_dict.pop("dual_gate_middle")
68
+ config = MirroredConfig.from_dict(config_dict)
69
+ model = MirroredTransformer(config)
70
+ else:
71
+ config = CircuitConfig.from_dict(config_dict)
72
+ model = CircuitTransformer(config)
73
+
74
+ state_dict = ckpt["model"]
75
+ if any(k.startswith("_orig_mod.") for k in state_dict):
76
+ state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
77
+ model.load_state_dict(state_dict, strict=False)
78
+ model.to(device).eval()
79
+
80
+ return model, config_dict, model_type
81
+
82
+
83
+ def load_hf_model(model_name: str, device: str = "cpu"):
84
+ """Load a HuggingFace causal LM."""
85
+ from transformers import AutoModelForCausalLM
86
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
87
+ model.to(device).eval()
88
+ return model
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # SVD utilities
93
+ # ---------------------------------------------------------------------------
94
+
95
+ def compute_singular_values(weight: torch.Tensor) -> np.ndarray:
96
+ """Compute singular values of a 2D weight matrix."""
97
+ w = weight.detach().float().cpu()
98
+ if w.ndim != 2:
99
+ return None
100
+ sv = torch.linalg.svdvals(w).numpy()
101
+ return sv
102
+
103
+
104
+ def effective_rank(sv: np.ndarray) -> float:
105
+ """Entropy-based effective rank (Roy & Vetterli, 2007).
106
+
107
+ erank = exp(H(p)) where p_i = sigma_i / sum(sigma)
108
+ and H is Shannon entropy. Ranges from 1 (rank-1) to min(m,n) (full rank).
109
+ """
110
+ sv = sv[sv > 1e-10]
111
+ if len(sv) == 0:
112
+ return 0.0
113
+ p = sv / sv.sum()
114
+ entropy = -(p * np.log(p)).sum()
115
+ return float(np.exp(entropy))
116
+
117
+
118
+ def stable_rank(sv: np.ndarray) -> float:
119
+ """Stable rank = ||W||_F^2 / ||W||_2^2 = sum(sigma^2) / max(sigma)^2."""
120
+ if len(sv) == 0 or sv[0] < 1e-10:
121
+ return 0.0
122
+ return float((sv ** 2).sum() / (sv[0] ** 2))
123
+
124
+
125
+ def marchenko_pastur_bound(m: int, n: int, sv: np.ndarray) -> float:
126
+ """Estimate Marchenko-Pastur upper edge.
127
+
128
+ For a random matrix with variance sigma^2, the MP upper bound is
129
+ sigma * (1 + sqrt(m/n))^2 (assuming m >= n).
130
+ We estimate sigma from the bulk of singular values.
131
+ """
132
+ gamma = max(m, n) / min(m, n)
133
+ # Estimate noise level from bottom half of spectrum
134
+ bottom_half = sv[len(sv) // 2:]
135
+ if len(bottom_half) == 0:
136
+ return sv[-1] if len(sv) > 0 else 0.0
137
+ sigma_est = float(np.median(bottom_half)) / np.sqrt(max(m, n))
138
+ mp_upper = sigma_est * (1.0 + np.sqrt(gamma)) ** 2 * np.sqrt(min(m, n))
139
+ return mp_upper
140
+
141
+
142
+ def fit_power_law(sv: np.ndarray, fit_fraction: float = 0.8) -> tuple[float, float]:
143
+ """Fit power law to singular value distribution tail.
144
+
145
+ Returns (alpha, r_squared). alpha < 2 = heavy-tailed (well-trained).
146
+ """
147
+ sv = sv[sv > 1e-10]
148
+ if len(sv) < 10:
149
+ return 0.0, 0.0
150
+ # Fit to the top `fit_fraction` of the spectrum (exclude noise floor)
151
+ n_fit = max(10, int(len(sv) * fit_fraction))
152
+ sv_fit = sv[:n_fit]
153
+
154
+ log_rank = np.log(np.arange(1, n_fit + 1))
155
+ log_sv = np.log(sv_fit)
156
+
157
+ # Linear regression in log-log space: log(sv) = -alpha * log(rank) + c
158
+ coeffs = np.polyfit(log_rank, log_sv, 1)
159
+ alpha = -coeffs[0]
160
+
161
+ # R-squared
162
+ predicted = np.polyval(coeffs, log_rank)
163
+ ss_res = ((log_sv - predicted) ** 2).sum()
164
+ ss_tot = ((log_sv - log_sv.mean()) ** 2).sum()
165
+ r_sq = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
166
+
167
+ return float(alpha), float(r_sq)
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Weight spectrum analysis
172
+ # ---------------------------------------------------------------------------
173
+
174
+ def analyze_weight_spectra(model: nn.Module, model_label: str = "model") -> dict:
175
+ """Compute SVD spectra for all 2D weight matrices."""
176
+ results = {}
177
+ for name, param in model.named_parameters():
178
+ if param.ndim != 2:
179
+ continue
180
+ sv = compute_singular_values(param)
181
+ if sv is None:
182
+ continue
183
+ m, n = param.shape
184
+ mp_bound = marchenko_pastur_bound(m, n, sv)
185
+ n_above_mp = int((sv > mp_bound).sum())
186
+ alpha, r_sq = fit_power_law(sv)
187
+
188
+ results[name] = {
189
+ "shape": (m, n),
190
+ "singular_values": sv,
191
+ "effective_rank": effective_rank(sv),
192
+ "stable_rank": stable_rank(sv),
193
+ "spectral_norm": float(sv[0]),
194
+ "frobenius_norm": float(np.sqrt((sv ** 2).sum())),
195
+ "mp_bound": mp_bound,
196
+ "n_above_mp": n_above_mp,
197
+ "n_total": len(sv),
198
+ "signal_ratio": n_above_mp / len(sv) if len(sv) > 0 else 0,
199
+ "alpha": alpha,
200
+ "alpha_r2": r_sq,
201
+ "condition_number": float(sv[0] / sv[-1]) if sv[-1] > 1e-10 else float("inf"),
202
+ }
203
+ return results
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # Activation spectrum analysis
208
+ # ---------------------------------------------------------------------------
209
+
210
+ def collect_activations(model, input_ids: torch.Tensor,
211
+ word_positions: torch.Tensor = None,
212
+ model_type: str = "standard") -> dict[str, torch.Tensor]:
213
+ """Run a forward pass and collect intermediate activations via hooks."""
214
+ activations = {}
215
+ hooks = []
216
+
217
+ def make_hook(name):
218
+ def hook_fn(module, input, output):
219
+ if isinstance(output, tuple):
220
+ out = output[0]
221
+ else:
222
+ out = output
223
+ # Store mean over batch and sequence for covariance
224
+ activations[name] = out.detach().float().cpu()
225
+ return hook_fn
226
+
227
+ # Register hooks based on model type
228
+ if model_type == "mirrored":
229
+ # Expand phase
230
+ for i, block in enumerate(model.mirror_blocks):
231
+ hooks.append(block.register_forward_hook(make_hook(f"expand_{i}")))
232
+ # Middle
233
+ for i, block in enumerate(model.middle_blocks):
234
+ hooks.append(block.register_forward_hook(make_hook(f"middle_{i}")))
235
+ # Compress — mirror blocks are reused in reverse, so we hook the FFN output
236
+ # We'll collect compress activations differently via a custom forward
237
+ else:
238
+ for i, block in enumerate(model.layers):
239
+ hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
240
+
241
+ # Also hook the embedding output
242
+ hooks.append(model.embed.register_forward_hook(make_hook("embedding")))
243
+
244
+ with torch.no_grad():
245
+ kwargs = {}
246
+ if word_positions is not None:
247
+ kwargs["word_positions"] = word_positions
248
+ model(input_ids, **kwargs)
249
+
250
+ for h in hooks:
251
+ h.remove()
252
+
253
+ return activations
254
+
255
+
256
+ def collect_mirrored_activations(model, input_ids: torch.Tensor,
257
+ word_positions: torch.Tensor = None) -> dict[str, torch.Tensor]:
258
+ """Collect activations from a MirroredTransformer, separating expand and compress phases.
259
+
260
+ This manually runs the forward pass to capture compress-phase activations
261
+ from the reversed mirror blocks.
262
+ """
263
+ import math
264
+
265
+ activations = {}
266
+
267
+ with torch.no_grad():
268
+ # Embed
269
+ x = model.embed(input_ids)
270
+ if model.embed_proj is not None:
271
+ import torch.nn.functional as F
272
+ if model.embed_g3 is not None:
273
+ g4 = F.silu(model.embed_g4(x))
274
+ g3 = F.silu(model.embed_g3(x) * g4)
275
+ x = model.embed_proj(x) * g3
276
+ else:
277
+ x = F.silu(model.embed_proj(x))
278
+ x = x * model.embed_scale
279
+ activations["embedding"] = x.detach().float().cpu()
280
+
281
+ # Expand phase
282
+ for i, block in enumerate(model.mirror_blocks):
283
+ x, _ = block(x, word_positions=word_positions)
284
+ activations[f"expand_{i}"] = x.detach().float().cpu()
285
+
286
+ # Middle phase
287
+ for i, block in enumerate(model.middle_blocks):
288
+ x, _ = block(x, word_positions=word_positions)
289
+ activations[f"middle_{i}"] = x.detach().float().cpu()
290
+
291
+ # Compress phase (reversed)
292
+ for i in reversed(range(len(model.mirror_blocks))):
293
+ x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
294
+ compress_idx = len(model.mirror_blocks) - 1 - i
295
+ activations[f"compress_{compress_idx}"] = x.detach().float().cpu()
296
+
297
+ # Final norm
298
+ x = model.norm(x)
299
+ activations["final_norm"] = x.detach().float().cpu()
300
+
301
+ return activations
302
+
303
+
304
+ def activation_spectrum(act: torch.Tensor, max_components: int = 256) -> dict:
305
+ """Compute eigenspectrum of activation covariance.
306
+
307
+ act: [B, T, D] — reshape to [B*T, D], compute covariance, eigendecompose.
308
+ """
309
+ # Flatten batch and sequence
310
+ flat = act.reshape(-1, act.shape[-1]) # [N, D]
311
+ N, D = flat.shape
312
+
313
+ if N < 2:
314
+ return None
315
+
316
+ # Center
317
+ flat = flat - flat.mean(dim=0, keepdim=True)
318
+
319
+ # Compute covariance via SVD of the data matrix (more stable than cov matrix)
320
+ n_components = min(max_components, D, N)
321
+ try:
322
+ U, S, Vh = torch.pca_lowrank(flat, q=n_components)
323
+ eigenvalues = (S ** 2 / (N - 1)).numpy()
324
+ except Exception:
325
+ # Fallback: full covariance
326
+ cov = (flat.T @ flat) / (N - 1)
327
+ eigenvalues = torch.linalg.eigvalsh(cov).flip(0).numpy()
328
+ eigenvalues = eigenvalues[:max_components]
329
+
330
+ eigenvalues = eigenvalues[eigenvalues > 1e-10]
331
+
332
+ return {
333
+ "eigenvalues": eigenvalues,
334
+ "effective_rank": effective_rank(np.sqrt(np.maximum(eigenvalues, 0))),
335
+ "total_variance": float(eigenvalues.sum()),
336
+ "top1_variance_ratio": float(eigenvalues[0] / eigenvalues.sum()) if len(eigenvalues) > 0 else 0,
337
+ "top10_variance_ratio": float(eigenvalues[:10].sum() / eigenvalues.sum()) if len(eigenvalues) >= 10 else 0,
338
+ "n_components": len(eigenvalues),
339
+ }
340
+
341
+
342
+ # ---------------------------------------------------------------------------
343
+ # Plotting
344
+ # ---------------------------------------------------------------------------
345
+
346
+ def plot_weight_spectra(results: dict, output_dir: Path, model_label: str = "model",
347
+ results_b: dict = None, model_b_label: str = "model_b"):
348
+ """Plot singular value distributions for all weight matrices."""
349
+ # Group by layer/component type
350
+ groups = defaultdict(list)
351
+ for name, data in results.items():
352
+ # Identify the component type
353
+ if "attn" in name and ("q_proj" in name or "wq" in name):
354
+ groups["attention_Q"].append((name, data))
355
+ elif "attn" in name and ("k_proj" in name or "wk" in name):
356
+ groups["attention_K"].append((name, data))
357
+ elif "attn" in name and ("v_proj" in name or "wv" in name):
358
+ groups["attention_V"].append((name, data))
359
+ elif "attn" in name and ("o_proj" in name or "wo" in name):
360
+ groups["attention_O"].append((name, data))
361
+ elif "w1" in name or "up_proj" in name:
362
+ groups["ffn_W1"].append((name, data))
363
+ elif "w2" in name or "down_proj" in name:
364
+ groups["ffn_W2"].append((name, data))
365
+ elif "w3" in name or "gate_proj" in name:
366
+ groups["ffn_gate_W3"].append((name, data))
367
+ elif "w4" in name:
368
+ groups["ffn_gate_W4"].append((name, data))
369
+ elif "embed" in name or "wte" in name:
370
+ groups["embedding"].append((name, data))
371
+ elif "lm_head" in name:
372
+ groups["lm_head"].append((name, data))
373
+ else:
374
+ groups["other"].append((name, data))
375
+
376
+ # Plot each group
377
+ for group_name, items in groups.items():
378
+ if not items:
379
+ continue
380
+
381
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
382
+ fig.suptitle(f"{model_label} — {group_name} weight spectra", fontsize=13)
383
+
384
+ ax_linear, ax_log = axes
385
+
386
+ cmap = plt.cm.viridis(np.linspace(0.1, 0.9, len(items)))
387
+ for idx, (name, data) in enumerate(items):
388
+ sv = data["singular_values"]
389
+ short_name = name.split(".")[-2] + "." + name.split(".")[-1] if "." in name else name
390
+ ax_linear.plot(sv, color=cmap[idx], alpha=0.7, linewidth=0.8, label=short_name)
391
+ ax_log.loglog(np.arange(1, len(sv) + 1), sv, color=cmap[idx], alpha=0.7,
392
+ linewidth=0.8, label=short_name)
393
+ # MP bound
394
+ ax_linear.axhline(data["mp_bound"], color=cmap[idx], linestyle=":", alpha=0.3)
395
+
396
+ ax_linear.set_xlabel("Rank")
397
+ ax_linear.set_ylabel("Singular value")
398
+ ax_linear.set_title("Linear scale")
399
+ ax_linear.legend(fontsize=6, ncol=2)
400
+
401
+ ax_log.set_xlabel("Rank")
402
+ ax_log.set_ylabel("Singular value")
403
+ ax_log.set_title("Log-log scale (power law)")
404
+ ax_log.legend(fontsize=6, ncol=2)
405
+
406
+ plt.tight_layout()
407
+ fig.savefig(output_dir / f"weight_spectra_{group_name}.png", dpi=150)
408
+ plt.close(fig)
409
+
410
+
411
+ def plot_effective_rank_progression(results: dict, output_dir: Path,
412
+ model_label: str = "model",
413
+ results_b: dict = None,
414
+ model_b_label: str = "model_b"):
415
+ """Plot effective rank per layer — the biconcave lens in eigenvalues."""
416
+ # Extract layer-ordered FFN W1 effective ranks (the main signal path)
417
+ layer_data = []
418
+ for name, data in sorted(results.items()):
419
+ if "w1" in name or "up_proj" in name:
420
+ # Extract layer index
421
+ parts = name.split(".")
422
+ layer_label = name
423
+ for p in parts:
424
+ if p.isdigit():
425
+ layer_label = p
426
+ break
427
+ layer_data.append((name, data["effective_rank"], data["stable_rank"],
428
+ data["alpha"], data["signal_ratio"], layer_label))
429
+
430
+ if not layer_data:
431
+ return
432
+
433
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
434
+ fig.suptitle(f"{model_label} — Layer-wise spectral properties (FFN W1)", fontsize=13)
435
+
436
+ names = [d[0] for d in layer_data]
437
+ x = range(len(layer_data))
438
+ short_labels = [d[5] for d in layer_data]
439
+
440
+ # Effective rank
441
+ axes[0, 0].bar(x, [d[1] for d in layer_data], color="steelblue", alpha=0.8)
442
+ axes[0, 0].set_ylabel("Effective rank")
443
+ axes[0, 0].set_title("Effective rank (entropy-based)")
444
+ axes[0, 0].set_xticks(x)
445
+ axes[0, 0].set_xticklabels(short_labels, rotation=45, fontsize=7)
446
+
447
+ # Stable rank
448
+ axes[0, 1].bar(x, [d[2] for d in layer_data], color="coral", alpha=0.8)
449
+ axes[0, 1].set_ylabel("Stable rank")
450
+ axes[0, 1].set_title("Stable rank (Frobenius/spectral)")
451
+ axes[0, 1].set_xticks(x)
452
+ axes[0, 1].set_xticklabels(short_labels, rotation=45, fontsize=7)
453
+
454
+ # Power-law alpha
455
+ axes[1, 0].bar(x, [d[3] for d in layer_data], color="mediumpurple", alpha=0.8)
456
+ axes[1, 0].set_ylabel("Alpha")
457
+ axes[1, 0].set_title("Power-law exponent (lower = heavier tail = more structure)")
458
+ axes[1, 0].axhline(2.0, color="red", linestyle="--", alpha=0.5, label="alpha=2 boundary")
459
+ axes[1, 0].legend(fontsize=8)
460
+ axes[1, 0].set_xticks(x)
461
+ axes[1, 0].set_xticklabels(short_labels, rotation=45, fontsize=7)
462
+
463
+ # Signal ratio (above MP)
464
+ axes[1, 1].bar(x, [d[4] for d in layer_data], color="seagreen", alpha=0.8)
465
+ axes[1, 1].set_ylabel("Signal ratio")
466
+ axes[1, 1].set_title("Fraction of singular values above MP bound")
467
+ axes[1, 1].set_xticks(x)
468
+ axes[1, 1].set_xticklabels(short_labels, rotation=45, fontsize=7)
469
+
470
+ plt.tight_layout()
471
+ fig.savefig(output_dir / "layer_progression.png", dpi=150)
472
+ plt.close(fig)
473
+
474
+
475
+ def plot_activation_spectra(act_spectra: dict, output_dir: Path,
476
+ model_label: str = "model"):
477
+ """Plot activation eigenspectra across layers."""
478
+ if not act_spectra:
479
+ return
480
+
481
+ # Sort layers in processing order
482
+ order_keys = {"embedding": -1, "final_norm": 999}
483
+ def sort_key(name):
484
+ if name in order_keys:
485
+ return order_keys[name]
486
+ parts = name.split("_")
487
+ phase = parts[0]
488
+ idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
489
+ phase_offset = {"expand": 0, "middle": 100, "compress": 200, "layer": 0}
490
+ return phase_offset.get(phase, 300) + idx
491
+
492
+ sorted_names = sorted(act_spectra.keys(), key=sort_key)
493
+
494
+ # -- Eigenvalue distributions --
495
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
496
+ fig.suptitle(f"{model_label} — Activation eigenspectra", fontsize=13)
497
+
498
+ cmap = plt.cm.coolwarm(np.linspace(0, 1, len(sorted_names)))
499
+ for idx, name in enumerate(sorted_names):
500
+ data = act_spectra[name]
501
+ ev = data["eigenvalues"]
502
+ axes[0].semilogy(ev / ev.sum(), color=cmap[idx], alpha=0.7, linewidth=1.0, label=name)
503
+ axes[1].plot(np.cumsum(ev) / ev.sum(), color=cmap[idx], alpha=0.7, linewidth=1.0, label=name)
504
+
505
+ axes[0].set_xlabel("Component")
506
+ axes[0].set_ylabel("Normalized eigenvalue (log)")
507
+ axes[0].set_title("Eigenvalue distribution")
508
+ axes[0].legend(fontsize=6, ncol=2)
509
+
510
+ axes[1].set_xlabel("Component")
511
+ axes[1].set_ylabel("Cumulative variance explained")
512
+ axes[1].set_title("Variance concentration")
513
+ axes[1].axhline(0.9, color="gray", linestyle="--", alpha=0.4, label="90%")
514
+ axes[1].legend(fontsize=6, ncol=2)
515
+
516
+ plt.tight_layout()
517
+ fig.savefig(output_dir / "activation_spectra.png", dpi=150)
518
+ plt.close(fig)
519
+
520
+ # -- Effective rank progression (the lens shape) --
521
+ fig, ax = plt.subplots(figsize=(12, 5))
522
+ fig.suptitle(f"{model_label} — Activation effective rank progression", fontsize=13)
523
+
524
+ eranks = [act_spectra[n]["effective_rank"] for n in sorted_names]
525
+ colors = []
526
+ for name in sorted_names:
527
+ if "expand" in name:
528
+ colors.append("steelblue")
529
+ elif "middle" in name:
530
+ colors.append("goldenrod")
531
+ elif "compress" in name:
532
+ colors.append("coral")
533
+ else:
534
+ colors.append("gray")
535
+
536
+ ax.bar(range(len(sorted_names)), eranks, color=colors, alpha=0.8)
537
+ ax.set_xticks(range(len(sorted_names)))
538
+ ax.set_xticklabels(sorted_names, rotation=45, ha="right", fontsize=8)
539
+ ax.set_ylabel("Effective rank")
540
+ ax.set_title("Expand (blue) → Middle (gold) → Compress (coral)")
541
+
542
+ plt.tight_layout()
543
+ fig.savefig(output_dir / "activation_rank_progression.png", dpi=150)
544
+ plt.close(fig)
545
+
546
+
547
+ def plot_mirror_comparison(act_spectra: dict, output_dir: Path,
548
+ model_label: str = "model"):
549
+ """Compare expand vs compress activation spectra for each mirror pair."""
550
+ expand_layers = sorted([n for n in act_spectra if n.startswith("expand_")])
551
+ compress_layers = sorted([n for n in act_spectra if n.startswith("compress_")])
552
+
553
+ if not expand_layers or not compress_layers:
554
+ return
555
+
556
+ n_pairs = min(len(expand_layers), len(compress_layers))
557
+ fig, axes = plt.subplots(1, n_pairs, figsize=(4 * n_pairs, 4), squeeze=False)
558
+ fig.suptitle(f"{model_label} — Mirror pair activation spectra (expand vs compress)", fontsize=13)
559
+
560
+ for i in range(n_pairs):
561
+ ax = axes[0, i]
562
+ exp_ev = act_spectra[expand_layers[i]]["eigenvalues"]
563
+ comp_ev = act_spectra[compress_layers[i]]["eigenvalues"]
564
+
565
+ n_plot = min(len(exp_ev), len(comp_ev), 100)
566
+ ax.semilogy(exp_ev[:n_plot] / exp_ev.sum(), color="steelblue", alpha=0.8,
567
+ linewidth=1.5, label="expand")
568
+ ax.semilogy(comp_ev[:n_plot] / comp_ev.sum(), color="coral", alpha=0.8,
569
+ linewidth=1.5, label="compress")
570
+
571
+ exp_er = act_spectra[expand_layers[i]]["effective_rank"]
572
+ comp_er = act_spectra[compress_layers[i]]["effective_rank"]
573
+ ax.set_title(f"Pair {i}\nerank: {exp_er:.0f} / {comp_er:.0f}", fontsize=10)
574
+ ax.set_xlabel("Component")
575
+ if i == 0:
576
+ ax.set_ylabel("Normalized eigenvalue")
577
+ ax.legend(fontsize=8)
578
+
579
+ plt.tight_layout()
580
+ fig.savefig(output_dir / "mirror_pair_comparison.png", dpi=150)
581
+ plt.close(fig)
582
+
583
+
584
+ def plot_gate_spectra(results: dict, output_dir: Path, model_label: str = "model"):
585
+ """Compare W3 vs W4 gate weight spectra (G2LU inner vs outer gate)."""
586
+ w3_items = [(n, d) for n, d in sorted(results.items()) if "w3" in n and "ffn" in n]
587
+ w4_items = [(n, d) for n, d in sorted(results.items()) if "w4" in n and "ffn" in n]
588
+
589
+ if not w3_items or not w4_items:
590
+ return
591
+
592
+ n_pairs = min(len(w3_items), len(w4_items))
593
+ fig, axes = plt.subplots(2, 1, figsize=(12, 8))
594
+ fig.suptitle(f"{model_label} — G2LU gate spectra (W3 outer vs W4 inner)", fontsize=13)
595
+
596
+ # Overlay all W3 vs W4
597
+ cmap_w3 = plt.cm.Blues(np.linspace(0.3, 0.9, n_pairs))
598
+ cmap_w4 = plt.cm.Reds(np.linspace(0.3, 0.9, n_pairs))
599
+
600
+ for i in range(n_pairs):
601
+ sv3 = w3_items[i][1]["singular_values"]
602
+ sv4 = w4_items[i][1]["singular_values"]
603
+ axes[0].semilogy(sv3, color=cmap_w3[i], alpha=0.6, linewidth=0.8, label=f"W3 pair {i}")
604
+ axes[0].semilogy(sv4, color=cmap_w4[i], alpha=0.6, linewidth=0.8, label=f"W4 pair {i}")
605
+
606
+ axes[0].set_xlabel("Rank")
607
+ axes[0].set_ylabel("Singular value (log)")
608
+ axes[0].set_title("Gate weight spectra")
609
+ axes[0].legend(fontsize=6, ncol=4)
610
+
611
+ # Effective rank comparison
612
+ er_w3 = [w3_items[i][1]["effective_rank"] for i in range(n_pairs)]
613
+ er_w4 = [w4_items[i][1]["effective_rank"] for i in range(n_pairs)]
614
+ x = np.arange(n_pairs)
615
+ axes[1].bar(x - 0.15, er_w3, 0.3, color="steelblue", alpha=0.8, label="W3 (outer gate)")
616
+ axes[1].bar(x + 0.15, er_w4, 0.3, color="coral", alpha=0.8, label="W4 (inner gate)")
617
+ axes[1].set_xlabel("Mirror pair")
618
+ axes[1].set_ylabel("Effective rank")
619
+ axes[1].set_title("Gate effective rank by pair")
620
+ axes[1].set_xticks(x)
621
+ axes[1].legend()
622
+
623
+ plt.tight_layout()
624
+ fig.savefig(output_dir / "gate_spectra.png", dpi=150)
625
+ plt.close(fig)
626
+
627
+
628
+ def plot_embedding_alignment(results: dict, act_spectra: dict, output_dir: Path,
629
+ model_label: str = "model"):
630
+ """Compare embedding weight spectrum with final layer activation spectrum."""
631
+ embed_data = None
632
+ for name, data in results.items():
633
+ if "embed" in name.lower() and "proj" not in name.lower() and "g3" not in name.lower() and "g4" not in name.lower():
634
+ embed_data = data
635
+ break
636
+
637
+ final_act = act_spectra.get("final_norm") or act_spectra.get("compress_0")
638
+ if embed_data is None or final_act is None:
639
+ return
640
+
641
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
642
+ fig.suptitle(f"{model_label} — Embedding vs final activation spectra", fontsize=13)
643
+
644
+ # Normalized comparison
645
+ sv_embed = embed_data["singular_values"]
646
+ ev_final = final_act["eigenvalues"]
647
+ sv_embed_norm = sv_embed / sv_embed.sum()
648
+ ev_final_norm = ev_final / ev_final.sum()
649
+
650
+ n_plot = min(len(sv_embed_norm), len(ev_final_norm), 200)
651
+ axes[0].semilogy(sv_embed_norm[:n_plot], color="steelblue", linewidth=1.5,
652
+ label=f"Embedding (erank={embed_data['effective_rank']:.0f})")
653
+ axes[0].semilogy(ev_final_norm[:n_plot], color="coral", linewidth=1.5,
654
+ label=f"Final act (erank={final_act['effective_rank']:.0f})")
655
+ axes[0].set_xlabel("Component")
656
+ axes[0].set_ylabel("Normalized value (log)")
657
+ axes[0].set_title("Spectral shape comparison")
658
+ axes[0].legend()
659
+
660
+ # Cumulative variance
661
+ axes[1].plot(np.cumsum(sv_embed_norm[:n_plot]), color="steelblue", linewidth=1.5, label="Embedding")
662
+ axes[1].plot(np.cumsum(ev_final_norm[:n_plot]), color="coral", linewidth=1.5, label="Final activation")
663
+ axes[1].set_xlabel("Component")
664
+ axes[1].set_ylabel("Cumulative fraction")
665
+ axes[1].set_title("Variance concentration")
666
+ axes[1].axhline(0.9, color="gray", linestyle="--", alpha=0.4)
667
+ axes[1].legend()
668
+
669
+ plt.tight_layout()
670
+ fig.savefig(output_dir / "embedding_alignment.png", dpi=150)
671
+ plt.close(fig)
672
+
673
+
674
+ def plot_comparison(results_a: dict, results_b: dict,
675
+ label_a: str, label_b: str,
676
+ output_dir: Path):
677
+ """Side-by-side comparison of two models' spectral properties."""
678
+ # Collect effective ranks for FFN W1 / up_proj
679
+ def extract_ffn_ranks(results):
680
+ ranks = []
681
+ for name, data in sorted(results.items()):
682
+ if ("w1" in name or "up_proj" in name or "c_fc" in name
683
+ or "dense_h_to_4h" in name) and "embed" not in name:
684
+ ranks.append((name, data["effective_rank"], data["stable_rank"], data["alpha"]))
685
+ return ranks
686
+
687
+ ranks_a = extract_ffn_ranks(results_a)
688
+ ranks_b = extract_ffn_ranks(results_b)
689
+
690
+ if not ranks_a or not ranks_b:
691
+ return
692
+
693
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
694
+ fig.suptitle(f"Comparison: {label_a} vs {label_b}", fontsize=13)
695
+
696
+ n = min(len(ranks_a), len(ranks_b))
697
+ x = np.arange(n)
698
+
699
+ for ax_idx, (metric_idx, ylabel, title) in enumerate([
700
+ (1, "Effective rank", "Effective rank per layer"),
701
+ (2, "Stable rank", "Stable rank per layer"),
702
+ (3, "Alpha", "Power-law alpha per layer"),
703
+ ]):
704
+ vals_a = [ranks_a[i][metric_idx] for i in range(n)]
705
+ vals_b = [ranks_b[i][metric_idx] for i in range(n)]
706
+ axes[ax_idx].bar(x - 0.15, vals_a, 0.3, color="steelblue", alpha=0.8, label=label_a)
707
+ axes[ax_idx].bar(x + 0.15, vals_b, 0.3, color="coral", alpha=0.8, label=label_b)
708
+ axes[ax_idx].set_xlabel("Layer")
709
+ axes[ax_idx].set_ylabel(ylabel)
710
+ axes[ax_idx].set_title(title)
711
+ axes[ax_idx].legend(fontsize=8)
712
+
713
+ plt.tight_layout()
714
+ fig.savefig(output_dir / "comparison.png", dpi=150)
715
+ plt.close(fig)
716
+
717
+
718
+ # ---------------------------------------------------------------------------
719
+ # Summary report
720
+ # ---------------------------------------------------------------------------
721
+
722
+ def print_summary(results: dict, model_label: str, act_spectra: dict = None):
723
+ """Print a concise text summary of spectral analysis."""
724
+ print(f"\n{'='*70}")
725
+ print(f" Spectral Analysis: {model_label}")
726
+ print(f"{'='*70}")
727
+
728
+ # Group by component type
729
+ components = defaultdict(list)
730
+ for name, data in sorted(results.items()):
731
+ if "w1" in name or "up_proj" in name:
732
+ components["FFN W1 (up)"].append(data)
733
+ elif "w2" in name or "down_proj" in name:
734
+ components["FFN W2 (down)"].append(data)
735
+ elif "w3" in name:
736
+ components["FFN W3 (outer gate)"].append(data)
737
+ elif "w4" in name:
738
+ components["FFN W4 (inner gate)"].append(data)
739
+ elif "embed" in name.lower() and "proj" not in name and "g3" not in name and "g4" not in name:
740
+ components["Embedding"].append(data)
741
+
742
+ print(f"\n{'Component':<25} {'Shape':>12} {'eRank':>8} {'sRank':>8} {'Alpha':>8} {'Sig%':>8} {'Cond#':>10}")
743
+ print("-" * 85)
744
+ for comp_name, items in components.items():
745
+ for i, data in enumerate(items):
746
+ label = f"{comp_name}" if len(items) == 1 else f"{comp_name}[{i}]"
747
+ shape_str = f"{data['shape'][0]}x{data['shape'][1]}"
748
+ cond = f"{data['condition_number']:.0f}" if data['condition_number'] < 1e6 else "inf"
749
+ print(f"{label:<25} {shape_str:>12} {data['effective_rank']:>8.1f} "
750
+ f"{data['stable_rank']:>8.1f} {data['alpha']:>8.3f} "
751
+ f"{data['signal_ratio']*100:>7.1f}% {cond:>10}")
752
+
753
+ # Aggregate stats
754
+ all_alphas = [d["alpha"] for d in results.values() if d["alpha"] > 0]
755
+ all_eranks = [d["effective_rank"] for d in results.values()]
756
+ if all_alphas:
757
+ print(f"\n Mean alpha: {np.mean(all_alphas):.3f} (< 2.0 = heavy-tailed = well-structured)")
758
+ print(f" Mean effective rank: {np.mean(all_eranks):.1f}")
759
+
760
+ # Activation summary
761
+ if act_spectra:
762
+ print(f"\n Activation spectra:")
763
+ print(f" {'Layer':<25} {'eRank':>8} {'Top1%':>8} {'Top10%':>8}")
764
+ print(" " + "-" * 55)
765
+
766
+ order_keys = {"embedding": -1, "final_norm": 999}
767
+ def sort_key(name):
768
+ if name in order_keys:
769
+ return order_keys[name]
770
+ parts = name.split("_")
771
+ phase = parts[0]
772
+ idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
773
+ phase_offset = {"expand": 0, "middle": 100, "compress": 200, "layer": 0}
774
+ return phase_offset.get(phase, 300) + idx
775
+
776
+ for name in sorted(act_spectra.keys(), key=sort_key):
777
+ data = act_spectra[name]
778
+ print(f" {name:<25} {data['effective_rank']:>8.1f} "
779
+ f"{data['top1_variance_ratio']*100:>7.1f}% "
780
+ f"{data['top10_variance_ratio']*100:>7.1f}%")
781
+
782
+
783
+ def save_results_json(results: dict, act_spectra: dict, output_path: Path):
784
+ """Save numerical results (no numpy arrays) to JSON."""
785
+ out = {}
786
+ for name, data in results.items():
787
+ out[name] = {k: v for k, v in data.items() if k != "singular_values"}
788
+ out[name]["top_10_sv"] = data["singular_values"][:10].tolist()
789
+
790
+ if act_spectra:
791
+ out["_activations"] = {}
792
+ for name, data in act_spectra.items():
793
+ out["_activations"][name] = {k: v for k, v in data.items() if k != "eigenvalues"}
794
+ out["_activations"][name]["top_10_ev"] = data["eigenvalues"][:10].tolist()
795
+
796
+ with open(output_path, "w") as f:
797
+ json.dump(out, f, indent=2, default=str)
798
+
799
+
800
+ # ---------------------------------------------------------------------------
801
+ # Data loading (minimal — just enough tokens for activation analysis)
802
+ # ---------------------------------------------------------------------------
803
+
804
+ def load_sample_data(data_source: str, tokenizer_name: str, num_samples: int = 256,
805
+ context_length: int = 512, device: str = "cpu"):
806
+ """Load a small batch of tokenized data for activation analysis."""
807
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
808
+ from circuits.data import get_tokenizer
809
+
810
+ tokenizer = get_tokenizer(tokenizer_name)
811
+
812
+ if data_source.startswith("hf:"):
813
+ from datasets import load_dataset
814
+ parts = data_source[3:].split(":")
815
+ ds_name = parts[0]
816
+ ds_config = parts[1] if len(parts) > 1 else None
817
+ ds_split = parts[2] if len(parts) > 2 else "train"
818
+ dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
819
+ texts = []
820
+ for item in dataset:
821
+ texts.append(item.get("text", ""))
822
+ if len(texts) >= num_samples:
823
+ break
824
+ else:
825
+ with open(data_source) as f:
826
+ texts = [line.strip() for line in f if line.strip()][:num_samples]
827
+
828
+ # Tokenize and create batches
829
+ all_ids = []
830
+ for text in texts:
831
+ ids = tokenizer.encode(text)
832
+ if len(ids) >= context_length:
833
+ all_ids.append(ids[:context_length])
834
+ elif len(ids) > 32:
835
+ all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
836
+
837
+ if not all_ids:
838
+ return None, tokenizer
839
+
840
+ input_ids = torch.tensor(all_ids[:num_samples], device=device)
841
+ return input_ids, tokenizer
842
+
843
+
844
+ # ---------------------------------------------------------------------------
845
+ # Main
846
+ # ---------------------------------------------------------------------------
847
+
848
+ def main():
849
+ parser = argparse.ArgumentParser(description="Spectral analysis of Prisma checkpoints")
850
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to Prisma/Circuit checkpoint")
851
+ parser.add_argument("--checkpoint-b", type=str, default=None, help="Second checkpoint for comparison")
852
+ parser.add_argument("--hf-model", type=str, default=None, help="HuggingFace model name for comparison")
853
+ parser.add_argument("--data", type=str, default=None,
854
+ help="Data source for activation analysis (hf:dataset:config:split or path)")
855
+ parser.add_argument("--num-samples", type=int, default=256, help="Number of samples for activation analysis")
856
+ parser.add_argument("--context-length", type=int, default=512, help="Context length for activation analysis")
857
+ parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: auto)")
858
+ parser.add_argument("--gpu", type=int, default=0, help="GPU index")
859
+ parser.add_argument("--no-activations", action="store_true", help="Skip activation analysis even if data provided")
860
+ args = parser.parse_args()
861
+
862
+ device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
863
+ print(f"Device: {device}")
864
+
865
+ # Output directory
866
+ if args.output_dir:
867
+ output_dir = Path(args.output_dir)
868
+ else:
869
+ ckpt_name = Path(args.checkpoint).parent.name
870
+ output_dir = Path("circuits/scripts/spectral_output") / ckpt_name
871
+ output_dir.mkdir(parents=True, exist_ok=True)
872
+ print(f"Output: {output_dir}")
873
+
874
+ # ── Load model A ──
875
+ print(f"\nLoading: {args.checkpoint}")
876
+ model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
877
+ label_a = Path(args.checkpoint).parent.name
878
+ print(f" Type: {model_type_a}")
879
+ n_params = sum(p.numel() for p in model_a.parameters())
880
+ print(f" Parameters: {n_params:,}")
881
+
882
+ # ── Weight spectra (A) ──
883
+ print("\nAnalyzing weight spectra...")
884
+ weight_results_a = analyze_weight_spectra(model_a, label_a)
885
+ print(f" Analyzed {len(weight_results_a)} weight matrices")
886
+
887
+ # ── Activation spectra (A) ──
888
+ act_spectra_a = None
889
+ if args.data and not args.no_activations:
890
+ tokenizer_name = torch.load(args.checkpoint, map_location="cpu",
891
+ weights_only=False).get("tokenizer_name", "gpt2")
892
+ print(f"\nLoading data for activation analysis ({args.num_samples} samples)...")
893
+ input_ids, tokenizer = load_sample_data(
894
+ args.data, tokenizer_name, args.num_samples, args.context_length, device
895
+ )
896
+ if input_ids is not None:
897
+ print(f" Data shape: {input_ids.shape}")
898
+
899
+ # Compute word positions if needed
900
+ word_positions = None
901
+ word_rope_dims = config_a.get("word_rope_dims", 0)
902
+ if word_rope_dims > 0:
903
+ from circuits.layers import build_word_start_table, compute_word_positions
904
+ word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
905
+ word_positions = compute_word_positions(input_ids, word_start_table)
906
+
907
+ print(" Collecting activations...")
908
+ if model_type_a == "mirrored":
909
+ raw_acts = collect_mirrored_activations(model_a, input_ids, word_positions)
910
+ else:
911
+ raw_acts = collect_activations(model_a, input_ids, word_positions, model_type_a)
912
+
913
+ print(f" Computing activation spectra ({len(raw_acts)} layers)...")
914
+ act_spectra_a = {}
915
+ for name, act in raw_acts.items():
916
+ spec = activation_spectrum(act)
917
+ if spec is not None:
918
+ act_spectra_a[name] = spec
919
+
920
+ # ── Model B (optional comparison) ──
921
+ weight_results_b = None
922
+ label_b = None
923
+ if args.checkpoint_b:
924
+ print(f"\nLoading comparison: {args.checkpoint_b}")
925
+ model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
926
+ label_b = Path(args.checkpoint_b).parent.name
927
+ weight_results_b = analyze_weight_spectra(model_b, label_b)
928
+ del model_b
929
+ elif args.hf_model:
930
+ print(f"\nLoading HF model: {args.hf_model}")
931
+ model_b = load_hf_model(args.hf_model, device)
932
+ label_b = args.hf_model
933
+ weight_results_b = analyze_weight_spectra(model_b, label_b)
934
+ del model_b
935
+
936
+ if device.startswith("cuda"):
937
+ torch.cuda.empty_cache()
938
+
939
+ # ── Plots ──
940
+ print("\nGenerating plots...")
941
+ plot_weight_spectra(weight_results_a, output_dir, label_a)
942
+ plot_effective_rank_progression(weight_results_a, output_dir, label_a)
943
+ plot_gate_spectra(weight_results_a, output_dir, label_a)
944
+
945
+ if act_spectra_a:
946
+ plot_activation_spectra(act_spectra_a, output_dir, label_a)
947
+ plot_mirror_comparison(act_spectra_a, output_dir, label_a)
948
+ plot_embedding_alignment(weight_results_a, act_spectra_a, output_dir, label_a)
949
+
950
+ if weight_results_b and label_b:
951
+ plot_comparison(weight_results_a, weight_results_b, label_a, label_b, output_dir)
952
+ # Also print summary for B
953
+ print_summary(weight_results_b, label_b)
954
+
955
+ # ── Summary ──
956
+ print_summary(weight_results_a, label_a, act_spectra_a)
957
+
958
+ # ── Save ──
959
+ save_results_json(weight_results_a, act_spectra_a, output_dir / "results.json")
960
+ if weight_results_b:
961
+ save_results_json(weight_results_b, None, output_dir / "results_b.json")
962
+
963
+ print(f"\nAll outputs saved to: {output_dir}")
964
+ print(f" Plots: {len(list(output_dir.glob('*.png')))} PNG files")
965
+ print(f" Data: results.json")
966
+
967
+
968
+ if __name__ == "__main__":
969
+ main()
scripts/spectral_to_csv.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert spectral analysis JSON results to CSV tables for analysis."""
2
+ import json
3
+ import csv
4
+ import sys
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+
9
+
10
+ def classify_layer(name, model_type):
11
+ """Classify a weight matrix by layer index, component type, and phase."""
12
+ if model_type == "prisma":
13
+ # mirror_blocks.N.component
14
+ m = re.match(r'mirror_blocks\.(\d+)\.', name)
15
+ if m:
16
+ layer_idx = int(m.group(1))
17
+ phase = "mirror"
18
+ if 'attn' in name:
19
+ comp = 'Q' if 'q_proj' in name else 'K' if 'k_proj' in name else 'V' if 'v_proj' in name else 'O' if 'o_proj' in name else 'attn'
20
+ elif 'ffn.w3' in name or 'gate_expand' in name:
21
+ comp = 'W3'
22
+ elif 'ffn.w4' in name or 'gate_compress' in name:
23
+ comp = 'W4'
24
+ elif 'ffn.w1' in name:
25
+ comp = 'W1'
26
+ elif 'w2' in name:
27
+ comp = 'W2'
28
+ else:
29
+ comp = 'other'
30
+ return layer_idx, comp, phase
31
+
32
+ m = re.match(r'middle_blocks\.(\d+)\.', name)
33
+ if m:
34
+ layer_idx = int(m.group(1))
35
+ phase = "middle"
36
+ if 'attn' in name:
37
+ comp = 'Q' if 'q_proj' in name else 'K' if 'k_proj' in name else 'V' if 'v_proj' in name else 'O' if 'o_proj' in name else 'attn'
38
+ elif 'gate' in name:
39
+ comp = 'W3'
40
+ elif 'ffn.w1' in name:
41
+ comp = 'W1'
42
+ elif 'ffn.w2' in name:
43
+ comp = 'W2'
44
+ else:
45
+ comp = 'other'
46
+ return layer_idx, comp, phase
47
+
48
+ m = re.match(r'(first|last)_block\.', name)
49
+ if m:
50
+ phase = m.group(1)
51
+ if 'attn' in name:
52
+ comp = 'Q' if 'q_proj' in name else 'K' if 'k_proj' in name else 'V' if 'v_proj' in name else 'O' if 'o_proj' in name else 'attn'
53
+ elif 'ffn.w3' in name or 'gate' in name:
54
+ comp = 'W3'
55
+ elif 'ffn.w4' in name:
56
+ comp = 'W4'
57
+ elif 'ffn.w1' in name:
58
+ comp = 'W1'
59
+ elif 'ffn.w2' in name:
60
+ comp = 'W2'
61
+ else:
62
+ comp = 'other'
63
+ return 0, comp, phase
64
+
65
+ if 'embed' in name:
66
+ return -1, 'embed', 'embed'
67
+ if 'head' in name or 'lm_head' in name:
68
+ return 99, 'head', 'head'
69
+ return -1, 'other', 'other'
70
+
71
+ else: # GPT-2 style
72
+ m = re.match(r'transformer\.h\.(\d+)\.', name)
73
+ if m:
74
+ layer_idx = int(m.group(1))
75
+ if 'c_attn' in name:
76
+ comp = 'QKV'
77
+ elif 'c_proj' in name and 'mlp' not in name:
78
+ comp = 'O'
79
+ elif 'c_fc' in name:
80
+ comp = 'W1'
81
+ elif 'mlp.c_proj' in name:
82
+ comp = 'W2'
83
+ else:
84
+ comp = 'other'
85
+ return layer_idx, comp, "layer"
86
+
87
+ if 'wte' in name:
88
+ return -1, 'embed', 'embed'
89
+ if 'wpe' in name:
90
+ return -1, 'pos_embed', 'embed'
91
+ return -1, 'other', 'other'
92
+
93
+
94
+ def json_to_csvs(json_path, output_dir, model_type="prisma"):
95
+ with open(json_path) as f:
96
+ data = json.load(f)
97
+
98
+ os.makedirs(output_dir, exist_ok=True)
99
+
100
+ # 1. Full weight matrix summary
101
+ rows = []
102
+ for name, info in data.items():
103
+ if 'activation' in name or name.startswith('_'):
104
+ continue
105
+ layer_idx, comp, phase = classify_layer(name, model_type)
106
+ rows.append({
107
+ 'name': name,
108
+ 'layer_idx': layer_idx,
109
+ 'component': comp,
110
+ 'phase': phase,
111
+ 'shape': 'x'.join(str(s) for s in info['shape']),
112
+ 'effective_rank': round(info['effective_rank'], 2),
113
+ 'stable_rank': round(info['stable_rank'], 3),
114
+ 'spectral_norm': round(info['spectral_norm'], 4),
115
+ 'frobenius_norm': round(info['frobenius_norm'], 4),
116
+ 'alpha': round(info['alpha'], 4),
117
+ 'alpha_r2': round(info['alpha_r2'], 4),
118
+ 'signal_ratio': round(info['signal_ratio'], 4),
119
+ 'condition_number': round(info['condition_number'], 2),
120
+ 'mp_bound': round(info['mp_bound'], 4),
121
+ 'n_above_mp': info['n_above_mp'],
122
+ 'n_total': info['n_total'],
123
+ 'sv_1': round(info['top_10_sv'][0], 4) if info['top_10_sv'] else 0,
124
+ 'sv_2': round(info['top_10_sv'][1], 4) if len(info['top_10_sv']) > 1 else 0,
125
+ 'sv_10': round(info['top_10_sv'][9], 4) if len(info['top_10_sv']) > 9 else 0,
126
+ 'sv1_sv2_ratio': round(info['top_10_sv'][0] / info['top_10_sv'][1], 4) if len(info['top_10_sv']) > 1 and info['top_10_sv'][1] > 0 else 0,
127
+ })
128
+
129
+ with open(os.path.join(output_dir, 'weights_full.csv'), 'w', newline='') as f:
130
+ w = csv.DictWriter(f, fieldnames=rows[0].keys())
131
+ w.writeheader()
132
+ w.writerows(sorted(rows, key=lambda r: (r['phase'], r['layer_idx'], r['component'])))
133
+
134
+ # 2. Layer-level FFN summary (W1 progression = the lens)
135
+ ffn_rows = [r for r in rows if r['component'] == 'W1']
136
+ with open(os.path.join(output_dir, 'ffn_w1_progression.csv'), 'w', newline='') as f:
137
+ w = csv.DictWriter(f, fieldnames=['layer_idx', 'phase', 'effective_rank', 'stable_rank', 'alpha', 'alpha_r2', 'signal_ratio', 'condition_number', 'sv1_sv2_ratio'])
138
+ w.writeheader()
139
+ for r in sorted(ffn_rows, key=lambda r: (r['phase'], r['layer_idx'])):
140
+ w.writerow({k: r[k] for k in w.fieldnames})
141
+
142
+ # 3. Gate comparison (W3 vs W4)
143
+ gate_rows = [r for r in rows if r['component'] in ('W3', 'W4') and r['phase'] == 'mirror']
144
+ with open(os.path.join(output_dir, 'gate_comparison.csv'), 'w', newline='') as f:
145
+ w = csv.DictWriter(f, fieldnames=['layer_idx', 'component', 'effective_rank', 'stable_rank', 'alpha', 'alpha_r2', 'signal_ratio', 'sv1_sv2_ratio'])
146
+ w.writeheader()
147
+ for r in sorted(gate_rows, key=lambda r: (r['layer_idx'], r['component'])):
148
+ w.writerow({k: r[k] for k in w.fieldnames})
149
+
150
+ # 4. Attention head comparison (Q, K, V, O per layer)
151
+ attn_rows = [r for r in rows if r['component'] in ('Q', 'K', 'V', 'O', 'QKV')]
152
+ with open(os.path.join(output_dir, 'attention_progression.csv'), 'w', newline='') as f:
153
+ w = csv.DictWriter(f, fieldnames=['layer_idx', 'phase', 'component', 'effective_rank', 'stable_rank', 'alpha', 'signal_ratio', 'condition_number'])
154
+ w.writeheader()
155
+ for r in sorted(attn_rows, key=lambda r: (r['phase'], r['layer_idx'], r['component'])):
156
+ w.writerow({k: r[k] for k in w.fieldnames})
157
+
158
+ # 5. Summary statistics
159
+ alphas = [r['alpha'] for r in rows if r['alpha'] > 0]
160
+ eff_ranks = [r['effective_rank'] for r in rows if r['layer_idx'] >= 0]
161
+ signal_ratios = [r['signal_ratio'] for r in rows if r['layer_idx'] >= 0]
162
+
163
+ summary = {
164
+ 'n_matrices': len(rows),
165
+ 'mean_alpha': round(sum(alphas) / len(alphas), 4) if alphas else 0,
166
+ 'min_alpha': round(min(alphas), 4) if alphas else 0,
167
+ 'max_alpha': round(max(alphas), 4) if alphas else 0,
168
+ 'mean_effective_rank': round(sum(eff_ranks) / len(eff_ranks), 2) if eff_ranks else 0,
169
+ 'mean_signal_ratio': round(sum(signal_ratios) / len(signal_ratios), 4) if signal_ratios else 0,
170
+ 'n_well_trained (alpha<2)': sum(1 for a in alphas if a < 2.0),
171
+ 'n_total_alpha': len(alphas),
172
+ }
173
+ with open(os.path.join(output_dir, 'summary.csv'), 'w', newline='') as f:
174
+ w = csv.DictWriter(f, fieldnames=summary.keys())
175
+ w.writeheader()
176
+ w.writerow(summary)
177
+
178
+ print(f"Wrote CSVs to {output_dir}/")
179
+ print(f" weights_full.csv ({len(rows)} matrices)")
180
+ print(f" ffn_w1_progression.csv ({len(ffn_rows)} layers)")
181
+ print(f" gate_comparison.csv ({len(gate_rows)} entries)")
182
+ print(f" attention_progression.csv ({len(attn_rows)} entries)")
183
+ print(f" summary.csv")
184
+
185
+
186
+ if __name__ == '__main__':
187
+ base = "circuits/scripts/spectral_output/mirrored_300M_mk4_cont"
188
+
189
+ # Prisma
190
+ json_to_csvs(
191
+ f"{base}/results.json",
192
+ f"{base}/csv_prisma",
193
+ model_type="prisma"
194
+ )
195
+
196
+ # GPT-2 medium
197
+ if os.path.exists(f"{base}/results_b.json"):
198
+ json_to_csvs(
199
+ f"{base}/results_b.json",
200
+ f"{base}/csv_gpt2",
201
+ model_type="gpt2"
202
+ )
train.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Training script for Circuit Transformer.
4
+
5
+ Usage:
6
+ python circuits/train.py --data hf:roneneldan/TinyStories --preset tiny --epochs 1 --gpu 0
7
+ python circuits/train.py --data path/to/corpus.txt --dims 256 --layers 6 --fp16
8
+ """
9
+
10
+ import gc
11
+ import os
12
+ import time
13
+ import math
14
+ import random
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.cuda.amp import GradScaler
21
+ from torch.amp import autocast
22
+
23
+ from .config import CircuitConfig, parse_args
24
+ from .model import CircuitTransformer, count_parameters
25
+ from .mirrored import MirroredConfig, MirroredTransformer, count_mirrored_parameters
26
+ from .graft_g2lu import G2LU_GraftedModel, save_g2lu_checkpoint
27
+ from .layers import build_word_start_table, compute_word_positions
28
+ from .data import get_tokenizer, load_data, create_dataloader
29
+
30
+ def corrupt_tokens(input_ids, ratio, vocab_size):
31
+ """Replace random tokens with random vocab tokens for denoising autoencoder.
32
+
33
+ Returns (corrupted_ids, mask) where mask is True at corrupted positions.
34
+ """
35
+ mask = torch.rand(input_ids.shape, device=input_ids.device) < ratio
36
+ mask[:, 0] = False # never corrupt first token (BOS/start)
37
+ random_tokens = torch.randint(0, vocab_size, input_ids.shape, device=input_ids.device)
38
+ corrupted = input_ids.clone()
39
+ corrupted[mask] = random_tokens[mask]
40
+ return corrupted, mask
41
+
42
+
43
+ @torch.no_grad()
44
+ def evaluate(config, model, dataloader, device, use_amp=False, amp_dtype=torch.float16, mid_run_eval=False,
45
+ word_start_table=None):
46
+ """Run validation and return avg loss + perplexity."""
47
+ model.eval()
48
+ total_loss = 0.0
49
+ n_batches = 0
50
+
51
+ for batch in dataloader:
52
+ input_ids = batch["input_ids"].to(device)
53
+ labels = batch["labels"].to(device)
54
+ word_positions = None
55
+ if word_start_table is not None:
56
+ word_positions = compute_word_positions(input_ids, word_start_table)
57
+
58
+ if use_amp:
59
+ with autocast('cuda', dtype=amp_dtype):
60
+ output = model(input_ids, labels=labels, word_positions=word_positions)
61
+ else:
62
+ output = model(input_ids, labels=labels, word_positions=word_positions)
63
+
64
+ total_loss += output["loss"].item()
65
+ n_batches += 1
66
+
67
+ if n_batches % (config.log_every * 10) == 0:
68
+ avg_loss = total_loss / max(n_batches, 1)
69
+ ppl = math.exp(min(avg_loss, 20))
70
+ print(
71
+ f"batch {n_batches:6d}/{len(dataloader):6d} | "
72
+ f"Loss {total_loss / n_batches:.4f} | "
73
+ f"PPL {ppl:8.2f}"
74
+ )
75
+
76
+ if mid_run_eval and n_batches >= 1500 :
77
+ break
78
+
79
+ if not mid_run_eval:
80
+ model.train()
81
+
82
+ avg_loss = total_loss / max(n_batches, 1)
83
+ ppl = math.exp(min(avg_loss, 20)) # cap to avoid overflow
84
+
85
+ return avg_loss, ppl
86
+
87
+
88
+ def get_lr(step: int, warmup_steps: int, max_steps: int, max_lr: float, min_lr: float = 0.0, delay: int = 0) -> float:
89
+ """Cosine learning rate schedule with warmup and optional delay.
90
+
91
+ With delay > 0, the schedule is shifted:
92
+ Steps 0..delay: LR = 0 (frozen)
93
+ Steps delay..delay+warmup: linear ramp 0 → max_lr
94
+ Steps delay+warmup..max_steps: cosine decay max_lr → min_lr
95
+ """
96
+ if step < delay:
97
+ return 0.0
98
+ effective_step = step - delay
99
+ effective_max = max(1, max_steps - delay)
100
+ if effective_step < warmup_steps:
101
+ return max_lr * effective_step / warmup_steps
102
+ if effective_step >= effective_max:
103
+ return min_lr
104
+ progress = (effective_step - warmup_steps) / (effective_max - warmup_steps)
105
+ return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
106
+
107
+
108
+ def save_checkpoint(
109
+ model: nn.Module,
110
+ optimizer: torch.optim.Optimizer,
111
+ step: int,
112
+ epoch: int,
113
+ loss: float,
114
+ config,
115
+ path: str,
116
+ model_type: str = "standard",
117
+ epoch_step: int = 0,
118
+ best_val_loss: float | None = None,
119
+ scaler=None,
120
+ tokenizer_name: str = "gpt2",
121
+ ):
122
+ """Save training checkpoint.
123
+
124
+ Args:
125
+ epoch: Next epoch to start on resume (completed epoch count).
126
+ epoch_step: Batches already processed in `epoch` (0 if epoch is complete).
127
+ optimizer_mid: Middle optimizer for dual-path training (optional).
128
+ """
129
+ checkpoint = {
130
+ "model": model.state_dict(),
131
+ "optimizer": optimizer.state_dict(),
132
+ "step": step,
133
+ "epoch": epoch,
134
+ "epoch_step": epoch_step,
135
+ "loss": loss,
136
+ "config": config.to_dict(),
137
+ "model_type": model_type,
138
+ "tokenizer_name": tokenizer_name,
139
+ }
140
+ if best_val_loss is not None:
141
+ checkpoint["best_val_loss"] = best_val_loss
142
+ if scaler is not None:
143
+ checkpoint["scaler"] = scaler.state_dict()
144
+ torch.save(checkpoint, path)
145
+
146
+
147
+ def _migrate_state_dict(state_dict: dict, model: nn.Module) -> dict:
148
+ """Migrate checkpoint state_dict to match current model architecture.
149
+
150
+ Handles upgrades like SwiGLU → MirroredSwiGLU (dual_gate_middle).
151
+ """
152
+ model_keys = set(model.state_dict().keys())
153
+ ckpt_keys = set(state_dict.keys())
154
+
155
+ missing = model_keys - ckpt_keys
156
+ unexpected = ckpt_keys - model_keys
157
+
158
+ if not missing and not unexpected:
159
+ return state_dict # perfect match, no migration needed
160
+
161
+ migrated = dict(state_dict)
162
+ migrations = []
163
+
164
+ # SwiGLU → MirroredSwiGLU: w3 → gate_expand (dual_gate_middle upgrade)
165
+ for key in list(unexpected):
166
+ if ".ffn.gate_expand.weight" in key:
167
+ new_key = key.replace(".ffn.gate_expand.weight", ".ffn.w3.weight")
168
+ if new_key in missing:
169
+ migrated[new_key] = migrated.pop(key)
170
+ missing.discard(new_key)
171
+ unexpected.discard(key)
172
+ migrations.append(f" {key} → {new_key}")
173
+ if ".ffn.gate_compress.weight" in key:
174
+ new_key = key.replace(".ffn.gate_compress.weight", ".ffn.w4.weight")
175
+ if new_key in missing:
176
+ migrated[new_key] = migrated.pop(key)
177
+ missing.discard(new_key)
178
+ unexpected.discard(key)
179
+ migrations.append(f" {key} → {new_key}")
180
+
181
+ if migrations:
182
+ print(f"State dict migration ({len(migrations)} keys renamed):")
183
+ for m in migrations:
184
+ print(m)
185
+ # Report remaining missing keys (freshly initialized)
186
+ still_missing = model_keys - set(migrated.keys())
187
+ if still_missing:
188
+ print(f" New parameters (freshly initialized): {len(still_missing)}")
189
+ for k in sorted(still_missing):
190
+ print(f" {k}")
191
+
192
+ return migrated
193
+
194
+
195
+ def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer = None,
196
+ scaler=None, reset:bool = False):
197
+ """Load training checkpoint. Returns dict with resume info."""
198
+ checkpoint = torch.load(path, map_location="cpu", weights_only=False)
199
+ state_dict = _migrate_state_dict(checkpoint["model"], model)
200
+ model.load_state_dict(state_dict, strict=False)
201
+ if not reset:
202
+ if optimizer is not None and "optimizer" in checkpoint:
203
+ optimizer.load_state_dict(checkpoint["optimizer"])
204
+ if scaler is not None and "scaler" in checkpoint:
205
+ scaler.load_state_dict(checkpoint["scaler"])
206
+ return {
207
+ "step": checkpoint.get("step", 0),
208
+ "epoch": checkpoint.get("epoch", 0),
209
+ "epoch_step": checkpoint.get("epoch_step", 0),
210
+ "best_val_loss": checkpoint.get("best_val_loss", float("inf")),
211
+ }
212
+
213
+
214
+ def train():
215
+ config, args = parse_args()
216
+
217
+ # Setup device
218
+ device = torch.device(f"cuda:{config.gpu}" if torch.cuda.is_available() else "cpu")
219
+ print(f"Device: {device}")
220
+
221
+ # Load tokenizer and data
222
+ print(f"Loading data from: {args.data}")
223
+ model_type = args.arch
224
+ tokenizer_name = getattr(args, 'tokenizer', 'gpt2')
225
+ if model_type == "graft_g2lu":
226
+ tokenizer_name = args.pretrained
227
+ tokenizer = get_tokenizer(tokenizer_name)
228
+ config.vocab_size = len(tokenizer)
229
+ print(f"Tokenizer: {tokenizer_name} (vocab_size={config.vocab_size})")
230
+ cache_dir = None if args.no_cache else args.cache_dir
231
+ dataset = load_data(
232
+ args.data,
233
+ tokenizer,
234
+ config.max_seq_len,
235
+ text_column=args.text_column,
236
+ num_samples=args.num_samples,
237
+ cache_dir=cache_dir,
238
+ data_format=args.data_format,
239
+ )
240
+ print(f"Loaded {len(dataset):,} chunks")
241
+
242
+ # Train/val split
243
+ val_split = args.val_split
244
+ if val_split > 0 and len(dataset) > 20:
245
+ train_dataset, val_dataset = dataset.split(val_split)
246
+ print(f"Split: {len(train_dataset):,} train / {len(val_dataset):,} val ({val_split:.0%})")
247
+ else:
248
+ train_dataset = dataset
249
+ val_dataset = None
250
+
251
+ # Create dataloaders
252
+ dataloader = create_dataloader(
253
+ train_dataset,
254
+ config.batch_size,
255
+ shuffle=True,
256
+ )
257
+ val_dataloader = None
258
+ if val_dataset is not None:
259
+ val_dataloader = create_dataloader(
260
+ val_dataset,
261
+ config.batch_size,
262
+ shuffle=False,
263
+ )
264
+
265
+ # Create model
266
+ if model_type == "mirrored":
267
+ model_config = MirroredConfig(
268
+ vocab_size=config.vocab_size,
269
+ hidden_size=config.hidden_size,
270
+ num_heads=config.num_heads,
271
+ num_kv_heads=config.num_kv_heads,
272
+ num_layers=config.num_layers,
273
+ n_middle=args.n_middle,
274
+ max_seq_len=config.max_seq_len,
275
+ dropout=config.dropout,
276
+ use_g2lu=not getattr(args, 'no_g2lu', False),
277
+ aux_skip_k=getattr(args, 'aux_skip', 0),
278
+ aux_skip_weight=getattr(args, 'aux_weight', 0.1),
279
+ word_rope_dims=getattr(config, 'word_rope_dims', 0),
280
+ word_rope_base=getattr(config, 'word_rope_base', 10.0),
281
+ embed_dim=getattr(config, 'embed_dim', 0),
282
+ head_dim=getattr(config, 'head_dim', 0),
283
+ )
284
+ model = MirroredTransformer(model_config).to(device)
285
+ param_info = count_mirrored_parameters(model)
286
+ num_params = param_info["unique"]
287
+ print(f"Model: MirroredTransformer")
288
+ print(f" Virtual layers: {model.total_virtual_layers} ({model_config.n_mirror} mirror pairs + {model_config.n_middle} middle)")
289
+ print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M unique)")
290
+ print(f" Shared FFN base: {param_info['shared_ffn_base']:,}")
291
+ print(f" Direction gates: {param_info['direction_gates']:,}")
292
+ print(f" FFN gating: {'G²LU (nested dual gate)' if model_config.use_g2lu else 'SwiGLU (vanilla)'}")
293
+ if model_config.num_kv_heads is not None:
294
+ print(f" GQA: {model_config.num_heads}Q / {model_config.num_kv_heads}KV ({model_config.num_heads // model_config.num_kv_heads}:1 ratio)")
295
+ if model_config.aux_skip_k > 0:
296
+ print(f" Aux skip prediction: t+{model_config.aux_skip_k} (weight={model_config.aux_skip_weight})")
297
+ if getattr(model_config, 'embed_dim', 0) > 0:
298
+ std_embed = config.vocab_size * config.hidden_size
299
+ fact_embed = config.vocab_size * model_config.embed_dim + model_config.embed_dim * config.hidden_size
300
+ print(f" Factorized embedding: {model_config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
301
+ if getattr(model_config, 'head_dim', 0) > 0:
302
+ std_head = config.hidden_size * config.vocab_size
303
+ mlp_head = config.hidden_size * model_config.head_dim + model_config.head_dim * config.vocab_size
304
+ print(f" MLP head: {config.hidden_size} → {model_config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
305
+ elif model_type == "graft_g2lu":
306
+ assert args.pretrained, "--pretrained is required for graft_g2lu architecture"
307
+ amp_dtype = torch.bfloat16 if config.bf16 else (torch.float16 if config.fp16 else torch.float32)
308
+ model = G2LU_GraftedModel(
309
+ pretrained_name=args.pretrained,
310
+ align_weight=args.align_weight,
311
+ warmup_steps=args.graft_warmup,
312
+ device=device,
313
+ dtype=amp_dtype,
314
+ )
315
+ model_config = None # No CircuitConfig for HF models
316
+ num_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
317
+ else:
318
+ model_config = config
319
+ model = CircuitTransformer(config).to(device)
320
+ num_params = count_parameters(model)
321
+ print(f"Model: CircuitTransformer")
322
+ print(f" Parameters: {num_params:,} ({num_params/1e6:.1f}M)")
323
+ if getattr(config, 'aux_skip_k', 0) > 0:
324
+ print(f" Aux skip prediction: t+{config.aux_skip_k} (weight={config.aux_skip_weight})")
325
+ if getattr(config, 'embed_dim', 0) > 0:
326
+ std_embed = config.vocab_size * config.hidden_size
327
+ fact_embed = config.vocab_size * config.embed_dim + config.embed_dim * config.hidden_size
328
+ print(f" Factorized embedding: {config.embed_dim} → {config.hidden_size} (saves {(std_embed - fact_embed):,} params)")
329
+ if getattr(config, 'head_dim', 0) > 0:
330
+ std_head = config.hidden_size * config.vocab_size
331
+ mlp_head = config.hidden_size * config.head_dim + config.head_dim * config.vocab_size
332
+ print(f" MLP head: {config.hidden_size} → {config.head_dim} → vocab (saves {(std_head - mlp_head):,} params)")
333
+
334
+ # Build word-position table if enabled
335
+ word_rope_dims = getattr(config, 'word_rope_dims', 0)
336
+ if word_rope_dims > 0:
337
+ word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
338
+ print(f" Word-position RoPE: {word_rope_dims} dims, base={getattr(config, 'word_rope_base', 10.0)}")
339
+ print(f" Word starters in vocab: {word_start_table.sum().item():,} / {len(tokenizer):,}")
340
+ else:
341
+ word_start_table = None
342
+
343
+ # Keep raw reference for set_gate_step (torch.compile wraps the model)
344
+ raw_model = model
345
+
346
+ # Optionally compile
347
+ if config.compile and hasattr(torch, "compile"):
348
+ print("Compiling model with torch.compile...")
349
+ model = torch.compile(raw_model)
350
+
351
+ # Optimizer — with optional staggered warmup and dual-path training
352
+ grad_accum = getattr(args, 'grad_accum', 1)
353
+
354
+ opt_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
355
+ optimizer = torch.optim.AdamW(
356
+ opt_params,
357
+ lr=config.learning_rate,
358
+ weight_decay=config.weight_decay,
359
+ betas=(0.9, 0.95),
360
+ )
361
+
362
+ # Mixed precision
363
+ use_amp = (config.fp16 or config.bf16) and device.type == "cuda"
364
+ amp_dtype = torch.bfloat16 if config.bf16 else torch.float16
365
+ scaler = GradScaler() if (config.fp16 and use_amp) else None
366
+ if use_amp:
367
+ print(f" Mixed precision: {'BF16' if config.bf16 else 'FP16'}" +
368
+ (" (no scaler)" if scaler is None else " (with GradScaler)"))
369
+
370
+ # Resume from checkpoint
371
+ start_step = 0
372
+ start_epoch = 0
373
+ skip_batches = 0
374
+ best_val_loss = float("inf")
375
+ if args.resume:
376
+ print(f"Resuming from: {args.resume}")
377
+ resume_info = load_checkpoint(args.resume, model, optimizer, scaler, args.reset)
378
+ if not args.reset:
379
+ start_step = resume_info["step"]
380
+ start_epoch = resume_info["epoch"]
381
+ skip_batches = resume_info["epoch_step"]
382
+ best_val_loss = resume_info["best_val_loss"]
383
+ print(f"Resumed at step {start_step}, epoch {start_epoch}" +
384
+ (f", skipping {skip_batches} batches" if skip_batches > 0 else ""))
385
+ if best_val_loss < float("inf"):
386
+ print(f" Best val loss so far: {best_val_loss:.4f} (PPL {math.exp(min(best_val_loss, 20)):.2f})")
387
+
388
+ # Setup checkpoint directory
389
+ checkpoint_dir = Path(config.checkpoint_dir)
390
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
391
+
392
+ # Training loop
393
+ steps_per_epoch = math.ceil(len(dataloader) / grad_accum)
394
+ max_steps = config.epochs * steps_per_epoch
395
+ tokens_per_step = config.batch_size * grad_accum * config.max_seq_len
396
+ total_train_tokens = config.epochs * len(dataloader) * config.batch_size * config.max_seq_len
397
+ step = start_step
398
+ model.train()
399
+
400
+ print(f"\nStarting training:")
401
+ print(f" Epochs: {config.epochs}")
402
+ print(f" Batch size: {config.batch_size}" +
403
+ (f" x {grad_accum} accum = {config.batch_size * grad_accum} effective" if grad_accum > 1 else ""))
404
+ print(f" Steps per epoch: {steps_per_epoch}" +
405
+ (f" ({len(dataloader)} micro-batches)" if grad_accum > 1 else ""))
406
+ print(f" Total steps: {max_steps}")
407
+ print(f" Total tokens: {total_train_tokens:,} ({total_train_tokens/1e6:.1f}M)")
408
+ if num_params > 0:
409
+ print(f" Tokens/param ratio: {total_train_tokens/num_params:.1f}x (Chinchilla=20x)")
410
+ print(f" Learning rate: {config.learning_rate}" +
411
+ (f" → {config.min_lr}" if config.min_lr > 0 else ""))
412
+ print(f" Mixed precision: {use_amp}")
413
+ print(f" Validation: {'enabled' if val_dataloader else 'disabled'}")
414
+ print()
415
+
416
+ total_loss = 0.0
417
+ log_steps = 0
418
+ total_tokens_seen = step * tokens_per_step
419
+ # best_val_loss already set in resume section above
420
+ h_mid_buffer = None
421
+ last_align_val = float("inf")
422
+ start_time = time.time()
423
+
424
+ for epoch in range(start_epoch, config.epochs):
425
+ epoch_start = time.time()
426
+ epoch_loss = 0.0
427
+ epoch_steps = 0
428
+
429
+ micro_batches = []
430
+ epoch_micro_batches = skip_batches if epoch == start_epoch else 0
431
+
432
+ for batch_idx, batch in enumerate(dataloader):
433
+ # Skip already-processed batches on resume
434
+ if epoch == start_epoch and batch_idx < skip_batches:
435
+ continue
436
+
437
+ micro_batches.append(batch)
438
+ epoch_micro_batches += 1
439
+
440
+ # Accumulate micro-batches (flush at accum boundary or epoch end)
441
+ if len(micro_batches) < grad_accum and batch_idx < len(dataloader) - 1:
442
+ continue
443
+
444
+ n_micro = len(micro_batches)
445
+ actual_tokens = n_micro * config.batch_size * config.max_seq_len
446
+
447
+ # Update learning rate (per-group delays for staggered warmup)
448
+ for param_group in optimizer.param_groups:
449
+ delay = param_group.get("delay", 0)
450
+ param_group["lr"] = get_lr(step, config.warmup_steps, max_steps, config.learning_rate, min_lr=config.min_lr, delay=delay)
451
+ lr = optimizer.param_groups[0]["lr"] # for logging
452
+
453
+ loss_ed_val = None
454
+ loss_align_val = None
455
+ grad_norm_mid = None
456
+ absorb_loss_val = None
457
+
458
+ # Update blend alpha for G²LU grafting
459
+ if model_type == "graft_g2lu":
460
+ raw_model.set_step(step)
461
+
462
+ # === Standard single-path training with accumulation ===
463
+ optimizer.zero_grad()
464
+ accum_loss = 0.0
465
+ accum_aux = 0.0
466
+ accum_align = 0.0
467
+
468
+ for mb in micro_batches:
469
+ mb_ids = mb["input_ids"].to(device)
470
+ mb_labels = mb["labels"].to(device)
471
+ word_positions = None
472
+ if word_start_table is not None:
473
+ word_positions = compute_word_positions(mb_ids, word_start_table)
474
+ if use_amp:
475
+ with autocast('cuda', dtype=amp_dtype):
476
+ output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
477
+ else:
478
+ output = model(mb_ids, labels=mb_labels, word_positions=word_positions)
479
+ if scaler:
480
+ scaler.scale(output["loss"] / n_micro).backward()
481
+ else:
482
+ (output["loss"] / n_micro).backward()
483
+ accum_loss += output["loss"].item()
484
+ if "aux_loss" in output:
485
+ accum_aux += output["aux_loss"].item()
486
+ if "align_loss" in output:
487
+ accum_align += output["align_loss"].item()
488
+
489
+ if scaler:
490
+ scaler.unscale_(optimizer)
491
+ clip_params = list(raw_model.trainable_parameters()) if model_type == "graft_g2lu" else model.parameters()
492
+ grad_norm = nn.utils.clip_grad_norm_(clip_params, config.grad_clip).item()
493
+ if scaler:
494
+ scaler.step(optimizer)
495
+ scaler.update()
496
+ else:
497
+ optimizer.step()
498
+ optimizer.zero_grad()
499
+
500
+ loss_val = accum_loss / n_micro
501
+ aux_loss_val = accum_aux / n_micro if accum_aux > 0 else None
502
+ align_loss_val = accum_align / n_micro if accum_align > 0 else None
503
+
504
+ total_loss += loss_val
505
+ epoch_loss += loss_val
506
+ epoch_steps += 1
507
+ log_steps += 1
508
+ total_tokens_seen += actual_tokens
509
+ step += 1
510
+
511
+ # Logging
512
+ if step % config.log_every == 0:
513
+ avg_loss = total_loss / max(log_steps, 1)
514
+ ppl = math.exp(min(avg_loss, 20))
515
+ elapsed = time.time() - start_time
516
+ tok_s = (log_steps * tokens_per_step) / max(elapsed, 1e-6)
517
+
518
+ extra = ""
519
+ if aux_loss_val is not None:
520
+ extra += f" | Aux {aux_loss_val:.3f}"
521
+ if align_loss_val is not None:
522
+ extra += f" | Align {align_loss_val:.4f}"
523
+
524
+ print(
525
+ f"Step {step:6d} | "
526
+ f"Epoch {epoch+1}/{config.epochs} | "
527
+ f"Loss {avg_loss:.4f} | "
528
+ f"PPL {ppl:8.2f} | "
529
+ f"GradN {grad_norm:.3f} | "
530
+ f"LR {lr:.2e} | "
531
+ f"Tok/s {tok_s:.0f}"
532
+ f"{extra}"
533
+ )
534
+
535
+ total_loss = 0.0
536
+ log_steps = 0
537
+ start_time = time.time()
538
+
539
+ # Checkpointing
540
+ if step % config.save_every == 0:
541
+ ckpt_path = checkpoint_dir / f"step_{step:06d}.pt"
542
+ if model_type == "graft_g2lu":
543
+ save_g2lu_checkpoint(raw_model, optimizer, step, epoch, loss_val, str(ckpt_path),
544
+ epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
545
+ else:
546
+ save_checkpoint(model, optimizer, step, epoch, loss_val, model_config, str(ckpt_path), model_type,
547
+ epoch_step=epoch_micro_batches, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
548
+ print(f" Saved checkpoint: {ckpt_path}")
549
+ gc.collect()
550
+ torch.cuda.empty_cache()
551
+
552
+ # Mid-training validation
553
+ val_every = getattr(args, 'val_every', 0)
554
+ if val_every > 0 and step % val_every == 0 and val_dataloader:
555
+ val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, mid_run_eval=True, word_start_table=word_start_table)
556
+ avg_train = epoch_loss / max(epoch_steps, 1)
557
+ gap = val_loss - avg_train
558
+ print(f" [Val @ step {step}] Loss: {val_loss:.4f} | PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
559
+ if val_loss < best_val_loss:
560
+ best_val_loss = val_loss
561
+ best_path = checkpoint_dir / "best.pt"
562
+ if model_type == "graft_g2lu":
563
+ save_g2lu_checkpoint(raw_model, optimizer, step, epoch, val_loss, str(best_path),
564
+ epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
565
+ else:
566
+ save_checkpoint(model, optimizer, step, epoch, val_loss, model_config, str(best_path), model_type,
567
+ epoch_step=epoch_micro_batches, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
568
+ print(f" New best! Saved: {best_path}")
569
+ gc.collect()
570
+ torch.cuda.empty_cache()
571
+
572
+ micro_batches = []
573
+
574
+ # --- Epoch summary ---
575
+ epoch_elapsed = time.time() - epoch_start
576
+ avg_epoch_loss = epoch_loss / max(epoch_steps, 1)
577
+ epoch_ppl = math.exp(min(avg_epoch_loss, 20))
578
+
579
+ print(f"\n{'='*70}")
580
+ print(f"Epoch {epoch+1}/{config.epochs} complete in {epoch_elapsed:.0f}s")
581
+ print(f" Train loss: {avg_epoch_loss:.4f} | Train PPL: {epoch_ppl:.2f}")
582
+ print(f" Tokens seen: {total_tokens_seen:,} ({total_tokens_seen/1e6:.1f}M)")
583
+
584
+ # Validation
585
+ if val_dataloader:
586
+ val_loss, val_ppl = evaluate(config, model, val_dataloader, device, use_amp, amp_dtype, word_start_table=word_start_table)
587
+ gap = val_loss - avg_epoch_loss
588
+ print(f" Val loss: {val_loss:.4f} | Val PPL: {val_ppl:.2f} | Gap: {gap:+.4f}")
589
+
590
+ if val_loss < best_val_loss:
591
+ best_val_loss = val_loss
592
+ best_path = checkpoint_dir / "best.pt"
593
+ if model_type == "graft_g2lu":
594
+ save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, val_loss, str(best_path),
595
+ epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
596
+ else:
597
+ save_checkpoint(model, optimizer, step, epoch + 1, val_loss, model_config, str(best_path), model_type,
598
+ epoch_step=0, best_val_loss=val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
599
+ print(f" New best! Saved: {best_path}")
600
+ # Free validation tensors
601
+ gc.collect()
602
+ torch.cuda.empty_cache()
603
+ print(f"{'='*70}\n")
604
+
605
+ # Save epoch checkpoint
606
+ ckpt_path = checkpoint_dir / f"epoch_{epoch+1:02d}.pt"
607
+ if model_type == "graft_g2lu":
608
+ save_g2lu_checkpoint(raw_model, optimizer, step, epoch + 1, avg_epoch_loss, str(ckpt_path),
609
+ epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
610
+ else:
611
+ save_checkpoint(model, optimizer, step, epoch + 1, avg_epoch_loss, model_config, str(ckpt_path), model_type,
612
+ epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
613
+ gc.collect()
614
+ torch.cuda.empty_cache()
615
+
616
+ # Save final checkpoint
617
+ if step == start_step:
618
+ print(f"\nNo training performed (already at step {step}/{max_steps}).")
619
+ print(f" To train more epochs, increase --epochs beyond {config.epochs}.")
620
+ else:
621
+ final_path = checkpoint_dir / "latest.pt"
622
+ if model_type == "graft_g2lu":
623
+ save_g2lu_checkpoint(raw_model, optimizer, step, config.epochs, avg_epoch_loss, str(final_path),
624
+ epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
625
+ else:
626
+ save_checkpoint(model, optimizer, step, config.epochs, avg_epoch_loss, model_config, str(final_path), model_type,
627
+ epoch_step=0, best_val_loss=best_val_loss, scaler=scaler, tokenizer_name=tokenizer_name)
628
+ print(f"\nTraining complete.")
629
+ print(f" Final train loss: {avg_epoch_loss:.4f} | PPL: {epoch_ppl:.2f}")
630
+ if val_dataloader:
631
+ print(f" Best val loss: {best_val_loss:.4f} | PPL: {math.exp(min(best_val_loss, 20)):.2f}")
632
+ print(f" Total tokens: {total_tokens_seen:,}")
633
+ print(f" Checkpoints: {final_path}")
634
+
635
+
636
+ if __name__ == "__main__":
637
+ train()