File size: 22,438 Bytes
5fe9601 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install -q x-transformers"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TWiErEkm1YNU",
"outputId": "1dd7de09-712e-4f5a-f74d-9c48f7702dd9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m97.8/97.8 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m101.6/101.6 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m103.0/103.0 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m61.6/61.6 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XfhKiI_Z1Q6F"
},
"outputs": [],
"source": [
"# @title π οΈ Appendix Physical Validation (Gain & Stability)\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from huggingface_hub import hf_hub_download\n",
"from transformers import AutoTokenizer\n",
"import sys\n",
"import os\n",
"\n",
"# ==============================================================================\n",
"# 1. SETUP & MODEL LOADING\n",
"# ==============================================================================\n",
"REPO_ID = \"prism-lab/prism-shimmer-100k\"\n",
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(f\"βοΈ Hardware: {DEVICE}\")\n",
"print(f\"π₯ Loading PRISM from {REPO_ID}...\")\n",
"\n",
"# Download architecture\n",
"os.makedirs(\"shimmer_code\", exist_ok=True)\n",
"hf_hub_download(repo_id=REPO_ID, filename=\"modeling_prism_gated.py\", local_dir=\"shimmer_code\")\n",
"sys.path.append(\"shimmer_code\")\n",
"\n",
"from modeling_prism_gated import PRISMHybrid_RoPE\n",
"\n",
"# Load Model\n",
"tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n",
"CONFIG = {\n",
" \"vocab_size\": 58101, \"d_model\": 512, \"num_heads\": 8, \"dff\": 2048,\n",
" \"dropout\": 0.1, \"max_length\": 128, \"num_encoder_layers\": 6,\n",
" \"num_refining_layers\": 0, \"num_decoder_layers\": 6\n",
"}\n",
"model = PRISMHybrid_RoPE(**CONFIG)\n",
"state_dict = torch.load(hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\"), map_location=DEVICE)\n",
"model.load_state_dict(state_dict)\n",
"model.to(DEVICE)\n",
"model.eval()\n",
"\n",
"print(\"β
Model Ready.\")\n",
"\n",
"# ==============================================================================\n",
"# 2. DATASETS (Placeholders)\n",
"# ==============================================================================\n",
"# β οΈ PASTE YOUR FULL LISTS HERE FROM THE PREVIOUS STEP\n",
"# N=76 Hard, N=70 Easy\n",
"\n",
"raw_poly_candidates = [\n",
" # --- ORIGINAL SET ---\n",
" (\"Ich gehe zur Bank um Geld zu holen\", \"Bank\"), (\"Die Bank hat hohe Zinsen\", \"Bank\"),\n",
" (\"Wir saΓen auf einer Bank im Park\", \"Bank\"), (\"Die Bank aus Holz war bequem\", \"Bank\"),\n",
" (\"Das Schloss hat viele TΓΌrme\", \"Schloss\"), (\"Der KΓΆnig wohnt im Schloss\", \"Schloss\"),\n",
" (\"Der SchlΓΌssel steckt im Schloss\", \"Schloss\"), (\"Das Schloss an der TΓΌr klemmt\", \"Schloss\"),\n",
" (\"Der Leiter der Firma ist streng\", \"Leiter\"), (\"Unser Leiter plant das Projekt\", \"Leiter\"),\n",
" (\"Ich steige auf die Leiter\", \"Leiter\"), (\"Die Leiter ist aus Aluminium\", \"Leiter\"),\n",
" (\"Die Lampe hΓ€ngt an der Decke\", \"Decke\"), (\"Die Decke ist weiΓ gestrichen\", \"Decke\"),\n",
" (\"Mir ist kalt gib mir eine Decke\", \"Decke\"), (\"Die Decke aus Wolle ist warm\", \"Decke\"),\n",
" (\"Der Kiefer ist ein Nadelbaum\", \"Kiefer\"), (\"Das Holz der Kiefer ist weich\", \"Kiefer\"),\n",
" (\"Der Arzt rΓΆntgt meinen Kiefer\", \"Kiefer\"), (\"Er hat Schmerzen im Kiefer\", \"Kiefer\"),\n",
" (\"Der StrauΓ ist ein schneller Vogel\", \"StrauΓ\"), (\"Dieser StrauΓ kann nicht fliegen\", \"StrauΓ\"),\n",
" (\"Sie kaufte einen bunten StrauΓ\", \"StrauΓ\"), (\"Der StrauΓ Blumen duftet gut\", \"StrauΓ\"),\n",
" (\"Er schoss ein schΓΆnes Tor\", \"Tor\"), (\"Der Ball flog ins Tor\", \"Tor\"),\n",
" (\"Das eiserne Tor war verschlossen\", \"Tor\"), (\"Sie ΓΆffneten das groΓe Tor\", \"Tor\"),\n",
" (\"Wir tanzen auf dem Ball\", \"Ball\"), (\"Der Maskenball war elegant\", \"Ball\"),\n",
" (\"Er warf den Ball weit weg\", \"Ball\"), (\"Der Ball ist rund und rot\", \"Ball\"),\n",
" (\"Die Schlange im Zoo ist giftig\", \"Schlange\"), (\"Die Schlange zischte laut\", \"Schlange\"),\n",
" (\"Wir stehen in einer langen Schlange\", \"Schlange\"), (\"Die Schlange an der Kasse war lang\", \"Schlange\"),\n",
" (\"Der Strom ist ausgefallen\", \"Strom\"), (\"Strom kostet viel Geld\", \"Strom\"),\n",
" (\"Der Strom flieΓt ins Meer\", \"Strom\"), (\"Wir schwammen gegen den Strom\", \"Strom\"),\n",
" (\"Seine Mutter ist sehr nett\", \"Mutter\"), (\"Die Mutter kocht das Essen\", \"Mutter\"),\n",
" (\"Die Mutter passt auf die Schraube\", \"Mutter\"), (\"Ich brauche eine neue Mutter\", \"Mutter\"),\n",
" (\"Die Birne schmeckt sΓΌΓ\", \"Birne\"), (\"Ich esse gerne eine Birne\", \"Birne\"),\n",
" (\"Die Birne in der Lampe ist kaputt\", \"Birne\"), (\"Wir mΓΌssen die Birne wechseln\", \"Birne\"),\n",
" # --- EXPANSION SET ---\n",
" (\"Das Gericht hat ihn verurteilt\", \"Gericht\"), (\"Der Anwalt geht zum Gericht\", \"Gericht\"),\n",
" (\"Mein Lieblingsessen ist ein Gericht aus Reis\", \"Gericht\"), (\"Das Gericht schmeckt sehr salzig\", \"Gericht\"),\n",
" (\"Der Ton war sehr laut\", \"Ton\"), (\"Ich hΓΆrte einen hohen Ton\", \"Ton\"),\n",
" (\"Die Vase ist aus Ton\", \"Ton\"), (\"Wir formen Figuren aus Ton\", \"Ton\"),\n",
" (\"Das Blatt fΓ€llt vom Baum\", \"Blatt\"), (\"Im Herbst werden die BlΓ€tter braun\", \"Blatt\"),\n",
" (\"Ich schreibe auf ein Blatt Papier\", \"Blatt\"), (\"Gib mir bitte ein leeres Blatt\", \"Blatt\"),\n",
" (\"Der Nagel steckt in der Wand\", \"Nagel\"), (\"Ich schlage den Nagel mit dem Hammer\", \"Nagel\"),\n",
" (\"Mein Nagel ist abgebrochen\", \"Nagel\"), (\"Sie lackiert sich den Nagel rot\", \"Nagel\"),\n",
" (\"Die Maus frisst den KΓ€se\", \"Maus\"), (\"Die Katze jagt die Maus\", \"Maus\"),\n",
" (\"Ich klicke mit der Maus\", \"Maus\"), (\"Der Computer braucht eine neue Maus\", \"Maus\"),\n",
" (\"Die Erde dreht sich um die Sonne\", \"Erde\"), (\"Der Astronaut schaut auf die Erde\", \"Erde\"),\n",
" (\"Die Blume braucht frische Erde\", \"Erde\"), (\"Er grΓ€bt ein Loch in die Erde\", \"Erde\"),\n",
" (\"Der Hahn krΓ€ht am Morgen\", \"Hahn\"), (\"Der Hahn hat bunte Federn\", \"Hahn\"),\n",
" (\"Der Wasserhahn tropft\", \"Hahn\"), (\"Dreh bitte den Hahn zu\", \"Hahn\"),\n",
" (\"Die Schale der Orange ist bitter\", \"Schale\"), (\"Er wirft die Schale weg\", \"Schale\"),\n",
" (\"Die Schale steht auf dem Tisch\", \"Schale\"), (\"Ich esse MΓΌsli aus der Schale\", \"Schale\"),\n",
" (\"Der Bauer melkt die KΓΌhe\", \"Bauer\"), (\"Der Bauer fΓ€hrt auf dem Traktor\", \"Bauer\"),\n",
" (\"Ich ziehe den Bauer auf E4\", \"Bauer\"), (\"Der Bauer schlΓ€gt den Turm\", \"Bauer\"),\n",
"]\n",
"\n",
"# B. EASY MODE (Casual)\n",
"raw_casual_candidates = [\n",
" (\"Die Katze schlΓ€ft\", \"Katze\"), (\"Der Hund bellt\", \"Hund\"), (\"Das Auto fΓ€hrt\", \"Auto\"),\n",
" (\"Wasser ist nass\", \"Wasser\"), (\"Das Brot schmeckt gut\", \"Brot\"), (\"Die Sonne scheint\", \"Sonne\"),\n",
" (\"Der Mond leuchtet\", \"Mond\"), (\"Das Buch ist spannend\", \"Buch\"), (\"Der Tisch ist rund\", \"Tisch\"),\n",
" (\"Der Stuhl ist bequem\", \"Stuhl\"), (\"Der Apfel ist rot\", \"Apfel\"), (\"Meine Hand ist kalt\", \"Hand\"),\n",
" (\"Das Herz klopft\", \"Herz\"), (\"Wir haben Zeit\", \"Zeit\"), (\"Geld ist wichtig\", \"Geld\"),\n",
" (\"Musik ist schΓΆn\", \"Musik\"), (\"Der Film ist zu Ende\", \"Film\"), (\"Das Spiel beginnt\", \"Spiel\"),\n",
" (\"Die Schule ist aus\", \"Schule\"), (\"Die Stadt ist laut\", \"Stadt\"), (\"Der Fluss flieΓt\", \"Fluss\"),\n",
" (\"Das Meer ist tief\", \"Meer\"), (\"Kaffee ist schwarz\", \"Kaffee\"), (\"Milch ist weiΓ\", \"Milch\"),\n",
" (\"Der Bruder lacht\", \"Bruder\"), (\"Die Schwester weint\", \"Schwester\"), (\"Das Haus ist groΓ\", \"Haus\"),\n",
" (\"Der Garten ist grΓΌn\", \"Garten\"), (\"Der Sommer ist heiΓ\", \"Sommer\"), (\"Der Winter ist kalt\", \"Winter\"),\n",
" (\"Das Fenster ist offen\", \"Fenster\"), (\"Die TΓΌr ist zu\", \"TΓΌr\"), (\"Der Boden ist sauber\", \"Boden\"),\n",
" (\"Die Wand ist weiΓ\", \"Wand\"), (\"Das Dach ist rot\", \"Dach\"), (\"Der Wald ist dunkel\", \"Wald\"),\n",
" (\"Der Berg ist hoch\", \"Berg\"), (\"Der See ist ruhig\", \"See\"), (\"Das Tier ist wild\", \"Tier\"),\n",
" (\"Der Mensch denkt\", \"Mensch\"), (\"Das Kind spielt\", \"Kind\"), (\"Die Frau arbeitet\", \"Frau\"),\n",
" (\"Der Mann schlΓ€ft\", \"Mann\"), (\"Das Auge sieht\", \"Auge\"), (\"Das Ohr hΓΆrt\", \"Ohr\"),\n",
" (\"Die Nase riecht\", \"Nase\"), (\"Der Mund spricht\", \"Mund\"), (\"Der Arm ist stark\", \"Arm\"),\n",
" (\"Das Bein tut weh\", \"Bein\"), (\"Der FuΓ ist groΓ\", \"FuΓ\"), (\"Der Tee ist heiΓ\", \"Tee\"),\n",
" (\"Das Bier ist kalt\", \"Bier\"), (\"Der Wein ist rot\", \"Wein\"), (\"Das Glas ist voll\", \"Glas\"),\n",
" (\"Die Tasse ist leer\", \"Tasse\"), (\"Der Teller ist blau\", \"Teller\"), (\"Die Gabel ist spitz\", \"Gabel\"),\n",
" (\"Der LΓΆffel ist rund\", \"LΓΆffel\"), (\"Das Messer ist scharf\", \"Messer\"), (\"Der Stift schreibt\", \"Stift\"),\n",
" (\"Der Brief ist lang\", \"Brief\"), (\"Das Bild ist schΓΆn\", \"Bild\"), (\"Die Uhr tickt\", \"Uhr\"),\n",
" (\"Das Bett ist weich\", \"Bett\"), (\"Der Schrank ist voll\", \"Schrank\"), (\"Das Sofa ist neu\", \"Sofa\"),\n",
" (\"Das Radio spielt\", \"Radio\"), (\"Das Jahr ist um\", \"Jahr\"), (\"Der Tag war lang\", \"Tag\"),\n",
" (\"Die Nacht ist kurz\", \"Nacht\")\n",
"]\n",
"\n",
"# ==============================================================================\n",
"# 3. HELPER: Single-Token Validator\n",
"# ==============================================================================\n",
"def filter_dataset(candidates, tokenizer, label):\n",
" valid = []\n",
" for ctx, tgt in candidates:\n",
" t1 = tokenizer.encode(tgt, add_special_tokens=False)\n",
" t2 = tokenizer.encode(\" \" + tgt, add_special_tokens=False)\n",
" if len(t1) == 1 or len(t2) == 1: valid.append((ctx, tgt))\n",
" print(f\"β
{label}: {len(valid)} atomic examples validated.\")\n",
" return valid\n",
"\n",
"def find_token_index(input_ids, target_word, tokenizer):\n",
" tokens = tokenizer.convert_ids_to_tokens(input_ids)\n",
" for i, t in enumerate(tokens):\n",
" clean = t.replace('Δ ', '').replace('β', '').replace(' ', '')\n",
" if target_word.lower() == clean.lower(): return i\n",
" for i, t in enumerate(tokens): # Fallback\n",
" clean = t.replace('Δ ', '').replace('β', '').replace(' ', '')\n",
" if target_word.lower() in clean.lower(): return i\n",
" return 1\n",
"\n",
"# ==============================================================================\n",
"# 4. PHYSICAL PROBE (Gain & Magnitude)\n",
"# ==============================================================================\n",
"def run_physical_probe(model, tokenizer, dataset, label, device):\n",
" \"\"\"\n",
" Extracts Gain (Ratio) and Raw Magnitude (Norm) for CV analysis.\n",
" \"\"\"\n",
" num_layers = len(model.prism_encoder.layers)\n",
"\n",
" # Store Gain (for Fig B3) and Magnitude (for Fig B1)\n",
" gain_stats = {i: [] for i in range(num_layers)}\n",
" magnitude_stats = {i: [] for i in range(num_layers)}\n",
" embedding_mags = []\n",
"\n",
" hook_data = {}\n",
"\n",
" def physics_hook(layer_idx):\n",
" def hook(module, input, output):\n",
" x, y = input[0].detach(), output.detach()\n",
"\n",
" # 1. Norms (Energy)\n",
" norm_x = torch.norm(x, p=2, dim=-1)\n",
" norm_y = torch.norm(y, p=2, dim=-1)\n",
"\n",
" # 2. Gain Calculation\n",
" gain = norm_y / (norm_x + 1e-9)\n",
"\n",
" hook_data[f'layer_{layer_idx}'] = {\n",
" 'gain': gain.cpu(),\n",
" 'mag': norm_y.cpu() # Output magnitude\n",
" }\n",
" return hook\n",
"\n",
" # Register Hooks\n",
" model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
" for i, layer in enumerate(model.prism_encoder.layers):\n",
" layer.register_forward_hook(physics_hook(i))\n",
"\n",
" # Run Probe\n",
" print(f\"π¬ Measuring Physics on {len(dataset)} {label} examples...\")\n",
" for context, target in dataset:\n",
" hook_data = {}\n",
" inputs = tokenizer(context, return_tensors=\"pt\").to(device)\n",
"\n",
" with torch.no_grad():\n",
" # Capture embedding magnitude before encoder\n",
" emb = model.harmonic_embedding(inputs.input_ids)\n",
" embedding_mags.append(torch.norm(emb, p=2, dim=-1).flatten().cpu())\n",
"\n",
" # Forward pass\n",
" src_mask = (inputs.input_ids == tokenizer.pad_token_id)\n",
" model.prism_encoder(emb, src_mask)\n",
"\n",
" idx = find_token_index(inputs.input_ids[0], target, tokenizer)\n",
"\n",
" for i in range(num_layers):\n",
" if f'layer_{i}' in hook_data:\n",
" data = hook_data[f'layer_{i}']\n",
"\n",
" # Extract atomic token metrics\n",
" g = data['gain']\n",
" m = data['mag']\n",
"\n",
" val_g = g[0, idx].item() if g.dim() > 1 else g[idx].item()\n",
" val_m = m[0, idx].item() if m.dim() > 1 else m[idx].item()\n",
"\n",
" gain_stats[i].append(val_g)\n",
" magnitude_stats[i].append(val_m)\n",
"\n",
" model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
"\n",
" return {\n",
" 'gain': pd.DataFrame(gain_stats),\n",
" 'magnitude': magnitude_stats, # Dict of lists\n",
" 'embedding': torch.cat(embedding_mags).numpy()\n",
" }\n",
"\n",
"# ==============================================================================\n",
"# 5. EXECUTION\n",
"# ==============================================================================\n",
"# Filter\n",
"ds_hard = filter_dataset(raw_poly_candidates, tokenizer, \"HARD\")\n",
"ds_easy = filter_dataset(raw_casual_candidates, tokenizer, \"EASY\")\n",
"\n",
"# Run\n",
"res_hard = run_physical_probe(model, tokenizer, ds_hard, \"HARD\", DEVICE)\n",
"res_easy = run_physical_probe(model, tokenizer, ds_easy, \"EASY\", DEVICE)\n",
"\n",
"# ==============================================================================\n",
"# 6. PLOT FIGURE B3: ISO-ENERGETIC GAIN\n",
"# ==============================================================================\n",
"def plot_gain_chart(res_hard, res_easy):\n",
" df_h = res_hard['gain']\n",
" df_e = res_easy['gain']\n",
"\n",
" layers = list(df_h.columns)\n",
" means_h = [df_h[i].mean() for i in layers]\n",
" stds_h = [df_h[i].std() for i in layers]\n",
" means_e = [df_e[i].mean() for i in layers]\n",
" stds_e = [df_e[i].std() for i in layers]\n",
"\n",
" x = np.arange(len(layers))\n",
" width = 0.35\n",
"\n",
" fig, ax = plt.subplots(figsize=(8, 4), dpi=300)\n",
" ax.bar(x - width/2, means_h, width, yerr=stds_h, label='Ambiguous',\n",
" color='indianred', alpha=0.8, capsize=3)\n",
" ax.bar(x + width/2, means_e, width, yerr=stds_e, label='Unambiguous',\n",
" color='steelblue', alpha=0.8, capsize=3)\n",
"\n",
" ax.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Unity Gain (g=1.0)')\n",
" ax.set_ylabel('Signal Gain (||y|| / ||x||)', fontweight='bold')\n",
" ax.set_xlabel('Layer Depth')\n",
" ax.set_xticks(x)\n",
" ax.set_xticklabels(layers)\n",
" ax.set_ylim(0.85, 1.15) # Zoom in to show it's flat\n",
" ax.legend(loc='upper right')\n",
" ax.set_title('Iso-Energetic Constraint: Gain β 1.0 Across All Conditions', fontweight='bold')\n",
" ax.grid(axis='y', linestyle='--', alpha=0.3)\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig(\"fig_B3_gain.png\")\n",
" plt.show()\n",
" print(\"β
Figure B3 Saved.\")\n",
"\n",
"# ==============================================================================\n",
"# 7. PLOT FIGURE B1: MAGNITUDE STABILITY (CV)\n",
"# ==============================================================================\n",
"def plot_cv_chart(res_hard, res_easy):\n",
" # Combine data to check global network stability\n",
" # CV = sigma / mu\n",
"\n",
" stages = [\"Embedding\"]\n",
" cvs = []\n",
"\n",
" # 1. Embedding Stage\n",
" all_emb = np.concatenate([res_hard['embedding'], res_easy['embedding']])\n",
" cvs.append(all_emb.std() / all_emb.mean())\n",
"\n",
" # 2. Layers 0-5\n",
" for i in range(6):\n",
" # Flatten lists\n",
" mags_h = np.array(res_hard['magnitude'][i])\n",
" mags_e = np.array(res_easy['magnitude'][i])\n",
" all_mags = np.concatenate([mags_h, mags_e])\n",
"\n",
" cv = all_mags.std() / (all_mags.mean() + 1e-9)\n",
" cvs.append(cv)\n",
" stages.append(f\"Layer {i}\")\n",
"\n",
" mean_cv = np.mean(cvs)\n",
"\n",
" fig, ax = plt.subplots(figsize=(8, 4), dpi=300)\n",
" bars = ax.bar(stages, cvs, color='steelblue', alpha=0.8, edgecolor='grey')\n",
"\n",
" ax.axhline(y=mean_cv, color='red', linestyle='--', label=f'Mean CV = {mean_cv:.3f}')\n",
" ax.set_ylabel('Coefficient of Variation (Ο/ΞΌ)', fontweight='bold')\n",
" ax.set_xlabel('Network Stage')\n",
" ax.set_title('Magnitude Stability Across Layers (Iso-Energetic Check)', fontweight='bold')\n",
" ax.set_ylim(0, 1.0)\n",
" ax.legend()\n",
"\n",
" # Label bars\n",
" for bar, v in zip(bars, cvs):\n",
" ax.text(bar.get_x() + bar.get_width()/2, v, f\"{v:.3f}\",\n",
" ha='center', va='bottom', fontsize=9)\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig(\"fig_B1_cv.png\")\n",
" plt.show()\n",
" print(\"β
Figure B1 Saved.\")\n",
"\n",
"# ==============================================================================\n",
"# RUN PLOTS\n",
"# ==============================================================================\n",
"plot_gain_chart(res_hard, res_easy)\n",
"plot_cv_chart(res_hard, res_easy)"
]
}
]
} |