File size: 22,438 Bytes
5fe9601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!pip install -q x-transformers"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TWiErEkm1YNU",
        "outputId": "1dd7de09-712e-4f5a-f74d-9c48f7702dd9"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m97.8/97.8 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.6/101.6 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.0/103.0 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.6/61.6 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XfhKiI_Z1Q6F"
      },
      "outputs": [],
      "source": [
        "# @title πŸ› οΈ Appendix Physical Validation (Gain & Stability)\n",
        "import torch\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from huggingface_hub import hf_hub_download\n",
        "from transformers import AutoTokenizer\n",
        "import sys\n",
        "import os\n",
        "\n",
        "# ==============================================================================\n",
        "# 1. SETUP & MODEL LOADING\n",
        "# ==============================================================================\n",
        "REPO_ID = \"prism-lab/prism-shimmer-100k\"\n",
        "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "print(f\"βš™οΈ Hardware: {DEVICE}\")\n",
        "print(f\"πŸ“₯ Loading PRISM from {REPO_ID}...\")\n",
        "\n",
        "# Download architecture\n",
        "os.makedirs(\"shimmer_code\", exist_ok=True)\n",
        "hf_hub_download(repo_id=REPO_ID, filename=\"modeling_prism_gated.py\", local_dir=\"shimmer_code\")\n",
        "sys.path.append(\"shimmer_code\")\n",
        "\n",
        "from modeling_prism_gated import PRISMHybrid_RoPE\n",
        "\n",
        "# Load Model\n",
        "tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n",
        "CONFIG = {\n",
        "    \"vocab_size\": 58101, \"d_model\": 512, \"num_heads\": 8, \"dff\": 2048,\n",
        "    \"dropout\": 0.1, \"max_length\": 128, \"num_encoder_layers\": 6,\n",
        "    \"num_refining_layers\": 0, \"num_decoder_layers\": 6\n",
        "}\n",
        "model = PRISMHybrid_RoPE(**CONFIG)\n",
        "state_dict = torch.load(hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\"), map_location=DEVICE)\n",
        "model.load_state_dict(state_dict)\n",
        "model.to(DEVICE)\n",
        "model.eval()\n",
        "\n",
        "print(\"βœ… Model Ready.\")\n",
        "\n",
        "# ==============================================================================\n",
        "# 2. DATASETS (Placeholders)\n",
        "# ==============================================================================\n",
        "# ⚠️ PASTE YOUR FULL LISTS HERE FROM THE PREVIOUS STEP\n",
        "# N=76 Hard, N=70 Easy\n",
        "\n",
        "raw_poly_candidates = [\n",
        "    # --- ORIGINAL SET ---\n",
        "    (\"Ich gehe zur Bank um Geld zu holen\", \"Bank\"), (\"Die Bank hat hohe Zinsen\", \"Bank\"),\n",
        "    (\"Wir saßen auf einer Bank im Park\", \"Bank\"), (\"Die Bank aus Holz war bequem\", \"Bank\"),\n",
        "    (\"Das Schloss hat viele TΓΌrme\", \"Schloss\"), (\"Der KΓΆnig wohnt im Schloss\", \"Schloss\"),\n",
        "    (\"Der SchlΓΌssel steckt im Schloss\", \"Schloss\"), (\"Das Schloss an der TΓΌr klemmt\", \"Schloss\"),\n",
        "    (\"Der Leiter der Firma ist streng\", \"Leiter\"), (\"Unser Leiter plant das Projekt\", \"Leiter\"),\n",
        "    (\"Ich steige auf die Leiter\", \"Leiter\"), (\"Die Leiter ist aus Aluminium\", \"Leiter\"),\n",
        "    (\"Die Lampe hÀngt an der Decke\", \"Decke\"), (\"Die Decke ist weiß gestrichen\", \"Decke\"),\n",
        "    (\"Mir ist kalt gib mir eine Decke\", \"Decke\"), (\"Die Decke aus Wolle ist warm\", \"Decke\"),\n",
        "    (\"Der Kiefer ist ein Nadelbaum\", \"Kiefer\"), (\"Das Holz der Kiefer ist weich\", \"Kiefer\"),\n",
        "    (\"Der Arzt rΓΆntgt meinen Kiefer\", \"Kiefer\"), (\"Er hat Schmerzen im Kiefer\", \"Kiefer\"),\n",
        "    (\"Der Strauß ist ein schneller Vogel\", \"Strauß\"), (\"Dieser Strauß kann nicht fliegen\", \"Strauß\"),\n",
        "    (\"Sie kaufte einen bunten Strauß\", \"Strauß\"), (\"Der Strauß Blumen duftet gut\", \"Strauß\"),\n",
        "    (\"Er schoss ein schΓΆnes Tor\", \"Tor\"), (\"Der Ball flog ins Tor\", \"Tor\"),\n",
        "    (\"Das eiserne Tor war verschlossen\", \"Tor\"), (\"Sie âffneten das große Tor\", \"Tor\"),\n",
        "    (\"Wir tanzen auf dem Ball\", \"Ball\"), (\"Der Maskenball war elegant\", \"Ball\"),\n",
        "    (\"Er warf den Ball weit weg\", \"Ball\"), (\"Der Ball ist rund und rot\", \"Ball\"),\n",
        "    (\"Die Schlange im Zoo ist giftig\", \"Schlange\"), (\"Die Schlange zischte laut\", \"Schlange\"),\n",
        "    (\"Wir stehen in einer langen Schlange\", \"Schlange\"), (\"Die Schlange an der Kasse war lang\", \"Schlange\"),\n",
        "    (\"Der Strom ist ausgefallen\", \"Strom\"), (\"Strom kostet viel Geld\", \"Strom\"),\n",
        "    (\"Der Strom fließt ins Meer\", \"Strom\"), (\"Wir schwammen gegen den Strom\", \"Strom\"),\n",
        "    (\"Seine Mutter ist sehr nett\", \"Mutter\"), (\"Die Mutter kocht das Essen\", \"Mutter\"),\n",
        "    (\"Die Mutter passt auf die Schraube\", \"Mutter\"), (\"Ich brauche eine neue Mutter\", \"Mutter\"),\n",
        "    (\"Die Birne schmeckt süß\", \"Birne\"), (\"Ich esse gerne eine Birne\", \"Birne\"),\n",
        "    (\"Die Birne in der Lampe ist kaputt\", \"Birne\"), (\"Wir mΓΌssen die Birne wechseln\", \"Birne\"),\n",
        "    # --- EXPANSION SET ---\n",
        "    (\"Das Gericht hat ihn verurteilt\", \"Gericht\"), (\"Der Anwalt geht zum Gericht\", \"Gericht\"),\n",
        "    (\"Mein Lieblingsessen ist ein Gericht aus Reis\", \"Gericht\"), (\"Das Gericht schmeckt sehr salzig\", \"Gericht\"),\n",
        "    (\"Der Ton war sehr laut\", \"Ton\"), (\"Ich hΓΆrte einen hohen Ton\", \"Ton\"),\n",
        "    (\"Die Vase ist aus Ton\", \"Ton\"), (\"Wir formen Figuren aus Ton\", \"Ton\"),\n",
        "    (\"Das Blatt fΓ€llt vom Baum\", \"Blatt\"), (\"Im Herbst werden die BlΓ€tter braun\", \"Blatt\"),\n",
        "    (\"Ich schreibe auf ein Blatt Papier\", \"Blatt\"), (\"Gib mir bitte ein leeres Blatt\", \"Blatt\"),\n",
        "    (\"Der Nagel steckt in der Wand\", \"Nagel\"), (\"Ich schlage den Nagel mit dem Hammer\", \"Nagel\"),\n",
        "    (\"Mein Nagel ist abgebrochen\", \"Nagel\"), (\"Sie lackiert sich den Nagel rot\", \"Nagel\"),\n",
        "    (\"Die Maus frisst den KΓ€se\", \"Maus\"), (\"Die Katze jagt die Maus\", \"Maus\"),\n",
        "    (\"Ich klicke mit der Maus\", \"Maus\"), (\"Der Computer braucht eine neue Maus\", \"Maus\"),\n",
        "    (\"Die Erde dreht sich um die Sonne\", \"Erde\"), (\"Der Astronaut schaut auf die Erde\", \"Erde\"),\n",
        "    (\"Die Blume braucht frische Erde\", \"Erde\"), (\"Er grΓ€bt ein Loch in die Erde\", \"Erde\"),\n",
        "    (\"Der Hahn krΓ€ht am Morgen\", \"Hahn\"), (\"Der Hahn hat bunte Federn\", \"Hahn\"),\n",
        "    (\"Der Wasserhahn tropft\", \"Hahn\"), (\"Dreh bitte den Hahn zu\", \"Hahn\"),\n",
        "    (\"Die Schale der Orange ist bitter\", \"Schale\"), (\"Er wirft die Schale weg\", \"Schale\"),\n",
        "    (\"Die Schale steht auf dem Tisch\", \"Schale\"), (\"Ich esse MΓΌsli aus der Schale\", \"Schale\"),\n",
        "    (\"Der Bauer melkt die KΓΌhe\", \"Bauer\"), (\"Der Bauer fΓ€hrt auf dem Traktor\", \"Bauer\"),\n",
        "    (\"Ich ziehe den Bauer auf E4\", \"Bauer\"), (\"Der Bauer schlΓ€gt den Turm\", \"Bauer\"),\n",
        "]\n",
        "\n",
        "# B. EASY MODE (Casual)\n",
        "raw_casual_candidates = [\n",
        "    (\"Die Katze schlΓ€ft\", \"Katze\"), (\"Der Hund bellt\", \"Hund\"), (\"Das Auto fΓ€hrt\", \"Auto\"),\n",
        "    (\"Wasser ist nass\", \"Wasser\"), (\"Das Brot schmeckt gut\", \"Brot\"), (\"Die Sonne scheint\", \"Sonne\"),\n",
        "    (\"Der Mond leuchtet\", \"Mond\"), (\"Das Buch ist spannend\", \"Buch\"), (\"Der Tisch ist rund\", \"Tisch\"),\n",
        "    (\"Der Stuhl ist bequem\", \"Stuhl\"), (\"Der Apfel ist rot\", \"Apfel\"), (\"Meine Hand ist kalt\", \"Hand\"),\n",
        "    (\"Das Herz klopft\", \"Herz\"), (\"Wir haben Zeit\", \"Zeit\"), (\"Geld ist wichtig\", \"Geld\"),\n",
        "    (\"Musik ist schΓΆn\", \"Musik\"), (\"Der Film ist zu Ende\", \"Film\"), (\"Das Spiel beginnt\", \"Spiel\"),\n",
        "    (\"Die Schule ist aus\", \"Schule\"), (\"Die Stadt ist laut\", \"Stadt\"), (\"Der Fluss fließt\", \"Fluss\"),\n",
        "    (\"Das Meer ist tief\", \"Meer\"), (\"Kaffee ist schwarz\", \"Kaffee\"), (\"Milch ist weiß\", \"Milch\"),\n",
        "    (\"Der Bruder lacht\", \"Bruder\"), (\"Die Schwester weint\", \"Schwester\"), (\"Das Haus ist groß\", \"Haus\"),\n",
        "    (\"Der Garten ist grün\", \"Garten\"), (\"Der Sommer ist heiß\", \"Sommer\"), (\"Der Winter ist kalt\", \"Winter\"),\n",
        "    (\"Das Fenster ist offen\", \"Fenster\"), (\"Die TΓΌr ist zu\", \"TΓΌr\"), (\"Der Boden ist sauber\", \"Boden\"),\n",
        "    (\"Die Wand ist weiß\", \"Wand\"), (\"Das Dach ist rot\", \"Dach\"), (\"Der Wald ist dunkel\", \"Wald\"),\n",
        "    (\"Der Berg ist hoch\", \"Berg\"), (\"Der See ist ruhig\", \"See\"), (\"Das Tier ist wild\", \"Tier\"),\n",
        "    (\"Der Mensch denkt\", \"Mensch\"), (\"Das Kind spielt\", \"Kind\"), (\"Die Frau arbeitet\", \"Frau\"),\n",
        "    (\"Der Mann schlΓ€ft\", \"Mann\"), (\"Das Auge sieht\", \"Auge\"), (\"Das Ohr hΓΆrt\", \"Ohr\"),\n",
        "    (\"Die Nase riecht\", \"Nase\"), (\"Der Mund spricht\", \"Mund\"), (\"Der Arm ist stark\", \"Arm\"),\n",
        "    (\"Das Bein tut weh\", \"Bein\"), (\"Der Fuß ist groß\", \"Fuß\"), (\"Der Tee ist heiß\", \"Tee\"),\n",
        "    (\"Das Bier ist kalt\", \"Bier\"), (\"Der Wein ist rot\", \"Wein\"), (\"Das Glas ist voll\", \"Glas\"),\n",
        "    (\"Die Tasse ist leer\", \"Tasse\"), (\"Der Teller ist blau\", \"Teller\"), (\"Die Gabel ist spitz\", \"Gabel\"),\n",
        "    (\"Der LΓΆffel ist rund\", \"LΓΆffel\"), (\"Das Messer ist scharf\", \"Messer\"), (\"Der Stift schreibt\", \"Stift\"),\n",
        "    (\"Der Brief ist lang\", \"Brief\"), (\"Das Bild ist schΓΆn\", \"Bild\"), (\"Die Uhr tickt\", \"Uhr\"),\n",
        "    (\"Das Bett ist weich\", \"Bett\"), (\"Der Schrank ist voll\", \"Schrank\"), (\"Das Sofa ist neu\", \"Sofa\"),\n",
        "    (\"Das Radio spielt\", \"Radio\"), (\"Das Jahr ist um\", \"Jahr\"), (\"Der Tag war lang\", \"Tag\"),\n",
        "    (\"Die Nacht ist kurz\", \"Nacht\")\n",
        "]\n",
        "\n",
        "# ==============================================================================\n",
        "# 3. HELPER: Single-Token Validator\n",
        "# ==============================================================================\n",
        "def filter_dataset(candidates, tokenizer, label):\n",
        "    valid = []\n",
        "    for ctx, tgt in candidates:\n",
        "        t1 = tokenizer.encode(tgt, add_special_tokens=False)\n",
        "        t2 = tokenizer.encode(\" \" + tgt, add_special_tokens=False)\n",
        "        if len(t1) == 1 or len(t2) == 1: valid.append((ctx, tgt))\n",
        "    print(f\"βœ… {label}: {len(valid)} atomic examples validated.\")\n",
        "    return valid\n",
        "\n",
        "def find_token_index(input_ids, target_word, tokenizer):\n",
        "    tokens = tokenizer.convert_ids_to_tokens(input_ids)\n",
        "    for i, t in enumerate(tokens):\n",
        "        clean = t.replace('Δ ', '').replace('▁', '').replace(' ', '')\n",
        "        if target_word.lower() == clean.lower(): return i\n",
        "    for i, t in enumerate(tokens): # Fallback\n",
        "        clean = t.replace('Δ ', '').replace('▁', '').replace(' ', '')\n",
        "        if target_word.lower() in clean.lower(): return i\n",
        "    return 1\n",
        "\n",
        "# ==============================================================================\n",
        "# 4. PHYSICAL PROBE (Gain & Magnitude)\n",
        "# ==============================================================================\n",
        "def run_physical_probe(model, tokenizer, dataset, label, device):\n",
        "    \"\"\"\n",
        "    Extracts Gain (Ratio) and Raw Magnitude (Norm) for CV analysis.\n",
        "    \"\"\"\n",
        "    num_layers = len(model.prism_encoder.layers)\n",
        "\n",
        "    # Store Gain (for Fig B3) and Magnitude (for Fig B1)\n",
        "    gain_stats = {i: [] for i in range(num_layers)}\n",
        "    magnitude_stats = {i: [] for i in range(num_layers)}\n",
        "    embedding_mags = []\n",
        "\n",
        "    hook_data = {}\n",
        "\n",
        "    def physics_hook(layer_idx):\n",
        "        def hook(module, input, output):\n",
        "            x, y = input[0].detach(), output.detach()\n",
        "\n",
        "            # 1. Norms (Energy)\n",
        "            norm_x = torch.norm(x, p=2, dim=-1)\n",
        "            norm_y = torch.norm(y, p=2, dim=-1)\n",
        "\n",
        "            # 2. Gain Calculation\n",
        "            gain = norm_y / (norm_x + 1e-9)\n",
        "\n",
        "            hook_data[f'layer_{layer_idx}'] = {\n",
        "                'gain': gain.cpu(),\n",
        "                'mag': norm_y.cpu() # Output magnitude\n",
        "            }\n",
        "        return hook\n",
        "\n",
        "    # Register Hooks\n",
        "    model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
        "    for i, layer in enumerate(model.prism_encoder.layers):\n",
        "        layer.register_forward_hook(physics_hook(i))\n",
        "\n",
        "    # Run Probe\n",
        "    print(f\"πŸ”¬ Measuring Physics on {len(dataset)} {label} examples...\")\n",
        "    for context, target in dataset:\n",
        "        hook_data = {}\n",
        "        inputs = tokenizer(context, return_tensors=\"pt\").to(device)\n",
        "\n",
        "        with torch.no_grad():\n",
        "            # Capture embedding magnitude before encoder\n",
        "            emb = model.harmonic_embedding(inputs.input_ids)\n",
        "            embedding_mags.append(torch.norm(emb, p=2, dim=-1).flatten().cpu())\n",
        "\n",
        "            # Forward pass\n",
        "            src_mask = (inputs.input_ids == tokenizer.pad_token_id)\n",
        "            model.prism_encoder(emb, src_mask)\n",
        "\n",
        "        idx = find_token_index(inputs.input_ids[0], target, tokenizer)\n",
        "\n",
        "        for i in range(num_layers):\n",
        "            if f'layer_{i}' in hook_data:\n",
        "                data = hook_data[f'layer_{i}']\n",
        "\n",
        "                # Extract atomic token metrics\n",
        "                g = data['gain']\n",
        "                m = data['mag']\n",
        "\n",
        "                val_g = g[0, idx].item() if g.dim() > 1 else g[idx].item()\n",
        "                val_m = m[0, idx].item() if m.dim() > 1 else m[idx].item()\n",
        "\n",
        "                gain_stats[i].append(val_g)\n",
        "                magnitude_stats[i].append(val_m)\n",
        "\n",
        "    model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
        "\n",
        "    return {\n",
        "        'gain': pd.DataFrame(gain_stats),\n",
        "        'magnitude': magnitude_stats, # Dict of lists\n",
        "        'embedding': torch.cat(embedding_mags).numpy()\n",
        "    }\n",
        "\n",
        "# ==============================================================================\n",
        "# 5. EXECUTION\n",
        "# ==============================================================================\n",
        "# Filter\n",
        "ds_hard = filter_dataset(raw_poly_candidates, tokenizer, \"HARD\")\n",
        "ds_easy = filter_dataset(raw_casual_candidates, tokenizer, \"EASY\")\n",
        "\n",
        "# Run\n",
        "res_hard = run_physical_probe(model, tokenizer, ds_hard, \"HARD\", DEVICE)\n",
        "res_easy = run_physical_probe(model, tokenizer, ds_easy, \"EASY\", DEVICE)\n",
        "\n",
        "# ==============================================================================\n",
        "# 6. PLOT FIGURE B3: ISO-ENERGETIC GAIN\n",
        "# ==============================================================================\n",
        "def plot_gain_chart(res_hard, res_easy):\n",
        "    df_h = res_hard['gain']\n",
        "    df_e = res_easy['gain']\n",
        "\n",
        "    layers = list(df_h.columns)\n",
        "    means_h = [df_h[i].mean() for i in layers]\n",
        "    stds_h = [df_h[i].std() for i in layers]\n",
        "    means_e = [df_e[i].mean() for i in layers]\n",
        "    stds_e = [df_e[i].std() for i in layers]\n",
        "\n",
        "    x = np.arange(len(layers))\n",
        "    width = 0.35\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(8, 4), dpi=300)\n",
        "    ax.bar(x - width/2, means_h, width, yerr=stds_h, label='Ambiguous',\n",
        "           color='indianred', alpha=0.8, capsize=3)\n",
        "    ax.bar(x + width/2, means_e, width, yerr=stds_e, label='Unambiguous',\n",
        "           color='steelblue', alpha=0.8, capsize=3)\n",
        "\n",
        "    ax.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Unity Gain (g=1.0)')\n",
        "    ax.set_ylabel('Signal Gain (||y|| / ||x||)', fontweight='bold')\n",
        "    ax.set_xlabel('Layer Depth')\n",
        "    ax.set_xticks(x)\n",
        "    ax.set_xticklabels(layers)\n",
        "    ax.set_ylim(0.85, 1.15) # Zoom in to show it's flat\n",
        "    ax.legend(loc='upper right')\n",
        "    ax.set_title('Iso-Energetic Constraint: Gain β‰ˆ 1.0 Across All Conditions', fontweight='bold')\n",
        "    ax.grid(axis='y', linestyle='--', alpha=0.3)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(\"fig_B3_gain.png\")\n",
        "    plt.show()\n",
        "    print(\"βœ… Figure B3 Saved.\")\n",
        "\n",
        "# ==============================================================================\n",
        "# 7. PLOT FIGURE B1: MAGNITUDE STABILITY (CV)\n",
        "# ==============================================================================\n",
        "def plot_cv_chart(res_hard, res_easy):\n",
        "    # Combine data to check global network stability\n",
        "    # CV = sigma / mu\n",
        "\n",
        "    stages = [\"Embedding\"]\n",
        "    cvs = []\n",
        "\n",
        "    # 1. Embedding Stage\n",
        "    all_emb = np.concatenate([res_hard['embedding'], res_easy['embedding']])\n",
        "    cvs.append(all_emb.std() / all_emb.mean())\n",
        "\n",
        "    # 2. Layers 0-5\n",
        "    for i in range(6):\n",
        "        # Flatten lists\n",
        "        mags_h = np.array(res_hard['magnitude'][i])\n",
        "        mags_e = np.array(res_easy['magnitude'][i])\n",
        "        all_mags = np.concatenate([mags_h, mags_e])\n",
        "\n",
        "        cv = all_mags.std() / (all_mags.mean() + 1e-9)\n",
        "        cvs.append(cv)\n",
        "        stages.append(f\"Layer {i}\")\n",
        "\n",
        "    mean_cv = np.mean(cvs)\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(8, 4), dpi=300)\n",
        "    bars = ax.bar(stages, cvs, color='steelblue', alpha=0.8, edgecolor='grey')\n",
        "\n",
        "    ax.axhline(y=mean_cv, color='red', linestyle='--', label=f'Mean CV = {mean_cv:.3f}')\n",
        "    ax.set_ylabel('Coefficient of Variation (Οƒ/ΞΌ)', fontweight='bold')\n",
        "    ax.set_xlabel('Network Stage')\n",
        "    ax.set_title('Magnitude Stability Across Layers (Iso-Energetic Check)', fontweight='bold')\n",
        "    ax.set_ylim(0, 1.0)\n",
        "    ax.legend()\n",
        "\n",
        "    # Label bars\n",
        "    for bar, v in zip(bars, cvs):\n",
        "        ax.text(bar.get_x() + bar.get_width()/2, v, f\"{v:.3f}\",\n",
        "                ha='center', va='bottom', fontsize=9)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.savefig(\"fig_B1_cv.png\")\n",
        "    plt.show()\n",
        "    print(\"βœ… Figure B1 Saved.\")\n",
        "\n",
        "# ==============================================================================\n",
        "# RUN PLOTS\n",
        "# ==============================================================================\n",
        "plot_gain_chart(res_hard, res_easy)\n",
        "plot_cv_chart(res_hard, res_easy)"
      ]
    }
  ]
}