prism-lab commited on
Commit
d23dcd6
·
verified ·
1 Parent(s): c3d749a

Delete WPT_Wikitext_103_Training.ipynb

Browse files
Files changed (1) hide show
  1. WPT_Wikitext_103_Training.ipynb +0 -1061
WPT_Wikitext_103_Training.ipynb DELETED
@@ -1,1061 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "A100"
8
- },
9
- "kernelspec": {
10
- "name": "python3",
11
- "display_name": "Python 3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
- "accelerator": "GPU"
17
- },
18
- "cells": [
19
- {
20
- "cell_type": "code",
21
- "source": [
22
- "!pip install -q x-transformers\n",
23
- "!pip install -q flash-attn --no-build-isolation"
24
- ],
25
- "metadata": {
26
- "id": "6q9RTvlf5IiS"
27
- },
28
- "execution_count": null,
29
- "outputs": []
30
- },
31
- {
32
- "cell_type": "code",
33
- "source": [
34
- "import torch\n",
35
- "import torch.nn as nn\n",
36
- "import torch.nn.functional as F\n",
37
- "import torch.optim as optim\n",
38
- "import math\n",
39
- "import os\n",
40
- "import sys\n",
41
- "import subprocess\n",
42
- "import hashlib\n",
43
- "import gc\n",
44
- "from datetime import datetime\n",
45
- "from tqdm.auto import tqdm\n",
46
- "from torch.utils.data import DataLoader\n",
47
- "from torch.utils.tensorboard import SummaryWriter\n",
48
- "from transformers import RobertaTokenizerFast, get_cosine_schedule_with_warmup, DataCollatorForLanguageModeling\n",
49
- "from datasets import load_dataset\n",
50
- "from x_transformers import Encoder\n",
51
- "\n",
52
- "# ==========================================\n",
53
- "# 1. CONFIGURATION\n",
54
- "# ==========================================\n",
55
- "# YOUR REPO ID (Created in previous step)\n",
56
- "HF_ID = \"prism-lab/wikitext-103-prism-32k-seq4k\"\n",
57
- "\n",
58
- "# Hyperparameters\n",
59
- "VOCAB_SIZE = 32768\n",
60
- "SEQ_LEN = 4096\n",
61
- "BATCH_SIZE = 8\n",
62
- "EPOCHS = 40\n",
63
- "LR = 1e-3\n",
64
- "D_MODEL = 512\n",
65
- "D_BRANCH = 256\n",
66
- "DEPTH = 6\n",
67
- "RESUME_PATH = None #\"/content/drive/MyDrive/PRISM_Experiments/PILLARS_SplitStream_8Layer_20260116_025321_8438ce62/last.pt\"\n",
68
- "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
69
- "torch.set_float32_matmul_precision(\"high\")\n",
70
- "\n",
71
- "# ==========================================\n",
72
- "# 2. DATA PIPELINE (The \"Pro\" Way)\n",
73
- "# ==========================================\n",
74
- "def prepare_data_from_hub():\n",
75
- " print(f\"⬇️ Pulling Pre-Tokenized Data from {HF_ID}...\")\n",
76
- "\n",
77
- " # 1. Load Tokenizer (Instant)\n",
78
- " # This pulls the exact tokenizer you uploaded\n",
79
- " tokenizer = RobertaTokenizerFast.from_pretrained(HF_ID)\n",
80
- "\n",
81
- " # 2. Load Dataset (Instant)\n",
82
- " # This pulls the already chunked/tokenized data\n",
83
- " dataset = load_dataset(HF_ID)\n",
84
- "\n",
85
- " print(f\"✅ Loaded {len(dataset['train'])} training chunks.\")\n",
86
- "\n",
87
- " # 3. Collator\n",
88
- " data_collator = DataCollatorForLanguageModeling(\n",
89
- " tokenizer=tokenizer,\n",
90
- " mlm=True,\n",
91
- " mlm_probability=0.15\n",
92
- " )\n",
93
- "\n",
94
- " return dataset, data_collator\n",
95
- "# ==========================================\n",
96
- "# 3. PRISM ARCHITECTURE (Complex-Valued)\n",
97
- "# ==========================================\n",
98
- "\n",
99
- "class ComplexDropout(nn.Module):\n",
100
- " def __init__(self, p=0.5):\n",
101
- " super().__init__()\n",
102
- " self.p = p\n",
103
- " def forward(self, z):\n",
104
- " if not self.training or self.p == 0.0: return z\n",
105
- " mask = torch.ones_like(z.real)\n",
106
- " mask = F.dropout(mask, self.p, self.training, inplace=False)\n",
107
- " return z * mask\n",
108
- "\n",
109
- "class RobustPhaseNorm(nn.Module):\n",
110
- " def __init__(self, d_model, eps=1e-5):\n",
111
- " super().__init__()\n",
112
- " self.scale = nn.Parameter(torch.ones(d_model))\n",
113
- " self.eps = eps\n",
114
- " def forward(self, x):\n",
115
- " mag = torch.abs(x)\n",
116
- " rms = torch.sqrt(torch.mean(mag**2, dim=-1, keepdim=True) + self.eps)\n",
117
- " return (x / rms) * self.scale\n",
118
- "\n",
119
- "class ModReLU(nn.Module):\n",
120
- " def __init__(self, features):\n",
121
- " super().__init__()\n",
122
- " self.b = nn.Parameter(torch.zeros(features))\n",
123
- "\n",
124
- " def forward(self, z):\n",
125
- " # 1. FORCE FLOAT32 FOR GEOMETRY\n",
126
- " # We must calculate magnitude in high precision to prevent\n",
127
- " # square-law overflow (Re^2 + Im^2) from killing the gradients.\n",
128
- " z_32 = z.to(torch.complex64)\n",
129
- "\n",
130
- " # 2. Calculate Magnitude (Safe)\n",
131
- " mag = torch.abs(z_32)\n",
132
- "\n",
133
- " # 3. Activation Logic (Still FP32)\n",
134
- " new_mag = F.relu(mag + self.b.float())\n",
135
- "\n",
136
- " # 4. Reconstruct Phase (Safe Division)\n",
137
- " # (z / mag) is the unit vector (phase)\n",
138
- " phase = z_32 / (mag + 1e-6)\n",
139
- "\n",
140
- " # 5. Result\n",
141
- " out = new_mag * phase\n",
142
- "\n",
143
- " # 6. Cast back to network dtype (BF16/FP16)\n",
144
- " return out.to(z.dtype)\n",
145
- "\n",
146
- "class ComplexToRealBridge(nn.Module):\n",
147
- " def __init__(self, d_model):\n",
148
- " super().__init__()\n",
149
- " self.proj = nn.Linear(d_model * 2, d_model)\n",
150
- " self.norm = nn.LayerNorm(d_model)\n",
151
- " def forward(self, x_complex):\n",
152
- " cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)\n",
153
- " return self.norm(self.proj(cat))\n",
154
- "\n",
155
- "# ==========================================\n",
156
- "# 4. DYNAMIC RoSE (Mamba-3 Engine)\n",
157
- "# ==========================================\n",
158
- "class DynamicRoSE(nn.Module):\n",
159
- " def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):\n",
160
- " super().__init__()\n",
161
- " self.embedding_dim = embedding_dim\n",
162
- "\n",
163
- " # 1. Master Real Embedding (The \"Particle\")\n",
164
- " self.raw_embedding = nn.Embedding(num_embeddings, embedding_dim)\n",
165
- "\n",
166
- " # 2. Complex Adapter (The \"Wave\" Magnitude/Initial Phase)\n",
167
- " self.adapter = nn.Linear(embedding_dim, embedding_dim * 2)\n",
168
- "\n",
169
- " # 3. Static Frequencies (Positional)\n",
170
- " freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))\n",
171
- " self.register_buffer('freqs', freqs)\n",
172
- "\n",
173
- " self.rotation_predictor = nn.Linear(embedding_dim, embedding_dim * 2)\n",
174
- "\n",
175
- " def forward(self, input_ids):\n",
176
- " # A. Raw Particle\n",
177
- " real_base = self.raw_embedding(input_ids)\n",
178
- " B, L, D = real_base.shape\n",
179
- "\n",
180
- " # B. Complex Wave Content\n",
181
- " complex_params = self.adapter(real_base)\n",
182
- " z_t = torch.complex(complex_params[..., :D], complex_params[..., D:])\n",
183
- "\n",
184
- " rot_raw = self.rotation_predictor(real_base)\n",
185
- " rot_x, rot_y = rot_raw.chunk(2, dim=-1)\n",
186
- "\n",
187
- " rot_mag = torch.sqrt(rot_x**2 + rot_y**2 + 1e-6)\n",
188
- " dynamic_rot = torch.complex(rot_x / rot_mag, rot_y / rot_mag)\n",
189
- "\n",
190
- " # D. Static Positional Rotation\n",
191
- " pos = torch.arange(L, device=input_ids.device).float()\n",
192
- " static_angles = torch.outer(pos, self.freqs) # [L, D]\n",
193
- " static_rot = torch.polar(torch.ones_like(static_angles), static_angles) # [L, D]\n",
194
- "\n",
195
- " z_final = z_t * static_rot.unsqueeze(0) * dynamic_rot\n",
196
- "\n",
197
- " return z_final, real_base\n",
198
- "\n",
199
- "# ==========================================\n",
200
- "# 5. HYENA FILTER\n",
201
- "# ==========================================\n",
202
- "class HyenaNeuralFilter(nn.Module):\n",
203
- " def __init__(self, d_model, max_len=1024, hidden_dim=64):\n",
204
- " super().__init__()\n",
205
- " self.d_model = d_model\n",
206
- " freqs = torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / hidden_dim))\n",
207
- " self.register_buffer(\"freqs\", freqs)\n",
208
- " self.mlp = nn.Sequential(\n",
209
- " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n",
210
- " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n",
211
- " nn.Linear(hidden_dim, d_model * 2)\n",
212
- " )\n",
213
- " def forward(self, L, device):\n",
214
- " t = torch.linspace(0, 1, steps=L, device=device).unsqueeze(-1)\n",
215
- " emb = torch.cat([torch.sin(t * self.freqs), torch.cos(t * self.freqs)], dim=-1)\n",
216
- " out = self.mlp(emb).view(L, self.d_model, 2)\n",
217
- " return torch.complex(out[..., 0], out[..., 1])\n",
218
- "\n",
219
- "# ==========================================\n",
220
- "# 6. GATED HARMONIC CONVOLUTION (Lean)\n",
221
- "# ==========================================\n",
222
- "# @title 🛠️ Fixed PRISM Layer (Precision-Gated)\n",
223
- "\n",
224
- "# @title 🛠️ Fixed PRISM Layer (Type-Safe)\n",
225
- "\n",
226
- "class GatedHarmonicConvolution(nn.Module):\n",
227
- " def __init__(self, d_model, max_len=1024, dropout=0.1):\n",
228
- " super().__init__()\n",
229
- " self.d_model = d_model\n",
230
- " self.filter_len = max_len\n",
231
- " self.neural_filter = HyenaNeuralFilter(d_model, max_len=max_len)\n",
232
- " self.gate_proj = nn.Linear(d_model * 2, d_model * 2)\n",
233
- "\n",
234
- " self.mix_real = nn.Linear(d_model, d_model)\n",
235
- " self.mix_imag = nn.Linear(d_model, d_model)\n",
236
- " self.out_real = nn.Linear(d_model, d_model)\n",
237
- " self.out_imag = nn.Linear(d_model, d_model)\n",
238
- "\n",
239
- " self.activation = ModReLU(d_model)\n",
240
- " self.norm = RobustPhaseNorm(d_model)\n",
241
- " self.dropout = ComplexDropout(dropout)\n",
242
- "\n",
243
- " def forward(self, x, src_mask=None):\n",
244
- " residual = x\n",
245
- " x_norm = self.norm(x)\n",
246
- " if src_mask is not None:\n",
247
- " x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0)\n",
248
- "\n",
249
- " # 🛑 PRECISION GATE 🛑\n",
250
- " # Force operations to Float32 Complex to preserve Phase Physics\n",
251
- " with torch.amp.autocast('cuda', enabled=False):\n",
252
- "\n",
253
- " # --- THE FIX IS HERE ---\n",
254
- " # Old: x_32 = x_norm.float() <-- This stripped the imaginary part\n",
255
- " # New: Explicit cast to Complex64\n",
256
- " x_32 = x_norm.to(torch.complex64)\n",
257
- " # -----------------------\n",
258
- "\n",
259
- " B, L, D = x_32.shape\n",
260
- " eff_L = min(L, self.filter_len)\n",
261
- "\n",
262
- " # 1. FFT (Now safe because x_32 is definitely complex)\n",
263
- " x_freq = torch.fft.fft(x_32, n=eff_L, dim=1, norm='ortho')\n",
264
- "\n",
265
- " # 2. Filter (Ensure filter is also complex64)\n",
266
- " h = self.neural_filter(eff_L, x.device).unsqueeze(0).to(torch.complex64)\n",
267
- " x_filtered = x_freq * h\n",
268
- "\n",
269
- " # 3. IFFT\n",
270
- " x_time = torch.fft.ifft(x_filtered, n=eff_L, dim=1, norm='ortho')\n",
271
- "\n",
272
- " if L > eff_L: x_time = F.pad(x_time, (0,0,0,L-eff_L))\n",
273
- " else: x_time = x_time[:, :L, :]\n",
274
- "\n",
275
- " # 4. Gating (Sigmoid logic)\n",
276
- " # Safe concatenation because x_32 is complex64\n",
277
- " x_cat = torch.cat([x_32.real, x_32.imag], dim=-1)\n",
278
- "\n",
279
- " # Cast weights to Float32 for the calculation\n",
280
- " gate_w = self.gate_proj.weight.to(torch.float32)\n",
281
- " gate_b = self.gate_proj.bias.to(torch.float32)\n",
282
- "\n",
283
- " gate_out = F.linear(x_cat, gate_w, gate_b)\n",
284
- " gates = torch.sigmoid(gate_out)\n",
285
- "\n",
286
- " g_r, g_i = gates.chunk(2, dim=-1)\n",
287
- " x_gated_32 = torch.complex(x_time.real * g_r, x_time.imag * g_i)\n",
288
- "\n",
289
- " # 🏁 EXIT GATE: Cast back to original dtype (likely BFloat16 from autocast)\n",
290
- " # We cast real/imag separately to be safe\n",
291
- " target_dtype = x.dtype\n",
292
- " # If x was complex, target is complex. If x was real, we have an issue.\n",
293
- " # Assuming x comes from autocast, it might be complex16.\n",
294
- "\n",
295
- " x_gated = x_gated_32.to(target_dtype)\n",
296
- "\n",
297
- " # 5. Mixing (Back in mixed precision)\n",
298
- " mr, mi = self.mix_real, self.mix_imag\n",
299
- " x_mixed = torch.complex(mr(x_gated.real) - mi(x_gated.imag), mr(x_gated.imag) + mi(x_gated.real))\n",
300
- "\n",
301
- " x_act = self.activation(x_mixed)\n",
302
- "\n",
303
- " or_, oi = self.out_real, self.out_imag\n",
304
- " out = torch.complex(or_(x_act.real) - oi(x_act.imag), or_(x_act.imag) + oi(x_act.real))\n",
305
- "\n",
306
- " return self.dropout(out) + residual\n",
307
- "# ==========================================\n",
308
- "# 7. MODEL WRAPPERS\n",
309
- "# ==========================================\n",
310
- "class PRISMEncoder(nn.Module):\n",
311
- " def __init__(self, num_layers, d_model, max_len, dropout=0.1):\n",
312
- " super().__init__()\n",
313
- " self.layers = nn.ModuleList([\n",
314
- " GatedHarmonicConvolution(d_model, max_len, dropout)\n",
315
- " for _ in range(num_layers)\n",
316
- " ])\n",
317
- " self.final_norm = RobustPhaseNorm(d_model)\n",
318
- " def forward(self, x, src_mask=None):\n",
319
- " for layer in self.layers:\n",
320
- " if self.training: x = torch.utils.checkpoint.checkpoint(layer, x, src_mask, use_reentrant=False)\n",
321
- " else: x = layer(x, src_mask)\n",
322
- " return self.final_norm(x)\n",
323
- "\n",
324
- "class PRISM_WikiText_Model(nn.Module):\n",
325
- " def __init__(self, vocab_size, d_model, max_len, prism_depth=5, trans_depth=1, dropout=0.1):\n",
326
- " super().__init__()\n",
327
- " self.d_model = d_model\n",
328
- "\n",
329
- " # 1. PRISM Core (The Optical/Passive Part)\n",
330
- " self.rose = DynamicRoSE(vocab_size, d_model)\n",
331
- " self.prism_encoder = PRISMEncoder(prism_depth, d_model, max_len=max_len, dropout=dropout)\n",
332
- " self.bridge = ComplexToRealBridge(d_model)\n",
333
- " self.periscope_proj = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU())\n",
334
- "\n",
335
- " # 2. Refiner (The Digital/Active Part)\n",
336
- " # 🔄 SWAPPED: Replaced Standard Transformer with RoPE-Enabled Encoder\n",
337
- " if trans_depth > 0:\n",
338
- " self.refiner = Encoder(\n",
339
- " dim=d_model,\n",
340
- " depth=trans_depth,\n",
341
- " heads=8,\n",
342
- " rotary_pos_emb=True,\n",
343
- " attn_flash=True,\n",
344
- " attn_dropout=dropout,\n",
345
- " ff_dropout=dropout,\n",
346
- "\n",
347
- " )\n",
348
- " else:\n",
349
- " self.refiner = None\n",
350
- "\n",
351
- " # 3. Output\n",
352
- " self.lm_head = nn.Linear(d_model, vocab_size)\n",
353
- " self.lm_head.weight = self.rose.raw_embedding.weight\n",
354
- "\n",
355
- " def forward(self, input_ids):\n",
356
- " # A. Wave Physics\n",
357
- " wave_src, particle_src = self.rose(input_ids)\n",
358
- " wave_out = self.prism_encoder(wave_src)\n",
359
- " wave_real = self.bridge(wave_out)\n",
360
- "\n",
361
- " # B. Interface\n",
362
- " mixed_memory = self.periscope_proj(torch.cat([wave_real, particle_src], dim=-1))\n",
363
- "\n",
364
- " # C. Digital Refinement (Now with RoPE)\n",
365
- " if self.refiner:\n",
366
- " out = self.refiner(mixed_memory)\n",
367
- " else:\n",
368
- " out = mixed_memory\n",
369
- "\n",
370
- " return self.lm_head(out)\n",
371
- "\n",
372
- "# ==========================================\n",
373
- "# 1. SENSORY STREAM (Transformer + RoPE)\n",
374
- "# ==========================================\n",
375
- "class SensoryStream(nn.Module):\n",
376
- " def __init__(self, depth, d_model, dropout=0.1):\n",
377
- " super().__init__()\n",
378
- " self.encoder = Encoder(\n",
379
- " dim=d_model,\n",
380
- " depth=depth,\n",
381
- " heads=4, # 256 dim / 64 head_dim = 4 heads\n",
382
- " attn_flash=True, # Flash Attention\n",
383
- " rotary_pos_emb=True, # <--- CRITICAL: RoPE Enabled\n",
384
- " attn_dropout=dropout,\n",
385
- " ff_dropout=dropout,\n",
386
- " use_rmsnorm=True, # RMSNorm (Llama style)\n",
387
- " ff_glu=True # SwiGLU (Llama style)\n",
388
- " )\n",
389
- "\n",
390
- " def forward(self, x):\n",
391
- " return self.encoder(x)\n",
392
- "\n",
393
- "# ==========================================\n",
394
- "# 2. PILLARS-DAT (Dual Attention with RoPE)\n",
395
- "# ==========================================\n",
396
- "class Pillars_DAT(nn.Module):\n",
397
- " def __init__(self, vocab_size, d_model=512, d_branch=256, seq_len=4096, depth=4):\n",
398
- " super().__init__()\n",
399
- " self.d_model = d_model\n",
400
- " self.d_branch = d_branch\n",
401
- "\n",
402
- " # --- A. SHARED ROOT ---\n",
403
- " self.rose = DynamicRoSE(vocab_size, d_model)\n",
404
- "\n",
405
- " # --- B. DOWNSAMPLE ---\n",
406
- " self.particle_down = nn.Linear(d_model, d_branch)\n",
407
- " self.wave_down = nn.Linear(d_model * 2, d_branch * 2)\n",
408
- "\n",
409
- " # --- C. STREAM 1: SENSORY (Object Attributes) ---\n",
410
- " # REPLACED: FNet -> Transformer with RoPE\n",
411
- " # NOTE: No self.sensory_pos anymore! RoPE handles it.\n",
412
- " self.stream_sensory = SensoryStream(depth=depth, d_model=d_branch, dropout=0.1)\n",
413
- "\n",
414
- " # --- D. STREAM 2: RELATIONAL (Structure / Phase) ---\n",
415
- " # PRISM handles positions internally via RoSE frequencies\n",
416
- " self.stream_relational = PRISMEncoder(num_layers=depth, d_model=d_branch, max_len=seq_len, dropout=0.1)\n",
417
- " self.relational_bridge = ComplexToRealBridge(d_branch)\n",
418
- "\n",
419
- " # --- E. FUSION ---\n",
420
- " self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n",
421
- " self.fusion_norm = nn.LayerNorm(d_model)\n",
422
- "\n",
423
- " # --- F. REFINER ---\n",
424
- " self.refiner = Encoder(\n",
425
- " dim=d_model, depth=1, heads=8, attn_flash=True,\n",
426
- " rotary_pos_emb=True, attn_dropout=0.1, ff_dropout=0.1\n",
427
- " )\n",
428
- "\n",
429
- " # --- G. OUTPUT ---\n",
430
- " self.head_bias = nn.Parameter(torch.zeros(vocab_size))\n",
431
- "\n",
432
- " def forward(self, input_ids):\n",
433
- " # 1. Root Physics\n",
434
- " wave_src, particle_src = self.rose(input_ids)\n",
435
- "\n",
436
- " # 2. Downsample\n",
437
- " p_small = self.particle_down(particle_src)\n",
438
- "\n",
439
- " # Prepare complex wave input\n",
440
- " w_flat = torch.cat([wave_src.real, wave_src.imag], dim=-1)\n",
441
- " w_small_flat = self.wave_down(w_flat)\n",
442
- " w_small = torch.complex(w_small_flat[..., :self.d_branch], w_small_flat[..., self.d_branch:])\n",
443
- "\n",
444
- " # 3. Parallel Processing\n",
445
- "\n",
446
- " # --- Stream A: Sensory (Transformer + RoPE) ---\n",
447
- " # Pass pure features. RoPE adds position info inside the attention layer.\n",
448
- " sensory_out = self.stream_sensory(p_small)\n",
449
- "\n",
450
- " # --- Stream B: Relational (PRISM) ---\n",
451
- " relational_out_complex = self.stream_relational(w_small)\n",
452
- " relational_out = self.relational_bridge(relational_out_complex)\n",
453
- "\n",
454
- " # 4. Integration\n",
455
- " stacked = torch.cat([sensory_out, relational_out], dim=-1)\n",
456
- " context = self.fusion_norm(self.fusion_proj(stacked))\n",
457
- "\n",
458
- " # 5. Refinement\n",
459
- " refined = self.refiner(context)\n",
460
- "\n",
461
- " # 6. Output\n",
462
- " logits = F.linear(refined, self.rose.raw_embedding.weight, self.head_bias)\n",
463
- "\n",
464
- " return logits\n",
465
- "\n",
466
- "import torch\n",
467
- "import torch.nn as nn\n",
468
- "from prettytable import PrettyTable # Optional, but makes tables nice.\n",
469
- "# If you don't have prettytable, the code below uses standard f-strings.\n",
470
- "\n",
471
- "import torch\n",
472
- "import torch.nn as nn\n",
473
- "\n",
474
- "import torch\n",
475
- "import torch.nn as nn\n",
476
- "\n",
477
- "def deep_analyze_pillars(model):\n",
478
- " def get_p(obj):\n",
479
- " \"\"\"Safely returns parameter count for Modules OR raw Parameters.\"\"\"\n",
480
- " if isinstance(obj, nn.Parameter):\n",
481
- " return obj.numel()\n",
482
- " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n",
483
- "\n",
484
- " def format_num(n):\n",
485
- " if n > 1e6: return f\"{n/1e6:.2f}M\"\n",
486
- " if n > 1e3: return f\"{n/1e3:.2f}K\"\n",
487
- " return str(n)\n",
488
- "\n",
489
- " print(\"\\n\" + \"=\"*80)\n",
490
- " print(f\"🏗️ PILLARS (COMPACT) - DEEP LAYER ANALYSIS\")\n",
491
- " print(\"=\"*80)\n",
492
- " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<15} | {'TYPE'}\")\n",
493
- " print(\"-\" * 80)\n",
494
- "\n",
495
- " total_params = get_p(model)\n",
496
- "\n",
497
- " # -----------------------------------------------\n",
498
- " # 1. STATIC MEMORY (Embeddings)\n",
499
- " # -----------------------------------------------\n",
500
- " vocab_emb = get_p(model.rose.raw_embedding)\n",
501
- " fnet_pos = get_p(model.fnet_pos)\n",
502
- "\n",
503
- " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<15} | 💾 STORAGE\")\n",
504
- " print(f\"{'FNet Positional Embedding':<40} | {format_num(fnet_pos):<15} | 💾 STORAGE\")\n",
505
- "\n",
506
- " # -----------------------------------------------\n",
507
- " # 2. INPUT LOGIC (RoSE & Downsampling)\n",
508
- " # -----------------------------------------------\n",
509
- " rose_total = get_p(model.rose)\n",
510
- " rose_logic = rose_total - vocab_emb # Subtract the embedding matrix we already counted\n",
511
- "\n",
512
- " print(\"-\" * 80)\n",
513
- " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<15} | 🌊 PHASE INIT\")\n",
514
- " print(f\"{'Particle Downsample (512->384)':<40} | {format_num(get_p(model.particle_down)):<15} | 📉 PROJ\")\n",
515
- " print(f\"{'Wave Downsample (1024->768)':<40} | {format_num(get_p(model.wave_down)):<15} | 📉 PROJ\")\n",
516
- "\n",
517
- " # -----------------------------------------------\n",
518
- " # 3. STREAM A: RATE (FNet)\n",
519
- " # -----------------------------------------------\n",
520
- " print(\"-\" * 80)\n",
521
- " print(f\"TRACK A: RATE STREAM (FNet) - Depth {len(model.stream_rate.layers)}\")\n",
522
- "\n",
523
- " fnet_encoder_total = 0\n",
524
- " for i, layer in enumerate(model.stream_rate.layers):\n",
525
- " p = get_p(layer)\n",
526
- " fnet_encoder_total += p\n",
527
- " print(f\" ├─ FNet Block {i:<24} | {format_num(p):<15} | ⚡ RATE\")\n",
528
- "\n",
529
- " fnet_norm = get_p(model.stream_rate.norm_out)\n",
530
- " fnet_encoder_total += fnet_norm\n",
531
- " print(f\" └─ Final Norm {i:<24} | {format_num(fnet_norm):<15} | ⚡ RATE\")\n",
532
- "\n",
533
- " # -----------------------------------------------\n",
534
- " # 4. STREAM B: PHASE (PRISM)\n",
535
- " # -----------------------------------------------\n",
536
- " print(\"-\" * 80)\n",
537
- " print(f\"TRACK B: PHASE STREAM (PRISM) - Depth {len(model.stream_phase.layers)}\")\n",
538
- "\n",
539
- " prism_encoder_total = 0\n",
540
- " for i, layer in enumerate(model.stream_phase.layers):\n",
541
- " p = get_p(layer)\n",
542
- " prism_encoder_total += p\n",
543
- " print(f\" ├─ PRISM Block {i:<23} | {format_num(p):<15} | 🌊 PHASE\")\n",
544
- "\n",
545
- " prism_norm = get_p(model.stream_phase.final_norm)\n",
546
- " prism_encoder_total += prism_norm\n",
547
- " print(f\" └─ Final Norm {i:<24} | {format_num(prism_norm):<15} | 🌊 PHASE\")\n",
548
- "\n",
549
- " bridge_p = get_p(model.phase_bridge)\n",
550
- " print(f\"{'Phase Bridge (Complex->Real)':<40} | {format_num(bridge_p):<15} | 🌉 BRIDGE\")\n",
551
- "\n",
552
- " # -----------------------------------------------\n",
553
- " # 5. THE BRAIN (Fusion & Refiner)\n",
554
- " # -----------------------------------------------\n",
555
- " print(\"-\" * 80)\n",
556
- " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n",
557
- " print(f\"{'Fusion (Concat -> Proj -> Norm)':<40} | {format_num(fusion_p):<15} | 🧠 FUSION\")\n",
558
- "\n",
559
- " refiner_p = get_p(model.refiner)\n",
560
- " print(f\"{'Transformer Refiner (1 Layer)':<40} | {format_num(refiner_p):<15} | 🧠 ATTENTION\")\n",
561
- "\n",
562
- " # [FIX] Handle nn.Parameter directly\n",
563
- " head_bias_p = get_p(model.head_bias)\n",
564
- " print(f\"{'Output Head Bias':<40} | {format_num(head_bias_p):<15} | 🎯 OUTPUT\")\n",
565
- "\n",
566
- " # -----------------------------------------------\n",
567
- " # 6. SUMMARY\n",
568
- " # -----------------------------------------------\n",
569
- " print(\"=\"*80)\n",
570
- "\n",
571
- " storage = vocab_emb + fnet_pos + head_bias_p\n",
572
- " active = total_params - storage\n",
573
- "\n",
574
- " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n",
575
- " print(f\" ├─ 💾 Storage: {storage/1e6:.2f} M (Embeddings)\")\n",
576
- " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Logic/Weights)\")\n",
577
- " print(\"-\" * 80)\n",
578
- " print(f\"STREAM BREAKDOWN:\")\n",
579
- " print(f\" ├─ ⚡ Rate Stream: {fnet_encoder_total/1e6:.2f} M\")\n",
580
- " print(f\" └─ 🌊 Phase Stream: {prism_encoder_total/1e6:.2f} M\")\n",
581
- " print(\"=\"*80 + \"\\n\")\n",
582
- "\n",
583
- " return total_params\n"
584
- ],
585
- "metadata": {
586
- "id": "V7DOwmmUjyin"
587
- },
588
- "execution_count": null,
589
- "outputs": []
590
- },
591
- {
592
- "cell_type": "code",
593
- "source": [
594
- "\n",
595
- "# Run the parameter analysis to confirm strict adherence to budget\n",
596
- "def deep_analyze_pillars_dat(model):\n",
597
- " def get_p(obj):\n",
598
- " if isinstance(obj, nn.Parameter): return obj.numel()\n",
599
- " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n",
600
- "\n",
601
- " def format_num(n):\n",
602
- " if n > 1e6: return f\"{n/1e6:.2f}M\"\n",
603
- " if n > 1e3: return f\"{n/1e3:.2f}K\"\n",
604
- " return str(n)\n",
605
- "\n",
606
- " print(\"\\n\" + \"=\"*80)\n",
607
- " print(f\"🏛️ PILLARS-DAT (Hybrid Transformer-PRISM) - ANALYSIS\")\n",
608
- " print(\"=\"*80)\n",
609
- " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<12} | {'TYPE'}\")\n",
610
- " print(\"-\" * 80)\n",
611
- "\n",
612
- " total_params = get_p(model)\n",
613
- "\n",
614
- " # --- 1. MEMORY ---\n",
615
- " vocab_emb = get_p(model.rose.raw_embedding)\n",
616
- " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<12} | 💾 STORAGE\")\n",
617
- "\n",
618
- " # --- 2. INPUT PHYSICS ---\n",
619
- " rose_logic = get_p(model.rose) - vocab_emb\n",
620
- " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<12} | 🌊 PHYSICS\")\n",
621
- "\n",
622
- " down_p = get_p(model.particle_down) + get_p(model.wave_down)\n",
623
- " print(f\"{'Stream Splitters (Downsample)':<40} | {format_num(down_p):<12} | 📉 PROJ\")\n",
624
- "\n",
625
- " # --- 3. STREAM A: SENSORY (TRANSFORMER) ---\n",
626
- " print(\"-\" * 80)\n",
627
- " print(f\"STREAM A: SENSORY (Identity/Magnitude)\")\n",
628
- " sensory_p = get_p(model.stream_sensory)\n",
629
- " # Attempt to count depth if accessible, else generic\n",
630
- " try:\n",
631
- " depth_s = len(model.stream_sensory.encoder.layers)\n",
632
- " print(f\" ├─ Transformer Encoder (Depth {depth_s}) | {format_num(sensory_p):<12} | ⚡ ATTENTION\")\n",
633
- " except:\n",
634
- " print(f\" ├─ Transformer Encoder (Fused) | {format_num(sensory_p):<12} | ⚡ ATTENTION\")\n",
635
- "\n",
636
- " # --- 4. STREAM B: RELATIONAL (PRISM) ---\n",
637
- " print(\"-\" * 80)\n",
638
- " print(f\"STREAM B: RELATIONAL (Structure/Phase)\")\n",
639
- " relational_core = get_p(model.stream_relational)\n",
640
- " relational_bridge = get_p(model.relational_bridge)\n",
641
- "\n",
642
- " try:\n",
643
- " depth_r = len(model.stream_relational.layers)\n",
644
- " print(f\" ├─ PRISM Encoder (Depth {depth_r}) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n",
645
- " except:\n",
646
- " print(f\" ├─ PRISM Encoder (Fused) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n",
647
- "\n",
648
- " print(f\" └─ Bridge (Complex->Real) | {format_num(relational_bridge):<12} | 🌉 PROJ\")\n",
649
- "\n",
650
- " # --- 5. FUSION & OUTPUT ---\n",
651
- " print(\"-\" * 80)\n",
652
- " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n",
653
- " print(f\"{'Fusion (Concat -> Proj)':<40} | {format_num(fusion_p):<12} | 🧠 MIX\")\n",
654
- "\n",
655
- " refiner_p = get_p(model.refiner)\n",
656
- " print(f\"{'Refiner (1-Layer Transformer)':<40} | {format_num(refiner_p):<12} | 🧠 REASONING\")\n",
657
- "\n",
658
- " bias_p = get_p(model.head_bias)\n",
659
- " print(f\"{'Output Head Bias':<40} | {format_num(bias_p):<12} | 🎯 OUT\")\n",
660
- "\n",
661
- " # --- SUMMARY ---\n",
662
- " print(\"=\"*80)\n",
663
- " storage = vocab_emb + bias_p\n",
664
- " active = total_params - storage\n",
665
- "\n",
666
- " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n",
667
- " print(f\" ├─ 💾 Storage: {storage/1e6:.2f} M (Embeddings)\")\n",
668
- " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Active Weights)\")\n",
669
- " print(\"-\" * 80)\n",
670
- " print(f\"RATIO CHECK:\")\n",
671
- " print(f\" ⚡ Sensory (Transf): {sensory_p/1e6:.2f} M\")\n",
672
- " print(f\" 🌊 Relation (PRISM): {(relational_core + relational_bridge)/1e6:.2f} M\")\n",
673
- " print(\"=\"*80 + \"\\n\")\n"
674
- ],
675
- "metadata": {
676
- "id": "ke4fYT8UX5zH"
677
- },
678
- "execution_count": null,
679
- "outputs": []
680
- },
681
- {
682
- "cell_type": "code",
683
- "source": [
684
- "# ==========================================\n",
685
- "# 4. LOGGING & ANALYSIS UTILITIES\n",
686
- "# ==========================================\n",
687
- "def deep_analyze_pillars_dat(model):\n",
688
- " def get_p(obj):\n",
689
- " if isinstance(obj, nn.Parameter): return obj.numel()\n",
690
- " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n",
691
- "\n",
692
- " def format_num(n):\n",
693
- " if n > 1e6: return f\"{n/1e6:.2f}M\"\n",
694
- " if n > 1e3: return f\"{n/1e3:.2f}K\"\n",
695
- " return str(n)\n",
696
- "\n",
697
- " print(\"\\n\" + \"=\"*80)\n",
698
- " print(f\"🏛️ PILLARS-DAT (Hybrid Transformer-PRISM) - ANALYSIS\")\n",
699
- " print(\"=\"*80)\n",
700
- " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<12} | {'TYPE'}\")\n",
701
- " print(\"-\" * 80)\n",
702
- "\n",
703
- " total_params = get_p(model)\n",
704
- "\n",
705
- " # --- 1. MEMORY ---\n",
706
- " vocab_emb = get_p(model.rose.raw_embedding)\n",
707
- " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<12} | 💾 STORAGE\")\n",
708
- "\n",
709
- " # --- 2. INPUT PHYSICS ---\n",
710
- " rose_logic = get_p(model.rose) - vocab_emb\n",
711
- " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<12} | 🌊 PHYSICS\")\n",
712
- "\n",
713
- " down_p = get_p(model.particle_down) + get_p(model.wave_down)\n",
714
- " print(f\"{'Stream Splitters (Downsample)':<40} | {format_num(down_p):<12} | 📉 PROJ\")\n",
715
- "\n",
716
- " # --- 3. STREAM A: SENSORY (TRANSFORMER) ---\n",
717
- " print(\"-\" * 80)\n",
718
- " print(f\"STREAM A: SENSORY (Identity/Magnitude)\")\n",
719
- " sensory_p = get_p(model.stream_sensory)\n",
720
- " try:\n",
721
- " depth_s = len(model.stream_sensory.encoder.layers)\n",
722
- " print(f\" ├─ Transformer Encoder (Depth {depth_s}) | {format_num(sensory_p):<12} | ⚡ ATTENTION\")\n",
723
- " except:\n",
724
- " print(f\" ├─ Transformer Encoder (Fused) | {format_num(sensory_p):<12} | ⚡ ATTENTION\")\n",
725
- "\n",
726
- " # --- 4. STREAM B: RELATIONAL (PRISM) ---\n",
727
- " print(\"-\" * 80)\n",
728
- " print(f\"STREAM B: RELATIONAL (Structure/Phase)\")\n",
729
- " relational_core = get_p(model.stream_relational)\n",
730
- " relational_bridge = get_p(model.relational_bridge)\n",
731
- "\n",
732
- " try:\n",
733
- " depth_r = len(model.stream_relational.layers)\n",
734
- " print(f\" ├─ PRISM Encoder (Depth {depth_r}) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n",
735
- " except:\n",
736
- " print(f\" ├─ PRISM Encoder (Fused) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n",
737
- "\n",
738
- " print(f\" └─ Bridge (Complex->Real) | {format_num(relational_bridge):<12} | 🌉 PROJ\")\n",
739
- "\n",
740
- " # --- 5. FUSION & OUTPUT ---\n",
741
- " print(\"-\" * 80)\n",
742
- " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n",
743
- " print(f\"{'Fusion (Concat -> Proj)':<40} | {format_num(fusion_p):<12} | 🧠 MIX\")\n",
744
- "\n",
745
- " refiner_p = get_p(model.refiner)\n",
746
- " print(f\"{'Refiner (1-Layer Transformer)':<40} | {format_num(refiner_p):<12} | 🧠 REASONING\")\n",
747
- "\n",
748
- " bias_p = get_p(model.head_bias)\n",
749
- " print(f\"{'Output Head Bias':<40} | {format_num(bias_p):<12} | 🎯 OUT\")\n",
750
- "\n",
751
- " # --- SUMMARY ---\n",
752
- " print(\"=\"*80)\n",
753
- " storage = vocab_emb + bias_p\n",
754
- " active = total_params - storage\n",
755
- "\n",
756
- " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n",
757
- " print(f\" ├─ 💾 Storage: {storage/1e6:.2f} M (Embeddings)\")\n",
758
- " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Active Weights)\")\n",
759
- " print(\"-\" * 80)\n",
760
- " print(f\"RATIO CHECK:\")\n",
761
- " print(f\" ⚡ Sensory (Transf): {sensory_p/1e6:.2f} M\")\n",
762
- " print(f\" 🌊 Relation (PRISM): {(relational_core + relational_bridge)/1e6:.2f} M\")\n",
763
- " print(\"=\"*80 + \"\\n\")\n",
764
- "\n",
765
- "def generate_run_id():\n",
766
- " raw = datetime.now().strftime(\"%Y%m%d%H%M%S%f\")\n",
767
- " return hashlib.md5(raw.encode()).hexdigest()[:8]\n",
768
- "\n",
769
- "def log_environment(save_dir, run_id, config):\n",
770
- " log_path = os.path.join(save_dir, f\"env_metadata_{run_id}.txt\")\n",
771
- " with open(log_path, \"w\") as f:\n",
772
- " f.write(f\"PRISM EXPERIMENT METADATA | Run ID: {run_id}\\n{'='*50}\\n\")\n",
773
- " for k, v in config.items(): f.write(f\"{k}: {v}\\n\")\n",
774
- " print(f\"📝 Environment Snapshot saved to: {log_path}\")\n",
775
- "\n",
776
- "def log_metrics(save_dir, run_id, epoch, train_loss, val_loss, ppl):\n",
777
- " log_path = os.path.join(save_dir, f\"metrics_log_{run_id}.csv\")\n",
778
- " if not os.path.exists(log_path):\n",
779
- " with open(log_path, \"w\") as f: f.write(\"Timestamp,Epoch,Train_Loss,Val_Loss,Perplexity\\n\")\n",
780
- " with open(log_path, \"a\") as f:\n",
781
- " ts = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
782
- " f.write(f\"{ts},{epoch},{train_loss:.6f},{val_loss:.6f},{ppl:.6f}\\n\")\n",
783
- "\n",
784
- "def save_checkpoint(path, model, optimizer, scheduler, scaler, epoch, best_loss, config):\n",
785
- " torch.save({\n",
786
- " 'epoch': epoch,\n",
787
- " 'model_state_dict': model.state_dict(),\n",
788
- " 'optimizer_state_dict': optimizer.state_dict(),\n",
789
- " 'scheduler_state_dict': scheduler.state_dict(),\n",
790
- " 'scaler_state_dict': scaler.state_dict(), # <--- IMPORTANT FOR AMP\n",
791
- " 'best_val_loss': best_loss,\n",
792
- " 'config': config\n",
793
- " }, path)\n",
794
- "\n",
795
- "# ==========================================\n",
796
- "# 5. A100 TRAINING LOOP (WITH LOGGING)\n",
797
- "# ==========================================\n",
798
- "# ==========================================\n",
799
- "# 4. LOGGING & ANALYSIS UTILITIES\n",
800
- "# ==========================================\n",
801
- "def deep_analyze_pillars_dat(model):\n",
802
- " def get_p(obj):\n",
803
- " if isinstance(obj, nn.Parameter): return obj.numel()\n",
804
- " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n",
805
- "\n",
806
- " def format_num(n):\n",
807
- " if n > 1e6: return f\"{n/1e6:.2f}M\"\n",
808
- " if n > 1e3: return f\"{n/1e3:.2f}K\"\n",
809
- " return str(n)\n",
810
- "\n",
811
- " print(\"\\n\" + \"=\"*80)\n",
812
- " print(f\"🏛️ PILLARS-DAT (Hybrid Transformer-PRISM) - ANALYSIS\")\n",
813
- " print(\"=\"*80)\n",
814
- " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<12} | {'TYPE'}\")\n",
815
- " print(\"-\" * 80)\n",
816
- "\n",
817
- " total_params = get_p(model)\n",
818
- "\n",
819
- " # --- 1. MEMORY ---\n",
820
- " vocab_emb = get_p(model.rose.raw_embedding)\n",
821
- " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<12} | 💾 STORAGE\")\n",
822
- "\n",
823
- " # --- 2. INPUT PHYSICS ---\n",
824
- " rose_logic = get_p(model.rose) - vocab_emb\n",
825
- " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<12} | 🌊 PHYSICS\")\n",
826
- "\n",
827
- " down_p = get_p(model.particle_down) + get_p(model.wave_down)\n",
828
- " print(f\"{'Stream Splitters (Downsample)':<40} | {format_num(down_p):<12} | 📉 PROJ\")\n",
829
- "\n",
830
- " # --- 3. STREAM A: SENSORY (TRANSFORMER) ---\n",
831
- " print(\"-\" * 80)\n",
832
- " print(f\"STREAM A: SENSORY (Identity/Magnitude)\")\n",
833
- " sensory_p = get_p(model.stream_sensory)\n",
834
- " try:\n",
835
- " depth_s = len(model.stream_sensory.encoder.layers)\n",
836
- " print(f\" ├─ Transformer Encoder (Depth {depth_s}) | {format_num(sensory_p):<12} | ⚡ ATTENTION\")\n",
837
- " except:\n",
838
- " print(f\" ├─ Transformer Encoder (Fused) | {format_num(sensory_p):<12} | ⚡ ATTENTION\")\n",
839
- "\n",
840
- " # --- 4. STREAM B: RELATIONAL (PRISM) ---\n",
841
- " print(\"-\" * 80)\n",
842
- " print(f\"STREAM B: RELATIONAL (Structure/Phase)\")\n",
843
- " relational_core = get_p(model.stream_relational)\n",
844
- " relational_bridge = get_p(model.relational_bridge)\n",
845
- "\n",
846
- " try:\n",
847
- " depth_r = len(model.stream_relational.layers)\n",
848
- " print(f\" ├─ PRISM Encoder (Depth {depth_r}) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n",
849
- " except:\n",
850
- " print(f\" ├─ PRISM Encoder (Fused) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n",
851
- "\n",
852
- " print(f\" └─ Bridge (Complex->Real) | {format_num(relational_bridge):<12} | 🌉 PROJ\")\n",
853
- "\n",
854
- " # --- 5. FUSION & OUTPUT ---\n",
855
- " print(\"-\" * 80)\n",
856
- " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n",
857
- " print(f\"{'Fusion (Concat -> Proj)':<40} | {format_num(fusion_p):<12} | 🧠 MIX\")\n",
858
- "\n",
859
- " refiner_p = get_p(model.refiner)\n",
860
- " print(f\"{'Refiner (1-Layer Transformer)':<40} | {format_num(refiner_p):<12} | 🧠 REASONING\")\n",
861
- "\n",
862
- " bias_p = get_p(model.head_bias)\n",
863
- " print(f\"{'Output Head Bias':<40} | {format_num(bias_p):<12} | 🎯 OUT\")\n",
864
- "\n",
865
- " # --- SUMMARY ---\n",
866
- " print(\"=\"*80)\n",
867
- " storage = vocab_emb + bias_p\n",
868
- " active = total_params - storage\n",
869
- "\n",
870
- " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n",
871
- " print(f\" ├─ 💾 Storage: {storage/1e6:.2f} M (Embeddings)\")\n",
872
- " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Active Weights)\")\n",
873
- " print(\"-\" * 80)\n",
874
- " print(f\"RATIO CHECK:\")\n",
875
- " print(f\" ⚡ Sensory (Transf): {sensory_p/1e6:.2f} M\")\n",
876
- " print(f\" 🌊 Relation (PRISM): {(relational_core + relational_bridge)/1e6:.2f} M\")\n",
877
- " print(\"=\"*80 + \"\\n\")\n",
878
- "\n",
879
- "def init_pillars_dat_weights(model):\n",
880
- " print(\"✨ APPLYING PILLARS-DAT INITIALIZATION PROTOCOL...\")\n",
881
- " # 1. SHARED ROOT (RoSE)\n",
882
- " nn.init.normal_(model.rose.raw_embedding.weight, std=model.d_model ** -0.5)\n",
883
- " nn.init.orthogonal_(model.rose.adapter.weight)\n",
884
- "\n",
885
- " # --- ROSE IDENTITY TRICK ---\n",
886
- " nn.init.normal_(model.rose.rotation_predictor.weight, std=0.01)\n",
887
- " with torch.no_grad():\n",
888
- " model.rose.rotation_predictor.bias[:model.d_model].fill_(1.0) # Real=1\n",
889
- " model.rose.rotation_predictor.bias[model.d_model:].fill_(0.0) # Imag=0\n",
890
- "\n",
891
- " # 2. DOWNSAMPLERS\n",
892
- " nn.init.orthogonal_(model.particle_down.weight, gain=1.414)\n",
893
- " nn.init.orthogonal_(model.wave_down.weight, gain=1.414)\n",
894
- "\n",
895
- " # 3. SENSORY STREAM (Transformer + RoPE)\n",
896
- " print(\" ├─ Initializing Sensory Stream (Transformer)...\")\n",
897
- " for name, p in model.stream_sensory.named_parameters():\n",
898
- " if p.dim() > 1:\n",
899
- " nn.init.xavier_uniform_(p)\n",
900
- " elif \"norm\" in name.lower() and p.dim() == 1:\n",
901
- " if \"weight\" in name: nn.init.ones_(p)\n",
902
- " if \"bias\" in name: nn.init.zeros_(p)\n",
903
- "\n",
904
- " # 4. RELATIONAL STREAM (PRISM)\n",
905
- " print(\" ├─ Initializing Relational Stream (PRISM)...\")\n",
906
- " for name, m in model.stream_relational.named_modules():\n",
907
- " if isinstance(m, nn.Linear):\n",
908
- " nn.init.xavier_uniform_(m.weight, gain=1.0)\n",
909
- " if m.bias is not None: nn.init.zeros_(m.bias)\n",
910
- " if isinstance(m, ModReLU):\n",
911
- " nn.init.constant_(m.b, 0.01)\n",
912
- "\n",
913
- " # 5. FUSION & REFINER\n",
914
- " nn.init.xavier_uniform_(model.fusion_proj.weight, gain=1.0)\n",
915
- " for p in model.refiner.parameters():\n",
916
- " if p.dim() > 1: nn.init.xavier_uniform_(p)\n",
917
- "\n",
918
- " # 6. TIED HEAD BIAS\n",
919
- " nn.init.zeros_(model.head_bias)\n",
920
- " print(\"✅ DAT INITIALIZATION COMPLETE.\")\n",
921
- "\n",
922
- "# ==========================================\n",
923
- "# 5. A100 TRAINING LOOP (WITH LOGGING)\n",
924
- "# ==========================================\n",
925
- "def run_a100_training(experiment_name=\"PILLARS_DAT_A100_Final\"):\n",
926
- " from torch.cuda.amp import autocast, GradScaler\n",
927
- " from torch.utils.tensorboard import SummaryWriter\n",
928
- "\n",
929
- " # --- 1. SETUP DRIVE & LOGGING ---\n",
930
- " from google.colab import drive\n",
931
- " if not os.path.exists('/content/drive'): drive.mount('/content/drive')\n",
932
- "\n",
933
- " run_id = generate_run_id()\n",
934
- " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
935
- " SAVE_DIR = os.path.join(\"/content/drive/My Drive/PRISM_Experiments\", f\"{experiment_name}_{timestamp}_{run_id}\")\n",
936
- " os.makedirs(SAVE_DIR, exist_ok=True)\n",
937
- "\n",
938
- " writer = SummaryWriter(log_dir=SAVE_DIR)\n",
939
- "\n",
940
- " # Config for Logs\n",
941
- " config_dump = {\n",
942
- " \"run_id\": run_id, \"batch_size\": 6, \"accum\": 8, \"d_model\": D_MODEL, \"depth\": DEPTH, \"seq_len\": SEQ_LEN\n",
943
- " }\n",
944
- " log_environment(SAVE_DIR, run_id, config_dump)\n",
945
- "\n",
946
- " # --- 2. MODEL & DATA ---\n",
947
- " SAFE_BATCH_SIZE = BATCH_SIZE\n",
948
- " GRAD_ACCUM = 4\n",
949
- " print(f\"\\n⚡ A100 DETECTED. CONFIGURING FLASH ATTENTION PIPELINE...\")\n",
950
- "\n",
951
- " lm_datasets, data_collator = prepare_data_from_hub()\n",
952
- " train_loader = DataLoader(lm_datasets[\"train\"], batch_size=SAFE_BATCH_SIZE, shuffle=True, collate_fn=data_collator, num_workers=4, pin_memory=True)\n",
953
- " valid_loader = DataLoader(lm_datasets[\"validation\"], batch_size=SAFE_BATCH_SIZE, collate_fn=data_collator, num_workers=2)\n",
954
- "\n",
955
- " model = Pillars_DAT(vocab_size=VOCAB_SIZE, d_model=D_MODEL, d_branch=D_BRANCH, seq_len=SEQ_LEN, depth=DEPTH).to(DEVICE)\n",
956
- " init_pillars_dat_weights(model)\n",
957
- " print(model)\n",
958
- " deep_analyze_pillars_dat(model) # <--- Parameter Analysis\n",
959
- "\n",
960
- " optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)\n",
961
- " total_steps = (len(train_loader) // GRAD_ACCUM) * EPOCHS\n",
962
- " warmup_steps = int(total_steps * 0.1)\n",
963
- " scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n",
964
- " criterion = nn.CrossEntropyLoss()\n",
965
- " scaler = GradScaler() # For AMP\n",
966
- "\n",
967
- " print(f\"\\n🚀 IGNITING FUSION DRIVE... Saving to: {SAVE_DIR}\")\n",
968
- "\n",
969
- " global_step = 0\n",
970
- " best_val_loss = float('inf')\n",
971
- "\n",
972
- " for epoch in range(EPOCHS):\n",
973
- " model.train()\n",
974
- " pbar = tqdm(train_loader, desc=f\"Ep {epoch+1}\")\n",
975
- "\n",
976
- " for step, batch in enumerate(pbar):\n",
977
- " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n",
978
- "\n",
979
- " # ⚡ AMP CONTEXT\n",
980
- " with autocast(dtype=torch.float16):\n",
981
- " logits = model(x).view(-1, VOCAB_SIZE)\n",
982
- " loss = criterion(logits, y.view(-1)) / GRAD_ACCUM\n",
983
- "\n",
984
- " scaler.scale(loss).backward()\n",
985
- "\n",
986
- " if (step + 1) % GRAD_ACCUM == 0:\n",
987
- " scaler.unscale_(optimizer)\n",
988
- " # 🛑 CALC GRAD NORM HERE FOR PBAR 🛑\n",
989
- " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
990
- "\n",
991
- " scaler.step(optimizer)\n",
992
- " scaler.update()\n",
993
- " scheduler.step()\n",
994
- " optimizer.zero_grad()\n",
995
- " global_step += 1\n",
996
- "\n",
997
- " # 📝 STEP LOGGING\n",
998
- " actual_loss = loss.item() * GRAD_ACCUM\n",
999
- " writer.add_scalar('Train/Loss', actual_loss, global_step)\n",
1000
- " writer.add_scalar('Train/GradNorm', grad_norm.item(), global_step)\n",
1001
- " writer.add_scalar('Train/LR', scheduler.get_last_lr()[0], global_step)\n",
1002
- "\n",
1003
- " # ✨ UPDATE PBAR WITH GNORM ✨\n",
1004
- " pbar.set_postfix({\n",
1005
- " 'loss': f\"{actual_loss:.4f}\",\n",
1006
- " 'gnorm': f\"{grad_norm.item():.2f}\"\n",
1007
- " })\n",
1008
- "\n",
1009
- " # --- VALIDATION ---\n",
1010
- " model.eval()\n",
1011
- " val_loss = 0\n",
1012
- " with torch.no_grad(), autocast():\n",
1013
- " for batch in valid_loader:\n",
1014
- " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n",
1015
- " val_loss += criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)).item()\n",
1016
- "\n",
1017
- " avg_val_loss = val_loss / len(valid_loader)\n",
1018
- " # Prevent overflow if loss is exploding\n",
1019
- " ppl = math.exp(avg_val_loss) if avg_val_loss < 20 else float('inf')\n",
1020
- "\n",
1021
- " print(f\"✨ Ep {epoch+1} | Val Loss: {avg_val_loss:.4f} | PPL: {ppl:.2f}\")\n",
1022
- "\n",
1023
- " # 📝 EPOCH LOGGING\n",
1024
- " writer.add_scalar('Val/Loss', avg_val_loss, epoch+1)\n",
1025
- " writer.add_scalar('Val/PPL', ppl, epoch+1)\n",
1026
- " log_metrics(SAVE_DIR, run_id, epoch+1, 0.0, avg_val_loss, ppl)\n",
1027
- "\n",
1028
- " # 💾 SAVE CHECKPOINTS (Includes Scaler/Optim/Sched)\n",
1029
- " save_checkpoint(os.path.join(SAVE_DIR, \"last.pt\"), model, optimizer, scheduler, scaler, epoch, best_val_loss, config_dump)\n",
1030
- "\n",
1031
- " if avg_val_loss < best_val_loss:\n",
1032
- " best_val_loss = avg_val_loss\n",
1033
- " print(f\" 🏆 New Best Model! Saving best.pt...\")\n",
1034
- " save_checkpoint(os.path.join(SAVE_DIR, \"best.pt\"), model, optimizer, scheduler, scaler, epoch, best_val_loss, config_dump)\n",
1035
- "\n",
1036
- " writer.close()\n",
1037
- " return model\n",
1038
- "\n",
1039
- "if __name__ == \"__main__\":\n",
1040
- " run_a100_training()"
1041
- ],
1042
- "metadata": {
1043
- "id": "-TNEv89gkS1k"
1044
- },
1045
- "execution_count": null,
1046
- "outputs": []
1047
- },
1048
- {
1049
- "cell_type": "code",
1050
- "source": [
1051
- "from google.colab import runtime\n",
1052
- "runtime.unassign()"
1053
- ],
1054
- "metadata": {
1055
- "id": "bxFTYWHVqcSI"
1056
- },
1057
- "execution_count": null,
1058
- "outputs": []
1059
- }
1060
- ]
1061
- }