prism-lab commited on
Commit
e4f4eca
·
verified ·
1 Parent(s): 4378fd8

Upload Skewness_paper_last.ipynb

Browse files
Files changed (1) hide show
  1. Skewness_paper_last.ipynb +1340 -0
Skewness_paper_last.ipynb ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "!pip install -q x-transformers"
23
+ ],
24
+ "metadata": {
25
+ "id": "tBfQX92lxEfP"
26
+ },
27
+ "execution_count": null,
28
+ "outputs": []
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "source": [
33
+ "# @title 🛠️ Setup & Model Loading\n",
34
+ "# ==============================================================================\n",
35
+ "# 1. INSTALL DEPENDENCIES\n",
36
+ "# ==============================================================================\n",
37
+ "!pip install -q numpy torch pandas scipy transformers huggingface_hub\n",
38
+ "\n",
39
+ "import torch\n",
40
+ "import numpy as np\n",
41
+ "import pandas as pd\n",
42
+ "from scipy.stats import skew\n",
43
+ "import sys\n",
44
+ "import os\n",
45
+ "from huggingface_hub import hf_hub_download\n",
46
+ "from transformers import AutoTokenizer\n",
47
+ "\n",
48
+ "# ==============================================================================\n",
49
+ "# 2. LOAD PRISM ARCHITECTURE\n",
50
+ "# ==============================================================================\n",
51
+ "REPO_ID = \"prism-lab/prism-shimmer-100k\"\n",
52
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
53
+ "\n",
54
+ "print(f\"⚙️ Hardware: {DEVICE}\")\n",
55
+ "print(f\"📥 Downloading Architecture from {REPO_ID}...\")\n",
56
+ "\n",
57
+ "# Download the model code to a local folder\n",
58
+ "os.makedirs(\"shimmer_code\", exist_ok=True)\n",
59
+ "hf_hub_download(repo_id=REPO_ID, filename=\"modeling_prism_gated.py\", local_dir=\"shimmer_code\")\n",
60
+ "sys.path.append(\"shimmer_code\")\n",
61
+ "\n",
62
+ "# Now we can import the class\n",
63
+ "from modeling_prism_gated import PRISMHybrid_RoPE\n",
64
+ "\n",
65
+ "# ==============================================================================\n",
66
+ "# 3. LOAD WEIGHTS\n",
67
+ "# ==============================================================================\n",
68
+ "print(\"📚 Loading Tokenizer...\")\n",
69
+ "tokenizer = AutoTokenizer.from_pretrained(REPO_ID)\n",
70
+ "\n",
71
+ "print(\"🏗️ Constructing PRISM Model...\")\n",
72
+ "CONFIG = {\n",
73
+ " \"vocab_size\": 58101,\n",
74
+ " \"d_model\": 512,\n",
75
+ " \"num_heads\": 8,\n",
76
+ " \"dff\": 2048,\n",
77
+ " \"dropout\": 0.1,\n",
78
+ " \"max_length\": 128,\n",
79
+ " \"num_encoder_layers\": 6,\n",
80
+ " \"num_refining_layers\": 0,\n",
81
+ " \"num_decoder_layers\": 6\n",
82
+ "}\n",
83
+ "model = PRISMHybrid_RoPE(**CONFIG)\n",
84
+ "\n",
85
+ "print(\"📥 Loading Checkpoint...\")\n",
86
+ "weights_path = hf_hub_download(repo_id=REPO_ID, filename=\"pytorch_model.bin\")\n",
87
+ "state_dict = torch.load(weights_path, map_location=DEVICE)\n",
88
+ "model.load_state_dict(state_dict)\n",
89
+ "model.to(DEVICE)\n",
90
+ "model.eval()\n",
91
+ "\n",
92
+ "print(\"✅ Model Ready for Probing.\")"
93
+ ],
94
+ "metadata": {
95
+ "id": "fn7A70MZxV1U"
96
+ },
97
+ "execution_count": null,
98
+ "outputs": []
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "source": [
103
+ "# @title 🧪 The Probe & Datasets\n",
104
+ "# ==============================================================================\n",
105
+ "# 1. DATASETS (The \"76 + 70\" Split)\n",
106
+ "# ==============================================================================\n",
107
+ "\n",
108
+ "# A. HARD MODE (Polysemous / Ambiguous)\n",
109
+ "# Words that require context to resolve (High Entropy)\n",
110
+ "raw_poly_candidates = [\n",
111
+ " # --- ORIGINAL SET ---\n",
112
+ " (\"Ich gehe zur Bank um Geld zu holen\", \"Bank\"), (\"Die Bank hat hohe Zinsen\", \"Bank\"),\n",
113
+ " (\"Wir saßen auf einer Bank im Park\", \"Bank\"), (\"Die Bank aus Holz war bequem\", \"Bank\"),\n",
114
+ " (\"Das Schloss hat viele Türme\", \"Schloss\"), (\"Der König wohnt im Schloss\", \"Schloss\"),\n",
115
+ " (\"Der Schlüssel steckt im Schloss\", \"Schloss\"), (\"Das Schloss an der Tür klemmt\", \"Schloss\"),\n",
116
+ " (\"Der Leiter der Firma ist streng\", \"Leiter\"), (\"Unser Leiter plant das Projekt\", \"Leiter\"),\n",
117
+ " (\"Ich steige auf die Leiter\", \"Leiter\"), (\"Die Leiter ist aus Aluminium\", \"Leiter\"),\n",
118
+ " (\"Die Lampe hängt an der Decke\", \"Decke\"), (\"Die Decke ist weiß gestrichen\", \"Decke\"),\n",
119
+ " (\"Mir ist kalt gib mir eine Decke\", \"Decke\"), (\"Die Decke aus Wolle ist warm\", \"Decke\"),\n",
120
+ " (\"Der Kiefer ist ein Nadelbaum\", \"Kiefer\"), (\"Das Holz der Kiefer ist weich\", \"Kiefer\"),\n",
121
+ " (\"Der Arzt röntgt meinen Kiefer\", \"Kiefer\"), (\"Er hat Schmerzen im Kiefer\", \"Kiefer\"),\n",
122
+ " (\"Der Strauß ist ein schneller Vogel\", \"Strauß\"), (\"Dieser Strauß kann nicht fliegen\", \"Strauß\"),\n",
123
+ " (\"Sie kaufte einen bunten Strauß\", \"Strauß\"), (\"Der Strauß Blumen duftet gut\", \"Strauß\"),\n",
124
+ " (\"Er schoss ein schönes Tor\", \"Tor\"), (\"Der Ball flog ins Tor\", \"Tor\"),\n",
125
+ " (\"Das eiserne Tor war verschlossen\", \"Tor\"), (\"Sie öffneten das große Tor\", \"Tor\"),\n",
126
+ " (\"Wir tanzen auf dem Ball\", \"Ball\"), (\"Der Maskenball war elegant\", \"Ball\"),\n",
127
+ " (\"Er warf den Ball weit weg\", \"Ball\"), (\"Der Ball ist rund und rot\", \"Ball\"),\n",
128
+ " (\"Die Schlange im Zoo ist giftig\", \"Schlange\"), (\"Die Schlange zischte laut\", \"Schlange\"),\n",
129
+ " (\"Wir stehen in einer langen Schlange\", \"Schlange\"), (\"Die Schlange an der Kasse war lang\", \"Schlange\"),\n",
130
+ " (\"Der Strom ist ausgefallen\", \"Strom\"), (\"Strom kostet viel Geld\", \"Strom\"),\n",
131
+ " (\"Der Strom fließt ins Meer\", \"Strom\"), (\"Wir schwammen gegen den Strom\", \"Strom\"),\n",
132
+ " (\"Seine Mutter ist sehr nett\", \"Mutter\"), (\"Die Mutter kocht das Essen\", \"Mutter\"),\n",
133
+ " (\"Die Mutter passt auf die Schraube\", \"Mutter\"), (\"Ich brauche eine neue Mutter\", \"Mutter\"),\n",
134
+ " (\"Die Birne schmeckt süß\", \"Birne\"), (\"Ich esse gerne eine Birne\", \"Birne\"),\n",
135
+ " (\"Die Birne in der Lampe ist kaputt\", \"Birne\"), (\"Wir müssen die Birne wechseln\", \"Birne\"),\n",
136
+ " # --- EXPANSION SET ---\n",
137
+ " (\"Das Gericht hat ihn verurteilt\", \"Gericht\"), (\"Der Anwalt geht zum Gericht\", \"Gericht\"),\n",
138
+ " (\"Mein Lieblingsessen ist ein Gericht aus Reis\", \"Gericht\"), (\"Das Gericht schmeckt sehr salzig\", \"Gericht\"),\n",
139
+ " (\"Der Ton war sehr laut\", \"Ton\"), (\"Ich hörte einen hohen Ton\", \"Ton\"),\n",
140
+ " (\"Die Vase ist aus Ton\", \"Ton\"), (\"Wir formen Figuren aus Ton\", \"Ton\"),\n",
141
+ " (\"Das Blatt fällt vom Baum\", \"Blatt\"), (\"Im Herbst werden die Blätter braun\", \"Blatt\"),\n",
142
+ " (\"Ich schreibe auf ein Blatt Papier\", \"Blatt\"), (\"Gib mir bitte ein leeres Blatt\", \"Blatt\"),\n",
143
+ " (\"Der Nagel steckt in der Wand\", \"Nagel\"), (\"Ich schlage den Nagel mit dem Hammer\", \"Nagel\"),\n",
144
+ " (\"Mein Nagel ist abgebrochen\", \"Nagel\"), (\"Sie lackiert sich den Nagel rot\", \"Nagel\"),\n",
145
+ " (\"Die Maus frisst den Käse\", \"Maus\"), (\"Die Katze jagt die Maus\", \"Maus\"),\n",
146
+ " (\"Ich klicke mit der Maus\", \"Maus\"), (\"Der Computer braucht eine neue Maus\", \"Maus\"),\n",
147
+ " (\"Die Erde dreht sich um die Sonne\", \"Erde\"), (\"Der Astronaut schaut auf die Erde\", \"Erde\"),\n",
148
+ " (\"Die Blume braucht frische Erde\", \"Erde\"), (\"Er gräbt ein Loch in die Erde\", \"Erde\"),\n",
149
+ " (\"Der Hahn kräht am Morgen\", \"Hahn\"), (\"Der Hahn hat bunte Federn\", \"Hahn\"),\n",
150
+ " (\"Der Wasserhahn tropft\", \"Hahn\"), (\"Dreh bitte den Hahn zu\", \"Hahn\"),\n",
151
+ " (\"Die Schale der Orange ist bitter\", \"Schale\"), (\"Er wirft die Schale weg\", \"Schale\"),\n",
152
+ " (\"Die Schale steht auf dem Tisch\", \"Schale\"), (\"Ich esse Müsli aus der Schale\", \"Schale\"),\n",
153
+ " (\"Der Bauer melkt die Kühe\", \"Bauer\"), (\"Der Bauer fährt auf dem Traktor\", \"Bauer\"),\n",
154
+ " (\"Ich ziehe den Bauer auf E4\", \"Bauer\"), (\"Der Bauer schlägt den Turm\", \"Bauer\"),\n",
155
+ "]\n",
156
+ "\n",
157
+ "# B. EASY MODE (Casual)\n",
158
+ "raw_casual_candidates = [\n",
159
+ " (\"Die Katze schläft\", \"Katze\"), (\"Der Hund bellt\", \"Hund\"), (\"Das Auto fährt\", \"Auto\"),\n",
160
+ " (\"Wasser ist nass\", \"Wasser\"), (\"Das Brot schmeckt gut\", \"Brot\"), (\"Die Sonne scheint\", \"Sonne\"),\n",
161
+ " (\"Der Mond leuchtet\", \"Mond\"), (\"Das Buch ist spannend\", \"Buch\"), (\"Der Tisch ist rund\", \"Tisch\"),\n",
162
+ " (\"Der Stuhl ist bequem\", \"Stuhl\"), (\"Der Apfel ist rot\", \"Apfel\"), (\"Meine Hand ist kalt\", \"Hand\"),\n",
163
+ " (\"Das Herz klopft\", \"Herz\"), (\"Wir haben Zeit\", \"Zeit\"), (\"Geld ist wichtig\", \"Geld\"),\n",
164
+ " (\"Musik ist schön\", \"Musik\"), (\"Der Film ist zu Ende\", \"Film\"), (\"Das Spiel beginnt\", \"Spiel\"),\n",
165
+ " (\"Die Schule ist aus\", \"Schule\"), (\"Die Stadt ist laut\", \"Stadt\"), (\"Der Fluss fließt\", \"Fluss\"),\n",
166
+ " (\"Das Meer ist tief\", \"Meer\"), (\"Kaffee ist schwarz\", \"Kaffee\"), (\"Milch ist weiß\", \"Milch\"),\n",
167
+ " (\"Der Bruder lacht\", \"Bruder\"), (\"Die Schwester weint\", \"Schwester\"), (\"Das Haus ist groß\", \"Haus\"),\n",
168
+ " (\"Der Garten ist grün\", \"Garten\"), (\"Der Sommer ist heiß\", \"Sommer\"), (\"Der Winter ist kalt\", \"Winter\"),\n",
169
+ " (\"Das Fenster ist offen\", \"Fenster\"), (\"Die Tür ist zu\", \"Tür\"), (\"Der Boden ist sauber\", \"Boden\"),\n",
170
+ " (\"Die Wand ist weiß\", \"Wand\"), (\"Das Dach ist rot\", \"Dach\"), (\"Der Wald ist dunkel\", \"Wald\"),\n",
171
+ " (\"Der Berg ist hoch\", \"Berg\"), (\"Der See ist ruhig\", \"See\"), (\"Das Tier ist wild\", \"Tier\"),\n",
172
+ " (\"Der Mensch denkt\", \"Mensch\"), (\"Das Kind spielt\", \"Kind\"), (\"Die Frau arbeitet\", \"Frau\"),\n",
173
+ " (\"Der Mann schläft\", \"Mann\"), (\"Das Auge sieht\", \"Auge\"), (\"Das Ohr hört\", \"Ohr\"),\n",
174
+ " (\"Die Nase riecht\", \"Nase\"), (\"Der Mund spricht\", \"Mund\"), (\"Der Arm ist stark\", \"Arm\"),\n",
175
+ " (\"Das Bein tut weh\", \"Bein\"), (\"Der Fuß ist groß\", \"Fuß\"), (\"Der Tee ist heiß\", \"Tee\"),\n",
176
+ " (\"Das Bier ist kalt\", \"Bier\"), (\"Der Wein ist rot\", \"Wein\"), (\"Das Glas ist voll\", \"Glas\"),\n",
177
+ " (\"Die Tasse ist leer\", \"Tasse\"), (\"Der Teller ist blau\", \"Teller\"), (\"Die Gabel ist spitz\", \"Gabel\"),\n",
178
+ " (\"Der Löffel ist rund\", \"Löffel\"), (\"Das Messer ist scharf\", \"Messer\"), (\"Der Stift schreibt\", \"Stift\"),\n",
179
+ " (\"Der Brief ist lang\", \"Brief\"), (\"Das Bild ist schön\", \"Bild\"), (\"Die Uhr tickt\", \"Uhr\"),\n",
180
+ " (\"Das Bett ist weich\", \"Bett\"), (\"Der Schrank ist voll\", \"Schrank\"), (\"Das Sofa ist neu\", \"Sofa\"),\n",
181
+ " (\"Das Radio spielt\", \"Radio\"), (\"Das Jahr ist um\", \"Jahr\"), (\"Der Tag war lang\", \"Tag\"),\n",
182
+ " (\"Die Nacht ist kurz\", \"Nacht\")\n",
183
+ "]\n",
184
+ "# ==============================================================================\n",
185
+ "# 2. HELPER FUNCTIONS\n",
186
+ "# ==============================================================================\n",
187
+ "def find_token_index(input_ids, target_word, tokenizer):\n",
188
+ " tokens = tokenizer.convert_ids_to_tokens(input_ids)\n",
189
+ " for i, t in enumerate(tokens):\n",
190
+ " # Clean BPE artifacts\n",
191
+ " clean = t.replace('Ġ', '').replace('▁', '').replace(' ', '')\n",
192
+ " if target_word.lower() == clean.lower():\n",
193
+ " return i\n",
194
+ " # Fallback\n",
195
+ " for i, t in enumerate(tokens):\n",
196
+ " clean = t.replace('Ġ', '').replace('▁', '').replace(' ', '')\n",
197
+ " if target_word.lower() in clean.lower():\n",
198
+ " return i\n",
199
+ " return 1\n",
200
+ "\n",
201
+ "def filter_dataset(candidates, tokenizer, label):\n",
202
+ " \"\"\"Ensures we only test single-token words to keep phase metrics valid.\"\"\"\n",
203
+ " valid_data = []\n",
204
+ " rejected_count = 0\n",
205
+ " print(f\"\\n🔍 Validating {label} Candidates...\")\n",
206
+ "\n",
207
+ " for context, target in candidates:\n",
208
+ " # Check standard and space-prefixed tokenization\n",
209
+ " t1 = tokenizer.encode(target, add_special_tokens=False)\n",
210
+ " t2 = tokenizer.encode(\" \" + target, add_special_tokens=False)\n",
211
+ "\n",
212
+ " if len(t1) == 1 or len(t2) == 1:\n",
213
+ " valid_data.append((context, target))\n",
214
+ " else:\n",
215
+ " rejected_count += 1\n",
216
+ "\n",
217
+ " print(f\" ✅ Accepted: {len(valid_data)} examples.\")\n",
218
+ " print(f\" ❌ Rejected: {rejected_count} multi-token words.\")\n",
219
+ " return valid_data\n",
220
+ "\n",
221
+ "# ==============================================================================\n",
222
+ "# 3. UNIFIED PROBE CLASS\n",
223
+ "# ==============================================================================\n",
224
+ "def run_unified_probe(model, tokenizer, dataset, label, device):\n",
225
+ " num_layers = len(model.prism_encoder.layers)\n",
226
+ " rotation_stats = {i: [] for i in range(num_layers)}\n",
227
+ "\n",
228
+ " # Hooks\n",
229
+ " hook_data = {}\n",
230
+ "\n",
231
+ " def physics_hook(layer_idx):\n",
232
+ " def hook(module, input, output):\n",
233
+ " x, y = input[0].detach(), output.detach()\n",
234
+ "\n",
235
+ " # --- PHASE ROTATION CALCULATION ---\n",
236
+ " # 1. Norms\n",
237
+ " norm_x = torch.norm(x, p=2, dim=-1)\n",
238
+ " norm_y = torch.norm(y, p=2, dim=-1)\n",
239
+ "\n",
240
+ " # 2. Flatten Complex to 2D Real [Batch, Seq, Dim*2]\n",
241
+ " # This allows us to calculate geometric angle\n",
242
+ " x_f = x.view(x.shape[0], x.shape[1], -1)\n",
243
+ " y_f = y.view(y.shape[0], y.shape[1], -1)\n",
244
+ "\n",
245
+ " # 3. Dot Product\n",
246
+ " dot = (x_f.real * y_f.real + x_f.imag * y_f.imag).sum(dim=-1)\n",
247
+ "\n",
248
+ " # 4. Angle (Arccos)\n",
249
+ " cosine = torch.clamp(dot / (norm_x * norm_y + 1e-9), -1.0, 1.0)\n",
250
+ " angle = torch.rad2deg(torch.acos(cosine)).cpu()\n",
251
+ "\n",
252
+ " hook_data[f'rot_{layer_idx}'] = angle\n",
253
+ " return hook\n",
254
+ "\n",
255
+ " # Register\n",
256
+ " model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
257
+ " for i, layer in enumerate(model.prism_encoder.layers):\n",
258
+ " layer.register_forward_hook(physics_hook(i))\n",
259
+ "\n",
260
+ " # Execute\n",
261
+ " model.eval()\n",
262
+ " print(f\"🔬 Running Probe on {len(dataset)} {label} examples...\")\n",
263
+ "\n",
264
+ " for context, target in dataset:\n",
265
+ " hook_data = {}\n",
266
+ " inputs = tokenizer(context, return_tensors=\"pt\").to(device)\n",
267
+ "\n",
268
+ " with torch.no_grad():\n",
269
+ " x = model.harmonic_embedding(inputs.input_ids)\n",
270
+ " src_mask = (inputs.input_ids == tokenizer.pad_token_id)\n",
271
+ " model.prism_encoder(x, src_mask)\n",
272
+ "\n",
273
+ " idx = find_token_index(inputs.input_ids[0], target, tokenizer)\n",
274
+ "\n",
275
+ " for i in range(num_layers):\n",
276
+ " if f'rot_{i}' in hook_data:\n",
277
+ " batch = hook_data[f'rot_{i}']\n",
278
+ " # Handle batch dimension if present\n",
279
+ " val = batch[0, idx].item() if batch.dim() > 1 else batch[idx].item()\n",
280
+ " rotation_stats[i].append(val)\n",
281
+ "\n",
282
+ " model.prism_encoder.apply(lambda m: m._forward_hooks.clear())\n",
283
+ " return pd.DataFrame(rotation_stats)\n",
284
+ "\n",
285
+ "# ==============================================================================\n",
286
+ "# 4. EXECUTION & REPORTING\n",
287
+ "# ==============================================================================\n",
288
+ "# A. Filter Data\n",
289
+ "ds_hard = filter_dataset(raw_poly_candidates, tokenizer, \"HARD (Polysemous)\")\n",
290
+ "ds_easy = filter_dataset(raw_casual_candidates, tokenizer, \"EASY (Casual)\")\n",
291
+ "\n",
292
+ "# B. Run Analysis\n",
293
+ "DEVICE = next(model.parameters()).device # Robust device check\n",
294
+ "df_hard = run_unified_probe(model, tokenizer, ds_hard, \"HARD\", DEVICE)\n",
295
+ "df_easy = run_unified_probe(model, tokenizer, ds_easy, \"EASY\", DEVICE)\n",
296
+ "\n",
297
+ "# C. Generate ASCII Tables\n",
298
+ "def print_stats(df, title):\n",
299
+ " print(f\"\\n📊 {title} (N={len(df)})\")\n",
300
+ " print(\"=\"*90)\n",
301
+ " print(f\"{'Lyr':<3} | {'Mean (°)':<10} | {'Median (°)':<10} | {'Max (°)':<10} | {'Skewness':<10} | {'Regime'}\")\n",
302
+ " print(\"-\" * 90)\n",
303
+ "\n",
304
+ " total_skew = 0\n",
305
+ " for col in df.columns:\n",
306
+ " d = df[col]\n",
307
+ " skew_val = d.skew()\n",
308
+ " total_skew += skew_val\n",
309
+ "\n",
310
+ " # Interpret Regime\n",
311
+ " if skew_val > 1.5: regime = \"⚡ STEERING (Heavy Tail)\"\n",
312
+ " elif skew_val > 0.5: regime = \"⚖️ HYBRID\"\n",
313
+ " else: regime = \"💤 INERTIAL\"\n",
314
+ "\n",
315
+ " print(f\"{col:<3} | {d.mean():6.2f} | {d.median():6.2f} | {d.max():6.2f} | {skew_val:6.2f} | {regime}\")\n",
316
+ "\n",
317
+ " print(\"-\" * 90)\n",
318
+ " print(f\"∑ INTEGRATED SKEWNESS (Metabolic Load): {total_skew:.2f}\")\n",
319
+ "\n",
320
+ "print_stats(df_hard, \"TABLE 3A: POLYSEMOUS TOKENS (Ambiguous)\")\n",
321
+ "print_stats(df_easy, \"TABLE 3B: CASUAL TOKENS (Unambiguous)\")"
322
+ ],
323
+ "metadata": {
324
+ "id": "hvHJcuqmxDkt"
325
+ },
326
+ "execution_count": null,
327
+ "outputs": []
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "source": [
332
+ "import matplotlib.pyplot as plt\n",
333
+ "import seaborn as sns\n",
334
+ "\n",
335
+ "# ==============================================================================\n",
336
+ "# FIGURE 3 GENERATOR (Robust N=76 Version)\n",
337
+ "# ==============================================================================\n",
338
+ "def plot_figure_3_robust(df_hard, df_easy, save_path=\"fig3_phase_steering_robust.png\"):\n",
339
+ " \"\"\"\n",
340
+ " Generates the \"Dual Regime\" Violin Plot matching the paper style.\n",
341
+ " Visualizes the 'Mid-Network Resolution' (Skew spike at Layer 3).\n",
342
+ " \"\"\"\n",
343
+ " # Setup the canvas (Two panels, shared Y-axis)\n",
344
+ " fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True, dpi=300)\n",
345
+ "\n",
346
+ " # 1. AMBIGUOUS PANEL (The Steering Regime)\n",
347
+ " # We use a Red palette to signify \"Metabolic Work\"\n",
348
+ " sns.violinplot(\n",
349
+ " data=df_hard,\n",
350
+ " palette=\"Reds\",\n",
351
+ " ax=axes[0],\n",
352
+ " inner=\"quartile\",\n",
353
+ " linewidth=1.2,\n",
354
+ " cut=0 # Don't extend past data range\n",
355
+ " )\n",
356
+ " axes[0].set_title(\"(A) Ambiguous (Steering)\", fontweight='bold', fontsize=12, color='darkred')\n",
357
+ " axes[0].set_ylabel(\"Phase Rotation (°)\", fontsize=11, fontweight='bold')\n",
358
+ " axes[0].set_xlabel(\"Layer Depth\", fontsize=10)\n",
359
+ " axes[0].grid(axis='y', linestyle='--', alpha=0.3)\n",
360
+ "\n",
361
+ " # 2. UNAMBIGUOUS PANEL (The Inertial Regime)\n",
362
+ " # We use a Blue/Green palette to signify \"Coasting\"\n",
363
+ " sns.violinplot(\n",
364
+ " data=df_easy,\n",
365
+ " palette=\"mako\",\n",
366
+ " ax=axes[1],\n",
367
+ " inner=\"quartile\",\n",
368
+ " linewidth=1.2,\n",
369
+ " cut=0\n",
370
+ " )\n",
371
+ " axes[1].set_title(\"(B) Unambiguous (Inertial)\", fontweight='bold', fontsize=12, color='darkgreen')\n",
372
+ " axes[1].set_xlabel(\"Layer Depth\", fontsize=10)\n",
373
+ " axes[1].grid(axis='y', linestyle='--', alpha=0.3)\n",
374
+ " axes[1].set_ylabel(\"\") # Remove redundant y-label\n",
375
+ "\n",
376
+ " # Formatting\n",
377
+ " plt.ylim(0, 30) # Focus on the active range (Max is ~22 deg)\n",
378
+ " sns.despine(trim=True, offset=5)\n",
379
+ " plt.tight_layout()\n",
380
+ "\n",
381
+ " # Save & Show\n",
382
+ " plt.savefig(save_path, bbox_inches='tight')\n",
383
+ " plt.show()\n",
384
+ " print(f\"✅ Updated Figure 3 saved to: {save_path}\")\n",
385
+ "\n",
386
+ "# Run with your existing dataframes\n",
387
+ "plot_figure_3_robust(df_hard, df_easy)"
388
+ ],
389
+ "metadata": {
390
+ "id": "bG4t_QIyycpD"
391
+ },
392
+ "execution_count": null,
393
+ "outputs": []
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "source": [
398
+ "# @title 🛠️ Fixed Stress Test (Using Custom Generation API)\n",
399
+ "import torch\n",
400
+ "import pandas as pd\n",
401
+ "from tqdm import tqdm\n",
402
+ "\n",
403
+ "# ==============================================================================\n",
404
+ "# 1. SETUP & DATA\n",
405
+ "# ==============================================================================\n",
406
+ "test_cases = [\n",
407
+ " # (German Word, Context Helper, Expected English)\n",
408
+ " (\"Bank\", \"Die Bank.\", \"bench\"), # Ambiguous: Bench vs Bank\n",
409
+ " (\"Schloss\", \"Das Schloss.\", \"castle\"), # Ambiguous: Castle vs Lock\n",
410
+ " (\"Leiter\", \"Der Leiter.\", \"leader\"), # Ambiguous: Ladder vs Leader\n",
411
+ " (\"Decke\", \"Die Decke.\", \"ceiling\"), # Ambiguous: Blanket vs Ceiling\n",
412
+ " (\"Kiefer\", \"Der Kiefer.\", \"jaw\"), # Ambiguous: Pine vs Jaw\n",
413
+ " (\"Gericht\", \"Das Gericht.\", \"court\"), # Ambiguous: Dish vs Court\n",
414
+ " (\"Steuer\", \"Das Steuer.\", \"helm\"), # Ambiguous: Tax vs Helm\n",
415
+ " (\"Hahn\", \"Der Hahn.\", \"rooster\"), # Ambiguous: Tap vs Rooster\n",
416
+ " (\"Tau\", \"Das Tau.\", \"rope\"), # Ambiguous: Dew vs Rope\n",
417
+ " (\"Strauß\", \"Der Strauß.\", \"bouquet\"), # Ambiguous: Ostrich vs Bouquet\n",
418
+ "]\n",
419
+ "\n",
420
+ "# ==============================================================================\n",
421
+ "# 2. THE EXPERIMENT LOOP\n",
422
+ "# ==============================================================================\n",
423
+ "results = []\n",
424
+ "print(f\"📉 Running Phase Interference Test on {len(test_cases)} ambiguous terms...\\n\")\n",
425
+ "\n",
426
+ "model.eval()\n",
427
+ "\n",
428
+ "# Helper to run your custom generate function\n",
429
+ "def run_prism_gen(text_input):\n",
430
+ " # 1. Tokenize (Get IDs only)\n",
431
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
432
+ "\n",
433
+ " # 2. Call YOUR custom generate method\n",
434
+ " # Signature: generate(self, src, max_length, num_beams=5)\n",
435
+ " with torch.no_grad():\n",
436
+ " out_ids = model.generate(src=input_tensor, max_length=10, num_beams=1)\n",
437
+ "\n",
438
+ " # 3. Decode\n",
439
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
440
+ "\n",
441
+ "for word, context_phrase, target in tqdm(test_cases):\n",
442
+ " try:\n",
443
+ " # --- PASS 1: ISOLATION (Single Token) ---\n",
444
+ " trans_iso = run_prism_gen(word)\n",
445
+ "\n",
446
+ " # --- PASS 2: INTERFERENCE (Context) ---\n",
447
+ " trans_int = run_prism_gen(context_phrase)\n",
448
+ "\n",
449
+ " # Log Result\n",
450
+ " # We flag it as \"FAIL\" (bug) if the isolation translation is WRONG.\n",
451
+ " # BUT for your paper, an \"Isolation Fail\" + \"Context Pass\" is actually a SUCCESSFUL scientific result.\n",
452
+ "\n",
453
+ " status = \"✅ Context Fix\" if (target.lower() not in trans_iso.lower() and target.lower() in trans_int.lower()) else \"Neutral\"\n",
454
+ "\n",
455
+ " results.append({\n",
456
+ " \"Source\": word,\n",
457
+ " \"Target\": target,\n",
458
+ " \"⛔ Isolation\": trans_iso,\n",
459
+ " \"✅ Context\": trans_int,\n",
460
+ " \"Outcome\": status\n",
461
+ " })\n",
462
+ " except Exception as e:\n",
463
+ " print(f\"Error on {word}: {e}\")\n",
464
+ "\n",
465
+ "# ==============================================================================\n",
466
+ "# 3. ANALYSIS & REPORT\n",
467
+ "# ==============================================================================\n",
468
+ "df_res = pd.DataFrame(results)\n",
469
+ "\n",
470
+ "print(\"\\n\\n🌊 EXPERIMENTAL RESULTS: The Single-Token Paradox\")\n",
471
+ "print(\"=\"*110)\n",
472
+ "print(f\"{'Input':<10} | {'Target':<10} | {'⛔ Isolation (No Wave)':<30} | {'✅ Context (Interference)':<30}\")\n",
473
+ "print(\"-\" * 110)\n",
474
+ "\n",
475
+ "success_scientific_count = 0\n",
476
+ "\n",
477
+ "for _, row in df_res.iterrows():\n",
478
+ " iso_text = row['⛔ Isolation']\n",
479
+ " ctx_text = row['✅ Context']\n",
480
+ "\n",
481
+ " # Visual Logic: If Isolation failed to find target, but Context found it -> Highlight\n",
482
+ " if row['Outcome'] == \"✅ Context Fix\":\n",
483
+ " iso_text = f\"--> {iso_text} <--\" # Shows the \"Inertial\" failure\n",
484
+ " success_scientific_count += 1\n",
485
+ "\n",
486
+ " print(f\"{row['Source']:<10} | {row['Target']:<10} | {iso_text:<30} | {ctx_text:<30}\")\n",
487
+ "\n",
488
+ "print(\"=\"*110)\n",
489
+ "print(f\"\\n🧪 SCIENTIFIC CONCLUSION:\")\n",
490
+ "if success_scientific_count > 0:\n",
491
+ " print(f\"🎉 OBSERVED: {success_scientific_count}/{len(test_cases)} cases showed the 'Physical Abstractor' effect!\")\n",
492
+ " print(\" The model failed to resolve ambiguity in isolation (as predicted) but resolved it with context.\")\n",
493
+ "else:\n",
494
+ " print(\"🤔 OBSERVED: The model resolved isolation perfectly. Rate-Coding (Magnitude) might be leaking.\")"
495
+ ],
496
+ "metadata": {
497
+ "id": "llU4wmT67ZRm"
498
+ },
499
+ "execution_count": null,
500
+ "outputs": []
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "source": [
505
+ "# @title 🧪 Sanity Check: Full Wave Packets (Sentences)\n",
506
+ "# ==============================================================================\n",
507
+ "# HYPOTHESIS:\n",
508
+ "# If the \"Single-Token Paradox\" is real, then full sentences (rich interference)\n",
509
+ "# should translate fluently, unlike the \"broken\" isolated tokens.\n",
510
+ "# ==============================================================================\n",
511
+ "\n",
512
+ "sanity_sentences = [\n",
513
+ " # 1. Standard Fluency (Control)\n",
514
+ " \"Das Haus ist groß und schön.\",\n",
515
+ " \"Die Katze schläft auf dem Sofa.\",\n",
516
+ " \"Wir gehen heute in den Park.\",\n",
517
+ " \"Das Wetter ist sehr gut.\",\n",
518
+ "\n",
519
+ " # 2. Contextual Resolution (The ambiguous words from before)\n",
520
+ " \"Der Lehrer schreibt an die Tafel.\", # \"Leiter\" implies leader/head, but checking context\n",
521
+ " \"Ich sitze auf einer Bank im Garten.\", # Should lock to \"Bench\"\n",
522
+ " \"Ich bringe mein Geld zur Bank.\", # Should lock to \"Bank\" (Financial)\n",
523
+ " \"Das Schloss ist alt und aus Stein.\", # Should lock to \"Castle\"\n",
524
+ " \"Der Schlüssel steckt im Schloss.\", # Should lock to \"Lock\"\n",
525
+ "\n",
526
+ " # 3. Complex Grammar (Phase coherence test)\n",
527
+ " \"Obwohl es regnet, gehe ich spazieren.\",\n",
528
+ " \"Wenn du Zeit hast, komm bitte vorbei.\"\n",
529
+ "]\n",
530
+ "\n",
531
+ "print(f\"🌊 Running Sanity Check on {len(sanity_sentences)} sentences...\\n\")\n",
532
+ "print(\"=\"*100)\n",
533
+ "print(f\"{'German Source':<40} | {'🇬🇧 PRISM Translation'}\")\n",
534
+ "print(\"-\" * 100)\n",
535
+ "\n",
536
+ "model.eval()\n",
537
+ "\n",
538
+ "# Re-using the helper from before\n",
539
+ "def run_prism_gen(text_input):\n",
540
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
541
+ " with torch.no_grad():\n",
542
+ " # Increased max_length for full sentences\n",
543
+ " out_ids = model.generate(src=input_tensor, max_length=40, num_beams=1)\n",
544
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
545
+ "\n",
546
+ "for sent in sanity_sentences:\n",
547
+ " try:\n",
548
+ " translation = run_prism_gen(sent)\n",
549
+ " print(f\"{sent:<40} | {translation}\")\n",
550
+ " except Exception as e:\n",
551
+ " print(f\"{sent:<40} | ❌ Error: {e}\")\n",
552
+ "\n",
553
+ "print(\"=\"*100)"
554
+ ],
555
+ "metadata": {
556
+ "id": "71sdvAAn8O63"
557
+ },
558
+ "execution_count": null,
559
+ "outputs": []
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "source": [
564
+ "# @title 🧪 Control Experiment: Unambiguous Single Tokens\n",
565
+ "# ==============================================================================\n",
566
+ "# HYPOTHESIS:\n",
567
+ "# If the single-token collapse is due to AMBIGUITY (phase can't resolve meaning),\n",
568
+ "# then UNAMBIGUOUS tokens should translate correctly even in isolation.\n",
569
+ "# If they ALSO collapse, the issue is purely L=1 (no interference at all).\n",
570
+ "# ==============================================================================\n",
571
+ "\n",
572
+ "import torch\n",
573
+ "import pandas as pd\n",
574
+ "from tqdm import tqdm\n",
575
+ "\n",
576
+ "# ==============================================================================\n",
577
+ "# 1. UNAMBIGUOUS TEST SET\n",
578
+ "# ==============================================================================\n",
579
+ "# These words have ONE clear meaning - no context needed for humans\n",
580
+ "unambiguous_cases = [\n",
581
+ " # (German Word, Expected English)\n",
582
+ " # --- Animals (No ambiguity) ---\n",
583
+ " (\"Katze\", \"cat\"),\n",
584
+ " (\"Hund\", \"dog\"),\n",
585
+ " (\"Pferd\", \"horse\"),\n",
586
+ " (\"Vogel\", \"bird\"),\n",
587
+ " (\"Fisch\", \"fish\"),\n",
588
+ " (\"Elefant\", \"elephant\"),\n",
589
+ " (\"Löwe\", \"lion\"),\n",
590
+ " (\"Bär\", \"bear\"),\n",
591
+ "\n",
592
+ " # --- Objects (No ambiguity) ---\n",
593
+ " (\"Tisch\", \"table\"),\n",
594
+ " (\"Stuhl\", \"chair\"),\n",
595
+ " (\"Buch\", \"book\"),\n",
596
+ " (\"Auto\", \"car\"),\n",
597
+ " (\"Haus\", \"house\"),\n",
598
+ " (\"Fenster\", \"window\"),\n",
599
+ " (\"Lampe\", \"lamp\"),\n",
600
+ " (\"Telefon\", \"phone\"),\n",
601
+ "\n",
602
+ " # --- Nature (No ambiguity) ---\n",
603
+ " (\"Baum\", \"tree\"),\n",
604
+ " (\"Blume\", \"flower\"),\n",
605
+ " (\"Wolke\", \"cloud\"),\n",
606
+ " (\"Regen\", \"rain\"),\n",
607
+ " (\"Schnee\", \"snow\"),\n",
608
+ " (\"Feuer\", \"fire\"),\n",
609
+ "\n",
610
+ " # --- Body Parts (No ambiguity) ---\n",
611
+ " (\"Kopf\", \"head\"),\n",
612
+ " (\"Auge\", \"eye\"),\n",
613
+ " (\"Ohr\", \"ear\"),\n",
614
+ " (\"Nase\", \"nose\"),\n",
615
+ " (\"Finger\", \"finger\"),\n",
616
+ "\n",
617
+ " # --- Food (No ambiguity) ---\n",
618
+ " (\"Brot\", \"bread\"),\n",
619
+ " (\"Käse\", \"cheese\"),\n",
620
+ " (\"Apfel\", \"apple\"),\n",
621
+ " (\"Wasser\", \"water\"),\n",
622
+ " (\"Milch\", \"milk\"),\n",
623
+ "]\n",
624
+ "\n",
625
+ "# ==============================================================================\n",
626
+ "# 2. THE EXPERIMENT\n",
627
+ "# ==============================================================================\n",
628
+ "results_unambig = []\n",
629
+ "print(f\"🔬 Running UNAMBIGUOUS Single-Token Test on {len(unambiguous_cases)} words...\\n\")\n",
630
+ "\n",
631
+ "model.eval()\n",
632
+ "\n",
633
+ "def run_prism_gen(text_input, max_len=10):\n",
634
+ " \"\"\"Helper to run generation\"\"\"\n",
635
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
636
+ " with torch.no_grad():\n",
637
+ " out_ids = model.generate(src=input_tensor, max_length=max_len, num_beams=1)\n",
638
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
639
+ "\n",
640
+ "def is_repetition_collapse(text):\n",
641
+ " \"\"\"Detect if output is repetitive garbage\"\"\"\n",
642
+ " words = text.split()\n",
643
+ " if len(words) < 2:\n",
644
+ " return False\n",
645
+ " # Check if first word repeats\n",
646
+ " first_word = words[0].lower().strip('.,!?')\n",
647
+ " repeat_count = sum(1 for w in words if w.lower().strip('.,!?') == first_word)\n",
648
+ " return repeat_count >= len(words) * 0.5 # 50%+ repetition = collapse\n",
649
+ "\n",
650
+ "def check_correct(output, target):\n",
651
+ " \"\"\"Check if target word appears in output\"\"\"\n",
652
+ " return target.lower() in output.lower()\n",
653
+ "\n",
654
+ "# Run the test\n",
655
+ "for word, target in tqdm(unambiguous_cases):\n",
656
+ " try:\n",
657
+ " # Single token translation\n",
658
+ " translation = run_prism_gen(word)\n",
659
+ "\n",
660
+ " # Analyze\n",
661
+ " collapsed = is_repetition_collapse(translation)\n",
662
+ " correct = check_correct(translation, target)\n",
663
+ "\n",
664
+ " # Also test with minimal context (article)\n",
665
+ " # German articles: der/die/das\n",
666
+ " context_translation = run_prism_gen(f\"Das {word}.\")\n",
667
+ " context_correct = check_correct(context_translation, target)\n",
668
+ "\n",
669
+ " results_unambig.append({\n",
670
+ " \"German\": word,\n",
671
+ " \"Target\": target,\n",
672
+ " \"Isolation\": translation,\n",
673
+ " \"Collapsed\": collapsed,\n",
674
+ " \"Correct\": correct,\n",
675
+ " \"With Article\": context_translation,\n",
676
+ " \"Context Correct\": context_correct\n",
677
+ " })\n",
678
+ " except Exception as e:\n",
679
+ " print(f\"Error on {word}: {e}\")\n",
680
+ "\n",
681
+ "# ==============================================================================\n",
682
+ "# 3. ANALYSIS & REPORT\n",
683
+ "# ==============================================================================\n",
684
+ "df_unambig = pd.DataFrame(results_unambig)\n",
685
+ "\n",
686
+ "# Stats\n",
687
+ "total = len(df_unambig)\n",
688
+ "collapsed_count = df_unambig['Collapsed'].sum()\n",
689
+ "correct_iso = df_unambig['Correct'].sum()\n",
690
+ "correct_ctx = df_unambig['Context Correct'].sum()\n",
691
+ "\n",
692
+ "print(\"\\n\" + \"=\"*120)\n",
693
+ "print(\"🧪 CONTROL EXPERIMENT: UNAMBIGUOUS SINGLE TOKENS\")\n",
694
+ "print(\"=\"*120)\n",
695
+ "print(f\"{'German':<12} | {'Target':<10} | {'⛔ Isolation (L=1)':<35} | {'Collapse?':<10} | {'✅ With Article':<30}\")\n",
696
+ "print(\"-\" * 120)\n",
697
+ "\n",
698
+ "for _, row in df_unambig.iterrows():\n",
699
+ " collapse_marker = \"💥 YES\" if row['Collapsed'] else \"No\"\n",
700
+ " iso_display = row['Isolation'][:33] + \"..\" if len(row['Isolation']) > 35 else row['Isolation']\n",
701
+ " ctx_display = row['With Article'][:28] + \"..\" if len(row['With Article']) > 30 else row['With Article']\n",
702
+ "\n",
703
+ " print(f\"{row['German']:<12} | {row['Target']:<10} | {iso_display:<35} | {collapse_marker:<10} | {ctx_display:<30}\")\n",
704
+ "\n",
705
+ "print(\"=\"*120)\n",
706
+ "\n",
707
+ "# ==============================================================================\n",
708
+ "# 4. SCIENTIFIC SUMMARY\n",
709
+ "# ==============================================================================\n",
710
+ "print(\"\\n📊 STATISTICAL SUMMARY\")\n",
711
+ "print(\"-\"*60)\n",
712
+ "print(f\"Total test cases: {total}\")\n",
713
+ "print(f\"Repetition collapses (L=1): {collapsed_count} ({100*collapsed_count/total:.1f}%)\")\n",
714
+ "print(f\"Correct in isolation: {correct_iso} ({100*correct_iso/total:.1f}%)\")\n",
715
+ "print(f\"Correct with article context: {correct_ctx} ({100*correct_ctx/total:.1f}%)\")\n",
716
+ "print(\"-\"*60)\n",
717
+ "\n",
718
+ "print(\"\\n🔬 INTERPRETATION:\")\n",
719
+ "if collapsed_count > total * 0.5:\n",
720
+ " print(\" 💥 FINDING: Unambiguous tokens ALSO collapse!\")\n",
721
+ " print(\" → This suggests L=1 provides insufficient interference for ANY semantic encoding.\")\n",
722
+ " print(\" → The decoder needs wave structure, not just meaning clarity.\")\n",
723
+ " print(\"\\n 📝 PAPER IMPLICATION: Phase interference is necessary for GENERATION,\")\n",
724
+ " print(\" not just disambiguation. Single tokens lack the spectral 'carrier wave'\")\n",
725
+ " print(\" needed to bootstrap autoregressive decoding.\")\n",
726
+ "elif collapsed_count > 0:\n",
727
+ " print(\" ⚖️ FINDING: Mixed results - some collapse, some don't.\")\n",
728
+ " print(\" → Suggests a threshold effect based on embedding quality or token frequency.\")\n",
729
+ "else:\n",
730
+ " print(\" ✅ FINDING: Unambiguous tokens translate correctly!\")\n",
731
+ " print(\" → This confirms the hypothesis: ambiguity requires interference to resolve,\")\n",
732
+ " print(\" but clear semantics can propagate through the encoder even at L=1.\")\n",
733
+ " print(\"\\n 📝 PAPER IMPLICATION: The single-token collapse is SPECIFIC to polysemy.\")\n",
734
+ " print(\" PRISM's phase encoding captures unambiguous semantics in static embeddings,\")\n",
735
+ " print(\" but requires interference patterns to perform 'semantic selection'.\")\n",
736
+ "\n",
737
+ "# ==============================================================================\n",
738
+ "# 5. COMPARISON TABLE (Side by Side)\n",
739
+ "# ==============================================================================\n",
740
+ "print(\"\\n\\n\" + \"=\"*80)\n",
741
+ "print(\"📈 AMBIGUOUS vs UNAMBIGUOUS COMPARISON\")\n",
742
+ "print(\"=\"*80)\n",
743
+ "print(f\"{'Metric':<40} | {'Ambiguous':<15} | {'Unambiguous':<15}\")\n",
744
+ "print(\"-\"*80)\n",
745
+ "print(f\"{'Repetition Collapse Rate':<40} | {'~100%':<15} | {f'{100*collapsed_count/total:.0f}%':<15}\")\n",
746
+ "print(f\"{'Correct in Isolation':<40} | {'~0%':<15} | {f'{100*correct_iso/total:.0f}%':<15}\")\n",
747
+ "print(f\"{'Correct with Minimal Context':<40} | {'Partial':<15} | {f'{100*correct_ctx/total:.0f}%':<15}\")\n",
748
+ "print(\"=\"*80)"
749
+ ],
750
+ "metadata": {
751
+ "id": "k6rH5YJWauTc"
752
+ },
753
+ "execution_count": null,
754
+ "outputs": []
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "source": [
759
+ "# Test with num_beams=1 (greedy) vs num_beams=5 (beam search)\n",
760
+ "word = \"Hund\"\n",
761
+ "input_tensor = tokenizer(word, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
762
+ "\n",
763
+ "with torch.no_grad():\n",
764
+ " out_greedy = model.generate(src=input_tensor, max_length=10, num_beams=1)\n",
765
+ " out_beam = model.generate(src=input_tensor, max_length=10, num_beams=5)\n",
766
+ "\n",
767
+ "print(f\"Greedy (beams=1): {tokenizer.decode(out_greedy[0], skip_special_tokens=True)}\")\n",
768
+ "print(f\"Beam (beams=5): {tokenizer.decode(out_beam[0], skip_special_tokens=True)}\")"
769
+ ],
770
+ "metadata": {
771
+ "id": "E0WRta7Qdjcg"
772
+ },
773
+ "execution_count": null,
774
+ "outputs": []
775
+ },
776
+ {
777
+ "cell_type": "code",
778
+ "source": [
779
+ "# @title 🧪 Large-Scale Single-Token Analysis (N >> 32)\n",
780
+ "# ==============================================================================\n",
781
+ "# GOAL: Increase sample size with automatic single-token validation\n",
782
+ "# ==============================================================================\n",
783
+ "\n",
784
+ "import torch\n",
785
+ "import pandas as pd\n",
786
+ "from tqdm import tqdm\n",
787
+ "\n",
788
+ "# ==============================================================================\n",
789
+ "# 1. LARGE CANDIDATE POOLS\n",
790
+ "# ==============================================================================\n",
791
+ "\n",
792
+ "# A. AMBIGUOUS CANDIDATES (German words with multiple meanings)\n",
793
+ "ambiguous_candidates = [\n",
794
+ " # Word, Meaning1, Meaning2\n",
795
+ " (\"Bank\", \"bench\", \"bank\"),\n",
796
+ " (\"Schloss\", \"castle\", \"lock\"),\n",
797
+ " (\"Leiter\", \"ladder\", \"leader\"),\n",
798
+ " (\"Decke\", \"ceiling\", \"blanket\"),\n",
799
+ " (\"Kiefer\", \"pine\", \"jaw\"),\n",
800
+ " (\"Strauß\", \"ostrich\", \"bouquet\"),\n",
801
+ " (\"Tor\", \"gate\", \"goal\"),\n",
802
+ " (\"Ball\", \"ball\", \"dance\"),\n",
803
+ " (\"Schlange\", \"snake\", \"queue\"),\n",
804
+ " (\"Strom\", \"electricity\", \"river\"),\n",
805
+ " (\"Mutter\", \"mother\", \"nut\"),\n",
806
+ " (\"Birne\", \"pear\", \"lightbulb\"),\n",
807
+ " (\"Gericht\", \"court\", \"dish\"),\n",
808
+ " (\"Ton\", \"sound\", \"clay\"),\n",
809
+ " (\"Blatt\", \"leaf\", \"sheet\"),\n",
810
+ " (\"Nagel\", \"nail\", \"fingernail\"),\n",
811
+ " (\"Maus\", \"mouse\", \"computer mouse\"),\n",
812
+ " (\"Erde\", \"earth\", \"soil\"),\n",
813
+ " (\"Hahn\", \"rooster\", \"tap\"),\n",
814
+ " (\"Schale\", \"shell\", \"bowl\"),\n",
815
+ " (\"Bauer\", \"farmer\", \"pawn\"),\n",
816
+ " (\"Steuer\", \"tax\", \"steering wheel\"),\n",
817
+ " (\"Tau\", \"dew\", \"rope\"),\n",
818
+ " (\"Feder\", \"feather\", \"spring\"),\n",
819
+ " (\"Absatz\", \"heel\", \"paragraph\"),\n",
820
+ " (\"Band\", \"ribbon\", \"volume\"),\n",
821
+ " (\"Brücke\", \"bridge\", \"dental bridge\"),\n",
822
+ " (\"Flügel\", \"wing\", \"grand piano\"),\n",
823
+ " (\"Golf\", \"golf\", \"gulf\"),\n",
824
+ " (\"Grund\", \"reason\", \"ground\"),\n",
825
+ " (\"Hut\", \"hat\", \"guard\"),\n",
826
+ " (\"Kette\", \"chain\", \"necklace\"),\n",
827
+ " (\"Kran\", \"crane bird\", \"crane machine\"),\n",
828
+ " (\"Lauf\", \"run\", \"barrel\"),\n",
829
+ " (\"Linse\", \"lens\", \"lentil\"),\n",
830
+ " (\"Mark\", \"marrow\", \"mark currency\"),\n",
831
+ " (\"Masse\", \"mass\", \"crowd\"),\n",
832
+ " (\"Netz\", \"net\", \"network\"),\n",
833
+ " (\"Pony\", \"pony\", \"bangs\"),\n",
834
+ " (\"Raum\", \"room\", \"space\"),\n",
835
+ " (\"Reif\", \"hoop\", \"frost\"),\n",
836
+ " (\"Rock\", \"skirt\", \"rock music\"),\n",
837
+ " (\"Schalter\", \"switch\", \"counter\"),\n",
838
+ " (\"Schild\", \"sign\", \"shield\"),\n",
839
+ " (\"See\", \"lake\", \"sea\"),\n",
840
+ " (\"Seite\", \"side\", \"page\"),\n",
841
+ " (\"Star\", \"starling\", \"celebrity\"),\n",
842
+ " (\"Stock\", \"stick\", \"floor\"),\n",
843
+ " (\"Wahl\", \"choice\", \"election\"),\n",
844
+ " (\"Welle\", \"wave\", \"shaft\"),\n",
845
+ " (\"Zug\", \"train\", \"pull\"),\n",
846
+ "]\n",
847
+ "\n",
848
+ "# B. UNAMBIGUOUS CANDIDATES (German words with single clear meaning)\n",
849
+ "unambiguous_candidates = [\n",
850
+ " # Animals\n",
851
+ " (\"Katze\", \"cat\"), (\"Hund\", \"dog\"), (\"Pferd\", \"horse\"), (\"Vogel\", \"bird\"),\n",
852
+ " (\"Fisch\", \"fish\"), (\"Elefant\", \"elephant\"), (\"Löwe\", \"lion\"), (\"Bär\", \"bear\"),\n",
853
+ " (\"Tiger\", \"tiger\"), (\"Affe\", \"monkey\"), (\"Schaf\", \"sheep\"), (\"Kuh\", \"cow\"),\n",
854
+ " (\"Schwein\", \"pig\"), (\"Huhn\", \"chicken\"), (\"Ente\", \"duck\"), (\"Gans\", \"goose\"),\n",
855
+ " (\"Wolf\", \"wolf\"), (\"Fuchs\", \"fox\"), (\"Hase\", \"rabbit\"), (\"Hirsch\", \"deer\"),\n",
856
+ " (\"Frosch\", \"frog\"), (\"Spinne\", \"spider\"), (\"Biene\", \"bee\"), (\"Käfer\", \"beetle\"),\n",
857
+ "\n",
858
+ " # Objects\n",
859
+ " (\"Tisch\", \"table\"), (\"Stuhl\", \"chair\"), (\"Buch\", \"book\"), (\"Auto\", \"car\"),\n",
860
+ " (\"Haus\", \"house\"), (\"Fenster\", \"window\"), (\"Lampe\", \"lamp\"), (\"Telefon\", \"phone\"),\n",
861
+ " (\"Computer\", \"computer\"), (\"Uhr\", \"clock\"), (\"Brille\", \"glasses\"), (\"Schlüssel\", \"key\"),\n",
862
+ " (\"Tasche\", \"bag\"), (\"Schuh\", \"shoe\"), (\"Hemd\", \"shirt\"), (\"Hose\", \"pants\"),\n",
863
+ " (\"Kleid\", \"dress\"), (\"Jacke\", \"jacket\"), (\"Tür\", \"door\"), (\"Bett\", \"bed\"),\n",
864
+ " (\"Schrank\", \"closet\"), (\"Sofa\", \"sofa\"), (\"Spiegel\", \"mirror\"), (\"Teppich\", \"carpet\"),\n",
865
+ "\n",
866
+ " # Nature\n",
867
+ " (\"Baum\", \"tree\"), (\"Blume\", \"flower\"), (\"Wolke\", \"cloud\"), (\"Regen\", \"rain\"),\n",
868
+ " (\"Schnee\", \"snow\"), (\"Feuer\", \"fire\"), (\"Berg\", \"mountain\"), (\"Wald\", \"forest\"),\n",
869
+ " (\"Fluss\", \"river\"), (\"Himmel\", \"sky\"), (\"Stern\", \"star\"), (\"Mond\", \"moon\"),\n",
870
+ " (\"Sonne\", \"sun\"), (\"Gras\", \"grass\"), (\"Stein\", \"stone\"), (\"Sand\", \"sand\"),\n",
871
+ "\n",
872
+ " # Body parts\n",
873
+ " (\"Kopf\", \"head\"), (\"Auge\", \"eye\"), (\"Ohr\", \"ear\"), (\"Nase\", \"nose\"),\n",
874
+ " (\"Mund\", \"mouth\"), (\"Zahn\", \"tooth\"), (\"Zunge\", \"tongue\"), (\"Hals\", \"neck\"),\n",
875
+ " (\"Arm\", \"arm\"), (\"Bein\", \"leg\"), (\"Fuß\", \"foot\"), (\"Knie\", \"knee\"),\n",
876
+ " (\"Finger\", \"finger\"), (\"Herz\", \"heart\"), (\"Lunge\", \"lung\"), (\"Magen\", \"stomach\"),\n",
877
+ "\n",
878
+ " # Food & Drink\n",
879
+ " (\"Brot\", \"bread\"), (\"Käse\", \"cheese\"), (\"Apfel\", \"apple\"), (\"Wasser\", \"water\"),\n",
880
+ " (\"Milch\", \"milk\"), (\"Ei\", \"egg\"), (\"Fleisch\", \"meat\"), (\"Reis\", \"rice\"),\n",
881
+ " (\"Nudel\", \"noodle\"), (\"Suppe\", \"soup\"), (\"Salat\", \"salad\"), (\"Kuchen\", \"cake\"),\n",
882
+ " (\"Kaffee\", \"coffee\"), (\"Tee\", \"tea\"), (\"Bier\", \"beer\"), (\"Wein\", \"wine\"),\n",
883
+ " (\"Saft\", \"juice\"), (\"Zucker\", \"sugar\"), (\"Salz\", \"salt\"), (\"Butter\", \"butter\"),\n",
884
+ "\n",
885
+ " # Colors (as nouns)\n",
886
+ " (\"Rot\", \"red\"), (\"Blau\", \"blue\"), (\"Grün\", \"green\"), (\"Gelb\", \"yellow\"),\n",
887
+ " (\"Schwarz\", \"black\"), (\"Weiß\", \"white\"), (\"Braun\", \"brown\"), (\"Grau\", \"gray\"),\n",
888
+ "\n",
889
+ " # Numbers (as nouns)\n",
890
+ " (\"Eins\", \"one\"), (\"Zwei\", \"two\"), (\"Drei\", \"three\"), (\"Vier\", \"four\"),\n",
891
+ " (\"Fünf\", \"five\"), (\"Sechs\", \"six\"), (\"Sieben\", \"seven\"), (\"Acht\", \"eight\"),\n",
892
+ "\n",
893
+ " # Family\n",
894
+ " (\"Vater\", \"father\"), (\"Bruder\", \"brother\"), (\"Schwester\", \"sister\"),\n",
895
+ " (\"Onkel\", \"uncle\"), (\"Tante\", \"aunt\"), (\"Oma\", \"grandma\"), (\"Opa\", \"grandpa\"),\n",
896
+ "\n",
897
+ " # Professions\n",
898
+ " (\"Arzt\", \"doctor\"), (\"Lehrer\", \"teacher\"), (\"Koch\", \"cook\"), (\"Pilot\", \"pilot\"),\n",
899
+ " (\"Polizist\", \"policeman\"), (\"Bäcker\", \"baker\"), (\"Maler\", \"painter\"),\n",
900
+ "]\n",
901
+ "\n",
902
+ "# ==============================================================================\n",
903
+ "# 2. SINGLE-TOKEN VALIDATION FUNCTION\n",
904
+ "# ==============================================================================\n",
905
+ "\n",
906
+ "def is_single_token(word, tokenizer):\n",
907
+ " \"\"\"\n",
908
+ " Check if a word is represented as a single token.\n",
909
+ " Tests both with and without space prefix (BPE behavior varies).\n",
910
+ " \"\"\"\n",
911
+ " # Test 1: Raw word\n",
912
+ " tokens_raw = tokenizer.encode(word, add_special_tokens=False)\n",
913
+ "\n",
914
+ " # Test 2: With space prefix (common in BPE)\n",
915
+ " tokens_space = tokenizer.encode(\" \" + word, add_special_tokens=False)\n",
916
+ "\n",
917
+ " # Test 3: With article (might help with German nouns)\n",
918
+ " tokens_article = tokenizer.encode(\"Das \" + word, add_special_tokens=False)\n",
919
+ "\n",
920
+ " # Accept if ANY encoding is single token (or 2 tokens for article version)\n",
921
+ " is_single = (len(tokens_raw) == 1) or (len(tokens_space) == 1)\n",
922
+ "\n",
923
+ " return is_single, len(tokens_raw), len(tokens_space)\n",
924
+ "\n",
925
+ "def filter_single_tokens(candidates, tokenizer, is_ambiguous=True):\n",
926
+ " \"\"\"\n",
927
+ " Filter candidates to only include single-token words.\n",
928
+ " Returns validated list with token info.\n",
929
+ " \"\"\"\n",
930
+ " valid = []\n",
931
+ " rejected = []\n",
932
+ "\n",
933
+ " for item in candidates:\n",
934
+ " if is_ambiguous:\n",
935
+ " word = item[0]\n",
936
+ " else:\n",
937
+ " word = item[0]\n",
938
+ "\n",
939
+ " is_single, n_raw, n_space = is_single_token(word, tokenizer)\n",
940
+ "\n",
941
+ " if is_single:\n",
942
+ " valid.append(item)\n",
943
+ " else:\n",
944
+ " rejected.append((word, n_raw, n_space))\n",
945
+ "\n",
946
+ " return valid, rejected\n",
947
+ "\n",
948
+ "# ==============================================================================\n",
949
+ "# 3. GENERATION & ANALYSIS FUNCTIONS\n",
950
+ "# ==============================================================================\n",
951
+ "\n",
952
+ "def run_prism_gen(text_input, model, tokenizer, device, max_len=10):\n",
953
+ " \"\"\"Run PRISM generation\"\"\"\n",
954
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(device)\n",
955
+ " with torch.no_grad():\n",
956
+ " out_ids = model.generate(src=input_tensor, max_length=max_len, num_beams=1)\n",
957
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
958
+ "\n",
959
+ "def is_repetition_collapse(text):\n",
960
+ " \"\"\"Detect if output shows repetition collapse\"\"\"\n",
961
+ " if not text or len(text.split()) < 2:\n",
962
+ " return False\n",
963
+ "\n",
964
+ " words = text.lower().split()\n",
965
+ " # Clean punctuation\n",
966
+ " words = [w.strip('.,!?;:') for w in words]\n",
967
+ "\n",
968
+ " if len(words) < 2:\n",
969
+ " return False\n",
970
+ "\n",
971
+ " # Check various collapse patterns\n",
972
+ " # Pattern 1: Same word repeats\n",
973
+ " first_word = words[0]\n",
974
+ " same_count = sum(1 for w in words if w == first_word or first_word.startswith(w) or w.startswith(first_word))\n",
975
+ "\n",
976
+ " # Pattern 2: Substring repetition (e.g., \"dogdogdog\")\n",
977
+ " joined = ''.join(words)\n",
978
+ " if len(joined) > 3:\n",
979
+ " chunk = joined[:3]\n",
980
+ " if joined.count(chunk) >= 3:\n",
981
+ " return True\n",
982
+ "\n",
983
+ " return same_count >= len(words) * 0.5\n",
984
+ "\n",
985
+ "def check_correct(output, targets):\n",
986
+ " \"\"\"Check if any target word appears in output\"\"\"\n",
987
+ " output_lower = output.lower()\n",
988
+ " if isinstance(targets, str):\n",
989
+ " targets = [targets]\n",
990
+ " return any(t.lower() in output_lower for t in targets)\n",
991
+ "\n",
992
+ "# ==============================================================================\n",
993
+ "# 4. MAIN EXPERIMENT\n",
994
+ "# ==============================================================================\n",
995
+ "\n",
996
+ "print(\"=\" * 100)\n",
997
+ "print(\"🔬 LARGE-SCALE SINGLE-TOKEN CARRIER WAVE ANALYSIS\")\n",
998
+ "print(\"=\" * 100)\n",
999
+ "\n",
1000
+ "# A. Filter to single tokens only\n",
1001
+ "print(\"\\n📋 STEP 1: Validating Single-Token Candidates...\")\n",
1002
+ "print(\"-\" * 60)\n",
1003
+ "\n",
1004
+ "ambig_valid, ambig_rejected = filter_single_tokens(ambiguous_candidates, tokenizer, is_ambiguous=True)\n",
1005
+ "unambig_valid, unambig_rejected = filter_single_tokens(unambiguous_candidates, tokenizer, is_ambiguous=False)\n",
1006
+ "\n",
1007
+ "print(f\"AMBIGUOUS: {len(ambig_valid)} valid / {len(ambig_rejected)} rejected (multi-token)\")\n",
1008
+ "print(f\"UNAMBIGUOUS: {len(unambig_valid)} valid / {len(unambig_rejected)} rejected (multi-token)\")\n",
1009
+ "\n",
1010
+ "# Show some rejected examples\n",
1011
+ "if ambig_rejected:\n",
1012
+ " print(f\"\\n Rejected ambiguous (examples): {ambig_rejected[:5]}\")\n",
1013
+ "if unambig_rejected:\n",
1014
+ " print(f\" Rejected unambiguous (examples): {unambig_rejected[:5]}\")\n",
1015
+ "\n",
1016
+ "# B. Run experiments\n",
1017
+ "print(f\"\\n📋 STEP 2: Running Generation Tests...\")\n",
1018
+ "print(\"-\" * 60)\n",
1019
+ "\n",
1020
+ "# Storage\n",
1021
+ "results_ambig = []\n",
1022
+ "results_unambig = []\n",
1023
+ "\n",
1024
+ "# Test AMBIGUOUS tokens\n",
1025
+ "print(f\"\\n🔴 Testing {len(ambig_valid)} AMBIGUOUS single tokens...\")\n",
1026
+ "for word, meaning1, meaning2 in tqdm(ambig_valid):\n",
1027
+ " try:\n",
1028
+ " output = run_prism_gen(word, model, tokenizer, DEVICE)\n",
1029
+ " collapsed = is_repetition_collapse(output)\n",
1030
+ " # For ambiguous, \"correct\" is undefined - check if either meaning appears\n",
1031
+ " has_meaning = check_correct(output, [meaning1, meaning2])\n",
1032
+ "\n",
1033
+ " results_ambig.append({\n",
1034
+ " \"word\": word,\n",
1035
+ " \"meanings\": f\"{meaning1}/{meaning2}\",\n",
1036
+ " \"output\": output,\n",
1037
+ " \"collapsed\": collapsed,\n",
1038
+ " \"has_any_meaning\": has_meaning\n",
1039
+ " })\n",
1040
+ " except Exception as e:\n",
1041
+ " print(f\"Error on {word}: {e}\")\n",
1042
+ "\n",
1043
+ "# Test UNAMBIGUOUS tokens\n",
1044
+ "print(f\"\\n🟢 Testing {len(unambig_valid)} UNAMBIGUOUS single tokens...\")\n",
1045
+ "for word, meaning in tqdm(unambig_valid):\n",
1046
+ " try:\n",
1047
+ " output = run_prism_gen(word, model, tokenizer, DEVICE)\n",
1048
+ " collapsed = is_repetition_collapse(output)\n",
1049
+ " correct = check_correct(output, meaning)\n",
1050
+ "\n",
1051
+ " results_unambig.append({\n",
1052
+ " \"word\": word,\n",
1053
+ " \"meaning\": meaning,\n",
1054
+ " \"output\": output,\n",
1055
+ " \"collapsed\": collapsed,\n",
1056
+ " \"correct\": correct\n",
1057
+ " })\n",
1058
+ " except Exception as e:\n",
1059
+ " print(f\"Error on {word}: {e}\")\n",
1060
+ "\n",
1061
+ "# ==============================================================================\n",
1062
+ "# 5. STATISTICAL ANALYSIS\n",
1063
+ "# ==============================================================================\n",
1064
+ "\n",
1065
+ "df_ambig = pd.DataFrame(results_ambig)\n",
1066
+ "df_unambig = pd.DataFrame(results_unambig)\n",
1067
+ "\n",
1068
+ "print(\"\\n\" + \"=\" * 100)\n",
1069
+ "print(\"📊 RESULTS: AMBIGUOUS TOKENS (L=1)\")\n",
1070
+ "print(\"=\" * 100)\n",
1071
+ "print(f\"{'Word':<15} | {'Meanings':<25} | {'Output':<40} | {'Collapse?'}\")\n",
1072
+ "print(\"-\" * 100)\n",
1073
+ "\n",
1074
+ "for _, row in df_ambig.iterrows():\n",
1075
+ " collapse_mark = \"💥 YES\" if row['collapsed'] else \"No\"\n",
1076
+ " output_display = row['output'][:38] + \"..\" if len(row['output']) > 40 else row['output']\n",
1077
+ " print(f\"{row['word']:<15} | {row['meanings']:<25} | {output_display:<40} | {collapse_mark}\")\n",
1078
+ "\n",
1079
+ "print(\"\\n\" + \"=\" * 100)\n",
1080
+ "print(\"📊 RESULTS: UNAMBIGUOUS TOKENS (L=1)\")\n",
1081
+ "print(\"=\" * 100)\n",
1082
+ "print(f\"{'Word':<15} | {'Target':<15} | {'Output':<40} | {'Collapse?':<10} | {'Correct?'}\")\n",
1083
+ "print(\"-\" * 100)\n",
1084
+ "\n",
1085
+ "for _, row in df_unambig.iterrows():\n",
1086
+ " collapse_mark = \"💥 YES\" if row['collapsed'] else \"No\"\n",
1087
+ " correct_mark = \"✅\" if row['correct'] else \"❌\"\n",
1088
+ " output_display = row['output'][:38] + \"..\" if len(row['output']) > 40 else row['output']\n",
1089
+ " print(f\"{row['word']:<15} | {row['meaning']:<15} | {output_display:<40} | {collapse_mark:<10} | {correct_mark}\")\n",
1090
+ "\n",
1091
+ "# ==============================================================================\n",
1092
+ "# 6. SUMMARY STATISTICS\n",
1093
+ "# ==============================================================================\n",
1094
+ "\n",
1095
+ "print(\"\\n\" + \"=\" * 100)\n",
1096
+ "print(\"📈 SUMMARY STATISTICS\")\n",
1097
+ "print(\"=\" * 100)\n",
1098
+ "\n",
1099
+ "n_ambig = len(df_ambig)\n",
1100
+ "n_unambig = len(df_unambig)\n",
1101
+ "\n",
1102
+ "ambig_collapse_rate = df_ambig['collapsed'].sum() / n_ambig * 100 if n_ambig > 0 else 0\n",
1103
+ "unambig_collapse_rate = df_unambig['collapsed'].sum() / n_unambig * 100 if n_unambig > 0 else 0\n",
1104
+ "unambig_correct_rate = df_unambig['correct'].sum() / n_unambig * 100 if n_unambig > 0 else 0\n",
1105
+ "\n",
1106
+ "print(f\"\"\"\n",
1107
+ "┌─────────────────────────────────────────────────────────────┐\n",
1108
+ "│ CARRIER WAVE THRESHOLD │\n",
1109
+ "├──��──────────────────────────────────────────────────────────┤\n",
1110
+ "│ Condition │ N │ Collapse Rate │ Correct │\n",
1111
+ "├─────────────────────────────────────────────────────────────┤\n",
1112
+ "│ AMBIGUOUS (L=1) │ {n_ambig:<5} │ {ambig_collapse_rate:>6.1f}% │ N/A │\n",
1113
+ "│ UNAMBIGUOUS (L=1) │ {n_unambig:<5} │ {unambig_collapse_rate:>6.1f}% │ {unambig_correct_rate:>5.1f}% │\n",
1114
+ "└─────────────────────────────────────────────────────────────┘\n",
1115
+ "\"\"\")\n",
1116
+ "\n",
1117
+ "# Statistical comparison\n",
1118
+ "print(\"🔬 STATISTICAL INTERPRETATION:\")\n",
1119
+ "print(\"-\" * 60)\n",
1120
+ "\n",
1121
+ "if ambig_collapse_rate > unambig_collapse_rate + 10:\n",
1122
+ " print(f\" → Ambiguous tokens collapse MORE ({ambig_collapse_rate:.1f}% vs {unambig_collapse_rate:.1f}%)\")\n",
1123
+ " print(f\" → Difference: {ambig_collapse_rate - unambig_collapse_rate:.1f} percentage points\")\n",
1124
+ " print(f\" → SUPPORTS: Ambiguity exacerbates L=1 failure\")\n",
1125
+ "elif abs(ambig_collapse_rate - unambig_collapse_rate) <= 10:\n",
1126
+ " print(f\" → Both collapse at similar rates ({ambig_collapse_rate:.1f}% vs {unambig_collapse_rate:.1f}%)\")\n",
1127
+ " print(f\" → SUPPORTS: L=1 failure is about sequence length, not ambiguity\")\n",
1128
+ "else:\n",
1129
+ " print(f\" → Unexpected pattern - investigate further\")\n",
1130
+ "\n",
1131
+ "print(f\"\\n → Unambiguous tokens that ARE correct despite collapse: {unambig_correct_rate:.1f}%\")\n",
1132
+ "print(f\" → This suggests embeddings DO encode meaning, but decoder loops anyway\")\n",
1133
+ "\n",
1134
+ "# ==============================================================================\n",
1135
+ "# 7. EXPORT FOR PAPER\n",
1136
+ "# ==============================================================================\n",
1137
+ "\n",
1138
+ "print(\"\\n\" + \"=\" * 100)\n",
1139
+ "print(\"📝 LATEX-READY TABLE\")\n",
1140
+ "print(\"=\" * 100)\n",
1141
+ "\n",
1142
+ "print(f\"\"\"\n",
1143
+ "\\\\begin{{table}}[h]\n",
1144
+ "\\\\centering\n",
1145
+ "\\\\caption{{Carrier Wave Threshold Analysis (Extended). Large-scale validation confirms\n",
1146
+ "that repetition collapse at $L=1$ affects both ambiguous and unambiguous tokens.}}\n",
1147
+ "\\\\label{{tab:carrier_wave_extended}}\n",
1148
+ "\\\\begin{{tabular}}{{lccc}}\n",
1149
+ "\\\\toprule\n",
1150
+ "\\\\textbf{{Condition}} & \\\\textbf{{N}} & \\\\textbf{{Collapse Rate}} & \\\\textbf{{Correct (if applicable)}} \\\\\\\\\n",
1151
+ "\\\\midrule\n",
1152
+ "Ambiguous ($L=1$) & {n_ambig} & {ambig_collapse_rate:.1f}\\\\% & N/A \\\\\\\\\n",
1153
+ "Unambiguous ($L=1$) & {n_unambig} & {unambig_collapse_rate:.1f}\\\\% & {unambig_correct_rate:.1f}\\\\% \\\\\\\\\n",
1154
+ "\\\\bottomrule\n",
1155
+ "\\\\end{{tabular}}\n",
1156
+ "\\\\end{{table}}\n",
1157
+ "\"\"\")"
1158
+ ],
1159
+ "metadata": {
1160
+ "id": "A1Ta_6di5s7E"
1161
+ },
1162
+ "execution_count": null,
1163
+ "outputs": []
1164
+ },
1165
+ {
1166
+ "cell_type": "code",
1167
+ "source": [
1168
+ "import torch\n",
1169
+ "import pandas as pd\n",
1170
+ "from tqdm import tqdm\n",
1171
+ "\n",
1172
+ "# ==============================================================================\n",
1173
+ "# 1. SETUP & HELPERS\n",
1174
+ "# ==============================================================================\n",
1175
+ "print(\"=\" * 100)\n",
1176
+ "print(\"🌊 PHASE 2: THE CARRIER WAVE RESURRECTION (L=2)\")\n",
1177
+ "print(\"=\" * 100)\n",
1178
+ "\n",
1179
+ "def add_minimal_context(word):\n",
1180
+ " \"\"\"Prepends a neutral article to force L >= 2\"\"\"\n",
1181
+ " return f\"Das {word}\"\n",
1182
+ "\n",
1183
+ "def run_prism_gen_exact(text_input, max_len=10):\n",
1184
+ " \"\"\"\n",
1185
+ " Exact generation wrapper matching previous usage.\n",
1186
+ " \"\"\"\n",
1187
+ " # 1. Tokenize (Get IDs only, no special tokens)\n",
1188
+ " input_tensor = tokenizer(text_input, return_tensors=\"pt\", add_special_tokens=False).input_ids.to(DEVICE)\n",
1189
+ "\n",
1190
+ " # 2. Call custom generate method exactly as before\n",
1191
+ " with torch.no_grad():\n",
1192
+ " out_ids = model.generate(src=input_tensor, max_length=max_len, num_beams=1)\n",
1193
+ "\n",
1194
+ " # 3. Decode\n",
1195
+ " return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()\n",
1196
+ "\n",
1197
+ "results_recovery = []\n",
1198
+ "\n",
1199
+ "# ==============================================================================\n",
1200
+ "# 2. RUN THE COMPARISON LOOP\n",
1201
+ "# ==============================================================================\n",
1202
+ "print(f\"🔄 Retesting {len(ambig_valid)} Ambiguous & {len(unambig_valid)} Unambiguous tokens with 'Das [X]'...\")\n",
1203
+ "\n",
1204
+ "# --- A. AMBIGUOUS RECOVERY ---\n",
1205
+ "# Expected format: (word, meaning1, meaning2)\n",
1206
+ "for item in tqdm(ambig_valid, desc=\"Ambiguous L=2\"):\n",
1207
+ " word = item[0]\n",
1208
+ " meanings = item[1:] # Capture all meanings provided in tuple\n",
1209
+ "\n",
1210
+ " try:\n",
1211
+ " # Run L=1 (Single Token)\n",
1212
+ " out_L1 = run_prism_gen_exact(word)\n",
1213
+ " is_collapsed_L1 = is_repetition_collapse(out_L1)\n",
1214
+ "\n",
1215
+ " # Run L=2 (Minimal Context)\n",
1216
+ " input_L2 = add_minimal_context(word)\n",
1217
+ " out_L2 = run_prism_gen_exact(input_L2)\n",
1218
+ " is_collapsed_L2 = is_repetition_collapse(out_L2)\n",
1219
+ "\n",
1220
+ " # Check Meaning Recovery: Does L=2 output contain any valid meaning?\n",
1221
+ " # We assume check_correct handles a list of valid targets\n",
1222
+ " has_meaning_L2 = check_correct(out_L2, meanings)\n",
1223
+ "\n",
1224
+ " results_recovery.append({\n",
1225
+ " \"Type\": \"Ambiguous\",\n",
1226
+ " \"Word\": word,\n",
1227
+ " \"L1_Output\": out_L1,\n",
1228
+ " \"L1_Collapse\": is_collapsed_L1,\n",
1229
+ " \"L2_Input\": input_L2,\n",
1230
+ " \"L2_Output\": out_L2,\n",
1231
+ " \"L2_Collapse\": is_collapsed_L2,\n",
1232
+ " \"L2_Success\": has_meaning_L2\n",
1233
+ " })\n",
1234
+ " except Exception as e:\n",
1235
+ " print(f\"Error on {word}: {e}\")\n",
1236
+ "\n",
1237
+ "# --- B. UNAMBIGUOUS RECOVERY ---\n",
1238
+ "# Expected format: (word, target)\n",
1239
+ "for item in tqdm(unambig_valid, desc=\"Unambiguous L=2\"):\n",
1240
+ " word = item[0]\n",
1241
+ " target = item[1]\n",
1242
+ "\n",
1243
+ " try:\n",
1244
+ " # Run L=1\n",
1245
+ " out_L1 = run_prism_gen_exact(word)\n",
1246
+ " is_collapsed_L1 = is_repetition_collapse(out_L1)\n",
1247
+ "\n",
1248
+ " # Run L=2\n",
1249
+ " input_L2 = add_minimal_context(word)\n",
1250
+ " out_L2 = run_prism_gen_exact(input_L2)\n",
1251
+ " is_collapsed_L2 = is_repetition_collapse(out_L2)\n",
1252
+ "\n",
1253
+ " # Check Accuracy\n",
1254
+ " is_correct_L2 = check_correct(out_L2, [target])\n",
1255
+ "\n",
1256
+ " results_recovery.append({\n",
1257
+ " \"Type\": \"Unambiguous\",\n",
1258
+ " \"Word\": word,\n",
1259
+ " \"L1_Output\": out_L1,\n",
1260
+ " \"L1_Collapse\": is_collapsed_L1,\n",
1261
+ " \"L2_Input\": input_L2,\n",
1262
+ " \"L2_Output\": out_L2,\n",
1263
+ " \"L2_Collapse\": is_collapsed_L2,\n",
1264
+ " \"L2_Success\": is_correct_L2\n",
1265
+ " })\n",
1266
+ " except Exception as e:\n",
1267
+ " print(f\"Error on {word}: {e}\")\n",
1268
+ "\n",
1269
+ "# ==============================================================================\n",
1270
+ "# 3. ANALYSIS & VISUALIZATION\n",
1271
+ "# ==============================================================================\n",
1272
+ "df_rec = pd.DataFrame(results_recovery)\n",
1273
+ "\n",
1274
+ "# Metrics Calculation\n",
1275
+ "total_ambig = len(df_rec[df_rec[\"Type\"] == \"Ambiguous\"])\n",
1276
+ "total_unambig = len(df_rec[df_rec[\"Type\"] == \"Unambiguous\"])\n",
1277
+ "\n",
1278
+ "# L1 Collapse counts\n",
1279
+ "collapse_L1_ambig = df_rec[(df_rec[\"Type\"] == \"Ambiguous\") & (df_rec[\"L1_Collapse\"])].shape[0]\n",
1280
+ "collapse_L1_unambig = df_rec[(df_rec[\"Type\"] == \"Unambiguous\") & (df_rec[\"L1_Collapse\"])].shape[0]\n",
1281
+ "\n",
1282
+ "# L2 Collapse counts\n",
1283
+ "collapse_L2_ambig = df_rec[(df_rec[\"Type\"] == \"Ambiguous\") & (df_rec[\"L2_Collapse\"])].shape[0]\n",
1284
+ "collapse_L2_unambig = df_rec[(df_rec[\"Type\"] == \"Unambiguous\") & (df_rec[\"L2_Collapse\"])].shape[0]\n",
1285
+ "\n",
1286
+ "# Resurrection counts (Collapsed at L1 AND Succeeded at L2)\n",
1287
+ "resurrected_ambig = df_rec[\n",
1288
+ " (df_rec[\"Type\"] == \"Ambiguous\") &\n",
1289
+ " (df_rec[\"L1_Collapse\"] == True) &\n",
1290
+ " (df_rec[\"L2_Success\"] == True)\n",
1291
+ "].shape[0]\n",
1292
+ "\n",
1293
+ "resurrected_unambig = df_rec[\n",
1294
+ " (df_rec[\"Type\"] == \"Unambiguous\") &\n",
1295
+ " (df_rec[\"L1_Collapse\"] == True) &\n",
1296
+ " (df_rec[\"L2_Success\"] == True)\n",
1297
+ "].shape[0]\n",
1298
+ "\n",
1299
+ "# Recovery Rates (Percentage of collapsed L1 tokens that were fixed)\n",
1300
+ "recov_rate_ambig = (resurrected_ambig / collapse_L1_ambig * 100) if collapse_L1_ambig > 0 else 0\n",
1301
+ "recov_rate_unambig = (resurrected_unambig / collapse_L1_unambig * 100) if collapse_L1_unambig > 0 else 0\n",
1302
+ "\n",
1303
+ "# --- ASCII TABLE OUTPUT ---\n",
1304
+ "print(\"\\n\" + \"=\"*110)\n",
1305
+ "print(\"🏥 THE RECOVERY WARD: L=1 vs L=2 COMPARISON\")\n",
1306
+ "print(\"=\"*110)\n",
1307
+ "print(f\"{'Word':<12} | {'L=1 Output (Collapsed)':<30} | {'➡️'} | {'L=2 Output (Recovered)':<30} | {'Status'}\")\n",
1308
+ "print(\"-\" * 110)\n",
1309
+ "\n",
1310
+ "# Show examples of successful recoveries\n",
1311
+ "recoveries = df_rec[(df_rec[\"L1_Collapse\"] == True) & (df_rec[\"L2_Success\"] == True)]\n",
1312
+ "for _, row in recoveries.head(15).iterrows():\n",
1313
+ " l1_short = row['L1_Output'][:28] + \"..\" if len(row['L1_Output']) > 30 else row['L1_Output']\n",
1314
+ " l2_short = row['L2_Output'][:28] + \"..\" if len(row['L2_Output']) > 30 else row['L2_Output']\n",
1315
+ " print(f\"{row['Word']:<12} | {l1_short:<30} | {'➡️'} | {l2_short:<30} | {'✅ FIXED'}\")\n",
1316
+ "\n",
1317
+ "print(\"=\"*110)\n",
1318
+ "\n",
1319
+ "# --- SUMMARY STATISTICS ---\n",
1320
+ "print(\"\\n📊 CARRIER WAVE RECOVERY STATISTICS\")\n",
1321
+ "print(\"-\" * 80)\n",
1322
+ "print(f\"{'Metric':<30} | {'Ambiguous':<15} | {'Unambiguous':<15}\")\n",
1323
+ "print(\"-\" * 80)\n",
1324
+ "print(f\"{'Total Samples':<30} | {total_ambig:<15} | {total_unambig:<15}\")\n",
1325
+ "print(f\"{'L=1 Collapse Rate':<30} | {collapse_L1_ambig/total_ambig*100:5.1f}% | {collapse_L1_unambig/total_unambig*100:5.1f}%\")\n",
1326
+ "print(f\"{'L=2 Collapse Rate':<30} | {collapse_L2_ambig/total_ambig*100:5.1f}% | {collapse_L2_unambig/total_unambig*100:5.1f}%\")\n",
1327
+ "print(f\"{'Resurrection Rate':<30} | {recov_rate_ambig:5.1f}% | {recov_rate_unambig:5.1f}%\")\n",
1328
+ "print(\"-\" * 80)\n",
1329
+ "\n",
1330
+ "\n",
1331
+ "print(\"DO NOT MENTION 'FIXED' TAGS. RESULTS ARE CORRECTLY INTERPRETED ON PAPER, NOT HERE.\" )"
1332
+ ],
1333
+ "metadata": {
1334
+ "id": "9EemBN675ZLh"
1335
+ },
1336
+ "execution_count": null,
1337
+ "outputs": []
1338
+ }
1339
+ ]
1340
+ }