{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install -q x-transformers" ], "metadata": { "id": "6q9RTvlf5IiS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import math\n", "import os\n", "import sys\n", "import subprocess\n", "import hashlib\n", "import gc\n", "from datetime import datetime\n", "from tqdm.auto import tqdm\n", "from torch.utils.data import DataLoader\n", "from torch.utils.tensorboard import SummaryWriter\n", "from transformers import RobertaTokenizerFast, get_cosine_schedule_with_warmup, DataCollatorForLanguageModeling\n", "from datasets import load_dataset\n", "from x_transformers import Encoder\n", "\n", "# ==========================================\n", "# 1. CONFIGURATION\n", "# ==========================================\n", "# YOUR REPO ID (Created in previous step)\n", "HF_ID = \"prism-lab/wikitext-103-prism-32k-seq4k\"\n", "\n", "# Hyperparameters\n", "VOCAB_SIZE = 32768\n", "SEQ_LEN = 4096\n", "BATCH_SIZE = 8\n", "EPOCHS = 40\n", "LR = 1e-3\n", "D_MODEL = 512\n", "DEPTH = 6\n", "DROPOUT = 0.1\n", "RESUME_PATH = None\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "torch.set_float32_matmul_precision(\"high\")\n", "\n", "# ==========================================\n", "# 2. DATA PIPELINE (The \"Pro\" Way)\n", "# ==========================================\n", "def prepare_data_from_hub():\n", " print(f\"ā¬‡ļø Pulling Pre-Tokenized Data from {HF_ID}...\")\n", "\n", " # 1. Load Tokenizer (Instant)\n", " # This pulls the exact tokenizer you uploaded\n", " tokenizer = RobertaTokenizerFast.from_pretrained(HF_ID)\n", "\n", " # 2. Load Dataset (Instant)\n", " # This pulls the already chunked/tokenized data\n", " dataset = load_dataset(HF_ID)\n", "\n", " print(f\"āœ… Loaded {len(dataset['train'])} training chunks.\")\n", "\n", " # 3. Collator\n", " data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=True,\n", " mlm_probability=0.15\n", " )\n", "\n", " return dataset, data_collator\n", "# ==========================================\n", "# 3. PRISM ARCHITECTURE (Complex-Valued)\n", "# ==========================================\n", "\n", "class ComplexDropout(nn.Module):\n", " def __init__(self, p=0.5):\n", " super().__init__()\n", " self.p = p\n", " def forward(self, z):\n", " if not self.training or self.p == 0.0: return z\n", " mask = torch.ones_like(z.real)\n", " mask = F.dropout(mask, self.p, self.training, inplace=False)\n", " return z * mask\n", "\n", "class RobustPhaseNorm(nn.Module):\n", " def __init__(self, d_model, eps=1e-5):\n", " super().__init__()\n", " self.scale = nn.Parameter(torch.ones(d_model))\n", " self.eps = eps\n", " def forward(self, x):\n", " mag = torch.abs(x)\n", " rms = torch.sqrt(torch.mean(mag**2, dim=-1, keepdim=True) + self.eps)\n", " return (x / rms) * self.scale\n", "\n", "class ModReLU(nn.Module):\n", " def __init__(self, features):\n", " super().__init__()\n", " self.b = nn.Parameter(torch.zeros(features))\n", " def forward(self, z):\n", " mag = torch.abs(z)\n", " new_mag = F.relu(mag + self.b)\n", " phase = z / (mag + 1e-6)\n", " return new_mag * phase\n", "\n", "class ComplexToRealBridge(nn.Module):\n", " def __init__(self, d_model):\n", " super().__init__()\n", " self.proj = nn.Linear(d_model * 2, d_model)\n", " self.norm = nn.LayerNorm(d_model)\n", " def forward(self, x_complex):\n", " cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)\n", " return self.norm(self.proj(cat))\n", "\n", "# ==========================================\n", "# 4. DYNAMIC RoSE (Mamba-3 Engine)\n", "# ==========================================\n", "class DynamicRoSE(nn.Module):\n", " def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):\n", " super().__init__()\n", " self.embedding_dim = embedding_dim\n", "\n", " # 1. Master Real Embedding (The \"Particle\")\n", " self.raw_embedding = nn.Embedding(num_embeddings, embedding_dim)\n", "\n", " # 2. Complex Adapter (The \"Wave\" Magnitude/Initial Phase)\n", " self.adapter = nn.Linear(embedding_dim, embedding_dim * 2)\n", "\n", " # 3. Static Frequencies (Positional)\n", " freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))\n", " self.register_buffer('freqs', freqs)\n", "\n", " self.rotation_predictor = nn.Linear(embedding_dim, embedding_dim * 2)\n", "\n", " def forward(self, input_ids):\n", " # A. Raw Particle\n", " real_base = self.raw_embedding(input_ids)\n", " B, L, D = real_base.shape\n", "\n", " # B. Complex Wave Content\n", " complex_params = self.adapter(real_base)\n", " z_t = torch.complex(complex_params[..., :D], complex_params[..., D:])\n", "\n", " rot_raw = self.rotation_predictor(real_base)\n", " rot_x, rot_y = rot_raw.chunk(2, dim=-1)\n", "\n", " rot_mag = torch.sqrt(rot_x**2 + rot_y**2 + 1e-6)\n", " dynamic_rot = torch.complex(rot_x / rot_mag, rot_y / rot_mag)\n", "\n", " # D. Static Positional Rotation\n", " pos = torch.arange(L, device=input_ids.device).float()\n", " static_angles = torch.outer(pos, self.freqs) # [L, D]\n", " static_rot = torch.polar(torch.ones_like(static_angles), static_angles) # [L, D]\n", "\n", " z_final = z_t * static_rot.unsqueeze(0) * dynamic_rot\n", "\n", " return z_final, real_base\n", "\n", "# ==========================================\n", "# 5. HYENA FILTER\n", "# ==========================================\n", "class HyenaNeuralFilter(nn.Module):\n", " def __init__(self, d_model, max_len=1024, hidden_dim=64):\n", " super().__init__()\n", " self.d_model = d_model\n", " freqs = torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / hidden_dim))\n", " self.register_buffer(\"freqs\", freqs)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n", " nn.Linear(hidden_dim, hidden_dim), nn.SiLU(),\n", " nn.Linear(hidden_dim, d_model * 2)\n", " )\n", " def forward(self, L, device):\n", " t = torch.linspace(0, 1, steps=L, device=device).unsqueeze(-1)\n", " emb = torch.cat([torch.sin(t * self.freqs), torch.cos(t * self.freqs)], dim=-1)\n", " out = self.mlp(emb).view(L, self.d_model, 2)\n", " return torch.complex(out[..., 0], out[..., 1])\n", "\n", "# ==========================================\n", "# 6. GATED HARMONIC CONVOLUTION (Lean)\n", "# ==========================================\n", "class GatedHarmonicConvolution(nn.Module):\n", " def __init__(self, d_model, max_len=1024, dropout=0.1):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.filter_len = max_len\n", " self.neural_filter = HyenaNeuralFilter(d_model, max_len=max_len)\n", " self.gate_proj = nn.Linear(d_model * 2, d_model * 2)\n", " self.mix_real = nn.Linear(d_model, d_model)\n", " self.mix_imag = nn.Linear(d_model, d_model)\n", " self.out_real = nn.Linear(d_model, d_model)\n", " self.out_imag = nn.Linear(d_model, d_model)\n", " self.activation = ModReLU(d_model)\n", " self.norm = RobustPhaseNorm(d_model)\n", " self.dropout = ComplexDropout(dropout)\n", "\n", " def forward(self, x, src_mask=None):\n", " residual = x\n", " x_norm = self.norm(x)\n", " if src_mask is not None:\n", " x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0)\n", "\n", " # 1. Global Beam (FFT + Hyena)\n", " B, L, D = x_norm.shape\n", " eff_L = min(L, self.filter_len)\n", " x_freq = torch.fft.fft(x_norm, n=eff_L, dim=1, norm='ortho')\n", " h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n", " x_filtered = x_freq * h\n", " x_time = torch.fft.ifft(x_filtered, n=eff_L, dim=1, norm='ortho')\n", " if L > eff_L: x_time = F.pad(x_time, (0,0,0,L-eff_L))\n", " else: x_time = x_time[:, :L, :]\n", "\n", " # 2. Gating\n", " gates = torch.sigmoid(self.gate_proj(torch.cat([x_norm.real, x_norm.imag], dim=-1)))\n", " g_r, g_i = gates.chunk(2, dim=-1)\n", " x_gated = torch.complex(x_time.real * g_r, x_time.imag * g_i)\n", "\n", " # 3. Mixing & Out\n", " mr, mi = self.mix_real, self.mix_imag\n", " x_mixed = torch.complex(mr(x_gated.real) - mi(x_gated.imag), mr(x_gated.imag) + mi(x_gated.real))\n", " x_act = self.activation(x_mixed)\n", " or_, oi = self.out_real, self.out_imag\n", " out = torch.complex(or_(x_act.real) - oi(x_act.imag), or_(x_act.imag) + oi(x_act.real))\n", " return self.dropout(out) + residual\n", "\n", "# ==========================================\n", "# 7. MODEL WRAPPERS\n", "# ==========================================\n", "class PRISMEncoder(nn.Module):\n", " def __init__(self, num_layers, d_model, max_len, dropout=0.1):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " GatedHarmonicConvolution(d_model, max_len, dropout)\n", " for _ in range(num_layers)\n", " ])\n", " self.final_norm = RobustPhaseNorm(d_model)\n", " def forward(self, x, src_mask=None):\n", " for layer in self.layers:\n", " if self.training: x = torch.utils.checkpoint.checkpoint(layer, x, src_mask, use_reentrant=False)\n", " else: x = layer(x, src_mask)\n", " return self.final_norm(x)\n", "\n", "class PRISM_WikiText_Model(nn.Module):\n", " def __init__(self, vocab_size, d_model, max_len, prism_depth=5, trans_depth=1, dropout=0.1):\n", " super().__init__()\n", " self.d_model = d_model\n", "\n", " # 1. PRISM Core (The Optical/Passive Part)\n", " self.rose = DynamicRoSE(vocab_size, d_model)\n", " self.prism_encoder = PRISMEncoder(prism_depth, d_model, max_len=max_len, dropout=dropout)\n", " self.bridge = ComplexToRealBridge(d_model)\n", " self.periscope_proj = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU())\n", "\n", " # 2. Refiner (The Digital/Active Part)\n", " # šŸ”„ SWAPPED: Replaced Standard Transformer with RoPE-Enabled Encoder\n", " if trans_depth > 0:\n", " self.refiner = Encoder(\n", " dim=d_model,\n", " depth=trans_depth,\n", " heads=8,\n", " rotary_pos_emb=True,\n", " attn_flash=True,\n", " attn_dropout=dropout,\n", " ff_dropout=dropout,\n", "\n", " )\n", " else:\n", " self.refiner = None\n", "\n", " # 3. Output\n", " self.lm_head = nn.Linear(d_model, vocab_size)\n", " self.lm_head.weight = self.rose.raw_embedding.weight\n", "\n", " def forward(self, input_ids):\n", " # A. Wave Physics\n", " wave_src, particle_src = self.rose(input_ids)\n", " wave_out = self.prism_encoder(wave_src)\n", " wave_real = self.bridge(wave_out)\n", "\n", " # B. Interface\n", " mixed_memory = self.periscope_proj(torch.cat([wave_real, particle_src], dim=-1))\n", "\n", " # C. Digital Refinement (Now with RoPE)\n", " if self.refiner:\n", " out = self.refiner(mixed_memory)\n", " else:\n", " out = mixed_memory\n", "\n", " return self.lm_head(out)" ], "metadata": { "id": "V7DOwmmUjyin" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ==========================================\n", "# 4. LOGGING UTILITIES\n", "# ==========================================\n", "def generate_run_id():\n", " raw = datetime.now().strftime(\"%Y%m%d%H%M%S%f\")\n", " return hashlib.md5(raw.encode()).hexdigest()[:8]\n", "\n", "def log_environment(save_dir, run_id, config):\n", " log_path = os.path.join(save_dir, f\"env_metadata_{run_id}.txt\")\n", " with open(log_path, \"w\") as f:\n", " f.write(f\"PRISM EXPERIMENT METADATA | Run ID: {run_id}\\n{'='*50}\\n\")\n", " for k, v in config.items(): f.write(f\"{k}: {v}\\n\")\n", " print(f\"šŸ“ Environment Snapshot saved to: {log_path}\")\n", "\n", "def log_metrics(save_dir, run_id, epoch, train_loss, val_loss, ppl):\n", " log_path = os.path.join(save_dir, f\"metrics_log_{run_id}.csv\")\n", " if not os.path.exists(log_path):\n", " with open(log_path, \"w\") as f: f.write(\"Timestamp,Epoch,Train_Loss,Val_Loss,Perplexity\\n\")\n", " with open(log_path, \"a\") as f:\n", " ts = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " f.write(f\"{ts},{epoch},{train_loss:.6f},{val_loss:.6f},{ppl:.6f}\\n\")\n", "\n", "\n", "def save_checkpoint(path, model, optimizer, scheduler, epoch, best_loss, config):\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': scheduler.state_dict(),\n", " 'best_val_loss': best_loss,\n", " 'config': config\n", " }, path)\n", "\n", "\n", "def run_wikitext_training(experiment_name=\"PRISM2_WT103_40epochs\"):\n", " from google.colab import drive\n", " if not os.path.exists('/content/drive'): drive.mount('/content/drive')\n", "\n", " # --- SETUP DIRS ---\n", " if RESUME_PATH and os.path.exists(RESUME_PATH):\n", " print(f\"šŸ”„ RESUMING FROM: {RESUME_PATH}\")\n", " checkpoint = torch.load(RESUME_PATH, map_location=DEVICE)\n", " SAVE_DIR = os.path.dirname(RESUME_PATH)\n", " run_id = checkpoint.get('config', {}).get('run_id', 'resumed')\n", " else:\n", " run_id = hashlib.md5(datetime.now().strftime(\"%Y%m%d%H%M%S%f\").encode()).hexdigest()[:8]\n", " timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", " folder_name = f\"{experiment_name}_{timestamp}_{run_id}\"\n", " SAVE_DIR = os.path.join(\"/content/drive/My Drive/PRISM_Experiments\", folder_name)\n", " os.makedirs(SAVE_DIR, exist_ok=True)\n", " print(f\"šŸ’¾ Checkpoints: {SAVE_DIR}\")\n", "\n", " writer = SummaryWriter(log_dir=SAVE_DIR)\n", " GRAD_ACCUM = 4\n", "\n", " lm_datasets, data_collator = prepare_data_from_hub()\n", "\n", " # WORKERS=2 (Safe for Colab)\n", " train_loader = DataLoader(\n", " lm_datasets[\"train\"], batch_size=BATCH_SIZE, shuffle=True,\n", " collate_fn=data_collator, num_workers=2, pin_memory=True,\n", " prefetch_factor=2, persistent_workers=True\n", " )\n", " valid_loader = DataLoader(\n", " lm_datasets[\"validation\"], batch_size=BATCH_SIZE,\n", " collate_fn=data_collator, num_workers=2, pin_memory=True\n", " )\n", " test_loader = DataLoader(\n", " lm_datasets[\"test\"], batch_size=BATCH_SIZE,\n", " collate_fn=data_collator, num_workers=2, pin_memory=True\n", " )\n", "\n", " print(\"\\n⚔ INITIALIZING MODEL...\")\n", " model = PRISM_WikiText_Model(\n", " vocab_size=VOCAB_SIZE, d_model=D_MODEL, max_len=SEQ_LEN,\n", " prism_depth=DEPTH-1, trans_depth=1, dropout=DROPOUT\n", " ).to(DEVICE)\n", "\n", " optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.0)\n", " total_steps = (len(train_loader) // GRAD_ACCUM) * EPOCHS\n", " scheduler = get_cosine_schedule_with_warmup(\n", " optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps\n", " )\n", " criterion = nn.CrossEntropyLoss()\n", "\n", " start_epoch = 0\n", " best_val_loss = float('inf')\n", "\n", " if RESUME_PATH and os.path.exists(RESUME_PATH):\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", " start_epoch = checkpoint['epoch'] + 1\n", " best_val_loss = checkpoint['best_val_loss']\n", " del checkpoint\n", " torch.cuda.empty_cache()\n", " else:\n", " def init_weights_PRISM(m):\n", " if isinstance(m, nn.Linear):\n", " nn.init.xavier_uniform_(m.weight)\n", " if m.bias is not None: nn.init.zeros_(m.bias)\n", " elif isinstance(m, nn.Embedding):\n", " nn.init.normal_(m.weight, std=D_MODEL**-0.5)\n", " model.apply(init_weights_PRISM)\n", " nn.init.normal_(model.rose.rotation_predictor.weight, std=0.01)\n", " with torch.no_grad():\n", " model.rose.rotation_predictor.bias[:D_MODEL].fill_(1.0)\n", " model.rose.rotation_predictor.bias[D_MODEL:].fill_(0.0)\n", "\n", " print(f\"\\nšŸš€ STARTING (Ep {start_epoch+1} to {EPOCHS})\")\n", " global_step = (len(train_loader) // GRAD_ACCUM) * start_epoch\n", "\n", " for epoch in range(start_epoch, EPOCHS):\n", " model.train()\n", " pbar = tqdm(train_loader, desc=f\"Ep {epoch+1}/{EPOCHS}\")\n", "\n", " for step, batch in enumerate(pbar):\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", "\n", " # Forward\n", " loss = criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)) / GRAD_ACCUM\n", " loss.backward()\n", "\n", " # Step (Every 4 batches)\n", " if (step + 1) % GRAD_ACCUM == 0:\n", " # 1. Calc Norm\n", " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " # 2. Step\n", " optimizer.step()\n", " scheduler.step()\n", " optimizer.zero_grad()\n", " global_step += 1\n", "\n", " # 3. UPDATE BAR WITH GNORM\n", " actual_loss = loss.item() * GRAD_ACCUM\n", " writer.add_scalar('Train/Loss', actual_loss, global_step)\n", "\n", " # <--- FIXED LINE HERE:\n", " pbar.set_postfix({\n", " 'loss': f\"{actual_loss:.4f}\",\n", " 'gnorm': f\"{grad_norm.item():.2f}\"\n", " })\n", "\n", " # VALIDATION\n", " model.eval()\n", " val_loss = 0\n", " with torch.no_grad():\n", " for batch in valid_loader:\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", " val_loss += criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)).item()\n", "\n", " avg_val_loss = val_loss / len(valid_loader)\n", " ppl = math.exp(avg_val_loss) if avg_val_loss < 100 else float('inf')\n", "\n", " print(f\"✨ Epoch {epoch+1} | Val Loss: {avg_val_loss:.4f} | PPL: {ppl:.2f}\")\n", " writer.add_scalar('Val/PPL', ppl, epoch+1)\n", "\n", " config_dump = {\"epoch\": epoch, \"run_id\": run_id}\n", " save_checkpoint(os.path.join(SAVE_DIR, \"last.pt\"), model, optimizer, scheduler, epoch, best_val_loss, config_dump)\n", "\n", " if avg_val_loss < best_val_loss:\n", " best_val_loss = avg_val_loss\n", " torch.save(model.state_dict(), os.path.join(SAVE_DIR, \"best.pt\"))\n", " print(\" šŸ† New Best Model Saved!\")\n", "\n", " # TEST\n", " best_path = os.path.join(SAVE_DIR, \"best.pt\")\n", " if os.path.exists(best_path):\n", " model.load_state_dict(torch.load(best_path))\n", " model.eval()\n", " test_loss = 0\n", " with torch.no_grad():\n", " for batch in tqdm(test_loader, desc=\"Testing\"):\n", " x, y = batch['input_ids'].to(DEVICE), batch['labels'].to(DEVICE)\n", " test_loss += criterion(model(x).view(-1, VOCAB_SIZE), y.view(-1)).item()\n", " print(f\"šŸ† FINAL PPL: {math.exp(test_loss/len(test_loader)):.2f}\")\n", "\n", " writer.close()\n", " return model" ], "metadata": { "id": "-TNEv89gkS1k" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def analyze_prism_params(model):\n", " print(\"=\"*80)\n", " print(f\"šŸ“Š LEAN PRISM-2 PARAMETER ANALYSIS\")\n", " print(\"=\"*80)\n", " total_params = sum(p.numel() for p in model.parameters())\n", " # Embeddings (Shared)\n", " vocab_params = model.rose.raw_embedding.weight.numel()\n", " # Wave Engine\n", " enc_params = sum(p.numel() for p in model.prism_encoder.parameters())\n", " # Transformer Refiner\n", " ref_params = sum(p.numel() for p in model.refiner.parameters()) if model.refiner else 0\n", " # Other\n", " other_params = total_params - vocab_params - enc_params - ref_params\n", "\n", " print(f\"{'Shared Embeddings (Particle)':<35} | {vocab_params:<15,} | {vocab_params/total_params:.1%} | Tied\")\n", " print(f\"{'PRISM Optical Engine':<35} | {enc_params:<15,} | {enc_params/total_params:.1%} | 5 Layers\")\n", " print(f\"{'Digital Refiner':<35} | {ref_params:<15,} | {ref_params/total_params:.1%} | 1 Layer\")\n", " print(\"=\"*80)\n", " print(f\"{'TOTAL PARAMETERS':<35} | {total_params:<15,} | 100.0%\")\n", " print(\"=\"*80)\n", "\n", "\n", "if __name__ == \"__main__\":\n", "\n", " print(\"šŸ—ļø INSTANTIATING MODEL FOR INSPECTION...\")\n", " # Initialize a temporary model just for counting\n", " dummy_model = PRISM_WikiText_Model(\n", " vocab_size=VOCAB_SIZE,\n", " d_model=D_MODEL,\n", " max_len=SEQ_LEN,\n", " prism_depth=DEPTH-1,\n", " trans_depth=1,\n", " dropout=DROPOUT\n", " )\n", "\n", " # 2. Run Analysis\n", " analyze_prism_params(dummy_model)\n", "\n", " # 3. Clean up to free RAM for actual training\n", " del dummy_model\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", "\n", " # 4. Ask for confirmation (Optional, or just proceed)\n", " print(\"\\nāœ… Analysis Complete. Starting Training Routine in 5 seconds...\")\n", " import time\n", " time.sleep(5)\n", "\n", " # 5. Start Training\n", " trained_prism = run_wikitext_training()\n", "\n", " # 6. Final Analysis (Post-training check)\n", " analyze_prism_params(trained_prism)\n", "\n", " # 7. Kill Runtime (Colab specific)\n", " from google.colab import runtime\n", " runtime.unassign()" ], "metadata": { "id": "KaiJU0tPkVp-" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import runtime\n", "runtime.unassign()" ], "metadata": { "id": "bxFTYWHVqcSI" }, "execution_count": null, "outputs": [] } ] }