AbstractPhil commited on
Commit
c1c10d7
·
verified ·
1 Parent(s): a225307

Upload 6 files

Browse files
Geolip_Procrustes_Bert_Model_Step_Model_Scaling.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiment_bulk (1).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiment_bulk.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
experiment_bulk_claude_generated.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
qwen35_embedding_explorer.ipynb ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "accelerator": "GPU"
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "# Qwen3.5-0.8B Embedding Explorer\n",
21
+ "Extract all-layer embeddings, compare prompt similarity, and evaluate potential for diffusion conditioning.\n",
22
+ "\n",
23
+ "**Runtime: GPU (T4 is fine for 0.8B)**"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "metadata": {},
29
+ "source": [
30
+ "# Qwen3.5 requires transformers from git main (not yet in PyPI release)\n",
31
+ "!pip install -q \"transformers @ git+https://github.com/huggingface/transformers.git@main\"\n",
32
+ "!pip install -q accelerate torch matplotlib seaborn numpy scipy"
33
+ ],
34
+ "execution_count": null,
35
+ "outputs": []
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "metadata": {},
40
+ "source": [
41
+ "import torch\n",
42
+ "import numpy as np\n",
43
+ "import matplotlib.pyplot as plt\n",
44
+ "import seaborn as sns\n",
45
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
46
+ "from scipy.spatial.distance import cosine\n",
47
+ "from typing import Optional\n",
48
+ "import gc\n",
49
+ "\n",
50
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
51
+ "print(f'Device: {device}')\n",
52
+ "if device.type == 'cuda':\n",
53
+ " print(f'GPU: {torch.cuda.get_device_name()}')\n",
54
+ " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
55
+ ],
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {},
62
+ "source": [
63
+ "## Load Model"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "metadata": {},
69
+ "source": [
70
+ "MODEL_ID = 'Qwen/Qwen3.5-0.8B'\n",
71
+ "\n",
72
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
73
+ "model = AutoModelForCausalLM.from_pretrained(\n",
74
+ " MODEL_ID,\n",
75
+ " torch_dtype=torch.bfloat16,\n",
76
+ " device_map='auto',\n",
77
+ " trust_remote_code=True,\n",
78
+ " output_hidden_states=True, # Critical: get all layer outputs\n",
79
+ ")\n",
80
+ "model.eval()\n",
81
+ "\n",
82
+ "num_layers = model.config.num_hidden_layers\n",
83
+ "hidden_dim = model.config.hidden_size\n",
84
+ "print(f'Layers: {num_layers}, Hidden dim: {hidden_dim}')\n",
85
+ "print(f'Total hidden states returned: {num_layers + 1} (embedding layer + {num_layers} transformer layers)')"
86
+ ],
87
+ "execution_count": null,
88
+ "outputs": []
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {},
93
+ "source": [
94
+ "## Embedding Extraction Engine\n",
95
+ "Extracts hidden states from all layers with multiple pooling strategies."
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "metadata": {},
101
+ "source": [
102
+ "class QwenEmbeddingExtractor:\n",
103
+ " \"\"\"Extract and pool hidden states from all layers of Qwen3.5-0.8B.\"\"\"\n",
104
+ "\n",
105
+ " def __init__(self, model, tokenizer, device):\n",
106
+ " self.model = model\n",
107
+ " self.tokenizer = tokenizer\n",
108
+ " self.device = device\n",
109
+ " self.num_layers = model.config.num_hidden_layers + 1 # +1 for embedding layer\n",
110
+ " self.hidden_dim = model.config.hidden_size\n",
111
+ "\n",
112
+ " @torch.no_grad()\n",
113
+ " def extract_hidden_states(self, text: str) -> dict:\n",
114
+ " \"\"\"\n",
115
+ " Run forward pass and return all hidden states + metadata.\n",
116
+ "\n",
117
+ " Returns dict with:\n",
118
+ " - hidden_states: tuple of (num_layers+1) tensors, each [1, seq_len, hidden_dim]\n",
119
+ " - input_ids: token IDs\n",
120
+ " - tokens: decoded token strings\n",
121
+ " - seq_len: number of tokens\n",
122
+ " \"\"\"\n",
123
+ " inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n",
124
+ " outputs = self.model(**inputs)\n",
125
+ "\n",
126
+ " hidden_states = outputs.hidden_states # tuple of (num_layers+1) tensors\n",
127
+ " input_ids = inputs['input_ids'][0]\n",
128
+ " tokens = [self.tokenizer.decode(tid) for tid in input_ids]\n",
129
+ "\n",
130
+ " return {\n",
131
+ " 'hidden_states': hidden_states,\n",
132
+ " 'input_ids': input_ids,\n",
133
+ " 'tokens': tokens,\n",
134
+ " 'seq_len': len(tokens),\n",
135
+ " }\n",
136
+ "\n",
137
+ " def pool_hidden_states(\n",
138
+ " self,\n",
139
+ " hidden_states: tuple,\n",
140
+ " method: str = 'mean',\n",
141
+ " layer_indices: Optional[list] = None,\n",
142
+ " ) -> torch.Tensor:\n",
143
+ " \"\"\"\n",
144
+ " Pool hidden states across tokens for specified layers.\n",
145
+ "\n",
146
+ " Args:\n",
147
+ " hidden_states: tuple from extract_hidden_states\n",
148
+ " method: 'mean', 'last_token', 'max', or 'all_tokens'\n",
149
+ " layer_indices: which layers to return (None = all)\n",
150
+ "\n",
151
+ " Returns:\n",
152
+ " For 'all_tokens': [num_layers, seq_len, hidden_dim]\n",
153
+ " Otherwise: [num_layers, hidden_dim]\n",
154
+ " \"\"\"\n",
155
+ " if layer_indices is None:\n",
156
+ " layer_indices = list(range(len(hidden_states)))\n",
157
+ "\n",
158
+ " pooled = []\n",
159
+ " for idx in layer_indices:\n",
160
+ " hs = hidden_states[idx].squeeze(0) # [seq_len, hidden_dim]\n",
161
+ "\n",
162
+ " if method == 'mean':\n",
163
+ " pooled.append(hs.mean(dim=0)) # [hidden_dim]\n",
164
+ " elif method == 'last_token':\n",
165
+ " pooled.append(hs[-1]) # [hidden_dim]\n",
166
+ " elif method == 'max':\n",
167
+ " pooled.append(hs.max(dim=0).values) # [hidden_dim]\n",
168
+ " elif method == 'all_tokens':\n",
169
+ " pooled.append(hs) # [seq_len, hidden_dim]\n",
170
+ " else:\n",
171
+ " raise ValueError(f'Unknown pooling method: {method}')\n",
172
+ "\n",
173
+ " return torch.stack(pooled)\n",
174
+ "\n",
175
+ " def extract_and_pool(self, text: str, method: str = 'mean') -> dict:\n",
176
+ " \"\"\"\n",
177
+ " Convenience: extract + pool in one call.\n",
178
+ "\n",
179
+ " Returns dict with:\n",
180
+ " - embeddings: [num_layers, hidden_dim] (or [num_layers, seq_len, hidden_dim] for all_tokens)\n",
181
+ " - tokens: list of token strings\n",
182
+ " - seq_len: int\n",
183
+ " \"\"\"\n",
184
+ " data = self.extract_hidden_states(text)\n",
185
+ " embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n",
186
+ " return {\n",
187
+ " 'embeddings': embeddings,\n",
188
+ " 'tokens': data['tokens'],\n",
189
+ " 'seq_len': data['seq_len'],\n",
190
+ " }\n",
191
+ "\n",
192
+ "extractor = QwenEmbeddingExtractor(model, tokenizer, device)\n",
193
+ "print(f'Extractor ready. Will return {extractor.num_layers} layer embeddings per prompt.')"
194
+ ],
195
+ "execution_count": null,
196
+ "outputs": []
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "metadata": {},
201
+ "source": [
202
+ "## Define Test Prompts\n",
203
+ "Edit these to whatever you want to compare. Grouped by semantic category to see clustering behavior."
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "metadata": {},
209
+ "source": [
210
+ "# ---- EDIT THESE ----\n",
211
+ "# Groups help visualize clustering. Flat list is also fine.\n",
212
+ "PROMPT_GROUPS = {\n",
213
+ " 'photorealistic': [\n",
214
+ " 'a photograph of a cat sitting on a windowsill in golden hour light',\n",
215
+ " 'professional photo of a mountain landscape at sunset with dramatic clouds',\n",
216
+ " 'close-up portrait of an elderly man with weathered skin and blue eyes',\n",
217
+ " ],\n",
218
+ " 'artistic': [\n",
219
+ " 'an oil painting of a stormy sea in the style of Turner',\n",
220
+ " 'watercolor illustration of a quiet Japanese garden with cherry blossoms',\n",
221
+ " 'abstract geometric composition with overlapping translucent shapes',\n",
222
+ " ],\n",
223
+ " 'semantic_shift': [\n",
224
+ " 'a red cube on a blue floor',\n",
225
+ " 'a blue cube on a red floor',\n",
226
+ " 'a green sphere floating above a white plane',\n",
227
+ " ],\n",
228
+ " 'edge_cases': [\n",
229
+ " 'darkness',\n",
230
+ " '', # empty string baseline\n",
231
+ " 'asdfghjkl random noise tokens xyzzy',\n",
232
+ " ],\n",
233
+ "}\n",
234
+ "\n",
235
+ "# Flatten for processing\n",
236
+ "prompts = []\n",
237
+ "prompt_labels = []\n",
238
+ "prompt_groups = []\n",
239
+ "for group_name, group_prompts in PROMPT_GROUPS.items():\n",
240
+ " for p in group_prompts:\n",
241
+ " prompts.append(p)\n",
242
+ " label = p[:50] + '...' if len(p) > 50 else p\n",
243
+ " label = label if label else '<empty>'\n",
244
+ " prompt_labels.append(label)\n",
245
+ " prompt_groups.append(group_name)\n",
246
+ "\n",
247
+ "print(f'{len(prompts)} prompts across {len(PROMPT_GROUPS)} groups')"
248
+ ],
249
+ "execution_count": null,
250
+ "outputs": []
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "metadata": {},
255
+ "source": [
256
+ "## Extract All Embeddings"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "metadata": {},
262
+ "source": [
263
+ "POOL_METHODS = ['mean', 'last_token']\n",
264
+ "\n",
265
+ "# Store results: {method: {prompt_idx: [num_layers, hidden_dim]}}\n",
266
+ "all_embeddings = {method: {} for method in POOL_METHODS}\n",
267
+ "token_counts = {}\n",
268
+ "\n",
269
+ "for i, prompt in enumerate(prompts):\n",
270
+ " print(f'[{i+1}/{len(prompts)}] ({len(prompt)} chars) \"{prompt_labels[i]}\"')\n",
271
+ " for method in POOL_METHODS:\n",
272
+ " result = extractor.extract_and_pool(prompt, method=method)\n",
273
+ " all_embeddings[method][i] = result['embeddings'].float().cpu() # [num_layers, hidden_dim]\n",
274
+ " if method == POOL_METHODS[0]:\n",
275
+ " token_counts[i] = result['seq_len']\n",
276
+ "\n",
277
+ "print(f'\\nDone. Shape per prompt per method: {all_embeddings[\"mean\"][0].shape}')\n",
278
+ "print(f'Token counts: {list(token_counts.values())}')"
279
+ ],
280
+ "execution_count": null,
281
+ "outputs": []
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "metadata": {},
286
+ "source": [
287
+ "## Cosine Similarity Analysis\n",
288
+ "Compute pairwise similarity at every layer, for each pooling method."
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "metadata": {},
294
+ "source": [
295
+ "def compute_pairwise_cosine(embeddings_dict, num_prompts, num_layers):\n",
296
+ " \"\"\"\n",
297
+ " Compute cosine similarity between all prompt pairs at each layer.\n",
298
+ "\n",
299
+ " Returns: [num_layers, num_prompts, num_prompts] numpy array\n",
300
+ " \"\"\"\n",
301
+ " sim_matrix = np.zeros((num_layers, num_prompts, num_prompts))\n",
302
+ "\n",
303
+ " for layer_idx in range(num_layers):\n",
304
+ " for i in range(num_prompts):\n",
305
+ " for j in range(num_prompts):\n",
306
+ " if i == j:\n",
307
+ " sim_matrix[layer_idx, i, j] = 1.0\n",
308
+ " elif j > i:\n",
309
+ " vec_i = embeddings_dict[i][layer_idx].numpy()\n",
310
+ " vec_j = embeddings_dict[j][layer_idx].numpy()\n",
311
+ " sim = 1.0 - cosine(vec_i, vec_j)\n",
312
+ " sim_matrix[layer_idx, i, j] = sim\n",
313
+ " sim_matrix[layer_idx, j, i] = sim\n",
314
+ "\n",
315
+ " return sim_matrix\n",
316
+ "\n",
317
+ "n_prompts = len(prompts)\n",
318
+ "n_layers = extractor.num_layers\n",
319
+ "\n",
320
+ "sim_matrices = {}\n",
321
+ "for method in POOL_METHODS:\n",
322
+ " sim_matrices[method] = compute_pairwise_cosine(\n",
323
+ " all_embeddings[method], n_prompts, n_layers\n",
324
+ " )\n",
325
+ " print(f'{method}: similarity matrix shape = {sim_matrices[method].shape}')"
326
+ ],
327
+ "execution_count": null,
328
+ "outputs": []
329
+ },
330
+ {
331
+ "cell_type": "markdown",
332
+ "metadata": {},
333
+ "source": [
334
+ "## Heatmaps: Per-Layer Similarity\n",
335
+ "Shows how prompt-pair similarity evolves across layers."
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "metadata": {},
341
+ "source": [
342
+ "def plot_similarity_heatmaps(sim_matrix, labels, method_name, layers_to_show=None):\n",
343
+ " \"\"\"\n",
344
+ " Plot similarity heatmaps for selected layers.\n",
345
+ " If layers_to_show is None, picks: first, 25%, 50%, 75%, last.\n",
346
+ " \"\"\"\n",
347
+ " n_layers = sim_matrix.shape[0]\n",
348
+ "\n",
349
+ " if layers_to_show is None:\n",
350
+ " layers_to_show = sorted(set([\n",
351
+ " 0,\n",
352
+ " n_layers // 4,\n",
353
+ " n_layers // 2,\n",
354
+ " 3 * n_layers // 4,\n",
355
+ " n_layers - 2, # penultimate\n",
356
+ " n_layers - 1, # final\n",
357
+ " ]))\n",
358
+ "\n",
359
+ " n_plots = len(layers_to_show)\n",
360
+ " fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))\n",
361
+ " if n_plots == 1:\n",
362
+ " axes = [axes]\n",
363
+ "\n",
364
+ " for ax, layer_idx in zip(axes, layers_to_show):\n",
365
+ " layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n",
366
+ " sns.heatmap(\n",
367
+ " sim_matrix[layer_idx],\n",
368
+ " xticklabels=labels, yticklabels=labels,\n",
369
+ " vmin=-0.2, vmax=1.0,\n",
370
+ " cmap='RdYlBu_r', annot=True, fmt='.2f',\n",
371
+ " ax=ax, square=True,\n",
372
+ " cbar_kws={'shrink': 0.6},\n",
373
+ " )\n",
374
+ " ax.set_title(f'{method_name} | {layer_name}', fontsize=11)\n",
375
+ " ax.tick_params(axis='x', rotation=45)\n",
376
+ " ax.tick_params(axis='y', rotation=0)\n",
377
+ "\n",
378
+ " plt.tight_layout()\n",
379
+ " plt.show()\n",
380
+ "\n",
381
+ "# Short labels for readability\n",
382
+ "short_labels = [l[:30] for l in prompt_labels]\n",
383
+ "\n",
384
+ "for method in POOL_METHODS:\n",
385
+ " print(f'\\n=== {method.upper()} POOLING ===')\n",
386
+ " plot_similarity_heatmaps(sim_matrices[method], short_labels, method)"
387
+ ],
388
+ "execution_count": null,
389
+ "outputs": []
390
+ },
391
+ {
392
+ "cell_type": "markdown",
393
+ "metadata": {},
394
+ "source": [
395
+ "## Layer-wise Discriminability\n",
396
+ "For each layer, compute average within-group similarity vs. between-group similarity.\n",
397
+ "Higher gap = better semantic clustering = more useful for conditioning."
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "metadata": {},
403
+ "source": [
404
+ "def compute_discriminability(sim_matrix, group_labels):\n",
405
+ " \"\"\"\n",
406
+ " Per-layer: avg within-group sim, avg between-group sim, and gap.\n",
407
+ " Returns arrays of shape [num_layers].\n",
408
+ " \"\"\"\n",
409
+ " n_layers = sim_matrix.shape[0]\n",
410
+ " n = sim_matrix.shape[1]\n",
411
+ " unique_groups = list(set(group_labels))\n",
412
+ "\n",
413
+ " within_sim = np.zeros(n_layers)\n",
414
+ " between_sim = np.zeros(n_layers)\n",
415
+ "\n",
416
+ " for layer in range(n_layers):\n",
417
+ " w_vals, b_vals = [], []\n",
418
+ " for i in range(n):\n",
419
+ " for j in range(i + 1, n):\n",
420
+ " val = sim_matrix[layer, i, j]\n",
421
+ " if group_labels[i] == group_labels[j]:\n",
422
+ " w_vals.append(val)\n",
423
+ " else:\n",
424
+ " b_vals.append(val)\n",
425
+ " within_sim[layer] = np.mean(w_vals) if w_vals else 0\n",
426
+ " between_sim[layer] = np.mean(b_vals) if b_vals else 0\n",
427
+ "\n",
428
+ " return within_sim, between_sim, within_sim - between_sim\n",
429
+ "\n",
430
+ "\n",
431
+ "fig, axes = plt.subplots(1, len(POOL_METHODS), figsize=(8 * len(POOL_METHODS), 5))\n",
432
+ "if len(POOL_METHODS) == 1:\n",
433
+ " axes = [axes]\n",
434
+ "\n",
435
+ "best_layers = {}\n",
436
+ "for ax, method in zip(axes, POOL_METHODS):\n",
437
+ " within, between, gap = compute_discriminability(sim_matrices[method], prompt_groups)\n",
438
+ "\n",
439
+ " layer_x = np.arange(n_layers)\n",
440
+ " ax.plot(layer_x, within, label='Within-group sim', color='#2196F3', linewidth=2)\n",
441
+ " ax.plot(layer_x, between, label='Between-group sim', color='#FF5722', linewidth=2)\n",
442
+ " ax.fill_between(layer_x, between, within, alpha=0.15, color='green')\n",
443
+ " ax.plot(layer_x, gap, label='Gap (discriminability)', color='green', linewidth=2, linestyle='--')\n",
444
+ "\n",
445
+ " best_layer = np.argmax(gap)\n",
446
+ " best_layers[method] = best_layer\n",
447
+ " ax.axvline(best_layer, color='green', linestyle=':', alpha=0.5)\n",
448
+ " ax.annotate(f'Best: L{best_layer}', xy=(best_layer, gap[best_layer]),\n",
449
+ " xytext=(best_layer + 1, gap[best_layer] + 0.02),\n",
450
+ " arrowprops=dict(arrowstyle='->', color='green'),\n",
451
+ " fontsize=10, color='green')\n",
452
+ "\n",
453
+ " ax.set_xlabel('Layer Index')\n",
454
+ " ax.set_ylabel('Cosine Similarity')\n",
455
+ " ax.set_title(f'{method} pooling — Semantic Discriminability')\n",
456
+ " ax.legend()\n",
457
+ " ax.grid(True, alpha=0.3)\n",
458
+ "\n",
459
+ "plt.tight_layout()\n",
460
+ "plt.show()\n",
461
+ "\n",
462
+ "print('\\nBest discriminability layers:')\n",
463
+ "for method, layer in best_layers.items():\n",
464
+ " print(f' {method}: layer {layer}')"
465
+ ],
466
+ "execution_count": null,
467
+ "outputs": []
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "metadata": {},
472
+ "source": [
473
+ "## Embedding Norm & Variance Across Layers\n",
474
+ "Checks for collapse (all norms converging) or explosion — both bad for conditioning."
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "metadata": {},
480
+ "source": [
481
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
482
+ "\n",
483
+ "# Norms per layer per prompt\n",
484
+ "for method, ax in zip(POOL_METHODS, axes):\n",
485
+ " for i in range(n_prompts):\n",
486
+ " norms = all_embeddings[method][i].norm(dim=-1).numpy() # [num_layers]\n",
487
+ " ax.plot(range(n_layers), norms, alpha=0.6, label=short_labels[i][:20])\n",
488
+ "\n",
489
+ " ax.set_xlabel('Layer')\n",
490
+ " ax.set_ylabel('L2 Norm')\n",
491
+ " ax.set_title(f'{method} pooling — Embedding Norms')\n",
492
+ " ax.grid(True, alpha=0.3)\n",
493
+ " ax.legend(fontsize=7, loc='upper left')\n",
494
+ "\n",
495
+ "plt.tight_layout()\n",
496
+ "plt.show()"
497
+ ],
498
+ "execution_count": null,
499
+ "outputs": []
500
+ },
501
+ {
502
+ "cell_type": "markdown",
503
+ "metadata": {},
504
+ "source": [
505
+ "## Effective Dimensionality per Layer\n",
506
+ "How many dimensions are actually being used? Low rank = bad for diffusion conditioning diversity."
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "metadata": {},
512
+ "source": [
513
+ "def effective_dimensionality(embeddings_list):\n",
514
+ " \"\"\"\n",
515
+ " Compute effective dimensionality via participation ratio of singular values.\n",
516
+ " embeddings_list: list of [hidden_dim] vectors\n",
517
+ " Returns: float (effective rank)\n",
518
+ " \"\"\"\n",
519
+ " mat = torch.stack(embeddings_list) # [n_prompts, hidden_dim]\n",
520
+ " mat = mat - mat.mean(dim=0) # center\n",
521
+ " _, S, _ = torch.svd(mat)\n",
522
+ " S = S / S.sum()\n",
523
+ " participation_ratio = 1.0 / (S ** 2).sum().item()\n",
524
+ " return participation_ratio\n",
525
+ "\n",
526
+ "\n",
527
+ "for method in POOL_METHODS:\n",
528
+ " eff_dims = []\n",
529
+ " for layer_idx in range(n_layers):\n",
530
+ " layer_vecs = [all_embeddings[method][i][layer_idx] for i in range(n_prompts)]\n",
531
+ " ed = effective_dimensionality(layer_vecs)\n",
532
+ " eff_dims.append(ed)\n",
533
+ "\n",
534
+ " plt.plot(range(n_layers), eff_dims, label=method, linewidth=2)\n",
535
+ "\n",
536
+ "plt.xlabel('Layer')\n",
537
+ "plt.ylabel('Effective Dimensionality (participation ratio)')\n",
538
+ "plt.title('Effective Rank of Embedding Space per Layer')\n",
539
+ "plt.legend()\n",
540
+ "plt.grid(True, alpha=0.3)\n",
541
+ "plt.tight_layout()\n",
542
+ "plt.show()"
543
+ ],
544
+ "execution_count": null,
545
+ "outputs": []
546
+ },
547
+ {
548
+ "cell_type": "markdown",
549
+ "metadata": {},
550
+ "source": [
551
+ "## Quick Diffusion Conditioning Assessment\n",
552
+ "Summary: which layers look most promising as conditioning vectors?"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "metadata": {},
558
+ "source": [
559
+ "print('=' * 70)\n",
560
+ "print('DIFFUSION CONDITIONING VIABILITY SUMMARY')\n",
561
+ "print('=' * 70)\n",
562
+ "print(f'\\nModel: {MODEL_ID}')\n",
563
+ "print(f'Layers: {n_layers} (0=input embeddings, rest=transformer layers)')\n",
564
+ "print(f'Hidden dim: {hidden_dim}')\n",
565
+ "print(f'Prompts tested: {n_prompts}')\n",
566
+ "print()\n",
567
+ "\n",
568
+ "for method in POOL_METHODS:\n",
569
+ " within, between, gap = compute_discriminability(sim_matrices[method], prompt_groups)\n",
570
+ " best_l = np.argmax(gap)\n",
571
+ " worst_l = np.argmin(gap)\n",
572
+ "\n",
573
+ " # Check for near-collapse: if all pairwise sims > 0.95 at any layer\n",
574
+ " collapse_layers = []\n",
575
+ " for l in range(n_layers):\n",
576
+ " off_diag = sim_matrices[method][l][np.triu_indices(n_prompts, k=1)]\n",
577
+ " if off_diag.min() > 0.95:\n",
578
+ " collapse_layers.append(l)\n",
579
+ "\n",
580
+ " print(f'--- {method.upper()} POOLING ---')\n",
581
+ " print(f' Best discriminability: Layer {best_l} (gap = {gap[best_l]:.4f})')\n",
582
+ " print(f' Worst discriminability: Layer {worst_l} (gap = {gap[worst_l]:.4f})')\n",
583
+ " print(f' Penultimate layer gap: {gap[-2]:.4f}')\n",
584
+ " print(f' Final layer gap: {gap[-1]:.4f}')\n",
585
+ " if collapse_layers:\n",
586
+ " print(f' WARNING: Near-collapse at layers: {collapse_layers}')\n",
587
+ " else:\n",
588
+ " print(f' No collapse detected (all layers have some discrimination)')\n",
589
+ " print()\n",
590
+ "\n",
591
+ "print('RECOMMENDATIONS:')\n",
592
+ "print(' For POOLED conditioning (global vector): Use the best discriminability layer.')\n",
593
+ "print(' For TOKEN-LEVEL conditioning (cross-attention): Re-run with method=\"all_tokens\"')\n",
594
+ "print(' and compare token-level structure against T5/CLIP token outputs.')\n",
595
+ "print(' Watch for: norm explosion in later layers (may need LayerNorm before conditioning).')\n",
596
+ "print(' The penultimate layer often outperforms the final layer (CLIP effect).')"
597
+ ],
598
+ "execution_count": null,
599
+ "outputs": []
600
+ },
601
+ {
602
+ "cell_type": "markdown",
603
+ "metadata": {},
604
+ "source": [
605
+ "## (Optional) Export Embeddings for Further Analysis\n",
606
+ "Save to disk for loading into your geometric pipeline."
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "metadata": {},
612
+ "source": [
613
+ "# Uncomment to save\n",
614
+ "# export = {\n",
615
+ "# 'model_id': MODEL_ID,\n",
616
+ "# 'prompts': prompts,\n",
617
+ "# 'prompt_groups': prompt_groups,\n",
618
+ "# 'pool_methods': POOL_METHODS,\n",
619
+ "# 'embeddings': {m: {i: all_embeddings[m][i] for i in range(n_prompts)} for m in POOL_METHODS},\n",
620
+ "# 'sim_matrices': sim_matrices,\n",
621
+ "# 'num_layers': n_layers,\n",
622
+ "# 'hidden_dim': hidden_dim,\n",
623
+ "# }\n",
624
+ "# torch.save(export, 'qwen35_0.8b_embeddings.pt')\n",
625
+ "# print('Saved to qwen35_0.8b_embeddings.pt')"
626
+ ],
627
+ "execution_count": null,
628
+ "outputs": []
629
+ }
630
+ ]
631
+ }
qwen35_twoshot_embedding_explorer.ipynb ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "accelerator": "GPU"
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "# Qwen3.5-0.8B Two-Shot Embedding Explorer\n",
21
+ "Generate descriptions via two-shot prompting, then re-encode the output to extract embeddings with actual semantic diversity.\n",
22
+ "\n",
23
+ "**Runtime: GPU (T4)**"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "metadata": {},
29
+ "source": [
30
+ "# Qwen3.5 requires transformers from git main\n",
31
+ "!pip install -q \"transformers @ git+https://github.com/huggingface/transformers.git@main\"\n",
32
+ "!pip install -q accelerate torch matplotlib seaborn numpy scipy"
33
+ ],
34
+ "execution_count": null,
35
+ "outputs": []
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "metadata": {},
40
+ "source": [
41
+ "import torch\n",
42
+ "import torch.nn.functional as F\n",
43
+ "import numpy as np\n",
44
+ "import matplotlib.pyplot as plt\n",
45
+ "import seaborn as sns\n",
46
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
47
+ "from scipy.spatial.distance import cosine\n",
48
+ "from typing import Optional\n",
49
+ "import gc\n",
50
+ "\n",
51
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
52
+ "print(f'Device: {device}')\n",
53
+ "if device.type == 'cuda':\n",
54
+ " print(f'GPU: {torch.cuda.get_device_name()}')\n",
55
+ " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
56
+ ],
57
+ "execution_count": null,
58
+ "outputs": []
59
+ },
60
+ {
61
+ "cell_type": "markdown",
62
+ "metadata": {},
63
+ "source": [
64
+ "## Load Model"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "metadata": {},
70
+ "source": [
71
+ "MODEL_ID = 'Qwen/Qwen3.5-0.8B'\n",
72
+ "\n",
73
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n",
74
+ "model = AutoModelForCausalLM.from_pretrained(\n",
75
+ " MODEL_ID,\n",
76
+ " torch_dtype=torch.bfloat16,\n",
77
+ " device_map='auto',\n",
78
+ " trust_remote_code=True,\n",
79
+ ")\n",
80
+ "model.eval()\n",
81
+ "\n",
82
+ "num_layers = model.config.num_hidden_layers\n",
83
+ "hidden_dim = model.config.hidden_size\n",
84
+ "print(f'Layers: {num_layers}, Hidden dim: {hidden_dim}')\n",
85
+ "print(f'Total hidden states: {num_layers + 1}')"
86
+ ],
87
+ "execution_count": null,
88
+ "outputs": []
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {},
93
+ "source": [
94
+ "## Two-Shot Generation + Re-Encode Pipeline\n",
95
+ "1. Build a two-shot chat prompt with examples\n",
96
+ "2. Generate a description\n",
97
+ "3. Re-encode the generated text (not the prompt) and extract all hidden states"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "metadata": {},
103
+ "source": [
104
+ "class TwoShotEmbeddingExtractor:\n",
105
+ " \"\"\"\n",
106
+ " Two-shot generation -> re-encode pipeline.\n",
107
+ " \n",
108
+ " Step 1: Chat-template two-shot prompt -> generate description\n",
109
+ " Step 2: Encode the GENERATED text alone -> extract hidden states\n",
110
+ " \n",
111
+ " This produces embeddings of the model's own description,\n",
112
+ " which has far more semantic diversity than raw prompt encoding.\n",
113
+ " \"\"\"\n",
114
+ "\n",
115
+ " def __init__(self, model, tokenizer, device, min_tokens=2):\n",
116
+ " self.model = model\n",
117
+ " self.tokenizer = tokenizer\n",
118
+ " self.device = device\n",
119
+ " self.min_tokens = min_tokens\n",
120
+ " self.num_layers = model.config.num_hidden_layers + 1\n",
121
+ " self.hidden_dim = model.config.hidden_size\n",
122
+ "\n",
123
+ " def build_twoshot_prompt(self, subject: str) -> str:\n",
124
+ " \"\"\"Build two-shot chat prompt with visual description examples.\"\"\"\n",
125
+ " messages = [\n",
126
+ " {\n",
127
+ " 'role': 'system',\n",
128
+ " 'content': 'You describe scenes and subjects in exactly one sentence. '\n",
129
+ " 'Be specific about visual features, lighting, colors, and composition.'\n",
130
+ " },\n",
131
+ " {\n",
132
+ " 'role': 'user',\n",
133
+ " 'content': 'Describe: a car on a highway'\n",
134
+ " },\n",
135
+ " {\n",
136
+ " 'role': 'assistant',\n",
137
+ " 'content': 'A silver sedan cruises along a sunlit four-lane highway '\n",
138
+ " 'cutting through rolling green hills under a pale blue sky with wispy cirrus clouds.'\n",
139
+ " },\n",
140
+ " {\n",
141
+ " 'role': 'user',\n",
142
+ " 'content': 'Describe: a sunflower field'\n",
143
+ " },\n",
144
+ " {\n",
145
+ " 'role': 'assistant',\n",
146
+ " 'content': 'Thousands of tall sunflowers with bright yellow petals and dark brown centers '\n",
147
+ " 'stand in dense rows across a flat field stretching to the horizon at golden hour.'\n",
148
+ " },\n",
149
+ " {\n",
150
+ " 'role': 'user',\n",
151
+ " 'content': f'Describe: {subject}'\n",
152
+ " },\n",
153
+ " ]\n",
154
+ " return self.tokenizer.apply_chat_template(\n",
155
+ " messages, tokenize=False, add_generation_prompt=True\n",
156
+ " )\n",
157
+ "\n",
158
+ " @torch.no_grad()\n",
159
+ " def generate_description(self, subject: str, max_new_tokens=80) -> str:\n",
160
+ " \"\"\"Generate a one-sentence visual description via two-shot.\"\"\"\n",
161
+ " prompt = self.build_twoshot_prompt(subject)\n",
162
+ " inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)\n",
163
+ "\n",
164
+ " output_ids = self.model.generate(\n",
165
+ " **inputs,\n",
166
+ " max_new_tokens=max_new_tokens,\n",
167
+ " do_sample=True,\n",
168
+ " temperature=0.7,\n",
169
+ " top_p=0.9,\n",
170
+ " pad_token_id=self.tokenizer.eos_token_id,\n",
171
+ " )\n",
172
+ "\n",
173
+ " # Decode only the new tokens\n",
174
+ " new_tokens = output_ids[0][inputs['input_ids'].shape[1]:]\n",
175
+ " description = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()\n",
176
+ " return description\n",
177
+ "\n",
178
+ " @torch.no_grad()\n",
179
+ " def encode_text(self, text: str) -> dict:\n",
180
+ " \"\"\"\n",
181
+ " Encode text and return all hidden states.\n",
182
+ " Pads ultra-short inputs to avoid conv1d crash in DeltaNet layers.\n",
183
+ " \"\"\"\n",
184
+ " inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n",
185
+ " seq_len = inputs['input_ids'].shape[1]\n",
186
+ "\n",
187
+ " if seq_len < self.min_tokens:\n",
188
+ " text = text + ' . .'\n",
189
+ " inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n",
190
+ " seq_len = inputs['input_ids'].shape[1]\n",
191
+ "\n",
192
+ " outputs = self.model(**inputs, output_hidden_states=True)\n",
193
+ "\n",
194
+ " hidden_states = outputs.hidden_states\n",
195
+ " if hidden_states is None:\n",
196
+ " raise RuntimeError('Model returned None for hidden_states.')\n",
197
+ "\n",
198
+ " input_ids = inputs['input_ids'][0]\n",
199
+ " tokens = [self.tokenizer.decode(tid) for tid in input_ids]\n",
200
+ "\n",
201
+ " return {\n",
202
+ " 'hidden_states': hidden_states,\n",
203
+ " 'input_ids': input_ids,\n",
204
+ " 'tokens': tokens,\n",
205
+ " 'seq_len': len(tokens),\n",
206
+ " }\n",
207
+ "\n",
208
+ " def pool_hidden_states(self, hidden_states, method='mean'):\n",
209
+ " \"\"\"Pool across tokens for all layers. Returns [num_layers, hidden_dim].\"\"\"\n",
210
+ " pooled = []\n",
211
+ " for hs in hidden_states:\n",
212
+ " hs = hs.squeeze(0) # [seq_len, hidden_dim]\n",
213
+ " if method == 'mean':\n",
214
+ " pooled.append(hs.mean(dim=0))\n",
215
+ " elif method == 'last_token':\n",
216
+ " pooled.append(hs[-1])\n",
217
+ " elif method == 'max':\n",
218
+ " pooled.append(hs.max(dim=0).values)\n",
219
+ " else:\n",
220
+ " raise ValueError(f'Unknown method: {method}')\n",
221
+ " return torch.stack(pooled)\n",
222
+ "\n",
223
+ " def generate_and_encode(self, subject: str, method='mean') -> dict:\n",
224
+ " \"\"\"\n",
225
+ " Full pipeline: generate description, then re-encode it.\n",
226
+ " Returns embeddings of the GENERATED text, not the prompt.\n",
227
+ " \"\"\"\n",
228
+ " description = self.generate_description(subject)\n",
229
+ " data = self.encode_text(description)\n",
230
+ " embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n",
231
+ " return {\n",
232
+ " 'embeddings': embeddings,\n",
233
+ " 'description': description,\n",
234
+ " 'tokens': data['tokens'],\n",
235
+ " 'seq_len': data['seq_len'],\n",
236
+ " }\n",
237
+ "\n",
238
+ " def encode_raw(self, text: str, method='mean') -> dict:\n",
239
+ " \"\"\"\n",
240
+ " Direct encode (no generation). For comparison baseline.\n",
241
+ " \"\"\"\n",
242
+ " data = self.encode_text(text)\n",
243
+ " embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n",
244
+ " return {\n",
245
+ " 'embeddings': embeddings,\n",
246
+ " 'description': text,\n",
247
+ " 'tokens': data['tokens'],\n",
248
+ " 'seq_len': data['seq_len'],\n",
249
+ " }\n",
250
+ "\n",
251
+ "\n",
252
+ "extractor = TwoShotEmbeddingExtractor(model, tokenizer, device)\n",
253
+ "print(f'Extractor ready. {extractor.num_layers} layers, {extractor.hidden_dim}d')"
254
+ ],
255
+ "execution_count": null,
256
+ "outputs": []
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "metadata": {},
261
+ "source": [
262
+ "## Test Generation\n",
263
+ "Quick sanity check that the two-shot pipeline produces good descriptions."
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "metadata": {},
269
+ "source": [
270
+ "test_subjects = [\n",
271
+ " 'a cat on a windowsill',\n",
272
+ " 'a red cube on a blue floor',\n",
273
+ " 'an oil painting of a stormy sea',\n",
274
+ " 'darkness',\n",
275
+ " 'cheese',\n",
276
+ "]\n",
277
+ "\n",
278
+ "print('Two-shot generation test:')\n",
279
+ "print('=' * 70)\n",
280
+ "for subject in test_subjects:\n",
281
+ " desc = extractor.generate_description(subject)\n",
282
+ " print(f'\\nSubject: {subject}')\n",
283
+ " print(f'Generated: {desc}')"
284
+ ],
285
+ "execution_count": null,
286
+ "outputs": []
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "metadata": {},
291
+ "source": [
292
+ "## Define Test Subjects\n",
293
+ "Same groups as before. Each gets a two-shot generated description + raw encode for comparison."
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "metadata": {},
299
+ "source": [
300
+ "SUBJECT_GROUPS = {\n",
301
+ " 'photorealistic': [\n",
302
+ " 'a cat sitting on a windowsill in golden hour light',\n",
303
+ " 'a mountain landscape at sunset with dramatic clouds',\n",
304
+ " 'an elderly man with weathered skin and blue eyes',\n",
305
+ " ],\n",
306
+ " 'artistic': [\n",
307
+ " 'an oil painting of a stormy sea',\n",
308
+ " 'a quiet Japanese garden with cherry blossoms',\n",
309
+ " 'abstract geometric shapes overlapping',\n",
310
+ " ],\n",
311
+ " 'semantic_shift': [\n",
312
+ " 'a red cube on a blue floor',\n",
313
+ " 'a blue cube on a red floor',\n",
314
+ " 'a green sphere floating above a white plane',\n",
315
+ " ],\n",
316
+ " 'gibberish': [\n",
317
+ " 'mxkrl vvtonp qazhif bwsdee lpoqnr yttmz',\n",
318
+ " 'florpnax grindleby shovantic wumblecrax tazzifer',\n",
319
+ " 'aaaa bbbb cccc dddd eeee ffff gggg hhhh',\n",
320
+ " ],\n",
321
+ " 'short': [\n",
322
+ " 'taco',\n",
323
+ " '1girl',\n",
324
+ " 'cheese',\n",
325
+ " 'cheddar bacon sub',\n",
326
+ " ],\n",
327
+ "}\n",
328
+ "\n",
329
+ "subjects = []\n",
330
+ "subject_labels = []\n",
331
+ "subject_groups = []\n",
332
+ "for group_name, group_items in SUBJECT_GROUPS.items():\n",
333
+ " for s in group_items:\n",
334
+ " subjects.append(s)\n",
335
+ " label = s[:40] + '...' if len(s) > 40 else s\n",
336
+ " subject_labels.append(label)\n",
337
+ " subject_groups.append(group_name)\n",
338
+ "\n",
339
+ "print(f'{len(subjects)} subjects across {len(SUBJECT_GROUPS)} groups')"
340
+ ],
341
+ "execution_count": null,
342
+ "outputs": []
343
+ },
344
+ {
345
+ "cell_type": "markdown",
346
+ "metadata": {},
347
+ "source": [
348
+ "## Generate Descriptions + Extract Embeddings\n",
349
+ "Both two-shot (generated) and raw (direct encode) for comparison."
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "metadata": {},
355
+ "source": [
356
+ "POOL_METHODS = ['mean', 'last_token']\n",
357
+ "\n",
358
+ "# Two-shot generated embeddings\n",
359
+ "twoshot_embeddings = {method: {} for method in POOL_METHODS}\n",
360
+ "twoshot_descriptions = {}\n",
361
+ "twoshot_token_counts = {}\n",
362
+ "\n",
363
+ "# Raw direct-encode embeddings (baseline)\n",
364
+ "raw_embeddings = {method: {} for method in POOL_METHODS}\n",
365
+ "raw_token_counts = {}\n",
366
+ "\n",
367
+ "print('=== TWO-SHOT GENERATION + ENCODE ===')\n",
368
+ "print('=' * 70)\n",
369
+ "for i, subject in enumerate(subjects):\n",
370
+ " print(f'\\n[{i+1}/{len(subjects)}] \"{subject_labels[i]}\"')\n",
371
+ " for method in POOL_METHODS:\n",
372
+ " result = extractor.generate_and_encode(subject, method=method)\n",
373
+ " twoshot_embeddings[method][i] = result['embeddings'].float().cpu()\n",
374
+ " if method == POOL_METHODS[0]:\n",
375
+ " twoshot_descriptions[i] = result['description']\n",
376
+ " twoshot_token_counts[i] = result['seq_len']\n",
377
+ " print(f' -> \"{result[\"description\"][:80]}...\"' if len(result['description']) > 80 else f' -> \"{result[\"description\"]}\"')\n",
378
+ "\n",
379
+ "print('\\n\\n=== RAW DIRECT ENCODE (BASELINE) ===')\n",
380
+ "print('=' * 70)\n",
381
+ "for i, subject in enumerate(subjects):\n",
382
+ " print(f'[{i+1}/{len(subjects)}] \"{subject_labels[i]}\"')\n",
383
+ " for method in POOL_METHODS:\n",
384
+ " result = extractor.encode_raw(subject, method=method)\n",
385
+ " raw_embeddings[method][i] = result['embeddings'].float().cpu()\n",
386
+ " if method == POOL_METHODS[0]:\n",
387
+ " raw_token_counts[i] = result['seq_len']\n",
388
+ "\n",
389
+ "n_subjects = len(subjects)\n",
390
+ "n_layers = extractor.num_layers\n",
391
+ "\n",
392
+ "print(f'\\nDone. {n_subjects} subjects, {n_layers} layers, {extractor.hidden_dim}d')\n",
393
+ "print(f'Two-shot token counts: {list(twoshot_token_counts.values())}')\n",
394
+ "print(f'Raw token counts: {list(raw_token_counts.values())}')"
395
+ ],
396
+ "execution_count": null,
397
+ "outputs": []
398
+ },
399
+ {
400
+ "cell_type": "markdown",
401
+ "metadata": {},
402
+ "source": [
403
+ "## Cosine Similarity Matrices"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "metadata": {},
409
+ "source": [
410
+ "def compute_pairwise_cosine(embeddings_dict, num_prompts, num_layers):\n",
411
+ " sim_matrix = np.zeros((num_layers, num_prompts, num_prompts))\n",
412
+ " for layer_idx in range(num_layers):\n",
413
+ " for i in range(num_prompts):\n",
414
+ " for j in range(num_prompts):\n",
415
+ " if i == j:\n",
416
+ " sim_matrix[layer_idx, i, j] = 1.0\n",
417
+ " elif j > i:\n",
418
+ " vec_i = embeddings_dict[i][layer_idx].numpy()\n",
419
+ " vec_j = embeddings_dict[j][layer_idx].numpy()\n",
420
+ " sim = 1.0 - cosine(vec_i, vec_j)\n",
421
+ " sim_matrix[layer_idx, i, j] = sim\n",
422
+ " sim_matrix[layer_idx, j, i] = sim\n",
423
+ " return sim_matrix\n",
424
+ "\n",
425
+ "# Compute for both pipelines, both pooling methods\n",
426
+ "twoshot_sim = {}\n",
427
+ "raw_sim = {}\n",
428
+ "for method in POOL_METHODS:\n",
429
+ " twoshot_sim[method] = compute_pairwise_cosine(twoshot_embeddings[method], n_subjects, n_layers)\n",
430
+ " raw_sim[method] = compute_pairwise_cosine(raw_embeddings[method], n_subjects, n_layers)\n",
431
+ " print(f'{method}: twoshot {twoshot_sim[method].shape}, raw {raw_sim[method].shape}')"
432
+ ],
433
+ "execution_count": null,
434
+ "outputs": []
435
+ },
436
+ {
437
+ "cell_type": "markdown",
438
+ "metadata": {},
439
+ "source": [
440
+ "## Head-to-Head: Two-Shot vs Raw at Best Layer\n",
441
+ "Side-by-side heatmaps showing how two-shot generation changes the similarity landscape."
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "metadata": {},
447
+ "source": [
448
+ "def plot_comparison_heatmaps(twoshot_sim, raw_sim, labels, method, layer_idx):\n",
449
+ " \"\"\"Side-by-side: raw vs two-shot at a specific layer.\"\"\"\n",
450
+ " layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n",
451
+ " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))\n",
452
+ "\n",
453
+ " sns.heatmap(\n",
454
+ " raw_sim[layer_idx], xticklabels=labels, yticklabels=labels,\n",
455
+ " vmin=-0.2, vmax=1.0, cmap='RdYlBu_r', annot=True, fmt='.2f',\n",
456
+ " ax=ax1, square=True, annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n",
457
+ " )\n",
458
+ " ax1.set_title(f'RAW encode | {method} | {layer_name}', fontsize=14)\n",
459
+ " ax1.tick_params(axis='x', rotation=90, labelsize=7)\n",
460
+ " ax1.tick_params(axis='y', rotation=0, labelsize=7)\n",
461
+ "\n",
462
+ " sns.heatmap(\n",
463
+ " twoshot_sim[layer_idx], xticklabels=labels, yticklabels=labels,\n",
464
+ " vmin=-0.2, vmax=1.0, cmap='RdYlBu_r', annot=True, fmt='.2f',\n",
465
+ " ax=ax2, square=True, annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n",
466
+ " )\n",
467
+ " ax2.set_title(f'TWO-SHOT encode | {method} | {layer_name}', fontsize=14)\n",
468
+ " ax2.tick_params(axis='x', rotation=90, labelsize=7)\n",
469
+ " ax2.tick_params(axis='y', rotation=0, labelsize=7)\n",
470
+ "\n",
471
+ " plt.tight_layout()\n",
472
+ " plt.show()\n",
473
+ "\n",
474
+ "short_labels = [l[:30] for l in subject_labels]\n",
475
+ "\n",
476
+ "# Show at penultimate and final layer for last_token\n",
477
+ "for layer_idx in [n_layers - 2, n_layers - 1]:\n",
478
+ " plot_comparison_heatmaps(\n",
479
+ " twoshot_sim['last_token'], raw_sim['last_token'],\n",
480
+ " short_labels, 'last_token', layer_idx\n",
481
+ " )"
482
+ ],
483
+ "execution_count": null,
484
+ "outputs": []
485
+ },
486
+ {
487
+ "cell_type": "markdown",
488
+ "metadata": {},
489
+ "source": [
490
+ "## Two-Shot Heatmap Grid (All Sampled Layers)"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "metadata": {},
496
+ "source": [
497
+ "def plot_heatmap_grid(sim_matrix, labels, method_name, title_prefix=''):\n",
498
+ " n_layers = sim_matrix.shape[0]\n",
499
+ " layers_to_show = sorted(set([\n",
500
+ " 0, n_layers // 4, n_layers // 2,\n",
501
+ " 3 * n_layers // 4, n_layers - 2, n_layers - 1,\n",
502
+ " ]))\n",
503
+ "\n",
504
+ " fig, axes = plt.subplots(2, 3, figsize=(24, 20))\n",
505
+ " axes = axes.flatten()\n",
506
+ "\n",
507
+ " for idx, (ax, layer_idx) in enumerate(zip(axes, layers_to_show)):\n",
508
+ " layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n",
509
+ " sns.heatmap(\n",
510
+ " sim_matrix[layer_idx],\n",
511
+ " xticklabels=labels, yticklabels=labels,\n",
512
+ " vmin=-0.2, vmax=1.0, cmap='RdYlBu_r',\n",
513
+ " annot=True, fmt='.2f', ax=ax, square=True,\n",
514
+ " annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n",
515
+ " )\n",
516
+ " ax.set_title(f'{title_prefix}{method_name} | {layer_name}', fontsize=13)\n",
517
+ " ax.tick_params(axis='x', rotation=90, labelsize=7)\n",
518
+ " ax.tick_params(axis='y', rotation=0, labelsize=7)\n",
519
+ "\n",
520
+ " for idx in range(len(layers_to_show), len(axes)):\n",
521
+ " axes[idx].set_visible(False)\n",
522
+ "\n",
523
+ " plt.tight_layout()\n",
524
+ " plt.show()\n",
525
+ "\n",
526
+ "for method in POOL_METHODS:\n",
527
+ " print(f'\\n=== TWO-SHOT | {method.upper()} ===')\n",
528
+ " plot_heatmap_grid(twoshot_sim[method], short_labels, method, 'twoshot | ')"
529
+ ],
530
+ "execution_count": null,
531
+ "outputs": []
532
+ },
533
+ {
534
+ "cell_type": "markdown",
535
+ "metadata": {},
536
+ "source": [
537
+ "## Discriminability: Two-Shot vs Raw"
538
+ ]
539
+ },
540
+ {
541
+ "cell_type": "code",
542
+ "metadata": {},
543
+ "source": [
544
+ "def compute_discriminability(sim_matrix, group_labels):\n",
545
+ " n_layers = sim_matrix.shape[0]\n",
546
+ " n = sim_matrix.shape[1]\n",
547
+ " within_sim = np.zeros(n_layers)\n",
548
+ " between_sim = np.zeros(n_layers)\n",
549
+ "\n",
550
+ " for layer in range(n_layers):\n",
551
+ " w_vals, b_vals = [], []\n",
552
+ " for i in range(n):\n",
553
+ " for j in range(i + 1, n):\n",
554
+ " val = sim_matrix[layer, i, j]\n",
555
+ " if group_labels[i] == group_labels[j]:\n",
556
+ " w_vals.append(val)\n",
557
+ " else:\n",
558
+ " b_vals.append(val)\n",
559
+ " within_sim[layer] = np.mean(w_vals) if w_vals else 0\n",
560
+ " between_sim[layer] = np.mean(b_vals) if b_vals else 0\n",
561
+ "\n",
562
+ " return within_sim, between_sim, within_sim - between_sim\n",
563
+ "\n",
564
+ "\n",
565
+ "fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
566
+ "\n",
567
+ "configs = [\n",
568
+ " ('mean', raw_sim, 'RAW | mean'),\n",
569
+ " ('mean', twoshot_sim, 'TWO-SHOT | mean'),\n",
570
+ " ('last_token', raw_sim, 'RAW | last_token'),\n",
571
+ " ('last_token', twoshot_sim, 'TWO-SHOT | last_token'),\n",
572
+ "]\n",
573
+ "\n",
574
+ "for ax, (method, sim_dict, title) in zip(axes.flatten(), configs):\n",
575
+ " within, between, gap = compute_discriminability(sim_dict[method], subject_groups)\n",
576
+ " layer_x = np.arange(n_layers)\n",
577
+ "\n",
578
+ " ax.plot(layer_x, within, label='Within-group', color='#2196F3', linewidth=2)\n",
579
+ " ax.plot(layer_x, between, label='Between-group', color='#FF5722', linewidth=2)\n",
580
+ " ax.fill_between(layer_x, between, within, alpha=0.15, color='green')\n",
581
+ " ax.plot(layer_x, gap, label='Gap', color='green', linewidth=2, linestyle='--')\n",
582
+ "\n",
583
+ " best = np.argmax(gap)\n",
584
+ " ax.axvline(best, color='green', linestyle=':', alpha=0.5)\n",
585
+ " ax.annotate(f'Best: L{best} ({gap[best]:.3f})', xy=(best, gap[best]),\n",
586
+ " xytext=(best + 1, gap[best] + 0.02),\n",
587
+ " arrowprops=dict(arrowstyle='->', color='green'),\n",
588
+ " fontsize=9, color='green')\n",
589
+ "\n",
590
+ " ax.set_xlabel('Layer')\n",
591
+ " ax.set_ylabel('Cosine Similarity')\n",
592
+ " ax.set_title(title, fontsize=13)\n",
593
+ " ax.legend(fontsize=8)\n",
594
+ " ax.grid(True, alpha=0.3)\n",
595
+ "\n",
596
+ "plt.tight_layout()\n",
597
+ "plt.show()"
598
+ ],
599
+ "execution_count": null,
600
+ "outputs": []
601
+ },
602
+ {
603
+ "cell_type": "markdown",
604
+ "metadata": {},
605
+ "source": [
606
+ "## Similarity Statistics: Two-Shot vs Raw"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "metadata": {},
612
+ "source": [
613
+ "print('=' * 70)\n",
614
+ "print('SIMILARITY STATISTICS COMPARISON')\n",
615
+ "print('=' * 70)\n",
616
+ "\n",
617
+ "for method in POOL_METHODS:\n",
618
+ " print(f'\\n--- {method.upper()} ---')\n",
619
+ " for label, sim_dict in [('RAW', raw_sim), ('TWO-SHOT', twoshot_sim)]:\n",
620
+ " # Use penultimate layer\n",
621
+ " layer = n_layers - 2\n",
622
+ " mat = sim_dict[method][layer]\n",
623
+ " off_diag = mat[np.triu_indices(n_subjects, k=1)]\n",
624
+ "\n",
625
+ " print(f' {label} (L{layer}):')\n",
626
+ " print(f' Mean sim: {off_diag.mean():.4f}')\n",
627
+ " print(f' Std sim: {off_diag.std():.4f}')\n",
628
+ " print(f' Min sim: {off_diag.min():.4f}')\n",
629
+ " print(f' Max sim: {off_diag.max():.4f}')\n",
630
+ " print(f' Near-zero (<0.05): {(off_diag < 0.05).sum()}')\n",
631
+ " print(f' High (>0.9): {(off_diag > 0.9).sum()}')"
632
+ ],
633
+ "execution_count": null,
634
+ "outputs": []
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "metadata": {},
639
+ "source": [
640
+ "## Norms & Effective Dimensionality (Two-Shot)"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "metadata": {},
646
+ "source": [
647
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
648
+ "\n",
649
+ "for method, ax in zip(POOL_METHODS, axes):\n",
650
+ " for i in range(n_subjects):\n",
651
+ " norms = twoshot_embeddings[method][i].norm(dim=-1).numpy()\n",
652
+ " ax.plot(range(n_layers), norms, alpha=0.6, label=short_labels[i][:20])\n",
653
+ " ax.set_xlabel('Layer')\n",
654
+ " ax.set_ylabel('L2 Norm')\n",
655
+ " ax.set_title(f'TWO-SHOT | {method} | Embedding Norms')\n",
656
+ " ax.grid(True, alpha=0.3)\n",
657
+ " ax.legend(fontsize=6, loc='upper left')\n",
658
+ "\n",
659
+ "plt.tight_layout()\n",
660
+ "plt.show()"
661
+ ],
662
+ "execution_count": null,
663
+ "outputs": []
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "metadata": {},
668
+ "source": [
669
+ "def effective_dimensionality(embeddings_list):\n",
670
+ " mat = torch.stack(embeddings_list)\n",
671
+ " mat = mat - mat.mean(dim=0)\n",
672
+ " _, S, _ = torch.svd(mat)\n",
673
+ " S = S / S.sum()\n",
674
+ " return 1.0 / (S ** 2).sum().item()\n",
675
+ "\n",
676
+ "fig, ax = plt.subplots(figsize=(10, 5))\n",
677
+ "\n",
678
+ "for label, emb_dict, ls in [('raw', raw_embeddings, '--'), ('twoshot', twoshot_embeddings, '-')]:\n",
679
+ " for method, color in zip(POOL_METHODS, ['#2196F3', '#FF5722']):\n",
680
+ " eff_dims = []\n",
681
+ " for layer_idx in range(n_layers):\n",
682
+ " vecs = [emb_dict[method][i][layer_idx] for i in range(n_subjects)]\n",
683
+ " eff_dims.append(effective_dimensionality(vecs))\n",
684
+ " ax.plot(range(n_layers), eff_dims, label=f'{label} | {method}',\n",
685
+ " linewidth=2, linestyle=ls, color=color)\n",
686
+ "\n",
687
+ "ax.set_xlabel('Layer')\n",
688
+ "ax.set_ylabel('Effective Dimensionality')\n",
689
+ "ax.set_title('Effective Rank: Raw (dashed) vs Two-Shot (solid)')\n",
690
+ "ax.legend()\n",
691
+ "ax.grid(True, alpha=0.3)\n",
692
+ "plt.tight_layout()\n",
693
+ "plt.show()"
694
+ ],
695
+ "execution_count": null,
696
+ "outputs": []
697
+ },
698
+ {
699
+ "cell_type": "markdown",
700
+ "metadata": {},
701
+ "source": [
702
+ "## Summary"
703
+ ]
704
+ },
705
+ {
706
+ "cell_type": "code",
707
+ "metadata": {},
708
+ "source": [
709
+ "print('=' * 70)\n",
710
+ "print('TWO-SHOT vs RAW EMBEDDING SUMMARY')\n",
711
+ "print('=' * 70)\n",
712
+ "print(f'Model: {MODEL_ID}')\n",
713
+ "print(f'Layers: {n_layers}, Hidden dim: {extractor.hidden_dim}')\n",
714
+ "print(f'Subjects: {n_subjects}')\n",
715
+ "print()\n",
716
+ "\n",
717
+ "for method in POOL_METHODS:\n",
718
+ " print(f'--- {method.upper()} ---')\n",
719
+ " for label, sim_dict in [('RAW', raw_sim), ('TWO-SHOT', twoshot_sim)]:\n",
720
+ " within, between, gap = compute_discriminability(sim_dict[method], subject_groups)\n",
721
+ " best_l = np.argmax(gap)\n",
722
+ " print(f' {label}:')\n",
723
+ " print(f' Best layer: L{best_l} (gap = {gap[best_l]:.4f})')\n",
724
+ " print(f' Final layer gap: {gap[-1]:.4f}')\n",
725
+ " print()\n",
726
+ "\n",
727
+ "print('\\nGENERATED DESCRIPTIONS:')\n",
728
+ "for i, subject in enumerate(subjects):\n",
729
+ " print(f' [{subject_labels[i]}]')\n",
730
+ " print(f' -> {twoshot_descriptions[i]}')"
731
+ ],
732
+ "execution_count": null,
733
+ "outputs": []
734
+ },
735
+ {
736
+ "cell_type": "markdown",
737
+ "metadata": {},
738
+ "source": [
739
+ "## (Optional) Export"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "code",
744
+ "metadata": {},
745
+ "source": [
746
+ "# Uncomment to save\n",
747
+ "# export = {\n",
748
+ "# 'model_id': MODEL_ID,\n",
749
+ "# 'subjects': subjects,\n",
750
+ "# 'subject_groups': subject_groups,\n",
751
+ "# 'twoshot_descriptions': twoshot_descriptions,\n",
752
+ "# 'twoshot_embeddings': twoshot_embeddings,\n",
753
+ "# 'raw_embeddings': raw_embeddings,\n",
754
+ "# 'twoshot_sim': twoshot_sim,\n",
755
+ "# 'raw_sim': raw_sim,\n",
756
+ "# 'num_layers': n_layers,\n",
757
+ "# 'hidden_dim': extractor.hidden_dim,\n",
758
+ "# }\n",
759
+ "# torch.save(export, 'qwen35_twoshot_embeddings.pt')\n",
760
+ "# print('Saved.')"
761
+ ],
762
+ "execution_count": null,
763
+ "outputs": []
764
+ }
765
+ ]
766
+ }