{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install -q x-transformers\n", "!pip install -q flash-attn --no-build-isolation" ], "metadata": { "id": "6q9RTvlf5IiS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import math\n", "import os\n", "import sys\n", "import subprocess\n", "import hashlib\n", "import gc\n", "from datetime import datetime\n", "from tqdm.auto import tqdm\n", "from torch.utils.data import DataLoader\n", "from torch.utils.tensorboard import SummaryWriter\n", "from transformers import RobertaTokenizerFast, get_cosine_schedule_with_warmup, DataCollatorForLanguageModeling\n", "from datasets import load_dataset\n", "from x_transformers import Encoder\n", "\n", "# ==========================================\n", "# 1. CONFIGURATION\n", "# ==========================================\n", "# YOUR REPO ID (Created in previous step)\n", "HF_ID = \"prism-lab/wikitext-103-prism-32k-seq4k\"\n", "\n", "# Hyperparameters\n", "VOCAB_SIZE = 32768\n", "SEQ_LEN = 4096\n", "BATCH_SIZE = 8\n", "EPOCHS = 40\n", "LR = 1e-3\n", "D_MODEL = 512\n", "D_BRANCH = 256\n", "DEPTH = 9\n", "RESUME_PATH = None #\"/content/drive/MyDrive/PRISM_Experiments/PILLARS_SplitStream_8Layer_20260116_025321_8438ce62/last.pt\"\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "torch.set_float32_matmul_precision(\"high\")\n", "\n", "# ==========================================\n", "# 2. DATA PIPELINE (The \"Pro\" Way)\n", "# ==========================================\n", "def prepare_data_from_hub():\n", " print(f\"ā¬‡ļø Pulling Pre-Tokenized Data from {HF_ID}...\")\n", "\n", " # 1. Load Tokenizer (Instant)\n", " # This pulls the exact tokenizer you uploaded\n", " tokenizer = RobertaTokenizerFast.from_pretrained(HF_ID)\n", "\n", " # 2. Load Dataset (Instant)\n", " # This pulls the already chunked/tokenized data\n", " dataset = load_dataset(HF_ID)\n", "\n", " print(f\"āœ… Loaded {len(dataset['train'])} training chunks.\")\n", "\n", " # 3. Collator\n", " data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=True,\n", " mlm_probability=0.15\n", " )\n", "\n", " return dataset, data_collator\n", "# ==========================================\n", "# 3. PRISM ARCHITECTURE (Complex-Valued)\n", "# ==========================================\n", "\n", "class ComplexDropout(nn.Module):\n", " def __init__(self, p=0.5):\n", " super().__init__()\n", " self.p = p\n", " def forward(self, z):\n", " if not self.training or self.p == 0.0: return z\n", " mask = torch.ones_like(z.real)\n", " mask = F.dropout(mask, self.p, self.training, inplace=False)\n", " return z * mask\n", "\n", "class RobustPhaseNorm(nn.Module):\n", " def __init__(self, d_model, eps=1e-5):\n", " super().__init__()\n", " self.scale = nn.Parameter(torch.ones(d_model))\n", " self.eps = eps\n", " def forward(self, x):\n", " mag = torch.abs(x)\n", " rms = torch.sqrt(torch.mean(mag**2, dim=-1, keepdim=True) + self.eps)\n", " return (x / rms) * self.scale\n", "\n", "class ModReLU(nn.Module):\n", " def __init__(self, features):\n", " super().__init__()\n", " self.b = nn.Parameter(torch.zeros(features))\n", " def forward(self, z):\n", " mag = torch.abs(z)\n", " new_mag = F.relu(mag + self.b)\n", " phase = z / (mag + 1e-6)\n", " return new_mag * phase\n", "\n", "class ComplexToRealBridge(nn.Module):\n", " def __init__(self, d_model):\n", " super().__init__()\n", " self.proj = nn.Linear(d_model * 2, d_model)\n", " self.norm = nn.LayerNorm(d_model)\n", " def forward(self, x_complex):\n", " cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)\n", " return self.norm(self.proj(cat))\n", "\n", "# ==========================================\n", "# 4. DYNAMIC RoSE (Mamba-3 Engine)\n", "# ==========================================\n", "class DynamicRoSE(nn.Module):\n", " def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):\n", " super().__init__()\n", " self.embedding_dim = embedding_dim\n", "\n", " # 1. Master Real Embedding (The \"Particle\")\n", " self.raw_embedding = nn.Embedding(num_embeddings, embedding_dim)\n", "\n", " # 2. Complex Adapter (The \"Wave\" Magnitude/Initial Phase)\n", " self.adapter = nn.Linear(embedding_dim, embedding_dim * 2)\n", "\n", " # 3. Static Frequencies (Positional)\n", " freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))\n", " self.register_buffer('freqs', freqs)\n", "\n", " self.rotation_predictor = nn.Linear(embedding_dim, embedding_dim * 2)\n", "\n", " def forward(self, input_ids):\n", " # A. Raw Particle\n", " real_base = self.raw_embedding(input_ids)\n", " B, L, D = real_base.shape\n", "\n", " # B. Complex Wave Content\n", " complex_params = self.adapter(real_base)\n", " z_t = torch.complex(complex_params[..., :D], complex_params[..., D:])\n", "\n", " rot_raw = self.rotation_predictor(real_base)\n", " rot_x, rot_y = rot_raw.chunk(2, dim=-1)\n", "\n", " rot_mag = torch.sqrt(rot_x**2 + rot_y**2 + 1e-6)\n", " dynamic_rot = torch.complex(rot_x / rot_mag, rot_y / rot_mag)\n", "\n", " # D. Static Positional Rotation\n", " pos = torch.arange(L, device=input_ids.device).float()\n", " static_angles = torch.outer(pos, self.freqs) # [L, D]\n", " static_rot = torch.polar(torch.ones_like(static_angles), static_angles) # [L, D]\n", "\n", " z_final = z_t * static_rot.unsqueeze(0) * dynamic_rot\n", "\n", " return z_final, real_base\n", "\n", "# ==========================================\n", "# 5. HYENA FILTER\n", "# ==========================================\n", "class HyenaNeuralFilter(nn.Module):\n", " def __init__(self, d_model, max_len=1024, hidden_dim=64):\n", " super().__init__()\n", " self.d_model = d_model\n", " freqs = torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / hidden_dim))\n", " self.register_buffer(\"freqs\", freqs)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n", " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n", " nn.Linear(hidden_dim, d_model * 2)\n", " )\n", " def forward(self, L, device):\n", " t = torch.linspace(0, 1, steps=L, device=device).unsqueeze(-1)\n", " emb = torch.cat([torch.sin(t * self.freqs), torch.cos(t * self.freqs)], dim=-1)\n", " out = self.mlp(emb).view(L, self.d_model, 2)\n", " return torch.complex(out[..., 0], out[..., 1])\n", "\n", "# ==========================================\n", "# 6. GATED HARMONIC CONVOLUTION (Lean)\n", "# ==========================================\n", "class GatedHarmonicConvolution(nn.Module):\n", " def __init__(self, d_model, max_len=1024, dropout=0.1):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.filter_len = max_len\n", " self.neural_filter = HyenaNeuralFilter(d_model, max_len=max_len)\n", " self.gate_proj = nn.Linear(d_model * 2, d_model * 2)\n", " self.mix_real = nn.Linear(d_model, d_model)\n", " self.mix_imag = nn.Linear(d_model, d_model)\n", " self.out_real = nn.Linear(d_model, d_model)\n", " self.out_imag = nn.Linear(d_model, d_model)\n", " self.activation = ModReLU(d_model)\n", " self.norm = RobustPhaseNorm(d_model)\n", " self.dropout = ComplexDropout(dropout)\n", "\n", " def forward(self, x, src_mask=None):\n", " residual = x\n", " x_norm = self.norm(x)\n", " if src_mask is not None:\n", " x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0)\n", "\n", " # 1. Global Beam (FFT + Hyena)\n", " B, L, D = x_norm.shape\n", " eff_L = min(L, self.filter_len)\n", " x_freq = torch.fft.fft(x_norm, n=eff_L, dim=1, norm='ortho')\n", " h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n", " x_filtered = x_freq * h\n", " x_time = torch.fft.ifft(x_filtered, n=eff_L, dim=1, norm='ortho')\n", " if L > eff_L: x_time = F.pad(x_time, (0,0,0,L-eff_L))\n", " else: x_time = x_time[:, :L, :]\n", "\n", " # 2. Gating\n", " gates = torch.sigmoid(self.gate_proj(torch.cat([x_norm.real, x_norm.imag], dim=-1)))\n", " g_r, g_i = gates.chunk(2, dim=-1)\n", " x_gated = torch.complex(x_time.real * g_r, x_time.imag * g_i)\n", "\n", " # 3. Mixing & Out\n", " mr, mi = self.mix_real, self.mix_imag\n", " x_mixed = torch.complex(mr(x_gated.real) - mi(x_gated.imag), mr(x_gated.imag) + mi(x_gated.real))\n", " x_act = self.activation(x_mixed)\n", " or_, oi = self.out_real, self.out_imag\n", " out = torch.complex(or_(x_act.real) - oi(x_act.imag), or_(x_act.imag) + oi(x_act.real))\n", " return self.dropout(out) + residual\n", "\n", "# ==========================================\n", "# 7. MODEL WRAPPERS\n", "# ==========================================\n", "class PRISMEncoder(nn.Module):\n", " def __init__(self, num_layers, d_model, max_len, dropout=0.1):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " GatedHarmonicConvolution(d_model, max_len, dropout)\n", " for _ in range(num_layers)\n", " ])\n", " self.final_norm = RobustPhaseNorm(d_model)\n", " def forward(self, x, src_mask=None):\n", " for layer in self.layers:\n", " if self.training: x = torch.utils.checkpoint.checkpoint(layer, x, src_mask, use_reentrant=False)\n", " else: x = layer(x, src_mask)\n", " return self.final_norm(x)\n", "\n", "class PRISM_WikiText_Model(nn.Module):\n", " def __init__(self, vocab_size, d_model, max_len, prism_depth=5, trans_depth=1, dropout=0.1):\n", " super().__init__()\n", " self.d_model = d_model\n", "\n", " # 1. PRISM Core (The Optical/Passive Part)\n", " self.rose = DynamicRoSE(vocab_size, d_model)\n", " self.prism_encoder = PRISMEncoder(prism_depth, d_model, max_len=max_len, dropout=dropout)\n", " self.bridge = ComplexToRealBridge(d_model)\n", " self.periscope_proj = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU())\n", "\n", " # 2. Refiner (The Digital/Active Part)\n", " # šŸ”„ SWAPPED: Replaced Standard Transformer with RoPE-Enabled Encoder\n", " if trans_depth > 0:\n", " self.refiner = Encoder(\n", " dim=d_model,\n", " depth=trans_depth,\n", " heads=8,\n", " rotary_pos_emb=True,\n", " attn_flash=True,\n", " attn_dropout=dropout,\n", " ff_dropout=dropout,\n", "\n", " )\n", " else:\n", " self.refiner = None\n", "\n", " # 3. Output\n", " self.lm_head = nn.Linear(d_model, vocab_size)\n", " self.lm_head.weight = self.rose.raw_embedding.weight\n", "\n", " def forward(self, input_ids):\n", " # A. Wave Physics\n", " wave_src, particle_src = self.rose(input_ids)\n", " wave_out = self.prism_encoder(wave_src)\n", " wave_real = self.bridge(wave_out)\n", "\n", " # B. Interface\n", " mixed_memory = self.periscope_proj(torch.cat([wave_real, particle_src], dim=-1))\n", "\n", " # C. Digital Refinement (Now with RoPE)\n", " if self.refiner:\n", " out = self.refiner(mixed_memory)\n", " else:\n", " out = mixed_memory\n", "\n", " return self.lm_head(out)\n", "\n", "class FNetBlock(nn.Module):\n", " def __init__(self, d_model, d_ff, dropout):\n", " super().__init__()\n", " self.norm_mix = nn.LayerNorm(d_model) # LayerNorm is safer for FNet than RMSNorm\n", " self.norm_ff = nn.LayerNorm(d_model)\n", "\n", " self.mix_dropout = nn.Dropout(dropout)\n", "\n", " self.ff = nn.Sequential(\n", " nn.Linear(d_model, d_ff),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(d_ff, d_model),\n", " nn.Dropout(dropout)\n", " )\n", "\n", " def forward(self, x):\n", " # 1. Fourier Mixing Branch\n", " residual = x\n", " x = self.norm_mix(x)\n", "\n", " # --- THE FIX ---\n", " with torch.cuda.amp.autocast(enabled=False):\n", " x = x.float()\n", " # norm='ortho' makes the FFT energy-preserving.\n", " # Output magnitude will match input magnitude (~1).\n", " x = torch.fft.fftn(x, dim=(-2, -1), norm='ortho').real\n", " x = x.to(dtype=residual.dtype)\n", " # ---------------\n", "\n", " # Now 'x' and 'residual' have roughly same magnitude.\n", " # The skip connection works again.\n", " x = self.mix_dropout(x)\n", " x = x + residual\n", "\n", " # 2. Feed Forward Branch\n", " residual = x\n", " x = self.norm_ff(x)\n", " x = self.ff(x)\n", " return x + residual\n", "\n", "\n", "class FNetEncoder(nn.Module):\n", " def __init__(self, depth, d_model, d_ff, dropout):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " FNetBlock(d_model, d_ff, dropout) for _ in range(depth)\n", " ])\n", " # [FIX] Use LayerNorm here to match the blocks\n", " self.norm_out = nn.LayerNorm(d_model)\n", "\n", " def forward(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", " return self.norm_out(x)\n", "\n", "class Pillars_DualStream(nn.Module):\n", " def __init__(self, vocab_size, d_model=512, d_branch=384, seq_len=4096, depth=4):\n", " super().__init__()\n", " self.d_branch = d_branch\n", " self.d_refiner = d_model\n", "\n", " # --- A. Rate Stream (FNet) ---\n", " self.fnet_emb = nn.Embedding(vocab_size, d_branch)\n", " self.fnet_pos = nn.Embedding(seq_len, d_branch)\n", " self.stream_rate = FNetEncoder(depth=depth, d_model=d_branch, d_ff=d_branch*4, dropout=0.1)\n", "\n", " # --- B. Phase Stream (PRISM) ---\n", " self.stream_phase_emb = DynamicRoSE(vocab_size, d_branch)\n", " self.stream_phase = PRISMEncoder(num_layers=depth, d_model=d_branch, max_len=seq_len, dropout=0.1)\n", " self.phase_bridge = ComplexToRealBridge(d_branch)\n", "\n", " # --- C. Fusion (The Funnel) ---\n", " self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n", " self.fusion_norm = nn.LayerNorm(d_model)\n", "\n", " # --- D. Refiner ---\n", " self.refiner = Encoder(\n", " dim=d_model, depth=1, heads=8, attn_flash=True,\n", " rotary_pos_emb=True, attn_dropout=0.1, ff_dropout=0.1\n", " )\n", " self.lm_head = nn.Linear(d_model, vocab_size)\n", "\n", " def forward(self, x):\n", " # 1. Rate Path\n", " f_emb = self.fnet_emb(x) + self.fnet_pos(torch.arange(x.shape[1], device=x.device))\n", " rate_out = self.stream_rate(f_emb)\n", "\n", " # 2. Phase Path\n", " p_src, _ = self.stream_phase_emb(x)\n", " phase_out = self.phase_bridge(self.stream_phase(p_src))\n", "\n", " # 3. Fusion\n", " fused = self.fusion_norm(self.fusion_proj(torch.cat([rate_out, phase_out], dim=-1)))\n", "\n", " # 4. Refine & Output\n", " return self.lm_head(self.refiner(fused))\n", "\n", "\n", "class Pillars_Compact(nn.Module):\n", " def __init__(self, vocab_size, d_model=512, d_branch=384, seq_len=4096, depth=4):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.d_branch = d_branch\n", "\n", " # 1. SHARED ROOT\n", " self.rose = DynamicRoSE(vocab_size, d_model)\n", "\n", " # 2. DOWNSAMPLE (512 -> 384)\n", " self.particle_down = nn.Linear(d_model, d_branch)\n", " self.wave_down = nn.Linear(d_model * 2, d_branch * 2)\n", "\n", " # 3. RATE STREAM (FNet, Depth 4)\n", " self.fnet_pos = nn.Embedding(seq_len, d_branch)\n", " self.stream_rate = FNetEncoder(depth=depth, d_model=d_branch, d_ff=d_branch*4, dropout=0.1)\n", "\n", " # 4. PHASE STREAM (PRISM, Depth 4)\n", " self.stream_phase = PRISMEncoder(num_layers=depth, d_model=d_branch, max_len=seq_len, dropout=0.1)\n", " self.phase_bridge = ComplexToRealBridge(d_branch)\n", "\n", " # 5. FUSION (Clean Projection)\n", " # Input: 384 (Rate) + 384 (Phase) = 768\n", " # Output: 512 (Refiner Dim)\n", " self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n", " self.fusion_norm = nn.LayerNorm(d_model)\n", "\n", " # 6. REFINER (The Brain)\n", " self.refiner = Encoder(\n", " dim=d_model, depth=1, heads=8, attn_flash=True,\n", " rotary_pos_emb=True, attn_dropout=0.1, ff_dropout=0.1\n", " )\n", "\n", " # 7. TIED HEAD\n", " self.head_bias = nn.Parameter(torch.zeros(vocab_size))\n", "\n", " def forward(self, input_ids):\n", " # A. Shared Root\n", " wave_src, particle_src = self.rose(input_ids)\n", "\n", " # B. Downsample\n", " p_small = self.particle_down(particle_src)\n", " w_flat = torch.cat([wave_src.real, wave_src.imag], dim=-1)\n", " w_small_flat = self.wave_down(w_flat)\n", " w_small = torch.complex(w_small_flat[..., :self.d_branch], w_small_flat[..., self.d_branch:])\n", "\n", " # C. Branches\n", " pos_emb = self.fnet_pos(torch.arange(input_ids.shape[1], device=input_ids.device))\n", " rate_out = self.stream_rate(p_small + pos_emb)\n", " phase_out = self.phase_bridge(self.stream_phase(w_small))\n", "\n", " # D. Fusion (Concat -> Project)\n", " # We rely on the Transformer Refiner to attend to the right parts.\n", " stacked = torch.cat([rate_out, phase_out], dim=-1)\n", " context = self.fusion_norm(self.fusion_proj(stacked))\n", "\n", " # E. Refiner\n", " refined = self.refiner(context)\n", "\n", " # F. Output\n", " logits = F.linear(refined, self.rose.raw_embedding.weight, self.head_bias)\n", "\n", " return logits\n", "\n", "import torch\n", "import torch.nn as nn\n", "from prettytable import PrettyTable # Optional, but makes tables nice.\n", "# If you don't have prettytable, the code below uses standard f-strings.\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "def deep_analyze_pillars(model):\n", " def get_p(obj):\n", " \"\"\"Safely returns parameter count for Modules OR raw Parameters.\"\"\"\n", " if isinstance(obj, nn.Parameter):\n", " return obj.numel()\n", " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n", "\n", " def format_num(n):\n", " if n > 1e6: return f\"{n/1e6:.2f}M\"\n", " if n > 1e3: return f\"{n/1e3:.2f}K\"\n", " return str(n)\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(f\"šŸ—ļø PILLARS (COMPACT) - DEEP LAYER ANALYSIS\")\n", " print(\"=\"*80)\n", " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<15} | {'TYPE'}\")\n", " print(\"-\" * 80)\n", "\n", " total_params = get_p(model)\n", "\n", " # -----------------------------------------------\n", " # 1. STATIC MEMORY (Embeddings)\n", " # -----------------------------------------------\n", " vocab_emb = get_p(model.rose.raw_embedding)\n", " fnet_pos = get_p(model.fnet_pos)\n", "\n", " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<15} | šŸ’¾ STORAGE\")\n", " print(f\"{'FNet Positional Embedding':<40} | {format_num(fnet_pos):<15} | šŸ’¾ STORAGE\")\n", "\n", " # -----------------------------------------------\n", " # 2. INPUT LOGIC (RoSE & Downsampling)\n", " # -----------------------------------------------\n", " rose_total = get_p(model.rose)\n", " rose_logic = rose_total - vocab_emb # Subtract the embedding matrix we already counted\n", "\n", " print(\"-\" * 80)\n", " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<15} | 🌊 PHASE INIT\")\n", " print(f\"{'Particle Downsample (512->384)':<40} | {format_num(get_p(model.particle_down)):<15} | šŸ“‰ PROJ\")\n", " print(f\"{'Wave Downsample (1024->768)':<40} | {format_num(get_p(model.wave_down)):<15} | šŸ“‰ PROJ\")\n", "\n", " # -----------------------------------------------\n", " # 3. STREAM A: RATE (FNet)\n", " # -----------------------------------------------\n", " print(\"-\" * 80)\n", " print(f\"TRACK A: RATE STREAM (FNet) - Depth {len(model.stream_rate.layers)}\")\n", "\n", " fnet_encoder_total = 0\n", " for i, layer in enumerate(model.stream_rate.layers):\n", " p = get_p(layer)\n", " fnet_encoder_total += p\n", " print(f\" ā”œā”€ FNet Block {i:<24} | {format_num(p):<15} | ⚔ RATE\")\n", "\n", " fnet_norm = get_p(model.stream_rate.norm_out)\n", " fnet_encoder_total += fnet_norm\n", " print(f\" └─ Final Norm {i:<24} | {format_num(fnet_norm):<15} | ⚔ RATE\")\n", "\n", " # -----------------------------------------------\n", " # 4. STREAM B: PHASE (PRISM)\n", " # -----------------------------------------------\n", " print(\"-\" * 80)\n", " print(f\"TRACK B: PHASE STREAM (PRISM) - Depth {len(model.stream_phase.layers)}\")\n", "\n", " prism_encoder_total = 0\n", " for i, layer in enumerate(model.stream_phase.layers):\n", " p = get_p(layer)\n", " prism_encoder_total += p\n", " print(f\" ā”œā”€ PRISM Block {i:<23} | {format_num(p):<15} | 🌊 PHASE\")\n", "\n", " prism_norm = get_p(model.stream_phase.final_norm)\n", " prism_encoder_total += prism_norm\n", " print(f\" └─ Final Norm {i:<24} | {format_num(prism_norm):<15} | 🌊 PHASE\")\n", "\n", " bridge_p = get_p(model.phase_bridge)\n", " print(f\"{'Phase Bridge (Complex->Real)':<40} | {format_num(bridge_p):<15} | šŸŒ‰ BRIDGE\")\n", "\n", " # -----------------------------------------------\n", " # 5. THE BRAIN (Fusion & Refiner)\n", " # -----------------------------------------------\n", " print(\"-\" * 80)\n", " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n", " print(f\"{'Fusion (Concat -> Proj -> Norm)':<40} | {format_num(fusion_p):<15} | 🧠 FUSION\")\n", "\n", " refiner_p = get_p(model.refiner)\n", " print(f\"{'Transformer Refiner (1 Layer)':<40} | {format_num(refiner_p):<15} | 🧠 ATTENTION\")\n", "\n", " # [FIX] Handle nn.Parameter directly\n", " head_bias_p = get_p(model.head_bias)\n", " print(f\"{'Output Head Bias':<40} | {format_num(head_bias_p):<15} | šŸŽÆ OUTPUT\")\n", "\n", " # -----------------------------------------------\n", " # 6. SUMMARY\n", " # -----------------------------------------------\n", " print(\"=\"*80)\n", "\n", " storage = vocab_emb + fnet_pos + head_bias_p\n", " active = total_params - storage\n", "\n", " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n", " print(f\" ā”œā”€ šŸ’¾ Storage: {storage/1e6:.2f} M (Embeddings)\")\n", " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Logic/Weights)\")\n", " print(\"-\" * 80)\n", " print(f\"STREAM BREAKDOWN:\")\n", " print(f\" ā”œā”€ ⚔ Rate Stream: {fnet_encoder_total/1e6:.2f} M\")\n", " print(f\" └─ 🌊 Phase Stream: {prism_encoder_total/1e6:.2f} M\")\n", " print(\"=\"*80 + \"\\n\")\n", "\n", " return total_params\n", "\n", "model = Pillars_Compact(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=D_MODEL,\n", " d_branch=D_BRANCH,\n", " seq_len=SEQ_LEN,\n", " depth=DEPTH\n", ").to(DEVICE)\n", "deep_analyze_pillars(model)" ], "metadata": { "id": "V7DOwmmUjyin" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "# Run the parameter analysis to confirm strict adherence to budget\n", "def analyze_pillars_compact(model):\n", " print(\"\\n\" + \"=\"*70)\n", " print(\"šŸ›ļø PILLARS COMPACT: ARCHITECTURAL COST ANALYSIS\")\n", " print(\"=\"*70)\n", "\n", " stats = {\n", " \"Shared Memory (Storage)\": 0,\n", " \"Rate Stream (FNet)\": 0,\n", " \"Phase Stream (PRISM)\": 0,\n", " \"Fusion & Refiner\": 0,\n", " \"Tied Head Bias\": 0\n", " }\n", "\n", " total_params = 0\n", "\n", " for name, param in model.named_parameters():\n", " if not param.requires_grad: continue\n", " n = param.numel()\n", " total_params += n\n", "\n", " if \"rose.raw_embedding\" in name:\n", " stats[\"Shared Memory (Storage)\"] += n\n", " elif \"rose.adapter\" in name or \"rose.rotation\" in name or \"stream_phase\" in name or \"phase_bridge\" in name:\n", " stats[\"Phase Stream (PRISM)\"] += n\n", " elif \"fnet_pos\" in name or \"stream_rate\" in name:\n", " stats[\"Rate Stream (FNet)\"] += n\n", " elif \"gate\" in name or \"mix\" in name or \"refiner\" in name or \"down\" in name or \"proj\" in name or \"norm\" in name:\n", " stats[\"Fusion & Refiner\"] += n\n", " elif \"head_bias\" in name:\n", " stats[\"Tied Head Bias\"] += n\n", " else:\n", " print(f\"āš ļø Uncategorized: {name} ({n})\")\n", "\n", " print(f\"{'COMPONENT':<30} | {'PARAMS':<12} | {'% TOTAL':<8}\")\n", " print(\"-\" * 60)\n", "\n", " for category, count in stats.items():\n", " if count > 0:\n", " pct = (count / total_params) * 100\n", " print(f\"{category:<30} | {count:12,} | {pct:6.1f}%\")\n", "\n", " print(\"-\" * 60)\n", " print(f\"{'TOTAL PARAMETERS':<30} | {total_params:12,} | 100.0%\")\n", " print(\"=\" * 70)\n", "\n", "\n", " active_params = total_params - stats[\"Shared Memory (Storage)\"] - stats[\"Tied Head Bias\"]\n", " print(f\" 1. Total Model Size: {total_params/1e6:.1f}M\")\n", " print(f\" 2. Baseline Target: ~32.5M\")\n", " print(f\" 3. Active Reasoning Params: {active_params/1e6:.1f}M (The actual brain)\")\n", " print(\"=\"*70 + \"\\n\")\n", "\n", "\n" ], "metadata": { "id": "ke4fYT8UX5zH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ==========================================\n", "# 4. LOGGING UTILITIES\n", "# ==========================================\n", "def generate_run_id():\n", " raw = datetime.now().strftime(\"%Y%m%d%H%M%S%f\")\n", " return hashlib.md5(raw.encode()).hexdigest()[:8]\n", "\n", "def log_environment(save_dir, run_id, config):\n", " log_path = os.path.join(save_dir, f\"env_metadata_{run_id}.txt\")\n", " with open(log_path, \"w\") as f:\n", " f.write(f\"PRISM EXPERIMENT METADATA | Run ID: {run_id}\\n{'='*50}\\n\")\n", " for k, v in config.items(): f.write(f\"{k}: {v}\\n\")\n", " print(f\"šŸ“ Environment Snapshot saved to: {log_path}\")\n", "\n", "def log_metrics(save_dir, run_id, epoch, train_loss, val_loss, ppl):\n", " log_path = os.path.join(save_dir, f\"metrics_log_{run_id}.csv\")\n", " if not os.path.exists(log_path):\n", " with open(log_path, \"w\") as f: f.write(\"Timestamp,Epoch,Train_Loss,Val_Loss,Perplexity\\n\")\n", " with open(log_path, \"a\") as f:\n", " ts = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " f.write(f\"{ts},{epoch},{train_loss:.6f},{val_loss:.6f},{ppl:.6f}\\n\")\n", "\n", "\n", "def save_checkpoint(path, model, optimizer, scheduler, epoch, best_loss, config):\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': scheduler.state_dict(),\n", " 'best_val_loss': best_loss,\n", " 'config': config\n", " }, path)\n", "\n", "def init_pillars_weights(model):\n", " print(\"✨ APPLYING PILLARS INITIALIZATION PROTOCOL...\")\n", "\n", " # 1. SHARED ROOT (RoSE) - MATCHING YOUR ORIGINAL LOGIC\n", " # Standard embedding init\n", " nn.init.normal_(model.rose.raw_embedding.weight, std=model.d_model ** -0.5)\n", "\n", " # Adapter: Orthogonal ensures clean entry to complex plane\n", " nn.init.orthogonal_(model.rose.adapter.weight)\n", "\n", " # --- THE ROSE IDENTITY TRICK (From your original code) ---\n", " # Start with almost zero rotation influence from content\n", " nn.init.normal_(model.rose.rotation_predictor.weight, std=0.01)\n", " with torch.no_grad():\n", " # Force initial vector to (1, 0) -> Angle 0, Mag 1\n", " # This allows the model to start with \"Safe\" static physics\n", " model.rose.rotation_predictor.bias[:model.d_model].fill_(1.0)\n", " model.rose.rotation_predictor.bias[model.d_model:].fill_(0.0)\n", " # -------------------------------------------------------\n", "\n", " # 2. DOWNSAMPLERS (The Split)\n", " # Scale gain by 1.414 (sqrt 2) to preserve energy when halving dimensions\n", " nn.init.orthogonal_(model.particle_down.weight, gain=1.414)\n", " nn.init.orthogonal_(model.wave_down.weight, gain=1.414)\n", "\n", " # 3. FNET BRANCH (Rate Stream)\n", " # Kaiming Normal (Good for GELU)\n", " for name, m in model.stream_rate.named_modules():\n", " if isinstance(m, nn.Linear):\n", " nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", " if m.bias is not None: nn.init.zeros_(m.bias)\n", "\n", " # 4. PRISM BRANCH (Phase Stream)\n", " # Xavier Uniform (Good for Complex/Linear)\n", " for name, m in model.stream_phase.named_modules():\n", " if isinstance(m, nn.Linear):\n", " nn.init.xavier_uniform_(m.weight, gain=1.0)\n", " if m.bias is not None: nn.init.zeros_(m.bias)\n", " # Initialize ModReLU bias slightly positive to avoid dead neurons\n", " if isinstance(m, ModReLU):\n", " nn.init.constant_(m.b, 0.01)\n", "\n", " # 5. FUSION & REFINER\n", " # Start neutral\n", " nn.init.xavier_uniform_(model.fusion_proj.weight, gain=1.0)\n", "\n", " for p in model.refiner.parameters():\n", " if p.dim() > 1:\n", " nn.init.xavier_uniform_(p)\n", "\n", " # 6. TIED HEAD BIAS\n", " nn.init.zeros_(model.head_bias)\n", "\n", " print(\"āœ… INITIALIZATION COMPLETE.\")\n", "\n", "def run_wikitext_training(experiment_name=\"PILLARS_SplitStream_9Layer\"):\n", " from google.colab import drive\n", " if not os.path.exists('/content/drive'): drive.mount('/content/drive')\n", "\n", " # --- SETUP DIRS ---\n", " if RESUME_PATH and os.path.exists(RESUME_PATH):\n", " print(f\"šŸ”„ RESUMING FROM: {RESUME_PATH}\")\n", " checkpoint = torch.load(RESUME_PATH, map_location=DEVICE)\n", " SAVE_DIR = os.path.dirname(RESUME_PATH)\n", " run_id = checkpoint.get('config', {}).get('run_id', 'resumed')\n", " else:\n", " run_id = hashlib.md5(datetime.now().strftime(\"%Y%m%d%H%M%S%f\").encode()).hexdigest()[:8]\n", " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", " folder_name = f\"{experiment_name}_{timestamp}_{run_id}\"\n", " SAVE_DIR = os.path.join(\"/content/drive/My Drive/PRISM_Experiments\", folder_name)\n", " os.makedirs(SAVE_DIR, exist_ok=True)\n", " print(f\"šŸ’¾ Checkpoints: {SAVE_DIR}\")\n", "\n", " writer = SummaryWriter(log_dir=SAVE_DIR)\n", " GRAD_ACCUM = 4\n", "\n", " lm_datasets, data_collator = prepare_data_from_hub()\n", "\n", " train_loader = DataLoader(\n", " lm_datasets[\"train\"], batch_size=BATCH_SIZE, shuffle=True,\n", " collate_fn=data_collator, num_workers=2, pin_memory=True,\n", " prefetch_factor=2, persistent_workers=True\n", " )\n", " valid_loader = DataLoader(\n", " lm_datasets[\"validation\"], batch_size=BATCH_SIZE,\n", " collate_fn=data_collator, num_workers=2, pin_memory=True\n", " )\n", " test_loader = DataLoader(\n", " lm_datasets[\"test\"], batch_size=BATCH_SIZE,\n", " collate_fn=data_collator, num_workers=2, pin_memory=True\n", " )\n", "\n", " print(\"\\n⚔ INITIALIZING PILLARS MODEL...\")\n", "\n", " # INSTANTIATE THE NEW MODEL\n", " model = Pillars_Compact(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=D_MODEL,\n", " d_branch=D_BRANCH,\n", " seq_len=SEQ_LEN,\n", " depth=DEPTH\n", " ).to(DEVICE)\n", "\n", " optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) # Added decay for stabilization\n", " total_steps = (len(train_loader) // GRAD_ACCUM) * EPOCHS\n", " scheduler = get_cosine_schedule_with_warmup(\n", " optimizer, num_warmup_steps=int(0.05 * total_steps), num_training_steps=total_steps\n", " )\n", " criterion = nn.CrossEntropyLoss()\n", "\n", " start_epoch = 0\n", " best_val_loss = float('inf')\n", "\n", " if RESUME_PATH and os.path.exists(RESUME_PATH):\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", " start_epoch = checkpoint['epoch'] + 1\n", " best_val_loss = checkpoint['best_val_loss']\n", " del checkpoint\n", " torch.cuda.empty_cache()\n", " else:\n", " # APPLY THE NEW INIT LOGIC\n", " init_pillars_weights(model)\n", " print(model)\n", " analyze_pillars_compact(model)\n", " print(f\"\\nšŸš€ STARTING (Ep {start_epoch+1} to {EPOCHS})\")\n", " global_step = (len(train_loader) // GRAD_ACCUM) * start_epoch\n", "\n", " for epoch in range(start_epoch, EPOCHS):\n", " model.train()\n", " pbar = tqdm(train_loader, desc=f\"Ep {epoch+1}/{EPOCHS}\")\n", "\n", " for step, batch in enumerate(pbar):\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", "\n", " loss = criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)) / GRAD_ACCUM\n", " loss.backward()\n", "\n", " if (step + 1) % GRAD_ACCUM == 0:\n", " # 1. Calc Norm\n", " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " # 2. Step\n", " optimizer.step()\n", " scheduler.step()\n", " optimizer.zero_grad()\n", " global_step += 1\n", "\n", " # 3. LOGGING\n", " actual_loss = loss.item() * GRAD_ACCUM\n", "\n", " # [FIX] Log Grad Norm to TensorBoard now\n", " writer.add_scalar('Train/Loss', actual_loss, global_step)\n", " writer.add_scalar('Train/GradNorm', grad_norm.item(), global_step)\n", "\n", " # 4. Progress Bar\n", " pbar.set_postfix({\n", " 'loss': f\"{actual_loss:.4f}\",\n", " 'gnorm': f\"{grad_norm.item():.2f}\"\n", " })\n", "\n", " # VALIDATION\n", " model.eval()\n", " val_loss = 0\n", " with torch.no_grad():\n", " for batch in valid_loader:\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", " val_loss += criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)).item()\n", "\n", " avg_val_loss = val_loss / len(valid_loader)\n", " ppl = math.exp(avg_val_loss) if avg_val_loss < 100 else float('inf')\n", "\n", " print(f\"✨ Epoch {epoch+1} | Val Loss: {avg_val_loss:.4f} | PPL: {ppl:.2f}\")\n", " writer.add_scalar('Val/PPL', ppl, epoch+1)\n", "\n", " config_dump = {\"epoch\": epoch, \"run_id\": run_id}\n", " save_checkpoint(os.path.join(SAVE_DIR, \"last.pt\"), model, optimizer, scheduler, epoch, best_val_loss, config_dump)\n", "\n", " if avg_val_loss < best_val_loss:\n", " best_val_loss = avg_val_loss\n", " torch.save(model.state_dict(), os.path.join(SAVE_DIR, \"best.pt\"))\n", " print(\" šŸ† New Best Model Saved!\")\n", "\n", " best_path = os.path.join(SAVE_DIR, \"best.pt\")\n", " if os.path.exists(best_path):\n", " model.load_state_dict(torch.load(best_path))\n", " model.eval()\n", " test_loss = 0\n", " with torch.no_grad():\n", " for batch in tqdm(test_loader, desc=\"Testing\"):\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", " test_loss += criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)).item()\n", " print(f\"šŸ† FINAL PPL: {math.exp(test_loss/len(test_loader)):.2f}\")\n", "\n", " writer.close()\n", " return model" ], "metadata": { "id": "-TNEv89gkS1k" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "if __name__ == \"__main__\":\n", "\n", " print(\"šŸ”„ IGNITING PILLARS TRAINING PIPELINE...\")\n", "\n", " # 1. Run the Training Routine\n", " # This handles Model Creation -> Analysis -> Training -> Saving\n", " trained_prism = run_wikitext_training()\n", "\n", " # 2. Cleanup & Shutdown\n", " print(\"āœ… Experiment Complete. Shutting down runtime...\")\n", " from google.colab import runtime\n", " runtime.unassign()" ], "metadata": { "id": "KaiJU0tPkVp-" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import runtime\n", "runtime.unassign()" ], "metadata": { "id": "bxFTYWHVqcSI" }, "execution_count": null, "outputs": [] } ] }