{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "SvLO5U3q_Q3x" }, "outputs": [], "source": [ "!pip install -q x-transformers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xy1HCL1GzAbM" }, "outputs": [], "source": [ "# ==========================================\n", "# 1. SETUP & MODEL LOADING (FIXED)\n", "# ==========================================\n", "import os\n", "import sys\n", "from huggingface_hub import hf_hub_download\n", "\n", "# --- CRITICAL FIX: Download the Model Definition FIRST ---\n", "REPO_ID = \"prism-lab/prism-shimmer-100k\"\n", "filename = \"modeling_prism_gated.py\"\n", "\n", "print(f\"⬇️ Downloading {filename} from Hugging Face...\")\n", "if not os.path.exists(filename):\n", " hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\".\", force_download=True)\n", "\n", "# Now that the file exists locally, we can import it\n", "sys.path.append(\".\") # Ensure current dir is in path\n", "from modeling_prism_gated import PRISMHybrid_RoPE\n", "\n", "# Continue with standard imports\n", "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "from transformers import AutoTokenizer\n", "import json\n", "\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "D_MODEL = 512\n", "\n", "print(\"⏳ Downloading Weights & Config...\")\n", "if not os.path.exists(\"config.json\"):\n", " hf_hub_download(repo_id=REPO_ID, filename=\"config.json\", local_dir=\".\")\n", "if not os.path.exists(\"pytorch_model.bin\"):\n", " hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\", local_dir=\".\")\n", "\n", "with open(\"config.json\", \"r\") as f: config = json.load(f)\n", "tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n", "\n", "# Initialize Model\n", "model = PRISMHybrid_RoPE(\n", " vocab_size=config['vocab_size'], d_model=config['d_model'],\n", " num_encoder_layers=config['num_encoder_layers'], num_refining_layers=0,\n", " num_decoder_layers=6, num_heads=8, dff=2048, max_length=128, dropout=0.0\n", ").to(DEVICE)\n", "\n", "model.load_state_dict(torch.load(\"pytorch_model.bin\", map_location=DEVICE))\n", "model.eval()\n", "print(\"✅ Model Loaded Successfully.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vmp0y5TYdTd0" }, "outputs": [], "source": [ "# @title 🧭 Extended Phase Compass: Synonyms vs Antonyms vs Randoms\n", "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "from transformers import AutoTokenizer\n", "from huggingface_hub import hf_hub_download\n", "import os\n", "import json\n", "\n", "\n", "# ==========================================\n", "# 2. DEFINING THE CANDIDATE PAIRS\n", "# ==========================================\n", "# Master list of all candidates\n", "candidates_raw = [\n", " # --- ORIGINAL LIST ---\n", " (\"Euro\", \"Geld\", \"Synonym\"),\n", " (\"Auto\", \"Wagen\", \"Synonym\"),\n", " (\"schnell\", \"rasch\", \"Synonym\"),\n", " (\"Stimme\", \"Wahl\", \"Synonym\"),\n", " (\"Zeit\", \"Uhr\", \"Synonym\"),\n", " (\"Start\", \"Beginn\", \"Synonym\"),\n", " (\"Ende\", \"Schluss\", \"Synonym\"),\n", " (\"Raum\", \"Platz\", \"Synonym\"),\n", "\n", " # --- NEW GERMAN SYNONYMS ---\n", " (\"Haus\", \"Heim\", \"Synonym\"),\n", " (\"Boot\", \"Schiff\", \"Synonym\"),\n", " (\"See\", \"Meer\", \"Synonym\"),\n", " (\"Wald\", \"Forst\", \"Synonym\"),\n", " (\"Weg\", \"Pfad\", \"Synonym\"),\n", " (\"Berg\", \"Gipfel\", \"Synonym\"),\n", " (\"Mund\", \"Maul\", \"Synonym\"),\n", " (\"Pferd\", \"Ross\", \"Synonym\"),\n", " (\"Hund\", \"Tier\", \"Synonym\"),\n", " (\"Reise\", \"Fahrt\", \"Synonym\"),\n", " (\"Angst\", \"Furcht\", \"Synonym\"),\n", " (\"Mut\", \"Traute\", \"Synonym\"),\n", " (\"Glück\", \"Dusel\", \"Synonym\"),\n", " (\"Ding\", \"Sache\", \"Synonym\"),\n", " (\"Welt\", \"Erde\", \"Synonym\"),\n", " (\"Stadt\", \"Ort\", \"Synonym\"),\n", " (\"Vater\", \"Papa\", \"Synonym\"),\n", " (\"Mutter\", \"Mama\", \"Synonym\"),\n", " (\"klug\", \"weise\", \"Synonym\"),\n", " (\"klug\", \"schlau\", \"Synonym\"),\n", " (\"schön\", \"hübsch\", \"Synonym\"),\n", " (\"klein\", \"winzig\", \"Synonym\"),\n", " (\"stark\", \"fest\", \"Synonym\"),\n", " (\"neu\", \"frisch\", \"Synonym\"),\n", " (\"still\", \"leise\", \"Synonym\"),\n", " (\"froh\", \"heiter\", \"Synonym\"),\n", " (\"dunkel\", \"finster\", \"Synonym\"),\n", " (\"kalt\", \"eisig\", \"Synonym\"),\n", " (\"rennen\", \"laufen\", \"Synonym\"),\n", " (\"reden\", \"sagen\", \"Synonym\"),\n", " (\"sehen\", \"schauen\", \"Synonym\"),\n", " (\"gehen\", \"wandern\", \"Synonym\"),\n", " (\"essen\", \"speisen\", \"Synonym\"),\n", " (\"Wut\", \"Zorn\", \"Synonym\"),\n", " (\"Dreck\", \"Schmutz\", \"Synonym\"),\n", " (\"Chance\", \"Möglichkeit\", \"Synonym\"), # Möglichkeit might be split, but code will check\n", " (\"Lehrer\", \"Pauker\", \"Synonym\"),\n", " (\"Gott\", \"Herr\", \"Synonym\"),\n", " (\"Chef\", \"Boss\", \"Synonym\"),\n", " (\"Haut\", \"Fell\", \"Synonym\"),\n", " (\"Tor\", \"Tür\", \"Synonym\"),\n", " (\"Zimmer\", \"Raum\", \"Synonym\"),\n", " (\"Bahn\", \"Zug\", \"Synonym\"),\n", " (\"Boot\", \"Kahn\", \"Synonym\"),\n", " (\"Hose\", \"Jeans\", \"Synonym\"),\n", " (\"Witz\", \"Scherz\", \"Synonym\"),\n", " (\"Hass\", \"Abscheu\", \"Synonym\"),\n", " (\"Fett\", \"Dick\", \"Synonym\"),\n", " (\"klug\", \"gescheit\", \"Synonym\"),\n", " (\"dumm\", \"doof\", \"Synonym\"),\n", " (\"rasch\", \"flink\", \"Synonym\"),\n", " (\"stumm\", \"still\", \"Synonym\"),\n", " (\"echt\", \"wahr\", \"Synonym\"),\n", " (\"korrekt\", \"richtig\", \"Synonym\"),\n", " # --- NEW ANTONYMS (OPPOSITES) ---\n", " (\"gut\", \"böse\", \"Antonym\"),\n", " (\"groß\", \"klein\", \"Antonym\"),\n", " (\"heiß\", \"kalt\", \"Antonym\"),\n", " (\"Tag\", \"Nacht\", \"Antonym\"),\n", " (\"hoch\", \"tief\", \"Antonym\"),\n", " (\"jung\", \"alt\", \"Antonym\"),\n", " (\"voll\", \"leer\", \"Antonym\"),\n", " (\"Liebe\", \"Hass\", \"Antonym\"),\n", " (\"Licht\", \"Schatten\", \"Antonym\"),\n", " (\"Start\", \"Ziel\", \"Antonym\"),\n", " (\"Frage\", \"Antwort\", \"Antonym\"),\n", "# --- ADDITIONAL ANTONYMS (Fundamental Opposites) ---\n", " (\"Leben\", \"Tod\", \"Antonym\"),\n", " (\"Freund\", \"Feind\", \"Antonym\"),\n", " (\"Krieg\", \"Frieden\", \"Antonym\"),\n", " (\"Sieg\", \"Niederlage\", \"Antonym\"),\n", " (\"Gewinn\", \"Verlust\", \"Antonym\"),\n", " (\"Himmel\", \"Hölle\", \"Antonym\"),\n", " (\"Junge\", \"Mädchen\", \"Antonym\"),\n", " (\"Vater\", \"Mutter\", \"Antonym\"),\n", " (\"Bruder\", \"Schwester\", \"Antonym\"),\n", " (\"Sommer\", \"Winter\", \"Antonym\"),\n", " (\"Sonne\", \"Mond\", \"Antonym\"),\n", " (\"Feuer\", \"Wasser\", \"Antonym\"),\n", " (\"schwarz\", \"weiß\", \"Antonym\"),\n", " (\"hart\", \"weich\", \"Antonym\"),\n", " (\"laut\", \"leise\", \"Antonym\"),\n", " (\"schnell\", \"langsam\", \"Antonym\"),\n", " (\"teuer\", \"billig\", \"Antonym\"),\n", " (\"reich\", \"arm\", \"Antonym\"),\n", " (\"schwer\", \"leicht\", \"Antonym\"),\n", " (\"nass\", \"trocken\", \"Antonym\"),\n", " (\"sauber\", \"schmutzig\", \"Antonym\"),\n", " (\"klug\", \"dumm\", \"Antonym\"),\n", " (\"stark\", \"schwach\", \"Antonym\"),\n", " (\"dick\", \"dünn\", \"Antonym\"),\n", " (\"breit\", \"schmal\", \"Antonym\"),\n", " # --- NEW RANDOM/UNRELATED ---\n", " (\"Mond\", \"Tisch\", \"Random\"),\n", " (\"Brot\", \"Wolke\", \"Random\"),\n", " (\"Schuh\", \"Idee\", \"Random\"),\n", " (\"Baum\", \"Zahn\", \"Random\"),\n", " (\"Glas\", \"Löwe\", \"Random\"),\n", " (\"Buch\", \"Suppe\", \"Random\"),\n", " (\"Wand\", \"Vogel\", \"Random\"),\n", " (\"Gras\", \"Auto\", \"Random\"),\n", " (\"Salz\", \"Musik\", \"Random\"),\n", " (\"Dach\", \"Fisch\", \"Random\"),\n", " (\"Stein\", \"Wort\", \"Random\"),\n", " (\"Kopf\", \"Preis\", \"Random\"),\n", " (\"Hand\", \"Woche\", \"Random\"),\n", " (\"Euro\", \"Apfel\", \"Random\"),\n", " (\"Auto\", \"Idee\", \"Random\"),\n", " (\"schnell\", \"Haus\", \"Random\"),\n", " (\"Zeit\", \"Fisch\", \"Random\"),\n", " (\"Start\", \"Milch\", \"Random\"),\n", " (\"Raum\", \"Laufen\", \"Random\"),\n", "# --- ADDITIONAL RANDOM PAIRS (Noise Floor) ---\n", " (\"Käse\", \"Mond\", \"Random\"),\n", " (\"Bier\", \"Tante\", \"Random\"),\n", " (\"Zahn\", \"Autobahn\", \"Random\"),\n", " (\"Vogel\", \"Benzin\", \"Random\"),\n", " (\"Computer\", \"Blume\", \"Random\"),\n", " (\"Glas\", \"Schaf\", \"Random\"),\n", " (\"Schuh\", \"Luft\", \"Random\"),\n", " (\"Kaffee\", \"Stein\", \"Random\"),\n", " (\"Wand\", \"Butter\", \"Random\"),\n", " (\"Fenster\", \"Schmerz\", \"Random\"),\n", " (\"Nase\", \"Rechnung\", \"Random\"),\n", " (\"Hund\", \"Lampe\", \"Random\"),\n", " (\"Katze\", \"Strom\", \"Random\"),\n", " (\"Apfel\", \"Krieg\", \"Random\"),\n", " (\"Löffel\", \"Angst\", \"Random\"),\n", " (\"Zucker\", \"Politik\", \"Random\"),\n", " (\"Salz\", \"Liebe\", \"Random\"),\n", " (\"Pfeffer\", \"Auto\", \"Random\"),\n", " (\"Stuhl\", \"Wolke\", \"Random\"),\n", "]\n", "\n", "# ==========================================\n", "# 3. HELPER FUNCTIONS\n", "# ==========================================\n", "def is_single_token(word):\n", " \"\"\"Check if word is 1 token in vocabulary.\"\"\"\n", " ids = tokenizer.encode(word, add_special_tokens=False)\n", " return len(ids) == 1, ids[0] if len(ids) == 1 else None\n", "\n", "def calculate_coherence(id_a, id_b):\n", " \"\"\"Extract phases and calculate Mean Resultant Length (R).\"\"\"\n", " # 1. Get Weights (CPU)\n", " w = model.harmonic_embedding.complex_embedding.weight.detach().cpu()\n", "\n", " # 2. Form Complex Numbers (Real + i*Imag)\n", " za = torch.complex(w[id_a, :D_MODEL], w[id_a, D_MODEL:])\n", " zb = torch.complex(w[id_b, :D_MODEL], w[id_b, D_MODEL:])\n", "\n", " # 3. Phase Difference (Angle between vectors)\n", " diff = torch.angle(za) - torch.angle(zb)\n", "\n", " # 4. Energy Weighting (Magnitude * Magnitude)\n", " # Stronger concepts contribute more to the \"Phase Compass\"\n", " weights = torch.abs(za) * torch.abs(zb)\n", "\n", " # 5. Convert to Numpy\n", " diff_np = diff.numpy()\n", " weights_np = weights.numpy()\n", "\n", " # 6. Calculate Circular Mean (R)\n", " # R ranges from 0 (Random/Cancel) to 1 (Perfect Alignment)\n", " weighted_complex_diffs = weights_np * np.exp(1j * diff_np)\n", " mean_vector = np.sum(weighted_complex_diffs) / np.sum(weights_np)\n", "\n", " return np.abs(mean_vector), np.angle(mean_vector), diff_np, weights_np\n", "\n", "# ==========================================\n", "# 4. EXECUTE ANALYSIS\n", "# ==========================================\n", "valid_pairs = []\n", "results = []\n", "\n", "print(f\"\\n{'Pair':<25} | {'Type':<10} | {'Status':<15} | {'R (Coherence)'}\")\n", "print(\"-\" * 75)\n", "\n", "for w1, w2, ptype in candidates_raw:\n", " s1, id1 = is_single_token(w1)\n", " s2, id2 = is_single_token(w2)\n", "\n", " if s1 and s2:\n", " R, angle, diffs, weights = calculate_coherence(id1, id2)\n", " valid_pairs.append({\n", " \"w1\": w1, \"w2\": w2, \"type\": ptype,\n", " \"R\": R, \"angle\": angle, \"diffs\": diffs, \"weights\": weights\n", " })\n", " results.append({\"Pair\": f\"{w1}-{w2}\", \"Type\": ptype, \"R\": R})\n", " print(f\"{w1}-{w2:<20} | {ptype:<10} | ✅ Valid | {R:.4f}\")\n", " else:\n", " # Just logging for info, skipped in analysis\n", " pass\n", " # print(f\"{w1}-{w2:<20} | {ptype:<10} | ❌ Multi-token | -\")\n", "\n", "# ==========================================\n", "# 5. STATISTICS\n", "# ==========================================\n", "df = pd.DataFrame(results)\n", "print(\"\\n📊 AGGREGATE STATS (Mean Resultant Length R):\")\n", "print(df.groupby(\"Type\")[\"R\"].describe())\n", "\n", "# ==========================================\n", "# 6. VISUALIZATION (GRID 3x3)\n", "# ==========================================\n", "# Select Top 3 from each category to show clearest examples\n", "synonyms = sorted([p for p in valid_pairs if p[\"type\"] == \"Synonym\"], key=lambda x: x[\"R\"], reverse=True)[:3]\n", "antonyms = sorted([p for p in valid_pairs if p[\"type\"] == \"Antonym\"], key=lambda x: x[\"R\"], reverse=True)[:3]\n", "randoms = sorted([p for p in valid_pairs if p[\"type\"] == \"Random\"], key=lambda x: x[\"R\"], reverse=False)[:3] # Lowest R for randoms\n", "\n", "plot_list = synonyms + antonyms + randoms\n", "\n", "if len(plot_list) > 0:\n", " fig = plt.figure(figsize=(15, 12))\n", " fig.suptitle(\"Semantic Phase Compass: Synonyms vs Antonyms vs Randoms\", fontsize=16, y=0.98)\n", "\n", " # Colors for categories\n", " colors = {\"Synonym\": \"red\", \"Antonym\": \"purple\", \"Random\": \"blue\"}\n", "\n", " for i, item in enumerate(plot_list):\n", " ax = fig.add_subplot(3, 3, i+1, projection='polar')\n", "\n", " ptype = item[\"type\"]\n", " color = colors[ptype]\n", "\n", " # Weighted Histogram of Phase Differences\n", " ax.hist(item[\"diffs\"], bins=30, weights=item[\"weights\"], color=color, alpha=0.7, density=True)\n", "\n", " # Mean Vector Arrow (The \"Compass Needle\")\n", " # Length of arrow = R (Coherence Strength)\n", " ax.annotate(\"\", xy=(item[\"angle\"], item[\"R\"]), xytext=(0,0),\n", " arrowprops=dict(facecolor='black', width=2, headwidth=10))\n", "\n", " # Styling\n", " ax.set_title(f\"{item['w1']} - {item['w2']}\\n{ptype}\\nR = {item['R']:.3f}\", fontsize=11)\n", " ax.set_yticklabels([]) # Hide radial labels\n", " ax.set_xticklabels([]) # Hide angular labels\n", "\n", " plt.tight_layout()\n", " plt.savefig(\"phase_compass_extended.png\", dpi=300)\n", " plt.show()\n", " print(\"\\n📸 Saved plot to 'phase_compass_extended.png'\")\n", "else:\n", " print(\"⚠️ Not enough valid pairs to generate plot.\")" ] }, { "cell_type": "code", "source": [ "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import seaborn as sns\n", "\n", "# ... [Keep your Candidate Lists & Helper Functions from before] ...\n", "\n", "# ==========================================\n", "# 4. EXECUTE ANALYSIS & SELECT BEST EXAMPLES\n", "# ==========================================\n", "valid_pairs = []\n", "results = []\n", "\n", "print(f\"🚀 Running Phase Compass on {len(candidates_raw)} pairs...\")\n", "\n", "for w1, w2, ptype in candidates_raw:\n", " s1, id1 = is_single_token(w1)\n", " s2, id2 = is_single_token(w2)\n", "\n", " if s1 and s2:\n", " R, angle, diffs, weights = calculate_coherence(id1, id2)\n", " valid_pairs.append({\n", " \"w1\": w1, \"w2\": w2, \"type\": ptype,\n", " \"R\": R, \"angle\": angle, \"diffs\": diffs, \"weights\": weights\n", " })\n", " results.append({\"Pair\": f\"{w1}-{w2}\", \"Type\": ptype, \"R\": R})\n", "\n", "# ==========================================\n", "# 5. VISUALIZATION (1x3 STRIP)\n", "# ==========================================\n", "# Select the \"Best\" example for each category (Highest R for Syn/Ant, Lowest for Random)\n", "best_syn = max([p for p in valid_pairs if p[\"type\"] == \"Synonym\"], key=lambda x: x[\"R\"])\n", "best_ant = max([p for p in valid_pairs if p[\"type\"] == \"Antonym\"], key=lambda x: x[\"R\"])\n", "best_rnd = min([p for p in valid_pairs if p[\"type\"] == \"Random\"], key=lambda x: x[\"R\"])\n", "\n", "plot_list = [best_syn, best_ant, best_rnd]\n", "titles = [\"A. Synonyms (High Coherence)\", \"B. Antonyms (High Coherence)\", \"C. Unrelated (Random Phase)\"]\n", "colors = [\"#d62728\", \"#9467bd\", \"#7f7f7f\"] # Red, Purple, Gray\n", "\n", "fig = plt.figure(figsize=(12, 4)) # Wide, Short aspect ratio\n", "\n", "for i, item in enumerate(plot_list):\n", " ax = fig.add_subplot(1, 3, i+1, projection='polar')\n", "\n", " # 1. Circular Histogram (The \"Cloud\")\n", " # We use 'weights' to show that high-energy frequencies matter more\n", " ax.hist(item[\"diffs\"], bins=40, weights=item[\"weights\"], color=colors[i], alpha=0.6, density=True)\n", "\n", " # 2. Mean Resultant Vector (The \"Needle\")\n", " # The length of this arrow is the PROOF of phase locking.\n", " ax.annotate(\"\", xy=(item[\"angle\"], item[\"R\"]), xytext=(0,0),\n", " arrowprops=dict(facecolor='black', width=1.5, headwidth=8, alpha=0.9))\n", "\n", " # 3. Styling\n", " ax.set_title(f\"{titles[i]}\\n'{item['w1']}' - '{item['w2']}'\\n$R = {item['R']:.2f}$\",\n", " fontsize=10, fontweight='bold', pad=10)\n", " ax.set_yticklabels([]) # Hide radial numbers\n", " ax.set_xticklabels([]) # Hide degree numbers\n", " ax.grid(True, alpha=0.3)\n", " ax.set_ylim(0, 0.6) # Fix scale for fair comparison\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"fig_compass_1x3.png\", dpi=300, bbox_inches='tight')\n", "plt.show()\n", "\n", "# ==========================================\n", "# 6. STATISTICAL TABLE OUTPUT\n", "# ==========================================\n", "df = pd.DataFrame(results)\n", "print(\"\\n📊 PHASE LOCKING STATISTICS (Mean Resultant Length R)\")\n", "print(\"=\"*60)\n", "print(f\"{'Category':<15} | {'Mean R':<10} | {'Std Dev':<10} | {'Count'}\")\n", "print(\"-\" * 60)\n", "stats = df.groupby(\"Type\")[\"R\"].agg(['mean', 'std', 'count'])\n", "for idx, row in stats.iterrows():\n", " print(f\"{idx:<15} | {row['mean']:.4f} | {row['std']:.4f} | {int(row['count'])}\")" ], "metadata": { "id": "GzwkznYXTJpL" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ==========================================\n", "# 0. SETUP & DEPENDENCIES\n", "# ==========================================\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import math\n", "import gc\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import json\n", "from transformers import RobertaTokenizerFast\n", "from huggingface_hub import hf_hub_download\n", "from x_transformers import TransformerWrapper, Encoder\n", "\n", "# Global Config\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "SEQ_LEN = 4096\n", "MAX_VOCAB_SIZE = 32768\n", "TOKENIZER_ID = \"prism-lab/wikitext-103-prism-32k-seq4k\" # <--- YOUR REPO\n", "\n", "print(f\"🔥 Initializing Phase Compass Analysis on {DEVICE}\")\n", "\n", "# ==========================================\n", "# 1. ARCHITECTURE DEFINITIONS\n", "# ==========================================\n", "# (Standard Definitions - Collapsed for brevity)\n", "class ComplexDropout(nn.Module):\n", " def __init__(self, p=0.0): super().__init__(); self.p = p\n", " def forward(self, z): return z\n", "class RobustPhaseNorm(nn.Module):\n", " def __init__(self, d, eps=1e-5): super().__init__(); self.scale = nn.Parameter(torch.ones(d)); self.eps = eps\n", " def forward(self, x): return (x / torch.sqrt((x.abs()**2).mean(-1, keepdim=True) + self.eps)) * self.scale\n", "class ModReLU(nn.Module):\n", " def __init__(self, f): super().__init__(); self.b = nn.Parameter(torch.zeros(f))\n", " def forward(self, z): return F.relu(z.abs() + self.b) * (z / (z.abs() + 1e-6))\n", "class ComplexToRealBridge(nn.Module):\n", " def __init__(self, d): super().__init__(); self.proj = nn.Linear(d*2, d); self.norm = nn.LayerNorm(d)\n", " def forward(self, x): return self.norm(self.proj(torch.cat([x.real, x.imag], -1)))\n", "class DynamicRoSE(nn.Module):\n", " def __init__(self, n, d):\n", " super().__init__(); self.raw_embedding = nn.Embedding(n, d); self.adapter = nn.Linear(d, d*2); self.rotation_predictor = nn.Linear(d, d*2)\n", " self.register_buffer('freqs', torch.exp(torch.arange(0, d) * -(math.log(10000.0)/d)))\n", " def forward(self, x):\n", " real = self.raw_embedding(x); params = self.adapter(real); D = real.shape[-1]\n", " z = torch.complex(params[...,:D], params[...,D:]); r = self.rotation_predictor(real); rx, ry = r.chunk(2, -1)\n", " drot = torch.complex(rx/torch.sqrt(rx**2+ry**2+1e-6), ry/torch.sqrt(rx**2+ry**2+1e-6))\n", " pos = torch.arange(real.shape[1], device=x.device).float()\n", " srot = torch.polar(torch.ones_like(torch.outer(pos, self.freqs)), torch.outer(pos, self.freqs))\n", " return (z * srot.unsqueeze(0) * drot), real\n", "class HyenaNeuralFilter(nn.Module):\n", " def __init__(self, d, max_len=1024, h=64):\n", " super().__init__(); self.d = d; self.register_buffer(\"freqs\", torch.exp(torch.arange(0, h, 2) * -(math.log(10000.0)/h)))\n", " self.mlp = nn.Sequential(nn.Linear(h, h), nn.SiLU(), nn.Linear(h, h), nn.SiLU(), nn.Linear(h, d*2))\n", " def forward(self, L, dev):\n", " t = torch.linspace(0, 1, steps=L, device=dev).unsqueeze(-1)\n", " emb = torch.cat([torch.sin(t*self.freqs), torch.cos(t*self.freqs)], -1)\n", " out = self.mlp(emb).view(L, self.d, 2); return torch.complex(out[...,0], out[...,1])\n", "class GatedHarmonicConvolution(nn.Module):\n", " def __init__(self, d, max_len):\n", " super().__init__(); self.d=d; self.filter_len=max_len; self.neural_filter = HyenaNeuralFilter(d, max_len)\n", " self.gate_proj = nn.Linear(d*2, d*2); self.mix_real = nn.Linear(d,d); self.mix_imag = nn.Linear(d,d)\n", " self.out_real = nn.Linear(d,d); self.out_imag = nn.Linear(d,d); self.activation = ModReLU(d); self.norm = RobustPhaseNorm(d)\n", " self.dropout = ComplexDropout(0.0)\n", " def forward(self, x, mask=None):\n", " res = x; x = self.norm(x); B,L,D = x.shape; eff_L = min(L, self.filter_len)\n", " h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n", " xt = torch.fft.ifft(torch.fft.fft(x, n=eff_L, dim=1, norm='ortho') * h, n=eff_L, dim=1, norm='ortho')\n", " if L > eff_L: xt = F.pad(xt, (0,0,0,L-eff_L));\n", " else: xt = xt[:, :L, :]\n", " g = torch.sigmoid(self.gate_proj(torch.cat([x.real, x.imag], -1))); gr, gi = g.chunk(2, -1)\n", " xg = torch.complex(xt.real*gr, xt.imag*gi); mr, mi = self.mix_real, self.mix_imag\n", " xm = torch.complex(mr(xg.real)-mi(xg.imag), mr(xg.imag)+mi(xg.real)); xa = self.activation(xm); or_, oi = self.out_real, self.out_imag\n", " out = torch.complex(or_(xa.real)-oi(xa.imag), or_(xa.imag)+oi(xa.real))\n", " return self.dropout(out) + res\n", "class PRISMEncoder(nn.Module):\n", " def __init__(self, l, d, max_l): super().__init__(); self.layers = nn.ModuleList([GatedHarmonicConvolution(d, max_l) for _ in range(l)]); self.final_norm = RobustPhaseNorm(d)\n", " def forward(self, x):\n", " for layer in self.layers: x = layer(x)\n", " return self.final_norm(x)\n", "\n", "# --- A. BASELINE (Transformer) ---\n", "class LocalBaseline(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.model = TransformerWrapper(\n", " num_tokens=vocab_size, max_seq_len=SEQ_LEN, use_abs_pos_emb=False, tie_embedding=True,\n", " attn_layers=Encoder(dim=512, depth=5, heads=8, rotary_pos_emb=True, attn_flash=True, use_scalenorm=False)\n", " )\n", " def forward(self, x): return self.model(x)\n", "\n", "# --- B. FNET (Hybrid) ---\n", "class FNetBlock(nn.Module):\n", " def __init__(self, d, df):\n", " super().__init__(); self.norm_mix = nn.LayerNorm(d); self.norm_ff = nn.LayerNorm(d)\n", " self.ff = nn.Sequential(nn.Linear(d, df), nn.GELU(), nn.Dropout(0), nn.Linear(df, d), nn.Dropout(0))\n", " def forward(self, x):\n", " r = x; x = self.norm_mix(x); x = torch.fft.fftn(x.float(), dim=(-2,-1), norm='ortho').real.to(r.dtype); x = x+r\n", " r = x; x = self.norm_ff(x); x = self.ff(x); return x+r\n", "class FNetEncoder(nn.Module):\n", " def __init__(self, depth, d, df): super().__init__(); self.layers = nn.ModuleList([FNetBlock(d, df) for _ in range(depth)]); self.norm_out = nn.LayerNorm(d)\n", " def forward(self, x):\n", " for l in self.layers: x = l(x)\n", " return self.norm_out(x)\n", "class HybridFNetMLM(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.token_emb = nn.Embedding(vocab_size, 512); self.pos_emb = nn.Parameter(torch.zeros(1, SEQ_LEN, 512))\n", " self.fnet_encoder = FNetEncoder(6, 512, 2048)\n", " self.transformer_cap = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", " self.final_norm = nn.LayerNorm(512); self.to_logits = nn.Linear(512, vocab_size)\n", " self.to_logits.weight = self.token_emb.weight # Tie\n", " def forward(self, x):\n", " h = self.token_emb(x) + self.pos_emb[:, :x.shape[1], :]\n", " return self.to_logits(self.final_norm(self.transformer_cap(self.fnet_encoder(h))))\n", "\n", "# --- C. PRISM (Phase Coder) ---\n", "class LocalPRISM(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.rose = DynamicRoSE(vocab_size, 512); self.prism_encoder = PRISMEncoder(5, 512, SEQ_LEN)\n", " self.bridge = ComplexToRealBridge(512); self.periscope_proj = nn.Sequential(nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU())\n", " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", " self.lm_head = nn.Linear(512, vocab_size); self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n", " def forward(self, x):\n", " w, p = self.rose(x); w = self.bridge(self.prism_encoder(w))\n", " return self.lm_head(self.refiner(self.periscope_proj(torch.cat([w, p], -1))))\n", "\n", "# --- D. PILLARS (Split-Stream) ---\n", "class LocalPillars(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.rose = DynamicRoSE(vocab_size, 512); self.particle_down = nn.Linear(512, 256); self.wave_down = nn.Linear(1024, 512)\n", " self.fnet_pos = nn.Embedding(SEQ_LEN, 256); self.stream_rate = FNetEncoder(9, 256, 1024)\n", " self.stream_phase = PRISMEncoder(9, 256, SEQ_LEN); self.phase_bridge = ComplexToRealBridge(256)\n", " self.fusion_proj = nn.Linear(512, 512); self.fusion_norm = nn.LayerNorm(512)\n", " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n", " self.head_bias = nn.Parameter(torch.zeros(vocab_size))\n", " def forward(self, x):\n", " w, p = self.rose(x); p_sm = self.particle_down(p); w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n", " w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n", " p_path = self.stream_rate(p_sm + self.fnet_pos(torch.arange(x.shape[1], device=x.device)))\n", " w_path = self.phase_bridge(self.stream_phase(w_sm))\n", " ctx = self.fusion_norm(self.fusion_proj(torch.cat([p_path, w_path], -1)))\n", " return F.linear(self.refiner(ctx), self.rose.raw_embedding.weight, self.head_bias)\n", "\n", "# ==========================================\n", "# 2. ANALYSIS LOGIC\n", "# ==========================================\n", "def smart_load(repo_id, name, cls):\n", " # Init Model\n", " model = cls(vocab_size=MAX_VOCAB_SIZE).to(DEVICE)\n", " print(f\"⬇️ Downloading weights for {name}...\")\n", " try: path = hf_hub_download(repo_id, \"best.pt\")\n", " except: path = hf_hub_download(repo_id, \"pytorch_model.bin\")\n", "\n", " state_dict = torch.load(path, map_location=\"cpu\")\n", " if 'model' in state_dict: state_dict = state_dict['model']\n", " clean = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n", "\n", " # FIXES for Baseline/FNet\n", " if name == \"Baseline\":\n", " new_d = {}\n", " for k, v in clean.items():\n", " nk = k if k.startswith(\"model.\") else \"model.\" + k\n", " if \"token_emb.weight\" in nk and \"emb\" not in nk: nk = nk.replace(\"token_emb.weight\", \"token_emb.emb.weight\")\n", " new_d[nk] = v\n", " clean = new_d\n", " elif name == \"FNet\":\n", " new_d = {}\n", " for k, v in clean.items():\n", " nk = k.replace(\"model.\", \"\")\n", " new_d[nk] = v\n", " clean = new_d\n", "\n", " model.load_state_dict(clean, strict=False)\n", " print(f\"✅ {name} Ready.\")\n", " return model\n", "\n", "def extract_phasor(model, name, token_id):\n", " with torch.no_grad():\n", " token_tensor = torch.tensor([token_id], device=DEVICE)\n", " if name in [\"PRISM\", \"PILLARS\"]:\n", " real_emb = model.rose.raw_embedding(token_tensor)\n", " params = model.rose.adapter(real_emb)\n", " D = real_emb.shape[-1]\n", " z = torch.complex(params[...,:D], params[...,D:])\n", " return z.squeeze(0).cpu()\n", " elif name == \"FNet\":\n", " x = model.token_emb(token_tensor)\n", " return torch.complex(x, torch.zeros_like(x)).squeeze(0).cpu()\n", " return None\n", "\n", "def calculate_coherence_dynamic(model, name, id_a, id_b):\n", " za = extract_phasor(model, name, id_a)\n", " zb = extract_phasor(model, name, id_b)\n", " diff = torch.angle(za) - torch.angle(zb)\n", " weights = torch.abs(za) * torch.abs(zb)\n", "\n", " diff_np = diff.numpy()\n", " weights_np = weights.numpy()\n", " weighted_complex_diffs = weights_np * np.exp(1j * diff_np)\n", " mean_vector = np.sum(weighted_complex_diffs) / (np.sum(weights_np) + 1e-9)\n", " return np.abs(mean_vector), np.angle(mean_vector), diff_np, weights_np\n", "\n", "# ==========================================\n", "# 3. ROBUST CANDIDATE LIST (N = 135)\n", "# ==========================================\n", "candidates_raw = [\n", " # --- SYNONYMS (Positive Correlation) ---\n", " (\"fast\", \"quick\", \"Synonym\"), (\"big\", \"large\", \"Synonym\"), (\"small\", \"little\", \"Synonym\"),\n", " (\"start\", \"begin\", \"Synonym\"), (\"end\", \"finish\", \"Synonym\"), (\"smart\", \"clever\", \"Synonym\"),\n", " (\"hard\", \"tough\", \"Synonym\"), (\"simple\", \"easy\", \"Synonym\"), (\"happy\", \"glad\", \"Synonym\"),\n", " (\"sad\", \"unhappy\", \"Synonym\"), (\"angry\", \"mad\", \"Synonym\"), (\"correct\", \"right\", \"Synonym\"),\n", " (\"wrong\", \"incorrect\", \"Synonym\"), (\"shut\", \"close\", \"Synonym\"), (\"buy\", \"purchase\", \"Synonym\"),\n", " (\"choose\", \"select\", \"Synonym\"), (\"gift\", \"present\", \"Synonym\"), (\"job\", \"work\", \"Synonym\"),\n", " (\"trip\", \"journey\", \"Synonym\"), (\"lady\", \"woman\", \"Synonym\"), (\"guy\", \"man\", \"Synonym\"),\n", " (\"street\", \"road\", \"Synonym\"), (\"stone\", \"rock\", \"Synonym\"), (\"speak\", \"talk\", \"Synonym\"),\n", " (\"listen\", \"hear\", \"Synonym\"), (\"look\", \"see\", \"Synonym\"), (\"run\", \"sprint\", \"Synonym\"),\n", " (\"jump\", \"leap\", \"Synonym\"), (\"scary\", \"afraid\", \"Synonym\"), (\"rich\", \"wealthy\", \"Synonym\"),\n", " (\"weird\", \"strange\", \"Synonym\"), (\"quiet\", \"silent\", \"Synonym\"), (\"loud\", \"noisy\", \"Synonym\"),\n", " (\"trash\", \"garbage\", \"Synonym\"), (\"sick\", \"ill\", \"Synonym\"), (\"thin\", \"slim\", \"Synonym\"),\n", " (\"near\", \"close\", \"Synonym\"), (\"far\", \"distant\", \"Synonym\"), (\"safe\", \"secure\", \"Synonym\"),\n", " (\"fix\", \"repair\", \"Synonym\"), (\"mix\", \"blend\", \"Synonym\"), (\"keep\", \"hold\", \"Synonym\"),\n", " (\"push\", \"shove\", \"Synonym\"), (\"pull\", \"drag\", \"Synonym\"), (\"under\", \"below\", \"Synonym\"),\n", " (\"above\", \"over\", \"Synonym\"), (\"center\", \"middle\", \"Synonym\"), (\"area\", \"zone\", \"Synonym\"),\n", "\n", " # --- ANTONYMS (Negative Correlation / Phase Shift) ---\n", " (\"good\", \"bad\", \"Antonym\"), (\"hot\", \"cold\", \"Antonym\"), (\"high\", \"low\", \"Antonym\"),\n", " (\"up\", \"down\", \"Antonym\"), (\"left\", \"right\", \"Antonym\"), (\"in\", \"out\", \"Antonym\"),\n", " (\"black\", \"white\", \"Antonym\"), (\"day\", \"night\", \"Antonym\"), (\"sun\", \"moon\", \"Antonym\"),\n", " (\"boy\", \"girl\", \"Antonym\"), (\"man\", \"woman\", \"Antonym\"), (\"king\", \"queen\", \"Antonym\"),\n", " (\"life\", \"death\", \"Antonym\"), (\"war\", \"peace\", \"Antonym\"), (\"win\", \"lose\", \"Antonym\"),\n", " (\"rich\", \"poor\", \"Antonym\"), (\"strong\", \"weak\", \"Antonym\"), (\"hard\", \"soft\", \"Antonym\"),\n", " (\"loud\", \"quiet\", \"Antonym\"), (\"wet\", \"dry\", \"Antonym\"), (\"clean\", \"dirty\", \"Antonym\"),\n", " (\"happy\", \"sad\", \"Antonym\"), (\"full\", \"empty\", \"Antonym\"), (\"open\", \"close\", \"Antonym\"),\n", " (\"first\", \"last\", \"Antonym\"), (\"young\", \"old\", \"Antonym\"), (\"new\", \"old\", \"Antonym\"),\n", " (\"fast\", \"slow\", \"Antonym\"), (\"tall\", \"short\", \"Antonym\"), (\"heavy\", \"light\", \"Antonym\"),\n", " (\"dark\", \"light\", \"Antonym\"), (\"true\", \"false\", \"Antonym\"), (\"yes\", \"no\", \"Antonym\"),\n", " (\"on\", \"off\", \"Antonym\"), (\"top\", \"bottom\", \"Antonym\"), (\"friend\", \"enemy\", \"Antonym\"),\n", " (\"give\", \"take\", \"Antonym\"), (\"come\", \"go\", \"Antonym\"), (\"rise\", \"fall\", \"Antonym\"),\n", " (\"north\", \"south\", \"Antonym\"), (\"east\", \"west\", \"Antonym\"), (\"buy\", \"sell\", \"Antonym\"),\n", " (\"love\", \"hate\", \"Antonym\"), (\"win\", \"fail\", \"Antonym\"), (\"start\", \"stop\", \"Antonym\"),\n", "\n", " # --- RANDOM (Noise Floor) ---\n", " (\"apple\", \"car\", \"Random\"), (\"banana\", \"sky\", \"Random\"), (\"bread\", \"cloud\", \"Random\"),\n", " (\"cheese\", \"door\", \"Random\"), (\"milk\", \"shoe\", \"Random\"), (\"water\", \"book\", \"Random\"),\n", " (\"coffee\", \"tree\", \"Random\"), (\"sugar\", \"phone\", \"Random\"), (\"salt\", \"idea\", \"Random\"),\n", " (\"meat\", \"ghost\", \"Random\"), (\"soup\", \"math\", \"Random\"), (\"cake\", \"song\", \"Random\"),\n", " (\"pie\", \"fish\", \"Random\"), (\"egg\", \"wall\", \"Random\"), (\"rice\", \"nose\", \"Random\"),\n", " (\"tea\", \"frog\", \"Random\"), (\"juice\", \"star\", \"Random\"), (\"fruit\", \"chair\", \"Random\"),\n", " (\"lemon\", \"fear\", \"Random\"), (\"melon\", \"bell\", \"Random\"), (\"berry\", \"law\", \"Random\"),\n", " (\"grape\", \"dog\", \"Random\"), (\"plum\", \"cat\", \"Random\"), (\"pear\", \"bird\", \"Random\"),\n", " (\"lime\", \"rock\", \"Random\"), (\"kiwi\", \"mud\", \"Random\"), (\"bean\", \"joy\", \"Random\"),\n", " (\"corn\", \"ice\", \"Random\"), (\"nut\", \"wind\", \"Random\"), (\"fig\", \"pen\", \"Random\"),\n", " (\"yam\", \"bus\", \"Random\"), (\"beef\", \"sun\", \"Random\"), (\"pork\", \"hat\", \"Random\"),\n", " (\"lamb\", \"ink\", \"Random\"), (\"duck\", \"map\", \"Random\"), (\"goat\", \"art\", \"Random\"),\n", " (\"cow\", \"box\", \"Random\"), (\"pig\", \"oil\", \"Random\"), (\"hen\", \"gas\", \"Random\"),\n", " (\"fox\", \"cup\", \"Random\"), (\"wolf\", \"key\", \"Random\"), (\"ant\", \"bed\", \"Random\"),\n", " (\"bee\", \"rug\", \"Random\"), (\"fly\", \"mud\", \"Random\"), (\"worm\", \"sky\", \"Random\")\n", "]\n", "\n", "# ==========================================\n", "# 4. EXECUTION WITH YOUR CUSTOM TOKENIZER\n", "# ==========================================\n", "print(f\"🔑 Loading Tokenizer from {TOKENIZER_ID}...\")\n", "try:\n", " tokenizer = RobertaTokenizerFast.from_pretrained(TOKENIZER_ID)\n", "except:\n", " print(\"⚠️ Fallback to base tokenizer if custom fails (Should not happen)\")\n", " tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n", "\n", "valid_pairs = []\n", "print(f\"🔎 Validating {len(candidates_raw)} candidate pairs...\")\n", "\n", "# Adding a space prefix \" \" is standard for RoBERTa tokenizers if words are start of sentence\n", "# but we check both raw and space-prefixed to be safe.\n", "for w1, w2, ptype in candidates_raw:\n", " # Try with space prefix which RoBERTa often uses for words\n", " ids1 = tokenizer.encode(\" \" + w1, add_special_tokens=False)\n", " ids2 = tokenizer.encode(\" \" + w2, add_special_tokens=False)\n", "\n", " # Fallback to raw if space fails\n", " if len(ids1) != 1: ids1 = tokenizer.encode(w1, add_special_tokens=False)\n", " if len(ids2) != 1: ids2 = tokenizer.encode(w2, add_special_tokens=False)\n", "\n", " if len(ids1) == 1 and len(ids2) == 1:\n", " id1, id2 = ids1[0], ids2[0]\n", " if id1 < MAX_VOCAB_SIZE and id2 < MAX_VOCAB_SIZE:\n", " valid_pairs.append((w1, w2, ptype, id1, id2))\n", "\n", "print(f\"✅ Found {len(valid_pairs)} valid single-token pairs for this tokenizer.\")\n", "\n", "MODELS_TO_TEST = [\n", " (\"PRISM\", \"prism-lab/prism-v2-wikitext\", LocalPRISM),\n", " (\"PILLARS\", \"prism-lab/pillars-compact-wikitext\", LocalPillars),\n", " (\"FNet\", \"prism-lab/hybrid-fnet-prism-custom\", HybridFNetMLM)\n", "]\n", "\n", "all_results = {}\n", "\n", "for name, repo, cls in MODELS_TO_TEST:\n", " print(f\"\\n🧪 Analyzing {name}...\")\n", " try:\n", " model = smart_load(repo, name, cls)\n", " model.eval()\n", "\n", " results = []\n", " for w1, w2, ptype, id1, id2 in valid_pairs:\n", " R, angle, diffs, weights = calculate_coherence_dynamic(model, name, id1, id2)\n", " results.append({\"Pair\": f\"{w1}-{w2}\", \"Type\": ptype, \"R\": R, \"Diffs\": diffs, \"Weights\": weights})\n", "\n", " all_results[name] = results\n", " df = pd.DataFrame(results)\n", " print(f\"📊 {name} Results (Mean R):\")\n", " if not df.empty:\n", " print(df.groupby(\"Type\")[\"R\"].mean())\n", " del model; torch.cuda.empty_cache(); gc.collect()\n", " except Exception as e:\n", " print(f\"❌ {name} Failed: {e}\")\n", "\n", "# Plotting\n", "if len(all_results) > 0:\n", " fig = plt.figure(figsize=(12, 10))\n", " cols = [\"Synonym\", \"Antonym\", \"Random\"]\n", " rows = list(all_results.keys())\n", " colors = {\"Synonym\": \"red\", \"Antonym\": \"purple\", \"Random\": \"gray\"}\n", "\n", " idx = 1\n", " for model_name in rows:\n", " data = all_results[model_name]\n", " df = pd.DataFrame(data)\n", " if df.empty: continue\n", "\n", " best_syn = df[df[\"Type\"]==\"Synonym\"].sort_values(\"R\", ascending=False).iloc[0]\n", " best_ant = df[df[\"Type\"]==\"Antonym\"].sort_values(\"R\", ascending=False).iloc[0]\n", " best_rnd = df[df[\"Type\"]==\"Random\"].sort_values(\"R\", ascending=True).iloc[0]\n", "\n", " examples = [best_syn, best_ant, best_rnd]\n", " for i, ex in enumerate(examples):\n", " ax = fig.add_subplot(len(rows), 3, idx, projection='polar')\n", " c = colors[ex[\"Type\"]]\n", " ax.hist(ex[\"Diffs\"], bins=30, weights=ex[\"Weights\"], color=c, alpha=0.6, density=True)\n", " ax.annotate(\"\", xy=(0, ex[\"R\"]), xytext=(0,0), arrowprops=dict(facecolor='black', width=1.5, headwidth=8, alpha=0.9))\n", "\n", " label = f\"{ex['Pair']}\\nR={ex['R']:.3f}\"\n", " if i == 1: ax.set_title(f\"Model: {model_name}\\n{label}\", fontsize=10, weight='bold')\n", " else: ax.set_title(label, fontsize=9)\n", "\n", " ax.set_yticklabels([]); ax.set_xticklabels([])\n", " idx += 1\n", "\n", " plt.tight_layout()\n", " plt.savefig(\"multi_model_compass_publication.png\", dpi=300)\n", " print(\"\\n📸 Saved plot to 'multi_model_compass_publication.png'\")" ], "metadata": { "id": "2-elQ3KH6aNg" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ========================\n", "# 6. ANGULAR TOPOLOGY PLOT\n", "# ========================\n", "import seaborn as sns\n", "\n", "def plot_angular_topology(all_results):\n", " fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)\n", "\n", " # We only care about Synonyms to see \"How\" they align\n", " models = [\"FNet\", \"PRISM\"]\n", " colors = {\"FNet\": \"blue\", \"PRISM\": \"red\"}\n", "\n", " for i, name in enumerate(models):\n", " if name not in all_results: continue\n", "\n", " # Collect ALL phase differences for Synonyms across all pairs\n", " # We flatten the list of angles\n", " angles = []\n", " data = all_results[name]\n", " for item in data:\n", " if item[\"Type\"] == \"Synonym\":\n", " # Convert radians to degrees for readability\n", " deg = np.degrees(item[\"Diffs\"])\n", " # Wrap to -180 to 180\n", " deg = (deg + 180) % 360 - 180\n", " angles.extend(deg)\n", "\n", " sns.histplot(angles, ax=axes[i], bins=60, color=colors[name], stat=\"density\", kde=True)\n", " axes[i].set_title(f\"{name} Phase Topology (Synonyms)\")\n", " axes[i].set_xlabel(\"Phase Difference (Degrees)\")\n", " axes[i].set_xlim(-180, 180)\n", " axes[i].grid(True, alpha=0.3)\n", "\n", " # Add annotation\n", " if name == \"FNet\":\n", " axes[i].text(0, 0.01, \"BINARY\\n(Sign Flips)\", ha='center', color='black', fontweight='bold')\n", " else:\n", " axes[i].text(0, 0.01, \"CONTINUOUS\\n(Rotation)\", ha='center', color='black', fontweight='bold')\n", "\n", " plt.tight_layout()\n", " plt.savefig(\"angular_topology_comparison.png\", dpi=300)\n", " print(\"📸 Saved topology proof to 'angular_topology_comparison.png'\")\n", "\n", "# Run the plot with your existing results\n", "plot_angular_topology(all_results)" ], "metadata": { "id": "esnL8jUk89ov" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# ==========================================\n", "# 7. RATE VS PHASE DISSOCIATION PROBE\n", "# ==========================================\n", "from scipy.stats import pearsonr\n", "\n", "def check_cosine_and_magnitude(model, name, id_a, id_b):\n", " z_a = extract_phasor(model, name, id_a)\n", " z_b = extract_phasor(model, name, id_b)\n", "\n", " # --- 1. Vector Cosine Similarity (The Standard Metric) ---\n", " # For Complex (PRISM), we treat Re/Im as two coordinate dimensions\n", " if name in [\"PRISM\", \"PILLARS\"]:\n", " # Flatten: [Re_1, Im_1, Re_2, Im_2, ...]\n", " vec_a = torch.cat([z_a.real, z_a.imag], -1)\n", " vec_b = torch.cat([z_b.real, z_b.imag], -1)\n", " else:\n", " # FNet is already real\n", " vec_a = z_a.real\n", " vec_b = z_b.real\n", "\n", " vec_sim = F.cosine_similarity(vec_a.unsqueeze(0), vec_b.unsqueeze(0)).item()\n", "\n", " # --- 2. Magnitude Correlation (The \"Rate Coding\" Check) ---\n", " # Do these words emphasize the same dimensions?\n", " mag_a = torch.abs(z_a).numpy()\n", " mag_b = torch.abs(z_b).numpy()\n", "\n", " # Pearson correlation of the magnitude profiles\n", " # If the model uses Rate Coding, this should be HIGH.\n", " # If the model is Iso-Energetic (PRISM), this should be NOISE.\n", " if np.std(mag_a) < 1e-6 or np.std(mag_b) < 1e-6:\n", " mag_corr = 0.0 # Handle constant magnitude case\n", " else:\n", " mag_corr, _ = pearsonr(mag_a, mag_b)\n", "\n", " return vec_sim, mag_corr\n", "\n", "print(\"\\n⚖️ Running Rate vs. Phase Dissociation...\")\n", "\n", "comparison_data = []\n", "\n", "# We only check Synonyms to see how they agree\n", "synonym_pairs = [p for p in valid_pairs if p[2] == \"Synonym\"]\n", "\n", "for name, repo, cls in MODELS_TO_TEST:\n", " try:\n", " model = smart_load(repo, name, cls)\n", " model.eval()\n", "\n", " vec_scores = []\n", " mag_scores = []\n", "\n", " for w1, w2, _, id1, id2 in synonym_pairs:\n", " v_sim, m_corr = check_cosine_and_magnitude(model, name, id1, id2)\n", " vec_scores.append(v_sim)\n", " mag_scores.append(m_corr)\n", "\n", " avg_vec = np.mean(vec_scores)\n", " avg_mag = np.mean(mag_scores)\n", "\n", " comparison_data.append({\n", " \"Model\": name,\n", " \"Vector Sim (Direction)\": avg_vec,\n", " \"Mag Corr (Loudness)\": avg_mag\n", " })\n", "\n", " del model; torch.cuda.empty_cache()\n", " except Exception as e:\n", " print(f\"Skipping {name}: {e}\")\n", "\n", "# Display the \"Dissociation\" Table\n", "df_comp = pd.DataFrame(comparison_data)\n", "print(\"\\n🔥 THE DISSOCIATION TABLE 🔥\")\n", "print(df_comp.set_index(\"Model\"))" ], "metadata": { "id": "URcMGvENAE3d" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import pandas as pd\n", "\n", "# 1. Convert the valid_pairs list to a DataFrame\n", "df_stats = pd.DataFrame(valid_pairs, columns=[\"Word1\", \"Word2\", \"Category\", \"ID1\", \"ID2\"])\n", "\n", "# 2. Print the statistics\n", "print(\"\\n📊 DATASET STATISTICS (Post-Filtering)\")\n", "print(\"========================================\")\n", "# Counts per category\n", "counts = df_stats[\"Category\"].value_counts()\n", "print(counts)\n", "\n", "print(\"----------------------------------------\")\n", "print(f\"✅ Total Valid Pairs: {len(df_stats)}\")\n", "print(\"========================================\")\n", "\n", "# 3. Helper for your Paper's Table\n", "print(\"\\n📝 UPDATE FOR TABLE 2 (Count Column):\")\n", "for category, count in counts.items():\n", " print(f\" > {category}: {count}\")" ], "metadata": { "id": "0kJpuCLhEHAJ" }, "execution_count": null, "outputs": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }