prism-lab commited on
Commit
48f6f9f
·
verified ·
1 Parent(s): 5fe9601

Upload Eval_T4_Last.ipynb

Browse files

Reproduces Table 4. Just open in colab and connect to T4 GPU and click run to replicate our results on our pretrained models.

Files changed (1) hide show
  1. Eval_T4_Last.ipynb +407 -0
Eval_T4_Last.ipynb ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "id": "o03F7Fu7qZXo"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "# ==========================================\n",
28
+ "# 0. SETUP & DEPENDENCIES\n",
29
+ "# ==========================================\n",
30
+ "!pip install -q x-transformers flash-attn datasets pandas tabulate huggingface_hub\n",
31
+ "\n",
32
+ "import torch\n",
33
+ "import torch.nn as nn\n",
34
+ "import torch.nn.functional as F\n",
35
+ "import math\n",
36
+ "import gc\n",
37
+ "import pandas as pd\n",
38
+ "from tqdm.auto import tqdm\n",
39
+ "from torch.utils.data import DataLoader\n",
40
+ "from datasets import load_dataset\n",
41
+ "from huggingface_hub import hf_hub_download\n",
42
+ "from x_transformers import TransformerWrapper, Encoder\n",
43
+ "\n",
44
+ "# Global Config\n",
45
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
46
+ "DATASET_ID = \"prism-lab/wikitext-103-prism-test-seed42\"\n",
47
+ "BATCH_SIZE = 8\n",
48
+ "VOCAB_SIZE = 32768\n",
49
+ "SEQ_LEN = 4096\n",
50
+ "\n",
51
+ "print(f\"🔥 Initializing Full Benchmark Suite on {DEVICE}\")\n",
52
+ "\n",
53
+ "# ==========================================\n",
54
+ "# 1. ARCHITECTURE DEFINITIONS (Exact Matches)\n",
55
+ "# ==========================================\n",
56
+ "\n",
57
+ "# --- UTILS (Shared) ---\n",
58
+ "class ComplexDropout(nn.Module):\n",
59
+ " def __init__(self, p=0.0): super().__init__(); self.p = p\n",
60
+ " def forward(self, z): return z\n",
61
+ "class RobustPhaseNorm(nn.Module):\n",
62
+ " def __init__(self, d, eps=1e-5): super().__init__(); self.scale = nn.Parameter(torch.ones(d)); self.eps = eps\n",
63
+ " def forward(self, x): return (x / torch.sqrt((x.abs()**2).mean(-1, keepdim=True) + self.eps)) * self.scale\n",
64
+ "class ModReLU(nn.Module):\n",
65
+ " def __init__(self, f): super().__init__(); self.b = nn.Parameter(torch.zeros(f))\n",
66
+ " def forward(self, z): return F.relu(z.abs() + self.b) * (z / (z.abs() + 1e-6))\n",
67
+ "class ComplexToRealBridge(nn.Module):\n",
68
+ " def __init__(self, d): super().__init__(); self.proj = nn.Linear(d*2, d); self.norm = nn.LayerNorm(d)\n",
69
+ " def forward(self, x): return self.norm(self.proj(torch.cat([x.real, x.imag], -1)))\n",
70
+ "class DynamicRoSE(nn.Module):\n",
71
+ " def __init__(self, n, d):\n",
72
+ " super().__init__(); self.raw_embedding = nn.Embedding(n, d); self.adapter = nn.Linear(d, d*2); self.rotation_predictor = nn.Linear(d, d*2)\n",
73
+ " self.register_buffer('freqs', torch.exp(torch.arange(0, d) * -(math.log(10000.0)/d)))\n",
74
+ " def forward(self, x):\n",
75
+ " real = self.raw_embedding(x); params = self.adapter(real); D = real.shape[-1]\n",
76
+ " z = torch.complex(params[...,:D], params[...,D:]); r = self.rotation_predictor(real); rx, ry = r.chunk(2, -1)\n",
77
+ " drot = torch.complex(rx/torch.sqrt(rx**2+ry**2+1e-6), ry/torch.sqrt(rx**2+ry**2+1e-6))\n",
78
+ " pos = torch.arange(real.shape[1], device=x.device).float()\n",
79
+ " srot = torch.polar(torch.ones_like(torch.outer(pos, self.freqs)), torch.outer(pos, self.freqs))\n",
80
+ " return (z * srot.unsqueeze(0) * drot), real\n",
81
+ "class HyenaNeuralFilter(nn.Module):\n",
82
+ " def __init__(self, d, max_len=1024, h=64):\n",
83
+ " super().__init__(); self.d = d; self.register_buffer(\"freqs\", torch.exp(torch.arange(0, h, 2) * -(math.log(10000.0)/h)))\n",
84
+ " self.mlp = nn.Sequential(nn.Linear(h, h), nn.SiLU(), nn.Linear(h, h), nn.SiLU(), nn.Linear(h, d*2))\n",
85
+ " def forward(self, L, dev):\n",
86
+ " t = torch.linspace(0, 1, steps=L, device=dev).unsqueeze(-1)\n",
87
+ " emb = torch.cat([torch.sin(t*self.freqs), torch.cos(t*self.freqs)], -1)\n",
88
+ " out = self.mlp(emb).view(L, self.d, 2); return torch.complex(out[...,0], out[...,1])\n",
89
+ "class GatedHarmonicConvolution(nn.Module):\n",
90
+ " def __init__(self, d, max_len):\n",
91
+ " super().__init__(); self.d=d; self.filter_len=max_len; self.neural_filter = HyenaNeuralFilter(d, max_len)\n",
92
+ " self.gate_proj = nn.Linear(d*2, d*2); self.mix_real = nn.Linear(d,d); self.mix_imag = nn.Linear(d,d)\n",
93
+ " self.out_real = nn.Linear(d,d); self.out_imag = nn.Linear(d,d); self.activation = ModReLU(d); self.norm = RobustPhaseNorm(d)\n",
94
+ " self.dropout = ComplexDropout(0.0)\n",
95
+ " def forward(self, x, mask=None):\n",
96
+ " res = x; x = self.norm(x); B,L,D = x.shape; eff_L = min(L, self.filter_len)\n",
97
+ " h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n",
98
+ " xt = torch.fft.ifft(torch.fft.fft(x, n=eff_L, dim=1, norm='ortho') * h, n=eff_L, dim=1, norm='ortho')\n",
99
+ " if L > eff_L: xt = F.pad(xt, (0,0,0,L-eff_L));\n",
100
+ " else: xt = xt[:, :L, :]\n",
101
+ " g = torch.sigmoid(self.gate_proj(torch.cat([x.real, x.imag], -1))); gr, gi = g.chunk(2, -1)\n",
102
+ " xg = torch.complex(xt.real*gr, xt.imag*gi); mr, mi = self.mix_real, self.mix_imag\n",
103
+ " xm = torch.complex(mr(xg.real)-mi(xg.imag), mr(xg.imag)+mi(xg.real)); xa = self.activation(xm); or_, oi = self.out_real, self.out_imag\n",
104
+ " out = torch.complex(or_(xa.real)-oi(xa.imag), or_(xa.imag)+oi(xa.real))\n",
105
+ " return self.dropout(out) + res\n",
106
+ "class PRISMEncoder(nn.Module):\n",
107
+ " def __init__(self, l, d, max_l): super().__init__(); self.layers = nn.ModuleList([GatedHarmonicConvolution(d, max_l) for _ in range(l)]); self.final_norm = RobustPhaseNorm(d)\n",
108
+ " def forward(self, x):\n",
109
+ " for layer in self.layers: x = layer(x)\n",
110
+ " return self.final_norm(x)\n",
111
+ "\n",
112
+ "# --- A. BASELINE (Transformer) ---\n",
113
+ "class LocalBaseline(nn.Module):\n",
114
+ " def __init__(self):\n",
115
+ " super().__init__()\n",
116
+ " self.model = TransformerWrapper(\n",
117
+ " num_tokens=VOCAB_SIZE, max_seq_len=SEQ_LEN, use_abs_pos_emb=False, tie_embedding=True,\n",
118
+ " attn_layers=Encoder(dim=512, depth=5, heads=8, rotary_pos_emb=True, attn_flash=True, use_scalenorm=False)\n",
119
+ " )\n",
120
+ " def forward(self, x): return self.model(x)\n",
121
+ "\n",
122
+ "# --- B. FNET (Hybrid) ---\n",
123
+ "class FNetBlock(nn.Module):\n",
124
+ " def __init__(self, d, df):\n",
125
+ " super().__init__(); self.norm_mix = nn.LayerNorm(d); self.norm_ff = nn.LayerNorm(d)\n",
126
+ " self.ff = nn.Sequential(nn.Linear(d, df), nn.GELU(), nn.Dropout(0), nn.Linear(df, d), nn.Dropout(0))\n",
127
+ " def forward(self, x):\n",
128
+ " r = x; x = self.norm_mix(x); x = torch.fft.fftn(x.float(), dim=(-2,-1), norm='ortho').real.to(r.dtype); x = x+r\n",
129
+ " r = x; x = self.norm_ff(x); x = self.ff(x); return x+r\n",
130
+ "class FNetEncoder(nn.Module):\n",
131
+ " def __init__(self, depth, d, df): super().__init__(); self.layers = nn.ModuleList([FNetBlock(d, df) for _ in range(depth)]); self.norm_out = nn.LayerNorm(d)\n",
132
+ " def forward(self, x):\n",
133
+ " for l in self.layers: x = l(x)\n",
134
+ " return self.norm_out(x)\n",
135
+ "class HybridFNetMLM(nn.Module):\n",
136
+ " def __init__(self):\n",
137
+ " super().__init__()\n",
138
+ " self.token_emb = nn.Embedding(VOCAB_SIZE, 512); self.pos_emb = nn.Parameter(torch.zeros(1, SEQ_LEN, 512))\n",
139
+ " self.fnet_encoder = FNetEncoder(6, 512, 2048)\n",
140
+ " self.transformer_cap = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
141
+ " self.final_norm = nn.LayerNorm(512); self.to_logits = nn.Linear(512, VOCAB_SIZE)\n",
142
+ " self.to_logits.weight = self.token_emb.weight # Tie\n",
143
+ " def forward(self, x):\n",
144
+ " h = self.token_emb(x) + self.pos_emb[:, :x.shape[1], :]\n",
145
+ " return self.to_logits(self.final_norm(self.transformer_cap(self.fnet_encoder(h))))\n",
146
+ "\n",
147
+ "# --- C. PRISM (Phase Coder) ---\n",
148
+ "class LocalPRISM(nn.Module):\n",
149
+ " def __init__(self):\n",
150
+ " super().__init__()\n",
151
+ " self.rose = DynamicRoSE(VOCAB_SIZE, 512); self.prism_encoder = PRISMEncoder(5, 512, SEQ_LEN)\n",
152
+ " self.bridge = ComplexToRealBridge(512); self.periscope_proj = nn.Sequential(nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU())\n",
153
+ " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
154
+ " self.lm_head = nn.Linear(512, VOCAB_SIZE); self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n",
155
+ " def forward(self, x):\n",
156
+ " w, p = self.rose(x); w = self.bridge(self.prism_encoder(w))\n",
157
+ " return self.lm_head(self.refiner(self.periscope_proj(torch.cat([w, p], -1))))\n",
158
+ "\n",
159
+ "# --- D. PILLARS (Split-Stream) ---\n",
160
+ "class LocalPillars(nn.Module):\n",
161
+ " def __init__(self):\n",
162
+ " super().__init__()\n",
163
+ " self.rose = DynamicRoSE(VOCAB_SIZE, 512); self.particle_down = nn.Linear(512, 256); self.wave_down = nn.Linear(1024, 512)\n",
164
+ " self.fnet_pos = nn.Embedding(SEQ_LEN, 256); self.stream_rate = FNetEncoder(9, 256, 1024)\n",
165
+ " self.stream_phase = PRISMEncoder(9, 256, SEQ_LEN); self.phase_bridge = ComplexToRealBridge(256)\n",
166
+ " self.fusion_proj = nn.Linear(512, 512); self.fusion_norm = nn.LayerNorm(512)\n",
167
+ " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
168
+ " self.head_bias = nn.Parameter(torch.zeros(VOCAB_SIZE))\n",
169
+ " def forward(self, x):\n",
170
+ " w, p = self.rose(x); p_sm = self.particle_down(p); w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n",
171
+ " w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n",
172
+ " p_path = self.stream_rate(p_sm + self.fnet_pos(torch.arange(x.shape[1], device=x.device)))\n",
173
+ " w_path = self.phase_bridge(self.stream_phase(w_sm))\n",
174
+ " ctx = self.fusion_norm(self.fusion_proj(torch.cat([p_path, w_path], -1)))\n",
175
+ " return F.linear(self.refiner(ctx), self.rose.raw_embedding.weight, self.head_bias)\n",
176
+ "\n",
177
+ "# --- NEW: Sensory Stream for DAT ---\n",
178
+ "class SensoryStream(nn.Module):\n",
179
+ " def __init__(self, depth, d, dropout=0.1):\n",
180
+ " super().__init__()\n",
181
+ " self.encoder = Encoder(\n",
182
+ " dim=d, depth=depth, heads=4, attn_flash=True,\n",
183
+ " rotary_pos_emb=True, attn_dropout=dropout, ff_dropout=dropout,\n",
184
+ " use_rmsnorm=True, ff_glu=True\n",
185
+ " )\n",
186
+ " def forward(self, x): return self.encoder(x)\n",
187
+ "\n",
188
+ "# --- E. PILLARS-DAT (New Hybrid) ---\n",
189
+ "class LocalPillarsDAT(nn.Module):\n",
190
+ " def __init__(self):\n",
191
+ " super().__init__()\n",
192
+ " # Config matching your training\n",
193
+ " d_model, d_branch, depth = 512, 256, 6\n",
194
+ "\n",
195
+ " self.rose = DynamicRoSE(VOCAB_SIZE, d_model)\n",
196
+ " self.particle_down = nn.Linear(d_model, d_branch)\n",
197
+ " self.wave_down = nn.Linear(d_model * 2, d_branch * 2)\n",
198
+ "\n",
199
+ " # Stream A: Sensory (Transformer)\n",
200
+ " self.stream_sensory = SensoryStream(depth, d_branch)\n",
201
+ "\n",
202
+ " # Stream B: Relational (PRISM)\n",
203
+ " # Re-using existing PRISMEncoder(layers, dim, seq_len)\n",
204
+ " self.stream_relational = PRISMEncoder(depth, d_branch, SEQ_LEN)\n",
205
+ " self.relational_bridge = ComplexToRealBridge(d_branch)\n",
206
+ "\n",
207
+ " # Fusion\n",
208
+ " self.fusion_proj = nn.Linear(d_branch * 2, d_model)\n",
209
+ " self.fusion_norm = nn.LayerNorm(d_model)\n",
210
+ "\n",
211
+ " # Refiner\n",
212
+ " self.refiner = Encoder(dim=d_model, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
213
+ "\n",
214
+ " # Output (Standard Linear to match HF Checkpoint)\n",
215
+ " self.lm_head = nn.Linear(d_model, VOCAB_SIZE)\n",
216
+ " self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n",
217
+ "\n",
218
+ " def forward(self, x):\n",
219
+ " w, p = self.rose(x)\n",
220
+ " p_sm = self.particle_down(p)\n",
221
+ "\n",
222
+ " # Complex Downsample\n",
223
+ " w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n",
224
+ " w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n",
225
+ "\n",
226
+ " # Parallel Streams\n",
227
+ " sensory_out = self.stream_sensory(p_sm)\n",
228
+ " rel_out = self.relational_bridge(self.stream_relational(w_sm))\n",
229
+ "\n",
230
+ " # Fusion\n",
231
+ " ctx = self.fusion_norm(self.fusion_proj(torch.cat([sensory_out, rel_out], -1)))\n",
232
+ " return self.lm_head(self.refiner(ctx))\n",
233
+ "\n",
234
+ "# ==========================================\n",
235
+ "# 2. INTELLIGENT LOADING & EVALUATION\n",
236
+ "# ==========================================\n",
237
+ "def smart_load(model, repo_id, name):\n",
238
+ " print(f\"⬇️ Downloading {name} from HF: {repo_id}...\")\n",
239
+ " try: path = hf_hub_download(repo_id, \"best.pt\")\n",
240
+ " except: path = hf_hub_download(repo_id, \"pytorch_model.bin\")\n",
241
+ "\n",
242
+ " state_dict = torch.load(path, map_location=\"cpu\")\n",
243
+ " if 'model' in state_dict: state_dict = state_dict['model']\n",
244
+ "\n",
245
+ " clean = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n",
246
+ "\n",
247
+ " # 🔧 SPECIFIC FIXES (Programmatic remapping)\n",
248
+ " if name == \"Baseline\":\n",
249
+ " new_d = {}\n",
250
+ " for k, v in clean.items():\n",
251
+ " nk = k if k.startswith(\"model.\") else \"model.\" + k\n",
252
+ " if \"token_emb.weight\" in nk and \"emb\" not in nk: nk = nk.replace(\"token_emb.weight\", \"token_emb.emb.weight\")\n",
253
+ " new_d[nk] = v\n",
254
+ " clean = new_d\n",
255
+ " elif name == \"FNet\":\n",
256
+ " new_d = {}\n",
257
+ " for k, v in clean.items():\n",
258
+ " nk = k.replace(\"model.\", \"\")\n",
259
+ " new_d[nk] = v\n",
260
+ " clean = new_d\n",
261
+ "\n",
262
+ " # LOAD\n",
263
+ " missing, _ = model.load_state_dict(clean, strict=False)\n",
264
+ " print(f\"✅ {name} Loaded. Missing Keys: {len(missing)}\")\n",
265
+ " return model\n",
266
+ "\n",
267
+ "def evaluate_full(model, loader):\n",
268
+ " model.eval()\n",
269
+ " total_nll = 0\n",
270
+ " total_mask_count = 0 # Track exact number of predictions\n",
271
+ " correct_1 = 0\n",
272
+ " correct_5 = 0\n",
273
+ "\n",
274
+ " # Switch to SUM reduction to handle varying mask counts correctly\n",
275
+ " criterion = nn.CrossEntropyLoss(reduction='sum')\n",
276
+ "\n",
277
+ " with torch.no_grad():\n",
278
+ " for b in tqdm(loader, leave=False):\n",
279
+ " ids = b['input_ids'].to(DEVICE)\n",
280
+ " lbl = b['labels'].to(DEVICE)\n",
281
+ "\n",
282
+ " logits = model(ids)\n",
283
+ " if isinstance(logits, dict): logits = logits['logits']\n",
284
+ "\n",
285
+ " # Mask logic: only calculate loss on -100 (masked) tokens\n",
286
+ " mask = (lbl != -100)\n",
287
+ " if mask.sum() > 0:\n",
288
+ " # 1. PPL: Sum the loss, don't mean it yet\n",
289
+ " loss = criterion(logits.view(-1, VOCAB_SIZE), lbl.view(-1))\n",
290
+ " total_nll += loss.item()\n",
291
+ " total_mask_count += mask.sum().item()\n",
292
+ "\n",
293
+ " # 2. Accuracy\n",
294
+ " m_log = logits[mask]\n",
295
+ " m_lbl = lbl[mask]\n",
296
+ " correct_1 += (m_log.argmax(-1) == m_lbl).sum().item()\n",
297
+ " _, top5 = m_log.topk(5, -1)\n",
298
+ " correct_5 += (top5 == m_lbl.unsqueeze(1)).any(1).sum().item()\n",
299
+ "\n",
300
+ " if total_mask_count == 0: return 0, 0, 0\n",
301
+ "\n",
302
+ " # Final Calculation\n",
303
+ " ppl = math.exp(total_nll / total_mask_count)\n",
304
+ " acc1 = (correct_1 / total_mask_count) * 100\n",
305
+ " acc5 = (correct_5 / total_mask_count) * 100\n",
306
+ " return ppl, acc1, acc5\n",
307
+ "\n",
308
+ "\n",
309
+ "def audit_efficiency(model, name):\n",
310
+ " total = sum(p.numel() for p in model.parameters())\n",
311
+ " is_tied = total < 45000000 # Hard check for 30-40M range\n",
312
+ "\n",
313
+ " active_mix = 0\n",
314
+ " if name == \"Baseline\":\n",
315
+ " for n, p in model.named_parameters():\n",
316
+ " if \"attn\" in n and \"norm\" not in n: active_mix += p.numel()\n",
317
+ " elif name == \"FNet\":\n",
318
+ " for n, p in model.named_parameters():\n",
319
+ " if \"transformer_cap\" in n and \"attn\" in n and \"norm\" not in n: active_mix += p.numel()\n",
320
+ " elif name == \"PRISM\":\n",
321
+ " for n, p in model.named_parameters():\n",
322
+ " if (\"neural_filter\" in n or \"gate_proj\" in n) or (\"refiner\" in n and \"attn\" in n and \"norm\" not in n):\n",
323
+ " active_mix += p.numel()\n",
324
+ " elif name == \"PILLARS\": # Old HSSM\n",
325
+ " for n, p in model.named_parameters():\n",
326
+ " if ((\"neural_filter\" in n or \"gate_proj\" in n) and \"stream_phase\" in n) or (\"fusion_proj\" in n) or (\"refiner\" in n and \"attn\" in n and \"norm\" not in n):\n",
327
+ " active_mix += p.numel()\n",
328
+ " # --- NEW BLOCK ---\n",
329
+ " elif name == \"PILLARS-DAT\":\n",
330
+ " for n, p in model.named_parameters():\n",
331
+ " # 1. Sensory Stream Attention (Rate)\n",
332
+ " if \"stream_sensory\" in n and \"attn\" in n and \"norm\" not in n:\n",
333
+ " active_mix += p.numel()\n",
334
+ " # 2. Relational Stream Filters (Phase)\n",
335
+ " if \"stream_relational\" in n and (\"neural_filter\" in n or \"gate_proj\" in n):\n",
336
+ " active_mix += p.numel()\n",
337
+ " # 3. Refiner Attention (Readout)\n",
338
+ " if \"refiner\" in n and \"attn\" in n and \"norm\" not in n:\n",
339
+ " active_mix += p.numel()\n",
340
+ "\n",
341
+ " return {\n",
342
+ " \"Model\": name,\n",
343
+ " \"Weights Tied?\": \"✅ YES\" if is_tied else \"❌ NO\",\n",
344
+ " \"Total Params (M)\": total / 1e6,\n",
345
+ " \"Active Mixing (M)\": active_mix / 1e6,\n",
346
+ " \"Mixing Ratio\": f\"{(active_mix/total)*100:.1f}%\"\n",
347
+ " }\n",
348
+ "# ==========================================\n",
349
+ "# 4. EXECUTION\n",
350
+ "# ==========================================\n",
351
+ "MODELS = [\n",
352
+ " (\"Baseline\", \"prism-lab/baseline-wikitext-prism\", LocalBaseline),\n",
353
+ " (\"FNet\", \"prism-lab/hybrid-fnet-prism-custom\", HybridFNetMLM),\n",
354
+ " (\"PRISM\", \"prism-lab/prism-v2-wikitext\", LocalPRISM),\n",
355
+ " (\"HSSM\", \"prism-lab/pillars-compact-wikitext\", LocalPillars),\n",
356
+ " (\"WPT\", \"prism-lab/pillars-dat-wikitext-32k\", LocalPillarsDAT)\n",
357
+ "]\n",
358
+ "print(\"📦 Loading Data...\")\n",
359
+ "d = load_dataset(DATASET_ID)\n",
360
+ "\n",
361
+ "# Priority: Test > Validation > Train (fallback)\n",
362
+ "if 'test' in d:\n",
363
+ " ds = d['test']\n",
364
+ "elif 'validation' in d:\n",
365
+ " ds = d['validation']\n",
366
+ "else:\n",
367
+ " ds = d['train']\n",
368
+ "\n",
369
+ "print(f\"📊 Evaluating on split: {ds.split if hasattr(ds, 'split') else 'Unknown'}\")\n",
370
+ "ds.set_format(\"torch\", columns=[\"input_ids\", \"labels\"])\n",
371
+ "loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)\n",
372
+ "\n",
373
+ "perf_results = []\n",
374
+ "param_results = []\n",
375
+ "\n",
376
+ "print(\"\\n🔥 STARTING FULL BENCHMARK\")\n",
377
+ "\n",
378
+ "for name, repo, cls in MODELS:\n",
379
+ " print(f\"\\n🧪 Processing {name}...\")\n",
380
+ " try:\n",
381
+ " model = cls().to(DEVICE)\n",
382
+ " model = smart_load(model, repo, name)\n",
383
+ "\n",
384
+ " # 1. Performance\n",
385
+ " ppl, top1, top5 = evaluate_full(model, loader)\n",
386
+ " perf_results.append({\"Model\": name, \"PPL\": ppl, \"Top-1\": top1, \"Top-5\": top5})\n",
387
+ "\n",
388
+ " # 2. Efficiency\n",
389
+ " param_results.append(audit_efficiency(model, name))\n",
390
+ "\n",
391
+ " del model; torch.cuda.empty_cache(); gc.collect()\n",
392
+ " except Exception as e:\n",
393
+ " print(f\"❌ {name} Failed: {e}\")\n",
394
+ "\n",
395
+ "print(\"\\n\\n\" + \"=\"*80)\n",
396
+ "print(\"🏆 TABLE 1: PERFORMANCE BENCHMARK\")\n",
397
+ "print(\"=\"*80)\n",
398
+ "print(pd.DataFrame(perf_results).sort_values(\"PPL\").to_markdown(index=False, floatfmt=\".4f\"))\n",
399
+ "\n",
400
+ "print(\"\\n\\n\" + \"=\"*80)\n",
401
+ "print(\"🏆 TABLE 2: COMPUTATIONAL EFFICIENCY\")\n",
402
+ "print(\"=\"*80)\n",
403
+ "print(pd.DataFrame(param_results).to_markdown(index=False, floatfmt=\".2f\"))"
404
+ ]
405
+ }
406
+ ]
407
+ }