Upload Eval_T4_Last.ipynb
Browse filesReproduces Table 4. Just open in colab and connect to T4 GPU and click run to replicate our results on our pretrained models.
- Eval_T4_Last.ipynb +407 -0
Eval_T4_Last.ipynb
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": null,
|
| 22 |
+
"metadata": {
|
| 23 |
+
"id": "o03F7Fu7qZXo"
|
| 24 |
+
},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"# ==========================================\n",
|
| 28 |
+
"# 0. SETUP & DEPENDENCIES\n",
|
| 29 |
+
"# ==========================================\n",
|
| 30 |
+
"!pip install -q x-transformers flash-attn datasets pandas tabulate huggingface_hub\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"import torch\n",
|
| 33 |
+
"import torch.nn as nn\n",
|
| 34 |
+
"import torch.nn.functional as F\n",
|
| 35 |
+
"import math\n",
|
| 36 |
+
"import gc\n",
|
| 37 |
+
"import pandas as pd\n",
|
| 38 |
+
"from tqdm.auto import tqdm\n",
|
| 39 |
+
"from torch.utils.data import DataLoader\n",
|
| 40 |
+
"from datasets import load_dataset\n",
|
| 41 |
+
"from huggingface_hub import hf_hub_download\n",
|
| 42 |
+
"from x_transformers import TransformerWrapper, Encoder\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"# Global Config\n",
|
| 45 |
+
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 46 |
+
"DATASET_ID = \"prism-lab/wikitext-103-prism-test-seed42\"\n",
|
| 47 |
+
"BATCH_SIZE = 8\n",
|
| 48 |
+
"VOCAB_SIZE = 32768\n",
|
| 49 |
+
"SEQ_LEN = 4096\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"print(f\"🔥 Initializing Full Benchmark Suite on {DEVICE}\")\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"# ==========================================\n",
|
| 54 |
+
"# 1. ARCHITECTURE DEFINITIONS (Exact Matches)\n",
|
| 55 |
+
"# ==========================================\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"# --- UTILS (Shared) ---\n",
|
| 58 |
+
"class ComplexDropout(nn.Module):\n",
|
| 59 |
+
" def __init__(self, p=0.0): super().__init__(); self.p = p\n",
|
| 60 |
+
" def forward(self, z): return z\n",
|
| 61 |
+
"class RobustPhaseNorm(nn.Module):\n",
|
| 62 |
+
" def __init__(self, d, eps=1e-5): super().__init__(); self.scale = nn.Parameter(torch.ones(d)); self.eps = eps\n",
|
| 63 |
+
" def forward(self, x): return (x / torch.sqrt((x.abs()**2).mean(-1, keepdim=True) + self.eps)) * self.scale\n",
|
| 64 |
+
"class ModReLU(nn.Module):\n",
|
| 65 |
+
" def __init__(self, f): super().__init__(); self.b = nn.Parameter(torch.zeros(f))\n",
|
| 66 |
+
" def forward(self, z): return F.relu(z.abs() + self.b) * (z / (z.abs() + 1e-6))\n",
|
| 67 |
+
"class ComplexToRealBridge(nn.Module):\n",
|
| 68 |
+
" def __init__(self, d): super().__init__(); self.proj = nn.Linear(d*2, d); self.norm = nn.LayerNorm(d)\n",
|
| 69 |
+
" def forward(self, x): return self.norm(self.proj(torch.cat([x.real, x.imag], -1)))\n",
|
| 70 |
+
"class DynamicRoSE(nn.Module):\n",
|
| 71 |
+
" def __init__(self, n, d):\n",
|
| 72 |
+
" super().__init__(); self.raw_embedding = nn.Embedding(n, d); self.adapter = nn.Linear(d, d*2); self.rotation_predictor = nn.Linear(d, d*2)\n",
|
| 73 |
+
" self.register_buffer('freqs', torch.exp(torch.arange(0, d) * -(math.log(10000.0)/d)))\n",
|
| 74 |
+
" def forward(self, x):\n",
|
| 75 |
+
" real = self.raw_embedding(x); params = self.adapter(real); D = real.shape[-1]\n",
|
| 76 |
+
" z = torch.complex(params[...,:D], params[...,D:]); r = self.rotation_predictor(real); rx, ry = r.chunk(2, -1)\n",
|
| 77 |
+
" drot = torch.complex(rx/torch.sqrt(rx**2+ry**2+1e-6), ry/torch.sqrt(rx**2+ry**2+1e-6))\n",
|
| 78 |
+
" pos = torch.arange(real.shape[1], device=x.device).float()\n",
|
| 79 |
+
" srot = torch.polar(torch.ones_like(torch.outer(pos, self.freqs)), torch.outer(pos, self.freqs))\n",
|
| 80 |
+
" return (z * srot.unsqueeze(0) * drot), real\n",
|
| 81 |
+
"class HyenaNeuralFilter(nn.Module):\n",
|
| 82 |
+
" def __init__(self, d, max_len=1024, h=64):\n",
|
| 83 |
+
" super().__init__(); self.d = d; self.register_buffer(\"freqs\", torch.exp(torch.arange(0, h, 2) * -(math.log(10000.0)/h)))\n",
|
| 84 |
+
" self.mlp = nn.Sequential(nn.Linear(h, h), nn.SiLU(), nn.Linear(h, h), nn.SiLU(), nn.Linear(h, d*2))\n",
|
| 85 |
+
" def forward(self, L, dev):\n",
|
| 86 |
+
" t = torch.linspace(0, 1, steps=L, device=dev).unsqueeze(-1)\n",
|
| 87 |
+
" emb = torch.cat([torch.sin(t*self.freqs), torch.cos(t*self.freqs)], -1)\n",
|
| 88 |
+
" out = self.mlp(emb).view(L, self.d, 2); return torch.complex(out[...,0], out[...,1])\n",
|
| 89 |
+
"class GatedHarmonicConvolution(nn.Module):\n",
|
| 90 |
+
" def __init__(self, d, max_len):\n",
|
| 91 |
+
" super().__init__(); self.d=d; self.filter_len=max_len; self.neural_filter = HyenaNeuralFilter(d, max_len)\n",
|
| 92 |
+
" self.gate_proj = nn.Linear(d*2, d*2); self.mix_real = nn.Linear(d,d); self.mix_imag = nn.Linear(d,d)\n",
|
| 93 |
+
" self.out_real = nn.Linear(d,d); self.out_imag = nn.Linear(d,d); self.activation = ModReLU(d); self.norm = RobustPhaseNorm(d)\n",
|
| 94 |
+
" self.dropout = ComplexDropout(0.0)\n",
|
| 95 |
+
" def forward(self, x, mask=None):\n",
|
| 96 |
+
" res = x; x = self.norm(x); B,L,D = x.shape; eff_L = min(L, self.filter_len)\n",
|
| 97 |
+
" h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n",
|
| 98 |
+
" xt = torch.fft.ifft(torch.fft.fft(x, n=eff_L, dim=1, norm='ortho') * h, n=eff_L, dim=1, norm='ortho')\n",
|
| 99 |
+
" if L > eff_L: xt = F.pad(xt, (0,0,0,L-eff_L));\n",
|
| 100 |
+
" else: xt = xt[:, :L, :]\n",
|
| 101 |
+
" g = torch.sigmoid(self.gate_proj(torch.cat([x.real, x.imag], -1))); gr, gi = g.chunk(2, -1)\n",
|
| 102 |
+
" xg = torch.complex(xt.real*gr, xt.imag*gi); mr, mi = self.mix_real, self.mix_imag\n",
|
| 103 |
+
" xm = torch.complex(mr(xg.real)-mi(xg.imag), mr(xg.imag)+mi(xg.real)); xa = self.activation(xm); or_, oi = self.out_real, self.out_imag\n",
|
| 104 |
+
" out = torch.complex(or_(xa.real)-oi(xa.imag), or_(xa.imag)+oi(xa.real))\n",
|
| 105 |
+
" return self.dropout(out) + res\n",
|
| 106 |
+
"class PRISMEncoder(nn.Module):\n",
|
| 107 |
+
" def __init__(self, l, d, max_l): super().__init__(); self.layers = nn.ModuleList([GatedHarmonicConvolution(d, max_l) for _ in range(l)]); self.final_norm = RobustPhaseNorm(d)\n",
|
| 108 |
+
" def forward(self, x):\n",
|
| 109 |
+
" for layer in self.layers: x = layer(x)\n",
|
| 110 |
+
" return self.final_norm(x)\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"# --- A. BASELINE (Transformer) ---\n",
|
| 113 |
+
"class LocalBaseline(nn.Module):\n",
|
| 114 |
+
" def __init__(self):\n",
|
| 115 |
+
" super().__init__()\n",
|
| 116 |
+
" self.model = TransformerWrapper(\n",
|
| 117 |
+
" num_tokens=VOCAB_SIZE, max_seq_len=SEQ_LEN, use_abs_pos_emb=False, tie_embedding=True,\n",
|
| 118 |
+
" attn_layers=Encoder(dim=512, depth=5, heads=8, rotary_pos_emb=True, attn_flash=True, use_scalenorm=False)\n",
|
| 119 |
+
" )\n",
|
| 120 |
+
" def forward(self, x): return self.model(x)\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"# --- B. FNET (Hybrid) ---\n",
|
| 123 |
+
"class FNetBlock(nn.Module):\n",
|
| 124 |
+
" def __init__(self, d, df):\n",
|
| 125 |
+
" super().__init__(); self.norm_mix = nn.LayerNorm(d); self.norm_ff = nn.LayerNorm(d)\n",
|
| 126 |
+
" self.ff = nn.Sequential(nn.Linear(d, df), nn.GELU(), nn.Dropout(0), nn.Linear(df, d), nn.Dropout(0))\n",
|
| 127 |
+
" def forward(self, x):\n",
|
| 128 |
+
" r = x; x = self.norm_mix(x); x = torch.fft.fftn(x.float(), dim=(-2,-1), norm='ortho').real.to(r.dtype); x = x+r\n",
|
| 129 |
+
" r = x; x = self.norm_ff(x); x = self.ff(x); return x+r\n",
|
| 130 |
+
"class FNetEncoder(nn.Module):\n",
|
| 131 |
+
" def __init__(self, depth, d, df): super().__init__(); self.layers = nn.ModuleList([FNetBlock(d, df) for _ in range(depth)]); self.norm_out = nn.LayerNorm(d)\n",
|
| 132 |
+
" def forward(self, x):\n",
|
| 133 |
+
" for l in self.layers: x = l(x)\n",
|
| 134 |
+
" return self.norm_out(x)\n",
|
| 135 |
+
"class HybridFNetMLM(nn.Module):\n",
|
| 136 |
+
" def __init__(self):\n",
|
| 137 |
+
" super().__init__()\n",
|
| 138 |
+
" self.token_emb = nn.Embedding(VOCAB_SIZE, 512); self.pos_emb = nn.Parameter(torch.zeros(1, SEQ_LEN, 512))\n",
|
| 139 |
+
" self.fnet_encoder = FNetEncoder(6, 512, 2048)\n",
|
| 140 |
+
" self.transformer_cap = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
|
| 141 |
+
" self.final_norm = nn.LayerNorm(512); self.to_logits = nn.Linear(512, VOCAB_SIZE)\n",
|
| 142 |
+
" self.to_logits.weight = self.token_emb.weight # Tie\n",
|
| 143 |
+
" def forward(self, x):\n",
|
| 144 |
+
" h = self.token_emb(x) + self.pos_emb[:, :x.shape[1], :]\n",
|
| 145 |
+
" return self.to_logits(self.final_norm(self.transformer_cap(self.fnet_encoder(h))))\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"# --- C. PRISM (Phase Coder) ---\n",
|
| 148 |
+
"class LocalPRISM(nn.Module):\n",
|
| 149 |
+
" def __init__(self):\n",
|
| 150 |
+
" super().__init__()\n",
|
| 151 |
+
" self.rose = DynamicRoSE(VOCAB_SIZE, 512); self.prism_encoder = PRISMEncoder(5, 512, SEQ_LEN)\n",
|
| 152 |
+
" self.bridge = ComplexToRealBridge(512); self.periscope_proj = nn.Sequential(nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU())\n",
|
| 153 |
+
" self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
|
| 154 |
+
" self.lm_head = nn.Linear(512, VOCAB_SIZE); self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n",
|
| 155 |
+
" def forward(self, x):\n",
|
| 156 |
+
" w, p = self.rose(x); w = self.bridge(self.prism_encoder(w))\n",
|
| 157 |
+
" return self.lm_head(self.refiner(self.periscope_proj(torch.cat([w, p], -1))))\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"# --- D. PILLARS (Split-Stream) ---\n",
|
| 160 |
+
"class LocalPillars(nn.Module):\n",
|
| 161 |
+
" def __init__(self):\n",
|
| 162 |
+
" super().__init__()\n",
|
| 163 |
+
" self.rose = DynamicRoSE(VOCAB_SIZE, 512); self.particle_down = nn.Linear(512, 256); self.wave_down = nn.Linear(1024, 512)\n",
|
| 164 |
+
" self.fnet_pos = nn.Embedding(SEQ_LEN, 256); self.stream_rate = FNetEncoder(9, 256, 1024)\n",
|
| 165 |
+
" self.stream_phase = PRISMEncoder(9, 256, SEQ_LEN); self.phase_bridge = ComplexToRealBridge(256)\n",
|
| 166 |
+
" self.fusion_proj = nn.Linear(512, 512); self.fusion_norm = nn.LayerNorm(512)\n",
|
| 167 |
+
" self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
|
| 168 |
+
" self.head_bias = nn.Parameter(torch.zeros(VOCAB_SIZE))\n",
|
| 169 |
+
" def forward(self, x):\n",
|
| 170 |
+
" w, p = self.rose(x); p_sm = self.particle_down(p); w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n",
|
| 171 |
+
" w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n",
|
| 172 |
+
" p_path = self.stream_rate(p_sm + self.fnet_pos(torch.arange(x.shape[1], device=x.device)))\n",
|
| 173 |
+
" w_path = self.phase_bridge(self.stream_phase(w_sm))\n",
|
| 174 |
+
" ctx = self.fusion_norm(self.fusion_proj(torch.cat([p_path, w_path], -1)))\n",
|
| 175 |
+
" return F.linear(self.refiner(ctx), self.rose.raw_embedding.weight, self.head_bias)\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"# --- NEW: Sensory Stream for DAT ---\n",
|
| 178 |
+
"class SensoryStream(nn.Module):\n",
|
| 179 |
+
" def __init__(self, depth, d, dropout=0.1):\n",
|
| 180 |
+
" super().__init__()\n",
|
| 181 |
+
" self.encoder = Encoder(\n",
|
| 182 |
+
" dim=d, depth=depth, heads=4, attn_flash=True,\n",
|
| 183 |
+
" rotary_pos_emb=True, attn_dropout=dropout, ff_dropout=dropout,\n",
|
| 184 |
+
" use_rmsnorm=True, ff_glu=True\n",
|
| 185 |
+
" )\n",
|
| 186 |
+
" def forward(self, x): return self.encoder(x)\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"# --- E. PILLARS-DAT (New Hybrid) ---\n",
|
| 189 |
+
"class LocalPillarsDAT(nn.Module):\n",
|
| 190 |
+
" def __init__(self):\n",
|
| 191 |
+
" super().__init__()\n",
|
| 192 |
+
" # Config matching your training\n",
|
| 193 |
+
" d_model, d_branch, depth = 512, 256, 6\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" self.rose = DynamicRoSE(VOCAB_SIZE, d_model)\n",
|
| 196 |
+
" self.particle_down = nn.Linear(d_model, d_branch)\n",
|
| 197 |
+
" self.wave_down = nn.Linear(d_model * 2, d_branch * 2)\n",
|
| 198 |
+
"\n",
|
| 199 |
+
" # Stream A: Sensory (Transformer)\n",
|
| 200 |
+
" self.stream_sensory = SensoryStream(depth, d_branch)\n",
|
| 201 |
+
"\n",
|
| 202 |
+
" # Stream B: Relational (PRISM)\n",
|
| 203 |
+
" # Re-using existing PRISMEncoder(layers, dim, seq_len)\n",
|
| 204 |
+
" self.stream_relational = PRISMEncoder(depth, d_branch, SEQ_LEN)\n",
|
| 205 |
+
" self.relational_bridge = ComplexToRealBridge(d_branch)\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" # Fusion\n",
|
| 208 |
+
" self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n",
|
| 209 |
+
" self.fusion_norm = nn.LayerNorm(d_model)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" # Refiner\n",
|
| 212 |
+
" self.refiner = Encoder(dim=d_model, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
|
| 213 |
+
"\n",
|
| 214 |
+
" # Output (Standard Linear to match HF Checkpoint)\n",
|
| 215 |
+
" self.lm_head = nn.Linear(d_model, VOCAB_SIZE)\n",
|
| 216 |
+
" self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n",
|
| 217 |
+
"\n",
|
| 218 |
+
" def forward(self, x):\n",
|
| 219 |
+
" w, p = self.rose(x)\n",
|
| 220 |
+
" p_sm = self.particle_down(p)\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" # Complex Downsample\n",
|
| 223 |
+
" w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n",
|
| 224 |
+
" w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n",
|
| 225 |
+
"\n",
|
| 226 |
+
" # Parallel Streams\n",
|
| 227 |
+
" sensory_out = self.stream_sensory(p_sm)\n",
|
| 228 |
+
" rel_out = self.relational_bridge(self.stream_relational(w_sm))\n",
|
| 229 |
+
"\n",
|
| 230 |
+
" # Fusion\n",
|
| 231 |
+
" ctx = self.fusion_norm(self.fusion_proj(torch.cat([sensory_out, rel_out], -1)))\n",
|
| 232 |
+
" return self.lm_head(self.refiner(ctx))\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"# ==========================================\n",
|
| 235 |
+
"# 2. INTELLIGENT LOADING & EVALUATION\n",
|
| 236 |
+
"# ==========================================\n",
|
| 237 |
+
"def smart_load(model, repo_id, name):\n",
|
| 238 |
+
" print(f\"⬇️ Downloading {name} from HF: {repo_id}...\")\n",
|
| 239 |
+
" try: path = hf_hub_download(repo_id, \"best.pt\")\n",
|
| 240 |
+
" except: path = hf_hub_download(repo_id, \"pytorch_model.bin\")\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" state_dict = torch.load(path, map_location=\"cpu\")\n",
|
| 243 |
+
" if 'model' in state_dict: state_dict = state_dict['model']\n",
|
| 244 |
+
"\n",
|
| 245 |
+
" clean = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n",
|
| 246 |
+
"\n",
|
| 247 |
+
" # 🔧 SPECIFIC FIXES (Programmatic remapping)\n",
|
| 248 |
+
" if name == \"Baseline\":\n",
|
| 249 |
+
" new_d = {}\n",
|
| 250 |
+
" for k, v in clean.items():\n",
|
| 251 |
+
" nk = k if k.startswith(\"model.\") else \"model.\" + k\n",
|
| 252 |
+
" if \"token_emb.weight\" in nk and \"emb\" not in nk: nk = nk.replace(\"token_emb.weight\", \"token_emb.emb.weight\")\n",
|
| 253 |
+
" new_d[nk] = v\n",
|
| 254 |
+
" clean = new_d\n",
|
| 255 |
+
" elif name == \"FNet\":\n",
|
| 256 |
+
" new_d = {}\n",
|
| 257 |
+
" for k, v in clean.items():\n",
|
| 258 |
+
" nk = k.replace(\"model.\", \"\")\n",
|
| 259 |
+
" new_d[nk] = v\n",
|
| 260 |
+
" clean = new_d\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" # LOAD\n",
|
| 263 |
+
" missing, _ = model.load_state_dict(clean, strict=False)\n",
|
| 264 |
+
" print(f\"✅ {name} Loaded. Missing Keys: {len(missing)}\")\n",
|
| 265 |
+
" return model\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"def evaluate_full(model, loader):\n",
|
| 268 |
+
" model.eval()\n",
|
| 269 |
+
" total_nll = 0\n",
|
| 270 |
+
" total_mask_count = 0 # Track exact number of predictions\n",
|
| 271 |
+
" correct_1 = 0\n",
|
| 272 |
+
" correct_5 = 0\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" # Switch to SUM reduction to handle varying mask counts correctly\n",
|
| 275 |
+
" criterion = nn.CrossEntropyLoss(reduction='sum')\n",
|
| 276 |
+
"\n",
|
| 277 |
+
" with torch.no_grad():\n",
|
| 278 |
+
" for b in tqdm(loader, leave=False):\n",
|
| 279 |
+
" ids = b['input_ids'].to(DEVICE)\n",
|
| 280 |
+
" lbl = b['labels'].to(DEVICE)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" logits = model(ids)\n",
|
| 283 |
+
" if isinstance(logits, dict): logits = logits['logits']\n",
|
| 284 |
+
"\n",
|
| 285 |
+
" # Mask logic: only calculate loss on -100 (masked) tokens\n",
|
| 286 |
+
" mask = (lbl != -100)\n",
|
| 287 |
+
" if mask.sum() > 0:\n",
|
| 288 |
+
" # 1. PPL: Sum the loss, don't mean it yet\n",
|
| 289 |
+
" loss = criterion(logits.view(-1, VOCAB_SIZE), lbl.view(-1))\n",
|
| 290 |
+
" total_nll += loss.item()\n",
|
| 291 |
+
" total_mask_count += mask.sum().item()\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" # 2. Accuracy\n",
|
| 294 |
+
" m_log = logits[mask]\n",
|
| 295 |
+
" m_lbl = lbl[mask]\n",
|
| 296 |
+
" correct_1 += (m_log.argmax(-1) == m_lbl).sum().item()\n",
|
| 297 |
+
" _, top5 = m_log.topk(5, -1)\n",
|
| 298 |
+
" correct_5 += (top5 == m_lbl.unsqueeze(1)).any(1).sum().item()\n",
|
| 299 |
+
"\n",
|
| 300 |
+
" if total_mask_count == 0: return 0, 0, 0\n",
|
| 301 |
+
"\n",
|
| 302 |
+
" # Final Calculation\n",
|
| 303 |
+
" ppl = math.exp(total_nll / total_mask_count)\n",
|
| 304 |
+
" acc1 = (correct_1 / total_mask_count) * 100\n",
|
| 305 |
+
" acc5 = (correct_5 / total_mask_count) * 100\n",
|
| 306 |
+
" return ppl, acc1, acc5\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"def audit_efficiency(model, name):\n",
|
| 310 |
+
" total = sum(p.numel() for p in model.parameters())\n",
|
| 311 |
+
" is_tied = total < 45000000 # Hard check for 30-40M range\n",
|
| 312 |
+
"\n",
|
| 313 |
+
" active_mix = 0\n",
|
| 314 |
+
" if name == \"Baseline\":\n",
|
| 315 |
+
" for n, p in model.named_parameters():\n",
|
| 316 |
+
" if \"attn\" in n and \"norm\" not in n: active_mix += p.numel()\n",
|
| 317 |
+
" elif name == \"FNet\":\n",
|
| 318 |
+
" for n, p in model.named_parameters():\n",
|
| 319 |
+
" if \"transformer_cap\" in n and \"attn\" in n and \"norm\" not in n: active_mix += p.numel()\n",
|
| 320 |
+
" elif name == \"PRISM\":\n",
|
| 321 |
+
" for n, p in model.named_parameters():\n",
|
| 322 |
+
" if (\"neural_filter\" in n or \"gate_proj\" in n) or (\"refiner\" in n and \"attn\" in n and \"norm\" not in n):\n",
|
| 323 |
+
" active_mix += p.numel()\n",
|
| 324 |
+
" elif name == \"PILLARS\": # Old HSSM\n",
|
| 325 |
+
" for n, p in model.named_parameters():\n",
|
| 326 |
+
" if ((\"neural_filter\" in n or \"gate_proj\" in n) and \"stream_phase\" in n) or (\"fusion_proj\" in n) or (\"refiner\" in n and \"attn\" in n and \"norm\" not in n):\n",
|
| 327 |
+
" active_mix += p.numel()\n",
|
| 328 |
+
" # --- NEW BLOCK ---\n",
|
| 329 |
+
" elif name == \"PILLARS-DAT\":\n",
|
| 330 |
+
" for n, p in model.named_parameters():\n",
|
| 331 |
+
" # 1. Sensory Stream Attention (Rate)\n",
|
| 332 |
+
" if \"stream_sensory\" in n and \"attn\" in n and \"norm\" not in n:\n",
|
| 333 |
+
" active_mix += p.numel()\n",
|
| 334 |
+
" # 2. Relational Stream Filters (Phase)\n",
|
| 335 |
+
" if \"stream_relational\" in n and (\"neural_filter\" in n or \"gate_proj\" in n):\n",
|
| 336 |
+
" active_mix += p.numel()\n",
|
| 337 |
+
" # 3. Refiner Attention (Readout)\n",
|
| 338 |
+
" if \"refiner\" in n and \"attn\" in n and \"norm\" not in n:\n",
|
| 339 |
+
" active_mix += p.numel()\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" return {\n",
|
| 342 |
+
" \"Model\": name,\n",
|
| 343 |
+
" \"Weights Tied?\": \"✅ YES\" if is_tied else \"❌ NO\",\n",
|
| 344 |
+
" \"Total Params (M)\": total / 1e6,\n",
|
| 345 |
+
" \"Active Mixing (M)\": active_mix / 1e6,\n",
|
| 346 |
+
" \"Mixing Ratio\": f\"{(active_mix/total)*100:.1f}%\"\n",
|
| 347 |
+
" }\n",
|
| 348 |
+
"# ==========================================\n",
|
| 349 |
+
"# 4. EXECUTION\n",
|
| 350 |
+
"# ==========================================\n",
|
| 351 |
+
"MODELS = [\n",
|
| 352 |
+
" (\"Baseline\", \"prism-lab/baseline-wikitext-prism\", LocalBaseline),\n",
|
| 353 |
+
" (\"FNet\", \"prism-lab/hybrid-fnet-prism-custom\", HybridFNetMLM),\n",
|
| 354 |
+
" (\"PRISM\", \"prism-lab/prism-v2-wikitext\", LocalPRISM),\n",
|
| 355 |
+
" (\"HSSM\", \"prism-lab/pillars-compact-wikitext\", LocalPillars),\n",
|
| 356 |
+
" (\"WPT\", \"prism-lab/pillars-dat-wikitext-32k\", LocalPillarsDAT)\n",
|
| 357 |
+
"]\n",
|
| 358 |
+
"print(\"📦 Loading Data...\")\n",
|
| 359 |
+
"d = load_dataset(DATASET_ID)\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"# Priority: Test > Validation > Train (fallback)\n",
|
| 362 |
+
"if 'test' in d:\n",
|
| 363 |
+
" ds = d['test']\n",
|
| 364 |
+
"elif 'validation' in d:\n",
|
| 365 |
+
" ds = d['validation']\n",
|
| 366 |
+
"else:\n",
|
| 367 |
+
" ds = d['train']\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"print(f\"📊 Evaluating on split: {ds.split if hasattr(ds, 'split') else 'Unknown'}\")\n",
|
| 370 |
+
"ds.set_format(\"torch\", columns=[\"input_ids\", \"labels\"])\n",
|
| 371 |
+
"loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"perf_results = []\n",
|
| 374 |
+
"param_results = []\n",
|
| 375 |
+
"\n",
|
| 376 |
+
"print(\"\\n🔥 STARTING FULL BENCHMARK\")\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"for name, repo, cls in MODELS:\n",
|
| 379 |
+
" print(f\"\\n🧪 Processing {name}...\")\n",
|
| 380 |
+
" try:\n",
|
| 381 |
+
" model = cls().to(DEVICE)\n",
|
| 382 |
+
" model = smart_load(model, repo, name)\n",
|
| 383 |
+
"\n",
|
| 384 |
+
" # 1. Performance\n",
|
| 385 |
+
" ppl, top1, top5 = evaluate_full(model, loader)\n",
|
| 386 |
+
" perf_results.append({\"Model\": name, \"PPL\": ppl, \"Top-1\": top1, \"Top-5\": top5})\n",
|
| 387 |
+
"\n",
|
| 388 |
+
" # 2. Efficiency\n",
|
| 389 |
+
" param_results.append(audit_efficiency(model, name))\n",
|
| 390 |
+
"\n",
|
| 391 |
+
" del model; torch.cuda.empty_cache(); gc.collect()\n",
|
| 392 |
+
" except Exception as e:\n",
|
| 393 |
+
" print(f\"❌ {name} Failed: {e}\")\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"print(\"\\n\\n\" + \"=\"*80)\n",
|
| 396 |
+
"print(\"🏆 TABLE 1: PERFORMANCE BENCHMARK\")\n",
|
| 397 |
+
"print(\"=\"*80)\n",
|
| 398 |
+
"print(pd.DataFrame(perf_results).sort_values(\"PPL\").to_markdown(index=False, floatfmt=\".4f\"))\n",
|
| 399 |
+
"\n",
|
| 400 |
+
"print(\"\\n\\n\" + \"=\"*80)\n",
|
| 401 |
+
"print(\"🏆 TABLE 2: COMPUTATIONAL EFFICIENCY\")\n",
|
| 402 |
+
"print(\"=\"*80)\n",
|
| 403 |
+
"print(pd.DataFrame(param_results).to_markdown(index=False, floatfmt=\".2f\"))"
|
| 404 |
+
]
|
| 405 |
+
}
|
| 406 |
+
]
|
| 407 |
+
}
|