{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install -q x-transformers\n", "!pip install -q flash-attn --no-build-isolation" ], "metadata": { "id": "6q9RTvlf5IiS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import math\n", "import os\n", "import sys\n", "import subprocess\n", "import hashlib\n", "import gc\n", "from datetime import datetime\n", "from tqdm.auto import tqdm\n", "from torch.utils.data import DataLoader\n", "from torch.utils.tensorboard import SummaryWriter\n", "from transformers import RobertaTokenizerFast, get_cosine_schedule_with_warmup, DataCollatorForLanguageModeling\n", "from datasets import load_dataset\n", "from x_transformers import Encoder\n", "\n", "# ==========================================\n", "# 1. CONFIGURATION\n", "# ==========================================\n", "# YOUR REPO ID (Created in previous step)\n", "HF_ID = \"prism-lab/wikitext-103-prism-32k-seq4k\"\n", "\n", "# Hyperparameters\n", "VOCAB_SIZE = 32768\n", "SEQ_LEN = 4096\n", "BATCH_SIZE = 8\n", "EPOCHS = 40\n", "LR = 1e-3\n", "D_MODEL = 512\n", "D_BRANCH = 256\n", "DEPTH = 6\n", "RESUME_PATH = None #\"/content/drive/MyDrive/PRISM_Experiments/PILLARS_SplitStream_8Layer_20260116_025321_8438ce62/last.pt\"\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "torch.set_float32_matmul_precision(\"high\")\n", "\n", "# ==========================================\n", "# 2. DATA PIPELINE (The \"Pro\" Way)\n", "# ==========================================\n", "def prepare_data_from_hub():\n", " print(f\"ā¬‡ļø Pulling Pre-Tokenized Data from {HF_ID}...\")\n", "\n", " # 1. Load Tokenizer (Instant)\n", " # This pulls the exact tokenizer you uploaded\n", " tokenizer = RobertaTokenizerFast.from_pretrained(HF_ID)\n", "\n", " # 2. Load Dataset (Instant)\n", " # This pulls the already chunked/tokenized data\n", " dataset = load_dataset(HF_ID)\n", "\n", " print(f\"āœ… Loaded {len(dataset['train'])} training chunks.\")\n", "\n", " # 3. Collator\n", " data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=True,\n", " mlm_probability=0.15\n", " )\n", "\n", " return dataset, data_collator\n", "# ==========================================\n", "# 3. PRISM ARCHITECTURE (Complex-Valued)\n", "# ==========================================\n", "\n", "class ComplexDropout(nn.Module):\n", " def __init__(self, p=0.5):\n", " super().__init__()\n", " self.p = p\n", " def forward(self, z):\n", " if not self.training or self.p == 0.0: return z\n", " mask = torch.ones_like(z.real)\n", " mask = F.dropout(mask, self.p, self.training, inplace=False)\n", " return z * mask\n", "\n", "class RobustPhaseNorm(nn.Module):\n", " def __init__(self, d_model, eps=1e-5):\n", " super().__init__()\n", " self.scale = nn.Parameter(torch.ones(d_model))\n", " self.eps = eps\n", " def forward(self, x):\n", " mag = torch.abs(x)\n", " rms = torch.sqrt(torch.mean(mag**2, dim=-1, keepdim=True) + self.eps)\n", " return (x / rms) * self.scale\n", "\n", "class ModReLU(nn.Module):\n", " def __init__(self, features):\n", " super().__init__()\n", " self.b = nn.Parameter(torch.zeros(features))\n", "\n", " def forward(self, z):\n", " # 1. FORCE FLOAT32 FOR GEOMETRY\n", " # We must calculate magnitude in high precision to prevent\n", " # square-law overflow (Re^2 + Im^2) from killing the gradients.\n", " z_32 = z.to(torch.complex64)\n", "\n", " # 2. Calculate Magnitude (Safe)\n", " mag = torch.abs(z_32)\n", "\n", " # 3. Activation Logic (Still FP32)\n", " new_mag = F.relu(mag + self.b.float())\n", "\n", " # 4. Reconstruct Phase (Safe Division)\n", " # (z / mag) is the unit vector (phase)\n", " phase = z_32 / (mag + 1e-6)\n", "\n", " # 5. Result\n", " out = new_mag * phase\n", "\n", " # 6. Cast back to network dtype (BF16/FP16)\n", " return out.to(z.dtype)\n", "\n", "class ComplexToRealBridge(nn.Module):\n", " def __init__(self, d_model):\n", " super().__init__()\n", " self.proj = nn.Linear(d_model * 2, d_model)\n", " self.norm = nn.LayerNorm(d_model)\n", " def forward(self, x_complex):\n", " cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)\n", " return self.norm(self.proj(cat))\n", "\n", "# ==========================================\n", "# 4. DYNAMIC RoSE (Mamba-3 Engine)\n", "# ==========================================\n", "class DynamicRoSE(nn.Module):\n", " def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):\n", " super().__init__()\n", " self.embedding_dim = embedding_dim\n", "\n", " # 1. Master Real Embedding (The \"Particle\")\n", " self.raw_embedding = nn.Embedding(num_embeddings, embedding_dim)\n", "\n", " # 2. Complex Adapter (The \"Wave\" Magnitude/Initial Phase)\n", " self.adapter = nn.Linear(embedding_dim, embedding_dim * 2)\n", "\n", " # 3. Static Frequencies (Positional)\n", " freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))\n", " self.register_buffer('freqs', freqs)\n", "\n", " self.rotation_predictor = nn.Linear(embedding_dim, embedding_dim * 2)\n", "\n", " def forward(self, input_ids):\n", " # A. Raw Particle\n", " real_base = self.raw_embedding(input_ids)\n", " B, L, D = real_base.shape\n", "\n", " # B. Complex Wave Content\n", " complex_params = self.adapter(real_base)\n", " z_t = torch.complex(complex_params[..., :D], complex_params[..., D:])\n", "\n", " rot_raw = self.rotation_predictor(real_base)\n", " rot_x, rot_y = rot_raw.chunk(2, dim=-1)\n", "\n", " rot_mag = torch.sqrt(rot_x**2 + rot_y**2 + 1e-6)\n", " dynamic_rot = torch.complex(rot_x / rot_mag, rot_y / rot_mag)\n", "\n", " # D. Static Positional Rotation\n", " pos = torch.arange(L, device=input_ids.device).float()\n", " static_angles = torch.outer(pos, self.freqs) # [L, D]\n", " static_rot = torch.polar(torch.ones_like(static_angles), static_angles) # [L, D]\n", "\n", " z_final = z_t * static_rot.unsqueeze(0) * dynamic_rot\n", "\n", " return z_final, real_base\n", "\n", "# ==========================================\n", "# 5. HYENA FILTER\n", "# ==========================================\n", "class HyenaNeuralFilter(nn.Module):\n", " def __init__(self, d_model, max_len=1024, hidden_dim=64):\n", " super().__init__()\n", " self.d_model = d_model\n", " freqs = torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / hidden_dim))\n", " self.register_buffer(\"freqs\", freqs)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n", " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n", " nn.Linear(hidden_dim, d_model * 2)\n", " )\n", " def forward(self, L, device):\n", " t = torch.linspace(0, 1, steps=L, device=device).unsqueeze(-1)\n", " emb = torch.cat([torch.sin(t * self.freqs), torch.cos(t * self.freqs)], dim=-1)\n", " out = self.mlp(emb).view(L, self.d_model, 2)\n", " return torch.complex(out[..., 0], out[..., 1])\n", "\n", "# ==========================================\n", "# 6. GATED HARMONIC CONVOLUTION (Lean)\n", "# ==========================================\n", "# @title šŸ› ļø Fixed PRISM Layer (Precision-Gated)\n", "\n", "# @title šŸ› ļø Fixed PRISM Layer (Type-Safe)\n", "\n", "class GatedHarmonicConvolution(nn.Module):\n", " def __init__(self, d_model, max_len=1024, dropout=0.1):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.filter_len = max_len\n", " self.neural_filter = HyenaNeuralFilter(d_model, max_len=max_len)\n", " self.gate_proj = nn.Linear(d_model * 2, d_model * 2)\n", "\n", " self.mix_real = nn.Linear(d_model, d_model)\n", " self.mix_imag = nn.Linear(d_model, d_model)\n", " self.out_real = nn.Linear(d_model, d_model)\n", " self.out_imag = nn.Linear(d_model, d_model)\n", "\n", " self.activation = ModReLU(d_model)\n", " self.norm = RobustPhaseNorm(d_model)\n", " self.dropout = ComplexDropout(dropout)\n", "\n", " def forward(self, x, src_mask=None):\n", " residual = x\n", " x_norm = self.norm(x)\n", " if src_mask is not None:\n", " x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0)\n", "\n", " # šŸ›‘ PRECISION GATE šŸ›‘\n", " # Force operations to Float32 Complex to preserve Phase Physics\n", " with torch.amp.autocast('cuda', enabled=False):\n", "\n", " # --- THE FIX IS HERE ---\n", " # Old: x_32 = x_norm.float() <-- This stripped the imaginary part\n", " # New: Explicit cast to Complex64\n", " x_32 = x_norm.to(torch.complex64)\n", " # -----------------------\n", "\n", " B, L, D = x_32.shape\n", " eff_L = min(L, self.filter_len)\n", "\n", " # 1. FFT (Now safe because x_32 is definitely complex)\n", " x_freq = torch.fft.fft(x_32, n=eff_L, dim=1, norm='ortho')\n", "\n", " # 2. Filter (Ensure filter is also complex64)\n", " h = self.neural_filter(eff_L, x.device).unsqueeze(0).to(torch.complex64)\n", " x_filtered = x_freq * h\n", "\n", " # 3. IFFT\n", " x_time = torch.fft.ifft(x_filtered, n=eff_L, dim=1, norm='ortho')\n", "\n", " if L > eff_L: x_time = F.pad(x_time, (0,0,0,L-eff_L))\n", " else: x_time = x_time[:, :L, :]\n", "\n", " # 4. Gating (Sigmoid logic)\n", " # Safe concatenation because x_32 is complex64\n", " x_cat = torch.cat([x_32.real, x_32.imag], dim=-1)\n", "\n", " # Cast weights to Float32 for the calculation\n", " gate_w = self.gate_proj.weight.to(torch.float32)\n", " gate_b = self.gate_proj.bias.to(torch.float32)\n", "\n", " gate_out = F.linear(x_cat, gate_w, gate_b)\n", " gates = torch.sigmoid(gate_out)\n", "\n", " g_r, g_i = gates.chunk(2, dim=-1)\n", " x_gated_32 = torch.complex(x_time.real * g_r, x_time.imag * g_i)\n", "\n", " # šŸ EXIT GATE: Cast back to original dtype (likely BFloat16 from autocast)\n", " # We cast real/imag separately to be safe\n", " target_dtype = x.dtype\n", " # If x was complex, target is complex. If x was real, we have an issue.\n", " # Assuming x comes from autocast, it might be complex16.\n", "\n", " x_gated = x_gated_32.to(target_dtype)\n", "\n", " # 5. Mixing (Back in mixed precision)\n", " mr, mi = self.mix_real, self.mix_imag\n", " x_mixed = torch.complex(mr(x_gated.real) - mi(x_gated.imag), mr(x_gated.imag) + mi(x_gated.real))\n", "\n", " x_act = self.activation(x_mixed)\n", "\n", " or_, oi = self.out_real, self.out_imag\n", " out = torch.complex(or_(x_act.real) - oi(x_act.imag), or_(x_act.imag) + oi(x_act.real))\n", "\n", " return self.dropout(out) + residual\n", "# ==========================================\n", "# 7. MODEL WRAPPERS\n", "# ==========================================\n", "class PRISMEncoder(nn.Module):\n", " def __init__(self, num_layers, d_model, max_len, dropout=0.1):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " GatedHarmonicConvolution(d_model, max_len, dropout)\n", " for _ in range(num_layers)\n", " ])\n", " self.final_norm = RobustPhaseNorm(d_model)\n", " def forward(self, x, src_mask=None):\n", " for layer in self.layers:\n", " if self.training: x = torch.utils.checkpoint.checkpoint(layer, x, src_mask, use_reentrant=False)\n", " else: x = layer(x, src_mask)\n", " return self.final_norm(x)\n", "\n", "class PRISM_WikiText_Model(nn.Module):\n", " def __init__(self, vocab_size, d_model, max_len, prism_depth=5, trans_depth=1, dropout=0.1):\n", " super().__init__()\n", " self.d_model = d_model\n", "\n", " # 1. PRISM Core (The Optical/Passive Part)\n", " self.rose = DynamicRoSE(vocab_size, d_model)\n", " self.prism_encoder = PRISMEncoder(prism_depth, d_model, max_len=max_len, dropout=dropout)\n", " self.bridge = ComplexToRealBridge(d_model)\n", " self.periscope_proj = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU())\n", "\n", " # 2. Refiner (The Digital/Active Part)\n", " # šŸ”„ SWAPPED: Replaced Standard Transformer with RoPE-Enabled Encoder\n", " if trans_depth > 0:\n", " self.refiner = Encoder(\n", " dim=d_model,\n", " depth=trans_depth,\n", " heads=8,\n", " rotary_pos_emb=True,\n", " attn_flash=True,\n", " attn_dropout=dropout,\n", " ff_dropout=dropout,\n", "\n", " )\n", " else:\n", " self.refiner = None\n", "\n", " # 3. Output\n", " self.lm_head = nn.Linear(d_model, vocab_size)\n", " self.lm_head.weight = self.rose.raw_embedding.weight\n", "\n", " def forward(self, input_ids):\n", " # A. Wave Physics\n", " wave_src, particle_src = self.rose(input_ids)\n", " wave_out = self.prism_encoder(wave_src)\n", " wave_real = self.bridge(wave_out)\n", "\n", " # B. Interface\n", " mixed_memory = self.periscope_proj(torch.cat([wave_real, particle_src], dim=-1))\n", "\n", " # C. Digital Refinement (Now with RoPE)\n", " if self.refiner:\n", " out = self.refiner(mixed_memory)\n", " else:\n", " out = mixed_memory\n", "\n", " return self.lm_head(out)\n", "\n", "# ==========================================\n", "# 1. SENSORY STREAM (Transformer + RoPE)\n", "# ==========================================\n", "class SensoryStream(nn.Module):\n", " def __init__(self, depth, d_model, dropout=0.1):\n", " super().__init__()\n", " self.encoder = Encoder(\n", " dim=d_model,\n", " depth=depth,\n", " heads=4, # 256 dim / 64 head_dim = 4 heads\n", " attn_flash=True, # Flash Attention\n", " rotary_pos_emb=True, # <--- CRITICAL: RoPE Enabled\n", " attn_dropout=dropout,\n", " ff_dropout=dropout,\n", " use_rmsnorm=True, # RMSNorm (Llama style)\n", " ff_glu=True # SwiGLU (Llama style)\n", " )\n", "\n", " def forward(self, x):\n", " return self.encoder(x)\n", "\n", "# ==========================================\n", "# 2. PILLARS-DAT (Dual Attention with RoPE)\n", "# ==========================================\n", "class Pillars_DAT(nn.Module):\n", " def __init__(self, vocab_size, d_model=512, d_branch=256, seq_len=4096, depth=4):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.d_branch = d_branch\n", "\n", " # --- A. SHARED ROOT ---\n", " self.rose = DynamicRoSE(vocab_size, d_model)\n", "\n", " # --- B. DOWNSAMPLE ---\n", " self.particle_down = nn.Linear(d_model, d_branch)\n", " self.wave_down = nn.Linear(d_model * 2, d_branch * 2)\n", "\n", " # --- C. STREAM 1: SENSORY (Object Attributes) ---\n", " # REPLACED: FNet -> Transformer with RoPE\n", " # NOTE: No self.sensory_pos anymore! RoPE handles it.\n", " self.stream_sensory = SensoryStream(depth=depth, d_model=d_branch, dropout=0.1)\n", "\n", " # --- D. STREAM 2: RELATIONAL (Structure / Phase) ---\n", " # PRISM handles positions internally via RoSE frequencies\n", " self.stream_relational = PRISMEncoder(num_layers=depth, d_model=d_branch, max_len=seq_len, dropout=0.1)\n", " self.relational_bridge = ComplexToRealBridge(d_branch)\n", "\n", " # --- E. FUSION ---\n", " self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n", " self.fusion_norm = nn.LayerNorm(d_model)\n", "\n", " # --- F. REFINER ---\n", " self.refiner = Encoder(\n", " dim=d_model, depth=1, heads=8, attn_flash=True,\n", " rotary_pos_emb=True, attn_dropout=0.1, ff_dropout=0.1\n", " )\n", "\n", " # --- G. OUTPUT ---\n", " self.head_bias = nn.Parameter(torch.zeros(vocab_size))\n", "\n", " def forward(self, input_ids):\n", " # 1. Root Physics\n", " wave_src, particle_src = self.rose(input_ids)\n", "\n", " # 2. Downsample\n", " p_small = self.particle_down(particle_src)\n", "\n", " # Prepare complex wave input\n", " w_flat = torch.cat([wave_src.real, wave_src.imag], dim=-1)\n", " w_small_flat = self.wave_down(w_flat)\n", " w_small = torch.complex(w_small_flat[..., :self.d_branch], w_small_flat[..., self.d_branch:])\n", "\n", " # 3. Parallel Processing\n", "\n", " # --- Stream A: Sensory (Transformer + RoPE) ---\n", " # Pass pure features. RoPE adds position info inside the attention layer.\n", " sensory_out = self.stream_sensory(p_small)\n", "\n", " # --- Stream B: Relational (PRISM) ---\n", " relational_out_complex = self.stream_relational(w_small)\n", " relational_out = self.relational_bridge(relational_out_complex)\n", "\n", " # 4. Integration\n", " stacked = torch.cat([sensory_out, relational_out], dim=-1)\n", " context = self.fusion_norm(self.fusion_proj(stacked))\n", "\n", " # 5. Refinement\n", " refined = self.refiner(context)\n", "\n", " # 6. Output\n", " logits = F.linear(refined, self.rose.raw_embedding.weight, self.head_bias)\n", "\n", " return logits\n", "\n", "import torch\n", "import torch.nn as nn\n", "from prettytable import PrettyTable # Optional, but makes tables nice.\n", "# If you don't have prettytable, the code below uses standard f-strings.\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "def deep_analyze_pillars(model):\n", " def get_p(obj):\n", " \"\"\"Safely returns parameter count for Modules OR raw Parameters.\"\"\"\n", " if isinstance(obj, nn.Parameter):\n", " return obj.numel()\n", " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n", "\n", " def format_num(n):\n", " if n > 1e6: return f\"{n/1e6:.2f}M\"\n", " if n > 1e3: return f\"{n/1e3:.2f}K\"\n", " return str(n)\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(f\"šŸ—ļø PILLARS (COMPACT) - DEEP LAYER ANALYSIS\")\n", " print(\"=\"*80)\n", " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<15} | {'TYPE'}\")\n", " print(\"-\" * 80)\n", "\n", " total_params = get_p(model)\n", "\n", " # -----------------------------------------------\n", " # 1. STATIC MEMORY (Embeddings)\n", " # -----------------------------------------------\n", " vocab_emb = get_p(model.rose.raw_embedding)\n", " fnet_pos = get_p(model.fnet_pos)\n", "\n", " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<15} | šŸ’¾ STORAGE\")\n", " print(f\"{'FNet Positional Embedding':<40} | {format_num(fnet_pos):<15} | šŸ’¾ STORAGE\")\n", "\n", " # -----------------------------------------------\n", " # 2. INPUT LOGIC (RoSE & Downsampling)\n", " # -----------------------------------------------\n", " rose_total = get_p(model.rose)\n", " rose_logic = rose_total - vocab_emb # Subtract the embedding matrix we already counted\n", "\n", " print(\"-\" * 80)\n", " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<15} | 🌊 PHASE INIT\")\n", " print(f\"{'Particle Downsample (512->384)':<40} | {format_num(get_p(model.particle_down)):<15} | šŸ“‰ PROJ\")\n", " print(f\"{'Wave Downsample (1024->768)':<40} | {format_num(get_p(model.wave_down)):<15} | šŸ“‰ PROJ\")\n", "\n", " # -----------------------------------------------\n", " # 3. STREAM A: RATE (FNet)\n", " # -----------------------------------------------\n", " print(\"-\" * 80)\n", " print(f\"TRACK A: RATE STREAM (FNet) - Depth {len(model.stream_rate.layers)}\")\n", "\n", " fnet_encoder_total = 0\n", " for i, layer in enumerate(model.stream_rate.layers):\n", " p = get_p(layer)\n", " fnet_encoder_total += p\n", " print(f\" ā”œā”€ FNet Block {i:<24} | {format_num(p):<15} | ⚔ RATE\")\n", "\n", " fnet_norm = get_p(model.stream_rate.norm_out)\n", " fnet_encoder_total += fnet_norm\n", " print(f\" └─ Final Norm {i:<24} | {format_num(fnet_norm):<15} | ⚔ RATE\")\n", "\n", " # -----------------------------------------------\n", " # 4. STREAM B: PHASE (PRISM)\n", " # -----------------------------------------------\n", " print(\"-\" * 80)\n", " print(f\"TRACK B: PHASE STREAM (PRISM) - Depth {len(model.stream_phase.layers)}\")\n", "\n", " prism_encoder_total = 0\n", " for i, layer in enumerate(model.stream_phase.layers):\n", " p = get_p(layer)\n", " prism_encoder_total += p\n", " print(f\" ā”œā”€ PRISM Block {i:<23} | {format_num(p):<15} | 🌊 PHASE\")\n", "\n", " prism_norm = get_p(model.stream_phase.final_norm)\n", " prism_encoder_total += prism_norm\n", " print(f\" └─ Final Norm {i:<24} | {format_num(prism_norm):<15} | 🌊 PHASE\")\n", "\n", " bridge_p = get_p(model.phase_bridge)\n", " print(f\"{'Phase Bridge (Complex->Real)':<40} | {format_num(bridge_p):<15} | šŸŒ‰ BRIDGE\")\n", "\n", " # -----------------------------------------------\n", " # 5. THE BRAIN (Fusion & Refiner)\n", " # -----------------------------------------------\n", " print(\"-\" * 80)\n", " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n", " print(f\"{'Fusion (Concat -> Proj -> Norm)':<40} | {format_num(fusion_p):<15} | 🧠 FUSION\")\n", "\n", " refiner_p = get_p(model.refiner)\n", " print(f\"{'Transformer Refiner (1 Layer)':<40} | {format_num(refiner_p):<15} | 🧠 ATTENTION\")\n", "\n", " # [FIX] Handle nn.Parameter directly\n", " head_bias_p = get_p(model.head_bias)\n", " print(f\"{'Output Head Bias':<40} | {format_num(head_bias_p):<15} | šŸŽÆ OUTPUT\")\n", "\n", " # -----------------------------------------------\n", " # 6. SUMMARY\n", " # -----------------------------------------------\n", " print(\"=\"*80)\n", "\n", " storage = vocab_emb + fnet_pos + head_bias_p\n", " active = total_params - storage\n", "\n", " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n", " print(f\" ā”œā”€ šŸ’¾ Storage: {storage/1e6:.2f} M (Embeddings)\")\n", " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Logic/Weights)\")\n", " print(\"-\" * 80)\n", " print(f\"STREAM BREAKDOWN:\")\n", " print(f\" ā”œā”€ ⚔ Rate Stream: {fnet_encoder_total/1e6:.2f} M\")\n", " print(f\" └─ 🌊 Phase Stream: {prism_encoder_total/1e6:.2f} M\")\n", " print(\"=\"*80 + \"\\n\")\n", "\n", " return total_params\n" ], "metadata": { "id": "V7DOwmmUjyin" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "# Run the parameter analysis to confirm strict adherence to budget\n", "def deep_analyze_pillars_dat(model):\n", " def get_p(obj):\n", " if isinstance(obj, nn.Parameter): return obj.numel()\n", " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n", "\n", " def format_num(n):\n", " if n > 1e6: return f\"{n/1e6:.2f}M\"\n", " if n > 1e3: return f\"{n/1e3:.2f}K\"\n", " return str(n)\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(f\"šŸ›ļø PILLARS-DAT (Hybrid Transformer-PRISM) - ANALYSIS\")\n", " print(\"=\"*80)\n", " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<12} | {'TYPE'}\")\n", " print(\"-\" * 80)\n", "\n", " total_params = get_p(model)\n", "\n", " # --- 1. MEMORY ---\n", " vocab_emb = get_p(model.rose.raw_embedding)\n", " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<12} | šŸ’¾ STORAGE\")\n", "\n", " # --- 2. INPUT PHYSICS ---\n", " rose_logic = get_p(model.rose) - vocab_emb\n", " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<12} | 🌊 PHYSICS\")\n", "\n", " down_p = get_p(model.particle_down) + get_p(model.wave_down)\n", " print(f\"{'Stream Splitters (Downsample)':<40} | {format_num(down_p):<12} | šŸ“‰ PROJ\")\n", "\n", " # --- 3. STREAM A: SENSORY (TRANSFORMER) ---\n", " print(\"-\" * 80)\n", " print(f\"STREAM A: SENSORY (Identity/Magnitude)\")\n", " sensory_p = get_p(model.stream_sensory)\n", " # Attempt to count depth if accessible, else generic\n", " try:\n", " depth_s = len(model.stream_sensory.encoder.layers)\n", " print(f\" ā”œā”€ Transformer Encoder (Depth {depth_s}) | {format_num(sensory_p):<12} | ⚔ ATTENTION\")\n", " except:\n", " print(f\" ā”œā”€ Transformer Encoder (Fused) | {format_num(sensory_p):<12} | ⚔ ATTENTION\")\n", "\n", " # --- 4. STREAM B: RELATIONAL (PRISM) ---\n", " print(\"-\" * 80)\n", " print(f\"STREAM B: RELATIONAL (Structure/Phase)\")\n", " relational_core = get_p(model.stream_relational)\n", " relational_bridge = get_p(model.relational_bridge)\n", "\n", " try:\n", " depth_r = len(model.stream_relational.layers)\n", " print(f\" ā”œā”€ PRISM Encoder (Depth {depth_r}) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n", " except:\n", " print(f\" ā”œā”€ PRISM Encoder (Fused) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n", "\n", " print(f\" └─ Bridge (Complex->Real) | {format_num(relational_bridge):<12} | šŸŒ‰ PROJ\")\n", "\n", " # --- 5. FUSION & OUTPUT ---\n", " print(\"-\" * 80)\n", " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n", " print(f\"{'Fusion (Concat -> Proj)':<40} | {format_num(fusion_p):<12} | 🧠 MIX\")\n", "\n", " refiner_p = get_p(model.refiner)\n", " print(f\"{'Refiner (1-Layer Transformer)':<40} | {format_num(refiner_p):<12} | 🧠 REASONING\")\n", "\n", " bias_p = get_p(model.head_bias)\n", " print(f\"{'Output Head Bias':<40} | {format_num(bias_p):<12} | šŸŽÆ OUT\")\n", "\n", " # --- SUMMARY ---\n", " print(\"=\"*80)\n", " storage = vocab_emb + bias_p\n", " active = total_params - storage\n", "\n", " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n", " print(f\" ā”œā”€ šŸ’¾ Storage: {storage/1e6:.2f} M (Embeddings)\")\n", " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Active Weights)\")\n", " print(\"-\" * 80)\n", " print(f\"RATIO CHECK:\")\n", " print(f\" ⚔ Sensory (Transf): {sensory_p/1e6:.2f} M\")\n", " print(f\" 🌊 Relation (PRISM): {(relational_core + relational_bridge)/1e6:.2f} M\")\n", " print(\"=\"*80 + \"\\n\")\n" ], "metadata": { "id": "ke4fYT8UX5zH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ==========================================\n", "# 4. LOGGING & ANALYSIS UTILITIES\n", "# ==========================================\n", "def deep_analyze_pillars_dat(model):\n", " def get_p(obj):\n", " if isinstance(obj, nn.Parameter): return obj.numel()\n", " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n", "\n", " def format_num(n):\n", " if n > 1e6: return f\"{n/1e6:.2f}M\"\n", " if n > 1e3: return f\"{n/1e3:.2f}K\"\n", " return str(n)\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(f\"šŸ›ļø PILLARS-DAT (Hybrid Transformer-PRISM) - ANALYSIS\")\n", " print(\"=\"*80)\n", " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<12} | {'TYPE'}\")\n", " print(\"-\" * 80)\n", "\n", " total_params = get_p(model)\n", "\n", " # --- 1. MEMORY ---\n", " vocab_emb = get_p(model.rose.raw_embedding)\n", " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<12} | šŸ’¾ STORAGE\")\n", "\n", " # --- 2. INPUT PHYSICS ---\n", " rose_logic = get_p(model.rose) - vocab_emb\n", " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<12} | 🌊 PHYSICS\")\n", "\n", " down_p = get_p(model.particle_down) + get_p(model.wave_down)\n", " print(f\"{'Stream Splitters (Downsample)':<40} | {format_num(down_p):<12} | šŸ“‰ PROJ\")\n", "\n", " # --- 3. STREAM A: SENSORY (TRANSFORMER) ---\n", " print(\"-\" * 80)\n", " print(f\"STREAM A: SENSORY (Identity/Magnitude)\")\n", " sensory_p = get_p(model.stream_sensory)\n", " try:\n", " depth_s = len(model.stream_sensory.encoder.layers)\n", " print(f\" ā”œā”€ Transformer Encoder (Depth {depth_s}) | {format_num(sensory_p):<12} | ⚔ ATTENTION\")\n", " except:\n", " print(f\" ā”œā”€ Transformer Encoder (Fused) | {format_num(sensory_p):<12} | ⚔ ATTENTION\")\n", "\n", " # --- 4. STREAM B: RELATIONAL (PRISM) ---\n", " print(\"-\" * 80)\n", " print(f\"STREAM B: RELATIONAL (Structure/Phase)\")\n", " relational_core = get_p(model.stream_relational)\n", " relational_bridge = get_p(model.relational_bridge)\n", "\n", " try:\n", " depth_r = len(model.stream_relational.layers)\n", " print(f\" ā”œā”€ PRISM Encoder (Depth {depth_r}) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n", " except:\n", " print(f\" ā”œā”€ PRISM Encoder (Fused) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n", "\n", " print(f\" └─ Bridge (Complex->Real) | {format_num(relational_bridge):<12} | šŸŒ‰ PROJ\")\n", "\n", " # --- 5. FUSION & OUTPUT ---\n", " print(\"-\" * 80)\n", " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n", " print(f\"{'Fusion (Concat -> Proj)':<40} | {format_num(fusion_p):<12} | 🧠 MIX\")\n", "\n", " refiner_p = get_p(model.refiner)\n", " print(f\"{'Refiner (1-Layer Transformer)':<40} | {format_num(refiner_p):<12} | 🧠 REASONING\")\n", "\n", " bias_p = get_p(model.head_bias)\n", " print(f\"{'Output Head Bias':<40} | {format_num(bias_p):<12} | šŸŽÆ OUT\")\n", "\n", " # --- SUMMARY ---\n", " print(\"=\"*80)\n", " storage = vocab_emb + bias_p\n", " active = total_params - storage\n", "\n", " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n", " print(f\" ā”œā”€ šŸ’¾ Storage: {storage/1e6:.2f} M (Embeddings)\")\n", " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Active Weights)\")\n", " print(\"-\" * 80)\n", " print(f\"RATIO CHECK:\")\n", " print(f\" ⚔ Sensory (Transf): {sensory_p/1e6:.2f} M\")\n", " print(f\" 🌊 Relation (PRISM): {(relational_core + relational_bridge)/1e6:.2f} M\")\n", " print(\"=\"*80 + \"\\n\")\n", "\n", "def generate_run_id():\n", " raw = datetime.now().strftime(\"%Y%m%d%H%M%S%f\")\n", " return hashlib.md5(raw.encode()).hexdigest()[:8]\n", "\n", "def log_environment(save_dir, run_id, config):\n", " log_path = os.path.join(save_dir, f\"env_metadata_{run_id}.txt\")\n", " with open(log_path, \"w\") as f:\n", " f.write(f\"PRISM EXPERIMENT METADATA | Run ID: {run_id}\\n{'='*50}\\n\")\n", " for k, v in config.items(): f.write(f\"{k}: {v}\\n\")\n", " print(f\"šŸ“ Environment Snapshot saved to: {log_path}\")\n", "\n", "def log_metrics(save_dir, run_id, epoch, train_loss, val_loss, ppl):\n", " log_path = os.path.join(save_dir, f\"metrics_log_{run_id}.csv\")\n", " if not os.path.exists(log_path):\n", " with open(log_path, \"w\") as f: f.write(\"Timestamp,Epoch,Train_Loss,Val_Loss,Perplexity\\n\")\n", " with open(log_path, \"a\") as f:\n", " ts = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " f.write(f\"{ts},{epoch},{train_loss:.6f},{val_loss:.6f},{ppl:.6f}\\n\")\n", "\n", "def save_checkpoint(path, model, optimizer, scheduler, scaler, epoch, best_loss, config):\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': scheduler.state_dict(),\n", " 'scaler_state_dict': scaler.state_dict(), # <--- IMPORTANT FOR AMP\n", " 'best_val_loss': best_loss,\n", " 'config': config\n", " }, path)\n", "\n", "# ==========================================\n", "# 5. A100 TRAINING LOOP (WITH LOGGING)\n", "# ==========================================\n", "# ==========================================\n", "# 4. LOGGING & ANALYSIS UTILITIES\n", "# ==========================================\n", "def deep_analyze_pillars_dat(model):\n", " def get_p(obj):\n", " if isinstance(obj, nn.Parameter): return obj.numel()\n", " return sum(p.numel() for p in obj.parameters() if p.requires_grad)\n", "\n", " def format_num(n):\n", " if n > 1e6: return f\"{n/1e6:.2f}M\"\n", " if n > 1e3: return f\"{n/1e3:.2f}K\"\n", " return str(n)\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(f\"šŸ›ļø PILLARS-DAT (Hybrid Transformer-PRISM) - ANALYSIS\")\n", " print(\"=\"*80)\n", " print(f\"{'MODULE / LAYER':<40} | {'PARAMS':<12} | {'TYPE'}\")\n", " print(\"-\" * 80)\n", "\n", " total_params = get_p(model)\n", "\n", " # --- 1. MEMORY ---\n", " vocab_emb = get_p(model.rose.raw_embedding)\n", " print(f\"{'Shared Vocab Embedding':<40} | {format_num(vocab_emb):<12} | šŸ’¾ STORAGE\")\n", "\n", " # --- 2. INPUT PHYSICS ---\n", " rose_logic = get_p(model.rose) - vocab_emb\n", " print(f\"{'Dynamic RoSE (Adapters)':<40} | {format_num(rose_logic):<12} | 🌊 PHYSICS\")\n", "\n", " down_p = get_p(model.particle_down) + get_p(model.wave_down)\n", " print(f\"{'Stream Splitters (Downsample)':<40} | {format_num(down_p):<12} | šŸ“‰ PROJ\")\n", "\n", " # --- 3. STREAM A: SENSORY (TRANSFORMER) ---\n", " print(\"-\" * 80)\n", " print(f\"STREAM A: SENSORY (Identity/Magnitude)\")\n", " sensory_p = get_p(model.stream_sensory)\n", " try:\n", " depth_s = len(model.stream_sensory.encoder.layers)\n", " print(f\" ā”œā”€ Transformer Encoder (Depth {depth_s}) | {format_num(sensory_p):<12} | ⚔ ATTENTION\")\n", " except:\n", " print(f\" ā”œā”€ Transformer Encoder (Fused) | {format_num(sensory_p):<12} | ⚔ ATTENTION\")\n", "\n", " # --- 4. STREAM B: RELATIONAL (PRISM) ---\n", " print(\"-\" * 80)\n", " print(f\"STREAM B: RELATIONAL (Structure/Phase)\")\n", " relational_core = get_p(model.stream_relational)\n", " relational_bridge = get_p(model.relational_bridge)\n", "\n", " try:\n", " depth_r = len(model.stream_relational.layers)\n", " print(f\" ā”œā”€ PRISM Encoder (Depth {depth_r}) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n", " except:\n", " print(f\" ā”œā”€ PRISM Encoder (Fused) | {format_num(relational_core):<12} | 🌊 SPECTRAL\")\n", "\n", " print(f\" └─ Bridge (Complex->Real) | {format_num(relational_bridge):<12} | šŸŒ‰ PROJ\")\n", "\n", " # --- 5. FUSION & OUTPUT ---\n", " print(\"-\" * 80)\n", " fusion_p = get_p(model.fusion_proj) + get_p(model.fusion_norm)\n", " print(f\"{'Fusion (Concat -> Proj)':<40} | {format_num(fusion_p):<12} | 🧠 MIX\")\n", "\n", " refiner_p = get_p(model.refiner)\n", " print(f\"{'Refiner (1-Layer Transformer)':<40} | {format_num(refiner_p):<12} | 🧠 REASONING\")\n", "\n", " bias_p = get_p(model.head_bias)\n", " print(f\"{'Output Head Bias':<40} | {format_num(bias_p):<12} | šŸŽÆ OUT\")\n", "\n", " # --- SUMMARY ---\n", " print(\"=\"*80)\n", " storage = vocab_emb + bias_p\n", " active = total_params - storage\n", "\n", " print(f\"TOTAL PARAMETERS: {total_params/1e6:.2f} M\")\n", " print(f\" ā”œā”€ šŸ’¾ Storage: {storage/1e6:.2f} M (Embeddings)\")\n", " print(f\" └─ 🧠 Compute: {active/1e6:.2f} M (Active Weights)\")\n", " print(\"-\" * 80)\n", " print(f\"RATIO CHECK:\")\n", " print(f\" ⚔ Sensory (Transf): {sensory_p/1e6:.2f} M\")\n", " print(f\" 🌊 Relation (PRISM): {(relational_core + relational_bridge)/1e6:.2f} M\")\n", " print(\"=\"*80 + \"\\n\")\n", "\n", "def init_pillars_dat_weights(model):\n", " print(\"✨ APPLYING PILLARS-DAT INITIALIZATION PROTOCOL...\")\n", " # 1. SHARED ROOT (RoSE)\n", " nn.init.normal_(model.rose.raw_embedding.weight, std=model.d_model ** -0.5)\n", " nn.init.orthogonal_(model.rose.adapter.weight)\n", "\n", " # --- ROSE IDENTITY TRICK ---\n", " nn.init.normal_(model.rose.rotation_predictor.weight, std=0.01)\n", " with torch.no_grad():\n", " model.rose.rotation_predictor.bias[:model.d_model].fill_(1.0) # Real=1\n", " model.rose.rotation_predictor.bias[model.d_model:].fill_(0.0) # Imag=0\n", "\n", " # 2. DOWNSAMPLERS\n", " nn.init.orthogonal_(model.particle_down.weight, gain=1.414)\n", " nn.init.orthogonal_(model.wave_down.weight, gain=1.414)\n", "\n", " # 3. SENSORY STREAM (Transformer + RoPE)\n", " print(\" ā”œā”€ Initializing Sensory Stream (Transformer)...\")\n", " for name, p in model.stream_sensory.named_parameters():\n", " if p.dim() > 1:\n", " nn.init.xavier_uniform_(p)\n", " elif \"norm\" in name.lower() and p.dim() == 1:\n", " if \"weight\" in name: nn.init.ones_(p)\n", " if \"bias\" in name: nn.init.zeros_(p)\n", "\n", " # 4. RELATIONAL STREAM (PRISM)\n", " print(\" ā”œā”€ Initializing Relational Stream (PRISM)...\")\n", " for name, m in model.stream_relational.named_modules():\n", " if isinstance(m, nn.Linear):\n", " nn.init.xavier_uniform_(m.weight, gain=1.0)\n", " if m.bias is not None: nn.init.zeros_(m.bias)\n", " if isinstance(m, ModReLU):\n", " nn.init.constant_(m.b, 0.01)\n", "\n", " # 5. FUSION & REFINER\n", " nn.init.xavier_uniform_(model.fusion_proj.weight, gain=1.0)\n", " for p in model.refiner.parameters():\n", " if p.dim() > 1: nn.init.xavier_uniform_(p)\n", "\n", " # 6. TIED HEAD BIAS\n", " nn.init.zeros_(model.head_bias)\n", " print(\"āœ… DAT INITIALIZATION COMPLETE.\")\n", "\n", "# ==========================================\n", "# 5. A100 TRAINING LOOP (WITH LOGGING)\n", "# ==========================================\n", "def run_a100_training(experiment_name=\"PILLARS_DAT_A100_Final\"):\n", " from torch.cuda.amp import autocast, GradScaler\n", " from torch.utils.tensorboard import SummaryWriter\n", "\n", " # --- 1. SETUP DRIVE & LOGGING ---\n", " from google.colab import drive\n", " if not os.path.exists('/content/drive'): drive.mount('/content/drive')\n", "\n", " run_id = generate_run_id()\n", " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", " SAVE_DIR = os.path.join(\"/content/drive/My Drive/PRISM_Experiments\", f\"{experiment_name}_{timestamp}_{run_id}\")\n", " os.makedirs(SAVE_DIR, exist_ok=True)\n", "\n", " writer = SummaryWriter(log_dir=SAVE_DIR)\n", "\n", " # Config for Logs\n", " config_dump = {\n", " \"run_id\": run_id, \"batch_size\": 6, \"accum\": 8, \"d_model\": D_MODEL, \"depth\": DEPTH, \"seq_len\": SEQ_LEN\n", " }\n", " log_environment(SAVE_DIR, run_id, config_dump)\n", "\n", " # --- 2. MODEL & DATA ---\n", " SAFE_BATCH_SIZE = BATCH_SIZE\n", " GRAD_ACCUM = 4\n", " print(f\"\\n⚔ A100 DETECTED. CONFIGURING FLASH ATTENTION PIPELINE...\")\n", "\n", " lm_datasets, data_collator = prepare_data_from_hub()\n", " train_loader = DataLoader(lm_datasets[\"train\"], batch_size=SAFE_BATCH_SIZE, shuffle=True, collate_fn=data_collator, num_workers=4, pin_memory=True)\n", " valid_loader = DataLoader(lm_datasets[\"validation\"], batch_size=SAFE_BATCH_SIZE, collate_fn=data_collator, num_workers=2)\n", "\n", " model = Pillars_DAT(vocab_size=VOCAB_SIZE, d_model=D_MODEL, d_branch=D_BRANCH, seq_len=SEQ_LEN, depth=DEPTH).to(DEVICE)\n", " init_pillars_dat_weights(model)\n", " print(model)\n", " deep_analyze_pillars_dat(model) # <--- Parameter Analysis\n", "\n", " optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)\n", " total_steps = (len(train_loader) // GRAD_ACCUM) * EPOCHS\n", " warmup_steps = int(total_steps * 0.1)\n", " scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n", " criterion = nn.CrossEntropyLoss()\n", " scaler = GradScaler() # For AMP\n", "\n", " print(f\"\\nšŸš€ IGNITING FUSION DRIVE... Saving to: {SAVE_DIR}\")\n", "\n", " global_step = 0\n", " best_val_loss = float('inf')\n", "\n", " for epoch in range(EPOCHS):\n", " model.train()\n", " pbar = tqdm(train_loader, desc=f\"Ep {epoch+1}\")\n", "\n", " for step, batch in enumerate(pbar):\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", "\n", " # ⚔ AMP CONTEXT\n", " with autocast(dtype=torch.float16):\n", " logits = model(x).view(-1, VOCAB_SIZE)\n", " loss = criterion(logits, y.view(-1)) / GRAD_ACCUM\n", "\n", " scaler.scale(loss).backward()\n", "\n", " if (step + 1) % GRAD_ACCUM == 0:\n", " scaler.unscale_(optimizer)\n", " # šŸ›‘ CALC GRAD NORM HERE FOR PBAR šŸ›‘\n", " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " scaler.step(optimizer)\n", " scaler.update()\n", " scheduler.step()\n", " optimizer.zero_grad()\n", " global_step += 1\n", "\n", " # šŸ“ STEP LOGGING\n", " actual_loss = loss.item() * GRAD_ACCUM\n", " writer.add_scalar('Train/Loss', actual_loss, global_step)\n", " writer.add_scalar('Train/GradNorm', grad_norm.item(), global_step)\n", " writer.add_scalar('Train/LR', scheduler.get_last_lr()[0], global_step)\n", "\n", " # ✨ UPDATE PBAR WITH GNORM ✨\n", " pbar.set_postfix({\n", " 'loss': f\"{actual_loss:.4f}\",\n", " 'gnorm': f\"{grad_norm.item():.2f}\"\n", " })\n", "\n", " # --- VALIDATION ---\n", " model.eval()\n", " val_loss = 0\n", " with torch.no_grad(), autocast():\n", " for batch in valid_loader:\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", " val_loss += criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)).item()\n", "\n", " avg_val_loss = val_loss / len(valid_loader)\n", " # Prevent overflow if loss is exploding\n", " ppl = math.exp(avg_val_loss) if avg_val_loss < 20 else float('inf')\n", "\n", " print(f\"✨ Ep {epoch+1} | Val Loss: {avg_val_loss:.4f} | PPL: {ppl:.2f}\")\n", "\n", " # šŸ“ EPOCH LOGGING\n", " writer.add_scalar('Val/Loss', avg_val_loss, epoch+1)\n", " writer.add_scalar('Val/PPL', ppl, epoch+1)\n", " log_metrics(SAVE_DIR, run_id, epoch+1, 0.0, avg_val_loss, ppl)\n", "\n", " # šŸ’¾ SAVE CHECKPOINTS (Includes Scaler/Optim/Sched)\n", " save_checkpoint(os.path.join(SAVE_DIR, \"last.pt\"), model, optimizer, scheduler, scaler, epoch, best_val_loss, config_dump)\n", "\n", " if avg_val_loss < best_val_loss:\n", " best_val_loss = avg_val_loss\n", " print(f\" šŸ† New Best Model! Saving best.pt...\")\n", " save_checkpoint(os.path.join(SAVE_DIR, \"best.pt\"), model, optimizer, scheduler, scaler, epoch, best_val_loss, config_dump)\n", "\n", " writer.close()\n", " return model\n", "\n", "if __name__ == \"__main__\":\n", " run_a100_training()" ], "metadata": { "id": "-TNEv89gkS1k" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import runtime\n", "runtime.unassign()" ], "metadata": { "id": "bxFTYWHVqcSI" }, "execution_count": null, "outputs": [] } ] }