{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "o03F7Fu7qZXo" }, "outputs": [], "source": [ "# ==========================================\n", "# 0. SETUP & DEPENDENCIES\n", "# ==========================================\n", "!pip install -q x-transformers flash-attn datasets pandas tabulate huggingface_hub\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import math\n", "import gc\n", "import pandas as pd\n", "from tqdm.auto import tqdm\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "from huggingface_hub import hf_hub_download\n", "from x_transformers import TransformerWrapper, Encoder\n", "\n", "# Global Config\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "DATASET_ID = \"prism-lab/wikitext-103-prism-test-seed42\"\n", "BATCH_SIZE = 8\n", "VOCAB_SIZE = 32768\n", "SEQ_LEN = 4096\n", "\n", "print(f\"๐Ÿ”ฅ Initializing Full Benchmark Suite on {DEVICE}\")\n", "\n", "# ==========================================\n", "# 1. ARCHITECTURE DEFINITIONS (Exact Matches)\n", "# ==========================================\n", "\n", "# --- UTILS (Shared) ---\n", "class ComplexDropout(nn.Module):\n", " def __init__(self, p=0.0): super().__init__(); self.p = p\n", " def forward(self, z): return z\n", "class RobustPhaseNorm(nn.Module):\n", " def __init__(self, d, eps=1e-5): super().__init__(); self.scale = nn.Parameter(torch.ones(d)); self.eps = eps\n", " def forward(self, x): return (x / torch.sqrt((x.abs()**2).mean(-1, keepdim=True) + self.eps)) * self.scale\n", "class ModReLU(nn.Module):\n", " def __init__(self, f): super().__init__(); self.b = nn.Parameter(torch.zeros(f))\n", " def forward(self, z): return F.relu(z.abs() + self.b) * (z / (z.abs() + 1e-6))\n", "class ComplexToRealBridge(nn.Module):\n", " def __init__(self, d): super().__init__(); self.proj = nn.Linear(d*2, d); self.norm = nn.LayerNorm(d)\n", " def forward(self, x): return self.norm(self.proj(torch.cat([x.real, x.imag], -1)))\n", "class DynamicRoSE(nn.Module):\n", " def __init__(self, n, d):\n", " super().__init__(); self.raw_embedding = nn.Embedding(n, d); self.adapter = nn.Linear(d, d*2); self.rotation_predictor = nn.Linear(d, d*2)\n", " self.register_buffer('freqs', torch.exp(torch.arange(0, d) * -(math.log(10000.0)/d)))\n", " def forward(self, x):\n", " real = self.raw_embedding(x); params = self.adapter(real); D = real.shape[-1]\n", " z = torch.complex(params[...,:D], params[...,D:]); r = self.rotation_predictor(real); rx, ry = r.chunk(2, -1)\n", " drot = torch.complex(rx/torch.sqrt(rx**2+ry**2+1e-6), ry/torch.sqrt(rx**2+ry**2+1e-6))\n", " pos = torch.arange(real.shape[1], device=x.device).float()\n", " srot = torch.polar(torch.ones_like(torch.outer(pos, self.freqs)), torch.outer(pos, self.freqs))\n", " return (z * srot.unsqueeze(0) * drot), real\n", "class HyenaNeuralFilter(nn.Module):\n", " def __init__(self, d, max_len=1024, h=64):\n", " super().__init__(); self.d = d; self.register_buffer(\"freqs\", torch.exp(torch.arange(0, h, 2) * -(math.log(10000.0)/h)))\n", " self.mlp = nn.Sequential(nn.Linear(h, h), nn.SiLU(), nn.Linear(h, h), nn.SiLU(), nn.Linear(h, d*2))\n", " def forward(self, L, dev):\n", " t = torch.linspace(0, 1, steps=L, device=dev).unsqueeze(-1)\n", " emb = torch.cat([torch.sin(t*self.freqs), torch.cos(t*self.freqs)], -1)\n", " out = self.mlp(emb).view(L, self.d, 2); return torch.complex(out[...,0], out[...,1])\n", "class GatedHarmonicConvolution(nn.Module):\n", " def __init__(self, d, max_len):\n", " super().__init__(); self.d=d; self.filter_len=max_len; self.neural_filter = HyenaNeuralFilter(d, max_len)\n", " self.gate_proj = nn.Linear(d*2, d*2); self.mix_real = nn.Linear(d,d); self.mix_imag = nn.Linear(d,d)\n", " self.out_real = nn.Linear(d,d); self.out_imag = nn.Linear(d,d); self.activation = ModReLU(d); self.norm = RobustPhaseNorm(d)\n", " self.dropout = ComplexDropout(0.0)\n", " def forward(self, x, mask=None):\n", " res = x; x = self.norm(x); B,L,D = x.shape; eff_L = min(L, self.filter_len)\n", " h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n", " xt = torch.fft.ifft(torch.fft.fft(x, n=eff_L, dim=1, norm='ortho') * h, n=eff_L, dim=1, norm='ortho')\n", " if L > eff_L: xt = F.pad(xt, (0,0,0,L-eff_L));\n", " else: xt = xt[:, :L, :]\n", " g = torch.sigmoid(self.gate_proj(torch.cat([x.real, x.imag], -1))); gr, gi = g.chunk(2, -1)\n", " xg = torch.complex(xt.real*gr, xt.imag*gi); mr, mi = self.mix_real, self.mix_imag\n", " xm = torch.complex(mr(xg.real)-mi(xg.imag), mr(xg.imag)+mi(xg.real)); xa = self.activation(xm); or_, oi = self.out_real, self.out_imag\n", " out = torch.complex(or_(xa.real)-oi(xa.imag), or_(xa.imag)+oi(xa.real))\n", " return self.dropout(out) + res\n", "class PRISMEncoder(nn.Module):\n", " def __init__(self, l, d, max_l): super().__init__(); self.layers = nn.ModuleList([GatedHarmonicConvolution(d, max_l) for _ in range(l)]); self.final_norm = RobustPhaseNorm(d)\n", " def forward(self, x):\n", " for layer in self.layers: x = layer(x)\n", " return self.final_norm(x)\n", "\n", "# --- A. BASELINE (Transformer) ---\n", "class LocalBaseline(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.model = TransformerWrapper(\n", " num_tokens=VOCAB_SIZE, max_seq_len=SEQ_LEN, use_abs_pos_emb=False, tie_embedding=True,\n", " attn_layers=Encoder(dim=512, depth=5, heads=8, rotary_pos_emb=True, attn_flash=True, use_scalenorm=False)\n", " )\n", " def forward(self, x): return self.model(x)\n", "\n", "# --- B. FNET (Hybrid) ---\n", "class FNetBlock(nn.Module):\n", " def __init__(self, d, df):\n", " super().__init__(); self.norm_mix = nn.LayerNorm(d); self.norm_ff = nn.LayerNorm(d)\n", " self.ff = nn.Sequential(nn.Linear(d, df), nn.GELU(), nn.Dropout(0), nn.Linear(df, d), nn.Dropout(0))\n", " def forward(self, x):\n", " r = x; x = self.norm_mix(x); x = torch.fft.fftn(x.float(), dim=(-2,-1), norm='ortho').real.to(r.dtype); x = x+r\n", " r = x; x = self.norm_ff(x); x = self.ff(x); return x+r\n", "class FNetEncoder(nn.Module):\n", " def __init__(self, depth, d, df): super().__init__(); self.layers = nn.ModuleList([FNetBlock(d, df) for _ in range(depth)]); self.norm_out = nn.LayerNorm(d)\n", " def forward(self, x):\n", " for l in self.layers: x = l(x)\n", " return self.norm_out(x)\n", "class HybridFNetMLM(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.token_emb = nn.Embedding(VOCAB_SIZE, 512); self.pos_emb = nn.Parameter(torch.zeros(1, SEQ_LEN, 512))\n", " self.fnet_encoder = FNetEncoder(6, 512, 2048)\n", " self.transformer_cap = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", " self.final_norm = nn.LayerNorm(512); self.to_logits = nn.Linear(512, VOCAB_SIZE)\n", " self.to_logits.weight = self.token_emb.weight # Tie\n", " def forward(self, x):\n", " h = self.token_emb(x) + self.pos_emb[:, :x.shape[1], :]\n", " return self.to_logits(self.final_norm(self.transformer_cap(self.fnet_encoder(h))))\n", "\n", "# --- C. PRISM (Phase Coder) ---\n", "class LocalPRISM(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.rose = DynamicRoSE(VOCAB_SIZE, 512); self.prism_encoder = PRISMEncoder(5, 512, SEQ_LEN)\n", " self.bridge = ComplexToRealBridge(512); self.periscope_proj = nn.Sequential(nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU())\n", " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", " self.lm_head = nn.Linear(512, VOCAB_SIZE); self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n", " def forward(self, x):\n", " w, p = self.rose(x); w = self.bridge(self.prism_encoder(w))\n", " return self.lm_head(self.refiner(self.periscope_proj(torch.cat([w, p], -1))))\n", "\n", "# --- D. PILLARS (Split-Stream) ---\n", "class LocalPillars(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.rose = DynamicRoSE(VOCAB_SIZE, 512); self.particle_down = nn.Linear(512, 256); self.wave_down = nn.Linear(1024, 512)\n", " self.fnet_pos = nn.Embedding(SEQ_LEN, 256); self.stream_rate = FNetEncoder(9, 256, 1024)\n", " self.stream_phase = PRISMEncoder(9, 256, SEQ_LEN); self.phase_bridge = ComplexToRealBridge(256)\n", " self.fusion_proj = nn.Linear(512, 512); self.fusion_norm = nn.LayerNorm(512)\n", " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", " self.head_bias = nn.Parameter(torch.zeros(VOCAB_SIZE))\n", " def forward(self, x):\n", " w, p = self.rose(x); p_sm = self.particle_down(p); w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n", " w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n", " p_path = self.stream_rate(p_sm + self.fnet_pos(torch.arange(x.shape[1], device=x.device)))\n", " w_path = self.phase_bridge(self.stream_phase(w_sm))\n", " ctx = self.fusion_norm(self.fusion_proj(torch.cat([p_path, w_path], -1)))\n", " return F.linear(self.refiner(ctx), self.rose.raw_embedding.weight, self.head_bias)\n", "\n", "# --- NEW: Sensory Stream for DAT ---\n", "class SensoryStream(nn.Module):\n", " def __init__(self, depth, d, dropout=0.1):\n", " super().__init__()\n", " self.encoder = Encoder(\n", " dim=d, depth=depth, heads=4, attn_flash=True,\n", " rotary_pos_emb=True, attn_dropout=dropout, ff_dropout=dropout,\n", " use_rmsnorm=True, ff_glu=True\n", " )\n", " def forward(self, x): return self.encoder(x)\n", "\n", "# --- E. PILLARS-DAT (New Hybrid) ---\n", "class LocalPillarsDAT(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # Config matching your training\n", " d_model, d_branch, depth = 512, 256, 6\n", "\n", " self.rose = DynamicRoSE(VOCAB_SIZE, d_model)\n", " self.particle_down = nn.Linear(d_model, d_branch)\n", " self.wave_down = nn.Linear(d_model * 2, d_branch * 2)\n", "\n", " # Stream A: Sensory (Transformer)\n", " self.stream_sensory = SensoryStream(depth, d_branch)\n", "\n", " # Stream B: Relational (PRISM)\n", " # Re-using existing PRISMEncoder(layers, dim, seq_len)\n", " self.stream_relational = PRISMEncoder(depth, d_branch, SEQ_LEN)\n", " self.relational_bridge = ComplexToRealBridge(d_branch)\n", "\n", " # Fusion\n", " self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n", " self.fusion_norm = nn.LayerNorm(d_model)\n", "\n", " # Refiner\n", " self.refiner = Encoder(dim=d_model, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", "\n", " # Output (Standard Linear to match HF Checkpoint)\n", " self.lm_head = nn.Linear(d_model, VOCAB_SIZE)\n", " self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n", "\n", " def forward(self, x):\n", " w, p = self.rose(x)\n", " p_sm = self.particle_down(p)\n", "\n", " # Complex Downsample\n", " w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n", " w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n", "\n", " # Parallel Streams\n", " sensory_out = self.stream_sensory(p_sm)\n", " rel_out = self.relational_bridge(self.stream_relational(w_sm))\n", "\n", " # Fusion\n", " ctx = self.fusion_norm(self.fusion_proj(torch.cat([sensory_out, rel_out], -1)))\n", " return self.lm_head(self.refiner(ctx))\n", "\n", "# ==========================================\n", "# 2. INTELLIGENT LOADING & EVALUATION\n", "# ==========================================\n", "def smart_load(model, repo_id, name):\n", " print(f\"โฌ‡๏ธ Downloading {name} from HF: {repo_id}...\")\n", " try: path = hf_hub_download(repo_id, \"best.pt\")\n", " except: path = hf_hub_download(repo_id, \"pytorch_model.bin\")\n", "\n", " state_dict = torch.load(path, map_location=\"cpu\")\n", " if 'model' in state_dict: state_dict = state_dict['model']\n", "\n", " clean = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n", "\n", " # ๐Ÿ”ง SPECIFIC FIXES (Programmatic remapping)\n", " if name == \"Baseline\":\n", " new_d = {}\n", " for k, v in clean.items():\n", " nk = k if k.startswith(\"model.\") else \"model.\" + k\n", " if \"token_emb.weight\" in nk and \"emb\" not in nk: nk = nk.replace(\"token_emb.weight\", \"token_emb.emb.weight\")\n", " new_d[nk] = v\n", " clean = new_d\n", " elif name == \"FNet\":\n", " new_d = {}\n", " for k, v in clean.items():\n", " nk = k.replace(\"model.\", \"\")\n", " new_d[nk] = v\n", " clean = new_d\n", "\n", " # LOAD\n", " missing, _ = model.load_state_dict(clean, strict=False)\n", " print(f\"โœ… {name} Loaded. Missing Keys: {len(missing)}\")\n", " return model\n", "\n", "def evaluate_full(model, loader):\n", " model.eval()\n", " total_nll = 0\n", " total_mask_count = 0 # Track exact number of predictions\n", " correct_1 = 0\n", " correct_5 = 0\n", "\n", " # Switch to SUM reduction to handle varying mask counts correctly\n", " criterion = nn.CrossEntropyLoss(reduction='sum')\n", "\n", " with torch.no_grad():\n", " for b in tqdm(loader, leave=False):\n", " ids = b['input_ids'].to(DEVICE)\n", " lbl = b['labels'].to(DEVICE)\n", "\n", " logits = model(ids)\n", " if isinstance(logits, dict): logits = logits['logits']\n", "\n", " # Mask logic: only calculate loss on -100 (masked) tokens\n", " mask = (lbl != -100)\n", " if mask.sum() > 0:\n", " # 1. PPL: Sum the loss, don't mean it yet\n", " loss = criterion(logits.view(-1, VOCAB_SIZE), lbl.view(-1))\n", " total_nll += loss.item()\n", " total_mask_count += mask.sum().item()\n", "\n", " # 2. Accuracy\n", " m_log = logits[mask]\n", " m_lbl = lbl[mask]\n", " correct_1 += (m_log.argmax(-1) == m_lbl).sum().item()\n", " _, top5 = m_log.topk(5, -1)\n", " correct_5 += (top5 == m_lbl.unsqueeze(1)).any(1).sum().item()\n", "\n", " if total_mask_count == 0: return 0, 0, 0\n", "\n", " # Final Calculation\n", " ppl = math.exp(total_nll / total_mask_count)\n", " acc1 = (correct_1 / total_mask_count) * 100\n", " acc5 = (correct_5 / total_mask_count) * 100\n", " return ppl, acc1, acc5\n", "\n", "\n", "def audit_efficiency(model, name):\n", " total = sum(p.numel() for p in model.parameters())\n", " is_tied = total < 45000000 # Hard check for 30-40M range\n", "\n", " active_mix = 0\n", " if name == \"Baseline\":\n", " for n, p in model.named_parameters():\n", " if \"attn\" in n and \"norm\" not in n: active_mix += p.numel()\n", " elif name == \"FNet\":\n", " for n, p in model.named_parameters():\n", " if \"transformer_cap\" in n and \"attn\" in n and \"norm\" not in n: active_mix += p.numel()\n", " elif name == \"PRISM\":\n", " for n, p in model.named_parameters():\n", " if (\"neural_filter\" in n or \"gate_proj\" in n) or (\"refiner\" in n and \"attn\" in n and \"norm\" not in n):\n", " active_mix += p.numel()\n", " elif name == \"PILLARS\": # Old HSSM\n", " for n, p in model.named_parameters():\n", " if ((\"neural_filter\" in n or \"gate_proj\" in n) and \"stream_phase\" in n) or (\"fusion_proj\" in n) or (\"refiner\" in n and \"attn\" in n and \"norm\" not in n):\n", " active_mix += p.numel()\n", " # --- NEW BLOCK ---\n", " elif name == \"PILLARS-DAT\":\n", " for n, p in model.named_parameters():\n", " # 1. Sensory Stream Attention (Rate)\n", " if \"stream_sensory\" in n and \"attn\" in n and \"norm\" not in n:\n", " active_mix += p.numel()\n", " # 2. Relational Stream Filters (Phase)\n", " if \"stream_relational\" in n and (\"neural_filter\" in n or \"gate_proj\" in n):\n", " active_mix += p.numel()\n", " # 3. Refiner Attention (Readout)\n", " if \"refiner\" in n and \"attn\" in n and \"norm\" not in n:\n", " active_mix += p.numel()\n", "\n", " return {\n", " \"Model\": name,\n", " \"Weights Tied?\": \"โœ… YES\" if is_tied else \"โŒ NO\",\n", " \"Total Params (M)\": total / 1e6,\n", " \"Active Mixing (M)\": active_mix / 1e6,\n", " \"Mixing Ratio\": f\"{(active_mix/total)*100:.1f}%\"\n", " }\n", "# ==========================================\n", "# 4. EXECUTION\n", "# ==========================================\n", "MODELS = [\n", " (\"Baseline\", \"prism-lab/baseline-wikitext-prism\", LocalBaseline),\n", " (\"FNet\", \"prism-lab/hybrid-fnet-prism-custom\", HybridFNetMLM),\n", " (\"PRISM\", \"prism-lab/prism-v2-wikitext\", LocalPRISM),\n", " (\"HSSM\", \"prism-lab/pillars-compact-wikitext\", LocalPillars),\n", " (\"WPT\", \"prism-lab/pillars-dat-wikitext-32k\", LocalPillarsDAT)\n", "]\n", "print(\"๐Ÿ“ฆ Loading Data...\")\n", "d = load_dataset(DATASET_ID)\n", "\n", "# Priority: Test > Validation > Train (fallback)\n", "if 'test' in d:\n", " ds = d['test']\n", "elif 'validation' in d:\n", " ds = d['validation']\n", "else:\n", " ds = d['train']\n", "\n", "print(f\"๐Ÿ“Š Evaluating on split: {ds.split if hasattr(ds, 'split') else 'Unknown'}\")\n", "ds.set_format(\"torch\", columns=[\"input_ids\", \"labels\"])\n", "loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)\n", "\n", "perf_results = []\n", "param_results = []\n", "\n", "print(\"\\n๐Ÿ”ฅ STARTING FULL BENCHMARK\")\n", "\n", "for name, repo, cls in MODELS:\n", " print(f\"\\n๐Ÿงช Processing {name}...\")\n", " try:\n", " model = cls().to(DEVICE)\n", " model = smart_load(model, repo, name)\n", "\n", " # 1. Performance\n", " ppl, top1, top5 = evaluate_full(model, loader)\n", " perf_results.append({\"Model\": name, \"PPL\": ppl, \"Top-1\": top1, \"Top-5\": top5})\n", "\n", " # 2. Efficiency\n", " param_results.append(audit_efficiency(model, name))\n", "\n", " del model; torch.cuda.empty_cache(); gc.collect()\n", " except Exception as e:\n", " print(f\"โŒ {name} Failed: {e}\")\n", "\n", "print(\"\\n\\n\" + \"=\"*80)\n", "print(\"๐Ÿ† TABLE 1: PERFORMANCE BENCHMARK\")\n", "print(\"=\"*80)\n", "print(pd.DataFrame(perf_results).sort_values(\"PPL\").to_markdown(index=False, floatfmt=\".4f\"))\n", "\n", "print(\"\\n\\n\" + \"=\"*80)\n", "print(\"๐Ÿ† TABLE 2: COMPUTATIONAL EFFICIENCY\")\n", "print(\"=\"*80)\n", "print(pd.DataFrame(param_results).to_markdown(index=False, floatfmt=\".2f\"))" ] } ] }