prism-lab commited on
Commit
4774e5b
·
verified ·
1 Parent(s): 63a761a

Upload 2 files

Browse files
Inspect_Resonances_Last.ipynb ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "SvLO5U3q_Q3x"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "!pip install -q x-transformers"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {
18
+ "id": "xy1HCL1GzAbM"
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "# ==========================================\n",
23
+ "# 1. SETUP & MODEL LOADING (FIXED)\n",
24
+ "# ==========================================\n",
25
+ "import os\n",
26
+ "import sys\n",
27
+ "from huggingface_hub import hf_hub_download\n",
28
+ "\n",
29
+ "# --- CRITICAL FIX: Download the Model Definition FIRST ---\n",
30
+ "REPO_ID = \"prism-lab/prism-shimmer-100k\"\n",
31
+ "filename = \"modeling_prism_gated.py\"\n",
32
+ "\n",
33
+ "print(f\"⬇️ Downloading {filename} from Hugging Face...\")\n",
34
+ "if not os.path.exists(filename):\n",
35
+ " hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\".\", force_download=True)\n",
36
+ "\n",
37
+ "# Now that the file exists locally, we can import it\n",
38
+ "sys.path.append(\".\") # Ensure current dir is in path\n",
39
+ "from modeling_prism_gated import PRISMHybrid_RoPE\n",
40
+ "\n",
41
+ "# Continue with standard imports\n",
42
+ "import torch\n",
43
+ "import numpy as np\n",
44
+ "import matplotlib.pyplot as plt\n",
45
+ "import pandas as pd\n",
46
+ "from transformers import AutoTokenizer\n",
47
+ "import json\n",
48
+ "\n",
49
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
50
+ "D_MODEL = 512\n",
51
+ "\n",
52
+ "print(\"⏳ Downloading Weights & Config...\")\n",
53
+ "if not os.path.exists(\"config.json\"):\n",
54
+ " hf_hub_download(repo_id=REPO_ID, filename=\"config.json\", local_dir=\".\")\n",
55
+ "if not os.path.exists(\"pytorch_model.bin\"):\n",
56
+ " hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\", local_dir=\".\")\n",
57
+ "\n",
58
+ "with open(\"config.json\", \"r\") as f: config = json.load(f)\n",
59
+ "tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n",
60
+ "\n",
61
+ "# Initialize Model\n",
62
+ "model = PRISMHybrid_RoPE(\n",
63
+ " vocab_size=config['vocab_size'], d_model=config['d_model'],\n",
64
+ " num_encoder_layers=config['num_encoder_layers'], num_refining_layers=0,\n",
65
+ " num_decoder_layers=6, num_heads=8, dff=2048, max_length=128, dropout=0.0\n",
66
+ ").to(DEVICE)\n",
67
+ "\n",
68
+ "model.load_state_dict(torch.load(\"pytorch_model.bin\", map_location=DEVICE))\n",
69
+ "model.eval()\n",
70
+ "print(\"✅ Model Loaded Successfully.\")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {
77
+ "id": "vmp0y5TYdTd0"
78
+ },
79
+ "outputs": [],
80
+ "source": [
81
+ "# @title 🧭 Extended Phase Compass: Synonyms vs Antonyms vs Randoms\n",
82
+ "import torch\n",
83
+ "import numpy as np\n",
84
+ "import matplotlib.pyplot as plt\n",
85
+ "import pandas as pd\n",
86
+ "from transformers import AutoTokenizer\n",
87
+ "from huggingface_hub import hf_hub_download\n",
88
+ "import os\n",
89
+ "import json\n",
90
+ "\n",
91
+ "\n",
92
+ "# ==========================================\n",
93
+ "# 2. DEFINING THE CANDIDATE PAIRS\n",
94
+ "# ==========================================\n",
95
+ "# Master list of all candidates\n",
96
+ "candidates_raw = [\n",
97
+ " # --- ORIGINAL LIST ---\n",
98
+ " (\"Euro\", \"Geld\", \"Synonym\"),\n",
99
+ " (\"Auto\", \"Wagen\", \"Synonym\"),\n",
100
+ " (\"schnell\", \"rasch\", \"Synonym\"),\n",
101
+ " (\"Stimme\", \"Wahl\", \"Synonym\"),\n",
102
+ " (\"Zeit\", \"Uhr\", \"Synonym\"),\n",
103
+ " (\"Start\", \"Beginn\", \"Synonym\"),\n",
104
+ " (\"Ende\", \"Schluss\", \"Synonym\"),\n",
105
+ " (\"Raum\", \"Platz\", \"Synonym\"),\n",
106
+ "\n",
107
+ " # --- NEW GERMAN SYNONYMS ---\n",
108
+ " (\"Haus\", \"Heim\", \"Synonym\"),\n",
109
+ " (\"Boot\", \"Schiff\", \"Synonym\"),\n",
110
+ " (\"See\", \"Meer\", \"Synonym\"),\n",
111
+ " (\"Wald\", \"Forst\", \"Synonym\"),\n",
112
+ " (\"Weg\", \"Pfad\", \"Synonym\"),\n",
113
+ " (\"Berg\", \"Gipfel\", \"Synonym\"),\n",
114
+ " (\"Mund\", \"Maul\", \"Synonym\"),\n",
115
+ " (\"Pferd\", \"Ross\", \"Synonym\"),\n",
116
+ " (\"Hund\", \"Tier\", \"Synonym\"),\n",
117
+ " (\"Reise\", \"Fahrt\", \"Synonym\"),\n",
118
+ " (\"Angst\", \"Furcht\", \"Synonym\"),\n",
119
+ " (\"Mut\", \"Traute\", \"Synonym\"),\n",
120
+ " (\"Glück\", \"Dusel\", \"Synonym\"),\n",
121
+ " (\"Ding\", \"Sache\", \"Synonym\"),\n",
122
+ " (\"Welt\", \"Erde\", \"Synonym\"),\n",
123
+ " (\"Stadt\", \"Ort\", \"Synonym\"),\n",
124
+ " (\"Vater\", \"Papa\", \"Synonym\"),\n",
125
+ " (\"Mutter\", \"Mama\", \"Synonym\"),\n",
126
+ " (\"klug\", \"weise\", \"Synonym\"),\n",
127
+ " (\"klug\", \"schlau\", \"Synonym\"),\n",
128
+ " (\"schön\", \"hübsch\", \"Synonym\"),\n",
129
+ " (\"klein\", \"winzig\", \"Synonym\"),\n",
130
+ " (\"stark\", \"fest\", \"Synonym\"),\n",
131
+ " (\"neu\", \"frisch\", \"Synonym\"),\n",
132
+ " (\"still\", \"leise\", \"Synonym\"),\n",
133
+ " (\"froh\", \"heiter\", \"Synonym\"),\n",
134
+ " (\"dunkel\", \"finster\", \"Synonym\"),\n",
135
+ " (\"kalt\", \"eisig\", \"Synonym\"),\n",
136
+ " (\"rennen\", \"laufen\", \"Synonym\"),\n",
137
+ " (\"reden\", \"sagen\", \"Synonym\"),\n",
138
+ " (\"sehen\", \"schauen\", \"Synonym\"),\n",
139
+ " (\"gehen\", \"wandern\", \"Synonym\"),\n",
140
+ " (\"essen\", \"speisen\", \"Synonym\"),\n",
141
+ " (\"Wut\", \"Zorn\", \"Synonym\"),\n",
142
+ " (\"Dreck\", \"Schmutz\", \"Synonym\"),\n",
143
+ " (\"Chance\", \"Möglichkeit\", \"Synonym\"), # Möglichkeit might be split, but code will check\n",
144
+ " (\"Lehrer\", \"Pauker\", \"Synonym\"),\n",
145
+ " (\"Gott\", \"Herr\", \"Synonym\"),\n",
146
+ " (\"Chef\", \"Boss\", \"Synonym\"),\n",
147
+ " (\"Haut\", \"Fell\", \"Synonym\"),\n",
148
+ " (\"Tor\", \"Tür\", \"Synonym\"),\n",
149
+ " (\"Zimmer\", \"Raum\", \"Synonym\"),\n",
150
+ " (\"Bahn\", \"Zug\", \"Synonym\"),\n",
151
+ " (\"Boot\", \"Kahn\", \"Synonym\"),\n",
152
+ " (\"Hose\", \"Jeans\", \"Synonym\"),\n",
153
+ " (\"Witz\", \"Scherz\", \"Synonym\"),\n",
154
+ " (\"Hass\", \"Abscheu\", \"Synonym\"),\n",
155
+ " (\"Fett\", \"Dick\", \"Synonym\"),\n",
156
+ " (\"klug\", \"gescheit\", \"Synonym\"),\n",
157
+ " (\"dumm\", \"doof\", \"Synonym\"),\n",
158
+ " (\"rasch\", \"flink\", \"Synonym\"),\n",
159
+ " (\"stumm\", \"still\", \"Synonym\"),\n",
160
+ " (\"echt\", \"wahr\", \"Synonym\"),\n",
161
+ " (\"korrekt\", \"richtig\", \"Synonym\"),\n",
162
+ " # --- NEW ANTONYMS (OPPOSITES) ---\n",
163
+ " (\"gut\", \"böse\", \"Antonym\"),\n",
164
+ " (\"groß\", \"klein\", \"Antonym\"),\n",
165
+ " (\"heiß\", \"kalt\", \"Antonym\"),\n",
166
+ " (\"Tag\", \"Nacht\", \"Antonym\"),\n",
167
+ " (\"hoch\", \"tief\", \"Antonym\"),\n",
168
+ " (\"jung\", \"alt\", \"Antonym\"),\n",
169
+ " (\"voll\", \"leer\", \"Antonym\"),\n",
170
+ " (\"Liebe\", \"Hass\", \"Antonym\"),\n",
171
+ " (\"Licht\", \"Schatten\", \"Antonym\"),\n",
172
+ " (\"Start\", \"Ziel\", \"Antonym\"),\n",
173
+ " (\"Frage\", \"Antwort\", \"Antonym\"),\n",
174
+ "# --- ADDITIONAL ANTONYMS (Fundamental Opposites) ---\n",
175
+ " (\"Leben\", \"Tod\", \"Antonym\"),\n",
176
+ " (\"Freund\", \"Feind\", \"Antonym\"),\n",
177
+ " (\"Krieg\", \"Frieden\", \"Antonym\"),\n",
178
+ " (\"Sieg\", \"Niederlage\", \"Antonym\"),\n",
179
+ " (\"Gewinn\", \"Verlust\", \"Antonym\"),\n",
180
+ " (\"Himmel\", \"Hölle\", \"Antonym\"),\n",
181
+ " (\"Junge\", \"Mädchen\", \"Antonym\"),\n",
182
+ " (\"Vater\", \"Mutter\", \"Antonym\"),\n",
183
+ " (\"Bruder\", \"Schwester\", \"Antonym\"),\n",
184
+ " (\"Sommer\", \"Winter\", \"Antonym\"),\n",
185
+ " (\"Sonne\", \"Mond\", \"Antonym\"),\n",
186
+ " (\"Feuer\", \"Wasser\", \"Antonym\"),\n",
187
+ " (\"schwarz\", \"weiß\", \"Antonym\"),\n",
188
+ " (\"hart\", \"weich\", \"Antonym\"),\n",
189
+ " (\"laut\", \"leise\", \"Antonym\"),\n",
190
+ " (\"schnell\", \"langsam\", \"Antonym\"),\n",
191
+ " (\"teuer\", \"billig\", \"Antonym\"),\n",
192
+ " (\"reich\", \"arm\", \"Antonym\"),\n",
193
+ " (\"schwer\", \"leicht\", \"Antonym\"),\n",
194
+ " (\"nass\", \"trocken\", \"Antonym\"),\n",
195
+ " (\"sauber\", \"schmutzig\", \"Antonym\"),\n",
196
+ " (\"klug\", \"dumm\", \"Antonym\"),\n",
197
+ " (\"stark\", \"schwach\", \"Antonym\"),\n",
198
+ " (\"dick\", \"dünn\", \"Antonym\"),\n",
199
+ " (\"breit\", \"schmal\", \"Antonym\"),\n",
200
+ " # --- NEW RANDOM/UNRELATED ---\n",
201
+ " (\"Mond\", \"Tisch\", \"Random\"),\n",
202
+ " (\"Brot\", \"Wolke\", \"Random\"),\n",
203
+ " (\"Schuh\", \"Idee\", \"Random\"),\n",
204
+ " (\"Baum\", \"Zahn\", \"Random\"),\n",
205
+ " (\"Glas\", \"Löwe\", \"Random\"),\n",
206
+ " (\"Buch\", \"Suppe\", \"Random\"),\n",
207
+ " (\"Wand\", \"Vogel\", \"Random\"),\n",
208
+ " (\"Gras\", \"Auto\", \"Random\"),\n",
209
+ " (\"Salz\", \"Musik\", \"Random\"),\n",
210
+ " (\"Dach\", \"Fisch\", \"Random\"),\n",
211
+ " (\"Stein\", \"Wort\", \"Random\"),\n",
212
+ " (\"Kopf\", \"Preis\", \"Random\"),\n",
213
+ " (\"Hand\", \"Woche\", \"Random\"),\n",
214
+ " (\"Euro\", \"Apfel\", \"Random\"),\n",
215
+ " (\"Auto\", \"Idee\", \"Random\"),\n",
216
+ " (\"schnell\", \"Haus\", \"Random\"),\n",
217
+ " (\"Zeit\", \"Fisch\", \"Random\"),\n",
218
+ " (\"Start\", \"Milch\", \"Random\"),\n",
219
+ " (\"Raum\", \"Laufen\", \"Random\"),\n",
220
+ "# --- ADDITIONAL RANDOM PAIRS (Noise Floor) ---\n",
221
+ " (\"Käse\", \"Mond\", \"Random\"),\n",
222
+ " (\"Bier\", \"Tante\", \"Random\"),\n",
223
+ " (\"Zahn\", \"Autobahn\", \"Random\"),\n",
224
+ " (\"Vogel\", \"Benzin\", \"Random\"),\n",
225
+ " (\"Computer\", \"Blume\", \"Random\"),\n",
226
+ " (\"Glas\", \"Schaf\", \"Random\"),\n",
227
+ " (\"Schuh\", \"Luft\", \"Random\"),\n",
228
+ " (\"Kaffee\", \"Stein\", \"Random\"),\n",
229
+ " (\"Wand\", \"Butter\", \"Random\"),\n",
230
+ " (\"Fenster\", \"Schmerz\", \"Random\"),\n",
231
+ " (\"Nase\", \"Rechnung\", \"Random\"),\n",
232
+ " (\"Hund\", \"Lampe\", \"Random\"),\n",
233
+ " (\"Katze\", \"Strom\", \"Random\"),\n",
234
+ " (\"Apfel\", \"Krieg\", \"Random\"),\n",
235
+ " (\"Löffel\", \"Angst\", \"Random\"),\n",
236
+ " (\"Zucker\", \"Politik\", \"Random\"),\n",
237
+ " (\"Salz\", \"Liebe\", \"Random\"),\n",
238
+ " (\"Pfeffer\", \"Auto\", \"Random\"),\n",
239
+ " (\"Stuhl\", \"Wolke\", \"Random\"),\n",
240
+ "]\n",
241
+ "\n",
242
+ "# ==========================================\n",
243
+ "# 3. HELPER FUNCTIONS\n",
244
+ "# ==========================================\n",
245
+ "def is_single_token(word):\n",
246
+ " \"\"\"Check if word is 1 token in vocabulary.\"\"\"\n",
247
+ " ids = tokenizer.encode(word, add_special_tokens=False)\n",
248
+ " return len(ids) == 1, ids[0] if len(ids) == 1 else None\n",
249
+ "\n",
250
+ "def calculate_coherence(id_a, id_b):\n",
251
+ " \"\"\"Extract phases and calculate Mean Resultant Length (R).\"\"\"\n",
252
+ " # 1. Get Weights (CPU)\n",
253
+ " w = model.harmonic_embedding.complex_embedding.weight.detach().cpu()\n",
254
+ "\n",
255
+ " # 2. Form Complex Numbers (Real + i*Imag)\n",
256
+ " za = torch.complex(w[id_a, :D_MODEL], w[id_a, D_MODEL:])\n",
257
+ " zb = torch.complex(w[id_b, :D_MODEL], w[id_b, D_MODEL:])\n",
258
+ "\n",
259
+ " # 3. Phase Difference (Angle between vectors)\n",
260
+ " diff = torch.angle(za) - torch.angle(zb)\n",
261
+ "\n",
262
+ " # 4. Energy Weighting (Magnitude * Magnitude)\n",
263
+ " # Stronger concepts contribute more to the \"Phase Compass\"\n",
264
+ " weights = torch.abs(za) * torch.abs(zb)\n",
265
+ "\n",
266
+ " # 5. Convert to Numpy\n",
267
+ " diff_np = diff.numpy()\n",
268
+ " weights_np = weights.numpy()\n",
269
+ "\n",
270
+ " # 6. Calculate Circular Mean (R)\n",
271
+ " # R ranges from 0 (Random/Cancel) to 1 (Perfect Alignment)\n",
272
+ " weighted_complex_diffs = weights_np * np.exp(1j * diff_np)\n",
273
+ " mean_vector = np.sum(weighted_complex_diffs) / np.sum(weights_np)\n",
274
+ "\n",
275
+ " return np.abs(mean_vector), np.angle(mean_vector), diff_np, weights_np\n",
276
+ "\n",
277
+ "# ==========================================\n",
278
+ "# 4. EXECUTE ANALYSIS\n",
279
+ "# ==========================================\n",
280
+ "valid_pairs = []\n",
281
+ "results = []\n",
282
+ "\n",
283
+ "print(f\"\\n{'Pair':<25} | {'Type':<10} | {'Status':<15} | {'R (Coherence)'}\")\n",
284
+ "print(\"-\" * 75)\n",
285
+ "\n",
286
+ "for w1, w2, ptype in candidates_raw:\n",
287
+ " s1, id1 = is_single_token(w1)\n",
288
+ " s2, id2 = is_single_token(w2)\n",
289
+ "\n",
290
+ " if s1 and s2:\n",
291
+ " R, angle, diffs, weights = calculate_coherence(id1, id2)\n",
292
+ " valid_pairs.append({\n",
293
+ " \"w1\": w1, \"w2\": w2, \"type\": ptype,\n",
294
+ " \"R\": R, \"angle\": angle, \"diffs\": diffs, \"weights\": weights\n",
295
+ " })\n",
296
+ " results.append({\"Pair\": f\"{w1}-{w2}\", \"Type\": ptype, \"R\": R})\n",
297
+ " print(f\"{w1}-{w2:<20} | {ptype:<10} | ✅ Valid | {R:.4f}\")\n",
298
+ " else:\n",
299
+ " # Just logging for info, skipped in analysis\n",
300
+ " pass\n",
301
+ " # print(f\"{w1}-{w2:<20} | {ptype:<10} | ❌ Multi-token | -\")\n",
302
+ "\n",
303
+ "# ==========================================\n",
304
+ "# 5. STATISTICS\n",
305
+ "# ==========================================\n",
306
+ "df = pd.DataFrame(results)\n",
307
+ "print(\"\\n📊 AGGREGATE STATS (Mean Resultant Length R):\")\n",
308
+ "print(df.groupby(\"Type\")[\"R\"].describe())\n",
309
+ "\n",
310
+ "# ==========================================\n",
311
+ "# 6. VISUALIZATION (GRID 3x3)\n",
312
+ "# ==========================================\n",
313
+ "# Select Top 3 from each category to show clearest examples\n",
314
+ "synonyms = sorted([p for p in valid_pairs if p[\"type\"] == \"Synonym\"], key=lambda x: x[\"R\"], reverse=True)[:3]\n",
315
+ "antonyms = sorted([p for p in valid_pairs if p[\"type\"] == \"Antonym\"], key=lambda x: x[\"R\"], reverse=True)[:3]\n",
316
+ "randoms = sorted([p for p in valid_pairs if p[\"type\"] == \"Random\"], key=lambda x: x[\"R\"], reverse=False)[:3] # Lowest R for randoms\n",
317
+ "\n",
318
+ "plot_list = synonyms + antonyms + randoms\n",
319
+ "\n",
320
+ "if len(plot_list) > 0:\n",
321
+ " fig = plt.figure(figsize=(15, 12))\n",
322
+ " fig.suptitle(\"Semantic Phase Compass: Synonyms vs Antonyms vs Randoms\", fontsize=16, y=0.98)\n",
323
+ "\n",
324
+ " # Colors for categories\n",
325
+ " colors = {\"Synonym\": \"red\", \"Antonym\": \"purple\", \"Random\": \"blue\"}\n",
326
+ "\n",
327
+ " for i, item in enumerate(plot_list):\n",
328
+ " ax = fig.add_subplot(3, 3, i+1, projection='polar')\n",
329
+ "\n",
330
+ " ptype = item[\"type\"]\n",
331
+ " color = colors[ptype]\n",
332
+ "\n",
333
+ " # Weighted Histogram of Phase Differences\n",
334
+ " ax.hist(item[\"diffs\"], bins=30, weights=item[\"weights\"], color=color, alpha=0.7, density=True)\n",
335
+ "\n",
336
+ " # Mean Vector Arrow (The \"Compass Needle\")\n",
337
+ " # Length of arrow = R (Coherence Strength)\n",
338
+ " ax.annotate(\"\", xy=(item[\"angle\"], item[\"R\"]), xytext=(0,0),\n",
339
+ " arrowprops=dict(facecolor='black', width=2, headwidth=10))\n",
340
+ "\n",
341
+ " # Styling\n",
342
+ " ax.set_title(f\"{item['w1']} - {item['w2']}\\n{ptype}\\nR = {item['R']:.3f}\", fontsize=11)\n",
343
+ " ax.set_yticklabels([]) # Hide radial labels\n",
344
+ " ax.set_xticklabels([]) # Hide angular labels\n",
345
+ "\n",
346
+ " plt.tight_layout()\n",
347
+ " plt.savefig(\"phase_compass_extended.png\", dpi=300)\n",
348
+ " plt.show()\n",
349
+ " print(\"\\n📸 Saved plot to 'phase_compass_extended.png'\")\n",
350
+ "else:\n",
351
+ " print(\"⚠️ Not enough valid pairs to generate plot.\")"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "source": [
357
+ "import torch\n",
358
+ "import numpy as np\n",
359
+ "import matplotlib.pyplot as plt\n",
360
+ "import pandas as pd\n",
361
+ "import seaborn as sns\n",
362
+ "\n",
363
+ "# ... [Keep your Candidate Lists & Helper Functions from before] ...\n",
364
+ "\n",
365
+ "# ==========================================\n",
366
+ "# 4. EXECUTE ANALYSIS & SELECT BEST EXAMPLES\n",
367
+ "# ==========================================\n",
368
+ "valid_pairs = []\n",
369
+ "results = []\n",
370
+ "\n",
371
+ "print(f\"🚀 Running Phase Compass on {len(candidates_raw)} pairs...\")\n",
372
+ "\n",
373
+ "for w1, w2, ptype in candidates_raw:\n",
374
+ " s1, id1 = is_single_token(w1)\n",
375
+ " s2, id2 = is_single_token(w2)\n",
376
+ "\n",
377
+ " if s1 and s2:\n",
378
+ " R, angle, diffs, weights = calculate_coherence(id1, id2)\n",
379
+ " valid_pairs.append({\n",
380
+ " \"w1\": w1, \"w2\": w2, \"type\": ptype,\n",
381
+ " \"R\": R, \"angle\": angle, \"diffs\": diffs, \"weights\": weights\n",
382
+ " })\n",
383
+ " results.append({\"Pair\": f\"{w1}-{w2}\", \"Type\": ptype, \"R\": R})\n",
384
+ "\n",
385
+ "# ==========================================\n",
386
+ "# 5. VISUALIZATION (1x3 STRIP)\n",
387
+ "# ==========================================\n",
388
+ "# Select the \"Best\" example for each category (Highest R for Syn/Ant, Lowest for Random)\n",
389
+ "best_syn = max([p for p in valid_pairs if p[\"type\"] == \"Synonym\"], key=lambda x: x[\"R\"])\n",
390
+ "best_ant = max([p for p in valid_pairs if p[\"type\"] == \"Antonym\"], key=lambda x: x[\"R\"])\n",
391
+ "best_rnd = min([p for p in valid_pairs if p[\"type\"] == \"Random\"], key=lambda x: x[\"R\"])\n",
392
+ "\n",
393
+ "plot_list = [best_syn, best_ant, best_rnd]\n",
394
+ "titles = [\"A. Synonyms (High Coherence)\", \"B. Antonyms (High Coherence)\", \"C. Unrelated (Random Phase)\"]\n",
395
+ "colors = [\"#d62728\", \"#9467bd\", \"#7f7f7f\"] # Red, Purple, Gray\n",
396
+ "\n",
397
+ "fig = plt.figure(figsize=(12, 4)) # Wide, Short aspect ratio\n",
398
+ "\n",
399
+ "for i, item in enumerate(plot_list):\n",
400
+ " ax = fig.add_subplot(1, 3, i+1, projection='polar')\n",
401
+ "\n",
402
+ " # 1. Circular Histogram (The \"Cloud\")\n",
403
+ " # We use 'weights' to show that high-energy frequencies matter more\n",
404
+ " ax.hist(item[\"diffs\"], bins=40, weights=item[\"weights\"], color=colors[i], alpha=0.6, density=True)\n",
405
+ "\n",
406
+ " # 2. Mean Resultant Vector (The \"Needle\")\n",
407
+ " # The length of this arrow is the PROOF of phase locking.\n",
408
+ " ax.annotate(\"\", xy=(item[\"angle\"], item[\"R\"]), xytext=(0,0),\n",
409
+ " arrowprops=dict(facecolor='black', width=1.5, headwidth=8, alpha=0.9))\n",
410
+ "\n",
411
+ " # 3. Styling\n",
412
+ " ax.set_title(f\"{titles[i]}\\n'{item['w1']}' - '{item['w2']}'\\n$R = {item['R']:.2f}$\",\n",
413
+ " fontsize=10, fontweight='bold', pad=10)\n",
414
+ " ax.set_yticklabels([]) # Hide radial numbers\n",
415
+ " ax.set_xticklabels([]) # Hide degree numbers\n",
416
+ " ax.grid(True, alpha=0.3)\n",
417
+ " ax.set_ylim(0, 0.6) # Fix scale for fair comparison\n",
418
+ "\n",
419
+ "plt.tight_layout()\n",
420
+ "plt.savefig(\"fig_compass_1x3.png\", dpi=300, bbox_inches='tight')\n",
421
+ "plt.show()\n",
422
+ "\n",
423
+ "# ==========================================\n",
424
+ "# 6. STATISTICAL TABLE OUTPUT\n",
425
+ "# ==========================================\n",
426
+ "df = pd.DataFrame(results)\n",
427
+ "print(\"\\n📊 PHASE LOCKING STATISTICS (Mean Resultant Length R)\")\n",
428
+ "print(\"=\"*60)\n",
429
+ "print(f\"{'Category':<15} | {'Mean R':<10} | {'Std Dev':<10} | {'Count'}\")\n",
430
+ "print(\"-\" * 60)\n",
431
+ "stats = df.groupby(\"Type\")[\"R\"].agg(['mean', 'std', 'count'])\n",
432
+ "for idx, row in stats.iterrows():\n",
433
+ " print(f\"{idx:<15} | {row['mean']:.4f} | {row['std']:.4f} | {int(row['count'])}\")"
434
+ ],
435
+ "metadata": {
436
+ "id": "GzwkznYXTJpL"
437
+ },
438
+ "execution_count": null,
439
+ "outputs": []
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "source": [
444
+ "# ==========================================\n",
445
+ "# 0. SETUP & DEPENDENCIES\n",
446
+ "# ==========================================\n",
447
+ "import torch\n",
448
+ "import torch.nn as nn\n",
449
+ "import torch.nn.functional as F\n",
450
+ "import math\n",
451
+ "import gc\n",
452
+ "import pandas as pd\n",
453
+ "import numpy as np\n",
454
+ "import matplotlib.pyplot as plt\n",
455
+ "import json\n",
456
+ "from transformers import RobertaTokenizerFast\n",
457
+ "from huggingface_hub import hf_hub_download\n",
458
+ "from x_transformers import TransformerWrapper, Encoder\n",
459
+ "\n",
460
+ "# Global Config\n",
461
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
462
+ "SEQ_LEN = 4096\n",
463
+ "MAX_VOCAB_SIZE = 32768\n",
464
+ "TOKENIZER_ID = \"prism-lab/wikitext-103-prism-32k-seq4k\" # <--- YOUR REPO\n",
465
+ "\n",
466
+ "print(f\"🔥 Initializing Phase Compass Analysis on {DEVICE}\")\n",
467
+ "\n",
468
+ "# ==========================================\n",
469
+ "# 1. ARCHITECTURE DEFINITIONS\n",
470
+ "# ==========================================\n",
471
+ "# (Standard Definitions - Collapsed for brevity)\n",
472
+ "class ComplexDropout(nn.Module):\n",
473
+ " def __init__(self, p=0.0): super().__init__(); self.p = p\n",
474
+ " def forward(self, z): return z\n",
475
+ "class RobustPhaseNorm(nn.Module):\n",
476
+ " def __init__(self, d, eps=1e-5): super().__init__(); self.scale = nn.Parameter(torch.ones(d)); self.eps = eps\n",
477
+ " def forward(self, x): return (x / torch.sqrt((x.abs()**2).mean(-1, keepdim=True) + self.eps)) * self.scale\n",
478
+ "class ModReLU(nn.Module):\n",
479
+ " def __init__(self, f): super().__init__(); self.b = nn.Parameter(torch.zeros(f))\n",
480
+ " def forward(self, z): return F.relu(z.abs() + self.b) * (z / (z.abs() + 1e-6))\n",
481
+ "class ComplexToRealBridge(nn.Module):\n",
482
+ " def __init__(self, d): super().__init__(); self.proj = nn.Linear(d*2, d); self.norm = nn.LayerNorm(d)\n",
483
+ " def forward(self, x): return self.norm(self.proj(torch.cat([x.real, x.imag], -1)))\n",
484
+ "class DynamicRoSE(nn.Module):\n",
485
+ " def __init__(self, n, d):\n",
486
+ " super().__init__(); self.raw_embedding = nn.Embedding(n, d); self.adapter = nn.Linear(d, d*2); self.rotation_predictor = nn.Linear(d, d*2)\n",
487
+ " self.register_buffer('freqs', torch.exp(torch.arange(0, d) * -(math.log(10000.0)/d)))\n",
488
+ " def forward(self, x):\n",
489
+ " real = self.raw_embedding(x); params = self.adapter(real); D = real.shape[-1]\n",
490
+ " z = torch.complex(params[...,:D], params[...,D:]); r = self.rotation_predictor(real); rx, ry = r.chunk(2, -1)\n",
491
+ " drot = torch.complex(rx/torch.sqrt(rx**2+ry**2+1e-6), ry/torch.sqrt(rx**2+ry**2+1e-6))\n",
492
+ " pos = torch.arange(real.shape[1], device=x.device).float()\n",
493
+ " srot = torch.polar(torch.ones_like(torch.outer(pos, self.freqs)), torch.outer(pos, self.freqs))\n",
494
+ " return (z * srot.unsqueeze(0) * drot), real\n",
495
+ "class HyenaNeuralFilter(nn.Module):\n",
496
+ " def __init__(self, d, max_len=1024, h=64):\n",
497
+ " super().__init__(); self.d = d; self.register_buffer(\"freqs\", torch.exp(torch.arange(0, h, 2) * -(math.log(10000.0)/h)))\n",
498
+ " self.mlp = nn.Sequential(nn.Linear(h, h), nn.SiLU(), nn.Linear(h, h), nn.SiLU(), nn.Linear(h, d*2))\n",
499
+ " def forward(self, L, dev):\n",
500
+ " t = torch.linspace(0, 1, steps=L, device=dev).unsqueeze(-1)\n",
501
+ " emb = torch.cat([torch.sin(t*self.freqs), torch.cos(t*self.freqs)], -1)\n",
502
+ " out = self.mlp(emb).view(L, self.d, 2); return torch.complex(out[...,0], out[...,1])\n",
503
+ "class GatedHarmonicConvolution(nn.Module):\n",
504
+ " def __init__(self, d, max_len):\n",
505
+ " super().__init__(); self.d=d; self.filter_len=max_len; self.neural_filter = HyenaNeuralFilter(d, max_len)\n",
506
+ " self.gate_proj = nn.Linear(d*2, d*2); self.mix_real = nn.Linear(d,d); self.mix_imag = nn.Linear(d,d)\n",
507
+ " self.out_real = nn.Linear(d,d); self.out_imag = nn.Linear(d,d); self.activation = ModReLU(d); self.norm = RobustPhaseNorm(d)\n",
508
+ " self.dropout = ComplexDropout(0.0)\n",
509
+ " def forward(self, x, mask=None):\n",
510
+ " res = x; x = self.norm(x); B,L,D = x.shape; eff_L = min(L, self.filter_len)\n",
511
+ " h = self.neural_filter(eff_L, x.device).unsqueeze(0)\n",
512
+ " xt = torch.fft.ifft(torch.fft.fft(x, n=eff_L, dim=1, norm='ortho') * h, n=eff_L, dim=1, norm='ortho')\n",
513
+ " if L > eff_L: xt = F.pad(xt, (0,0,0,L-eff_L));\n",
514
+ " else: xt = xt[:, :L, :]\n",
515
+ " g = torch.sigmoid(self.gate_proj(torch.cat([x.real, x.imag], -1))); gr, gi = g.chunk(2, -1)\n",
516
+ " xg = torch.complex(xt.real*gr, xt.imag*gi); mr, mi = self.mix_real, self.mix_imag\n",
517
+ " xm = torch.complex(mr(xg.real)-mi(xg.imag), mr(xg.imag)+mi(xg.real)); xa = self.activation(xm); or_, oi = self.out_real, self.out_imag\n",
518
+ " out = torch.complex(or_(xa.real)-oi(xa.imag), or_(xa.imag)+oi(xa.real))\n",
519
+ " return self.dropout(out) + res\n",
520
+ "class PRISMEncoder(nn.Module):\n",
521
+ " def __init__(self, l, d, max_l): super().__init__(); self.layers = nn.ModuleList([GatedHarmonicConvolution(d, max_l) for _ in range(l)]); self.final_norm = RobustPhaseNorm(d)\n",
522
+ " def forward(self, x):\n",
523
+ " for layer in self.layers: x = layer(x)\n",
524
+ " return self.final_norm(x)\n",
525
+ "\n",
526
+ "# --- A. BASELINE (Transformer) ---\n",
527
+ "class LocalBaseline(nn.Module):\n",
528
+ " def __init__(self, vocab_size):\n",
529
+ " super().__init__()\n",
530
+ " self.model = TransformerWrapper(\n",
531
+ " num_tokens=vocab_size, max_seq_len=SEQ_LEN, use_abs_pos_emb=False, tie_embedding=True,\n",
532
+ " attn_layers=Encoder(dim=512, depth=5, heads=8, rotary_pos_emb=True, attn_flash=True, use_scalenorm=False)\n",
533
+ " )\n",
534
+ " def forward(self, x): return self.model(x)\n",
535
+ "\n",
536
+ "# --- B. FNET (Hybrid) ---\n",
537
+ "class FNetBlock(nn.Module):\n",
538
+ " def __init__(self, d, df):\n",
539
+ " super().__init__(); self.norm_mix = nn.LayerNorm(d); self.norm_ff = nn.LayerNorm(d)\n",
540
+ " self.ff = nn.Sequential(nn.Linear(d, df), nn.GELU(), nn.Dropout(0), nn.Linear(df, d), nn.Dropout(0))\n",
541
+ " def forward(self, x):\n",
542
+ " r = x; x = self.norm_mix(x); x = torch.fft.fftn(x.float(), dim=(-2,-1), norm='ortho').real.to(r.dtype); x = x+r\n",
543
+ " r = x; x = self.norm_ff(x); x = self.ff(x); return x+r\n",
544
+ "class FNetEncoder(nn.Module):\n",
545
+ " def __init__(self, depth, d, df): super().__init__(); self.layers = nn.ModuleList([FNetBlock(d, df) for _ in range(depth)]); self.norm_out = nn.LayerNorm(d)\n",
546
+ " def forward(self, x):\n",
547
+ " for l in self.layers: x = l(x)\n",
548
+ " return self.norm_out(x)\n",
549
+ "class HybridFNetMLM(nn.Module):\n",
550
+ " def __init__(self, vocab_size):\n",
551
+ " super().__init__()\n",
552
+ " self.token_emb = nn.Embedding(vocab_size, 512); self.pos_emb = nn.Parameter(torch.zeros(1, SEQ_LEN, 512))\n",
553
+ " self.fnet_encoder = FNetEncoder(6, 512, 2048)\n",
554
+ " self.transformer_cap = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
555
+ " self.final_norm = nn.LayerNorm(512); self.to_logits = nn.Linear(512, vocab_size)\n",
556
+ " self.to_logits.weight = self.token_emb.weight # Tie\n",
557
+ " def forward(self, x):\n",
558
+ " h = self.token_emb(x) + self.pos_emb[:, :x.shape[1], :]\n",
559
+ " return self.to_logits(self.final_norm(self.transformer_cap(self.fnet_encoder(h))))\n",
560
+ "\n",
561
+ "# --- C. PRISM (Phase Coder) ---\n",
562
+ "class LocalPRISM(nn.Module):\n",
563
+ " def __init__(self, vocab_size):\n",
564
+ " super().__init__()\n",
565
+ " self.rose = DynamicRoSE(vocab_size, 512); self.prism_encoder = PRISMEncoder(5, 512, SEQ_LEN)\n",
566
+ " self.bridge = ComplexToRealBridge(512); self.periscope_proj = nn.Sequential(nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU())\n",
567
+ " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
568
+ " self.lm_head = nn.Linear(512, vocab_size); self.lm_head.weight = self.rose.raw_embedding.weight # Tie\n",
569
+ " def forward(self, x):\n",
570
+ " w, p = self.rose(x); w = self.bridge(self.prism_encoder(w))\n",
571
+ " return self.lm_head(self.refiner(self.periscope_proj(torch.cat([w, p], -1))))\n",
572
+ "\n",
573
+ "# --- D. PILLARS (Split-Stream) ---\n",
574
+ "class LocalPillars(nn.Module):\n",
575
+ " def __init__(self, vocab_size):\n",
576
+ " super().__init__()\n",
577
+ " self.rose = DynamicRoSE(vocab_size, 512); self.particle_down = nn.Linear(512, 256); self.wave_down = nn.Linear(1024, 512)\n",
578
+ " self.fnet_pos = nn.Embedding(SEQ_LEN, 256); self.stream_rate = FNetEncoder(9, 256, 1024)\n",
579
+ " self.stream_phase = PRISMEncoder(9, 256, SEQ_LEN); self.phase_bridge = ComplexToRealBridge(256)\n",
580
+ " self.fusion_proj = nn.Linear(512, 512); self.fusion_norm = nn.LayerNorm(512)\n",
581
+ " self.refiner = Encoder(dim=512, depth=1, heads=8, rotary_pos_emb=True, attn_flash=True)\n",
582
+ " self.head_bias = nn.Parameter(torch.zeros(vocab_size))\n",
583
+ " def forward(self, x):\n",
584
+ " w, p = self.rose(x); p_sm = self.particle_down(p); w_raw = self.wave_down(torch.cat([w.real, w.imag], -1))\n",
585
+ " w_sm = torch.complex(w_raw[...,:256], w_raw[...,256:])\n",
586
+ " p_path = self.stream_rate(p_sm + self.fnet_pos(torch.arange(x.shape[1], device=x.device)))\n",
587
+ " w_path = self.phase_bridge(self.stream_phase(w_sm))\n",
588
+ " ctx = self.fusion_norm(self.fusion_proj(torch.cat([p_path, w_path], -1)))\n",
589
+ " return F.linear(self.refiner(ctx), self.rose.raw_embedding.weight, self.head_bias)\n",
590
+ "\n",
591
+ "# ==========================================\n",
592
+ "# 2. ANALYSIS LOGIC\n",
593
+ "# ==========================================\n",
594
+ "def smart_load(repo_id, name, cls):\n",
595
+ " # Init Model\n",
596
+ " model = cls(vocab_size=MAX_VOCAB_SIZE).to(DEVICE)\n",
597
+ " print(f\"⬇️ Downloading weights for {name}...\")\n",
598
+ " try: path = hf_hub_download(repo_id, \"best.pt\")\n",
599
+ " except: path = hf_hub_download(repo_id, \"pytorch_model.bin\")\n",
600
+ "\n",
601
+ " state_dict = torch.load(path, map_location=\"cpu\")\n",
602
+ " if 'model' in state_dict: state_dict = state_dict['model']\n",
603
+ " clean = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n",
604
+ "\n",
605
+ " # FIXES for Baseline/FNet\n",
606
+ " if name == \"Baseline\":\n",
607
+ " new_d = {}\n",
608
+ " for k, v in clean.items():\n",
609
+ " nk = k if k.startswith(\"model.\") else \"model.\" + k\n",
610
+ " if \"token_emb.weight\" in nk and \"emb\" not in nk: nk = nk.replace(\"token_emb.weight\", \"token_emb.emb.weight\")\n",
611
+ " new_d[nk] = v\n",
612
+ " clean = new_d\n",
613
+ " elif name == \"FNet\":\n",
614
+ " new_d = {}\n",
615
+ " for k, v in clean.items():\n",
616
+ " nk = k.replace(\"model.\", \"\")\n",
617
+ " new_d[nk] = v\n",
618
+ " clean = new_d\n",
619
+ "\n",
620
+ " model.load_state_dict(clean, strict=False)\n",
621
+ " print(f\"✅ {name} Ready.\")\n",
622
+ " return model\n",
623
+ "\n",
624
+ "def extract_phasor(model, name, token_id):\n",
625
+ " with torch.no_grad():\n",
626
+ " token_tensor = torch.tensor([token_id], device=DEVICE)\n",
627
+ " if name in [\"PRISM\", \"PILLARS\"]:\n",
628
+ " real_emb = model.rose.raw_embedding(token_tensor)\n",
629
+ " params = model.rose.adapter(real_emb)\n",
630
+ " D = real_emb.shape[-1]\n",
631
+ " z = torch.complex(params[...,:D], params[...,D:])\n",
632
+ " return z.squeeze(0).cpu()\n",
633
+ " elif name == \"FNet\":\n",
634
+ " x = model.token_emb(token_tensor)\n",
635
+ " return torch.complex(x, torch.zeros_like(x)).squeeze(0).cpu()\n",
636
+ " return None\n",
637
+ "\n",
638
+ "def calculate_coherence_dynamic(model, name, id_a, id_b):\n",
639
+ " za = extract_phasor(model, name, id_a)\n",
640
+ " zb = extract_phasor(model, name, id_b)\n",
641
+ " diff = torch.angle(za) - torch.angle(zb)\n",
642
+ " weights = torch.abs(za) * torch.abs(zb)\n",
643
+ "\n",
644
+ " diff_np = diff.numpy()\n",
645
+ " weights_np = weights.numpy()\n",
646
+ " weighted_complex_diffs = weights_np * np.exp(1j * diff_np)\n",
647
+ " mean_vector = np.sum(weighted_complex_diffs) / (np.sum(weights_np) + 1e-9)\n",
648
+ " return np.abs(mean_vector), np.angle(mean_vector), diff_np, weights_np\n",
649
+ "\n",
650
+ "# ==========================================\n",
651
+ "# 3. ROBUST CANDIDATE LIST (N = 135)\n",
652
+ "# ==========================================\n",
653
+ "candidates_raw = [\n",
654
+ " # --- SYNONYMS (Positive Correlation) ---\n",
655
+ " (\"fast\", \"quick\", \"Synonym\"), (\"big\", \"large\", \"Synonym\"), (\"small\", \"little\", \"Synonym\"),\n",
656
+ " (\"start\", \"begin\", \"Synonym\"), (\"end\", \"finish\", \"Synonym\"), (\"smart\", \"clever\", \"Synonym\"),\n",
657
+ " (\"hard\", \"tough\", \"Synonym\"), (\"simple\", \"easy\", \"Synonym\"), (\"happy\", \"glad\", \"Synonym\"),\n",
658
+ " (\"sad\", \"unhappy\", \"Synonym\"), (\"angry\", \"mad\", \"Synonym\"), (\"correct\", \"right\", \"Synonym\"),\n",
659
+ " (\"wrong\", \"incorrect\", \"Synonym\"), (\"shut\", \"close\", \"Synonym\"), (\"buy\", \"purchase\", \"Synonym\"),\n",
660
+ " (\"choose\", \"select\", \"Synonym\"), (\"gift\", \"present\", \"Synonym\"), (\"job\", \"work\", \"Synonym\"),\n",
661
+ " (\"trip\", \"journey\", \"Synonym\"), (\"lady\", \"woman\", \"Synonym\"), (\"guy\", \"man\", \"Synonym\"),\n",
662
+ " (\"street\", \"road\", \"Synonym\"), (\"stone\", \"rock\", \"Synonym\"), (\"speak\", \"talk\", \"Synonym\"),\n",
663
+ " (\"listen\", \"hear\", \"Synonym\"), (\"look\", \"see\", \"Synonym\"), (\"run\", \"sprint\", \"Synonym\"),\n",
664
+ " (\"jump\", \"leap\", \"Synonym\"), (\"scary\", \"afraid\", \"Synonym\"), (\"rich\", \"wealthy\", \"Synonym\"),\n",
665
+ " (\"weird\", \"strange\", \"Synonym\"), (\"quiet\", \"silent\", \"Synonym\"), (\"loud\", \"noisy\", \"Synonym\"),\n",
666
+ " (\"trash\", \"garbage\", \"Synonym\"), (\"sick\", \"ill\", \"Synonym\"), (\"thin\", \"slim\", \"Synonym\"),\n",
667
+ " (\"near\", \"close\", \"Synonym\"), (\"far\", \"distant\", \"Synonym\"), (\"safe\", \"secure\", \"Synonym\"),\n",
668
+ " (\"fix\", \"repair\", \"Synonym\"), (\"mix\", \"blend\", \"Synonym\"), (\"keep\", \"hold\", \"Synonym\"),\n",
669
+ " (\"push\", \"shove\", \"Synonym\"), (\"pull\", \"drag\", \"Synonym\"), (\"under\", \"below\", \"Synonym\"),\n",
670
+ " (\"above\", \"over\", \"Synonym\"), (\"center\", \"middle\", \"Synonym\"), (\"area\", \"zone\", \"Synonym\"),\n",
671
+ "\n",
672
+ " # --- ANTONYMS (Negative Correlation / Phase Shift) ---\n",
673
+ " (\"good\", \"bad\", \"Antonym\"), (\"hot\", \"cold\", \"Antonym\"), (\"high\", \"low\", \"Antonym\"),\n",
674
+ " (\"up\", \"down\", \"Antonym\"), (\"left\", \"right\", \"Antonym\"), (\"in\", \"out\", \"Antonym\"),\n",
675
+ " (\"black\", \"white\", \"Antonym\"), (\"day\", \"night\", \"Antonym\"), (\"sun\", \"moon\", \"Antonym\"),\n",
676
+ " (\"boy\", \"girl\", \"Antonym\"), (\"man\", \"woman\", \"Antonym\"), (\"king\", \"queen\", \"Antonym\"),\n",
677
+ " (\"life\", \"death\", \"Antonym\"), (\"war\", \"peace\", \"Antonym\"), (\"win\", \"lose\", \"Antonym\"),\n",
678
+ " (\"rich\", \"poor\", \"Antonym\"), (\"strong\", \"weak\", \"Antonym\"), (\"hard\", \"soft\", \"Antonym\"),\n",
679
+ " (\"loud\", \"quiet\", \"Antonym\"), (\"wet\", \"dry\", \"Antonym\"), (\"clean\", \"dirty\", \"Antonym\"),\n",
680
+ " (\"happy\", \"sad\", \"Antonym\"), (\"full\", \"empty\", \"Antonym\"), (\"open\", \"close\", \"Antonym\"),\n",
681
+ " (\"first\", \"last\", \"Antonym\"), (\"young\", \"old\", \"Antonym\"), (\"new\", \"old\", \"Antonym\"),\n",
682
+ " (\"fast\", \"slow\", \"Antonym\"), (\"tall\", \"short\", \"Antonym\"), (\"heavy\", \"light\", \"Antonym\"),\n",
683
+ " (\"dark\", \"light\", \"Antonym\"), (\"true\", \"false\", \"Antonym\"), (\"yes\", \"no\", \"Antonym\"),\n",
684
+ " (\"on\", \"off\", \"Antonym\"), (\"top\", \"bottom\", \"Antonym\"), (\"friend\", \"enemy\", \"Antonym\"),\n",
685
+ " (\"give\", \"take\", \"Antonym\"), (\"come\", \"go\", \"Antonym\"), (\"rise\", \"fall\", \"Antonym\"),\n",
686
+ " (\"north\", \"south\", \"Antonym\"), (\"east\", \"west\", \"Antonym\"), (\"buy\", \"sell\", \"Antonym\"),\n",
687
+ " (\"love\", \"hate\", \"Antonym\"), (\"win\", \"fail\", \"Antonym\"), (\"start\", \"stop\", \"Antonym\"),\n",
688
+ "\n",
689
+ " # --- RANDOM (Noise Floor) ---\n",
690
+ " (\"apple\", \"car\", \"Random\"), (\"banana\", \"sky\", \"Random\"), (\"bread\", \"cloud\", \"Random\"),\n",
691
+ " (\"cheese\", \"door\", \"Random\"), (\"milk\", \"shoe\", \"Random\"), (\"water\", \"book\", \"Random\"),\n",
692
+ " (\"coffee\", \"tree\", \"Random\"), (\"sugar\", \"phone\", \"Random\"), (\"salt\", \"idea\", \"Random\"),\n",
693
+ " (\"meat\", \"ghost\", \"Random\"), (\"soup\", \"math\", \"Random\"), (\"cake\", \"song\", \"Random\"),\n",
694
+ " (\"pie\", \"fish\", \"Random\"), (\"egg\", \"wall\", \"Random\"), (\"rice\", \"nose\", \"Random\"),\n",
695
+ " (\"tea\", \"frog\", \"Random\"), (\"juice\", \"star\", \"Random\"), (\"fruit\", \"chair\", \"Random\"),\n",
696
+ " (\"lemon\", \"fear\", \"Random\"), (\"melon\", \"bell\", \"Random\"), (\"berry\", \"law\", \"Random\"),\n",
697
+ " (\"grape\", \"dog\", \"Random\"), (\"plum\", \"cat\", \"Random\"), (\"pear\", \"bird\", \"Random\"),\n",
698
+ " (\"lime\", \"rock\", \"Random\"), (\"kiwi\", \"mud\", \"Random\"), (\"bean\", \"joy\", \"Random\"),\n",
699
+ " (\"corn\", \"ice\", \"Random\"), (\"nut\", \"wind\", \"Random\"), (\"fig\", \"pen\", \"Random\"),\n",
700
+ " (\"yam\", \"bus\", \"Random\"), (\"beef\", \"sun\", \"Random\"), (\"pork\", \"hat\", \"Random\"),\n",
701
+ " (\"lamb\", \"ink\", \"Random\"), (\"duck\", \"map\", \"Random\"), (\"goat\", \"art\", \"Random\"),\n",
702
+ " (\"cow\", \"box\", \"Random\"), (\"pig\", \"oil\", \"Random\"), (\"hen\", \"gas\", \"Random\"),\n",
703
+ " (\"fox\", \"cup\", \"Random\"), (\"wolf\", \"key\", \"Random\"), (\"ant\", \"bed\", \"Random\"),\n",
704
+ " (\"bee\", \"rug\", \"Random\"), (\"fly\", \"mud\", \"Random\"), (\"worm\", \"sky\", \"Random\")\n",
705
+ "]\n",
706
+ "\n",
707
+ "# ==========================================\n",
708
+ "# 4. EXECUTION WITH YOUR CUSTOM TOKENIZER\n",
709
+ "# ==========================================\n",
710
+ "print(f\"🔑 Loading Tokenizer from {TOKENIZER_ID}...\")\n",
711
+ "try:\n",
712
+ " tokenizer = RobertaTokenizerFast.from_pretrained(TOKENIZER_ID)\n",
713
+ "except:\n",
714
+ " print(\"⚠️ Fallback to base tokenizer if custom fails (Should not happen)\")\n",
715
+ " tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
716
+ "\n",
717
+ "valid_pairs = []\n",
718
+ "print(f\"🔎 Validating {len(candidates_raw)} candidate pairs...\")\n",
719
+ "\n",
720
+ "# Adding a space prefix \" \" is standard for RoBERTa tokenizers if words are start of sentence\n",
721
+ "# but we check both raw and space-prefixed to be safe.\n",
722
+ "for w1, w2, ptype in candidates_raw:\n",
723
+ " # Try with space prefix which RoBERTa often uses for words\n",
724
+ " ids1 = tokenizer.encode(\" \" + w1, add_special_tokens=False)\n",
725
+ " ids2 = tokenizer.encode(\" \" + w2, add_special_tokens=False)\n",
726
+ "\n",
727
+ " # Fallback to raw if space fails\n",
728
+ " if len(ids1) != 1: ids1 = tokenizer.encode(w1, add_special_tokens=False)\n",
729
+ " if len(ids2) != 1: ids2 = tokenizer.encode(w2, add_special_tokens=False)\n",
730
+ "\n",
731
+ " if len(ids1) == 1 and len(ids2) == 1:\n",
732
+ " id1, id2 = ids1[0], ids2[0]\n",
733
+ " if id1 < MAX_VOCAB_SIZE and id2 < MAX_VOCAB_SIZE:\n",
734
+ " valid_pairs.append((w1, w2, ptype, id1, id2))\n",
735
+ "\n",
736
+ "print(f\"✅ Found {len(valid_pairs)} valid single-token pairs for this tokenizer.\")\n",
737
+ "\n",
738
+ "MODELS_TO_TEST = [\n",
739
+ " (\"PRISM\", \"prism-lab/prism-v2-wikitext\", LocalPRISM),\n",
740
+ " (\"PILLARS\", \"prism-lab/pillars-compact-wikitext\", LocalPillars),\n",
741
+ " (\"FNet\", \"prism-lab/hybrid-fnet-prism-custom\", HybridFNetMLM)\n",
742
+ "]\n",
743
+ "\n",
744
+ "all_results = {}\n",
745
+ "\n",
746
+ "for name, repo, cls in MODELS_TO_TEST:\n",
747
+ " print(f\"\\n🧪 Analyzing {name}...\")\n",
748
+ " try:\n",
749
+ " model = smart_load(repo, name, cls)\n",
750
+ " model.eval()\n",
751
+ "\n",
752
+ " results = []\n",
753
+ " for w1, w2, ptype, id1, id2 in valid_pairs:\n",
754
+ " R, angle, diffs, weights = calculate_coherence_dynamic(model, name, id1, id2)\n",
755
+ " results.append({\"Pair\": f\"{w1}-{w2}\", \"Type\": ptype, \"R\": R, \"Diffs\": diffs, \"Weights\": weights})\n",
756
+ "\n",
757
+ " all_results[name] = results\n",
758
+ " df = pd.DataFrame(results)\n",
759
+ " print(f\"📊 {name} Results (Mean R):\")\n",
760
+ " if not df.empty:\n",
761
+ " print(df.groupby(\"Type\")[\"R\"].mean())\n",
762
+ " del model; torch.cuda.empty_cache(); gc.collect()\n",
763
+ " except Exception as e:\n",
764
+ " print(f\"❌ {name} Failed: {e}\")\n",
765
+ "\n",
766
+ "# Plotting\n",
767
+ "if len(all_results) > 0:\n",
768
+ " fig = plt.figure(figsize=(12, 10))\n",
769
+ " cols = [\"Synonym\", \"Antonym\", \"Random\"]\n",
770
+ " rows = list(all_results.keys())\n",
771
+ " colors = {\"Synonym\": \"red\", \"Antonym\": \"purple\", \"Random\": \"gray\"}\n",
772
+ "\n",
773
+ " idx = 1\n",
774
+ " for model_name in rows:\n",
775
+ " data = all_results[model_name]\n",
776
+ " df = pd.DataFrame(data)\n",
777
+ " if df.empty: continue\n",
778
+ "\n",
779
+ " best_syn = df[df[\"Type\"]==\"Synonym\"].sort_values(\"R\", ascending=False).iloc[0]\n",
780
+ " best_ant = df[df[\"Type\"]==\"Antonym\"].sort_values(\"R\", ascending=False).iloc[0]\n",
781
+ " best_rnd = df[df[\"Type\"]==\"Random\"].sort_values(\"R\", ascending=True).iloc[0]\n",
782
+ "\n",
783
+ " examples = [best_syn, best_ant, best_rnd]\n",
784
+ " for i, ex in enumerate(examples):\n",
785
+ " ax = fig.add_subplot(len(rows), 3, idx, projection='polar')\n",
786
+ " c = colors[ex[\"Type\"]]\n",
787
+ " ax.hist(ex[\"Diffs\"], bins=30, weights=ex[\"Weights\"], color=c, alpha=0.6, density=True)\n",
788
+ " ax.annotate(\"\", xy=(0, ex[\"R\"]), xytext=(0,0), arrowprops=dict(facecolor='black', width=1.5, headwidth=8, alpha=0.9))\n",
789
+ "\n",
790
+ " label = f\"{ex['Pair']}\\nR={ex['R']:.3f}\"\n",
791
+ " if i == 1: ax.set_title(f\"Model: {model_name}\\n{label}\", fontsize=10, weight='bold')\n",
792
+ " else: ax.set_title(label, fontsize=9)\n",
793
+ "\n",
794
+ " ax.set_yticklabels([]); ax.set_xticklabels([])\n",
795
+ " idx += 1\n",
796
+ "\n",
797
+ " plt.tight_layout()\n",
798
+ " plt.savefig(\"multi_model_compass_publication.png\", dpi=300)\n",
799
+ " print(\"\\n📸 Saved plot to 'multi_model_compass_publication.png'\")"
800
+ ],
801
+ "metadata": {
802
+ "id": "2-elQ3KH6aNg"
803
+ },
804
+ "execution_count": null,
805
+ "outputs": []
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "source": [
810
+ "# ========================\n",
811
+ "# 6. ANGULAR TOPOLOGY PLOT\n",
812
+ "# ========================\n",
813
+ "import seaborn as sns\n",
814
+ "\n",
815
+ "def plot_angular_topology(all_results):\n",
816
+ " fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)\n",
817
+ "\n",
818
+ " # We only care about Synonyms to see \"How\" they align\n",
819
+ " models = [\"FNet\", \"PRISM\"]\n",
820
+ " colors = {\"FNet\": \"blue\", \"PRISM\": \"red\"}\n",
821
+ "\n",
822
+ " for i, name in enumerate(models):\n",
823
+ " if name not in all_results: continue\n",
824
+ "\n",
825
+ " # Collect ALL phase differences for Synonyms across all pairs\n",
826
+ " # We flatten the list of angles\n",
827
+ " angles = []\n",
828
+ " data = all_results[name]\n",
829
+ " for item in data:\n",
830
+ " if item[\"Type\"] == \"Synonym\":\n",
831
+ " # Convert radians to degrees for readability\n",
832
+ " deg = np.degrees(item[\"Diffs\"])\n",
833
+ " # Wrap to -180 to 180\n",
834
+ " deg = (deg + 180) % 360 - 180\n",
835
+ " angles.extend(deg)\n",
836
+ "\n",
837
+ " sns.histplot(angles, ax=axes[i], bins=60, color=colors[name], stat=\"density\", kde=True)\n",
838
+ " axes[i].set_title(f\"{name} Phase Topology (Synonyms)\")\n",
839
+ " axes[i].set_xlabel(\"Phase Difference (Degrees)\")\n",
840
+ " axes[i].set_xlim(-180, 180)\n",
841
+ " axes[i].grid(True, alpha=0.3)\n",
842
+ "\n",
843
+ " # Add annotation\n",
844
+ " if name == \"FNet\":\n",
845
+ " axes[i].text(0, 0.01, \"BINARY\\n(Sign Flips)\", ha='center', color='black', fontweight='bold')\n",
846
+ " else:\n",
847
+ " axes[i].text(0, 0.01, \"CONTINUOUS\\n(Rotation)\", ha='center', color='black', fontweight='bold')\n",
848
+ "\n",
849
+ " plt.tight_layout()\n",
850
+ " plt.savefig(\"angular_topology_comparison.png\", dpi=300)\n",
851
+ " print(\"📸 Saved topology proof to 'angular_topology_comparison.png'\")\n",
852
+ "\n",
853
+ "# Run the plot with your existing results\n",
854
+ "plot_angular_topology(all_results)"
855
+ ],
856
+ "metadata": {
857
+ "id": "esnL8jUk89ov"
858
+ },
859
+ "execution_count": null,
860
+ "outputs": []
861
+ },
862
+ {
863
+ "cell_type": "code",
864
+ "source": [
865
+ "# ==========================================\n",
866
+ "# 7. RATE VS PHASE DISSOCIATION PROBE\n",
867
+ "# ==========================================\n",
868
+ "from scipy.stats import pearsonr\n",
869
+ "\n",
870
+ "def check_cosine_and_magnitude(model, name, id_a, id_b):\n",
871
+ " z_a = extract_phasor(model, name, id_a)\n",
872
+ " z_b = extract_phasor(model, name, id_b)\n",
873
+ "\n",
874
+ " # --- 1. Vector Cosine Similarity (The Standard Metric) ---\n",
875
+ " # For Complex (PRISM), we treat Re/Im as two coordinate dimensions\n",
876
+ " if name in [\"PRISM\", \"PILLARS\"]:\n",
877
+ " # Flatten: [Re_1, Im_1, Re_2, Im_2, ...]\n",
878
+ " vec_a = torch.cat([z_a.real, z_a.imag], -1)\n",
879
+ " vec_b = torch.cat([z_b.real, z_b.imag], -1)\n",
880
+ " else:\n",
881
+ " # FNet is already real\n",
882
+ " vec_a = z_a.real\n",
883
+ " vec_b = z_b.real\n",
884
+ "\n",
885
+ " vec_sim = F.cosine_similarity(vec_a.unsqueeze(0), vec_b.unsqueeze(0)).item()\n",
886
+ "\n",
887
+ " # --- 2. Magnitude Correlation (The \"Rate Coding\" Check) ---\n",
888
+ " # Do these words emphasize the same dimensions?\n",
889
+ " mag_a = torch.abs(z_a).numpy()\n",
890
+ " mag_b = torch.abs(z_b).numpy()\n",
891
+ "\n",
892
+ " # Pearson correlation of the magnitude profiles\n",
893
+ " # If the model uses Rate Coding, this should be HIGH.\n",
894
+ " # If the model is Iso-Energetic (PRISM), this should be NOISE.\n",
895
+ " if np.std(mag_a) < 1e-6 or np.std(mag_b) < 1e-6:\n",
896
+ " mag_corr = 0.0 # Handle constant magnitude case\n",
897
+ " else:\n",
898
+ " mag_corr, _ = pearsonr(mag_a, mag_b)\n",
899
+ "\n",
900
+ " return vec_sim, mag_corr\n",
901
+ "\n",
902
+ "print(\"\\n⚖️ Running Rate vs. Phase Dissociation...\")\n",
903
+ "\n",
904
+ "comparison_data = []\n",
905
+ "\n",
906
+ "# We only check Synonyms to see how they agree\n",
907
+ "synonym_pairs = [p for p in valid_pairs if p[2] == \"Synonym\"]\n",
908
+ "\n",
909
+ "for name, repo, cls in MODELS_TO_TEST:\n",
910
+ " try:\n",
911
+ " model = smart_load(repo, name, cls)\n",
912
+ " model.eval()\n",
913
+ "\n",
914
+ " vec_scores = []\n",
915
+ " mag_scores = []\n",
916
+ "\n",
917
+ " for w1, w2, _, id1, id2 in synonym_pairs:\n",
918
+ " v_sim, m_corr = check_cosine_and_magnitude(model, name, id1, id2)\n",
919
+ " vec_scores.append(v_sim)\n",
920
+ " mag_scores.append(m_corr)\n",
921
+ "\n",
922
+ " avg_vec = np.mean(vec_scores)\n",
923
+ " avg_mag = np.mean(mag_scores)\n",
924
+ "\n",
925
+ " comparison_data.append({\n",
926
+ " \"Model\": name,\n",
927
+ " \"Vector Sim (Direction)\": avg_vec,\n",
928
+ " \"Mag Corr (Loudness)\": avg_mag\n",
929
+ " })\n",
930
+ "\n",
931
+ " del model; torch.cuda.empty_cache()\n",
932
+ " except Exception as e:\n",
933
+ " print(f\"Skipping {name}: {e}\")\n",
934
+ "\n",
935
+ "# Display the \"Dissociation\" Table\n",
936
+ "df_comp = pd.DataFrame(comparison_data)\n",
937
+ "print(\"\\n🔥 THE DISSOCIATION TABLE 🔥\")\n",
938
+ "print(df_comp.set_index(\"Model\"))"
939
+ ],
940
+ "metadata": {
941
+ "id": "URcMGvENAE3d"
942
+ },
943
+ "execution_count": null,
944
+ "outputs": []
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "source": [
949
+ "import pandas as pd\n",
950
+ "\n",
951
+ "# 1. Convert the valid_pairs list to a DataFrame\n",
952
+ "df_stats = pd.DataFrame(valid_pairs, columns=[\"Word1\", \"Word2\", \"Category\", \"ID1\", \"ID2\"])\n",
953
+ "\n",
954
+ "# 2. Print the statistics\n",
955
+ "print(\"\\n📊 DATASET STATISTICS (Post-Filtering)\")\n",
956
+ "print(\"========================================\")\n",
957
+ "# Counts per category\n",
958
+ "counts = df_stats[\"Category\"].value_counts()\n",
959
+ "print(counts)\n",
960
+ "\n",
961
+ "print(\"----------------------------------------\")\n",
962
+ "print(f\"✅ Total Valid Pairs: {len(df_stats)}\")\n",
963
+ "print(\"========================================\")\n",
964
+ "\n",
965
+ "# 3. Helper for your Paper's Table\n",
966
+ "print(\"\\n📝 UPDATE FOR TABLE 2 (Count Column):\")\n",
967
+ "for category, count in counts.items():\n",
968
+ " print(f\" > {category}: {count}\")"
969
+ ],
970
+ "metadata": {
971
+ "id": "0kJpuCLhEHAJ"
972
+ },
973
+ "execution_count": null,
974
+ "outputs": []
975
+ }
976
+ ],
977
+ "metadata": {
978
+ "colab": {
979
+ "provenance": []
980
+ },
981
+ "kernelspec": {
982
+ "display_name": "Python 3",
983
+ "name": "python3"
984
+ },
985
+ "language_info": {
986
+ "name": "python"
987
+ }
988
+ },
989
+ "nbformat": 4,
990
+ "nbformat_minor": 0
991
+ }
Skewness_paper_last.ipynb ADDED
@@ -0,0 +1,1342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "!pip install -q x-transformers"
23
+ ],
24
+ "metadata": {
25
+ "id": "tBfQX92lxEfP"
26
+ },
27
+ "execution_count": null,
28
+ "outputs": []
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "source": [
33
+ "# @title 🛠️ Setup & Model Loading\n",
34
+ "# ==============================================================================\n",
35
+ "# 1. INSTALL DEPENDENCIES\n",
36
+ "# ==============================================================================\n",
37
+ "!pip install -q numpy torch pandas scipy transformers huggingface_hub\n",
38
+ "\n",
39
+ "import torch\n",
40
+ "import numpy as np\n",
41
+ "import pandas as pd\n",
42
+ "from scipy.stats import skew\n",
43
+ "import sys\n",
44
+ "import os\n",
45
+ "from huggingface_hub import hf_hub_download\n",
46
+ "from transformers import AutoTokenizer\n",
47
+ "\n",
48
+ "# ==============================================================================\n",
49
+ "# 2. LOAD PRISM ARCHITECTURE\n",
50
+ "# ==============================================================================\n",
51
+ "REPO_ID = \"prism-lab/prism-shimmer-100k\"\n",
52
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
53
+ "\n",
54
+ "print(f\"⚙️ Hardware: {DEVICE}\")\n",
55
+ "print(f\"📥 Downloading Architecture from {REPO_ID}...\")\n",
56
+ "\n",
57
+ "# Download the model code to a local folder\n",
58
+ "os.makedirs(\"shimmer_code\", exist_ok=True)\n",
59
+ "hf_hub_download(repo_id=REPO_ID, filename=\"modeling_prism_gated.py\", local_dir=\"shimmer_code\")\n",
60
+ "sys.path.append(\"shimmer_code\")\n",
61
+ "\n",
62
+ "# Now we can import the class\n",
63
+ "from modeling_prism_gated import PRISMHybrid_RoPE\n",
64
+ "\n",
65
+ "# ==============================================================================\n",
66
+ "# 3. LOAD WEIGHTS\n",
67
+ "# ==============================================================================\n",
68
+ "print(\"📚 Loading Tokenizer...\")\n",
69
+ "tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n",
70
+ "\n",
71
+ "print(\"🏗️ Constructing PRISM Model...\")\n",
72
+ "CONFIG = {\n",
73
+ " \"vocab_size\": 58101,\n",
74
+ " \"d_model\": 512,\n",
75
+ " \"num_heads\": 8,\n",
76
+ " \"dff\": 2048,\n",
77
+ " \"dropout\": 0.1,\n",
78
+ " \"max_length\": 128,\n",
79
+ " \"num_encoder_layers\": 6,\n",
80
+ " \"num_refining_layers\": 0,\n",
81
+ " \"num_decoder_layers\": 6\n",
82
+ "}\n",
83
+ "model = PRISMHybrid_RoPE(**CONFIG)\n",
84
+ "\n",
85
+ "print(\"📥 Loading Checkpoint...\")\n",
86
+ "weights_path = hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\")\n",
87
+ "state_dict = torch.load(weights_path, map_location=DEVICE)\n",
88
+ "model.load_state_dict(state_dict)\n",
89
+ "model.to(DEVICE)\n",
90
+ "model.eval()\n",
91
+ "\n",
92
+ "print(\"✅ Model Ready for Probing.\")"
93
+ ],
94
+ "metadata": {
95
+ "id": "fn7A70MZxV1U"
96
+ },
97
+ "execution_count": null,
98
+ "outputs": []
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "source": [
103
+ "# @title 🧪 The Probe & Datasets\n",
104
+ "# ==============================================================================\n",
105
+ "# 1. DATASETS (The \"76 + 70\" Split)\n",
106
+ "# ==============================================================================\n",
107
+ "\n",
108
+ "# A. HARD MODE (Polysemous / Ambiguous)\n",
109
+ "# Words that require context to resolve (High Entropy)\n",
110
+ "raw_poly_candidates = [\n",
111
+ " # --- ORIGINAL SET ---\n",
112
+ " (\"Ich gehe zur Bank um Geld zu holen\", \"Bank\"), (\"Die Bank hat hohe Zinsen\", \"Bank\"),\n",
113
+ " (\"Wir saßen auf einer Bank im Park\", \"Bank\"), (\"Die Bank aus Holz war bequem\", \"Bank\"),\n",
114
+ " (\"Das Schloss hat viele Türme\", \"Schloss\"), (\"Der König wohnt im Schloss\", \"Schloss\"),\n",
115
+ " (\"Der Schlüssel steckt im Schloss\", \"Schloss\"), (\"Das Schloss an der Tür klemmt\", \"Schloss\"),\n",
116
+ " (\"Der Leiter der Firma ist streng\", \"Leiter\"), (\"Unser Leiter plant das Projekt\", \"Leiter\"),\n",
117
+ " (\"Ich steige auf die Leiter\", \"Leiter\"), (\"Die Leiter ist aus Aluminium\", \"Leiter\"),\n",
118
+ " (\"Die Lampe hängt an der Decke\", \"Decke\"), (\"Die Decke ist weiß gestrichen\", \"Decke\"),\n",
119
+ " (\"Mir ist kalt gib mir eine Decke\", \"Decke\"), (\"Die Decke aus Wolle ist warm\", \"Decke\"),\n",
120
+ " (\"Der Kiefer ist ein Nadelbaum\", \"Kiefer\"), (\"Das Holz der Kiefer ist weich\", \"Kiefer\"),\n",
121
+ " (\"Der Arzt röntgt meinen Kiefer\", \"Kiefer\"), (\"Er hat Schmerzen im Kiefer\", \"Kiefer\"),\n",
122
+ " (\"Der Strauß ist ein schneller Vogel\", \"Strauß\"), (\"Dieser Strauß kann nicht fliegen\", \"Strauß\"),\n",
123
+ " (\"Sie kaufte einen bunten Strauß\", \"Strauß\"), (\"Der Strauß Blumen duftet gut\", \"Strauß\"),\n",
124
+ " (\"Er schoss ein schönes Tor\", \"Tor\"), (\"Der Ball flog ins Tor\", \"Tor\"),\n",
125
+ " (\"Das eiserne Tor war verschlossen\", \"Tor\"), (\"Sie öffneten das große Tor\", \"Tor\"),\n",
126
+ " (\"Wir tanzen auf dem Ball\", \"Ball\"), (\"Der Maskenball war elegant\", \"Ball\"),\n",
127
+ " (\"Er warf den Ball weit weg\", \"Ball\"), (\"Der Ball ist rund und rot\", \"Ball\"),\n",
128
+ " (\"Die Schlange im Zoo ist giftig\", \"Schlange\"), (\"Die Schlange zischte laut\", \"Schlange\"),\n",
129
+ " (\"Wir stehen in einer langen Schlange\", \"Schlange\"), (\"Die Schlange an der Kasse war lang\", \"Schlange\"),\n",
130
+ " (\"Der Strom ist ausgefallen\", \"Strom\"), (\"Strom kostet viel Geld\", \"Strom\"),\n",
131
+ " (\"Der Strom fließt ins Meer\", \"Strom\"), (\"Wir schwammen gegen den Strom\", \"Strom\"),\n",
132
+ " (\"Seine Mutter ist sehr nett\", \"Mutter\"), (\"Die Mutter kocht das Essen\", \"Mutter\"),\n",
133
+ " (\"Die Mutter passt auf die Schraube\", \"Mutter\"), (\"Ich brauche eine neue Mutter\", \"Mutter\"),\n",
134
+ " (\"Die Birne schmeckt süß\", \"Birne\"), (\"Ich esse gerne eine Birne\", \"Birne\"),\n",
135
+ " (\"Die Birne in der Lampe ist kaputt\", \"Birne\"), (\"Wir müssen die Birne wechseln\", \"Birne\"),\n",
136
+ " # --- EXPANSION SET ---\n",
137
+ " (\"Das Gericht hat ihn verurteilt\", \"Gericht\"), (\"Der Anwalt geht zum Gericht\", \"Gericht\"),\n",
138
+ " (\"Mein Lieblingsessen ist ein Gericht aus Reis\", \"Gericht\"), (\"Das Gericht schmeckt sehr salzig\", \"Gericht\"),\n",
139
+ " (\"Der Ton war sehr laut\", \"Ton\"), (\"Ich hörte einen hohen Ton\", \"Ton\"),\n",
140
+ " (\"Die Vase ist aus Ton\", \"Ton\"), (\"Wir formen Figuren aus Ton\", \"Ton\"),\n",
141
+ " (\"Das Blatt fällt vom Baum\", \"Blatt\"), (\"Im Herbst werden die Blätter braun\", \"Blatt\"),\n",
142
+ " (\"Ich schreibe auf ein Blatt Papier\", \"Blatt\"), (\"Gib mir bitte ein leeres Blatt\", \"Blatt\"),\n",
143
+ " (\"Der Nagel steckt in der Wand\", \"Nagel\"), (\"Ich schlage den Nagel mit dem Hammer\", \"Nagel\"),\n",
144
+ " (\"Mein Nagel ist abgebrochen\", \"Nagel\"), (\"Sie lackiert sich den Nagel rot\", \"Nagel\"),\n",
145
+ " (\"Die Maus frisst den Käse\", \"Maus\"), (\"Die Katze jagt die Maus\", \"Maus\"),\n",
146
+ " (\"Ich klicke mit der Maus\", \"Maus\"), (\"Der Computer braucht eine neue Maus\", \"Maus\"),\n",
147
+ " (\"Die Erde dreht sich um die Sonne\", \"Erde\"), (\"Der Astronaut schaut auf die Erde\", \"Erde\"),\n",
148
+ " (\"Die Blume braucht frische Erde\", \"Erde\"), (\"Er gräbt ein Loch in die Erde\", \"Erde\"),\n",
149
+ " (\"Der Hahn kräht am Morgen\", \"Hahn\"), (\"Der Hahn hat bunte Federn\", \"Hahn\"),\n",
150
+ " (\"Der Wasserhahn tropft\", \"Hahn\"), (\"Dreh bitte den Hahn zu\", \"Hahn\"),\n",
151
+ " (\"Die Schale der Orange ist bitter\", \"Schale\"), (\"Er wirft die Schale weg\", \"Schale\"),\n",
152
+ " (\"Die Schale steht auf dem Tisch\", \"Schale\"), (\"Ich esse Müsli aus der Schale\", \"Schale\"),\n",
153
+ " (\"Der Bauer melkt die Kühe\", \"Bauer\"), (\"Der Bauer fährt auf dem Traktor\", \"Bauer\"),\n",
154
+ " (\"Ich ziehe den Bauer auf E4\", \"Bauer\"), (\"Der Bauer schlägt den Turm\", \"Bauer\"),\n",
155
+ "]\n",
156
+ "\n",
157
+ "# B. EASY MODE (Casual)\n",
158
+ "raw_casual_candidates = [\n",
159
+ " (\"Die Katze schläft\", \"Katze\"), (\"Der Hund bellt\", \"Hund\"), (\"Das Auto fährt\", \"Auto\"),\n",
160
+ " (\"Wasser ist nass\", \"Wasser\"), (\"Das Brot schmeckt gut\", \"Brot\"), (\"Die Sonne scheint\", \"Sonne\"),\n",
161
+ " (\"Der Mond leuchtet\", \"Mond\"), (\"Das Buch ist spannend\", \"Buch\"), (\"Der Tisch ist rund\", \"Tisch\"),\n",
162
+ " (\"Der Stuhl ist bequem\", \"Stuhl\"), (\"Der Apfel ist rot\", \"Apfel\"), (\"Meine Hand ist kalt\", \"Hand\"),\n",
163
+ " (\"Das Herz klopft\", \"Herz\"), (\"Wir haben Zeit\", \"Zeit\"), (\"Geld ist wichtig\", \"Geld\"),\n",
164
+ " (\"Musik ist schön\", \"Musik\"), (\"Der Film ist zu Ende\", \"Film\"), (\"Das Spiel beginnt\", \"Spiel\"),\n",
165
+ " (\"Die Schule ist aus\", \"Schule\"), (\"Die Stadt ist laut\", \"Stadt\"), (\"Der Fluss fließt\", \"Fluss\"),\n",
166
+ " (\"Das Meer ist tief\", \"Meer\"), (\"Kaffee ist schwarz\", \"Kaffee\"), (\"Milch ist weiß\", \"Milch\"),\n",
167
+ " (\"Der Bruder lacht\", \"Bruder\"), (\"Die Schwester weint\", \"Schwester\"), (\"Das Haus ist groß\", \"Haus\"),\n",
168
+ " (\"Der Garten ist grün\", \"Garten\"), (\"Der Sommer ist heiß\", \"Sommer\"), (\"Der Winter ist kalt\", \"Winter\"),\n",
169
+ " (\"Das Fenster ist offen\", \"Fenster\"), (\"Die Tür ist zu\", \"Tür\"), (\"Der Boden ist sauber\", \"Boden\"),\n",
170
+ " (\"Die Wand ist weiß\", \"Wand\"), (\"Das Dach ist rot\", \"Dach\"), (\"Der Wald ist dunkel\", \"Wald\"),\n",
171
+ " (\"Der Berg ist hoch\", \"Berg\"), (\"Der See ist ruhig\", \"See\"), (\"Das Tier ist wild\", \"Tier\"),\n",
172
+ " (\"Der Mensch denkt\", \"Mensch\"), (\"Das Kind spielt\", \"Kind\"), (\"Die Frau arbeitet\", \"Frau\"),\n",
173
+ " (\"Der Mann schläft\", \"Mann\"), (\"Das Auge sieht\", \"Auge\"), (\"Das Ohr hört\", \"Ohr\"),\n",
174
+ " (\"Die Nase riecht\", \"Nase\"), (\"Der Mund spricht\", \"Mund\"), (\"Der Arm ist stark\", \"Arm\"),\n",
175
+ " (\"Das Bein tut weh\", \"Bein\"), (\"Der Fuß ist groß\", \"Fuß\"), (\"Der Tee ist heiß\", \"Tee\"),\n",
176
+ " (\"Das Bier ist kalt\", \"Bier\"), (\"Der Wein ist rot\", \"Wein\"), (\"Das Glas ist voll\", \"Glas\"),\n",
177
+ " (\"Die Tasse ist leer\", \"Tasse\"), (\"Der Teller ist blau\", \"Teller\"), (\"Die Gabel ist spitz\", \"Gabel\"),\n",
178
+ " (\"Der Löffel ist rund\", \"Löffel\"), (\"Das Messer ist scharf\", \"Messer\"), (\"Der Stift schreibt\", \"Stift\"),\n",
179
+ " (\"Der Brief ist lang\", \"Brief\"), (\"Das Bild ist schön\", \"Bild\"), (\"Die Uhr tickt\", \"Uhr\"),\n",
180
+ " (\"Das Bett ist weich\", \"Bett\"), (\"Der Schrank ist voll\", \"Schrank\"), (\"Das Sofa ist neu\", \"Sofa\"),\n",
181
+ " (\"Das Radio spielt\", \"Radio\"), (\"Das Jahr ist um\", \"Jahr\"), (\"Der Tag war lang\", \"Tag\"),\n",
182
+ " (\"Die Nacht ist kurz\", \"Nacht\")\n",
183
+ "]\n",
184
+ "# ==============================================================================\n",
185
+ "# 2. HELPER FUNCTIONS\n",
186
+ "# ==============================================================================\n",
187
+ "def find_token_index(input_ids, target_word, tokenizer):\n",
188
+ " tokens = tokenizer.convert_ids_to_tokens(input_ids)\n",
189
+ " for i, t in enumerate(tokens):\n",
190
+ " # Clean BPE artifacts\n",
191
+ " clean = t.replace('Ġ', '').replace('▁', '').replace(' ', '')\n",
192
+ " if target_word.lower() == clean.lower():\n",
193
+ " return i\n",
194
+ " # Fallback\n",
195
+ " for i, t in enumerate(tokens):\n",
196
+ " clean = t.replace('Ġ', '').replace('▁', '').replace(' ', '')\n",
197
+ " if target_word.lower() in clean.lower():\n",
198
+ " return i\n",
199
+ " return 1\n",
200
+ "\n",
201
+ "def filter_dataset(candidates, tokenizer, label):\n",
202
+ " \"\"\"Ensures we only test single-token words to keep phase metrics valid.\"\"\"\n",
203
+ " valid_data = []\n",
204
+ " rejected_count = 0\n",
205
+ " print(f\"\\n🔍 Validating {label} Candidates...\")\n",
206
+ "\n",
207
+ " for context, target in candidates:\n",
208
+ " # Check standard and space-prefixed tokenization\n",
209
+ " t1 = tokenizer.encode(target, add_special_tokens=False)\n",
210
+ " t2 = tokenizer.encode(\" \" + target, add_special_tokens=False)\n",
211
+ "\n",
212
+ " if len(t1) == 1 or len(t2) == 1:\n",
213
+ " valid_data.append((context, target))\n",
214
+ " else:\n",
215
+ " rejected_count += 1\n",
216
+ "\n",
217
+ " print(f\" ✅ Accepted: {len(valid_data)} examples.\")\n",
218
+ " print(f\" ❌ Rejected: {rejected_count} multi-token words.\")\n",
219
+ " return valid_data\n",
220
+ "\n",
221
+ "# ==============================================================================\n",
222
+ "# 3. UNIFIED PROBE CLASS\n",
223
+ "# ==============================================================================\n",
224
+ "def run_unified_probe(model, tokenizer, dataset, label, device):\n",
225
+ " num_layers = len(model.prism_encoder.layers)\n",
226
+ " rotation_stats = {i: [] for i in range(num_layers)}\n",
227
+ "\n",
228
+ " # Hooks\n",
229
+ " hook_data = {}\n",
230
+ "\n",
231
+ " def physics_hook(layer_idx):\n",
232
+ " def hook(module, input, output):\n",
233
+ " x, y = input[0].detach(), output.detach()\n",
234
+ "\n",
235
+ " # --- PHASE ROTATION CALCULATION ---\n",
236
+ " # 1. Norms\n",
237
+ " norm_x = torch.norm(x, p=2, dim=-1)\n",
238
+ " norm_y = torch.norm(y, p=2, dim=-1)\n",
239
+ "\n",
240
+ " # 2. Flatten Complex to 2D Real [Batch, Seq, Dim*2]\n",
241
+ " # This allows us to calculate geometric angle\n",
242
+ " x_f = x.view(x.shape[0], x.shape[1], -1)\n",
243
+ " y_f = y.view(y.shape[0], y.shape[1], -1)\n",
244
+ "\n",
245
+ " # 3. Dot Product\n",
246
+ " dot = (x_f.real * y_f.real + x_f.imag * y_f.imag).sum(dim=-1)\n",
247
+ "\n",
248
+ " # 4. Angle (Arccos)\n",
249
+ " cosine = torch.clamp(dot / (norm_x * norm_y + 1e-9), -1.0, 1.0)\n",
250
+ " angle = torch.rad2deg(torch.acos(cosine)).cpu()\n",
251
+ "\n",
252
+ " hook_data[f'rot_{layer_idx}'] = angle\n",
253
+ " return hook\n",
254
+ "\n",
255
+ " # Register\n",
256
+ " model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
257
+ " for i, layer in enumerate(model.prism_encoder.layers):\n",
258
+ " layer.register_forward_hook(physics_hook(i))\n",
259
+ "\n",
260
+ " # Execute\n",
261
+ " model.eval()\n",
262
+ " print(f\"🔬 Running Probe on {len(dataset)} {label} examples...\")\n",
263
+ "\n",
264
+ " for context, target in dataset:\n",
265
+ " hook_data = {}\n",
266
+ " inputs = tokenizer(context, return_tensors=\"pt\").to(device)\n",
267
+ "\n",
268
+ " with torch.no_grad():\n",
269
+ " x = model.harmonic_embedding(inputs.input_ids)\n",
270
+ " src_mask = (inputs.input_ids == tokenizer.pad_token_id)\n",
271
+ " model.prism_encoder(x, src_mask)\n",
272
+ "\n",
273
+ " idx = find_token_index(inputs.input_ids[0], target, tokenizer)\n",
274
+ "\n",
275
+ " for i in range(num_layers):\n",
276
+ " if f'rot_{i}' in hook_data:\n",
277
+ " batch = hook_data[f'rot_{i}']\n",
278
+ " # Handle batch dimension if present\n",
279
+ " val = batch[0, idx].item() if batch.dim() > 1 else batch[idx].item()\n",
280
+ " rotation_stats[i].append(val)\n",
281
+ "\n",
282
+ " model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
283
+ " return pd.DataFrame(rotation_stats)\n",
284
+ "\n",
285
+ "# ==============================================================================\n",
286
+ "# 4. EXECUTION & REPORTING\n",
287
+ "# ==============================================================================\n",
288
+ "# A. Filter Data\n",
289
+ "ds_hard = filter_dataset(raw_poly_candidates, tokenizer, \"HARD (Polysemous)\")\n",
290
+ "ds_easy = filter_dataset(raw_casual_candidates, tokenizer, \"EASY (Casual)\")\n",
291
+ "\n",
292
+ "# B. Run Analysis\n",
293
+ "DEVICE = next(model.parameters()).device # Robust device check\n",
294
+ "df_hard = run_unified_probe(model, tokenizer, ds_hard, \"HARD\", DEVICE)\n",
295
+ "df_easy = run_unified_probe(model, tokenizer, ds_easy, \"EASY\", DEVICE)\n",
296
+ "\n",
297
+ "# C. Generate ASCII Tables\n",
298
+ "def print_stats(df, title):\n",
299
+ " print(f\"\\n📊 {title} (N={len(df)})\")\n",
300
+ " print(\"=\"*90)\n",
301
+ " print(f\"{'Lyr':<3} | {'Mean (°)':<10} | {'Median (°)':<10} | {'Max (°)':<10} | {'Skewness':<10} | {'Regime'}\")\n",
302
+ " print(\"-\" * 90)\n",
303
+ "\n",
304
+ " total_skew = 0\n",
305
+ " for col in df.columns:\n",
306
+ " d = df[col]\n",
307
+ " skew_val = d.skew()\n",
308
+ " total_skew += skew_val\n",
309
+ "\n",
310
+ " # Interpret Regime\n",
311
+ " if skew_val > 1.5: regime = \"⚡ STEERING (Heavy Tail)\"\n",
312
+ " elif skew_val > 0.5: regime = \"⚖️ HYBRID\"\n",
313
+ " else: regime = \"💤 INERTIAL\"\n",
314
+ "\n",
315
+ " print(f\"{col:<3} | {d.mean():6.2f} | {d.median():6.2f} | {d.max():6.2f} | {skew_val:6.2f} | {regime}\")\n",
316
+ "\n",
317
+ " print(\"-\" * 90)\n",
318
+ " print(f\"∑ INTEGRATED SKEWNESS (Metabolic Load): {total_skew:.2f}\")\n",
319
+ "\n",
320
+ "print_stats(df_hard, \"TABLE 3A: POLYSEMOUS TOKENS (Ambiguous)\")\n",
321
+ "print_stats(df_easy, \"TABLE 3B: CASUAL TOKENS (Unambiguous)\")"
322
+ ],
323
+ "metadata": {
324
+ "id": "hvHJcuqmxDkt"
325
+ },
326
+ "execution_count": null,
327
+ "outputs": []
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "source": [
332
+ "import matplotlib.pyplot as plt\n",
333
+ "import seaborn as sns\n",
334
+ "\n",
335
+ "# ==============================================================================\n",
336
+ "# FIGURE 3 GENERATOR (Robust N=76 Version)\n",
337
+ "# ==============================================================================\n",
338
+ "def plot_figure_3_robust(df_hard, df_easy, save_path=\"fig3_phase_steering_robust.png\"):\n",
339
+ " \"\"\"\n",
340
+ " Generates the \"Dual Regime\" Violin Plot matching the paper style.\n",
341
+ " Visualizes the 'Mid-Network Resolution' (Skew spike at Layer 3).\n",
342
+ " \"\"\"\n",
343
+ " # Setup the canvas (Two panels, shared Y-axis)\n",
344
+ " fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True, dpi=300)\n",
345
+ "\n",
346
+ " # 1. AMBIGUOUS PANEL (The Steering Regime)\n",
347
+ " # We use a Red palette to signify \"Metabolic Work\"\n",
348
+ " sns.violinplot(\n",
349
+ " data=df_hard,\n",
350
+ " palette=\"Reds\",\n",
351
+ " ax=axes[0],\n",
352
+ " inner=\"quartile\",\n",
353
+ " linewidth=1.2,\n",
354
+ " cut=0 # Don't extend past data range\n",
355
+ " )\n",
356
+ " axes[0].set_title(\"(A) Ambiguous (Steering)\", fontweight='bold', fontsize=12, color='darkred')\n",
357
+ " axes[0].set_ylabel(\"Phase Rotation (°)\", fontsize=11, fontweight='bold')\n",
358
+ " axes[0].set_xlabel(\"Layer Depth\", fontsize=10)\n",
359
+ " axes[0].grid(axis='y', linestyle='--', alpha=0.3)\n",
360
+ "\n",
361
+ " # 2. UNAMBIGUOUS PANEL (The Inertial Regime)\n",
362
+ " # We use a Blue/Green palette to signify \"Coasting\"\n",
363
+ " sns.violinplot(\n",
364
+ " data=df_easy,\n",
365
+ " palette=\"mako\",\n",
366
+ " ax=axes[1],\n",
367
+ " inner=\"quartile\",\n",
368
+ " linewidth=1.2,\n",
369
+ " cut=0\n",
370
+ " )\n",
371
+ " axes[1].set_title(\"(B) Unambiguous (Inertial)\", fontweight='bold', fontsize=12, color='darkgreen')\n",
372
+ " axes[1].set_xlabel(\"Layer Depth\", fontsize=10)\n",
373
+ " axes[1].grid(axis='y', linestyle='--', alpha=0.3)\n",
374
+ " axes[1].set_ylabel(\"\") # Remove redundant y-label\n",
375
+ "\n",
376
+ " # Formatting\n",
377
+ " plt.ylim(0, 30) # Focus on the active range (Max is ~22 deg)\n",
378
+ " sns.despine(trim=True, offset=5)\n",
379
+ " plt.tight_layout()\n",
380
+ "\n",
381
+ " # Save & Show\n",
382
+ " plt.savefig(save_path, bbox_inches='tight')\n",
383
+ " plt.show()\n",
384
+ " print(f\"✅ Updated Figure 3 saved to: {save_path}\")\n",
385
+ "\n",
386
+ "# Run with your existing dataframes\n",
387
+ "plot_figure_3_robust(df_hard, df_easy)"
388
+ ],
389
+ "metadata": {
390
+ "id": "bG4t_QIyycpD"
391
+ },
392
+ "execution_count": null,
393
+ "outputs": []
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "source": [
398
+ "# @title 🛠️ Fixed Stress Test (Using Custom Generation API)\n",
399
+ "import torch\n",
400
+ "import pandas as pd\n",
401
+ "from tqdm import tqdm\n",
402
+ "\n",
403
+ "# ==============================================================================\n",
404
+ "# 1. SETUP & DATA\n",
405
+ "# ==============================================================================\n",
406
+ "test_cases = [\n",
407
+ " # (German Word, Context Helper, Expected English)\n",
408
+ " (\"Bank\", \"Die Bank.\", \"bench\"), # Ambiguous: Bench vs Bank\n",
409
+ " (\"Schloss\", \"Das Schloss.\", \"castle\"), # Ambiguous: Castle vs Lock\n",
410
+ " (\"Leiter\", \"Der Leiter.\", \"leader\"), # Ambiguous: Ladder vs Leader\n",
411
+ " (\"Decke\", \"Die Decke.\", \"ceiling\"), # Ambiguous: Blanket vs Ceiling\n",
412
+ " (\"Kiefer\", \"Der Kiefer.\", \"jaw\"), # Ambiguous: Pine vs Jaw\n",
413
+ " (\"Gericht\", \"Das Gericht.\", \"court\"), # Ambiguous: Dish vs Court\n",
414
+ " (\"Steuer\", \"Das Steuer.\", \"helm\"), # Ambiguous: Tax vs Helm\n",
415
+ " (\"Hahn\", \"Der Hahn.\", \"rooster\"), # Ambiguous: Tap vs Rooster\n",
416
+ " (\"Tau\", \"Das Tau.\", \"rope\"), # Ambiguous: Dew vs Rope\n",
417
+ " (\"Strauß\", \"Der Strauß.\", \"bouquet\"), # Ambiguous: Ostrich vs Bouquet\n",
418
+ "]\n",
419
+ "\n",
420
+ "# ==============================================================================\n",
421
+ "# 2. THE EXPERIMENT LOOP\n",
422
+ "# ==============================================================================\n",
423
+ "results = []\n",
424
+ "print(f\"📉 Running Phase Interference Test on {len(test_cases)} ambiguous terms...\\n\")\n",
425
+ "\n",
426
+ "model.eval()\n",
427
+ "\n",
428
+ "# Helper to run your custom generate function\n",
429
+ "def run_prism_gen(text_input):\n",
430
+ " # 1. Tokenize (Get IDs only)\n",
431
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
432
+ "\n",
433
+ " # 2. Call YOUR custom generate method\n",
434
+ " # Signature: generate(self, src, max_length, num_beams=5)\n",
435
+ " with torch.no_grad():\n",
436
+ " out_ids = model.generate(src=input_tensor, max_length=10, num_beams=1)\n",
437
+ "\n",
438
+ " # 3. Decode\n",
439
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
440
+ "\n",
441
+ "for word, context_phrase, target in tqdm(test_cases):\n",
442
+ " try:\n",
443
+ " # --- PASS 1: ISOLATION (Single Token) ---\n",
444
+ " trans_iso = run_prism_gen(word)\n",
445
+ "\n",
446
+ " # --- PASS 2: INTERFERENCE (Context) ---\n",
447
+ " trans_int = run_prism_gen(context_phrase)\n",
448
+ "\n",
449
+ " # Log Result\n",
450
+ " # We flag it as \"FAIL\" (bug) if the isolation translation is WRONG.\n",
451
+ " # BUT for your paper, an \"Isolation Fail\" + \"Context Pass\" is actually a SUCCESSFUL scientific result.\n",
452
+ "\n",
453
+ " status = \"✅ Context Fix\" if (target.lower() not in trans_iso.lower() and target.lower() in trans_int.lower()) else \"Neutral\"\n",
454
+ "\n",
455
+ " results.append({\n",
456
+ " \"Source\": word,\n",
457
+ " \"Target\": target,\n",
458
+ " \"⛔ Isolation\": trans_iso,\n",
459
+ " \"✅ Context\": trans_int,\n",
460
+ " \"Outcome\": status\n",
461
+ " })\n",
462
+ " except Exception as e:\n",
463
+ " print(f\"Error on {word}: {e}\")\n",
464
+ "\n",
465
+ "# ==============================================================================\n",
466
+ "# 3. ANALYSIS & REPORT\n",
467
+ "# ==============================================================================\n",
468
+ "df_res = pd.DataFrame(results)\n",
469
+ "\n",
470
+ "print(\"\\n\\n🌊 EXPERIMENTAL RESULTS: The Single-Token Paradox\")\n",
471
+ "print(\"=\"*110)\n",
472
+ "print(f\"{'Input':<10} | {'Target':<10} | {'⛔ Isolation (No Wave)':<30} | {'✅ Context (Interference)':<30}\")\n",
473
+ "print(\"-\" * 110)\n",
474
+ "\n",
475
+ "success_scientific_count = 0\n",
476
+ "\n",
477
+ "for _, row in df_res.iterrows():\n",
478
+ " iso_text = row['⛔ Isolation']\n",
479
+ " ctx_text = row['✅ Context']\n",
480
+ "\n",
481
+ " # Visual Logic: If Isolation failed to find target, but Context found it -> Highlight\n",
482
+ " if row['Outcome'] == \"✅ Context Fix\":\n",
483
+ " iso_text = f\"--> {iso_text} <--\" # Shows the \"Inertial\" failure\n",
484
+ " success_scientific_count += 1\n",
485
+ "\n",
486
+ " print(f\"{row['Source']:<10} | {row['Target']:<10} | {iso_text:<30} | {ctx_text:<30}\")\n",
487
+ "\n",
488
+ "print(\"=\"*110)\n",
489
+ "print(f\"\\n🧪 SCIENTIFIC CONCLUSION:\")\n",
490
+ "if success_scientific_count > 0:\n",
491
+ " print(f\"🎉 OBSERVED: {success_scientific_count}/{len(test_cases)} cases showed the 'Physical Abstractor' effect!\")\n",
492
+ " print(\" The model failed to resolve ambiguity in isolation (as predicted) but resolved it with context.\")\n",
493
+ "else:\n",
494
+ " print(\"🤔 OBSERVED: The model resolved isolation perfectly. Rate-Coding (Magnitude) might be leaking.\")"
495
+ ],
496
+ "metadata": {
497
+ "id": "llU4wmT67ZRm"
498
+ },
499
+ "execution_count": null,
500
+ "outputs": []
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "source": [
505
+ "# @title 🧪 Sanity Check: Full Wave Packets (Sentences)\n",
506
+ "# ==============================================================================\n",
507
+ "# HYPOTHESIS:\n",
508
+ "# If the \"Single-Token Paradox\" is real, then full sentences (rich interference)\n",
509
+ "# should translate fluently, unlike the \"broken\" isolated tokens.\n",
510
+ "# ==============================================================================\n",
511
+ "\n",
512
+ "sanity_sentences = [\n",
513
+ " # 1. Standard Fluency (Control)\n",
514
+ " \"Das Haus ist groß und schön.\",\n",
515
+ " \"Die Katze schläft auf dem Sofa.\",\n",
516
+ " \"Wir gehen heute in den Park.\",\n",
517
+ " \"Das Wetter ist sehr gut.\",\n",
518
+ "\n",
519
+ " # 2. Contextual Resolution (The ambiguous words from before)\n",
520
+ " \"Der Lehrer schreibt an die Tafel.\", # \"Leiter\" implies leader/head, but checking context\n",
521
+ " \"Ich sitze auf einer Bank im Garten.\", # Should lock to \"Bench\"\n",
522
+ " \"Ich bringe mein Geld zur Bank.\", # Should lock to \"Bank\" (Financial)\n",
523
+ " \"Das Schloss ist alt und aus Stein.\", # Should lock to \"Castle\"\n",
524
+ " \"Der Schlüssel steckt im Schloss.\", # Should lock to \"Lock\"\n",
525
+ "\n",
526
+ " # 3. Complex Grammar (Phase coherence test)\n",
527
+ " \"Obwohl es regnet, gehe ich spazieren.\",\n",
528
+ " \"Wenn du Zeit hast, komm bitte vorbei.\"\n",
529
+ "]\n",
530
+ "\n",
531
+ "print(f\"🌊 Running Sanity Check on {len(sanity_sentences)} sentences...\\n\")\n",
532
+ "print(\"=\"*100)\n",
533
+ "print(f\"{'German Source':<40} | {'🇬🇧 PRISM Translation'}\")\n",
534
+ "print(\"-\" * 100)\n",
535
+ "\n",
536
+ "model.eval()\n",
537
+ "\n",
538
+ "# Re-using the helper from before\n",
539
+ "def run_prism_gen(text_input):\n",
540
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
541
+ " with torch.no_grad():\n",
542
+ " # Increased max_length for full sentences\n",
543
+ " out_ids = model.generate(src=input_tensor, max_length=40, num_beams=1)\n",
544
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
545
+ "\n",
546
+ "for sent in sanity_sentences:\n",
547
+ " try:\n",
548
+ " translation = run_prism_gen(sent)\n",
549
+ " print(f\"{sent:<40} | {translation}\")\n",
550
+ " except Exception as e:\n",
551
+ " print(f\"{sent:<40} | ❌ Error: {e}\")\n",
552
+ "\n",
553
+ "print(\"=\"*100)"
554
+ ],
555
+ "metadata": {
556
+ "id": "71sdvAAn8O63"
557
+ },
558
+ "execution_count": null,
559
+ "outputs": []
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "source": [
564
+ "# @title 🧪 Control Experiment: Unambiguous Single Tokens\n",
565
+ "# ==============================================================================\n",
566
+ "# HYPOTHESIS:\n",
567
+ "# If the single-token collapse is due to AMBIGUITY (phase can't resolve meaning),\n",
568
+ "# then UNAMBIGUOUS tokens should translate correctly even in isolation.\n",
569
+ "# If they ALSO collapse, the issue is purely L=1 (no interference at all).\n",
570
+ "# ==============================================================================\n",
571
+ "\n",
572
+ "import torch\n",
573
+ "import pandas as pd\n",
574
+ "from tqdm import tqdm\n",
575
+ "\n",
576
+ "# ==============================================================================\n",
577
+ "# 1. UNAMBIGUOUS TEST SET\n",
578
+ "# ==============================================================================\n",
579
+ "# These words have ONE clear meaning - no context needed for humans\n",
580
+ "unambiguous_cases = [\n",
581
+ " # (German Word, Expected English)\n",
582
+ " # --- Animals (No ambiguity) ---\n",
583
+ " (\"Katze\", \"cat\"),\n",
584
+ " (\"Hund\", \"dog\"),\n",
585
+ " (\"Pferd\", \"horse\"),\n",
586
+ " (\"Vogel\", \"bird\"),\n",
587
+ " (\"Fisch\", \"fish\"),\n",
588
+ " (\"Elefant\", \"elephant\"),\n",
589
+ " (\"Löwe\", \"lion\"),\n",
590
+ " (\"Bär\", \"bear\"),\n",
591
+ "\n",
592
+ " # --- Objects (No ambiguity) ---\n",
593
+ " (\"Tisch\", \"table\"),\n",
594
+ " (\"Stuhl\", \"chair\"),\n",
595
+ " (\"Buch\", \"book\"),\n",
596
+ " (\"Auto\", \"car\"),\n",
597
+ " (\"Haus\", \"house\"),\n",
598
+ " (\"Fenster\", \"window\"),\n",
599
+ " (\"Lampe\", \"lamp\"),\n",
600
+ " (\"Telefon\", \"phone\"),\n",
601
+ "\n",
602
+ " # --- Nature (No ambiguity) ---\n",
603
+ " (\"Baum\", \"tree\"),\n",
604
+ " (\"Blume\", \"flower\"),\n",
605
+ " (\"Wolke\", \"cloud\"),\n",
606
+ " (\"Regen\", \"rain\"),\n",
607
+ " (\"Schnee\", \"snow\"),\n",
608
+ " (\"Feuer\", \"fire\"),\n",
609
+ "\n",
610
+ " # --- Body Parts (No ambiguity) ---\n",
611
+ " (\"Kopf\", \"head\"),\n",
612
+ " (\"Auge\", \"eye\"),\n",
613
+ " (\"Ohr\", \"ear\"),\n",
614
+ " (\"Nase\", \"nose\"),\n",
615
+ " (\"Finger\", \"finger\"),\n",
616
+ "\n",
617
+ " # --- Food (No ambiguity) ---\n",
618
+ " (\"Brot\", \"bread\"),\n",
619
+ " (\"Käse\", \"cheese\"),\n",
620
+ " (\"Apfel\", \"apple\"),\n",
621
+ " (\"Wasser\", \"water\"),\n",
622
+ " (\"Milch\", \"milk\"),\n",
623
+ "]\n",
624
+ "\n",
625
+ "# ==============================================================================\n",
626
+ "# 2. THE EXPERIMENT\n",
627
+ "# ==============================================================================\n",
628
+ "results_unambig = []\n",
629
+ "print(f\"🔬 Running UNAMBIGUOUS Single-Token Test on {len(unambiguous_cases)} words...\\n\")\n",
630
+ "\n",
631
+ "model.eval()\n",
632
+ "\n",
633
+ "def run_prism_gen(text_input, max_len=10):\n",
634
+ " \"\"\"Helper to run generation\"\"\"\n",
635
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
636
+ " with torch.no_grad():\n",
637
+ " out_ids = model.generate(src=input_tensor, max_length=max_len, num_beams=1)\n",
638
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
639
+ "\n",
640
+ "def is_repetition_collapse(text):\n",
641
+ " \"\"\"Detect if output is repetitive garbage\"\"\"\n",
642
+ " words = text.split()\n",
643
+ " if len(words) < 2:\n",
644
+ " return False\n",
645
+ " # Check if first word repeats\n",
646
+ " first_word = words[0].lower().strip('.,!?')\n",
647
+ " repeat_count = sum(1 for w in words if w.lower().strip('.,!?') == first_word)\n",
648
+ " return repeat_count >= len(words) * 0.5 # 50%+ repetition = collapse\n",
649
+ "\n",
650
+ "def check_correct(output, target):\n",
651
+ " \"\"\"Check if target word appears in output\"\"\"\n",
652
+ " return target.lower() in output.lower()\n",
653
+ "\n",
654
+ "# Run the test\n",
655
+ "for word, target in tqdm(unambiguous_cases):\n",
656
+ " try:\n",
657
+ " # Single token translation\n",
658
+ " translation = run_prism_gen(word)\n",
659
+ "\n",
660
+ " # Analyze\n",
661
+ " collapsed = is_repetition_collapse(translation)\n",
662
+ " correct = check_correct(translation, target)\n",
663
+ "\n",
664
+ " # Also test with minimal context (article)\n",
665
+ " # German articles: der/die/das\n",
666
+ " context_translation = run_prism_gen(f\"Das {word}.\")\n",
667
+ " context_correct = check_correct(context_translation, target)\n",
668
+ "\n",
669
+ " results_unambig.append({\n",
670
+ " \"German\": word,\n",
671
+ " \"Target\": target,\n",
672
+ " \"Isolation\": translation,\n",
673
+ " \"Collapsed\": collapsed,\n",
674
+ " \"Correct\": correct,\n",
675
+ " \"With Article\": context_translation,\n",
676
+ " \"Context Correct\": context_correct\n",
677
+ " })\n",
678
+ " except Exception as e:\n",
679
+ " print(f\"Error on {word}: {e}\")\n",
680
+ "\n",
681
+ "# ==============================================================================\n",
682
+ "# 3. ANALYSIS & REPORT\n",
683
+ "# ==============================================================================\n",
684
+ "df_unambig = pd.DataFrame(results_unambig)\n",
685
+ "\n",
686
+ "# Stats\n",
687
+ "total = len(df_unambig)\n",
688
+ "collapsed_count = df_unambig['Collapsed'].sum()\n",
689
+ "correct_iso = df_unambig['Correct'].sum()\n",
690
+ "correct_ctx = df_unambig['Context Correct'].sum()\n",
691
+ "\n",
692
+ "print(\"\\n\" + \"=\"*120)\n",
693
+ "print(\"🧪 CONTROL EXPERIMENT: UNAMBIGUOUS SINGLE TOKENS\")\n",
694
+ "print(\"=\"*120)\n",
695
+ "print(f\"{'German':<12} | {'Target':<10} | {'⛔ Isolation (L=1)':<35} | {'Collapse?':<10} | {'✅ With Article':<30}\")\n",
696
+ "print(\"-\" * 120)\n",
697
+ "\n",
698
+ "for _, row in df_unambig.iterrows():\n",
699
+ " collapse_marker = \"💥 YES\" if row['Collapsed'] else \"No\"\n",
700
+ " iso_display = row['Isolation'][:33] + \"..\" if len(row['Isolation']) > 35 else row['Isolation']\n",
701
+ " ctx_display = row['With Article'][:28] + \"..\" if len(row['With Article']) > 30 else row['With Article']\n",
702
+ "\n",
703
+ " print(f\"{row['German']:<12} | {row['Target']:<10} | {iso_display:<35} | {collapse_marker:<10} | {ctx_display:<30}\")\n",
704
+ "\n",
705
+ "print(\"=\"*120)\n",
706
+ "\n",
707
+ "# ==============================================================================\n",
708
+ "# 4. SCIENTIFIC SUMMARY\n",
709
+ "# ==============================================================================\n",
710
+ "print(\"\\n📊 STATISTICAL SUMMARY\")\n",
711
+ "print(\"-\"*60)\n",
712
+ "print(f\"Total test cases: {total}\")\n",
713
+ "print(f\"Repetition collapses (L=1): {collapsed_count} ({100*collapsed_count/total:.1f}%)\")\n",
714
+ "print(f\"Correct in isolation: {correct_iso} ({100*correct_iso/total:.1f}%)\")\n",
715
+ "print(f\"Correct with article context: {correct_ctx} ({100*correct_ctx/total:.1f}%)\")\n",
716
+ "print(\"-\"*60)\n",
717
+ "\n",
718
+ "print(\"\\n🔬 INTERPRETATION:\")\n",
719
+ "if collapsed_count > total * 0.5:\n",
720
+ " print(\" 💥 FINDING: Unambiguous tokens ALSO collapse!\")\n",
721
+ " print(\" → This suggests L=1 provides insufficient interference for ANY semantic encoding.\")\n",
722
+ " print(\" → The decoder needs wave structure, not just meaning clarity.\")\n",
723
+ " print(\"\\n 📝 PAPER IMPLICATION: Phase interference is necessary for GENERATION,\")\n",
724
+ " print(\" not just disambiguation. Single tokens lack the spectral 'carrier wave'\")\n",
725
+ " print(\" needed to bootstrap autoregressive decoding.\")\n",
726
+ "elif collapsed_count > 0:\n",
727
+ " print(\" ⚖️ FINDING: Mixed results - some collapse, some don't.\")\n",
728
+ " print(\" → Suggests a threshold effect based on embedding quality or token frequency.\")\n",
729
+ "else:\n",
730
+ " print(\" ✅ FINDING: Unambiguous tokens translate correctly!\")\n",
731
+ " print(\" → This confirms the hypothesis: ambiguity requires interference to resolve,\")\n",
732
+ " print(\" but clear semantics can propagate through the encoder even at L=1.\")\n",
733
+ " print(\"\\n 📝 PAPER IMPLICATION: The single-token collapse is SPECIFIC to polysemy.\")\n",
734
+ " print(\" PRISM's phase encoding captures unambiguous semantics in static embeddings,\")\n",
735
+ " print(\" but requires interference patterns to perform 'semantic selection'.\")\n",
736
+ "\n",
737
+ "# ==============================================================================\n",
738
+ "# 5. COMPARISON TABLE (Side by Side)\n",
739
+ "# ==============================================================================\n",
740
+ "print(\"\\n\\n\" + \"=\"*80)\n",
741
+ "print(\"📈 AMBIGUOUS vs UNAMBIGUOUS COMPARISON\")\n",
742
+ "print(\"=\"*80)\n",
743
+ "print(f\"{'Metric':<40} | {'Ambiguous':<15} | {'Unambiguous':<15}\")\n",
744
+ "print(\"-\"*80)\n",
745
+ "print(f\"{'Repetition Collapse Rate':<40} | {'~100%':<15} | {f'{100*collapsed_count/total:.0f}%':<15}\")\n",
746
+ "print(f\"{'Correct in Isolation':<40} | {'~0%':<15} | {f'{100*correct_iso/total:.0f}%':<15}\")\n",
747
+ "print(f\"{'Correct with Minimal Context':<40} | {'Partial':<15} | {f'{100*correct_ctx/total:.0f}%':<15}\")\n",
748
+ "print(\"=\"*80)"
749
+ ],
750
+ "metadata": {
751
+ "id": "k6rH5YJWauTc"
752
+ },
753
+ "execution_count": null,
754
+ "outputs": []
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "source": [
759
+ "# Test with num_beams=1 (greedy) vs num_beams=5 (beam search)\n",
760
+ "word = \"Hund\"\n",
761
+ "input_tensor = tokenizer(word, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
762
+ "\n",
763
+ "with torch.no_grad():\n",
764
+ " out_greedy = model.generate(src=input_tensor, max_length=10, num_beams=1)\n",
765
+ " out_beam = model.generate(src=input_tensor, max_length=10, num_beams=5)\n",
766
+ "\n",
767
+ "print(f\"Greedy (beams=1): {tokenizer.decode(out_greedy[0], skip_special_tokens=True)}\")\n",
768
+ "print(f\"Beam (beams=5): {tokenizer.decode(out_beam[0], skip_special_tokens=True)}\")"
769
+ ],
770
+ "metadata": {
771
+ "id": "E0WRta7Qdjcg"
772
+ },
773
+ "execution_count": null,
774
+ "outputs": []
775
+ },
776
+ {
777
+ "cell_type": "code",
778
+ "source": [
779
+ "# @title 🧪 Large-Scale Single-Token Analysis (N >> 32)\n",
780
+ "# ==============================================================================\n",
781
+ "# GOAL: Increase sample size with automatic single-token validation\n",
782
+ "# ==============================================================================\n",
783
+ "# This is Claude's overkill code. It is a savant.\n",
784
+ "\n",
785
+ "\n",
786
+ "import torch\n",
787
+ "import pandas as pd\n",
788
+ "from tqdm import tqdm\n",
789
+ "\n",
790
+ "# ==============================================================================\n",
791
+ "# 1. LARGE CANDIDATE POOLS\n",
792
+ "# ==============================================================================\n",
793
+ "\n",
794
+ "# A. AMBIGUOUS CANDIDATES (German words with multiple meanings)\n",
795
+ "ambiguous_candidates = [\n",
796
+ " # Word, Meaning1, Meaning2\n",
797
+ " (\"Bank\", \"bench\", \"bank\"),\n",
798
+ " (\"Schloss\", \"castle\", \"lock\"),\n",
799
+ " (\"Leiter\", \"ladder\", \"leader\"),\n",
800
+ " (\"Decke\", \"ceiling\", \"blanket\"),\n",
801
+ " (\"Kiefer\", \"pine\", \"jaw\"),\n",
802
+ " (\"Strauß\", \"ostrich\", \"bouquet\"),\n",
803
+ " (\"Tor\", \"gate\", \"goal\"),\n",
804
+ " (\"Ball\", \"ball\", \"dance\"),\n",
805
+ " (\"Schlange\", \"snake\", \"queue\"),\n",
806
+ " (\"Strom\", \"electricity\", \"river\"),\n",
807
+ " (\"Mutter\", \"mother\", \"nut\"),\n",
808
+ " (\"Birne\", \"pear\", \"lightbulb\"),\n",
809
+ " (\"Gericht\", \"court\", \"dish\"),\n",
810
+ " (\"Ton\", \"sound\", \"clay\"),\n",
811
+ " (\"Blatt\", \"leaf\", \"sheet\"),\n",
812
+ " (\"Nagel\", \"nail\", \"fingernail\"),\n",
813
+ " (\"Maus\", \"mouse\", \"computer mouse\"),\n",
814
+ " (\"Erde\", \"earth\", \"soil\"),\n",
815
+ " (\"Hahn\", \"rooster\", \"tap\"),\n",
816
+ " (\"Schale\", \"shell\", \"bowl\"),\n",
817
+ " (\"Bauer\", \"farmer\", \"pawn\"),\n",
818
+ " (\"Steuer\", \"tax\", \"steering wheel\"),\n",
819
+ " (\"Tau\", \"dew\", \"rope\"),\n",
820
+ " (\"Feder\", \"feather\", \"spring\"),\n",
821
+ " (\"Absatz\", \"heel\", \"paragraph\"),\n",
822
+ " (\"Band\", \"ribbon\", \"volume\"),\n",
823
+ " (\"Brücke\", \"bridge\", \"dental bridge\"),\n",
824
+ " (\"Flügel\", \"wing\", \"grand piano\"),\n",
825
+ " (\"Golf\", \"golf\", \"gulf\"),\n",
826
+ " (\"Grund\", \"reason\", \"ground\"),\n",
827
+ " (\"Hut\", \"hat\", \"guard\"),\n",
828
+ " (\"Kette\", \"chain\", \"necklace\"),\n",
829
+ " (\"Kran\", \"crane bird\", \"crane machine\"),\n",
830
+ " (\"Lauf\", \"run\", \"barrel\"),\n",
831
+ " (\"Linse\", \"lens\", \"lentil\"),\n",
832
+ " (\"Mark\", \"marrow\", \"mark currency\"),\n",
833
+ " (\"Masse\", \"mass\", \"crowd\"),\n",
834
+ " (\"Netz\", \"net\", \"network\"),\n",
835
+ " (\"Pony\", \"pony\", \"bangs\"),\n",
836
+ " (\"Raum\", \"room\", \"space\"),\n",
837
+ " (\"Reif\", \"hoop\", \"frost\"),\n",
838
+ " (\"Rock\", \"skirt\", \"rock music\"),\n",
839
+ " (\"Schalter\", \"switch\", \"counter\"),\n",
840
+ " (\"Schild\", \"sign\", \"shield\"),\n",
841
+ " (\"See\", \"lake\", \"sea\"),\n",
842
+ " (\"Seite\", \"side\", \"page\"),\n",
843
+ " (\"Star\", \"starling\", \"celebrity\"),\n",
844
+ " (\"Stock\", \"stick\", \"floor\"),\n",
845
+ " (\"Wahl\", \"choice\", \"election\"),\n",
846
+ " (\"Welle\", \"wave\", \"shaft\"),\n",
847
+ " (\"Zug\", \"train\", \"pull\"),\n",
848
+ "]\n",
849
+ "\n",
850
+ "# B. UNAMBIGUOUS CANDIDATES (German words with single clear meaning)\n",
851
+ "unambiguous_candidates = [\n",
852
+ " # Animals\n",
853
+ " (\"Katze\", \"cat\"), (\"Hund\", \"dog\"), (\"Pferd\", \"horse\"), (\"Vogel\", \"bird\"),\n",
854
+ " (\"Fisch\", \"fish\"), (\"Elefant\", \"elephant\"), (\"Löwe\", \"lion\"), (\"Bär\", \"bear\"),\n",
855
+ " (\"Tiger\", \"tiger\"), (\"Affe\", \"monkey\"), (\"Schaf\", \"sheep\"), (\"Kuh\", \"cow\"),\n",
856
+ " (\"Schwein\", \"pig\"), (\"Huhn\", \"chicken\"), (\"Ente\", \"duck\"), (\"Gans\", \"goose\"),\n",
857
+ " (\"Wolf\", \"wolf\"), (\"Fuchs\", \"fox\"), (\"Hase\", \"rabbit\"), (\"Hirsch\", \"deer\"),\n",
858
+ " (\"Frosch\", \"frog\"), (\"Spinne\", \"spider\"), (\"Biene\", \"bee\"), (\"Käfer\", \"beetle\"),\n",
859
+ "\n",
860
+ " # Objects\n",
861
+ " (\"Tisch\", \"table\"), (\"Stuhl\", \"chair\"), (\"Buch\", \"book\"), (\"Auto\", \"car\"),\n",
862
+ " (\"Haus\", \"house\"), (\"Fenster\", \"window\"), (\"Lampe\", \"lamp\"), (\"Telefon\", \"phone\"),\n",
863
+ " (\"Computer\", \"computer\"), (\"Uhr\", \"clock\"), (\"Brille\", \"glasses\"), (\"Schlüssel\", \"key\"),\n",
864
+ " (\"Tasche\", \"bag\"), (\"Schuh\", \"shoe\"), (\"Hemd\", \"shirt\"), (\"Hose\", \"pants\"),\n",
865
+ " (\"Kleid\", \"dress\"), (\"Jacke\", \"jacket\"), (\"Tür\", \"door\"), (\"Bett\", \"bed\"),\n",
866
+ " (\"Schrank\", \"closet\"), (\"Sofa\", \"sofa\"), (\"Spiegel\", \"mirror\"), (\"Teppich\", \"carpet\"),\n",
867
+ "\n",
868
+ " # Nature\n",
869
+ " (\"Baum\", \"tree\"), (\"Blume\", \"flower\"), (\"Wolke\", \"cloud\"), (\"Regen\", \"rain\"),\n",
870
+ " (\"Schnee\", \"snow\"), (\"Feuer\", \"fire\"), (\"Berg\", \"mountain\"), (\"Wald\", \"forest\"),\n",
871
+ " (\"Fluss\", \"river\"), (\"Himmel\", \"sky\"), (\"Stern\", \"star\"), (\"Mond\", \"moon\"),\n",
872
+ " (\"Sonne\", \"sun\"), (\"Gras\", \"grass\"), (\"Stein\", \"stone\"), (\"Sand\", \"sand\"),\n",
873
+ "\n",
874
+ " # Body parts\n",
875
+ " (\"Kopf\", \"head\"), (\"Auge\", \"eye\"), (\"Ohr\", \"ear\"), (\"Nase\", \"nose\"),\n",
876
+ " (\"Mund\", \"mouth\"), (\"Zahn\", \"tooth\"), (\"Zunge\", \"tongue\"), (\"Hals\", \"neck\"),\n",
877
+ " (\"Arm\", \"arm\"), (\"Bein\", \"leg\"), (\"Fuß\", \"foot\"), (\"Knie\", \"knee\"),\n",
878
+ " (\"Finger\", \"finger\"), (\"Herz\", \"heart\"), (\"Lunge\", \"lung\"), (\"Magen\", \"stomach\"),\n",
879
+ "\n",
880
+ " # Food & Drink\n",
881
+ " (\"Brot\", \"bread\"), (\"Käse\", \"cheese\"), (\"Apfel\", \"apple\"), (\"Wasser\", \"water\"),\n",
882
+ " (\"Milch\", \"milk\"), (\"Ei\", \"egg\"), (\"Fleisch\", \"meat\"), (\"Reis\", \"rice\"),\n",
883
+ " (\"Nudel\", \"noodle\"), (\"Suppe\", \"soup\"), (\"Salat\", \"salad\"), (\"Kuchen\", \"cake\"),\n",
884
+ " (\"Kaffee\", \"coffee\"), (\"Tee\", \"tea\"), (\"Bier\", \"beer\"), (\"Wein\", \"wine\"),\n",
885
+ " (\"Saft\", \"juice\"), (\"Zucker\", \"sugar\"), (\"Salz\", \"salt\"), (\"Butter\", \"butter\"),\n",
886
+ "\n",
887
+ " # Colors (as nouns)\n",
888
+ " (\"Rot\", \"red\"), (\"Blau\", \"blue\"), (\"Grün\", \"green\"), (\"Gelb\", \"yellow\"),\n",
889
+ " (\"Schwarz\", \"black\"), (\"Weiß\", \"white\"), (\"Braun\", \"brown\"), (\"Grau\", \"gray\"),\n",
890
+ "\n",
891
+ " # Numbers (as nouns)\n",
892
+ " (\"Eins\", \"one\"), (\"Zwei\", \"two\"), (\"Drei\", \"three\"), (\"Vier\", \"four\"),\n",
893
+ " (\"Fünf\", \"five\"), (\"Sechs\", \"six\"), (\"Sieben\", \"seven\"), (\"Acht\", \"eight\"),\n",
894
+ "\n",
895
+ " # Family\n",
896
+ " (\"Vater\", \"father\"), (\"Bruder\", \"brother\"), (\"Schwester\", \"sister\"),\n",
897
+ " (\"Onkel\", \"uncle\"), (\"Tante\", \"aunt\"), (\"Oma\", \"grandma\"), (\"Opa\", \"grandpa\"),\n",
898
+ "\n",
899
+ " # Professions\n",
900
+ " (\"Arzt\", \"doctor\"), (\"Lehrer\", \"teacher\"), (\"Koch\", \"cook\"), (\"Pilot\", \"pilot\"),\n",
901
+ " (\"Polizist\", \"policeman\"), (\"Bäcker\", \"baker\"), (\"Maler\", \"painter\"),\n",
902
+ "]\n",
903
+ "\n",
904
+ "# ==============================================================================\n",
905
+ "# 2. SINGLE-TOKEN VALIDATION FUNCTION\n",
906
+ "# ==============================================================================\n",
907
+ "\n",
908
+ "def is_single_token(word, tokenizer):\n",
909
+ " \"\"\"\n",
910
+ " Check if a word is represented as a single token.\n",
911
+ " Tests both with and without space prefix (BPE behavior varies).\n",
912
+ " \"\"\"\n",
913
+ " # Test 1: Raw word\n",
914
+ " tokens_raw = tokenizer.encode(word, add_special_tokens=False)\n",
915
+ "\n",
916
+ " # Test 2: With space prefix (common in BPE)\n",
917
+ " tokens_space = tokenizer.encode(\" \" + word, add_special_tokens=False)\n",
918
+ "\n",
919
+ " # Test 3: With article (might help with German nouns)\n",
920
+ " tokens_article = tokenizer.encode(\"Das \" + word, add_special_tokens=False)\n",
921
+ "\n",
922
+ " # Accept if ANY encoding is single token (or 2 tokens for article version)\n",
923
+ " is_single = (len(tokens_raw) == 1) or (len(tokens_space) == 1)\n",
924
+ "\n",
925
+ " return is_single, len(tokens_raw), len(tokens_space)\n",
926
+ "\n",
927
+ "def filter_single_tokens(candidates, tokenizer, is_ambiguous=True):\n",
928
+ " \"\"\"\n",
929
+ " Filter candidates to only include single-token words.\n",
930
+ " Returns validated list with token info.\n",
931
+ " \"\"\"\n",
932
+ " valid = []\n",
933
+ " rejected = []\n",
934
+ "\n",
935
+ " for item in candidates:\n",
936
+ " if is_ambiguous:\n",
937
+ " word = item[0]\n",
938
+ " else:\n",
939
+ " word = item[0]\n",
940
+ "\n",
941
+ " is_single, n_raw, n_space = is_single_token(word, tokenizer)\n",
942
+ "\n",
943
+ " if is_single:\n",
944
+ " valid.append(item)\n",
945
+ " else:\n",
946
+ " rejected.append((word, n_raw, n_space))\n",
947
+ "\n",
948
+ " return valid, rejected\n",
949
+ "\n",
950
+ "# ==============================================================================\n",
951
+ "# 3. GENERATION & ANALYSIS FUNCTIONS\n",
952
+ "# ==============================================================================\n",
953
+ "\n",
954
+ "def run_prism_gen(text_input, model, tokenizer, device, max_len=10):\n",
955
+ " \"\"\"Run PRISM generation\"\"\"\n",
956
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(device)\n",
957
+ " with torch.no_grad():\n",
958
+ " out_ids = model.generate(src=input_tensor, max_length=max_len, num_beams=1)\n",
959
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
960
+ "\n",
961
+ "def is_repetition_collapse(text):\n",
962
+ " \"\"\"Detect if output shows repetition collapse\"\"\"\n",
963
+ " if not text or len(text.split()) < 2:\n",
964
+ " return False\n",
965
+ "\n",
966
+ " words = text.lower().split()\n",
967
+ " # Clean punctuation\n",
968
+ " words = [w.strip('.,!?;:') for w in words]\n",
969
+ "\n",
970
+ " if len(words) < 2:\n",
971
+ " return False\n",
972
+ "\n",
973
+ " # Check various collapse patterns\n",
974
+ " # Pattern 1: Same word repeats\n",
975
+ " first_word = words[0]\n",
976
+ " same_count = sum(1 for w in words if w == first_word or first_word.startswith(w) or w.startswith(first_word))\n",
977
+ "\n",
978
+ " # Pattern 2: Substring repetition (e.g., \"dogdogdog\")\n",
979
+ " joined = ''.join(words)\n",
980
+ " if len(joined) > 3:\n",
981
+ " chunk = joined[:3]\n",
982
+ " if joined.count(chunk) >= 3:\n",
983
+ " return True\n",
984
+ "\n",
985
+ " return same_count >= len(words) * 0.5\n",
986
+ "\n",
987
+ "def check_correct(output, targets):\n",
988
+ " \"\"\"Check if any target word appears in output\"\"\"\n",
989
+ " output_lower = output.lower()\n",
990
+ " if isinstance(targets, str):\n",
991
+ " targets = [targets]\n",
992
+ " return any(t.lower() in output_lower for t in targets)\n",
993
+ "\n",
994
+ "# ==============================================================================\n",
995
+ "# 4. MAIN EXPERIMENT\n",
996
+ "# ==============================================================================\n",
997
+ "\n",
998
+ "print(\"=\" * 100)\n",
999
+ "print(\"🔬 LARGE-SCALE SINGLE-TOKEN CARRIER WAVE ANALYSIS\")\n",
1000
+ "print(\"=\" * 100)\n",
1001
+ "\n",
1002
+ "# A. Filter to single tokens only\n",
1003
+ "print(\"\\n📋 STEP 1: Validating Single-Token Candidates...\")\n",
1004
+ "print(\"-\" * 60)\n",
1005
+ "\n",
1006
+ "ambig_valid, ambig_rejected = filter_single_tokens(ambiguous_candidates, tokenizer, is_ambiguous=True)\n",
1007
+ "unambig_valid, unambig_rejected = filter_single_tokens(unambiguous_candidates, tokenizer, is_ambiguous=False)\n",
1008
+ "\n",
1009
+ "print(f\"AMBIGUOUS: {len(ambig_valid)} valid / {len(ambig_rejected)} rejected (multi-token)\")\n",
1010
+ "print(f\"UNAMBIGUOUS: {len(unambig_valid)} valid / {len(unambig_rejected)} rejected (multi-token)\")\n",
1011
+ "\n",
1012
+ "# Show some rejected examples\n",
1013
+ "if ambig_rejected:\n",
1014
+ " print(f\"\\n Rejected ambiguous (examples): {ambig_rejected[:5]}\")\n",
1015
+ "if unambig_rejected:\n",
1016
+ " print(f\" Rejected unambiguous (examples): {unambig_rejected[:5]}\")\n",
1017
+ "\n",
1018
+ "# B. Run experiments\n",
1019
+ "print(f\"\\n📋 STEP 2: Running Generation Tests...\")\n",
1020
+ "print(\"-\" * 60)\n",
1021
+ "\n",
1022
+ "# Storage\n",
1023
+ "results_ambig = []\n",
1024
+ "results_unambig = []\n",
1025
+ "\n",
1026
+ "# Test AMBIGUOUS tokens\n",
1027
+ "print(f\"\\n🔴 Testing {len(ambig_valid)} AMBIGUOUS single tokens...\")\n",
1028
+ "for word, meaning1, meaning2 in tqdm(ambig_valid):\n",
1029
+ " try:\n",
1030
+ " output = run_prism_gen(word, model, tokenizer, DEVICE)\n",
1031
+ " collapsed = is_repetition_collapse(output)\n",
1032
+ " # For ambiguous, \"correct\" is undefined - check if either meaning appears\n",
1033
+ " has_meaning = check_correct(output, [meaning1, meaning2])\n",
1034
+ "\n",
1035
+ " results_ambig.append({\n",
1036
+ " \"word\": word,\n",
1037
+ " \"meanings\": f\"{meaning1}/{meaning2}\",\n",
1038
+ " \"output\": output,\n",
1039
+ " \"collapsed\": collapsed,\n",
1040
+ " \"has_any_meaning\": has_meaning\n",
1041
+ " })\n",
1042
+ " except Exception as e:\n",
1043
+ " print(f\"Error on {word}: {e}\")\n",
1044
+ "\n",
1045
+ "# Test UNAMBIGUOUS tokens\n",
1046
+ "print(f\"\\n🟢 Testing {len(unambig_valid)} UNAMBIGUOUS single tokens...\")\n",
1047
+ "for word, meaning in tqdm(unambig_valid):\n",
1048
+ " try:\n",
1049
+ " output = run_prism_gen(word, model, tokenizer, DEVICE)\n",
1050
+ " collapsed = is_repetition_collapse(output)\n",
1051
+ " correct = check_correct(output, meaning)\n",
1052
+ "\n",
1053
+ " results_unambig.append({\n",
1054
+ " \"word\": word,\n",
1055
+ " \"meaning\": meaning,\n",
1056
+ " \"output\": output,\n",
1057
+ " \"collapsed\": collapsed,\n",
1058
+ " \"correct\": correct\n",
1059
+ " })\n",
1060
+ " except Exception as e:\n",
1061
+ " print(f\"Error on {word}: {e}\")\n",
1062
+ "\n",
1063
+ "# ==============================================================================\n",
1064
+ "# 5. STATISTICAL ANALYSIS\n",
1065
+ "# ==============================================================================\n",
1066
+ "\n",
1067
+ "df_ambig = pd.DataFrame(results_ambig)\n",
1068
+ "df_unambig = pd.DataFrame(results_unambig)\n",
1069
+ "\n",
1070
+ "print(\"\\n\" + \"=\" * 100)\n",
1071
+ "print(\"📊 RESULTS: AMBIGUOUS TOKENS (L=1)\")\n",
1072
+ "print(\"=\" * 100)\n",
1073
+ "print(f\"{'Word':<15} | {'Meanings':<25} | {'Output':<40} | {'Collapse?'}\")\n",
1074
+ "print(\"-\" * 100)\n",
1075
+ "\n",
1076
+ "for _, row in df_ambig.iterrows():\n",
1077
+ " collapse_mark = \"💥 YES\" if row['collapsed'] else \"No\"\n",
1078
+ " output_display = row['output'][:38] + \"..\" if len(row['output']) > 40 else row['output']\n",
1079
+ " print(f\"{row['word']:<15} | {row['meanings']:<25} | {output_display:<40} | {collapse_mark}\")\n",
1080
+ "\n",
1081
+ "print(\"\\n\" + \"=\" * 100)\n",
1082
+ "print(\"📊 RESULTS: UNAMBIGUOUS TOKENS (L=1)\")\n",
1083
+ "print(\"=\" * 100)\n",
1084
+ "print(f\"{'Word':<15} | {'Target':<15} | {'Output':<40} | {'Collapse?':<10} | {'Correct?'}\")\n",
1085
+ "print(\"-\" * 100)\n",
1086
+ "\n",
1087
+ "for _, row in df_unambig.iterrows():\n",
1088
+ " collapse_mark = \"💥 YES\" if row['collapsed'] else \"No\"\n",
1089
+ " correct_mark = \"✅\" if row['correct'] else \"❌\"\n",
1090
+ " output_display = row['output'][:38] + \"..\" if len(row['output']) > 40 else row['output']\n",
1091
+ " print(f\"{row['word']:<15} | {row['meaning']:<15} | {output_display:<40} | {collapse_mark:<10} | {correct_mark}\")\n",
1092
+ "\n",
1093
+ "# ==============================================================================\n",
1094
+ "# 6. SUMMARY STATISTICS\n",
1095
+ "# ==============================================================================\n",
1096
+ "\n",
1097
+ "print(\"\\n\" + \"=\" * 100)\n",
1098
+ "print(\"📈 SUMMARY STATISTICS\")\n",
1099
+ "print(\"=\" * 100)\n",
1100
+ "\n",
1101
+ "n_ambig = len(df_ambig)\n",
1102
+ "n_unambig = len(df_unambig)\n",
1103
+ "\n",
1104
+ "ambig_collapse_rate = df_ambig['collapsed'].sum() / n_ambig * 100 if n_ambig > 0 else 0\n",
1105
+ "unambig_collapse_rate = df_unambig['collapsed'].sum() / n_unambig * 100 if n_unambig > 0 else 0\n",
1106
+ "unambig_correct_rate = df_unambig['correct'].sum() / n_unambig * 100 if n_unambig > 0 else 0\n",
1107
+ "\n",
1108
+ "print(f\"\"\"\n",
1109
+ "┌─────────────────────────────────────────────────────────────┐\n",
1110
+ "│ CARRIER WAVE THRESHOLD │\n",
1111
+ "├─────────────────────────────────────────────────────────────┤\n",
1112
+ "│ Condition │ N │ Collapse Rate │ Correct │\n",
1113
+ "├─────────────────────────────────────────────────────────────┤\n",
1114
+ "│ AMBIGUOUS (L=1) │ {n_ambig:<5} │ {ambig_collapse_rate:>6.1f}% │ N/A │\n",
1115
+ "│ UNAMBIGUOUS (L=1) │ {n_unambig:<5} │ {unambig_collapse_rate:>6.1f}% │ {unambig_correct_rate:>5.1f}% │\n",
1116
+ "└─────────────────────────────────────────────────────────────┘\n",
1117
+ "\"\"\")\n",
1118
+ "\n",
1119
+ "# Statistical comparison\n",
1120
+ "print(\"🔬 STATISTICAL INTERPRETATION:\")\n",
1121
+ "print(\"-\" * 60)\n",
1122
+ "\n",
1123
+ "if ambig_collapse_rate > unambig_collapse_rate + 10:\n",
1124
+ " print(f\" → Ambiguous tokens collapse MORE ({ambig_collapse_rate:.1f}% vs {unambig_collapse_rate:.1f}%)\")\n",
1125
+ " print(f\" → Difference: {ambig_collapse_rate - unambig_collapse_rate:.1f} percentage points\")\n",
1126
+ " print(f\" → SUPPORTS: Ambiguity exacerbates L=1 failure\")\n",
1127
+ "elif abs(ambig_collapse_rate - unambig_collapse_rate) <= 10:\n",
1128
+ " print(f\" → Both collapse at similar rates ({ambig_collapse_rate:.1f}% vs {unambig_collapse_rate:.1f}%)\")\n",
1129
+ " print(f\" → SUPPORTS: L=1 failure is about sequence length, not ambiguity\")\n",
1130
+ "else:\n",
1131
+ " print(f\" → Unexpected pattern - investigate further\")\n",
1132
+ "\n",
1133
+ "print(f\"\\n → Unambiguous tokens that ARE correct despite collapse: {unambig_correct_rate:.1f}%\")\n",
1134
+ "print(f\" → This suggests embeddings DO encode meaning, but decoder loops anyway\")\n",
1135
+ "\n",
1136
+ "# ==============================================================================\n",
1137
+ "# 7. EXPORT FOR PAPER\n",
1138
+ "# ==============================================================================\n",
1139
+ "\n",
1140
+ "print(\"\\n\" + \"=\" * 100)\n",
1141
+ "print(\"📝 LATEX-READY TABLE\")\n",
1142
+ "print(\"=\" * 100)\n",
1143
+ "\n",
1144
+ "print(f\"\"\"\n",
1145
+ "\\\\begin{{table}}[h]\n",
1146
+ "\\\\centering\n",
1147
+ "\\\\caption{{Carrier Wave Threshold Analysis (Extended). Large-scale validation confirms\n",
1148
+ "that repetition collapse at $L=1$ affects both ambiguous and unambiguous tokens.}}\n",
1149
+ "\\\\label{{tab:carrier_wave_extended}}\n",
1150
+ "\\\\begin{{tabular}}{{lccc}}\n",
1151
+ "\\\\toprule\n",
1152
+ "\\\\textbf{{Condition}} & \\\\textbf{{N}} & \\\\textbf{{Collapse Rate}} & \\\\textbf{{Correct (if applicable)}} \\\\\\\\\n",
1153
+ "\\\\midrule\n",
1154
+ "Ambiguous ($L=1$) & {n_ambig} & {ambig_collapse_rate:.1f}\\\\% & N/A \\\\\\\\\n",
1155
+ "Unambiguous ($L=1$) & {n_unambig} & {unambig_collapse_rate:.1f}\\\\% & {unambig_correct_rate:.1f}\\\\% \\\\\\\\\n",
1156
+ "\\\\bottomrule\n",
1157
+ "\\\\end{{tabular}}\n",
1158
+ "\\\\end{{table}}\n",
1159
+ "\"\"\")"
1160
+ ],
1161
+ "metadata": {
1162
+ "id": "A1Ta_6di5s7E"
1163
+ },
1164
+ "execution_count": null,
1165
+ "outputs": []
1166
+ },
1167
+ {
1168
+ "cell_type": "code",
1169
+ "source": [
1170
+ "import torch\n",
1171
+ "import pandas as pd\n",
1172
+ "from tqdm import tqdm\n",
1173
+ "\n",
1174
+ "# ==============================================================================\n",
1175
+ "# 1. SETUP & HELPERS\n",
1176
+ "# ==============================================================================\n",
1177
+ "print(\"=\" * 100)\n",
1178
+ "print(\"🌊 PHASE 2: THE CARRIER WAVE RESURRECTION (L=2)\")\n",
1179
+ "print(\"=\" * 100)\n",
1180
+ "\n",
1181
+ "def add_minimal_context(word):\n",
1182
+ " \"\"\"Prepends a neutral article to force L >= 2\"\"\"\n",
1183
+ " return f\"Das {word}\"\n",
1184
+ "\n",
1185
+ "def run_prism_gen_exact(text_input, max_len=10):\n",
1186
+ " \"\"\"\n",
1187
+ " Exact generation wrapper matching previous usage.\n",
1188
+ " \"\"\"\n",
1189
+ " # 1. Tokenize (Get IDs only, no special tokens)\n",
1190
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
1191
+ "\n",
1192
+ " # 2. Call custom generate method exactly as before\n",
1193
+ " with torch.no_grad():\n",
1194
+ " out_ids = model.generate(src=input_tensor, max_length=max_len, num_beams=1)\n",
1195
+ "\n",
1196
+ " # 3. Decode\n",
1197
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
1198
+ "\n",
1199
+ "results_recovery = []\n",
1200
+ "\n",
1201
+ "# ==============================================================================\n",
1202
+ "# 2. RUN THE COMPARISON LOOP\n",
1203
+ "# ==============================================================================\n",
1204
+ "print(f\"🔄 Retesting {len(ambig_valid)} Ambiguous & {len(unambig_valid)} Unambiguous tokens with 'Das [X]'...\")\n",
1205
+ "\n",
1206
+ "# --- A. AMBIGUOUS RECOVERY ---\n",
1207
+ "# Expected format: (word, meaning1, meaning2)\n",
1208
+ "for item in tqdm(ambig_valid, desc=\"Ambiguous L=2\"):\n",
1209
+ " word = item[0]\n",
1210
+ " meanings = item[1:] # Capture all meanings provided in tuple\n",
1211
+ "\n",
1212
+ " try:\n",
1213
+ " # Run L=1 (Single Token)\n",
1214
+ " out_L1 = run_prism_gen_exact(word)\n",
1215
+ " is_collapsed_L1 = is_repetition_collapse(out_L1)\n",
1216
+ "\n",
1217
+ " # Run L=2 (Minimal Context)\n",
1218
+ " input_L2 = add_minimal_context(word)\n",
1219
+ " out_L2 = run_prism_gen_exact(input_L2)\n",
1220
+ " is_collapsed_L2 = is_repetition_collapse(out_L2)\n",
1221
+ "\n",
1222
+ " # Check Meaning Recovery: Does L=2 output contain any valid meaning?\n",
1223
+ " # We assume check_correct handles a list of valid targets\n",
1224
+ " has_meaning_L2 = check_correct(out_L2, meanings)\n",
1225
+ "\n",
1226
+ " results_recovery.append({\n",
1227
+ " \"Type\": \"Ambiguous\",\n",
1228
+ " \"Word\": word,\n",
1229
+ " \"L1_Output\": out_L1,\n",
1230
+ " \"L1_Collapse\": is_collapsed_L1,\n",
1231
+ " \"L2_Input\": input_L2,\n",
1232
+ " \"L2_Output\": out_L2,\n",
1233
+ " \"L2_Collapse\": is_collapsed_L2,\n",
1234
+ " \"L2_Success\": has_meaning_L2\n",
1235
+ " })\n",
1236
+ " except Exception as e:\n",
1237
+ " print(f\"Error on {word}: {e}\")\n",
1238
+ "\n",
1239
+ "# --- B. UNAMBIGUOUS RECOVERY ---\n",
1240
+ "# Expected format: (word, target)\n",
1241
+ "for item in tqdm(unambig_valid, desc=\"Unambiguous L=2\"):\n",
1242
+ " word = item[0]\n",
1243
+ " target = item[1]\n",
1244
+ "\n",
1245
+ " try:\n",
1246
+ " # Run L=1\n",
1247
+ " out_L1 = run_prism_gen_exact(word)\n",
1248
+ " is_collapsed_L1 = is_repetition_collapse(out_L1)\n",
1249
+ "\n",
1250
+ " # Run L=2\n",
1251
+ " input_L2 = add_minimal_context(word)\n",
1252
+ " out_L2 = run_prism_gen_exact(input_L2)\n",
1253
+ " is_collapsed_L2 = is_repetition_collapse(out_L2)\n",
1254
+ "\n",
1255
+ " # Check Accuracy\n",
1256
+ " is_correct_L2 = check_correct(out_L2, [target])\n",
1257
+ "\n",
1258
+ " results_recovery.append({\n",
1259
+ " \"Type\": \"Unambiguous\",\n",
1260
+ " \"Word\": word,\n",
1261
+ " \"L1_Output\": out_L1,\n",
1262
+ " \"L1_Collapse\": is_collapsed_L1,\n",
1263
+ " \"L2_Input\": input_L2,\n",
1264
+ " \"L2_Output\": out_L2,\n",
1265
+ " \"L2_Collapse\": is_collapsed_L2,\n",
1266
+ " \"L2_Success\": is_correct_L2\n",
1267
+ " })\n",
1268
+ " except Exception as e:\n",
1269
+ " print(f\"Error on {word}: {e}\")\n",
1270
+ "\n",
1271
+ "# ==============================================================================\n",
1272
+ "# 3. ANALYSIS & VISUALIZATION\n",
1273
+ "# ==============================================================================\n",
1274
+ "df_rec = pd.DataFrame(results_recovery)\n",
1275
+ "\n",
1276
+ "# Metrics Calculation\n",
1277
+ "total_ambig = len(df_rec[df_rec[\"Type\"] == \"Ambiguous\"])\n",
1278
+ "total_unambig = len(df_rec[df_rec[\"Type\"] == \"Unambiguous\"])\n",
1279
+ "\n",
1280
+ "# L1 Collapse counts\n",
1281
+ "collapse_L1_ambig = df_rec[(df_rec[\"Type\"] == \"Ambiguous\") & (df_rec[\"L1_Collapse\"])].shape[0]\n",
1282
+ "collapse_L1_unambig = df_rec[(df_rec[\"Type\"] == \"Unambiguous\") & (df_rec[\"L1_Collapse\"])].shape[0]\n",
1283
+ "\n",
1284
+ "# L2 Collapse counts\n",
1285
+ "collapse_L2_ambig = df_rec[(df_rec[\"Type\"] == \"Ambiguous\") & (df_rec[\"L2_Collapse\"])].shape[0]\n",
1286
+ "collapse_L2_unambig = df_rec[(df_rec[\"Type\"] == \"Unambiguous\") & (df_rec[\"L2_Collapse\"])].shape[0]\n",
1287
+ "\n",
1288
+ "# Resurrection counts (Collapsed at L1 AND Succeeded at L2)\n",
1289
+ "resurrected_ambig = df_rec[\n",
1290
+ " (df_rec[\"Type\"] == \"Ambiguous\") &\n",
1291
+ " (df_rec[\"L1_Collapse\"] == True) &\n",
1292
+ " (df_rec[\"L2_Success\"] == True)\n",
1293
+ "].shape[0]\n",
1294
+ "\n",
1295
+ "resurrected_unambig = df_rec[\n",
1296
+ " (df_rec[\"Type\"] == \"Unambiguous\") &\n",
1297
+ " (df_rec[\"L1_Collapse\"] == True) &\n",
1298
+ " (df_rec[\"L2_Success\"] == True)\n",
1299
+ "].shape[0]\n",
1300
+ "\n",
1301
+ "# Recovery Rates (Percentage of collapsed L1 tokens that were fixed)\n",
1302
+ "recov_rate_ambig = (resurrected_ambig / collapse_L1_ambig * 100) if collapse_L1_ambig > 0 else 0\n",
1303
+ "recov_rate_unambig = (resurrected_unambig / collapse_L1_unambig * 100) if collapse_L1_unambig > 0 else 0\n",
1304
+ "\n",
1305
+ "# --- ASCII TABLE OUTPUT ---\n",
1306
+ "print(\"\\n\" + \"=\"*110)\n",
1307
+ "print(\"🏥 THE RECOVERY WARD: L=1 vs L=2 COMPARISON\")\n",
1308
+ "print(\"=\"*110)\n",
1309
+ "print(f\"{'Word':<12} | {'L=1 Output (Collapsed)':<30} | {'➡️'} | {'L=2 Output (Recovered)':<30} | {'Status'}\")\n",
1310
+ "print(\"-\" * 110)\n",
1311
+ "\n",
1312
+ "# Show examples of successful recoveries\n",
1313
+ "recoveries = df_rec[(df_rec[\"L1_Collapse\"] == True) & (df_rec[\"L2_Success\"] == True)]\n",
1314
+ "for _, row in recoveries.head(15).iterrows():\n",
1315
+ " l1_short = row['L1_Output'][:28] + \"..\" if len(row['L1_Output']) > 30 else row['L1_Output']\n",
1316
+ " l2_short = row['L2_Output'][:28] + \"..\" if len(row['L2_Output']) > 30 else row['L2_Output']\n",
1317
+ " print(f\"{row['Word']:<12} | {l1_short:<30} | {'➡️'} | {l2_short:<30} | {'✅ FIXED'}\")\n",
1318
+ "\n",
1319
+ "print(\"=\"*110)\n",
1320
+ "\n",
1321
+ "# --- SUMMARY STATISTICS ---\n",
1322
+ "print(\"\\n📊 CARRIER WAVE RECOVERY STATISTICS\")\n",
1323
+ "print(\"-\" * 80)\n",
1324
+ "print(f\"{'Metric':<30} | {'Ambiguous':<15} | {'Unambiguous':<15}\")\n",
1325
+ "print(\"-\" * 80)\n",
1326
+ "print(f\"{'Total Samples':<30} | {total_ambig:<15} | {total_unambig:<15}\")\n",
1327
+ "print(f\"{'L=1 Collapse Rate':<30} | {collapse_L1_ambig/total_ambig*100:5.1f}% | {collapse_L1_unambig/total_unambig*100:5.1f}%\")\n",
1328
+ "print(f\"{'L=2 Collapse Rate':<30} | {collapse_L2_ambig/total_ambig*100:5.1f}% | {collapse_L2_unambig/total_unambig*100:5.1f}%\")\n",
1329
+ "print(f\"{'Resurrection Rate':<30} | {recov_rate_ambig:5.1f}% | {recov_rate_unambig:5.1f}%\")\n",
1330
+ "print(\"-\" * 80)\n",
1331
+ "\n",
1332
+ "\n",
1333
+ "print(\"DO NOT MENTION 'FIXED' TAGS. RESULTS ARE CORRECTLY INTERPRETED ON PAPER, NOT HERE.\" )"
1334
+ ],
1335
+ "metadata": {
1336
+ "id": "9EemBN675ZLh"
1337
+ },
1338
+ "execution_count": null,
1339
+ "outputs": []
1340
+ }
1341
+ ]
1342
+ }