{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install -q x-transformers" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TWiErEkm1YNU", "outputId": "1dd7de09-712e-4f5a-f74d-9c48f7702dd9" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m97.8/97.8 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.6/101.6 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.0/103.0 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.6/61.6 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XfhKiI_Z1Q6F" }, "outputs": [], "source": [ "# @title 🛠️ Appendix Physical Validation (Gain & Stability)\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from huggingface_hub import hf_hub_download\n", "from transformers import AutoTokenizer\n", "import sys\n", "import os\n", "\n", "# ==============================================================================\n", "# 1. SETUP & MODEL LOADING\n", "# ==============================================================================\n", "REPO_ID = \"prism-lab/prism-shimmer-100k\"\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "print(f\"⚙️ Hardware: {DEVICE}\")\n", "print(f\"📥 Loading PRISM from {REPO_ID}...\")\n", "\n", "# Download architecture\n", "os.makedirs(\"shimmer_code\", exist_ok=True)\n", "hf_hub_download(repo_id=REPO_ID, filename=\"modeling_prism_gated.py\", local_dir=\"shimmer_code\")\n", "sys.path.append(\"shimmer_code\")\n", "\n", "from modeling_prism_gated import PRISMHybrid_RoPE\n", "\n", "# Load Model\n", "tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n", "CONFIG = {\n", " \"vocab_size\": 58101, \"d_model\": 512, \"num_heads\": 8, \"dff\": 2048,\n", " \"dropout\": 0.1, \"max_length\": 128, \"num_encoder_layers\": 6,\n", " \"num_refining_layers\": 0, \"num_decoder_layers\": 6\n", "}\n", "model = PRISMHybrid_RoPE(**CONFIG)\n", "state_dict = torch.load(hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\"), map_location=DEVICE)\n", "model.load_state_dict(state_dict)\n", "model.to(DEVICE)\n", "model.eval()\n", "\n", "print(\"✅ Model Ready.\")\n", "\n", "# ==============================================================================\n", "# 2. DATASETS (Placeholders)\n", "# ==============================================================================\n", "# ⚠️ PASTE YOUR FULL LISTS HERE FROM THE PREVIOUS STEP\n", "# N=76 Hard, N=70 Easy\n", "\n", "raw_poly_candidates = [\n", " # --- ORIGINAL SET ---\n", " (\"Ich gehe zur Bank um Geld zu holen\", \"Bank\"), (\"Die Bank hat hohe Zinsen\", \"Bank\"),\n", " (\"Wir saßen auf einer Bank im Park\", \"Bank\"), (\"Die Bank aus Holz war bequem\", \"Bank\"),\n", " (\"Das Schloss hat viele Türme\", \"Schloss\"), (\"Der König wohnt im Schloss\", \"Schloss\"),\n", " (\"Der Schlüssel steckt im Schloss\", \"Schloss\"), (\"Das Schloss an der Tür klemmt\", \"Schloss\"),\n", " (\"Der Leiter der Firma ist streng\", \"Leiter\"), (\"Unser Leiter plant das Projekt\", \"Leiter\"),\n", " (\"Ich steige auf die Leiter\", \"Leiter\"), (\"Die Leiter ist aus Aluminium\", \"Leiter\"),\n", " (\"Die Lampe hängt an der Decke\", \"Decke\"), (\"Die Decke ist weiß gestrichen\", \"Decke\"),\n", " (\"Mir ist kalt gib mir eine Decke\", \"Decke\"), (\"Die Decke aus Wolle ist warm\", \"Decke\"),\n", " (\"Der Kiefer ist ein Nadelbaum\", \"Kiefer\"), (\"Das Holz der Kiefer ist weich\", \"Kiefer\"),\n", " (\"Der Arzt röntgt meinen Kiefer\", \"Kiefer\"), (\"Er hat Schmerzen im Kiefer\", \"Kiefer\"),\n", " (\"Der Strauß ist ein schneller Vogel\", \"Strauß\"), (\"Dieser Strauß kann nicht fliegen\", \"Strauß\"),\n", " (\"Sie kaufte einen bunten Strauß\", \"Strauß\"), (\"Der Strauß Blumen duftet gut\", \"Strauß\"),\n", " (\"Er schoss ein schönes Tor\", \"Tor\"), (\"Der Ball flog ins Tor\", \"Tor\"),\n", " (\"Das eiserne Tor war verschlossen\", \"Tor\"), (\"Sie öffneten das große Tor\", \"Tor\"),\n", " (\"Wir tanzen auf dem Ball\", \"Ball\"), (\"Der Maskenball war elegant\", \"Ball\"),\n", " (\"Er warf den Ball weit weg\", \"Ball\"), (\"Der Ball ist rund und rot\", \"Ball\"),\n", " (\"Die Schlange im Zoo ist giftig\", \"Schlange\"), (\"Die Schlange zischte laut\", \"Schlange\"),\n", " (\"Wir stehen in einer langen Schlange\", \"Schlange\"), (\"Die Schlange an der Kasse war lang\", \"Schlange\"),\n", " (\"Der Strom ist ausgefallen\", \"Strom\"), (\"Strom kostet viel Geld\", \"Strom\"),\n", " (\"Der Strom fließt ins Meer\", \"Strom\"), (\"Wir schwammen gegen den Strom\", \"Strom\"),\n", " (\"Seine Mutter ist sehr nett\", \"Mutter\"), (\"Die Mutter kocht das Essen\", \"Mutter\"),\n", " (\"Die Mutter passt auf die Schraube\", \"Mutter\"), (\"Ich brauche eine neue Mutter\", \"Mutter\"),\n", " (\"Die Birne schmeckt süß\", \"Birne\"), (\"Ich esse gerne eine Birne\", \"Birne\"),\n", " (\"Die Birne in der Lampe ist kaputt\", \"Birne\"), (\"Wir müssen die Birne wechseln\", \"Birne\"),\n", " # --- EXPANSION SET ---\n", " (\"Das Gericht hat ihn verurteilt\", \"Gericht\"), (\"Der Anwalt geht zum Gericht\", \"Gericht\"),\n", " (\"Mein Lieblingsessen ist ein Gericht aus Reis\", \"Gericht\"), (\"Das Gericht schmeckt sehr salzig\", \"Gericht\"),\n", " (\"Der Ton war sehr laut\", \"Ton\"), (\"Ich hörte einen hohen Ton\", \"Ton\"),\n", " (\"Die Vase ist aus Ton\", \"Ton\"), (\"Wir formen Figuren aus Ton\", \"Ton\"),\n", " (\"Das Blatt fällt vom Baum\", \"Blatt\"), (\"Im Herbst werden die Blätter braun\", \"Blatt\"),\n", " (\"Ich schreibe auf ein Blatt Papier\", \"Blatt\"), (\"Gib mir bitte ein leeres Blatt\", \"Blatt\"),\n", " (\"Der Nagel steckt in der Wand\", \"Nagel\"), (\"Ich schlage den Nagel mit dem Hammer\", \"Nagel\"),\n", " (\"Mein Nagel ist abgebrochen\", \"Nagel\"), (\"Sie lackiert sich den Nagel rot\", \"Nagel\"),\n", " (\"Die Maus frisst den Käse\", \"Maus\"), (\"Die Katze jagt die Maus\", \"Maus\"),\n", " (\"Ich klicke mit der Maus\", \"Maus\"), (\"Der Computer braucht eine neue Maus\", \"Maus\"),\n", " (\"Die Erde dreht sich um die Sonne\", \"Erde\"), (\"Der Astronaut schaut auf die Erde\", \"Erde\"),\n", " (\"Die Blume braucht frische Erde\", \"Erde\"), (\"Er gräbt ein Loch in die Erde\", \"Erde\"),\n", " (\"Der Hahn kräht am Morgen\", \"Hahn\"), (\"Der Hahn hat bunte Federn\", \"Hahn\"),\n", " (\"Der Wasserhahn tropft\", \"Hahn\"), (\"Dreh bitte den Hahn zu\", \"Hahn\"),\n", " (\"Die Schale der Orange ist bitter\", \"Schale\"), (\"Er wirft die Schale weg\", \"Schale\"),\n", " (\"Die Schale steht auf dem Tisch\", \"Schale\"), (\"Ich esse Müsli aus der Schale\", \"Schale\"),\n", " (\"Der Bauer melkt die Kühe\", \"Bauer\"), (\"Der Bauer fährt auf dem Traktor\", \"Bauer\"),\n", " (\"Ich ziehe den Bauer auf E4\", \"Bauer\"), (\"Der Bauer schlägt den Turm\", \"Bauer\"),\n", "]\n", "\n", "# B. EASY MODE (Casual)\n", "raw_casual_candidates = [\n", " (\"Die Katze schläft\", \"Katze\"), (\"Der Hund bellt\", \"Hund\"), (\"Das Auto fährt\", \"Auto\"),\n", " (\"Wasser ist nass\", \"Wasser\"), (\"Das Brot schmeckt gut\", \"Brot\"), (\"Die Sonne scheint\", \"Sonne\"),\n", " (\"Der Mond leuchtet\", \"Mond\"), (\"Das Buch ist spannend\", \"Buch\"), (\"Der Tisch ist rund\", \"Tisch\"),\n", " (\"Der Stuhl ist bequem\", \"Stuhl\"), (\"Der Apfel ist rot\", \"Apfel\"), (\"Meine Hand ist kalt\", \"Hand\"),\n", " (\"Das Herz klopft\", \"Herz\"), (\"Wir haben Zeit\", \"Zeit\"), (\"Geld ist wichtig\", \"Geld\"),\n", " (\"Musik ist schön\", \"Musik\"), (\"Der Film ist zu Ende\", \"Film\"), (\"Das Spiel beginnt\", \"Spiel\"),\n", " (\"Die Schule ist aus\", \"Schule\"), (\"Die Stadt ist laut\", \"Stadt\"), (\"Der Fluss fließt\", \"Fluss\"),\n", " (\"Das Meer ist tief\", \"Meer\"), (\"Kaffee ist schwarz\", \"Kaffee\"), (\"Milch ist weiß\", \"Milch\"),\n", " (\"Der Bruder lacht\", \"Bruder\"), (\"Die Schwester weint\", \"Schwester\"), (\"Das Haus ist groß\", \"Haus\"),\n", " (\"Der Garten ist grün\", \"Garten\"), (\"Der Sommer ist heiß\", \"Sommer\"), (\"Der Winter ist kalt\", \"Winter\"),\n", " (\"Das Fenster ist offen\", \"Fenster\"), (\"Die Tür ist zu\", \"Tür\"), (\"Der Boden ist sauber\", \"Boden\"),\n", " (\"Die Wand ist weiß\", \"Wand\"), (\"Das Dach ist rot\", \"Dach\"), (\"Der Wald ist dunkel\", \"Wald\"),\n", " (\"Der Berg ist hoch\", \"Berg\"), (\"Der See ist ruhig\", \"See\"), (\"Das Tier ist wild\", \"Tier\"),\n", " (\"Der Mensch denkt\", \"Mensch\"), (\"Das Kind spielt\", \"Kind\"), (\"Die Frau arbeitet\", \"Frau\"),\n", " (\"Der Mann schläft\", \"Mann\"), (\"Das Auge sieht\", \"Auge\"), (\"Das Ohr hört\", \"Ohr\"),\n", " (\"Die Nase riecht\", \"Nase\"), (\"Der Mund spricht\", \"Mund\"), (\"Der Arm ist stark\", \"Arm\"),\n", " (\"Das Bein tut weh\", \"Bein\"), (\"Der Fuß ist groß\", \"Fuß\"), (\"Der Tee ist heiß\", \"Tee\"),\n", " (\"Das Bier ist kalt\", \"Bier\"), (\"Der Wein ist rot\", \"Wein\"), (\"Das Glas ist voll\", \"Glas\"),\n", " (\"Die Tasse ist leer\", \"Tasse\"), (\"Der Teller ist blau\", \"Teller\"), (\"Die Gabel ist spitz\", \"Gabel\"),\n", " (\"Der Löffel ist rund\", \"Löffel\"), (\"Das Messer ist scharf\", \"Messer\"), (\"Der Stift schreibt\", \"Stift\"),\n", " (\"Der Brief ist lang\", \"Brief\"), (\"Das Bild ist schön\", \"Bild\"), (\"Die Uhr tickt\", \"Uhr\"),\n", " (\"Das Bett ist weich\", \"Bett\"), (\"Der Schrank ist voll\", \"Schrank\"), (\"Das Sofa ist neu\", \"Sofa\"),\n", " (\"Das Radio spielt\", \"Radio\"), (\"Das Jahr ist um\", \"Jahr\"), (\"Der Tag war lang\", \"Tag\"),\n", " (\"Die Nacht ist kurz\", \"Nacht\")\n", "]\n", "\n", "# ==============================================================================\n", "# 3. HELPER: Single-Token Validator\n", "# ==============================================================================\n", "def filter_dataset(candidates, tokenizer, label):\n", " valid = []\n", " for ctx, tgt in candidates:\n", " t1 = tokenizer.encode(tgt, add_special_tokens=False)\n", " t2 = tokenizer.encode(\" \" + tgt, add_special_tokens=False)\n", " if len(t1) == 1 or len(t2) == 1: valid.append((ctx, tgt))\n", " print(f\"✅ {label}: {len(valid)} atomic examples validated.\")\n", " return valid\n", "\n", "def find_token_index(input_ids, target_word, tokenizer):\n", " tokens = tokenizer.convert_ids_to_tokens(input_ids)\n", " for i, t in enumerate(tokens):\n", " clean = t.replace('Ġ', '').replace('▁', '').replace(' ', '')\n", " if target_word.lower() == clean.lower(): return i\n", " for i, t in enumerate(tokens): # Fallback\n", " clean = t.replace('Ġ', '').replace('▁', '').replace(' ', '')\n", " if target_word.lower() in clean.lower(): return i\n", " return 1\n", "\n", "# ==============================================================================\n", "# 4. PHYSICAL PROBE (Gain & Magnitude)\n", "# ==============================================================================\n", "def run_physical_probe(model, tokenizer, dataset, label, device):\n", " \"\"\"\n", " Extracts Gain (Ratio) and Raw Magnitude (Norm) for CV analysis.\n", " \"\"\"\n", " num_layers = len(model.prism_encoder.layers)\n", "\n", " # Store Gain (for Fig B3) and Magnitude (for Fig B1)\n", " gain_stats = {i: [] for i in range(num_layers)}\n", " magnitude_stats = {i: [] for i in range(num_layers)}\n", " embedding_mags = []\n", "\n", " hook_data = {}\n", "\n", " def physics_hook(layer_idx):\n", " def hook(module, input, output):\n", " x, y = input[0].detach(), output.detach()\n", "\n", " # 1. Norms (Energy)\n", " norm_x = torch.norm(x, p=2, dim=-1)\n", " norm_y = torch.norm(y, p=2, dim=-1)\n", "\n", " # 2. Gain Calculation\n", " gain = norm_y / (norm_x + 1e-9)\n", "\n", " hook_data[f'layer_{layer_idx}'] = {\n", " 'gain': gain.cpu(),\n", " 'mag': norm_y.cpu() # Output magnitude\n", " }\n", " return hook\n", "\n", " # Register Hooks\n", " model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n", " for i, layer in enumerate(model.prism_encoder.layers):\n", " layer.register_forward_hook(physics_hook(i))\n", "\n", " # Run Probe\n", " print(f\"🔬 Measuring Physics on {len(dataset)} {label} examples...\")\n", " for context, target in dataset:\n", " hook_data = {}\n", " inputs = tokenizer(context, return_tensors=\"pt\").to(device)\n", "\n", " with torch.no_grad():\n", " # Capture embedding magnitude before encoder\n", " emb = model.harmonic_embedding(inputs.input_ids)\n", " embedding_mags.append(torch.norm(emb, p=2, dim=-1).flatten().cpu())\n", "\n", " # Forward pass\n", " src_mask = (inputs.input_ids == tokenizer.pad_token_id)\n", " model.prism_encoder(emb, src_mask)\n", "\n", " idx = find_token_index(inputs.input_ids[0], target, tokenizer)\n", "\n", " for i in range(num_layers):\n", " if f'layer_{i}' in hook_data:\n", " data = hook_data[f'layer_{i}']\n", "\n", " # Extract atomic token metrics\n", " g = data['gain']\n", " m = data['mag']\n", "\n", " val_g = g[0, idx].item() if g.dim() > 1 else g[idx].item()\n", " val_m = m[0, idx].item() if m.dim() > 1 else m[idx].item()\n", "\n", " gain_stats[i].append(val_g)\n", " magnitude_stats[i].append(val_m)\n", "\n", " model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n", "\n", " return {\n", " 'gain': pd.DataFrame(gain_stats),\n", " 'magnitude': magnitude_stats, # Dict of lists\n", " 'embedding': torch.cat(embedding_mags).numpy()\n", " }\n", "\n", "# ==============================================================================\n", "# 5. EXECUTION\n", "# ==============================================================================\n", "# Filter\n", "ds_hard = filter_dataset(raw_poly_candidates, tokenizer, \"HARD\")\n", "ds_easy = filter_dataset(raw_casual_candidates, tokenizer, \"EASY\")\n", "\n", "# Run\n", "res_hard = run_physical_probe(model, tokenizer, ds_hard, \"HARD\", DEVICE)\n", "res_easy = run_physical_probe(model, tokenizer, ds_easy, \"EASY\", DEVICE)\n", "\n", "# ==============================================================================\n", "# 6. PLOT FIGURE B3: ISO-ENERGETIC GAIN\n", "# ==============================================================================\n", "def plot_gain_chart(res_hard, res_easy):\n", " df_h = res_hard['gain']\n", " df_e = res_easy['gain']\n", "\n", " layers = list(df_h.columns)\n", " means_h = [df_h[i].mean() for i in layers]\n", " stds_h = [df_h[i].std() for i in layers]\n", " means_e = [df_e[i].mean() for i in layers]\n", " stds_e = [df_e[i].std() for i in layers]\n", "\n", " x = np.arange(len(layers))\n", " width = 0.35\n", "\n", " fig, ax = plt.subplots(figsize=(8, 4), dpi=300)\n", " ax.bar(x - width/2, means_h, width, yerr=stds_h, label='Ambiguous',\n", " color='indianred', alpha=0.8, capsize=3)\n", " ax.bar(x + width/2, means_e, width, yerr=stds_e, label='Unambiguous',\n", " color='steelblue', alpha=0.8, capsize=3)\n", "\n", " ax.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Unity Gain (g=1.0)')\n", " ax.set_ylabel('Signal Gain (||y|| / ||x||)', fontweight='bold')\n", " ax.set_xlabel('Layer Depth')\n", " ax.set_xticks(x)\n", " ax.set_xticklabels(layers)\n", " ax.set_ylim(0.85, 1.15) # Zoom in to show it's flat\n", " ax.legend(loc='upper right')\n", " ax.set_title('Iso-Energetic Constraint: Gain ≈ 1.0 Across All Conditions', fontweight='bold')\n", " ax.grid(axis='y', linestyle='--', alpha=0.3)\n", "\n", " plt.tight_layout()\n", " plt.savefig(\"fig_B3_gain.png\")\n", " plt.show()\n", " print(\"✅ Figure B3 Saved.\")\n", "\n", "# ==============================================================================\n", "# 7. PLOT FIGURE B1: MAGNITUDE STABILITY (CV)\n", "# ==============================================================================\n", "def plot_cv_chart(res_hard, res_easy):\n", " # Combine data to check global network stability\n", " # CV = sigma / mu\n", "\n", " stages = [\"Embedding\"]\n", " cvs = []\n", "\n", " # 1. Embedding Stage\n", " all_emb = np.concatenate([res_hard['embedding'], res_easy['embedding']])\n", " cvs.append(all_emb.std() / all_emb.mean())\n", "\n", " # 2. Layers 0-5\n", " for i in range(6):\n", " # Flatten lists\n", " mags_h = np.array(res_hard['magnitude'][i])\n", " mags_e = np.array(res_easy['magnitude'][i])\n", " all_mags = np.concatenate([mags_h, mags_e])\n", "\n", " cv = all_mags.std() / (all_mags.mean() + 1e-9)\n", " cvs.append(cv)\n", " stages.append(f\"Layer {i}\")\n", "\n", " mean_cv = np.mean(cvs)\n", "\n", " fig, ax = plt.subplots(figsize=(8, 4), dpi=300)\n", " bars = ax.bar(stages, cvs, color='steelblue', alpha=0.8, edgecolor='grey')\n", "\n", " ax.axhline(y=mean_cv, color='red', linestyle='--', label=f'Mean CV = {mean_cv:.3f}')\n", " ax.set_ylabel('Coefficient of Variation (σ/μ)', fontweight='bold')\n", " ax.set_xlabel('Network Stage')\n", " ax.set_title('Magnitude Stability Across Layers (Iso-Energetic Check)', fontweight='bold')\n", " ax.set_ylim(0, 1.0)\n", " ax.legend()\n", "\n", " # Label bars\n", " for bar, v in zip(bars, cvs):\n", " ax.text(bar.get_x() + bar.get_width()/2, v, f\"{v:.3f}\",\n", " ha='center', va='bottom', fontsize=9)\n", "\n", " plt.tight_layout()\n", " plt.savefig(\"fig_B1_cv.png\")\n", " plt.show()\n", " print(\"✅ Figure B1 Saved.\")\n", "\n", "# ==============================================================================\n", "# RUN PLOTS\n", "# ==============================================================================\n", "plot_gain_chart(res_hard, res_easy)\n", "plot_cv_chart(res_hard, res_easy)" ] } ] }