{ "cells": [ { "cell_type": "markdown", "id": "221e53a9", "metadata": {}, "source": [ "# Geometric Terrain Analysis — Experiment Bulk\n", "\n", "**Repository:** AbstractPhil/procrustes-analysis \n", "**Date:** 2026-03-05/06 \n", "**Models Profiled:** 17 \n", "**Architecture Families:** 5 (Transformer enc-dec, encoder-only/vision, adapted enc-dec, conv UNet, conv autoencoder) \n", "**Training Objectives:** 6 (span corruption, MLM, contrastive, self-supervised, diffusion, reconstruction)\n", "\n", "This notebook contains all experiments from the geometric terrain analysis sessions. Each section corresponds to sections in the README statistics composite.\n", "\n", "---\n" ] }, { "cell_type": "markdown", "id": "a265f262", "metadata": {}, "source": [ "## 1. T5-Small Embedding + Layer Geometry\n", "*Sections II–V.1: PR, CV, digit manifold, categories, layer evolution*" ] }, { "cell_type": "code", "execution_count": null, "id": "6a5e7396", "metadata": {}, "outputs": [], "source": [ "# T5-Small terrain\n", "\n", "# ============================================================================\n", "# T5-SMALL: FULL GEOMETRIC TERRAIN MAP\n", "# Everything. Kitchen sink and more.\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import math\n", "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", "import matplotlib.pyplot as plt\n", "from scipy.stats import spearmanr\n", "from scipy.spatial.distance import pdist\n", "\n", "model_id = \"google-t5/t5-small\"\n", "print(f\"Loading {model_id}...\")\n", "tokenizer = T5Tokenizer.from_pretrained(model_id, legacy=True)\n", "model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float32)\n", "model.eval()\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ARCHITECTURE INTROSPECTION\")\n", "print(f\"{'='*70}\")\n", "\n", "# Parameter census\n", "components = {}\n", "for name, param in model.named_parameters():\n", " parts = name.split('.')\n", " key = f\"{parts[0]}.{parts[1]}\" if len(parts) > 1 else parts[0]\n", " components[key] = components.get(key, 0) + param.numel()\n", "\n", "print(\"Parameter distribution:\")\n", "total = sum(components.values())\n", "for key, count in sorted(components.items(), key=lambda x: -x[1]):\n", " print(f\" {key:40s} {count:>12,} ({count/total*100:.1f}%)\")\n", "print(f\" {'TOTAL':40s} {total:>12,}\")\n", "\n", "# Embedding details\n", "embed = model.shared # T5 uses shared embedding\n", "E = embed.weight.detach().float().clone()\n", "vocab_size, hidden_dim = E.shape\n", "print(f\"\\nEmbedding: vocab={vocab_size}, dim={hidden_dim}\")\n", "print(f\"Embed params: {vocab_size * hidden_dim:,}\")\n", "\n", "# Check weight tying\n", "enc_embed = model.encoder.embed_tokens.weight\n", "dec_embed = model.decoder.embed_tokens.weight\n", "lm_head = model.lm_head.weight\n", "\n", "enc_tied = torch.allclose(E, enc_embed.detach().float())\n", "dec_tied = torch.allclose(E, dec_embed.detach().float())\n", "lm_tied = torch.allclose(E, lm_head.detach().float())\n", "print(f\"Encoder embed tied to shared: {enc_tied}\")\n", "print(f\"Decoder embed tied to shared: {dec_tied}\")\n", "print(f\"LM head tied to shared: {lm_tied}\")\n", "\n", "# Architecture shape\n", "n_enc_layers = len(model.encoder.block)\n", "n_dec_layers = len(model.decoder.block)\n", "print(f\"\\nEncoder layers: {n_enc_layers}\")\n", "print(f\"Decoder layers: {n_dec_layers}\")\n", "print(f\"Hidden dim: {hidden_dim}\")\n", "\n", "# Attention heads\n", "enc_attn = model.encoder.block[0].layer[0].SelfAttention\n", "print(f\"Attention heads: {enc_attn.n_heads}\")\n", "print(f\"d_kv: {enc_attn.key_value_proj_dim}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"GLOBAL EMBEDDING STATISTICS\")\n", "print(f\"{'='*70}\")\n", "\n", "E_np = E.numpy()\n", "norms = np.linalg.norm(E_np, axis=1)\n", "print(f\"Norm mean={norms.mean():.6f} std={norms.std():.6f} min={norms.min():.6f} max={norms.max():.6f}\")\n", "\n", "# Per-dim stats\n", "per_dim_mean = E_np.mean(axis=0)\n", "per_dim_std = E_np.std(axis=0)\n", "print(f\"Per-dim mean of means: {per_dim_mean.mean():.8f}\")\n", "print(f\"Per-dim mean of stds: {per_dim_std.mean():.8f}\")\n", "\n", "# Zero / near-zero embeddings\n", "zero_count = (norms < 1e-6).sum()\n", "print(f\"Zero embeddings: {zero_count} / {vocab_size}\")\n", "\n", "# Min/max norm tokens\n", "min_idx = norms.argmin()\n", "max_idx = norms.argmax()\n", "print(f\"Min norm token {min_idx}: '{tokenizer.decode([min_idx])}' (norm={norms[min_idx]:.6f})\")\n", "print(f\"Max norm token {max_idx}: '{tokenizer.decode([max_idx])}' (norm={norms[max_idx]:.6f})\")\n", "\n", "# Norm histogram percentiles\n", "for p in [1, 5, 25, 50, 75, 95, 99]:\n", " print(f\" {p:>3}% norm: {np.percentile(norms, p):.6f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"COSINE SIMILARITY DISTRIBUTION\")\n", "print(f\"{'='*70}\")\n", "\n", "rng = np.random.default_rng(42)\n", "N_SAMPLE = 5000\n", "sample_idx = rng.choice(vocab_size, size=N_SAMPLE, replace=False)\n", "E_sample = E_np[sample_idx]\n", "E_sample_n = E_sample / (np.linalg.norm(E_sample, axis=1, keepdims=True) + 1e-8)\n", "cos_mat = E_sample_n @ E_sample_n.T\n", "\n", "tri = np.triu_indices(N_SAMPLE, k=1)\n", "flat_cos = cos_mat[tri[0], tri[1]]\n", "\n", "print(f\"Pairs: {len(flat_cos):,}\")\n", "print(f\"Mean: {flat_cos.mean():.6f}\")\n", "print(f\"Std: {flat_cos.std():.6f}\")\n", "print(f\"Median: {np.median(flat_cos):.6f}\")\n", "for p in [1, 5, 25, 50, 75, 95, 99]:\n", " print(f\" {p:>3}%: {np.percentile(flat_cos, p):.6f}\")\n", "\n", "# Check for key constants\n", "for val, name in [(0.0, \"zero\"), (0.19471, \"qwen_mean\"), (0.29514, \"phil_constant\"), (0.5, \"half\")]:\n", " within = (np.abs(flat_cos - val) < 0.01).mean()\n", " nearest = np.abs(flat_cos - val).min()\n", " print(f\" Pairs within ±0.01 of {val:.5f} ({name}): {within*100:.3f}% nearest={nearest:.8f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"EIGENSPECTRUM & INTRINSIC DIMENSIONALITY\")\n", "print(f\"{'='*70}\")\n", "\n", "# Use all embeddings for covariance (T5 vocab is small enough)\n", "E_centered = E_np - E_np.mean(axis=0)\n", "cov = (E_centered.T @ E_centered) / vocab_size\n", "eigvals = np.linalg.eigvalsh(cov)[::-1]\n", "\n", "total_var = eigvals.sum()\n", "cumvar = np.cumsum(eigvals) / total_var\n", "\n", "print(f\"Total variance: {total_var:.4f}\")\n", "print(f\"Top 5 eigenvalues: {eigvals[:5]}\")\n", "print(f\"Top eigenvalue %: {eigvals[0]/total_var*100:.2f}%\")\n", "\n", "# Participation ratio\n", "pr = (eigvals.sum()) ** 2 / (eigvals ** 2).sum()\n", "print(f\"Participation ratio: {pr:.1f}\")\n", "print(f\"Participation / dim: {pr/hidden_dim:.3f}\")\n", "\n", "for frac in [0.80, 0.90, 0.95, 0.99]:\n", " n_dims = np.searchsorted(cumvar, frac) + 1\n", " print(f\"Dims for {frac*100:.0f}% var: {n_dims} ({n_dims/hidden_dim*100:.1f}% of {hidden_dim})\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"PENTACHORON GEOMETRY (Cayley-Menger)\")\n", "print(f\"{'='*70}\")\n", "\n", "def cayley_menger_volume_sq(points):\n", " n = len(points)\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1\n", " D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq\n", " D[j + 1, i + 1] = d_sq\n", " k = n - 1\n", " sign = (-1) ** (k + 1)\n", " factorial_sq = math.factorial(k) ** 2\n", " denom = (2 ** k) * factorial_sq\n", " det = np.linalg.det(D)\n", " vol_sq = sign * det / denom\n", " return vol_sq\n", "\n", "N_SIMP = 1000\n", "vols_embed = []\n", "vols_random = []\n", "\n", "for _ in range(N_SIMP):\n", " idx = rng.choice(vocab_size, size=5, replace=False)\n", " pts = E_np[idx]\n", " vol_sq = cayley_menger_volume_sq(pts)\n", " if vol_sq > 0:\n", " vols_embed.append(np.sqrt(vol_sq))\n", "\n", " pts_r = rng.normal(0, E_np.std(), size=(5, hidden_dim)).astype(np.float32)\n", " vol_sq_r = cayley_menger_volume_sq(pts_r)\n", " if vol_sq_r > 0:\n", " vols_random.append(np.sqrt(vol_sq_r))\n", "\n", "vols_embed = np.array(vols_embed)\n", "vols_random = np.array(vols_random)\n", "\n", "print(f\"Valid: {len(vols_embed)} / {N_SIMP}\")\n", "print(f\"Mean vol: {vols_embed.mean():.6f}\")\n", "print(f\"Std vol: {vols_embed.std():.6f}\")\n", "print(f\"CV: {vols_embed.std()/vols_embed.mean():.6f}\")\n", "print(f\"Random mean: {vols_random.mean():.6f}\")\n", "print(f\"Ratio: {vols_embed.mean()/vols_random.mean():.6f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"DIGIT EMBEDDING GEOMETRY\")\n", "print(f\"{'='*70}\")\n", "\n", "# T5 tokenizer encodes digits differently — find them\n", "digit_tokens = []\n", "for d in range(10):\n", " ids = tokenizer.encode(str(d), add_special_tokens=False)\n", " digit_tokens.append(ids[0] if len(ids) == 1 else ids[0])\n", " tok_str = tokenizer.decode([digit_tokens[-1]])\n", " print(f\" '{d}' -> token {digit_tokens[-1]} '{tok_str}' (encode len={len(ids)})\")\n", "\n", "digit_embeds = E_np[digit_tokens]\n", "digit_n = digit_embeds / (np.linalg.norm(digit_embeds, axis=1, keepdims=True) + 1e-8)\n", "cos_digits = digit_n @ digit_n.T\n", "\n", "print(f\"\\n \", end=\"\")\n", "for d in range(10):\n", " print(f\" '{d}' \", end=\"\")\n", "print()\n", "for i in range(10):\n", " print(f\" '{i}' \", end=\"\")\n", " for j in range(10):\n", " if j <= i:\n", " print(\" \", end=\"\")\n", " else:\n", " print(f\"{cos_digits[i,j]:.4f} \", end=\"\")\n", " print()\n", "\n", "# Distance-cosine correlation\n", "pairs = []\n", "for i in range(10):\n", " for j in range(i+1, 10):\n", " pairs.append((abs(i - j), cos_digits[i, j]))\n", "dists, cosines = zip(*pairs)\n", "corr = np.corrcoef(dists, cosines)[0, 1]\n", "adj_mean = np.mean([c for d, c in pairs if d == 1])\n", "nonadj_mean = np.mean([c for d, c in pairs if d > 1])\n", "print(f\"\\nCorrelation(|i-j|, cosine): {corr:.4f}\")\n", "print(f\"Adjacent mean: {adj_mean:.4f}\")\n", "print(f\"Non-adjacent mean: {nonadj_mean:.4f}\")\n", "print(f\"Gap: {adj_mean - nonadj_mean:.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"SEMANTIC CATEGORY CLUSTERING\")\n", "print(f\"{'='*70}\")\n", "\n", "categories = {\n", " \"animals\": [\"cat\", \"dog\", \"bird\", \"fish\", \"horse\", \"mouse\", \"bear\", \"wolf\", \"deer\", \"fox\",\n", " \"lion\", \"tiger\", \"snake\", \"whale\", \"frog\", \"rabbit\", \"monkey\", \"elephant\"],\n", " \"colors\": [\"red\", \"blue\", \"green\", \"yellow\", \"black\", \"white\", \"orange\", \"purple\", \"brown\", \"pink\"],\n", " \"numbers\": [\"one\", \"two\", \"three\", \"four\", \"five\", \"six\", \"seven\", \"eight\", \"nine\", \"ten\"],\n", " \"body\": [\"head\", \"hand\", \"eye\", \"foot\", \"arm\", \"leg\", \"face\", \"mouth\", \"heart\", \"brain\"],\n", " \"food\": [\"bread\", \"meat\", \"rice\", \"milk\", \"fish\", \"salt\", \"sugar\", \"cheese\", \"fruit\", \"water\"],\n", " \"emotions\": [\"happy\", \"sad\", \"angry\", \"fear\", \"love\", \"hate\", \"joy\", \"hope\", \"pain\", \"calm\"],\n", " \"actions\": [\"run\", \"walk\", \"jump\", \"fly\", \"swim\", \"eat\", \"sleep\", \"talk\", \"think\", \"write\"],\n", " \"time\": [\"day\", \"night\", \"year\", \"month\", \"week\", \"hour\", \"morning\", \"evening\", \"today\", \"tomorrow\"],\n", "}\n", "\n", "# Global mean cosine for reference\n", "global_mean_cos = flat_cos.mean()\n", "print(f\"Global mean pairwise cosine: {global_mean_cos:.4f}\\n\")\n", "\n", "for cat_name, words in categories.items():\n", " token_ids = []\n", " valid_words = []\n", " for w in words:\n", " ids = tokenizer.encode(w, add_special_tokens=False)\n", " if len(ids) == 1:\n", " token_ids.append(ids[0])\n", " valid_words.append(w)\n", "\n", " if len(token_ids) < 3:\n", " print(f\" {cat_name:12s}: only {len(token_ids)} single-token words, skipping\")\n", " continue\n", "\n", " cat_embeds = E_np[token_ids]\n", " cat_n = cat_embeds / (np.linalg.norm(cat_embeds, axis=1, keepdims=True) + 1e-8)\n", " n_cat = len(token_ids)\n", " tri_c = np.triu_indices(n_cat, k=1)\n", " intra = (cat_n @ cat_n.T)[tri_c].mean()\n", " print(f\" {cat_name:12s}: n={n_cat:2d} intra_cos={intra:.4f} lift={intra - global_mean_cos:+.4f} words={valid_words[:5]}...\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"CROSS-CATEGORY RELATIONAL STRUCTURE\")\n", "print(f\"{'='*70}\")\n", "\n", "# Build a single matrix of all category centroids\n", "centroids = {}\n", "for cat_name, words in categories.items():\n", " token_ids = []\n", " for w in words:\n", " ids = tokenizer.encode(w, add_special_tokens=False)\n", " if len(ids) == 1:\n", " token_ids.append(ids[0])\n", " if len(token_ids) >= 3:\n", " cat_embeds = E_np[token_ids]\n", " centroids[cat_name] = cat_embeds.mean(axis=0)\n", "\n", "cat_names = list(centroids.keys())\n", "centroid_mat = np.stack([centroids[c] for c in cat_names])\n", "centroid_n = centroid_mat / (np.linalg.norm(centroid_mat, axis=1, keepdims=True) + 1e-8)\n", "cross_cos = centroid_n @ centroid_n.T\n", "\n", "print(\"Category centroid cosine similarity:\")\n", "print(f\"{'':12s}\", end=\"\")\n", "for c in cat_names:\n", " print(f\" {c[:7]:>7s}\", end=\"\")\n", "print()\n", "for i, ci in enumerate(cat_names):\n", " print(f\"{ci:12s}\", end=\"\")\n", " for j, cj in enumerate(cat_names):\n", " if j < i:\n", " print(f\" \", end=\"\")\n", " elif j == i:\n", " print(f\" --- \", end=\"\")\n", " else:\n", " print(f\" {cross_cos[i,j]:.4f} \", end=\"\")\n", " print()\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ENCODER LAYER-BY-LAYER GEOMETRY\")\n", "print(f\"{'='*70}\")\n", "\n", "# Feed a diverse set of sentences through encoder, capture hidden states at each layer\n", "test_sentences = [\n", " \"The cat sat on the mat.\",\n", " \"Quantum mechanics describes the behavior of particles at the atomic scale.\",\n", " \"She quickly ran to the store before it closed.\",\n", " \"The derivative of x squared is two x.\",\n", " \"Red and blue make purple when mixed together.\",\n", " \"The president signed the new trade agreement yesterday.\",\n", " \"Three plus four equals seven.\",\n", " \"Love is patient, love is kind.\",\n", " \"The function returns a sorted list of integers.\",\n", " \"Mount Everest is the tallest mountain in the world.\",\n", "]\n", "\n", "layer_stats = []\n", "\n", "for sent in test_sentences:\n", " inputs = tokenizer(sent, return_tensors=\"pt\", padding=False)\n", " with torch.no_grad():\n", " outputs = model.encoder(\n", " input_ids=inputs.input_ids,\n", " output_hidden_states=True,\n", " )\n", "\n", " # outputs.hidden_states: tuple of (n_layers+1) tensors, each [1, seq_len, dim]\n", " for layer_idx, hs in enumerate(outputs.hidden_states):\n", " h = hs[0].float().numpy() # [seq_len, dim]\n", " h_norms = np.linalg.norm(h, axis=1)\n", " # Pairwise cosine between token positions\n", " h_n = h / (np.linalg.norm(h, axis=1, keepdims=True) + 1e-8)\n", " if h.shape[0] > 1:\n", " tri_h = np.triu_indices(h.shape[0], k=1)\n", " pairwise_cos = (h_n @ h_n.T)[tri_h]\n", " else:\n", " pairwise_cos = np.array([0.0])\n", "\n", " layer_stats.append({\n", " 'layer': layer_idx,\n", " 'mean_norm': h_norms.mean(),\n", " 'std_norm': h_norms.std(),\n", " 'mean_cos': pairwise_cos.mean(),\n", " 'std_cos': pairwise_cos.std(),\n", " 'seq_len': h.shape[0],\n", " })\n", "\n", "# Aggregate by layer\n", "import pandas as pd\n", "df = pd.DataFrame(layer_stats)\n", "layer_agg = df.groupby('layer').agg({\n", " 'mean_norm': 'mean',\n", " 'std_norm': 'mean',\n", " 'mean_cos': 'mean',\n", " 'std_cos': 'mean',\n", "}).reset_index()\n", "\n", "print(f\"\\nEncoder hidden state geometry across {len(test_sentences)} sentences:\")\n", "print(f\"{'Layer':>5s} {'Norm':>8s} {'NormStd':>8s} {'Cos':>8s} {'CosStd':>8s}\")\n", "for _, row in layer_agg.iterrows():\n", " print(f\"{int(row['layer']):5d} {row['mean_norm']:8.4f} {row['std_norm']:8.4f} {row['mean_cos']:8.4f} {row['std_cos']:8.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ENCODER vs DECODER HIDDEN STATE COMPARISON\")\n", "print(f\"{'='*70}\")\n", "\n", "# Run a translation-style task to get both encoder and decoder states\n", "test_input = \"translate English to German: The house is big.\"\n", "inputs = tokenizer(test_input, return_tensors=\"pt\")\n", "decoder_input = tokenizer(\"Das Haus ist groß.\", return_tensors=\"pt\")\n", "\n", "with torch.no_grad():\n", " enc_out = model.encoder(input_ids=inputs.input_ids, output_hidden_states=True)\n", " dec_out = model.decoder(\n", " input_ids=decoder_input.input_ids,\n", " encoder_hidden_states=enc_out.last_hidden_state,\n", " output_hidden_states=True,\n", " )\n", "\n", "print(\"Encoder final hidden state:\")\n", "enc_final = enc_out.last_hidden_state[0].float().numpy()\n", "enc_norms = np.linalg.norm(enc_final, axis=1)\n", "print(f\" Shape: {enc_final.shape}, Norm mean={enc_norms.mean():.4f}, std={enc_norms.std():.4f}\")\n", "\n", "print(\"Decoder final hidden state:\")\n", "dec_final = dec_out.last_hidden_state[0].float().numpy()\n", "dec_norms = np.linalg.norm(dec_final, axis=1)\n", "print(f\" Shape: {dec_final.shape}, Norm mean={dec_norms.mean():.4f}, std={dec_norms.std():.4f}\")\n", "\n", "# Cosine between encoder and decoder token representations\n", "# Compare embedding space: encode same tokens, see if they diverge\n", "common_sent = \"the cat\"\n", "common_ids = tokenizer.encode(common_sent, add_special_tokens=False)\n", "print(f\"\\nCommon tokens '{common_sent}': {common_ids}\")\n", "\n", "for layer_idx in [0, n_enc_layers // 2, n_enc_layers]:\n", " enc_h = enc_out.hidden_states[layer_idx][0].float().numpy()\n", " if layer_idx < len(dec_out.hidden_states):\n", " dec_h = dec_out.hidden_states[layer_idx][0].float().numpy()\n", " # Both start from same embeddings — measure divergence\n", " # Use first few tokens only (they share the embedding)\n", " n_compare = min(enc_h.shape[0], dec_h.shape[0], 5)\n", " cos_vals = []\n", " for t in range(n_compare):\n", " cos = np.dot(enc_h[t], dec_h[t]) / (np.linalg.norm(enc_h[t]) * np.linalg.norm(dec_h[t]) + 1e-8)\n", " cos_vals.append(cos)\n", " print(f\" Layer {layer_idx}: enc-dec cosine per position = {[f'{c:.4f}' for c in cos_vals]}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"SPECIAL TOKEN STRUCTURE\")\n", "print(f\"{'='*70}\")\n", "\n", "# T5 special tokens\n", "special_tokens = {\n", " 'pad': tokenizer.pad_token_id,\n", " 'eos': tokenizer.eos_token_id,\n", " 'unk': tokenizer.unk_token_id,\n", "}\n", "# Sentinel tokens (T5 uses through )\n", "for i in range(5):\n", " tok = f\"\"\n", " ids = tokenizer.encode(tok, add_special_tokens=False)\n", " if len(ids) == 1:\n", " special_tokens[f'sentinel_{i}'] = ids[0]\n", "\n", "print(\"Special token norms and pairwise cosine:\")\n", "sp_ids = list(special_tokens.values())\n", "sp_names = list(special_tokens.keys())\n", "sp_embeds = E_np[sp_ids]\n", "sp_norms = np.linalg.norm(sp_embeds, axis=1)\n", "\n", "for name, sid, norm in zip(sp_names, sp_ids, sp_norms):\n", " print(f\" {name:15s} id={sid:6d} norm={norm:.6f}\")\n", "\n", "# Sentinel pairwise cosine\n", "if len(sp_ids) > 1:\n", " sp_n = sp_embeds / (np.linalg.norm(sp_embeds, axis=1, keepdims=True) + 1e-8)\n", " sp_cos = sp_n @ sp_n.T\n", " print(f\"\\nSentinel/special pairwise cosine:\")\n", " for i, ni in enumerate(sp_names):\n", " for j, nj in enumerate(sp_names):\n", " if j > i:\n", " print(f\" {ni:15s} ↔ {nj:15s}: {sp_cos[i,j]:.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"RELATIVE POSITION BIAS STRUCTURE\")\n", "print(f\"{'='*70}\")\n", "\n", "# T5 uses relative position biases instead of absolute PE\n", "rpb = model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias\n", "rpb_weight = rpb.weight.detach().float().numpy() # [num_buckets, n_heads]\n", "print(f\"Relative position bias shape: {rpb_weight.shape}\")\n", "print(f\" Num buckets: {rpb_weight.shape[0]}\")\n", "print(f\" Num heads: {rpb_weight.shape[1]}\")\n", "print(f\" Mean: {rpb_weight.mean():.6f}\")\n", "print(f\" Std: {rpb_weight.std():.6f}\")\n", "print(f\" Min: {rpb_weight.min():.6f}\")\n", "print(f\" Max: {rpb_weight.max():.6f}\")\n", "\n", "# Per-head bias profile\n", "print(f\"\\n Per-head statistics:\")\n", "for h in range(rpb_weight.shape[1]):\n", " col = rpb_weight[:, h]\n", " print(f\" Head {h:2d}: mean={col.mean():.4f} std={col.std():.4f} range=[{col.min():.4f}, {col.max():.4f}]\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"EUCLIDEAN DISTANCE STRUCTURE\")\n", "print(f\"{'='*70}\")\n", "\n", "N_DIST = 2000\n", "dist_idx = rng.choice(vocab_size, size=N_DIST, replace=False)\n", "dists = pdist(E_np[dist_idx], metric='euclidean')\n", "dists_normed = dists / dists.mean()\n", "\n", "print(f\"Pairwise Euclidean distances ({N_DIST} tokens):\")\n", "print(f\" Mean: {dists.mean():.6f}\")\n", "print(f\" Std: {dists.std():.6f}\")\n", "print(f\" CV: {dists.std()/dists.mean():.6f}\")\n", "for p in [1, 5, 25, 50, 75, 95, 99]:\n", " print(f\" {p:>3}%: {np.percentile(dists, p):.6f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"GENERATING VISUALIZATIONS\")\n", "print(f\"{'='*70}\")\n", "\n", "fig, axes = plt.subplots(3, 3, figsize=(18, 15))\n", "fig.suptitle(f\"T5-Small Complete Geometric Terrain Map (vocab={vocab_size}, dim={hidden_dim})\", fontsize=14)\n", "\n", "# 1. Norm distribution\n", "axes[0, 0].hist(norms, bins=200, color='steelblue', alpha=0.8)\n", "axes[0, 0].axvline(norms.mean(), color='red', ls='--', label=f'mean={norms.mean():.3f}')\n", "axes[0, 0].set_title(\"Embedding norm distribution\")\n", "axes[0, 0].legend()\n", "\n", "# 2. Cosine similarity distribution\n", "axes[0, 1].hist(flat_cos, bins=200, color='darkorange', alpha=0.8)\n", "axes[0, 1].axvline(flat_cos.mean(), color='red', ls='--', label=f'mean={flat_cos.mean():.3f}')\n", "axes[0, 1].set_title(\"Pairwise cosine distribution\")\n", "axes[0, 1].legend()\n", "\n", "# 3. Eigenspectrum\n", "axes[0, 2].semilogy(range(min(hidden_dim, 200)), eigvals[:200], color='darkgreen')\n", "axes[0, 2].axhline(eigvals[int(pr)], color='red', ls='--', alpha=0.5, label=f'PR={pr:.0f}')\n", "axes[0, 2].set_title(\"Eigenspectrum (top 200)\")\n", "axes[0, 2].set_xlabel(\"Component\")\n", "axes[0, 2].legend()\n", "\n", "# 4. Pentachoron volume distribution\n", "axes[1, 0].hist(vols_embed, bins=100, alpha=0.6, color='purple', label='Embeddings')\n", "axes[1, 0].hist(vols_random, bins=100, alpha=0.6, color='gray', label='Random')\n", "axes[1, 0].set_title(f\"Pentachoron volumes (ratio={vols_embed.mean()/vols_random.mean():.3f})\")\n", "axes[1, 0].legend()\n", "\n", "# 5. Digit cosine heatmap\n", "im = axes[1, 1].imshow(cos_digits, cmap='YlOrRd', vmin=0, vmax=1)\n", "axes[1, 1].set_xticks(range(10))\n", "axes[1, 1].set_yticks(range(10))\n", "axes[1, 1].set_xticklabels([str(d) for d in range(10)])\n", "axes[1, 1].set_yticklabels([str(d) for d in range(10)])\n", "axes[1, 1].set_title(f\"Digit cosine (|i-j| corr={corr:.3f})\")\n", "plt.colorbar(im, ax=axes[1, 1])\n", "\n", "# 6. Category intra-cosine bar chart\n", "cat_intras = {}\n", "for cat_name, words in categories.items():\n", " token_ids = [tokenizer.encode(w, add_special_tokens=False)[0]\n", " for w in words if len(tokenizer.encode(w, add_special_tokens=False)) == 1]\n", " if len(token_ids) >= 3:\n", " cat_e = E_np[token_ids]\n", " cat_n = cat_e / (np.linalg.norm(cat_e, axis=1, keepdims=True) + 1e-8)\n", " tri_c = np.triu_indices(len(token_ids), k=1)\n", " cat_intras[cat_name] = (cat_n @ cat_n.T)[tri_c].mean()\n", "\n", "cats = list(cat_intras.keys())\n", "vals = [cat_intras[c] for c in cats]\n", "axes[1, 2].barh(cats, vals, color='teal')\n", "axes[1, 2].axvline(global_mean_cos, color='red', ls='--', label=f'global={global_mean_cos:.3f}')\n", "axes[1, 2].set_title(\"Intra-category cosine\")\n", "axes[1, 2].legend()\n", "\n", "# 7. Layer-by-layer norm evolution\n", "axes[2, 0].plot(layer_agg['layer'], layer_agg['mean_norm'], 'o-', color='navy')\n", "axes[2, 0].fill_between(layer_agg['layer'],\n", " layer_agg['mean_norm'] - layer_agg['std_norm'],\n", " layer_agg['mean_norm'] + layer_agg['std_norm'], alpha=0.2)\n", "axes[2, 0].set_title(\"Encoder layer norm evolution\")\n", "axes[2, 0].set_xlabel(\"Layer\")\n", "axes[2, 0].set_ylabel(\"Mean norm\")\n", "\n", "# 8. Layer-by-layer cosine evolution\n", "axes[2, 1].plot(layer_agg['layer'], layer_agg['mean_cos'], 's-', color='crimson')\n", "axes[2, 1].fill_between(layer_agg['layer'],\n", " layer_agg['mean_cos'] - layer_agg['std_cos'],\n", " layer_agg['mean_cos'] + layer_agg['std_cos'], alpha=0.2)\n", "axes[2, 1].set_title(\"Encoder layer pairwise cosine\")\n", "axes[2, 1].set_xlabel(\"Layer\")\n", "axes[2, 1].set_ylabel(\"Mean pairwise cosine\")\n", "\n", "# 9. Relative position bias heatmap (first 4 heads)\n", "rpb_show = rpb_weight[:, :min(8, rpb_weight.shape[1])].T\n", "im2 = axes[2, 2].imshow(rpb_show, aspect='auto', cmap='RdBu_r')\n", "axes[2, 2].set_title(\"Relative position bias (heads × buckets)\")\n", "axes[2, 2].set_xlabel(\"Bucket\")\n", "axes[2, 2].set_ylabel(\"Head\")\n", "plt.colorbar(im2, ax=axes[2, 2])\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"/content/t5_small_terrain_map.png\", dpi=150, bbox_inches='tight')\n", "plt.show()\n", "print(\"Saved: /content/t5_small_terrain_map.png\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"T5-SMALL COMPLETE TERRAIN MAP — SUMMARY\")\n", "print(f\"{'='*70}\")\n", "print(f\"Model: {model_id}\")\n", "print(f\"Total params: {total:,}\")\n", "print(f\"Vocab: {vocab_size}\")\n", "print(f\"Hidden dim: {hidden_dim}\")\n", "print(f\"Encoder layers: {n_enc_layers}\")\n", "print(f\"Decoder layers: {n_dec_layers}\")\n", "print(f\"Weight tying: shared→enc={enc_tied}, shared→dec={dec_tied}, shared→lm_head={lm_tied}\")\n", "print(f\"\")\n", "print(f\"--- EMBEDDING GEOMETRY ---\")\n", "print(f\"Mean norm: {norms.mean():.4f}\")\n", "print(f\"Norm std: {norms.std():.4f}\")\n", "print(f\"Mean pairwise cosine: {flat_cos.mean():.4f}\")\n", "print(f\"Cosine std: {flat_cos.std():.4f}\")\n", "print(f\"\")\n", "print(f\"--- INTRINSIC DIMENSIONALITY ---\")\n", "print(f\"Participation ratio: {pr:.1f}\")\n", "print(f\"Participation / dim: {pr/hidden_dim:.3f}\")\n", "print(f\"Dims for 95% variance: {np.searchsorted(cumvar, 0.95)+1} ({(np.searchsorted(cumvar, 0.95)+1)/hidden_dim*100:.1f}%)\")\n", "print(f\"\")\n", "print(f\"--- PENTACHORON GEOMETRY ---\")\n", "print(f\"Valid simplices: {len(vols_embed)}/{N_SIMP}\")\n", "print(f\"Volume CV: {vols_embed.std()/vols_embed.mean():.4f}\")\n", "print(f\"Embed/random ratio: {vols_embed.mean()/vols_random.mean():.4f}\")\n", "print(f\"\")\n", "print(f\"--- DIGIT MANIFOLD ---\")\n", "print(f\"|i-j| correlation: {corr:.4f}\")\n", "print(f\"Adjacent mean: {adj_mean:.4f}\")\n", "print(f\"Non-adjacent mean: {nonadj_mean:.4f}\")\n", "print(f\"Gap: {adj_mean - nonadj_mean:.4f}\")\n", "print(f\"\")\n", "print(f\"--- REFERENCE (Qwen3.5-0.8B) ---\")\n", "print(f\"Participation / dim: 0.535\")\n", "print(f\"Volume CV: 0.208\")\n", "print(f\"Embed/random ratio: 0.984\")\n", "print(f\"Digit |i-j| corr: -0.862\")\n", "print(f\"Mean pairwise cosine: 0.195\")\n" ] }, { "cell_type": "markdown", "id": "791d5d1f", "metadata": {}, "source": [ "## 2. T5-Small × WordNet Relational Alignment\n", "*Section V.2–V.4: relational correlation, distance bands, hypernym decay*" ] }, { "cell_type": "code", "execution_count": null, "id": "fa2515ae", "metadata": {}, "outputs": [], "source": [ "# T5 × WordNet\n", "\n", "# ============================================================================\n", "# T5-SMALL × WORDNET: Relational Geometry via Summarization\n", "# Feed \"summarize: {definition}\" through encoder, compare hidden state\n", "# geometry against WordNet's relational graph structure.\n", "# ============================================================================\n", "\n", "# !pip install nltk -q\n", "import torch\n", "import numpy as np\n", "import math\n", "import time\n", "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", "import matplotlib.pyplot as plt\n", "from scipy.stats import spearmanr\n", "\n", "import nltk\n", "nltk.download('wordnet', quiet=True)\n", "nltk.download('omw-1.4', quiet=True)\n", "from nltk.corpus import wordnet as wn\n", "\n", "model_id = \"google-t5/t5-small\"\n", "print(f\"Loading {model_id}...\")\n", "tokenizer = T5Tokenizer.from_pretrained(model_id, legacy=True)\n", "model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float32)\n", "model.eval()\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = model.to(device)\n", "print(f\"Device: {device}\")\n", "\n", "# Build WordNet → T5 token mapping\n", "print(\"\\nMatching WordNet lemmas to T5 single-tokens...\")\n", "t5_vocab = {tokenizer.decode([i]).strip(): i for i in range(tokenizer.vocab_size)}\n", "\n", "matched = [] # (lemma_name, synset, token_id, definition)\n", "seen_tokens = set()\n", "\n", "for synset in wn.all_synsets():\n", " for lemma in synset.lemmas():\n", " name = lemma.name().replace('_', ' ')\n", " # Try as single T5 token\n", " ids = tokenizer.encode(name, add_special_tokens=False)\n", " if len(ids) == 1 and ids[0] not in seen_tokens:\n", " defn = synset.definition()\n", " if len(defn) > 10: # skip trivially short definitions\n", " matched.append((name, synset, ids[0], defn))\n", " seen_tokens.add(ids[0])\n", "\n", "print(f\"Matched: {len(matched)} unique single-token WordNet entries\")\n", "print(f\"Sample: {[(m[0], m[3][:60]) for m in matched[:5]]}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ENCODING WORDNET DEFINITIONS THROUGH T5 ENCODER\")\n", "print(f\"{'='*70}\")\n", "\n", "BATCH_SIZE = 64\n", "MAX_LEN = 128 # truncate long definitions\n", "\n", "# Prepare \"summarize: {definition}\" inputs\n", "texts = [f\"summarize: {m[3]}\" for m in matched]\n", "token_ids_list = [m[2] for m in matched]\n", "lemma_names = [m[0] for m in matched]\n", "synsets = [m[1] for m in matched]\n", "\n", "# Storage for encoder final hidden states (mean-pooled per definition)\n", "encoder_reps = np.zeros((len(matched), 512), dtype=np.float32)\n", "\n", "t0 = time.time()\n", "n_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE\n", "\n", "for batch_idx in range(n_batches):\n", " start = batch_idx * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(texts))\n", " batch_texts = texts[start:end]\n", "\n", " inputs = tokenizer(\n", " batch_texts,\n", " return_tensors=\"pt\",\n", " padding=True,\n", " truncation=True,\n", " max_length=MAX_LEN,\n", " ).to(device)\n", "\n", " with torch.no_grad():\n", " enc_out = model.encoder(\n", " input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask,\n", " )\n", " # Mean pool over non-padding positions\n", " hidden = enc_out.last_hidden_state.float() # [B, seq, 512]\n", " mask = inputs.attention_mask.unsqueeze(-1).float() # [B, seq, 1]\n", " pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1) # [B, 512]\n", " encoder_reps[start:end] = pooled.cpu().numpy()\n", "\n", " if (batch_idx + 1) % 50 == 0 or batch_idx == n_batches - 1:\n", " elapsed = time.time() - t0\n", " print(f\" Batch {batch_idx+1}/{n_batches} ({end}/{len(texts)}) - {elapsed:.1f}s\")\n", "\n", "total_time = time.time() - t0\n", "print(f\"\\nEncoded {len(texts)} definitions in {total_time:.1f}s ({len(texts)/total_time:.0f} defs/s)\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"STATIC EMBEDDINGS FOR MATCHED TOKENS\")\n", "print(f\"{'='*70}\")\n", "\n", "E = model.shared.weight.detach().float().cpu().numpy()\n", "static_reps = E[token_ids_list]\n", "\n", "static_norms = np.linalg.norm(static_reps, axis=1)\n", "encoder_norms = np.linalg.norm(encoder_reps, axis=1)\n", "print(f\"Static embed norms: mean={static_norms.mean():.2f} std={static_norms.std():.2f}\")\n", "print(f\"Encoder rep norms: mean={encoder_norms.mean():.2f} std={encoder_norms.std():.2f}\")\n", "\n", "# Per-token cosine between static and encoder representation\n", "dot = (static_reps * encoder_reps).sum(axis=1)\n", "cos_static_enc = dot / (static_norms * encoder_norms + 1e-8)\n", "print(f\"\\nCosine(static, encoder) per token:\")\n", "print(f\" Mean: {cos_static_enc.mean():.4f} Std: {cos_static_enc.std():.4f}\")\n", "print(f\" This tells us how much the encoder transforms the representation\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"WORDNET GRAPH STRUCTURE\")\n", "print(f\"{'='*70}\")\n", "\n", "# Compute WordNet path similarity for a sample of pairs\n", "N_REL = min(3000, len(matched))\n", "rng = np.random.default_rng(42)\n", "rel_idx = rng.choice(len(matched), size=N_REL, replace=False)\n", "\n", "# Pairwise WordNet path similarity (expensive, so subsample pairs)\n", "N_PAIRS = 500000\n", "pair_i = rng.choice(N_REL, size=N_PAIRS)\n", "pair_j = rng.choice(N_REL, size=N_PAIRS)\n", "# Remove self-pairs\n", "valid = pair_i != pair_j\n", "pair_i = pair_i[valid]\n", "pair_j = pair_j[valid]\n", "\n", "print(f\"Computing WordNet path similarity for {len(pair_i)} pairs...\")\n", "t0 = time.time()\n", "\n", "wn_sims = []\n", "enc_cosines = []\n", "static_cosines = []\n", "\n", "# Normalize for cosine\n", "enc_sub = encoder_reps[rel_idx]\n", "static_sub = static_reps[rel_idx]\n", "enc_n = enc_sub / (np.linalg.norm(enc_sub, axis=1, keepdims=True) + 1e-8)\n", "static_n = static_sub / (np.linalg.norm(static_sub, axis=1, keepdims=True) + 1e-8)\n", "\n", "batch_count = 0\n", "for pi, pj in zip(pair_i, pair_j):\n", " s1 = synsets[rel_idx[pi]]\n", " s2 = synsets[rel_idx[pj]]\n", " sim = s1.path_similarity(s2)\n", " if sim is not None and sim > 0:\n", " wn_sims.append(sim)\n", " enc_cosines.append(np.dot(enc_n[pi], enc_n[pj]))\n", " static_cosines.append(np.dot(static_n[pi], static_n[pj]))\n", " batch_count += 1\n", " if batch_count % 100000 == 0:\n", " print(f\" {batch_count}/{len(pair_i)} pairs processed...\")\n", "\n", "wn_sims = np.array(wn_sims)\n", "enc_cosines = np.array(enc_cosines)\n", "static_cosines = np.array(static_cosines)\n", "\n", "elapsed = time.time() - t0\n", "print(f\"Valid pairs with path similarity: {len(wn_sims)} ({elapsed:.1f}s)\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"RELATIONAL CORRELATION: WordNet vs T5 Geometry\")\n", "print(f\"{'='*70}\")\n", "\n", "# Pearson\n", "static_pearson = np.corrcoef(wn_sims, static_cosines)[0, 1]\n", "enc_pearson = np.corrcoef(wn_sims, enc_cosines)[0, 1]\n", "\n", "# Spearman\n", "static_spearman, _ = spearmanr(wn_sims, static_cosines)\n", "enc_spearman, _ = spearmanr(wn_sims, enc_cosines)\n", "\n", "print(f\"Static embeddings vs WordNet path similarity:\")\n", "print(f\" Pearson: {static_pearson:.6f}\")\n", "print(f\" Spearman: {static_spearman:.6f}\")\n", "\n", "print(f\"\\nEncoder representations vs WordNet path similarity:\")\n", "print(f\" Pearson: {enc_pearson:.6f}\")\n", "print(f\" Spearman: {enc_spearman:.6f}\")\n", "\n", "print(f\"\\nLift from encoder processing:\")\n", "print(f\" Pearson: {enc_pearson - static_pearson:+.6f}\")\n", "print(f\" Spearman: {enc_spearman - static_spearman:+.6f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"COSINE BY WORDNET DISTANCE BAND\")\n", "print(f\"{'='*70}\")\n", "\n", "bands = [(0.9, 1.0), (0.5, 0.9), (0.25, 0.5), (0.1, 0.25), (0.05, 0.1), (0.0, 0.05)]\n", "print(f\"{'WN Band':>12s} {'N':>7s} {'Static Cos':>10s} {'Enc Cos':>10s} {'Gap':>8s}\")\n", "for lo, hi in bands:\n", " mask = (wn_sims >= lo) & (wn_sims < hi) if lo > 0 else (wn_sims > lo) & (wn_sims <= hi)\n", " if mask.sum() < 10:\n", " continue\n", " sc = static_cosines[mask].mean()\n", " ec = enc_cosines[mask].mean()\n", " print(f\" [{lo:.2f},{hi:.2f}) {mask.sum():7d} {sc:10.4f} {ec:10.4f} {ec-sc:+8.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"PENTACHORON GEOMETRY: Static vs Encoder Space\")\n", "print(f\"{'='*70}\")\n", "\n", "def cayley_menger_volume_sq(points):\n", " n = len(points)\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1\n", " D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq\n", " D[j + 1, i + 1] = d_sq\n", " k = n - 1\n", " sign = (-1) ** (k + 1)\n", " factorial_sq = math.factorial(k) ** 2\n", " denom = (2 ** k) * factorial_sq\n", " det = np.linalg.det(D)\n", " vol_sq = sign * det / denom\n", " return vol_sq\n", "\n", "N_SIMP = 1000\n", "vols_static = []\n", "vols_encoder = []\n", "\n", "for _ in range(N_SIMP):\n", " idx = rng.choice(len(matched), size=5, replace=False)\n", "\n", " pts_s = static_reps[idx]\n", " vs = cayley_menger_volume_sq(pts_s)\n", " if vs > 0:\n", " vols_static.append(np.sqrt(vs))\n", "\n", " pts_e = encoder_reps[idx]\n", " ve = cayley_menger_volume_sq(pts_e)\n", " if ve > 0:\n", " vols_encoder.append(np.sqrt(ve))\n", "\n", "vols_static = np.array(vols_static)\n", "vols_encoder = np.array(vols_encoder)\n", "\n", "print(f\"Static: valid={len(vols_static)} mean={vols_static.mean():.4e} CV={vols_static.std()/vols_static.mean():.4f}\")\n", "print(f\"Encoder: valid={len(vols_encoder)} mean={vols_encoder.mean():.4e} CV={vols_encoder.std()/vols_encoder.mean():.4f}\")\n", "\n", "min_len = min(len(vols_static), len(vols_encoder))\n", "vol_corr = np.corrcoef(vols_static[:min_len], vols_encoder[:min_len])[0, 1]\n", "print(f\"Per-simplex volume correlation (static vs encoder): {vol_corr:.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"HYPERNYM CHAIN GEOMETRY\")\n", "print(f\"{'='*70}\")\n", "\n", "# Find tokens that form hypernym chains in WordNet\n", "# e.g., dog → canine → carnivore → mammal → animal → organism\n", "# Measure whether cosine decays with hypernym distance\n", "\n", "# Build lookup: synset -> index in matched\n", "synset_to_idx = {}\n", "for i, (name, syn, tid, defn) in enumerate(matched):\n", " if syn not in synset_to_idx:\n", " synset_to_idx[syn] = i\n", "\n", "# Find hypernym chains of length >= 3\n", "chains = []\n", "for syn, idx in synset_to_idx.items():\n", " chain = [(syn, idx)]\n", " current = syn\n", " for _ in range(10): # max depth\n", " hypernyms = current.hypernyms()\n", " if not hypernyms:\n", " break\n", " parent = hypernyms[0]\n", " if parent in synset_to_idx:\n", " chain.append((parent, synset_to_idx[parent]))\n", " current = parent\n", " if len(chain) >= 3:\n", " chains.append(chain)\n", "\n", "print(f\"Found {len(chains)} chains of length >= 3\")\n", "\n", "if len(chains) > 0:\n", " # Measure cosine decay along chains\n", " max_depth = min(8, max(len(c) for c in chains))\n", " depth_cosines_static = {d: [] for d in range(1, max_depth)}\n", " depth_cosines_enc = {d: [] for d in range(1, max_depth)}\n", "\n", " for chain in chains:\n", " root_idx = chain[0][1]\n", " root_s = static_n[0] if root_idx >= N_REL else static_reps[root_idx] / (np.linalg.norm(static_reps[root_idx]) + 1e-8)\n", " root_e = enc_n[0] if root_idx >= N_REL else encoder_reps[root_idx] / (np.linalg.norm(encoder_reps[root_idx]) + 1e-8)\n", "\n", " # Recompute properly\n", " rs = static_reps[root_idx]\n", " re = encoder_reps[root_idx]\n", " rs_n = rs / (np.linalg.norm(rs) + 1e-8)\n", " re_n = re / (np.linalg.norm(re) + 1e-8)\n", "\n", " for depth in range(1, min(len(chain), max_depth)):\n", " anc_idx = chain[depth][1]\n", " as_n = static_reps[anc_idx] / (np.linalg.norm(static_reps[anc_idx]) + 1e-8)\n", " ae_n = encoder_reps[anc_idx] / (np.linalg.norm(encoder_reps[anc_idx]) + 1e-8)\n", " depth_cosines_static[depth].append(np.dot(rs_n, as_n))\n", " depth_cosines_enc[depth].append(np.dot(re_n, ae_n))\n", "\n", " print(f\"\\n{'Depth':>5s} {'N':>5s} {'Static Cos':>10s} {'Enc Cos':>10s}\")\n", " for d in range(1, max_depth):\n", " if len(depth_cosines_static[d]) > 0:\n", " sc = np.mean(depth_cosines_static[d])\n", " ec = np.mean(depth_cosines_enc[d])\n", " print(f\" {d:3d} {len(depth_cosines_static[d]):5d} {sc:10.4f} {ec:10.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"WORDNET CATEGORY CLUSTERING: Static vs Encoder\")\n", "print(f\"{'='*70}\")\n", "\n", "# Use WordNet's top-level categories (lexicographer names)\n", "from collections import defaultdict\n", "lexname_groups = defaultdict(list)\n", "\n", "for i, (name, syn, tid, defn) in enumerate(matched):\n", " lexname = syn.lexname()\n", " lexname_groups[lexname].append(i)\n", "\n", "print(f\"{'Category':>25s} {'N':>5s} {'Static intra':>12s} {'Enc intra':>12s} {'Lift':>8s}\")\n", "for lexname, indices in sorted(lexname_groups.items(), key=lambda x: -len(x[1])):\n", " if len(indices) < 10:\n", " continue\n", " idx_arr = np.array(indices[:200]) # cap at 200\n", "\n", " s_cat = static_reps[idx_arr]\n", " e_cat = encoder_reps[idx_arr]\n", " s_n = s_cat / (np.linalg.norm(s_cat, axis=1, keepdims=True) + 1e-8)\n", " e_n = e_cat / (np.linalg.norm(e_cat, axis=1, keepdims=True) + 1e-8)\n", "\n", " n_c = len(idx_arr)\n", " tri_c = np.triu_indices(n_c, k=1)\n", " intra_s = (s_n @ s_n.T)[tri_c].mean()\n", " intra_e = (e_n @ e_n.T)[tri_c].mean()\n", " print(f\" {lexname:23s} {len(indices):5d} {intra_s:12.4f} {intra_e:12.4f} {intra_e-intra_s:+8.4f}\")\n", "\n", "fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", "fig.suptitle(\"T5-Small × WordNet: Static vs Encoder Geometry\", fontsize=14)\n", "\n", "# 1. WN sim vs encoder cosine scatter\n", "sub = rng.choice(len(wn_sims), size=min(50000, len(wn_sims)), replace=False)\n", "axes[0, 0].scatter(wn_sims[sub], enc_cosines[sub], alpha=0.02, s=1, color='darkgreen')\n", "axes[0, 0].set_xlabel(\"WordNet path similarity\")\n", "axes[0, 0].set_ylabel(\"Encoder cosine\")\n", "axes[0, 0].set_title(f\"Encoder vs WordNet (r={enc_pearson:.4f})\")\n", "\n", "# 2. WN sim vs static cosine scatter\n", "axes[0, 1].scatter(wn_sims[sub], static_cosines[sub], alpha=0.02, s=1, color='steelblue')\n", "axes[0, 1].set_xlabel(\"WordNet path similarity\")\n", "axes[0, 1].set_ylabel(\"Static cosine\")\n", "axes[0, 1].set_title(f\"Static vs WordNet (r={static_pearson:.4f})\")\n", "\n", "# 3. Static vs encoder cosine per token\n", "axes[0, 2].hist(cos_static_enc, bins=200, color='darkorange', alpha=0.8)\n", "axes[0, 2].set_title(f\"Cosine(static, encoder) mean={cos_static_enc.mean():.3f}\")\n", "\n", "# 4. Pentachoron volumes\n", "if len(vols_static) > 0 and len(vols_encoder) > 0:\n", " axes[1, 0].scatter(vols_static[:min_len], vols_encoder[:min_len], alpha=0.3, s=5, color='purple')\n", " axes[1, 0].set_xlabel(\"Static volume\")\n", " axes[1, 0].set_ylabel(\"Encoder volume\")\n", " axes[1, 0].set_title(f\"Pentachoron vols (r={vol_corr:.4f})\")\n", "\n", "# 5. Hypernym depth decay\n", "if len(chains) > 0:\n", " depths = []\n", " cos_s_means = []\n", " cos_e_means = []\n", " for d in range(1, max_depth):\n", " if len(depth_cosines_static[d]) > 0:\n", " depths.append(d)\n", " cos_s_means.append(np.mean(depth_cosines_static[d]))\n", " cos_e_means.append(np.mean(depth_cosines_enc[d]))\n", " axes[1, 1].plot(depths, cos_s_means, 'o-', label='Static', color='steelblue')\n", " axes[1, 1].plot(depths, cos_e_means, 's-', label='Encoder', color='crimson')\n", " axes[1, 1].set_xlabel(\"Hypernym depth\")\n", " axes[1, 1].set_ylabel(\"Mean cosine to root\")\n", " axes[1, 1].set_title(\"Hypernym chain cosine decay\")\n", " axes[1, 1].legend()\n", "\n", "# 6. Category intra-cosine comparison\n", "cat_names_plot = []\n", "static_intras = []\n", "enc_intras = []\n", "for lexname, indices in sorted(lexname_groups.items(), key=lambda x: -len(x[1])):\n", " if len(indices) < 10:\n", " continue\n", " idx_arr = np.array(indices[:200])\n", " s_cat = static_reps[idx_arr]\n", " e_cat = encoder_reps[idx_arr]\n", " s_n = s_cat / (np.linalg.norm(s_cat, axis=1, keepdims=True) + 1e-8)\n", " e_n = e_cat / (np.linalg.norm(e_cat, axis=1, keepdims=True) + 1e-8)\n", " n_c = len(idx_arr)\n", " tri_c = np.triu_indices(n_c, k=1)\n", " static_intras.append((s_n @ s_n.T)[tri_c].mean())\n", " enc_intras.append((e_n @ e_n.T)[tri_c].mean())\n", " cat_names_plot.append(lexname.split('.')[-1][:12] if '.' in lexname else lexname[:12])\n", "\n", " if len(cat_names_plot) >= 15:\n", " break\n", "\n", "y_pos = np.arange(len(cat_names_plot))\n", "axes[1, 2].barh(y_pos - 0.15, static_intras, 0.3, label='Static', color='steelblue', alpha=0.7)\n", "axes[1, 2].barh(y_pos + 0.15, enc_intras, 0.3, label='Encoder', color='crimson', alpha=0.7)\n", "axes[1, 2].set_yticks(y_pos)\n", "axes[1, 2].set_yticklabels(cat_names_plot)\n", "axes[1, 2].set_title(\"Intra-category cosine by WordNet lexname\")\n", "axes[1, 2].legend()\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"/content/t5_wordnet_geometry.png\", dpi=150, bbox_inches='tight')\n", "plt.show()\n", "print(\"\\nSaved: /content/t5_wordnet_geometry.png\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"T5-SMALL × WORDNET — SUMMARY\")\n", "print(f\"{'='*70}\")\n", "print(f\"WordNet entries matched: {len(matched)}\")\n", "print(f\"Encoding time: {total_time:.1f}s\")\n", "print(f\"Throughput: {len(texts)/total_time:.0f} defs/s\")\n", "print(f\"\")\n", "print(f\"--- STATIC vs ENCODER TRANSFORMATION ---\")\n", "print(f\"Cosine(static, encoder): {cos_static_enc.mean():.4f} ± {cos_static_enc.std():.4f}\")\n", "print(f\"\")\n", "print(f\"--- WORDNET RELATIONAL ALIGNMENT ---\")\n", "print(f\"Static Pearson: {static_pearson:.4f}\")\n", "print(f\"Static Spearman: {static_spearman:.4f}\")\n", "print(f\"Encoder Pearson: {enc_pearson:.4f}\")\n", "print(f\"Encoder Spearman: {enc_spearman:.4f}\")\n", "print(f\"Lift (Pearson): {enc_pearson - static_pearson:+.4f}\")\n", "print(f\"Lift (Spearman): {enc_spearman - static_spearman:+.4f}\")\n", "print(f\"\")\n", "print(f\"--- PENTACHORON GEOMETRY ---\")\n", "print(f\"Static CV: {vols_static.std()/vols_static.mean():.4f}\")\n", "print(f\"Encoder CV: {vols_encoder.std()/vols_encoder.mean():.4f}\")\n", "print(f\"Per-simplex correlation: {vol_corr:.4f}\")\n" ] }, { "cell_type": "markdown", "id": "9735f304", "metadata": {}, "source": [ "## 3. 50-Seed Stability Test (GPU)\n", "*Confidence intervals for Section V findings*" ] }, { "cell_type": "code", "execution_count": null, "id": "d5570695", "metadata": {}, "outputs": [], "source": [ "# 50-seed stability\n", "\n", "# ============================================================================\n", "# T5-SMALL × WORDNET: 50-Seed Stability Test\n", "# Run AFTER the main WordNet probe (reuses encoder_reps, static_reps, etc.)\n", "# ============================================================================\n", "\n", "import numpy as np\n", "import time\n", "from scipy.stats import spearmanr\n", "from tqdm import tqdm\n", "\n", "import torch\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Compute device: {device}\")\n", "\n", "N_SEEDS = 50\n", "N_REL_SAMPLE = 2000\n", "N_PAIRS_PER = 200000\n", "N_SIMP_PER = 500\n", "\n", "# Precompute GPU tensors — all reps normalized on GPU once\n", "enc_reps_t = torch.tensor(encoder_reps, device=device, dtype=torch.float32)\n", "static_reps_t = torch.tensor(static_reps, device=device, dtype=torch.float32)\n", "enc_normed_t = enc_reps_t / (enc_reps_t.norm(dim=1, keepdim=True) + 1e-8)\n", "static_normed_t = static_reps_t / (static_reps_t.norm(dim=1, keepdim=True) + 1e-8)\n", "\n", "# Precompute hypernym chain index tensors for fast GPU gather\n", "chain_roots = []\n", "chain_depths = {1: [], 3: [], 5: []}\n", "for chain in chains:\n", " root_idx = chain[0][1]\n", " chain_roots.append(root_idx)\n", " for d in [1, 3, 5]:\n", " if d < len(chain):\n", " chain_depths[d].append((len(chain_roots) - 1, chain[d][1]))\n", "\n", "# Batched Cayley-Menger on GPU\n", "def cayley_menger_batch_gpu(points_batch):\n", " \"\"\"points_batch: [B, 5, D] -> volumes [B]\"\"\"\n", " B = points_batch.shape[0]\n", " # Pairwise squared distances [B, 5, 5]\n", " diff = points_batch.unsqueeze(2) - points_batch.unsqueeze(1) # [B,5,5,D]\n", " D_sq = (diff ** 2).sum(dim=-1) # [B, 5, 5]\n", "\n", " # Build Cayley-Menger matrix [B, 6, 6]\n", " CM = torch.zeros(B, 6, 6, device=points_batch.device, dtype=points_batch.dtype)\n", " CM[:, 0, 1:] = 1.0\n", " CM[:, 1:, 0] = 1.0\n", " CM[:, 1:, 1:] = D_sq\n", "\n", " det = torch.linalg.det(CM) # [B]\n", " # k=4 (4-simplex): sign = (-1)^5 = -1, denom = 2^4 * (4!)^2 = 16 * 576 = 9216\n", " vol_sq = -det / 9216.0\n", " return vol_sq\n", "\n", "# Storage\n", "results = {\n", " 'enc_pearson': [], 'enc_spearman': [],\n", " 'static_pearson': [], 'static_spearman': [],\n", " 'enc_cv': [], 'static_cv': [],\n", " 'enc_vol_mean': [], 'static_vol_mean': [],\n", " 'hyp_depth1_enc': [], 'hyp_depth3_enc': [], 'hyp_depth5_enc': [],\n", " 'hyp_depth1_static': [], 'hyp_depth3_static': [], 'hyp_depth5_static': [],\n", " 'band_high_enc': [], 'band_low_enc': [],\n", " 'band_high_static': [], 'band_low_static': [],\n", "}\n", "\n", "N_TOTAL = len(matched)\n", "print(f\"Running {N_SEEDS} seeds across {N_TOTAL} WordNet entries...\")\n", "print(f\" Per seed: {N_REL_SAMPLE} rel tokens, {N_PAIRS_PER} WN pairs, {N_SIMP_PER} simplices\")\n", "\n", "t0 = time.time()\n", "\n", "for seed in range(N_SEEDS):\n", " rng = np.random.default_rng(seed)\n", "\n", " # Sample tokens\n", " rel_idx = rng.choice(N_TOTAL, size=min(N_REL_SAMPLE, N_TOTAL), replace=False)\n", " rel_idx_t = torch.tensor(rel_idx, device=device, dtype=torch.long)\n", "\n", " # GPU normalized subsets\n", " enc_n = enc_normed_t[rel_idx_t] # [N_REL, 512]\n", " static_n = static_normed_t[rel_idx_t] # [N_REL, 512]\n", "\n", " # Random pairs — WN similarity is CPU-bound, but cosines are precomputed on GPU\n", " pi = rng.choice(len(rel_idx), size=N_PAIRS_PER)\n", " pj = rng.choice(len(rel_idx), size=N_PAIRS_PER)\n", " valid = pi != pj\n", " pi, pj = pi[valid], pj[valid]\n", "\n", " # Batch cosines on GPU for ALL pairs at once\n", " pi_t = torch.tensor(pi, device=device, dtype=torch.long)\n", " pj_t = torch.tensor(pj, device=device, dtype=torch.long)\n", " enc_cos_all = (enc_n[pi_t] * enc_n[pj_t]).sum(dim=1).cpu().numpy()\n", " static_cos_all = (static_n[pi_t] * static_n[pj_t]).sum(dim=1).cpu().numpy()\n", "\n", " # WN similarity — CPU bound, unavoidable\n", " wn_s = np.empty(len(pi), dtype=np.float32)\n", " wn_valid = np.zeros(len(pi), dtype=bool)\n", " for k in tqdm(range(len(pi)), desc=f\"Seed {seed+1} WN pairs\", leave=False, miniters=10000):\n", " sim = synsets[rel_idx[pi[k]]].path_similarity(synsets[rel_idx[pj[k]]])\n", " if sim is not None and sim > 0:\n", " wn_s[k] = sim\n", " wn_valid[k] = True\n", "\n", " wn_s = wn_s[wn_valid]\n", " enc_c = enc_cos_all[wn_valid]\n", " static_c = static_cos_all[wn_valid]\n", "\n", " # Relational correlations\n", " results['enc_pearson'].append(np.corrcoef(wn_s, enc_c)[0, 1])\n", " results['static_pearson'].append(np.corrcoef(wn_s, static_c)[0, 1])\n", " sp_enc, _ = spearmanr(wn_s, enc_c)\n", " sp_static, _ = spearmanr(wn_s, static_c)\n", " results['enc_spearman'].append(sp_enc)\n", " results['static_spearman'].append(sp_static)\n", "\n", " # Distance bands\n", " high_mask = wn_s >= 0.25\n", " low_mask = wn_s < 0.1\n", " if high_mask.sum() > 0:\n", " results['band_high_enc'].append(enc_c[high_mask].mean())\n", " results['band_high_static'].append(static_c[high_mask].mean())\n", " if low_mask.sum() > 0:\n", " results['band_low_enc'].append(enc_c[low_mask].mean())\n", " results['band_low_static'].append(static_c[low_mask].mean())\n", "\n", " # Pentachoron geometry — batched on GPU\n", " simp_idx = np.stack([rng.choice(N_TOTAL, size=5, replace=False) for _ in range(N_SIMP_PER)])\n", " simp_idx_t = torch.tensor(simp_idx, device=device, dtype=torch.long)\n", "\n", " static_pts = static_reps_t[simp_idx_t] # [N_SIMP, 5, 512]\n", " enc_pts = enc_reps_t[simp_idx_t] # [N_SIMP, 5, 512]\n", "\n", " vol_sq_s = cayley_menger_batch_gpu(static_pts)\n", " vol_sq_e = cayley_menger_batch_gpu(enc_pts)\n", "\n", " valid_s = vol_sq_s > 0\n", " valid_e = vol_sq_e > 0\n", "\n", " if valid_s.sum() > 0:\n", " vs = torch.sqrt(vol_sq_s[valid_s]).cpu().numpy()\n", " results['static_cv'].append(vs.std() / vs.mean())\n", " results['static_vol_mean'].append(vs.mean())\n", " if valid_e.sum() > 0:\n", " ve = torch.sqrt(vol_sq_e[valid_e]).cpu().numpy()\n", " results['enc_cv'].append(ve.std() / ve.mean())\n", " results['enc_vol_mean'].append(ve.mean())\n", "\n", " # Hypernym chains — GPU cosine, CPU chain lookup\n", " chain_sub = rng.choice(len(chains), size=min(1000, len(chains)), replace=False)\n", "\n", " for d in [1, 3, 5]:\n", " root_indices = []\n", " anc_indices = []\n", " for ci in chain_sub:\n", " chain = chains[ci]\n", " if d < len(chain):\n", " root_indices.append(chain[0][1])\n", " anc_indices.append(chain[d][1])\n", "\n", " if len(root_indices) > 0:\n", " root_t = torch.tensor(root_indices, device=device, dtype=torch.long)\n", " anc_t = torch.tensor(anc_indices, device=device, dtype=torch.long)\n", "\n", " enc_cos_hyp = (enc_normed_t[root_t] * enc_normed_t[anc_t]).sum(dim=1).mean().item()\n", " static_cos_hyp = (static_normed_t[root_t] * static_normed_t[anc_t]).sum(dim=1).mean().item()\n", "\n", " results[f'hyp_depth{d}_enc'].append(enc_cos_hyp)\n", " results[f'hyp_depth{d}_static'].append(static_cos_hyp)\n", "\n", " if (seed + 1) % 10 == 0:\n", " elapsed = time.time() - t0\n", " print(f\" Seed {seed+1}/{N_SEEDS} - {elapsed:.1f}s\")\n", "\n", "total_time = time.time() - t0\n", "print(f\"\\nCompleted {N_SEEDS} seeds in {total_time:.1f}s ({total_time/N_SEEDS:.1f}s/seed)\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(f\"50-SEED STABILITY REPORT\")\n", "print(f\"{'='*70}\")\n", "\n", "def report(name, vals):\n", " v = np.array(vals)\n", " print(f\" {name:35s} {v.mean():.6f} ± {v.std():.6f} [{v.min():.6f}, {v.max():.6f}]\")\n", "\n", "print(f\"\\n--- RELATIONAL CORRELATION (WordNet vs T5) ---\")\n", "report(\"Encoder Pearson\", results['enc_pearson'])\n", "report(\"Encoder Spearman\", results['enc_spearman'])\n", "report(\"Static Pearson\", results['static_pearson'])\n", "report(\"Static Spearman\", results['static_spearman'])\n", "\n", "print(f\"\\n--- DISTANCE BANDS ---\")\n", "report(\"High WN sim (≥0.25) Enc cos\", results['band_high_enc'])\n", "report(\"High WN sim (≥0.25) Static cos\", results['band_high_static'])\n", "report(\"Low WN sim (<0.10) Enc cos\", results['band_low_enc'])\n", "report(\"Low WN sim (<0.10) Static cos\", results['band_low_static'])\n", "if len(results['band_high_enc']) > 0 and len(results['band_low_enc']) > 0:\n", " gradient = np.array(results['band_high_enc']) - np.array(results['band_low_enc'])\n", " report(\"Enc gradient (high - low)\", gradient.tolist())\n", " gradient_s = np.array(results['band_high_static']) - np.array(results['band_low_static'])\n", " report(\"Static gradient (high - low)\", gradient_s.tolist())\n", "\n", "print(f\"\\n--- PENTACHORON GEOMETRY ---\")\n", "report(\"Encoder CV\", results['enc_cv'])\n", "report(\"Static CV\", results['static_cv'])\n", "\n", "print(f\"\\n--- HYPERNYM CHAIN DECAY ---\")\n", "report(\"Depth 1 Encoder cos\", results['hyp_depth1_enc'])\n", "report(\"Depth 3 Encoder cos\", results['hyp_depth3_enc'])\n", "report(\"Depth 5 Encoder cos\", results['hyp_depth5_enc'])\n", "report(\"Depth 1 Static cos\", results['hyp_depth1_static'])\n", "report(\"Depth 3 Static cos\", results['hyp_depth3_static'])\n", "report(\"Depth 5 Static cos\", results['hyp_depth5_static'])\n", "\n", "if len(results['hyp_depth1_enc']) > 0 and len(results['hyp_depth5_enc']) > 0:\n", " decay_enc = np.array(results['hyp_depth1_enc']) - np.array(results['hyp_depth5_enc'][:len(results['hyp_depth1_enc'])])\n", " decay_static = np.array(results['hyp_depth1_static']) - np.array(results['hyp_depth5_static'][:len(results['hyp_depth1_static'])])\n", " report(\"Hypernym decay 1→5 Encoder\", decay_enc.tolist())\n", " report(\"Hypernym decay 1→5 Static\", decay_static.tolist())\n", "\n", "print(f\"\\n--- INVARIANT CHECK ---\")\n", "enc_cv_arr = np.array(results['enc_cv'])\n", "static_cv_arr = np.array(results['static_cv'])\n", "print(f\" Encoder CV coefficient of variation: {enc_cv_arr.std()/enc_cv_arr.mean()*100:.2f}%\")\n", "print(f\" Static CV coefficient of variation: {static_cv_arr.std()/static_cv_arr.mean()*100:.2f}%\")\n", "enc_p = np.array(results['enc_pearson'])\n", "print(f\" Enc Pearson CV: {enc_p.std()/enc_p.mean()*100:.2f}%\")\n" ] }, { "cell_type": "markdown", "id": "f6100d9e", "metadata": {}, "source": [ "## 4. T5-Small/Base Inactive Weight Topology\n", "*Section VI (T5 entries): SVD, sparsity, QK manifold, dead neurons*" ] }, { "cell_type": "code", "execution_count": null, "id": "520e0b9e", "metadata": {}, "outputs": [], "source": [ "# T5 inactive weights\n", "\n", "# ============================================================================\n", "# T5-SMALL: INACTIVE WEIGHT GEOMETRY\n", "# Analyze the raw weight matrices as geometric objects.\n", "# No inference. No data. Just the learned topology.\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from collections import defaultdict\n", "\n", "model_id = \"google-t5/t5-small\"\n", "from transformers import T5ForConditionalGeneration\n", "model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float32)\n", "model.eval()\n", "\n", "# Catalog all weight matrices by type\n", "weight_catalog = defaultdict(list)\n", "for name, param in model.named_parameters():\n", " p = param.detach().float()\n", " parts = name.split('.')\n", "\n", " # Classify\n", " if 'embed' in name or 'shared' in name:\n", " wtype = 'embedding'\n", " elif 'relative_attention_bias' in name:\n", " wtype = 'position_bias'\n", " elif 'layer_norm' in name or 'final_layer_norm' in name:\n", " wtype = 'layernorm'\n", " elif 'SelfAttention' in name:\n", " # e.g. encoder.block.0.layer.0.SelfAttention.q.weight\n", " for subpart in parts:\n", " if subpart in ('q', 'k', 'v', 'o'):\n", " wtype = f'self_attn_{subpart}'\n", " break\n", " else:\n", " wtype = 'self_attn_other'\n", " elif 'EncDecAttention' in name:\n", " for subpart in parts:\n", " if subpart in ('q', 'k', 'v', 'o'):\n", " wtype = f'cross_attn_{subpart}'\n", " break\n", " else:\n", " wtype = 'cross_attn_other'\n", " elif 'DenseReluDense' in name:\n", " for subpart in parts:\n", " if subpart in ('wi', 'wo', 'wi_0', 'wi_1'):\n", " wtype = f'mlp_{subpart}'\n", " break\n", " else:\n", " wtype = 'mlp_other'\n", " else:\n", " wtype = 'other'\n", "\n", " # Determine if encoder or decoder\n", " if 'encoder' in name:\n", " location = 'encoder'\n", " elif 'decoder' in name:\n", " location = 'decoder'\n", " else:\n", " location = 'shared'\n", "\n", " # Layer number\n", " layer_num = -1\n", " for i, part in enumerate(parts):\n", " if part == 'block' and i + 1 < len(parts):\n", " try:\n", " layer_num = int(parts[i + 1])\n", " except:\n", " pass\n", "\n", " weight_catalog[wtype].append({\n", " 'name': name,\n", " 'shape': tuple(p.shape),\n", " 'param': p,\n", " 'location': location,\n", " 'layer': layer_num,\n", " 'numel': p.numel(),\n", " })\n", "\n", "print(f\"{'='*70}\")\n", "print(\"WEIGHT CATALOG\")\n", "print(f\"{'='*70}\")\n", "for wtype, entries in sorted(weight_catalog.items()):\n", " total = sum(e['numel'] for e in entries)\n", " shapes = set(str(e['shape']) for e in entries)\n", " print(f\" {wtype:25s}: {len(entries):3d} matrices, {total:>12,} params, shapes={shapes}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"SINGULAR VALUE DECOMPOSITION — EFFECTIVE RANK\")\n", "print(f\"{'='*70}\")\n", "\n", "svd_results = []\n", "\n", "for wtype, entries in weight_catalog.items():\n", " if wtype in ['layernorm', 'position_bias', 'embedding']:\n", " continue # Skip 1D params and embeddings (already analyzed)\n", "\n", " for entry in entries:\n", " W = entry['param']\n", " if W.dim() != 2:\n", " continue\n", "\n", " U, S, Vh = torch.linalg.svd(W, full_matrices=False)\n", " S_np = S.numpy()\n", "\n", " # Effective rank metrics\n", " total_s = S_np.sum()\n", " cumsum = np.cumsum(S_np) / total_s\n", "\n", " # Stable rank: ||W||_F^2 / ||W||_2^2\n", " stable_rank = (S_np ** 2).sum() / (S_np[0] ** 2) if S_np[0] > 0 else 0\n", "\n", " # Participation ratio of singular values\n", " pr = (S_np.sum()) ** 2 / ((S_np ** 2).sum()) if (S_np ** 2).sum() > 0 else 0\n", "\n", " # Fraction of singular values > 1% of max\n", " active_frac = (S_np > 0.01 * S_np[0]).sum() / len(S_np)\n", "\n", " # 90% energy rank\n", " rank_90 = np.searchsorted(cumsum, 0.90) + 1\n", "\n", " svd_results.append({\n", " 'name': entry['name'],\n", " 'wtype': wtype,\n", " 'location': entry['location'],\n", " 'layer': entry['layer'],\n", " 'shape': entry['shape'],\n", " 'stable_rank': stable_rank,\n", " 'pr': pr,\n", " 'active_frac': active_frac,\n", " 'rank_90': rank_90,\n", " 'max_sv': S_np[0],\n", " 'min_sv': S_np[-1],\n", " 'condition': S_np[0] / (S_np[-1] + 1e-10),\n", " 'singular_values': S_np,\n", " })\n", "\n", "# Print summary by type\n", "print(f\"\\n{'Type':25s} {'StableRank':>10s} {'PR':>8s} {'Active%':>8s} {'Rank90':>7s} {'Condition':>10s}\")\n", "for wtype in sorted(set(r['wtype'] for r in svd_results)):\n", " subset = [r for r in svd_results if r['wtype'] == wtype]\n", " sr = np.mean([r['stable_rank'] for r in subset])\n", " pr = np.mean([r['pr'] for r in subset])\n", " af = np.mean([r['active_frac'] for r in subset])\n", " r90 = np.mean([r['rank_90'] for r in subset])\n", " cond = np.mean([r['condition'] for r in subset])\n", " print(f\" {wtype:23s} {sr:10.2f} {pr:8.2f} {af:8.3f} {r90:7.1f} {cond:10.1f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"SPARSITY TOPOLOGY — THE NEGATIVE SPACE\")\n", "print(f\"{'='*70}\")\n", "\n", "thresholds = [1e-6, 1e-4, 1e-3, 1e-2, 1e-1]\n", "\n", "print(f\"\\n{'Type':25s}\", end=\"\")\n", "for t in thresholds:\n", " print(f\" {'<'+str(t):>8s}\", end=\"\")\n", "print()\n", "\n", "for wtype in sorted(set(r['wtype'] for r in svd_results)):\n", " entries = [e for e in weight_catalog.get(wtype, []) if e['param'].dim() == 2]\n", " if not entries:\n", " continue\n", "\n", " all_vals = torch.cat([e['param'].abs().flatten() for e in entries])\n", " total = len(all_vals)\n", "\n", " print(f\" {wtype:23s}\", end=\"\")\n", " for t in thresholds:\n", " frac = (all_vals < t).sum().item() / total\n", " print(f\" {frac:8.4f}\", end=\"\")\n", " print()\n", "\n", "# Overall sparsity\n", "all_params = torch.cat([p.flatten() for p in model.parameters()])\n", "print(f\"\\n {'FULL MODEL':23s}\", end=\"\")\n", "for t in thresholds:\n", " frac = (all_params.abs() < t).sum().item() / len(all_params)\n", " print(f\" {frac:8.4f}\", end=\"\")\n", "print()\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"QK SIMILARITY MANIFOLD — THE LEARNED SIMILARITY FUNCTION\")\n", "print(f\"{'='*70}\")\n", "\n", "# For each encoder self-attention layer, compute W_Q^T W_K\n", "# This matrix defines the model's notion of similarity\n", "for location in ['encoder', 'decoder']:\n", " print(f\"\\n--- {location.upper()} ---\")\n", "\n", " q_weights = sorted([r for r in svd_results if r['wtype'] == 'self_attn_q' and r['location'] == location],\n", " key=lambda x: x['layer'])\n", " k_weights = sorted([r for r in svd_results if r['wtype'] == 'self_attn_k' and r['location'] == location],\n", " key=lambda x: x['layer'])\n", "\n", " for q_entry, k_entry in zip(q_weights, k_weights):\n", " layer = q_entry['layer']\n", " W_q = [e for e in weight_catalog['self_attn_q'] if e['name'] == q_entry['name']][0]['param']\n", " W_k = [e for e in weight_catalog['self_attn_k'] if e['name'] == k_entry['name']][0]['param']\n", "\n", " # QK^T defines the attention similarity: [d_model, d_model]\n", " QK = W_q @ W_k.T # [d_model, d_kv] @ [d_kv, d_model] -> [d_model, d_model]\n", "\n", " # SVD of the similarity matrix\n", " S_qk = torch.linalg.svdvals(QK).numpy()\n", " stable_rank_qk = (S_qk ** 2).sum() / (S_qk[0] ** 2) if S_qk[0] > 0 else 0\n", " pr_qk = (S_qk.sum()) ** 2 / ((S_qk ** 2).sum()) if (S_qk ** 2).sum() > 0 else 0\n", "\n", " # Symmetry: is QK^T symmetric? (would mean Q≈K)\n", " sym_diff = torch.norm(QK - QK.T).item() / torch.norm(QK).item()\n", "\n", " # Eigendecomposition of symmetric part\n", " QK_sym = (QK + QK.T) / 2\n", " eigvals = torch.linalg.eigvalsh(QK_sym).numpy()[::-1]\n", " n_positive = (eigvals > 0).sum()\n", " n_negative = (eigvals < 0).sum()\n", "\n", " print(f\" Layer {layer}: QK shape={tuple(QK.shape)}, stable_rank={stable_rank_qk:.2f}, PR={pr_qk:.2f}\")\n", " print(f\" symmetry_deviation={sym_diff:.4f}, positive_eig={n_positive}, negative_eig={n_negative}\")\n", " print(f\" top5_eig: {eigvals[:5]}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"MLP DEAD NEURON ANALYSIS\")\n", "print(f\"{'='*70}\")\n", "\n", "# T5 MLP: DenseReluDense with wi (up-project) and wo (down-project)\n", "# Dead neurons: columns of wi that are near-zero, or rows of wo that are near-zero\n", "# These are directions in intermediate space that the model learned to suppress\n", "\n", "for location in ['encoder', 'decoder']:\n", " print(f\"\\n--- {location.upper()} ---\")\n", "\n", " wi_entries = sorted([e for e in weight_catalog.get('mlp_wi', []) if e['location'] == location],\n", " key=lambda x: x['layer'])\n", " wo_entries = sorted([e for e in weight_catalog.get('mlp_wo', []) if e['location'] == location],\n", " key=lambda x: x['layer'])\n", "\n", " for wi_entry, wo_entry in zip(wi_entries, wo_entries):\n", " layer = wi_entry['layer']\n", " W_up = wi_entry['param'] # [d_ff, d_model] — columns are input features\n", " W_down = wo_entry['param'] # [d_model, d_ff] — rows are output features\n", "\n", " # Per-neuron norms in intermediate space\n", " up_norms = torch.norm(W_up, dim=1) # [d_ff] — norm of each neuron's input weights\n", " down_norms = torch.norm(W_down, dim=0) # [d_ff] — norm of each neuron's output weights\n", "\n", " # Combined importance: a neuron is dead if EITHER its input or output is near-zero\n", " combined = up_norms * down_norms\n", "\n", " # Thresholds\n", " d_ff = W_up.shape[0]\n", " dead_01 = (combined < 0.01 * combined.mean()).sum().item()\n", " dead_10 = (combined < 0.10 * combined.mean()).sum().item()\n", "\n", " # Distribution\n", " print(f\" Layer {layer}: d_ff={d_ff}\")\n", " print(f\" Up norms: mean={up_norms.mean():.4f} std={up_norms.std():.4f} min={up_norms.min():.4f} max={up_norms.max():.4f}\")\n", " print(f\" Down norms: mean={down_norms.mean():.4f} std={down_norms.std():.4f} min={down_norms.min():.4f} max={down_norms.max():.4f}\")\n", " print(f\" Combined: dead(<1% mean)={dead_01}/{d_ff} ({dead_01/d_ff*100:.1f}%), weak(<10% mean)={dead_10}/{d_ff} ({dead_10/d_ff*100:.1f}%)\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"CROSS-LAYER WEIGHT CORRELATION\")\n", "print(f\"{'='*70}\")\n", "\n", "# For each weight type, compute cosine similarity between layers\n", "for wtype in ['self_attn_q', 'self_attn_k', 'self_attn_v', 'mlp_wi', 'mlp_wo']:\n", " for location in ['encoder', 'decoder']:\n", " entries = sorted([e for e in weight_catalog.get(wtype, []) if e['location'] == location and e['param'].dim() == 2],\n", " key=lambda x: x['layer'])\n", "\n", " if len(entries) < 2:\n", " continue\n", "\n", " n_layers = len(entries)\n", " cross_cos = np.zeros((n_layers, n_layers))\n", "\n", " for i in range(n_layers):\n", " for j in range(n_layers):\n", " Wi = entries[i]['param'].flatten()\n", " Wj = entries[j]['param'].flatten()\n", " cos = torch.dot(Wi, Wj) / (torch.norm(Wi) * torch.norm(Wj) + 1e-8)\n", " cross_cos[i, j] = cos.item()\n", "\n", " # Print compact\n", " print(f\"\\n {location}/{wtype}:\")\n", " print(f\" \", end=\"\")\n", " for j in range(n_layers):\n", " print(f\" L{entries[j]['layer']}\", end=\"\")\n", " print()\n", " for i in range(n_layers):\n", " print(f\" L{entries[i]['layer']}\", end=\"\")\n", " for j in range(n_layers):\n", " if j < i:\n", " print(f\" \", end=\"\")\n", " elif j == i:\n", " print(f\" -- \", end=\"\")\n", " else:\n", " print(f\" {cross_cos[i,j]:.3f}\", end=\"\")\n", " print()\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"LAYER NORM LEARNED SCALES — DIMENSION IMPORTANCE\")\n", "print(f\"{'='*70}\")\n", "\n", "ln_entries = sorted(weight_catalog.get('layernorm', []), key=lambda x: (x['location'], x['layer'], x['name']))\n", "\n", "for entry in ln_entries:\n", " if 'weight' not in entry['name']:\n", " continue\n", " w = entry['param'].numpy()\n", " high = (w > 1.5).sum()\n", " low = (w < 0.5).sum()\n", " near_one = ((w > 0.8) & (w < 1.2)).sum()\n", " print(f\" {entry['name']:60s} mean={w.mean():.4f} std={w.std():.4f} high={high} low={low} near_one={near_one}/{len(w)}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"RELATIVE POSITION BIAS — PER-HEAD TOPOLOGY\")\n", "print(f\"{'='*70}\")\n", "\n", "for entry in weight_catalog.get('position_bias', []):\n", " rpb = entry['param'].numpy() # [n_buckets, n_heads]\n", " n_buckets, n_heads = rpb.shape\n", "\n", " print(f\"\\n {entry['name']}: [{n_buckets} buckets, {n_heads} heads]\")\n", "\n", " # Per-head: is the bias monotonic? (nearby tokens more attended)\n", " for h in range(n_heads):\n", " bias = rpb[:, h]\n", " # Check monotonicity of first half (nearby distances)\n", " half = n_buckets // 2\n", " nearby = bias[:half]\n", " diffs = np.diff(nearby)\n", " monotonic_frac = (diffs <= 0).sum() / len(diffs) # fraction monotonically decreasing\n", "\n", " # Correlation with bucket index (linear distance relationship)\n", " dist_corr = np.corrcoef(np.arange(n_buckets), bias)[0, 1]\n", "\n", " print(f\" Head {h}: range=[{bias.min():.3f}, {bias.max():.3f}], \"\n", " f\"nearby_monotonic={monotonic_frac:.2f}, dist_corr={dist_corr:.3f}, \"\n", " f\"peak_bucket={np.argmax(bias)}\")\n", "\n", "fig, axes = plt.subplots(3, 3, figsize=(18, 15))\n", "fig.suptitle(\"T5-Small Inactive Weight Geometry\", fontsize=14)\n", "\n", "# 1. Singular value spectra by weight type\n", "for wtype in ['self_attn_q', 'mlp_wi', 'mlp_wo']:\n", " subset = [r for r in svd_results if r['wtype'] == wtype and r['location'] == 'encoder']\n", " for i, r in enumerate(subset):\n", " s = r['singular_values']\n", " axes[0, 0].semilogy(s / s[0], alpha=0.3, label=f\"{wtype}\" if i == 0 else None)\n", "axes[0, 0].set_title(\"Normalized singular value spectra (encoder)\")\n", "axes[0, 0].set_xlabel(\"Index\")\n", "axes[0, 0].legend(fontsize=8)\n", "\n", "# 2. Stable rank by layer\n", "for wtype in ['self_attn_q', 'self_attn_k', 'mlp_wi']:\n", " for loc in ['encoder']:\n", " subset = sorted([r for r in svd_results if r['wtype'] == wtype and r['location'] == loc],\n", " key=lambda x: x['layer'])\n", " if subset:\n", " layers = [r['layer'] for r in subset]\n", " ranks = [r['stable_rank'] for r in subset]\n", " axes[0, 1].plot(layers, ranks, 'o-', label=f\"{wtype}\")\n", "axes[0, 1].set_title(\"Stable rank by layer\")\n", "axes[0, 1].set_xlabel(\"Layer\")\n", "if any(r['wtype'] in ['self_attn_q', 'self_attn_k', 'mlp_wi'] for r in svd_results):\n", " axes[0, 1].legend(fontsize=8)\n", "\n", "# 3. Sparsity histogram (all weights)\n", "all_abs = torch.cat([p.detach().abs().flatten() for p in model.parameters()]).numpy()\n", "axes[0, 2].hist(np.log10(all_abs + 1e-10), bins=500, color='steelblue', alpha=0.8)\n", "axes[0, 2].set_title(\"log10(|weight|) distribution (all params)\")\n", "axes[0, 2].axvline(np.log10(0.01), color='red', ls='--', alpha=0.5, label='0.01')\n", "axes[0, 2].legend()\n", "\n", "# 4. QK eigenvalue spectra\n", "for location in ['encoder']:\n", " q_ws = sorted([e for e in weight_catalog['self_attn_q'] if e['location'] == location], key=lambda x: x['layer'])\n", " k_ws = sorted([e for e in weight_catalog['self_attn_k'] if e['location'] == location], key=lambda x: x['layer'])\n", " for q_e, k_e in zip(q_ws, k_ws):\n", " QK = q_e['param'] @ k_e['param'].T\n", " QK_sym = (QK + QK.T) / 2\n", " eigvals = torch.linalg.eigvalsh(QK_sym).numpy()[::-1]\n", " axes[1, 0].plot(eigvals, alpha=0.5, label=f\"L{q_e['layer']}\")\n", "axes[1, 0].set_title(\"QK^T eigenvalues (encoder)\")\n", "axes[1, 0].axhline(0, color='black', ls='-', alpha=0.3)\n", "if q_ws:\n", " axes[1, 0].legend(fontsize=8)\n", "\n", "# 5. MLP neuron importance distribution\n", "all_combined = []\n", "for wi_e in weight_catalog.get('mlp_wi', []):\n", " if wi_e['location'] != 'encoder':\n", " continue\n", " wo_e = [e for e in weight_catalog.get('mlp_wo', [])\n", " if e['location'] == 'encoder' and e['layer'] == wi_e['layer']][0]\n", " up_norms = torch.norm(wi_e['param'], dim=1)\n", " down_norms = torch.norm(wo_e['param'], dim=0)\n", " combined = (up_norms * down_norms).numpy()\n", " all_combined.extend(combined)\n", "axes[1, 1].hist(all_combined, bins=200, color='darkorange', alpha=0.8)\n", "axes[1, 1].axvline(np.mean(all_combined) * 0.1, color='red', ls='--', label='10% of mean')\n", "axes[1, 1].set_title(\"MLP neuron importance (encoder)\")\n", "axes[1, 1].legend()\n", "\n", "# 6. Cross-layer Q weight correlation heatmap (encoder)\n", "q_entries = sorted([e for e in weight_catalog.get('self_attn_q', []) if e['location'] == 'encoder'],\n", " key=lambda x: x['layer'])\n", "n_q = len(q_entries)\n", "q_cross = np.zeros((n_q, n_q))\n", "for i in range(n_q):\n", " for j in range(n_q):\n", " Wi = q_entries[i]['param'].flatten()\n", " Wj = q_entries[j]['param'].flatten()\n", " q_cross[i, j] = (torch.dot(Wi, Wj) / (torch.norm(Wi) * torch.norm(Wj) + 1e-8)).item()\n", "im = axes[1, 2].imshow(q_cross, cmap='RdBu_r', vmin=-0.2, vmax=0.2)\n", "axes[1, 2].set_title(\"Cross-layer Q weight cosine (encoder)\")\n", "axes[1, 2].set_xticks(range(n_q))\n", "axes[1, 2].set_yticks(range(n_q))\n", "axes[1, 2].set_xticklabels([f\"L{e['layer']}\" for e in q_entries])\n", "axes[1, 2].set_yticklabels([f\"L{e['layer']}\" for e in q_entries])\n", "plt.colorbar(im, ax=axes[1, 2])\n", "\n", "# 7. LayerNorm scale distributions\n", "ln_weights = []\n", "ln_labels = []\n", "for entry in weight_catalog.get('layernorm', []):\n", " if 'weight' in entry['name'] and 'encoder' in entry['name']:\n", " ln_weights.append(entry['param'].numpy())\n", " ln_labels.append(f\"L{entry['layer']}\")\n", "if ln_weights:\n", " axes[2, 0].boxplot(ln_weights, tick_labels=ln_labels[:len(ln_weights)])\n", " axes[2, 0].axhline(1.0, color='red', ls='--', alpha=0.5)\n", " axes[2, 0].set_title(\"LayerNorm scales (encoder)\")\n", "\n", "# 8. Position bias heatmap (encoder)\n", "for entry in weight_catalog.get('position_bias', []):\n", " if 'encoder' in entry['name']:\n", " rpb = entry['param'].numpy()\n", " im2 = axes[2, 1].imshow(rpb.T, aspect='auto', cmap='RdBu_r')\n", " axes[2, 1].set_title(\"Position bias (encoder, heads × buckets)\")\n", " axes[2, 1].set_xlabel(\"Bucket\")\n", " axes[2, 1].set_ylabel(\"Head\")\n", " plt.colorbar(im2, ax=axes[2, 1])\n", " break\n", "\n", "# 9. Condition number by layer\n", "for wtype in ['self_attn_q', 'self_attn_k', 'mlp_wi', 'mlp_wo']:\n", " subset = sorted([r for r in svd_results if r['wtype'] == wtype and r['location'] == 'encoder'],\n", " key=lambda x: x['layer'])\n", " if subset:\n", " layers = [r['layer'] for r in subset]\n", " conds = [np.log10(r['condition']) for r in subset]\n", " axes[2, 2].plot(layers, conds, 'o-', label=wtype)\n", "axes[2, 2].set_title(\"log10(condition number) by layer\")\n", "axes[2, 2].set_xlabel(\"Layer\")\n", "if any(r['wtype'] in ['self_attn_q', 'self_attn_k', 'mlp_wi', 'mlp_wo'] for r in svd_results):\n", " axes[2, 2].legend(fontsize=8)\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"/content/t5_inactive_geometry.png\", dpi=150, bbox_inches='tight')\n", "plt.show()\n", "print(\"\\nSaved: /content/t5_inactive_geometry.png\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"T5-SMALL INACTIVE WEIGHT GEOMETRY — SUMMARY\")\n", "print(f\"{'='*70}\")\n", "\n", "# Key findings\n", "q_results = [r for r in svd_results if r['wtype'] == 'self_attn_q']\n", "k_results = [r for r in svd_results if r['wtype'] == 'self_attn_k']\n", "mlp_wi_results = [r for r in svd_results if r['wtype'] == 'mlp_wi']\n", "\n", "print(f\"\\n--- EFFECTIVE RANK ---\")\n", "print(f\"Q matrices stable rank: {np.mean([r['stable_rank'] for r in q_results]):.2f} ± {np.std([r['stable_rank'] for r in q_results]):.2f}\")\n", "print(f\"K matrices stable rank: {np.mean([r['stable_rank'] for r in k_results]):.2f} ± {np.std([r['stable_rank'] for r in k_results]):.2f}\")\n", "print(f\"MLP wi stable rank: {np.mean([r['stable_rank'] for r in mlp_wi_results]):.2f} ± {np.std([r['stable_rank'] for r in mlp_wi_results]):.2f}\")\n", "\n", "print(f\"\\n--- SPARSITY ---\")\n", "total_params = sum(p.numel() for p in model.parameters())\n", "near_zero = sum((p.abs() < 1e-3).sum().item() for p in model.parameters())\n", "print(f\"Params < 1e-3: {near_zero:,} / {total_params:,} ({near_zero/total_params*100:.2f}%)\")\n", "\n", "print(f\"\\n--- QK MANIFOLD ---\")\n", "print(f\"(See per-layer QK eigenvalue analysis above)\")\n", "print(f\"Negative eigenvalues present = model learned anti-similarity directions\")\n", "print(f\"These define the BOUNDARIES of semantic categories\")\n" ] }, { "cell_type": "markdown", "id": "5dd48e3f", "metadata": {}, "source": [ "## 5. Cross-Architecture Weight Battery\n", "*Section VI (BERT, CLIP, DINOv2): GPU-accelerated*\n", "\n", "**Requires:** `pip install open_clip_torch`" ] }, { "cell_type": "code", "execution_count": null, "id": "f3d941f0", "metadata": {}, "outputs": [], "source": [ "# Cross-architecture battery\n", "\n", "# ============================================================================\n", "# CROSS-ARCHITECTURE INACTIVE WEIGHT GEOMETRY BATTERY\n", "# BERT-large | CLIP-ViT-B/16 | DINOv2-large | CLIP-ViT-bigG\n", "# Question: Is Q sparsity universal? Is the pentachoron CV constant?\n", "# No inference. Just the frozen weights.\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import math\n", "import time\n", "from collections import defaultdict\n", "\n", "device = torch.device(\"cpu\") # Weight analysis only — no GPU needed\n", "\n", "def classify_weights_bert(model):\n", " \"\"\"Classify BERT-large weight matrices.\"\"\"\n", " catalog = defaultdict(list)\n", " for name, param in model.named_parameters():\n", " p = param.detach().float().cpu()\n", " parts = name.split('.')\n", "\n", " layer_num = -1\n", " for i, part in enumerate(parts):\n", " if part == 'layer' and i + 1 < len(parts):\n", " try: layer_num = int(parts[i + 1])\n", " except: pass\n", "\n", " if 'embeddings' in name:\n", " wtype = 'embedding'\n", " elif 'LayerNorm' in name or 'layernorm' in name:\n", " wtype = 'layernorm'\n", " elif 'attention' in name and 'self' in name:\n", " for sub in ['query', 'key', 'value']:\n", " if sub in name:\n", " wtype = f'self_attn_{sub[0]}'\n", " break\n", " else:\n", " wtype = 'self_attn_other'\n", " elif 'attention' in name and 'output' in name and 'dense' in name:\n", " wtype = 'self_attn_o'\n", " elif 'intermediate' in name:\n", " wtype = 'mlp_up'\n", " elif 'output' in name and 'dense' in name and 'attention' not in name:\n", " wtype = 'mlp_down'\n", " elif 'pooler' in name:\n", " wtype = 'pooler'\n", " else:\n", " wtype = 'other'\n", "\n", " if p.dim() == 2:\n", " catalog[wtype].append({\n", " 'name': name, 'shape': tuple(p.shape), 'param': p,\n", " 'layer': layer_num, 'numel': p.numel(),\n", " })\n", " return catalog\n", "\n", "\n", "def classify_weights_vit_transformers(model):\n", " \"\"\"Classify ViT weights from transformers library (DINOv2).\"\"\"\n", " catalog = defaultdict(list)\n", " for name, param in model.named_parameters():\n", " p = param.detach().float().cpu()\n", " parts = name.split('.')\n", "\n", " layer_num = -1\n", " for i, part in enumerate(parts):\n", " if part == 'layer' and i + 1 < len(parts):\n", " try: layer_num = int(parts[i + 1])\n", " except: pass\n", "\n", " if 'embeddings' in name:\n", " wtype = 'embedding'\n", " elif 'layernorm' in name.lower() or 'layer_norm' in name.lower() or 'norm' in name:\n", " wtype = 'layernorm'\n", " elif 'attention' in name:\n", " if 'query' in name or '.q.' in name or name.endswith('.q.weight'):\n", " wtype = 'self_attn_q'\n", " elif 'key' in name or '.k.' in name or name.endswith('.k.weight'):\n", " wtype = 'self_attn_k'\n", " elif 'value' in name or '.v.' in name or name.endswith('.v.weight'):\n", " wtype = 'self_attn_v'\n", " elif 'output' in name or 'proj' in name:\n", " wtype = 'self_attn_o'\n", " else:\n", " # DINOv2 uses qkv fused\n", " if 'qkv' in name:\n", " wtype = 'self_attn_qkv'\n", " else:\n", " wtype = 'self_attn_other'\n", " elif 'mlp' in name or 'intermediate' in name:\n", " if 'fc1' in name or 'dense' in name.split('.')[-2:-1]:\n", " wtype = 'mlp_up'\n", " elif 'fc2' in name:\n", " wtype = 'mlp_down'\n", " else:\n", " wtype = 'mlp_other'\n", " else:\n", " wtype = 'other'\n", "\n", " if p.dim() == 2:\n", " catalog[wtype].append({\n", " 'name': name, 'shape': tuple(p.shape), 'param': p,\n", " 'layer': layer_num, 'numel': p.numel(),\n", " })\n", " return catalog\n", "\n", "\n", "def classify_weights_open_clip(model_visual):\n", " \"\"\"Classify open_clip ViT visual encoder weights.\"\"\"\n", " catalog = defaultdict(list)\n", " for name, param in model_visual.named_parameters():\n", " p = param.detach().float().cpu()\n", " parts = name.split('.')\n", "\n", " layer_num = -1\n", " for i, part in enumerate(parts):\n", " if part == 'resblocks' and i + 1 < len(parts):\n", " try: layer_num = int(parts[i + 1])\n", " except: pass\n", "\n", " if 'token_embedding' in name or 'class_embedding' in name or 'positional_embedding' in name:\n", " wtype = 'embedding'\n", " elif 'ln_' in name or 'norm' in name:\n", " wtype = 'layernorm'\n", " elif 'attn' in name:\n", " if 'in_proj_weight' in name:\n", " wtype = 'self_attn_qkv' # fused QKV\n", " elif 'out_proj' in name:\n", " wtype = 'self_attn_o'\n", " elif 'q_proj' in name:\n", " wtype = 'self_attn_q'\n", " elif 'k_proj' in name:\n", " wtype = 'self_attn_k'\n", " elif 'v_proj' in name:\n", " wtype = 'self_attn_v'\n", " else:\n", " wtype = 'self_attn_other'\n", " elif 'mlp' in name or 'c_fc' in name or 'c_proj' in name:\n", " if 'c_fc' in name or ('mlp' in name and ('0' in parts[-2] or 'fc1' in name)):\n", " wtype = 'mlp_up'\n", " elif 'c_proj' in name or ('mlp' in name and ('2' in parts[-2] or 'fc2' in name)):\n", " wtype = 'mlp_down'\n", " else:\n", " wtype = 'mlp_other'\n", " elif 'proj' in name and 'attn' not in name:\n", " wtype = 'projection'\n", " else:\n", " wtype = 'other'\n", "\n", " if p.dim() == 2:\n", " catalog[wtype].append({\n", " 'name': name, 'shape': tuple(p.shape), 'param': p,\n", " 'layer': layer_num, 'numel': p.numel(),\n", " })\n", " return catalog\n", "\n", "\n", "\n", "def analyze_sparsity(catalog, thresholds=[1e-4, 1e-3, 1e-2, 1e-1]):\n", " \"\"\"Compute sparsity at multiple thresholds per weight type. GPU-accelerated.\"\"\"\n", " gpu = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " results = {}\n", " for wtype, entries in catalog.items():\n", " if not entries:\n", " continue\n", " all_vals = torch.cat([e['param'].abs().flatten() for e in entries]).to(gpu)\n", " total = len(all_vals)\n", " results[wtype] = {\n", " 'n_matrices': len(entries),\n", " 'total_params': total,\n", " }\n", " for t in thresholds:\n", " results[wtype][f'<{t}'] = (all_vals < t).sum().item() / total\n", " del all_vals\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " return results\n", "\n", "\n", "def analyze_svd(catalog, types_to_analyze=None):\n", " \"\"\"SVD analysis on 2D weight matrices. GPU-accelerated.\"\"\"\n", " if types_to_analyze is None:\n", " types_to_analyze = [k for k in catalog.keys()\n", " if k not in ['embedding', 'layernorm', 'other', 'pooler', 'projection']]\n", "\n", " gpu = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " results = []\n", " for wtype in types_to_analyze:\n", " if wtype not in catalog:\n", " continue\n", " for entry in catalog[wtype]:\n", " W = entry['param']\n", " if W.dim() != 2:\n", " continue\n", " W_gpu = W.to(gpu)\n", " S = torch.linalg.svdvals(W_gpu).cpu().numpy()\n", " del W_gpu\n", "\n", " stable_rank = (S ** 2).sum() / (S[0] ** 2) if S[0] > 0 else 0\n", " pr = (S.sum()) ** 2 / ((S ** 2).sum()) if (S ** 2).sum() > 0 else 0\n", " active_frac = (S > 0.01 * S[0]).sum() / len(S)\n", " rank_90 = np.searchsorted(np.cumsum(S) / S.sum(), 0.90) + 1\n", " condition = S[0] / (S[-1] + 1e-10)\n", "\n", " results.append({\n", " 'name': entry['name'], 'wtype': wtype, 'layer': entry['layer'],\n", " 'shape': entry['shape'], 'stable_rank': stable_rank,\n", " 'pr': pr, 'active_frac': active_frac, 'rank_90': rank_90,\n", " 'condition': condition, 'max_sv': S[0], 'min_sv': S[-1],\n", " })\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " return results\n", "\n", "\n", "def analyze_qk_manifold(catalog, d_model):\n", " \"\"\"\n", " QK similarity manifold analysis.\n", " Handles both separate Q/K matrices and fused QKV.\n", " GPU-accelerated.\n", " \"\"\"\n", " gpu = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " results = []\n", "\n", " # Try separate Q, K first\n", " q_entries = sorted([e for e in catalog.get('self_attn_q', [])], key=lambda x: x['layer'])\n", " k_entries = sorted([e for e in catalog.get('self_attn_k', [])], key=lambda x: x['layer'])\n", "\n", " if q_entries and k_entries and len(q_entries) == len(k_entries):\n", " for q_e, k_e in zip(q_entries, k_entries):\n", " W_q = q_e['param'].to(gpu)\n", " W_k = k_e['param'].to(gpu)\n", " QK = W_q @ W_k.T\n", " _analyze_qk_matrix(QK, q_e['layer'], results, gpu)\n", " del W_q, W_k, QK\n", " elif 'self_attn_qkv' in catalog:\n", " for entry in sorted(catalog['self_attn_qkv'], key=lambda x: x['layer']):\n", " W = entry['param'].to(gpu)\n", " total_out = W.shape[0]\n", " third = total_out // 3\n", " W_q = W[:third]\n", " W_k = W[third:2*third]\n", " QK = W_q @ W_k.T\n", " _analyze_qk_matrix(QK, entry['layer'], results, gpu)\n", " del W, W_q, W_k, QK\n", "\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " return results\n", "\n", "\n", "def _analyze_qk_matrix(QK, layer, results, device):\n", " \"\"\"Analyze a single QK^T matrix on GPU.\"\"\"\n", " S_qk = torch.linalg.svdvals(QK).cpu().numpy()\n", " stable_rank = (S_qk ** 2).sum() / (S_qk[0] ** 2) if S_qk[0] > 0 else 0\n", " pr = (S_qk.sum()) ** 2 / ((S_qk ** 2).sum()) if (S_qk ** 2).sum() > 0 else 0\n", "\n", " sym_diff = torch.norm(QK - QK.T).item() / (torch.norm(QK).item() + 1e-10)\n", "\n", " QK_sym = (QK + QK.T) / 2\n", " eigvals = torch.linalg.eigvalsh(QK_sym).cpu().numpy()[::-1]\n", " n_pos = (eigvals > 0).sum()\n", " n_neg = (eigvals < 0).sum()\n", "\n", " results.append({\n", " 'layer': layer, 'stable_rank': stable_rank, 'pr': pr,\n", " 'sym_dev': sym_diff, 'n_positive': n_pos, 'n_negative': n_neg,\n", " 'top3_eig': eigvals[:3].tolist(), 'dim': QK.shape[0],\n", " })\n", " del QK_sym, eigvals\n", "\n", "\n", "def analyze_dead_neurons(catalog):\n", " \"\"\"MLP dead neuron analysis.\"\"\"\n", " results = []\n", " up_entries = sorted(catalog.get('mlp_up', []), key=lambda x: x['layer'])\n", " down_entries = sorted(catalog.get('mlp_down', []), key=lambda x: x['layer'])\n", "\n", " if not up_entries or not down_entries:\n", " return results\n", "\n", " for up_e, down_e in zip(up_entries, down_entries):\n", " W_up = up_e['param']\n", " W_down = down_e['param']\n", "\n", " # up: [d_ff, d_model], down: [d_model, d_ff]\n", " up_norms = torch.norm(W_up, dim=1) # [d_ff]\n", " down_norms = torch.norm(W_down, dim=0) # [d_ff]\n", " combined = up_norms * down_norms\n", "\n", " d_ff = W_up.shape[0]\n", " mean_c = combined.mean().item()\n", " dead_01 = (combined < 0.01 * mean_c).sum().item()\n", " dead_10 = (combined < 0.10 * mean_c).sum().item()\n", "\n", " results.append({\n", " 'layer': up_e['layer'], 'd_ff': d_ff,\n", " 'dead_1pct': dead_01, 'dead_10pct': dead_10,\n", " 'dead_1pct_frac': dead_01 / d_ff, 'dead_10pct_frac': dead_10 / d_ff,\n", " })\n", " return results\n", "\n", "\n", "def cross_layer_correlation(catalog, wtype):\n", " \"\"\"Compute cross-layer weight cosine for a given weight type. GPU-accelerated.\"\"\"\n", " entries = sorted([e for e in catalog.get(wtype, []) if e['param'].dim() == 2],\n", " key=lambda x: x['layer'])\n", " if len(entries) < 2:\n", " return None\n", "\n", " gpu = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " n = len(entries)\n", " corr = np.zeros((n, n))\n", " flat_gpu = [e['param'].flatten().to(gpu) for e in entries]\n", "\n", " for i in range(n):\n", " for j in range(n):\n", " if len(flat_gpu[i]) == len(flat_gpu[j]):\n", " cos = torch.dot(flat_gpu[i], flat_gpu[j]) / (\n", " torch.norm(flat_gpu[i]) * torch.norm(flat_gpu[j]) + 1e-8)\n", " corr[i, j] = cos.item()\n", "\n", " del flat_gpu\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " return corr, [e['layer'] for e in entries]\n", "\n", "\n", "\n", "def run_full_battery(model_name, catalog, d_model):\n", " \"\"\"Run complete inactive weight geometry battery.\"\"\"\n", " print(f\"\\n{'='*70}\")\n", " print(f\"MODEL: {model_name}\")\n", " print(f\"{'='*70}\")\n", "\n", " # Catalog summary\n", " print(f\"\\n--- WEIGHT CATALOG ---\")\n", " total_params = 0\n", " for wtype, entries in sorted(catalog.items()):\n", " n_params = sum(e['numel'] for e in entries)\n", " total_params += n_params\n", " shapes = set(str(e['shape']) for e in entries)\n", " print(f\" {wtype:25s}: {len(entries):3d} matrices, {n_params:>12,} params, shapes={shapes}\")\n", " print(f\" {'TOTAL':25s}: {total_params:>12,} params (2D only)\")\n", "\n", " # SVD\n", " print(f\"\\n--- SVD EFFECTIVE RANK ---\")\n", " svd_results = analyze_svd(catalog)\n", " svd_by_type = defaultdict(list)\n", " for r in svd_results:\n", " svd_by_type[r['wtype']].append(r)\n", "\n", " print(f\"{'Type':25s} {'StableRank':>10s} {'PR':>8s} {'Active%':>8s} {'Rank90':>7s} {'Condition':>10s}\")\n", " for wtype in sorted(svd_by_type.keys()):\n", " subset = svd_by_type[wtype]\n", " sr = np.mean([r['stable_rank'] for r in subset])\n", " pr = np.mean([r['pr'] for r in subset])\n", " af = np.mean([r['active_frac'] for r in subset])\n", " r90 = np.mean([r['rank_90'] for r in subset])\n", " cond = np.mean([r['condition'] for r in subset])\n", " print(f\" {wtype:23s} {sr:10.2f} {pr:8.2f} {af:8.3f} {r90:7.1f} {cond:10.1f}\")\n", "\n", " # Sparsity\n", " print(f\"\\n--- SPARSITY TOPOLOGY ---\")\n", " sparsity = analyze_sparsity(catalog)\n", " thresholds = [1e-4, 1e-3, 1e-2, 1e-1]\n", " print(f\"{'Type':25s}\", end=\"\")\n", " for t in thresholds:\n", " print(f\" {'<'+str(t):>8s}\", end=\"\")\n", " print()\n", " for wtype in sorted(sparsity.keys()):\n", " print(f\" {wtype:23s}\", end=\"\")\n", " for t in thresholds:\n", " print(f\" {sparsity[wtype].get(f'<{t}', 0):8.4f}\", end=\"\")\n", " print()\n", "\n", " # Full model sparsity\n", " gpu = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " all_params = torch.cat([e['param'].abs().flatten()\n", " for entries in catalog.values()\n", " for e in entries]).to(gpu)\n", " print(f\" {'FULL MODEL':23s}\", end=\"\")\n", " for t in thresholds:\n", " frac = (all_params < t).sum().item() / len(all_params)\n", " print(f\" {frac:8.4f}\", end=\"\")\n", " print()\n", " del all_params\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", "\n", " # Q vs K vs V sparsity highlight\n", " print(f\"\\n--- Q/K/V SPARSITY COMPARISON (<0.1 threshold) ---\")\n", " for wtype in ['self_attn_q', 'self_attn_k', 'self_attn_v', 'self_attn_qkv']:\n", " if wtype in sparsity:\n", " print(f\" {wtype:25s}: {sparsity[wtype].get('<0.1', 0)*100:.1f}%\")\n", "\n", " # QK manifold\n", " print(f\"\\n--- QK SIMILARITY MANIFOLD ---\")\n", " qk_results = analyze_qk_manifold(catalog, d_model)\n", " if qk_results:\n", " print(f\"{'Layer':>6s} {'StableRk':>8s} {'PR':>8s} {'Pos':>5s} {'Neg':>5s} {'SymDev':>8s} {'TopEig':>10s}\")\n", " for r in qk_results[:6]: # first 6 layers\n", " print(f\" {r['layer']:4d} {r['stable_rank']:8.2f} {r['pr']:8.2f} \"\n", " f\"{r['n_positive']:5d} {r['n_negative']:5d} {r['sym_dev']:8.4f} \"\n", " f\"{r['top3_eig'][0]:10.2f}\")\n", " if len(qk_results) > 6:\n", " print(f\" ... ({len(qk_results)} layers total)\")\n", " # Last layer\n", " r = qk_results[-1]\n", " print(f\" {r['layer']:4d} {r['stable_rank']:8.2f} {r['pr']:8.2f} \"\n", " f\"{r['n_positive']:5d} {r['n_negative']:5d} {r['sym_dev']:8.4f} \"\n", " f\"{r['top3_eig'][0]:10.2f}\")\n", "\n", " # Trend: positive eigenvalue fraction by depth\n", " first = qk_results[0]\n", " last = qk_results[-1]\n", " first_pos_frac = first['n_positive'] / first['dim']\n", " last_pos_frac = last['n_positive'] / last['dim']\n", " print(f\"\\n Positive eig fraction: layer 0 = {first_pos_frac:.3f}, last = {last_pos_frac:.3f}\")\n", " else:\n", " print(\" (Could not extract QK manifold)\")\n", "\n", " # Dead neurons\n", " print(f\"\\n--- MLP DEAD NEURONS ---\")\n", " dead = analyze_dead_neurons(catalog)\n", " if dead:\n", " total_dead_1 = sum(d['dead_1pct'] for d in dead)\n", " total_neurons = sum(d['d_ff'] for d in dead)\n", " total_dead_10 = sum(d['dead_10pct'] for d in dead)\n", " print(f\" Dead (<1% mean): {total_dead_1}/{total_neurons} ({total_dead_1/total_neurons*100:.2f}%)\")\n", " print(f\" Weak (<10% mean): {total_dead_10}/{total_neurons} ({total_dead_10/total_neurons*100:.2f}%)\")\n", " else:\n", " print(\" (Could not analyze — check weight naming)\")\n", "\n", " # Cross-layer Q correlation\n", " print(f\"\\n--- CROSS-LAYER CORRELATION (adjacent pairs) ---\")\n", " for wtype in ['self_attn_q', 'self_attn_k', 'self_attn_qkv', 'mlp_up']:\n", " result = cross_layer_correlation(catalog, wtype)\n", " if result is not None:\n", " corr, layers = result\n", " adj_corrs = [corr[i, i+1] for i in range(len(layers)-1)]\n", " print(f\" {wtype:25s}: adj_mean={np.mean(adj_corrs):.4f}, \"\n", " f\"adj_range=[{min(adj_corrs):.4f}, {max(adj_corrs):.4f}]\")\n", "\n", " return {\n", " 'svd': svd_results,\n", " 'sparsity': sparsity,\n", " 'qk': qk_results,\n", " 'dead': dead,\n", " 'model_name': model_name,\n", " }\n", "\n", "\n", "print(\"Loading BERT-large...\")\n", "from transformers import BertModel\n", "bert = BertModel.from_pretrained(\"google-bert/bert-large-uncased\", torch_dtype=torch.float32)\n", "bert.eval()\n", "bert_catalog = classify_weights_bert(bert)\n", "bert_results = run_full_battery(\"BERT-large (1024d, 24L, 16H)\", bert_catalog, d_model=1024)\n", "del bert\n", "torch.cuda.empty_cache()\n", "\n", "print(\"\\n\\nLoading CLIP-ViT-B/16 (LAION)...\")\n", "import open_clip\n", "clip_b_model, _, _ = open_clip.create_model_and_transforms(\n", " 'ViT-B-16', pretrained='laion2b_s34b_b88k'\n", ")\n", "clip_b_model.eval()\n", "clip_b_catalog = classify_weights_open_clip(clip_b_model.visual)\n", "clip_b_results = run_full_battery(\"CLIP-ViT-B/16 LAION (768d, 12L, 12H)\", clip_b_catalog, d_model=768)\n", "del clip_b_model\n", "torch.cuda.empty_cache()\n", "\n", "print(\"\\n\\nLoading DINOv2-large...\")\n", "from transformers import Dinov2Model\n", "dino = Dinov2Model.from_pretrained(\"facebook/dinov2-large\", torch_dtype=torch.float32)\n", "dino.eval()\n", "dino_catalog = classify_weights_vit_transformers(dino)\n", "dino_results = run_full_battery(\"DINOv2-large (1024d, 24L, 16H)\", dino_catalog, d_model=1024)\n", "del dino\n", "torch.cuda.empty_cache()\n", "\n", "print(\"\\n\\nLoading CLIP-ViT-bigG/14 (LAION)...\")\n", "clip_g_model, _, _ = open_clip.create_model_and_transforms(\n", " 'ViT-bigG-14', pretrained='laion2b_s39b_b160k'\n", ")\n", "clip_g_model.eval()\n", "clip_g_catalog = classify_weights_open_clip(clip_g_model.visual)\n", "clip_g_results = run_full_battery(\"CLIP-ViT-bigG/14 LAION (1664d, 48L, 16H)\", clip_g_catalog, d_model=1664)\n", "del clip_g_model\n", "torch.cuda.empty_cache()\n", "\n", "print(f\"\\n\\n{'='*70}\")\n", "print(\"CROSS-MODEL COMPARISON\")\n", "print(f\"{'='*70}\")\n", "\n", "all_results = [bert_results, clip_b_results, dino_results, clip_g_results]\n", "\n", "# Q sparsity comparison\n", "print(f\"\\n--- Q SPARSITY (<0.1 threshold) ---\")\n", "print(f\"{'Model':45s} {'Q':>8s} {'K':>8s} {'V':>8s} {'QKV':>8s}\")\n", "for res in all_results:\n", " sp = res['sparsity']\n", " q = sp.get('self_attn_q', {}).get('<0.1', None)\n", " k = sp.get('self_attn_k', {}).get('<0.1', None)\n", " v = sp.get('self_attn_v', {}).get('<0.1', None)\n", " qkv = sp.get('self_attn_qkv', {}).get('<0.1', None)\n", " print(f\" {res['model_name']:43s}\", end=\"\")\n", " print(f\" {q*100:7.1f}%\" if q else \" -\", end=\"\")\n", " print(f\" {k*100:7.1f}%\" if k else \" -\", end=\"\")\n", " print(f\" {v*100:7.1f}%\" if v else \" -\", end=\"\")\n", " print(f\" {qkv*100:7.1f}%\" if qkv else \" -\")\n", "\n", "# Reference: T5 numbers\n", "print(f\" {'T5-Small (512d, 6L, 8H) [reference]':43s} 93.7% 19.2% 12.1% -\")\n", "print(f\" {'T5-Base (768d, 12L, 12H) [reference]':43s} 99.4% 30.0% 16.2% -\")\n", "\n", "# SVD stable rank comparison\n", "print(f\"\\n--- SVD STABLE RANK (mean across layers) ---\")\n", "print(f\"{'Model':45s} {'Q':>8s} {'K':>8s} {'V':>8s} {'MLP_up':>8s}\")\n", "for res in all_results:\n", " svd_by_type = defaultdict(list)\n", " for r in res['svd']:\n", " svd_by_type[r['wtype']].append(r['stable_rank'])\n", "\n", " print(f\" {res['model_name']:43s}\", end=\"\")\n", " for wtype in ['self_attn_q', 'self_attn_k', 'self_attn_v', 'mlp_up']:\n", " vals = svd_by_type.get(wtype, [])\n", " if vals:\n", " print(f\" {np.mean(vals):8.1f}\", end=\"\")\n", " else:\n", " print(f\" -\", end=\"\")\n", " print()\n", "\n", "# QK manifold comparison\n", "print(f\"\\n--- QK MANIFOLD: POSITIVE EIGENVALUE FRACTION ---\")\n", "print(f\"{'Model':45s} {'First':>8s} {'Last':>8s} {'Trend':>8s}\")\n", "for res in all_results:\n", " qk = res['qk']\n", " if qk:\n", " first = qk[0]\n", " last = qk[-1]\n", " f_frac = first['n_positive'] / first['dim']\n", " l_frac = last['n_positive'] / last['dim']\n", " trend = l_frac - f_frac\n", " print(f\" {res['model_name']:43s} {f_frac:8.3f} {l_frac:8.3f} {trend:+8.3f}\")\n", " else:\n", " print(f\" {res['model_name']:43s} - - -\")\n", "\n", "# Dead neurons comparison\n", "print(f\"\\n--- MLP DEAD NEURONS (<1% of mean) ---\")\n", "for res in all_results:\n", " dead = res['dead']\n", " if dead:\n", " total_dead = sum(d['dead_1pct'] for d in dead)\n", " total_neurons = sum(d['d_ff'] for d in dead)\n", " print(f\" {res['model_name']:43s}: {total_dead}/{total_neurons} ({total_dead/total_neurons*100:.2f}%)\")\n", " else:\n", " print(f\" {res['model_name']:43s}: N/A\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"BATTERY COMPLETE\")\n", "print(f\"{'='*70}\")\n" ] }, { "cell_type": "markdown", "id": "8560e52b", "metadata": {}, "source": [ "## 6. T5-v1.1-XXL (Flux Text Encoder) — Full Battery\n", "*Section VI (XXL): ALL 48 layers, encoder + decoder + cross-attn*\n", "\n", "**Hardware:** ~30GB VRAM\n", "\n", "**Key:** Q=100.0%, cross-attn QK locked 0.500, decoder 0 mixed position heads" ] }, { "cell_type": "code", "execution_count": null, "id": "710e5974", "metadata": {}, "outputs": [], "source": [ "# T5-v1.1-XXL full battery\n", "\n", "# ============================================================================\n", "# T5-v1.1-XXL (FLUX TEXT ENCODER) — G4 94GB\n", "# Load full model on GPU. SVD on GPU. Eigendecompositions on CPU.\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import time\n", "import gc\n", "from collections import defaultdict\n", "\n", "device = torch.device(\"cuda\")\n", "print(f\"GPU: {torch.cuda.get_device_name()}\")\n", "print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n", "\n", "# Load\n", "print(\"\\nLoading google/t5-v1_1-xxl (fp16 → GPU)...\")\n", "t0 = time.time()\n", "from transformers import T5ForConditionalGeneration\n", "model = T5ForConditionalGeneration.from_pretrained(\n", " \"google/t5-v1_1-xxl\", torch_dtype=torch.float16, device_map=\"auto\",\n", ")\n", "model.eval()\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Loaded in {time.time()-t0:.0f}s, {total_params:,} params\")\n", "print(f\"VRAM used: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "\n", "config = model.config\n", "print(f\"d_model={config.d_model}, d_kv={config.d_kv}, d_ff={config.d_ff}, \"\n", " f\"heads={config.num_heads}, layers={config.num_layers}+{config.num_decoder_layers}, \"\n", " f\"ff={config.feed_forward_proj}\")\n", "\n", "# Classify\n", "def classify(name):\n", " parts = name.split('.')\n", " loc = 'encoder' if 'encoder' in name else ('decoder' if 'decoder' in name else 'shared')\n", " layer = -1\n", " for i, p in enumerate(parts):\n", " if p == 'block' and i+1 < len(parts):\n", " try: layer = int(parts[i+1])\n", " except: pass\n", " if 'embed' in name or 'shared' in name: wt = 'embedding'\n", " elif 'relative_attention_bias' in name: wt = 'position_bias'\n", " elif 'layer_norm' in name: wt = 'layernorm'\n", " elif 'SelfAttention' in name:\n", " wt = 'self_attn_other'\n", " for s in parts:\n", " if s in ('q','k','v','o'): wt = f'self_attn_{s}'; break\n", " elif 'EncDecAttention' in name:\n", " wt = 'cross_attn_other'\n", " for s in parts:\n", " if s in ('q','k','v','o'): wt = f'cross_attn_{s}'; break\n", " elif 'DenseReluDense' in name:\n", " wt = 'mlp_other'\n", " for s in parts:\n", " if s == 'wi_0': wt = 'mlp_gate'; break\n", " elif s == 'wi_1': wt = 'mlp_up'; break\n", " elif s == 'wo': wt = 'mlp_down'; break\n", " else: wt = 'other'\n", " return wt, loc, layer\n", "\n", "# Catalog\n", "print(f\"\\n{'='*70}\\nCATALOG\\n{'='*70}\")\n", "catalog = defaultdict(list)\n", "for name, param in model.named_parameters():\n", " if param.dim() != 2: continue\n", " wt, loc, layer = classify(name)\n", " catalog[wt].append({'name': name, 'shape': tuple(param.shape), 'loc': loc,\n", " 'layer': layer, 'numel': param.numel()})\n", "\n", "for wt, entries in sorted(catalog.items()):\n", " t = sum(e['numel'] for e in entries)\n", " enc = sum(1 for e in entries if e['loc']=='encoder')\n", " dec = sum(1 for e in entries if e['loc']=='decoder')\n", " shapes = set(str(e['shape']) for e in entries)\n", " print(f\" {wt:25s}: {len(entries):4d} (E:{enc} D:{dec}) {t:>15,} {shapes}\")\n", "\n", "# Helper — get param as fp32 on GPU\n", "def get_w(name):\n", " parts = name.split('.')\n", " obj = model\n", " for p in parts:\n", " obj = obj[int(p)] if p.isdigit() else getattr(obj, p)\n", " return obj.detach().float() # stays on whatever device it's on\n", "\n", "# ALL layers — no sampling\n", "enc_sample = list(range(config.num_layers)) # 0..23\n", "dec_sample = list(range(config.num_decoder_layers)) # 0..23\n", "print(f\"\\nEncoder layers: ALL {len(enc_sample)}\")\n", "print(f\"Decoder layers: ALL {len(dec_sample)}\")\n", "\n", "# ── SVD ──\n", "print(f\"\\n{'='*70}\\nSVD EFFECTIVE RANK\\n{'='*70}\")\n", "svd_results = []\n", "skip = {'embedding','layernorm','position_bias','other','self_attn_other',\n", " 'cross_attn_other','mlp_other'}\n", "all_entries = [(wt,e) for wt,entries in catalog.items() if wt not in skip for e in entries]\n", "total = len(all_entries)\n", "done = 0\n", "t0 = time.time()\n", "\n", "for wt, entry in all_entries:\n", " layer, loc = entry['layer'], entry['loc']\n", " done += 1\n", "\n", " if done % 10 == 0:\n", " print(f\" [{done}/{total}] {wt} {loc} L{layer} \", end=\"\\r\")\n", "\n", " try:\n", " W = get_w(entry['name'])\n", " S = torch.linalg.svdvals(W).cpu().numpy()\n", " svd_results.append({\n", " 'wt': wt, 'loc': loc, 'layer': layer, 'shape': entry['shape'],\n", " 'sr': (S**2).sum()/(S[0]**2) if S[0]>0 else 0,\n", " 'pr': (S.sum())**2/((S**2).sum()) if (S**2).sum()>0 else 0,\n", " 'af': (S>0.01*S[0]).sum()/len(S),\n", " 'r90': np.searchsorted(np.cumsum(S)/S.sum(),0.90)+1,\n", " 'cond': S[0]/(S[-1]+1e-10),\n", " })\n", " except Exception as e:\n", " print(f\"\\n SVD fail {entry['name']}: {e}\")\n", " done += 1\n", "\n", "print(f\" SVD done: {len(svd_results)} matrices in {time.time()-t0:.0f}s \")\n", "\n", "by_type = defaultdict(list)\n", "for r in svd_results: by_type[r['wt']].append(r)\n", "print(f\"\\n{'Type':25s} {'SR':>8s} {'PR':>8s} {'Act%':>6s} {'R90':>5s} {'Cond':>12s}\")\n", "for w in sorted(by_type):\n", " s = by_type[w]\n", " print(f\" {w:23s} {np.mean([r['sr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['pr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['af'] for r in s]):6.3f} \"\n", " f\"{np.mean([r['r90'] for r in s]):5.0f} \"\n", " f\"{np.mean([r['cond'] for r in s]):12.1f}\")\n", "\n", "# ── SPARSITY ──\n", "print(f\"\\n{'='*70}\\nSPARSITY\\n{'='*70}\")\n", "thresholds = [1e-4,1e-3,1e-2,1e-1]\n", "sparsity = defaultdict(lambda: {'total':0, **{t:0 for t in thresholds}})\n", "done = 0\n", "t0 = time.time()\n", "\n", "for wt, entries in catalog.items():\n", " if wt in skip: continue\n", " for entry in entries:\n", " done += 1\n", " if done % 20 == 0:\n", " print(f\" [{done}] {wt} {entry['loc']} L{entry['layer']} \", end=\"\\r\")\n", " W = get_w(entry['name'])\n", " a = W.abs(); n = a.numel()\n", " sparsity[wt]['total'] += n\n", " for t in thresholds:\n", " sparsity[wt][t] += (a8s} {'<1e-3':>8s} {'<0.01':>8s} {'<0.1':>8s}\")\n", "for wt in sorted(sparsity):\n", " sc = sparsity[wt]\n", " if sc['total']==0: continue\n", " print(f\" {wt:23s} {sc[1e-4]/sc['total']:8.4f} {sc[1e-3]/sc['total']:8.4f} \"\n", " f\"{sc[1e-2]/sc['total']:8.4f} {sc[1e-1]/sc['total']:8.4f}\")\n", "\n", "# Encoder vs decoder Q/K/V\n", "print(f\"\\n--- ENCODER vs DECODER SPARSITY (<0.1) ---\")\n", "for wt in ['self_attn_q','self_attn_k','self_attn_v','cross_attn_q','cross_attn_k','cross_attn_v']:\n", " for loc in ['encoder','decoder']:\n", " entries = [e for e in catalog.get(wt,[]) if e['loc']==loc]\n", " if not entries: continue\n", " total_n = 0; below = 0\n", " for e in entries:\n", " W = get_w(e['name']); a = W.abs()\n", " total_n += a.numel(); below += (a<0.1).sum().item()\n", " if total_n > 0:\n", " print(f\" {loc:8s} {wt:20s}: {below/total_n*100:.1f}%\")\n", "\n", "# ── QK MANIFOLD — eigendecomposition on CPU ──\n", "print(f\"\\n{'='*70}\\nQK MANIFOLD (eigvalsh on CPU)\\n{'='*70}\")\n", "\n", "for loc, samples in [('encoder', enc_sample), ('decoder', dec_sample)]:\n", " print(f\"\\n--- {loc.upper()} self-attention ---\")\n", " q_map = {e['layer']:e['name'] for e in catalog.get('self_attn_q',[]) if e['loc']==loc}\n", " k_map = {e['layer']:e['name'] for e in catalog.get('self_attn_k',[]) if e['loc']==loc}\n", "\n", " qk_results = []\n", " for layer in samples:\n", " if layer not in q_map or layer not in k_map: continue\n", " try:\n", " t0 = time.time()\n", " # Get Q,K on GPU, compute QK on GPU, move result to CPU for eigvalsh\n", " Wq = get_w(q_map[layer])\n", " Wk = get_w(k_map[layer])\n", " QK = Wq @ Wk.T\n", " del Wq, Wk\n", "\n", " # SVD on GPU (fast, small result)\n", " S = torch.linalg.svdvals(QK).cpu().numpy()\n", " sr = (S**2).sum()/(S[0]**2) if S[0]>0 else 0\n", "\n", " sym = torch.norm(QK-QK.T).item()/(torch.norm(QK).item()+1e-10)\n", "\n", " # Move to CPU for eigvalsh (avoids GPU memory spike)\n", " QK_cpu = ((QK+QK.T)/2).cpu()\n", " del QK\n", " torch.cuda.empty_cache()\n", "\n", " eig = torch.linalg.eigvalsh(QK_cpu).numpy()[::-1]\n", " del QK_cpu\n", "\n", " n_pos=(eig>0).sum(); n_neg=(eig<0).sum(); dim=len(eig)\n", " print(f\" L{layer:2d}: SR={sr:.2f}, pos={n_pos}({n_pos/dim:.3f}), \"\n", " f\"neg={n_neg}({n_neg/dim:.3f}), sym={sym:.4f}, \"\n", " f\"top={eig[0]:.2f} ({time.time()-t0:.1f}s)\")\n", " qk_results.append({'layer':layer,'n_pos':n_pos,'n_neg':n_neg,'dim':dim,'sr':sr})\n", " del eig; gc.collect()\n", " except Exception as e:\n", " print(f\" L{layer}: FAIL — {e}\")\n", "\n", " if len(qk_results)>=2:\n", " f,l=qk_results[0],qk_results[-1]\n", " print(f\" Trend: L{f['layer']}={f['n_pos']/f['dim']:.3f} → L{l['layer']}={l['n_pos']/l['dim']:.3f}\")\n", "\n", "# Cross-attention QK\n", "print(f\"\\n--- DECODER cross-attention ---\")\n", "xq_map = {e['layer']:e['name'] for e in catalog.get('cross_attn_q',[]) if e['loc']=='decoder'}\n", "xk_map = {e['layer']:e['name'] for e in catalog.get('cross_attn_k',[]) if e['loc']=='decoder'}\n", "for layer in dec_sample:\n", " if layer not in xq_map or layer not in xk_map: continue\n", " try:\n", " t0 = time.time()\n", " Wq = get_w(xq_map[layer]); Wk = get_w(xk_map[layer])\n", " QK = Wq @ Wk.T; del Wq, Wk\n", " sym = torch.norm(QK-QK.T).item()/(torch.norm(QK).item()+1e-10)\n", " QK_cpu = ((QK+QK.T)/2).cpu(); del QK; torch.cuda.empty_cache()\n", " eig = torch.linalg.eigvalsh(QK_cpu).numpy()[::-1]; del QK_cpu\n", " n_pos=(eig>0).sum(); n_neg=(eig<0).sum(); dim=len(eig)\n", " print(f\" L{layer:2d}: pos={n_pos}({n_pos/dim:.3f}), neg={n_neg}({n_neg/dim:.3f}), \"\n", " f\"sym={sym:.4f}, top={eig[0]:.2f} ({time.time()-t0:.1f}s)\")\n", " del eig; gc.collect()\n", " except Exception as e:\n", " print(f\" L{layer}: FAIL — {e}\")\n", "\n", "# ── DEAD NEURONS ──\n", "print(f\"\\n{'='*70}\\nMLP DEAD NEURONS (GeGLU)\\n{'='*70}\")\n", "for loc, samples in [('encoder', enc_sample), ('decoder', dec_sample)]:\n", " print(f\"\\n--- {loc.upper()} ---\")\n", " g_map = {e['layer']:e['name'] for e in catalog.get('mlp_gate',[]) if e['loc']==loc}\n", " u_map = {e['layer']:e['name'] for e in catalog.get('mlp_up',[]) if e['loc']==loc}\n", " d_map = {e['layer']:e['name'] for e in catalog.get('mlp_down',[]) if e['loc']==loc}\n", " td=tn=0\n", " for layer in samples:\n", " if layer not in g_map or layer not in u_map or layer not in d_map: continue\n", " Wg = get_w(g_map[layer]); gn = torch.norm(Wg,dim=1).cpu().numpy(); del Wg\n", " Wu = get_w(u_map[layer]); un = torch.norm(Wu,dim=1).cpu().numpy(); del Wu\n", " Wd = get_w(d_map[layer]); dn = torch.norm(Wd,dim=0).cpu().numpy(); del Wd\n", " c = gn*un*dn; d_ff=len(c); mc=c.mean()\n", " dead=(c<0.01*mc).sum(); weak=(c<0.10*mc).sum()\n", " td+=dead; tn+=d_ff\n", " print(f\" L{layer:2d}: d_ff={d_ff}, dead={dead}({dead/d_ff*100:.1f}%), weak={weak}({weak/d_ff*100:.1f}%)\")\n", " if tn: print(f\" Total: {td}/{tn} ({td/tn*100:.2f}%)\")\n", "\n", "# ── CROSS-LAYER Q ──\n", "print(f\"\\n{'='*70}\\nCROSS-LAYER Q CORRELATION\\n{'='*70}\")\n", "for loc, samples in [('encoder', enc_sample), ('decoder', dec_sample)]:\n", " q_map = {e['layer']:e['name'] for e in catalog.get('self_attn_q',[]) if e['loc']==loc}\n", " qf = {}\n", " for l in samples:\n", " if l in q_map:\n", " try: qf[l] = get_w(q_map[l]).cpu().flatten()\n", " except: pass\n", " if len(qf)>=2:\n", " ls = sorted(qf)\n", " adj = []\n", " for i in range(len(ls)-1):\n", " a,b = qf[ls[i]],qf[ls[i+1]]\n", " if a.shape==b.shape:\n", " adj.append((torch.dot(a,b)/(torch.norm(a)*torch.norm(b)+1e-8)).item())\n", " if adj:\n", " print(f\" {loc} adj Q cos: mean={np.mean(adj):.4f}, range=[{min(adj):.4f},{max(adj):.4f}]\")\n", " del qf; gc.collect()\n", "\n", "# ── POSITION BIAS ──\n", "print(f\"\\n{'='*70}\\nPOSITION BIAS\\n{'='*70}\")\n", "for entry in catalog.get('position_bias',[]):\n", " rpb = get_w(entry['name']).cpu().numpy()\n", " nb,nh = rpb.shape\n", " loc_h = sum(1 for h in range(nh) if np.argmax(rpb[:,h])<=2)\n", " glb_h = sum(1 for h in range(nh) if np.argmax(rpb[:,h])>=nb-3)\n", " print(f\" {entry['loc']:8s}: [{nb}×{nh}] Local:{loc_h} Global:{glb_h} \"\n", " f\"Mixed:{nh-loc_h-glb_h} Range:[{rpb.min():.1f},{rpb.max():.1f}]\")\n", "\n", "# ── SUMMARY ──\n", "print(f\"\\n{'='*70}\\nSUMMARY — T5-v1.1-XXL (FLUX)\\n{'='*70}\")\n", "print(f\"Params: {total_params:,}\")\n", "print(f\"d_model={config.d_model}, d_ff={config.d_ff}, heads={config.num_heads}\")\n", "print(f\"Layers: {config.num_layers} enc + {config.num_decoder_layers} dec\")\n", "print(f\"MLP: {config.feed_forward_proj} (GeGLU)\")\n", "for w in ['self_attn_q','self_attn_k','self_attn_v','cross_attn_q']:\n", " sc = sparsity.get(w,{})\n", " if sc.get('total',0) > 0:\n", " print(f\" {w} (<0.1): {sc[1e-1]/sc['total']*100:.1f}%\")\n", "print(f\"\\nRef: T5-Small Q=93.7% | T5-Base Q=99.4% | BERT=99.1% | DINOv2=100%\")\n", "print(f\"VRAM at end: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "print(\"Done.\")\n" ] }, { "cell_type": "markdown", "id": "27266aa7", "metadata": {}, "source": [ "## 7. Geometric Residual Modulator (LERP)\n", "*Section VII: per-token geometric embedding + learned projection + per-layer alpha*" ] }, { "cell_type": "code", "execution_count": null, "id": "28454051", "metadata": {}, "outputs": [], "source": [ "# LERP modulator\n", "\n", "# ============================================================================\n", "# GEOMETRIC RESIDUAL MODULATOR\n", "# Injects geometric structure into the residual stream via LERP.\n", "# Q reads a better map. No loss changes. No architecture changes.\n", "# Just a better starting terrain for the existing compass.\n", "# ============================================================================\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import math\n", "from typing import Optional, Dict\n", "\n", "\n", "class GeometricResidualModulator(nn.Module):\n", " \"\"\"\n", " Injects geometric structure into a transformer's residual stream.\n", "\n", " Mechanism:\n", " residual_out = (1 - alpha) * residual_in + alpha * geometric_target\n", "\n", " Where geometric_target is derived from the token's geometric embedding\n", " projected into the residual stream's coordinate system.\n", "\n", " The alpha is small (0.01-0.1) so the model's existing computation\n", " dominates. The geometric delta is a nudge, not a replacement.\n", "\n", " Intervention point: between residual accumulation and pre-attention LayerNorm.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " d_model: int = 512,\n", " vocab_size: int = 32128,\n", " n_geometric_dims: int = 64,\n", " initial_alpha: float = 0.01,\n", " learnable_alpha: bool = True,\n", " per_layer_alpha: bool = False,\n", " n_layers: int = 6,\n", " ):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.vocab_size = vocab_size\n", " self.n_geometric_dims = n_geometric_dims\n", " self.n_layers = n_layers\n", "\n", " # Geometric embedding: each token gets a geometric fingerprint\n", " # This is the \"terrain\" — initialized with structure, then frozen or lightly tuned\n", " self.geometric_embed = nn.Embedding(vocab_size, n_geometric_dims)\n", "\n", " # Projection from geometric space to residual stream space\n", " # This learns HOW to inject the geometry, not WHAT the geometry is\n", " self.proj = nn.Linear(n_geometric_dims, d_model, bias=False)\n", "\n", " # LERP coefficient — store in logit space so sigmoid gives desired alpha\n", " # sigmoid(logit(x)) = x\n", " def _logit(x):\n", " return math.log(x / (1 - x))\n", "\n", " if per_layer_alpha:\n", " if learnable_alpha:\n", " self.alpha = nn.Parameter(torch.full((n_layers,), _logit(initial_alpha)))\n", " else:\n", " self.register_buffer('alpha', torch.full((n_layers,), _logit(initial_alpha)))\n", " else:\n", " if learnable_alpha:\n", " self.alpha = nn.Parameter(torch.tensor(_logit(initial_alpha)))\n", " else:\n", " self.register_buffer('alpha', torch.tensor(_logit(initial_alpha)))\n", "\n", " # Initialize projection to be small — don't disrupt existing model\n", " nn.init.normal_(self.proj.weight, std=0.01)\n", "\n", " def init_from_cayley_menger(self, simplices: torch.Tensor):\n", " \"\"\"\n", " Initialize geometric embeddings from precomputed simplex vertices.\n", "\n", " Args:\n", " simplices: [vocab_size, n_vertices, n_geometric_dims] or\n", " [vocab_size, n_geometric_dims] (already pooled)\n", " \"\"\"\n", " with torch.no_grad():\n", " if simplices.dim() == 3:\n", " # Sum-pool vertices\n", " pooled = simplices.sum(dim=1)\n", " else:\n", " pooled = simplices\n", "\n", " # Normalize to unit sphere — geometry is in the directions, not magnitudes\n", " norms = pooled.norm(dim=1, keepdim=True).clamp(min=1e-8)\n", " normalized = pooled / norms\n", "\n", " self.geometric_embed.weight.copy_(normalized[:self.vocab_size])\n", "\n", " def init_from_relational_target(self, pairwise_cosine_target: torch.Tensor, n_dims: int = None):\n", " \"\"\"\n", " Initialize geometric embeddings from a target pairwise similarity matrix.\n", " Uses eigendecomposition to find embeddings that reproduce the target.\n", "\n", " Args:\n", " pairwise_cosine_target: [vocab_size, vocab_size] target similarity\n", " n_dims: number of dimensions to keep (default: self.n_geometric_dims)\n", " \"\"\"\n", " if n_dims is None:\n", " n_dims = self.n_geometric_dims\n", "\n", " with torch.no_grad():\n", " # Eigendecomposition of target similarity\n", " eigvals, eigvecs = torch.linalg.eigh(pairwise_cosine_target)\n", " # Take top n_dims eigenvectors (largest eigenvalues)\n", " top_idx = torch.argsort(eigvals, descending=True)[:n_dims]\n", " top_vecs = eigvecs[:, top_idx] # [vocab_size, n_dims]\n", " top_vals = eigvals[top_idx].clamp(min=0).sqrt() # scale by sqrt(eigenvalue)\n", "\n", " embeddings = top_vecs * top_vals.unsqueeze(0)\n", "\n", " # Normalize\n", " norms = embeddings.norm(dim=1, keepdim=True).clamp(min=1e-8)\n", " embeddings = embeddings / norms\n", "\n", " self.geometric_embed.weight.copy_(embeddings[:self.vocab_size])\n", "\n", " def forward(\n", " self,\n", " residual: torch.Tensor,\n", " token_ids: torch.Tensor,\n", " layer_idx: int = 0,\n", " ) -> torch.Tensor:\n", " \"\"\"\n", " Apply geometric modulation to residual stream.\n", "\n", " Args:\n", " residual: [batch, seq_len, d_model] current residual state\n", " token_ids: [batch, seq_len] token indices\n", " layer_idx: which layer (for per-layer alpha)\n", "\n", " Returns:\n", " modulated: [batch, seq_len, d_model] residual with geometric delta\n", " \"\"\"\n", " # Get geometric fingerprints for these tokens\n", " geo = self.geometric_embed(token_ids) # [B, S, n_geo]\n", "\n", " # Project into residual stream coordinates\n", " geo_projected = self.proj(geo) # [B, S, d_model]\n", "\n", " # Get alpha for this layer\n", " if self.alpha.dim() > 0:\n", " a = torch.sigmoid(self.alpha[layer_idx]) # sigmoid to keep in [0, 1]\n", " else:\n", " a = torch.sigmoid(self.alpha)\n", "\n", " # LERP: nudge residual toward geometric target\n", " modulated = (1 - a) * residual + a * geo_projected\n", "\n", " return modulated\n", "\n", " def get_geometric_similarity(self, token_ids_a: torch.Tensor, token_ids_b: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Compute cosine similarity in geometric space between token sets.\"\"\"\n", " geo_a = self.geometric_embed(token_ids_a) # [..., n_geo]\n", " geo_b = self.geometric_embed(token_ids_b)\n", " geo_a = F.normalize(geo_a, dim=-1)\n", " geo_b = F.normalize(geo_b, dim=-1)\n", " return (geo_a * geo_b).sum(dim=-1)\n", "\n", " def geometric_residuals(self) -> Dict[str, torch.Tensor]:\n", " \"\"\"\n", " Compute geometric health metrics — the PDE residuals.\n", " These measure how well the geometric embeddings satisfy structural constraints.\n", " \"\"\"\n", " W = self.geometric_embed.weight # [vocab_size, n_geo]\n", " W_n = F.normalize(W, dim=1)\n", "\n", " # 1. Pairwise cosine distribution (sample for speed)\n", " idx = torch.randperm(min(self.vocab_size, 5000))[:5000]\n", " sample = W_n[idx]\n", " cos_mat = sample @ sample.T\n", " tri = torch.triu_indices(len(idx), len(idx), offset=1)\n", " flat_cos = cos_mat[tri[0], tri[1]]\n", "\n", " # 2. Norm distribution\n", " norms = W.norm(dim=1)\n", "\n", " # 3. Effective dimensionality (participation ratio)\n", " # Covariance eigenspectrum\n", " centered = W - W.mean(dim=0)\n", " cov = (centered.T @ centered) / W.shape[0]\n", " eigvals = torch.linalg.eigvalsh(cov)\n", " pr = (eigvals.sum() ** 2) / (eigvals ** 2).sum()\n", "\n", " # 4. Alpha state\n", " a = torch.sigmoid(self.alpha)\n", "\n", " return {\n", " 'cos_mean': flat_cos.mean(),\n", " 'cos_std': flat_cos.std(),\n", " 'norm_mean': norms.mean(),\n", " 'norm_std': norms.std(),\n", " 'participation_ratio': pr,\n", " 'pr_over_dim': pr / self.n_geometric_dims,\n", " 'alpha': a,\n", " }\n", "\n", "\n", "class ModulatedT5Encoder(nn.Module):\n", " \"\"\"\n", " Wraps a T5 encoder with geometric residual modulation.\n", " Hooks into the residual stream at configurable layers.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " t5_encoder: nn.Module,\n", " modulator: GeometricResidualModulator,\n", " modulate_layers: Optional[list] = None,\n", " ):\n", " super().__init__()\n", " self.encoder = t5_encoder\n", " self.modulator = modulator\n", " # Default: modulate at every layer\n", " if modulate_layers is None:\n", " n_layers = len(t5_encoder.block)\n", " modulate_layers = list(range(n_layers))\n", " self.modulate_layers = set(modulate_layers)\n", "\n", " # Store token_ids for the hooks\n", " self._current_token_ids = None\n", "\n", " def forward(self, input_ids, attention_mask=None, output_hidden_states=False, **kwargs):\n", " \"\"\"\n", " Forward pass with geometric modulation injected into residual stream.\n", "\n", " We manually step through encoder blocks instead of calling encoder.forward()\n", " so we can intervene between blocks.\n", " \"\"\"\n", " self._current_token_ids = input_ids\n", "\n", " # Get embeddings\n", " hidden_states = self.encoder.embed_tokens(input_ids)\n", " hidden_states = self.encoder.dropout(hidden_states)\n", "\n", " # Prepare extended attention mask (same as T5 encoder does internally)\n", " if attention_mask is not None:\n", " # [B, S] -> [B, 1, 1, S] with large negative values for padding\n", " extended_attention_mask = attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)\n", " extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(hidden_states.dtype).min\n", " else:\n", " extended_attention_mask = None\n", "\n", " all_hidden_states = [hidden_states] if output_hidden_states else None\n", "\n", " position_bias = None\n", "\n", " # Build cache_position — simple sequential indices\n", " seq_length = input_ids.shape[1]\n", " cache_position = torch.arange(seq_length, device=input_ids.device)\n", "\n", " # Step through blocks with modulation\n", " for i, block in enumerate(self.encoder.block):\n", " # MODULATE: inject geometric structure before the block processes it\n", " if i in self.modulate_layers:\n", " hidden_states = self.modulator(\n", " residual=hidden_states,\n", " token_ids=input_ids,\n", " layer_idx=i,\n", " )\n", "\n", " # Run the actual transformer block\n", " block_output = block(\n", " hidden_states,\n", " attention_mask=extended_attention_mask,\n", " position_bias=position_bias,\n", " cache_position=cache_position,\n", " )\n", " hidden_states = block_output[0]\n", "\n", " # Extract position bias from first block's self-attention output\n", " # T5Block returns: (hidden, ) + attention_outputs\n", " # where attention_outputs includes position_bias\n", " if position_bias is None:\n", " # position_bias is typically the second element after hidden_states\n", " # but the exact index depends on block config; try to find it\n", " for out in block_output[1:]:\n", " if isinstance(out, torch.Tensor) and out.dim() == 4:\n", " position_bias = out\n", " break\n", "\n", " if output_hidden_states:\n", " all_hidden_states.append(hidden_states)\n", "\n", " # Final layer norm\n", " hidden_states = self.encoder.final_layer_norm(hidden_states)\n", " hidden_states = self.encoder.dropout(hidden_states)\n", "\n", " if output_hidden_states:\n", " all_hidden_states.append(hidden_states)\n", " return type('Output', (), {\n", " 'last_hidden_state': hidden_states,\n", " 'hidden_states': tuple(all_hidden_states),\n", " })()\n", " else:\n", " return type('Output', (), {\n", " 'last_hidden_state': hidden_states,\n", " })()\n", "\n", "\n", "# ============================================================================\n", "# MEASUREMENT UTILITIES\n", "# ============================================================================\n", "\n", "def measure_modulator_impact(\n", " original_encoder,\n", " modulated_encoder,\n", " tokenizer,\n", " test_sentences: list,\n", ") -> Dict[str, float]:\n", " \"\"\"\n", " Compare encoder outputs with and without geometric modulation.\n", " Measures how much the modulator changes the representation.\n", " \"\"\"\n", " device = next(modulated_encoder.parameters()).device\n", "\n", " results = {\n", " 'cos_per_token': [],\n", " 'norm_ratio': [],\n", " 'pairwise_cos_shift': [],\n", " }\n", "\n", " for sent in test_sentences:\n", " inputs = tokenizer(sent, return_tensors=\"pt\", padding=False).to(device)\n", "\n", " with torch.no_grad():\n", " orig_out = original_encoder(\n", " input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask,\n", " )\n", " mod_out = modulated_encoder(\n", " input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask,\n", " )\n", "\n", " orig_h = orig_out.last_hidden_state[0] # [seq, d_model]\n", " mod_h = mod_out.last_hidden_state[0]\n", "\n", " # Per-token cosine between original and modulated\n", " cos = F.cosine_similarity(orig_h, mod_h, dim=-1)\n", " results['cos_per_token'].extend(cos.cpu().tolist())\n", "\n", " # Norm ratio\n", " orig_norms = orig_h.norm(dim=-1)\n", " mod_norms = mod_h.norm(dim=-1)\n", " ratio = (mod_norms / (orig_norms + 1e-8))\n", " results['norm_ratio'].extend(ratio.cpu().tolist())\n", "\n", " # Pairwise cosine shift\n", " if orig_h.shape[0] > 1:\n", " orig_n = F.normalize(orig_h, dim=-1)\n", " mod_n = F.normalize(mod_h, dim=-1)\n", " orig_pw = (orig_n @ orig_n.T)\n", " mod_pw = (mod_n @ mod_n.T)\n", " tri = torch.triu_indices(orig_h.shape[0], orig_h.shape[0], offset=1)\n", " shift = (mod_pw[tri[0], tri[1]] - orig_pw[tri[0], tri[1]])\n", " results['pairwise_cos_shift'].extend(shift.cpu().tolist())\n", "\n", " return {\n", " 'cos_mean': np.mean(results['cos_per_token']),\n", " 'cos_std': np.std(results['cos_per_token']),\n", " 'norm_ratio_mean': np.mean(results['norm_ratio']),\n", " 'pairwise_shift_mean': np.mean(results['pairwise_cos_shift']),\n", " 'pairwise_shift_std': np.std(results['pairwise_cos_shift']),\n", " }\n", "\n", "\n", "# ============================================================================\n", "# QUICK TEST\n", "# ============================================================================\n", "\n", "if __name__ == \"__main__\":\n", " from transformers import T5ForConditionalGeneration, T5Tokenizer\n", "\n", " model_id = \"google-t5/t5-small\"\n", " print(f\"Loading {model_id}...\")\n", " tokenizer = T5Tokenizer.from_pretrained(model_id, legacy=True)\n", " model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float32)\n", " model.eval()\n", "\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model = model.to(device)\n", "\n", " # Create modulator\n", " modulator = GeometricResidualModulator(\n", " d_model=512,\n", " vocab_size=32128,\n", " n_geometric_dims=64,\n", " initial_alpha=0.01, # Start very small — barely perturb\n", " learnable_alpha=True,\n", " per_layer_alpha=True,\n", " n_layers=6,\n", " ).to(device)\n", "\n", " # Wrap encoder\n", " mod_encoder = ModulatedT5Encoder(\n", " t5_encoder=model.encoder,\n", " modulator=modulator,\n", " modulate_layers=[0, 1, 2, 3, 4, 5], # all layers\n", " )\n", "\n", " print(f\"\\nModulator params: {sum(p.numel() for p in modulator.parameters()):,}\")\n", " print(f\" Geometric embed: {modulator.geometric_embed.weight.shape}\")\n", " print(f\" Projection: {modulator.proj.weight.shape}\")\n", " print(f\" Alpha: {torch.sigmoid(modulator.alpha).detach().cpu().numpy()}\")\n", "\n", " # Check geometric health\n", " health = modulator.geometric_residuals()\n", " print(f\"\\nGeometric health (random init):\")\n", " for k, v in health.items():\n", " if isinstance(v, torch.Tensor):\n", " print(f\" {k}: {v.item():.6f}\" if v.dim() == 0 else f\" {k}: {v.detach().cpu().numpy()}\")\n", " else:\n", " print(f\" {k}: {v}\")\n", "\n", " # Measure impact on encoder output\n", " test_sents = [\n", " \"summarize: The cat sat on the mat.\",\n", " \"summarize: Quantum mechanics describes particles at atomic scale.\",\n", " \"summarize: The derivative of x squared is two x.\",\n", " \"summarize: Love is patient, love is kind.\",\n", " \"summarize: Mount Everest is the tallest mountain.\",\n", " ]\n", "\n", " impact = measure_modulator_impact(\n", " original_encoder=model.encoder,\n", " modulated_encoder=mod_encoder,\n", " tokenizer=tokenizer,\n", " test_sentences=test_sents,\n", " )\n", "\n", " print(f\"\\nModulator impact (alpha={torch.sigmoid(modulator.alpha).mean().item():.4f}):\")\n", " for k, v in impact.items():\n", " print(f\" {k}: {v:.6f}\")\n", "\n", " print(f\"\\nAt alpha=0.01, the modulator should barely change the output.\")\n", " print(f\"cos_mean near 1.0 = minimal disruption. Good.\")\n", " print(f\"The geometric structure is there but whispering, not shouting.\")\n" ] }, { "cell_type": "markdown", "id": "20b431c9", "metadata": {}, "source": [ "## 8. Self-Contained Modulator Pipeline\n", "*Section VII end-to-end: T5 → WordNet → Encode → Modulator → Procrustes → Measure*" ] }, { "cell_type": "code", "execution_count": null, "id": "9a1ccb61", "metadata": {}, "outputs": [], "source": [ "# Full pipeline\n", "\n", "# ============================================================================\n", "# GEOMETRIC RESIDUAL MODULATOR — FULL SELF-CONTAINED PIPELINE\n", "# Load T5 → Match WordNet → Encode → Build Modulator → Procrustes → Measure\n", "# One file. No dependencies on prior cells. Just run it.\n", "# ============================================================================\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import math\n", "import time\n", "from scipy.linalg import orthogonal_procrustes\n", "from scipy.stats import spearmanr\n", "from tqdm import tqdm\n", "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", "from collections import defaultdict\n", "\n", "import nltk\n", "nltk.download('wordnet', quiet=True)\n", "nltk.download('omw-1.4', quiet=True)\n", "from nltk.corpus import wordnet as wn\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Device: {device}\")\n", "\n", "model_id = \"google-t5/t5-small\"\n", "print(f\"Loading {model_id}...\")\n", "tokenizer = T5Tokenizer.from_pretrained(model_id, legacy=True)\n", "model = T5ForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float32)\n", "model.eval()\n", "model = model.to(device)\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"MATCHING WORDNET → T5 TOKENIZER\")\n", "print(f\"{'='*70}\")\n", "\n", "matched = []\n", "seen_tokens = set()\n", "for synset in wn.all_synsets():\n", " for lemma in synset.lemmas():\n", " name = lemma.name().replace('_', ' ')\n", " ids = tokenizer.encode(name, add_special_tokens=False)\n", " if len(ids) == 1 and ids[0] not in seen_tokens:\n", " defn = synset.definition()\n", " if len(defn) > 10:\n", " matched.append((name, synset, ids[0], defn))\n", " seen_tokens.add(ids[0])\n", "\n", "synsets = [m[1] for m in matched]\n", "token_ids_list = [m[2] for m in matched]\n", "texts = [f\"summarize: {m[3]}\" for m in matched]\n", "print(f\"Matched: {len(matched)} tokens\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ENCODING THROUGH ORIGINAL ENCODER\")\n", "print(f\"{'='*70}\")\n", "\n", "BATCH_SIZE = 64\n", "MAX_LEN = 128\n", "\n", "encoder_reps = np.zeros((len(matched), 512), dtype=np.float32)\n", "t0 = time.time()\n", "n_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE\n", "\n", "for batch_idx in range(n_batches):\n", " start = batch_idx * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(texts))\n", " inputs = tokenizer(texts[start:end], return_tensors=\"pt\", padding=True,\n", " truncation=True, max_length=MAX_LEN).to(device)\n", " with torch.no_grad():\n", " enc_out = model.encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " hidden = enc_out.last_hidden_state.float()\n", " mask = inputs.attention_mask.unsqueeze(-1).float()\n", " pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1)\n", " encoder_reps[start:end] = pooled.cpu().numpy()\n", "\n", "print(f\"Encoded {len(texts)} definitions in {time.time()-t0:.1f}s\")\n", "\n", "# Static embeddings\n", "E = model.shared.weight.detach().float().cpu().numpy()\n", "static_reps = E[token_ids_list]\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"BUILDING GEOMETRIC RESIDUAL MODULATOR\")\n", "print(f\"{'='*70}\")\n", "\n", "\n", "class GeometricResidualModulator(nn.Module):\n", " def __init__(self, d_model=512, vocab_size=32128, n_geometric_dims=64,\n", " initial_alpha=0.01, n_layers=6):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.n_geometric_dims = n_geometric_dims\n", " self.geometric_embed = nn.Embedding(vocab_size, n_geometric_dims)\n", " self.proj = nn.Linear(n_geometric_dims, d_model, bias=False)\n", " logit = math.log(initial_alpha / (1 - initial_alpha))\n", " self.alpha = nn.Parameter(torch.full((n_layers,), logit))\n", " nn.init.normal_(self.proj.weight, std=0.01)\n", "\n", " def forward(self, residual, token_ids, layer_idx=0):\n", " geo = self.geometric_embed(token_ids)\n", " geo_projected = self.proj(geo)\n", " a = torch.sigmoid(self.alpha[layer_idx])\n", " return (1 - a) * residual + a * geo_projected\n", "\n", " def geometric_residuals(self):\n", " W = self.geometric_embed.weight\n", " W_n = F.normalize(W, dim=1)\n", " idx = torch.randperm(min(W.shape[0], 5000))[:5000]\n", " sample = W_n[idx]\n", " cos_mat = sample @ sample.T\n", " tri = torch.triu_indices(len(idx), len(idx), offset=1)\n", " flat_cos = cos_mat[tri[0], tri[1]]\n", " norms = W.norm(dim=1)\n", " centered = W - W.mean(dim=0)\n", " cov = (centered.T @ centered) / W.shape[0]\n", " eigvals = torch.linalg.eigvalsh(cov)\n", " pr = (eigvals.sum() ** 2) / (eigvals ** 2).sum()\n", " return {\n", " 'cos_mean': flat_cos.mean().item(),\n", " 'cos_std': flat_cos.std().item(),\n", " 'norm_mean': norms.mean().item(),\n", " 'pr_over_dim': (pr / self.n_geometric_dims).item(),\n", " 'alpha': torch.sigmoid(self.alpha).detach().cpu().numpy(),\n", " }\n", "\n", "\n", "class ModulatedT5Encoder(nn.Module):\n", " def __init__(self, t5_encoder, modulator, modulate_layers=None):\n", " super().__init__()\n", " self.encoder = t5_encoder\n", " self.modulator = modulator\n", " if modulate_layers is None:\n", " modulate_layers = list(range(len(t5_encoder.block)))\n", " self.modulate_layers = set(modulate_layers)\n", "\n", " def forward(self, input_ids, attention_mask=None, output_hidden_states=False, **kwargs):\n", " hidden_states = self.encoder.embed_tokens(input_ids)\n", " hidden_states = self.encoder.dropout(hidden_states)\n", "\n", " if attention_mask is not None:\n", " extended_attention_mask = attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)\n", " extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(hidden_states.dtype).min\n", " else:\n", " extended_attention_mask = None\n", "\n", " all_hidden_states = [hidden_states] if output_hidden_states else None\n", " position_bias = None\n", " seq_length = input_ids.shape[1]\n", " cache_position = torch.arange(seq_length, device=input_ids.device)\n", "\n", " for i, block in enumerate(self.encoder.block):\n", " if i in self.modulate_layers:\n", " hidden_states = self.modulator(hidden_states, input_ids, layer_idx=i)\n", "\n", " block_output = block(hidden_states, attention_mask=extended_attention_mask,\n", " position_bias=position_bias, cache_position=cache_position)\n", " hidden_states = block_output[0]\n", "\n", " if position_bias is None:\n", " for out in block_output[1:]:\n", " if isinstance(out, torch.Tensor) and out.dim() == 4:\n", " position_bias = out\n", " break\n", "\n", " if output_hidden_states:\n", " all_hidden_states.append(hidden_states)\n", "\n", " hidden_states = self.encoder.final_layer_norm(hidden_states)\n", " hidden_states = self.encoder.dropout(hidden_states)\n", "\n", " if output_hidden_states:\n", " all_hidden_states.append(hidden_states)\n", "\n", " return type('Output', (), {\n", " 'last_hidden_state': hidden_states,\n", " 'hidden_states': tuple(all_hidden_states) if all_hidden_states else None,\n", " })()\n", "\n", "\n", "N_GEO = 64\n", "modulator = GeometricResidualModulator(\n", " d_model=512, vocab_size=32128, n_geometric_dims=N_GEO,\n", " initial_alpha=0.01, n_layers=6,\n", ").to(device)\n", "\n", "mod_encoder = ModulatedT5Encoder(\n", " t5_encoder=model.encoder, modulator=modulator,\n", " modulate_layers=[0, 1, 2, 3, 4, 5],\n", ")\n", "\n", "print(f\"Modulator params: {sum(p.numel() for p in modulator.parameters()):,}\")\n", "print(f\"Alpha: {torch.sigmoid(modulator.alpha).detach().cpu().numpy()}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"BUILDING WORDNET SIMILARITY MATRIX (3000 anchors)\")\n", "print(f\"{'='*70}\")\n", "\n", "N_ANCHOR = 3000\n", "rng = np.random.default_rng(42)\n", "anchor_idx = rng.choice(len(matched), size=N_ANCHOR, replace=False)\n", "anchor_synsets = [synsets[i] for i in anchor_idx]\n", "\n", "sim_matrix = np.eye(N_ANCHOR, dtype=np.float32)\n", "t0 = time.time()\n", "\n", "for i in tqdm(range(N_ANCHOR), desc=\"WN sim\", miniters=100):\n", " syn_i = anchor_synsets[i]\n", " for j in range(i + 1, N_ANCHOR):\n", " sim = syn_i.wup_similarity(anchor_synsets[j])\n", " if sim is not None:\n", " sim_matrix[i, j] = sim\n", " sim_matrix[j, i] = sim\n", "\n", "print(f\"Built in {time.time()-t0:.1f}s\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"EIGENDECOMPOSITION → 64-d GEOMETRIC EMBEDDINGS\")\n", "print(f\"{'='*70}\")\n", "\n", "eigvals, eigvecs = np.linalg.eigh(sim_matrix)\n", "idx_sort = np.argsort(eigvals)[::-1]\n", "eigvals = eigvals[idx_sort]\n", "eigvecs = eigvecs[:, idx_sort]\n", "\n", "top_vals = eigvals[:N_GEO]\n", "top_vecs = eigvecs[:, :N_GEO]\n", "scales = np.sqrt(np.maximum(top_vals, 0))\n", "anchor_geo = top_vecs * scales[None, :]\n", "norms = np.linalg.norm(anchor_geo, axis=1, keepdims=True)\n", "anchor_geo = anchor_geo / np.maximum(norms, 1e-8)\n", "\n", "recon_cos = anchor_geo @ anchor_geo.T\n", "tri = np.triu_indices(N_ANCHOR, k=1)\n", "recon_corr = np.corrcoef(sim_matrix[tri], recon_cos[tri])[0, 1]\n", "print(f\"Reconstruction correlation: {recon_corr:.4f}\")\n", "print(f\"Top 5 eigenvalues: {eigvals[:5]}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"FAST PROJECTION VIA EMBEDDING COSINE PROXY\")\n", "print(f\"{'='*70}\")\n", "\n", "t0 = time.time()\n", "static_t = torch.tensor(static_reps, device=device, dtype=torch.float32)\n", "static_n = static_t / (static_t.norm(dim=1, keepdim=True) + 1e-8)\n", "anchor_static_n = static_n[anchor_idx]\n", "anchor_geo_t = torch.tensor(anchor_geo, device=device, dtype=torch.float32)\n", "\n", "all_geo = torch.zeros(len(matched), N_GEO, device=device, dtype=torch.float32)\n", "K_NEIGHBORS = 10\n", "BATCH = 1000\n", "\n", "# Place anchors\n", "for local_i, global_i in enumerate(anchor_idx):\n", " all_geo[global_i] = anchor_geo_t[local_i]\n", "\n", "# Project non-anchors\n", "non_anchor_mask = torch.ones(len(matched), dtype=torch.bool)\n", "non_anchor_mask[anchor_idx] = False\n", "non_anchor_indices = torch.where(non_anchor_mask)[0]\n", "\n", "for batch_start in range(0, len(non_anchor_indices), BATCH):\n", " batch_end = min(batch_start + BATCH, len(non_anchor_indices))\n", " batch_idx = non_anchor_indices[batch_start:batch_end]\n", " batch_static = static_n[batch_idx]\n", " cos_to_anchors = batch_static @ anchor_static_n.T\n", " topk_vals, topk_idx = cos_to_anchors.topk(K_NEIGHBORS, dim=1)\n", " weights = torch.softmax(topk_vals * 10.0, dim=1)\n", " neighbor_geo = anchor_geo_t[topk_idx]\n", " interpolated = (neighbor_geo * weights.unsqueeze(-1)).sum(dim=1)\n", " interpolated = interpolated / (interpolated.norm(dim=1, keepdim=True) + 1e-8)\n", " all_geo[batch_idx] = interpolated\n", "\n", "print(f\"Projected {len(non_anchor_indices)} tokens in {time.time()-t0:.1f}s\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"PROCRUSTES ALIGNMENT: Geometric → Residual Stream\")\n", "print(f\"{'='*70}\")\n", "\n", "anchor_enc = encoder_reps[anchor_idx]\n", "enc_mean = anchor_enc.mean(axis=0)\n", "enc_centered = anchor_enc - enc_mean\n", "U, S, Vt = np.linalg.svd(enc_centered, full_matrices=False)\n", "pca_components = Vt[:N_GEO]\n", "\n", "enc_pca = enc_centered @ pca_components.T\n", "enc_pca_n = enc_pca / (np.linalg.norm(enc_pca, axis=1, keepdims=True) + 1e-8)\n", "\n", "anchor_geo_np = all_geo[anchor_idx].cpu().numpy()\n", "geo_n = anchor_geo_np / (np.linalg.norm(anchor_geo_np, axis=1, keepdims=True) + 1e-8)\n", "\n", "geo_mean = geo_n.mean(axis=0)\n", "enc_mean_pca = enc_pca_n.mean(axis=0)\n", "geo_c = geo_n - geo_mean\n", "enc_c = enc_pca_n - enc_mean_pca\n", "\n", "R, procrustes_scale = orthogonal_procrustes(geo_c, enc_c)\n", "print(f\"Procrustes scale: {procrustes_scale:.4f}\")\n", "\n", "aligned_geo = geo_c @ R + enc_mean_pca\n", "alignment_cos = np.sum(aligned_geo * enc_pca_n, axis=1) / (\n", " np.linalg.norm(aligned_geo, axis=1) * np.linalg.norm(enc_pca_n, axis=1) + 1e-8\n", ")\n", "print(f\"Alignment cosine: mean={alignment_cos.mean():.4f} std={alignment_cos.std():.4f}\")\n", "\n", "proj_weight = (R @ pca_components).T\n", "residual_scale = np.linalg.norm(anchor_enc, axis=1).mean()\n", "proj_weight = proj_weight * (residual_scale * 0.1)\n", "print(f\"Projection norm: {np.linalg.norm(proj_weight):.4f}\")\n", "\n", "with torch.no_grad():\n", " modulator.proj.weight.copy_(torch.tensor(proj_weight, dtype=torch.float32, device=device))\n", " for i, (name, syn, tid, defn) in enumerate(matched):\n", " modulator.geometric_embed.weight[tid] = all_geo[i]\n", "\n", "health = modulator.geometric_residuals()\n", "print(f\"Geo health: cos_mean={health['cos_mean']:.4f}, pr/dim={health['pr_over_dim']:.4f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ENCODING THROUGH PROCRUSTES-ALIGNED MODULATOR\")\n", "print(f\"{'='*70}\")\n", "\n", "proc_reps = np.zeros((len(matched), 512), dtype=np.float32)\n", "t0 = time.time()\n", "\n", "for batch_idx in range(n_batches):\n", " start = batch_idx * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(texts))\n", " inputs = tokenizer(texts[start:end], return_tensors=\"pt\", padding=True,\n", " truncation=True, max_length=MAX_LEN).to(device)\n", " with torch.no_grad():\n", " enc_out = mod_encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " hidden = enc_out.last_hidden_state.float()\n", " mask = inputs.attention_mask.unsqueeze(-1).float()\n", " pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1)\n", " proc_reps[start:end] = pooled.cpu().numpy()\n", "\n", "print(f\"Done in {time.time()-t0:.1f}s\")\n", "\n", "# Per-token divergence check\n", "per_tok = (encoder_reps * proc_reps).sum(axis=1) / (\n", " np.linalg.norm(encoder_reps, axis=1) * np.linalg.norm(proc_reps, axis=1) + 1e-8\n", ")\n", "print(f\"Per-token cos(orig, procrustes): mean={per_tok.mean():.6f} std={per_tok.std():.6f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"FULL COMPARISON: Static | Original | Procrustes-Aligned\")\n", "print(f\"{'='*70}\")\n", "\n", "rng2 = np.random.default_rng(42)\n", "N_REL = min(3000, len(matched))\n", "rel_idx2 = rng2.choice(len(matched), size=N_REL, replace=False)\n", "\n", "sets = {\n", " 'Static': static_reps[rel_idx2],\n", " 'Original': encoder_reps[rel_idx2],\n", " 'Procrustes': proc_reps[rel_idx2],\n", "}\n", "sets_n = {name: reps / (np.linalg.norm(reps, axis=1, keepdims=True) + 1e-8)\n", " for name, reps in sets.items()}\n", "\n", "pi = rng2.choice(N_REL, size=50000)\n", "pj = rng2.choice(N_REL, size=50000)\n", "valid = pi != pj\n", "pi, pj = pi[valid], pj[valid]\n", "\n", "wn_s = []\n", "cos_arrays = {name: [] for name in sets_n}\n", "\n", "print(\"Computing 50K WordNet pairs...\")\n", "for k in tqdm(range(min(50000, len(pi))), desc=\"WN pairs\", miniters=5000):\n", " a, b = pi[k], pj[k]\n", " sim = synsets[rel_idx2[a]].path_similarity(synsets[rel_idx2[b]])\n", " if sim is not None and sim > 0:\n", " wn_s.append(sim)\n", " for name, normed in sets_n.items():\n", " cos_arrays[name].append(np.dot(normed[a], normed[b]))\n", "\n", "wn_s = np.array(wn_s)\n", "for name in cos_arrays:\n", " cos_arrays[name] = np.array(cos_arrays[name])\n", "\n", "orig_p = np.corrcoef(wn_s, cos_arrays['Original'])[0, 1]\n", "orig_sp, _ = spearmanr(wn_s, cos_arrays['Original'])\n", "\n", "print(f\"\\n{'Method':20s} {'Pearson':>10s} {'Spearman':>10s} {'ΔP':>10s} {'ΔS':>10s}\")\n", "for name, cos_arr in cos_arrays.items():\n", " p = np.corrcoef(wn_s, cos_arr)[0, 1]\n", " sp, _ = spearmanr(wn_s, cos_arr)\n", " dp = p - orig_p if name != 'Original' else 0.0\n", " ds = sp - orig_sp if name != 'Original' else 0.0\n", " print(f\" {name:18s} {p:10.6f} {sp:10.6f} {dp:+10.6f} {ds:+10.6f}\")\n", "\n", "# Distance bands\n", "print(f\"\\n--- DISTANCE BANDS ---\")\n", "bands = [(0.5, 1.0), (0.25, 0.5), (0.10, 0.25), (0.05, 0.10), (0.0, 0.05)]\n", "print(f\"{'Band':>12s} {'Orig':>8s} {'Procrust':>8s} {'Δ':>8s}\")\n", "for lo, hi in bands:\n", " mask = (wn_s >= lo) & (wn_s < hi) if hi < 1.0 else (wn_s >= lo) & (wn_s <= hi)\n", " if mask.sum() < 5:\n", " continue\n", " oc = cos_arrays['Original'][mask].mean()\n", " pc = cos_arrays['Procrustes'][mask].mean()\n", " print(f\" [{lo:.2f},{hi:.2f}) {oc:8.4f} {pc:8.4f} {pc-oc:+8.4f}\")\n", "\n", "# Gradients\n", "high_mask = wn_s >= 0.25\n", "low_mask = wn_s < 0.10\n", "if high_mask.sum() > 0 and low_mask.sum() > 0:\n", " print(f\"\\n Gradients (high - low):\")\n", " for name, cos_arr in cos_arrays.items():\n", " grad = cos_arr[high_mask].mean() - cos_arr[low_mask].mean()\n", " print(f\" {name:18s}: {grad:.6f}\")\n", "\n", "# Pentachoron\n", "print(f\"\\n--- PENTACHORON GEOMETRY ---\")\n", "def cayley_menger_volume_sq(points):\n", " n = len(points)\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1; D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq; D[j + 1, i + 1] = d_sq\n", " k = n - 1\n", " det = np.linalg.det(D)\n", " return ((-1) ** (k + 1)) * det / ((2 ** k) * (math.factorial(k) ** 2))\n", "\n", "rng3 = np.random.default_rng(42)\n", "for name, reps in [('Original', encoder_reps), ('Procrustes', proc_reps)]:\n", " vols = []\n", " for _ in range(500):\n", " idx = rng3.choice(len(matched), size=5, replace=False)\n", " v = cayley_menger_volume_sq(reps[idx])\n", " if v > 0:\n", " vols.append(np.sqrt(v))\n", " vols = np.array(vols)\n", " print(f\" {name:18s}: CV={vols.std()/vols.mean():.4f} mean={vols.mean():.4e}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"DONE\")\n", "print(f\"{'='*70}\")\n", "print(f\"Alpha: {torch.sigmoid(modulator.alpha).detach().cpu().numpy()}\")\n", "print(f\"Procrustes alignment cosine: {alignment_cos.mean():.4f}\")\n", "print(f\"Per-token preservation: {per_tok.mean():.6f}\")\n" ] }, { "cell_type": "markdown", "id": "7efa6ba5", "metadata": {}, "source": [ "## 9. Modulator Training — Alpha Convergence\n", "*Section VII.3–VII.4: freeze T5, train modulator, track alpha per layer*" ] }, { "cell_type": "code", "execution_count": null, "id": "471acb2d", "metadata": {}, "outputs": [], "source": [ "# Alpha convergence training\n", "\n", "# ============================================================================\n", "# TRAIN THE GEOMETRIC MODULATOR — LET ALPHA FIND ITSELF\n", "# Freeze T5. Train only: geometric_embed, proj, alpha.\n", "# Task: summarize definitions → lemma words\n", "# Watch where alpha settles.\n", "# Run AFTER the full pipeline (model, modulator, mod_encoder, matched, etc.)\n", "# ============================================================================\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import math\n", "import time\n", "from scipy.stats import spearmanr\n", "from tqdm import tqdm\n", "\n", "# Freeze T5 entirely\n", "for param in model.parameters():\n", " param.requires_grad = False\n", "\n", "# Unfreeze only the modulator\n", "for param in modulator.parameters():\n", " param.requires_grad = True\n", "\n", "# Reset alpha to neutral starting point — let it find its own equilibrium\n", "INITIAL_ALPHA = 0.01\n", "with torch.no_grad():\n", " modulator.alpha.fill_(math.log(INITIAL_ALPHA / (1 - INITIAL_ALPHA)))\n", "\n", "trainable = sum(p.numel() for p in modulator.parameters() if p.requires_grad)\n", "frozen = sum(p.numel() for p in model.parameters())\n", "print(f\"Frozen T5 params: {frozen:,}\")\n", "print(f\"Trainable mod params: {trainable:,}\")\n", "print(f\"Ratio: {trainable/frozen*100:.3f}%\")\n", "print(f\"Starting alpha: {torch.sigmoid(modulator.alpha).detach().cpu().numpy()}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"BUILDING TRAINING DATA\")\n", "print(f\"{'='*70}\")\n", "\n", "# Pair each definition with its lemma as the target\n", "train_inputs = []\n", "train_targets = []\n", "\n", "for name, syn, tid, defn in matched:\n", " train_inputs.append(f\"summarize: {defn}\")\n", " train_targets.append(name)\n", "\n", "# Shuffle and split\n", "rng = np.random.default_rng(42)\n", "perm = rng.permutation(len(train_inputs))\n", "n_train = int(len(perm) * 0.9)\n", "train_idx = perm[:n_train]\n", "val_idx = perm[n_train:]\n", "\n", "print(f\"Train: {len(train_idx)}, Val: {len(val_idx)}\")\n", "print(f\"Example: '{train_inputs[train_idx[0]][:80]}...' → '{train_targets[train_idx[0]]}'\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"TRAINING — ALPHA FINDS ITSELF\")\n", "print(f\"{'='*70}\")\n", "\n", "BATCH_SIZE = 32\n", "N_EPOCHS = 10\n", "LR = 1e-3\n", "MAX_INPUT_LEN = 128\n", "MAX_TARGET_LEN = 16\n", "\n", "optimizer = torch.optim.AdamW(modulator.parameters(), lr=LR, weight_decay=0.01)\n", "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)\n", "\n", "# Tracking\n", "alpha_history = []\n", "loss_history = []\n", "val_loss_history = []\n", "\n", "def compute_batch_loss(batch_indices, inputs_list, targets_list):\n", " \"\"\"Forward pass through modulated encoder + T5 decoder, return CE loss.\"\"\"\n", " batch_in = [inputs_list[i] for i in batch_indices]\n", " batch_tgt = [targets_list[i] for i in batch_indices]\n", "\n", " # Tokenize\n", " enc_inputs = tokenizer(\n", " batch_in, return_tensors=\"pt\", padding=True,\n", " truncation=True, max_length=MAX_INPUT_LEN,\n", " ).to(device)\n", "\n", " dec_inputs = tokenizer(\n", " batch_tgt, return_tensors=\"pt\", padding=True,\n", " truncation=True, max_length=MAX_TARGET_LEN,\n", " ).to(device)\n", "\n", " labels = dec_inputs.input_ids.clone()\n", " labels[labels == tokenizer.pad_token_id] = -100 # ignore padding in loss\n", "\n", " # Forward through modulated encoder\n", " enc_out = mod_encoder(\n", " input_ids=enc_inputs.input_ids,\n", " attention_mask=enc_inputs.attention_mask,\n", " )\n", "\n", " # Forward through T5 decoder with encoder output\n", " dec_out = model.decoder(\n", " input_ids=dec_inputs.input_ids,\n", " encoder_hidden_states=enc_out.last_hidden_state,\n", " encoder_attention_mask=enc_inputs.attention_mask,\n", " )\n", "\n", " # LM head + loss\n", " logits = model.lm_head(dec_out.last_hidden_state)\n", " loss = F.cross_entropy(\n", " logits.view(-1, logits.size(-1)),\n", " labels.view(-1),\n", " ignore_index=-100,\n", " )\n", " return loss\n", "\n", "\n", "t0 = time.time()\n", "\n", "for epoch in range(N_EPOCHS):\n", " # Shuffle training data\n", " epoch_perm = rng.permutation(len(train_idx))\n", " epoch_indices = train_idx[epoch_perm]\n", "\n", " n_batches = (len(epoch_indices) + BATCH_SIZE - 1) // BATCH_SIZE\n", " epoch_losses = []\n", "\n", " modulator.train()\n", "\n", " for batch_i in range(n_batches):\n", " start = batch_i * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(epoch_indices))\n", " batch_idx = epoch_indices[start:end]\n", "\n", " optimizer.zero_grad()\n", " loss = compute_batch_loss(batch_idx, train_inputs, train_targets)\n", " loss.backward()\n", "\n", " # Gradient clip\n", " torch.nn.utils.clip_grad_norm_(modulator.parameters(), 1.0)\n", " optimizer.step()\n", "\n", " epoch_losses.append(loss.item())\n", "\n", " # Track alpha every 10 batches\n", " if batch_i % 10 == 0:\n", " current_alpha = torch.sigmoid(modulator.alpha).detach().cpu().numpy()\n", " alpha_history.append(current_alpha.copy())\n", "\n", " scheduler.step()\n", "\n", " # Validation\n", " modulator.eval()\n", " val_losses = []\n", " val_batches = (len(val_idx) + BATCH_SIZE - 1) // BATCH_SIZE\n", "\n", " with torch.no_grad():\n", " for batch_i in range(val_batches):\n", " start = batch_i * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(val_idx))\n", " batch = val_idx[start:end]\n", " vloss = compute_batch_loss(batch, train_inputs, train_targets)\n", " val_losses.append(vloss.item())\n", "\n", " train_loss = np.mean(epoch_losses)\n", " val_loss = np.mean(val_losses)\n", " loss_history.append(train_loss)\n", " val_loss_history.append(val_loss)\n", "\n", " current_alpha = torch.sigmoid(modulator.alpha).detach().cpu().numpy()\n", " elapsed = time.time() - t0\n", "\n", " print(f\" Epoch {epoch+1:2d}/{N_EPOCHS} \"\n", " f\"train_loss={train_loss:.4f} val_loss={val_loss:.4f} \"\n", " f\"alpha=[{', '.join(f'{a:.4f}' for a in current_alpha)}] \"\n", " f\"({elapsed:.0f}s)\")\n", "\n", "total_time = time.time() - t0\n", "print(f\"\\nTraining complete in {total_time:.1f}s\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ALPHA CONVERGENCE ANALYSIS\")\n", "print(f\"{'='*70}\")\n", "\n", "alpha_arr = np.array(alpha_history) # [n_checkpoints, n_layers]\n", "print(f\"Alpha checkpoints: {alpha_arr.shape[0]}\")\n", "print(f\"\\nFinal alpha per layer:\")\n", "final_alpha = torch.sigmoid(modulator.alpha).detach().cpu().numpy()\n", "for i, a in enumerate(final_alpha):\n", " print(f\" Layer {i}: {a:.6f}\")\n", "\n", "print(f\"\\nMean final alpha: {final_alpha.mean():.6f}\")\n", "print(f\"Std final alpha: {final_alpha.std():.6f}\")\n", "\n", "# Distance from known constants\n", "print(f\"\\nDistance from known constants:\")\n", "print(f\" |α - 0.29154| = {abs(final_alpha.mean() - 0.29154):.6f}\")\n", "print(f\" |α - 0.50000| = {abs(final_alpha.mean() - 0.50000):.6f}\")\n", "print(f\" |α - 0.70846| = {abs(final_alpha.mean() - 0.70846):.6f}\") # 1 - 0.29154\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"POST-TRAINING RELATIONAL MEASUREMENT\")\n", "print(f\"{'='*70}\")\n", "\n", "modulator.eval()\n", "\n", "# Re-encode all definitions with trained modulator\n", "trained_reps = np.zeros((len(matched), 512), dtype=np.float32)\n", "n_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE\n", "\n", "for batch_idx in range(n_batches):\n", " start = batch_idx * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(texts))\n", " inputs = tokenizer(texts[start:end], return_tensors=\"pt\", padding=True,\n", " truncation=True, max_length=MAX_INPUT_LEN).to(device)\n", " with torch.no_grad():\n", " enc_out = mod_encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " hidden = enc_out.last_hidden_state.float()\n", " mask = inputs.attention_mask.unsqueeze(-1).float()\n", " pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1)\n", " trained_reps[start:end] = pooled.cpu().numpy()\n", "\n", "# Per-token preservation\n", "per_tok = (encoder_reps * trained_reps).sum(axis=1) / (\n", " np.linalg.norm(encoder_reps, axis=1) * np.linalg.norm(trained_reps, axis=1) + 1e-8\n", ")\n", "print(f\"Per-token cos(orig, trained): mean={per_tok.mean():.6f}\")\n", "\n", "# Relational correlation\n", "rng2 = np.random.default_rng(42)\n", "N_REL = min(3000, len(matched))\n", "rel_idx2 = rng2.choice(len(matched), size=N_REL, replace=False)\n", "\n", "sets = {\n", " 'Original': encoder_reps[rel_idx2],\n", " 'Trained': trained_reps[rel_idx2],\n", "}\n", "sets_n = {name: reps / (np.linalg.norm(reps, axis=1, keepdims=True) + 1e-8)\n", " for name, reps in sets.items()}\n", "\n", "pi = rng2.choice(N_REL, size=50000)\n", "pj = rng2.choice(N_REL, size=50000)\n", "valid = pi != pj\n", "pi, pj = pi[valid], pj[valid]\n", "\n", "wn_s = []\n", "cos_arrays = {name: [] for name in sets_n}\n", "\n", "for k in tqdm(range(min(50000, len(pi))), desc=\"WN pairs\", miniters=5000):\n", " a, b = pi[k], pj[k]\n", " sim = synsets[rel_idx2[a]].path_similarity(synsets[rel_idx2[b]])\n", " if sim is not None and sim > 0:\n", " wn_s.append(sim)\n", " for name, normed in sets_n.items():\n", " cos_arrays[name].append(np.dot(normed[a], normed[b]))\n", "\n", "wn_s = np.array(wn_s)\n", "for name in cos_arrays:\n", " cos_arrays[name] = np.array(cos_arrays[name])\n", "\n", "orig_p = np.corrcoef(wn_s, cos_arrays['Original'])[0, 1]\n", "trained_p = np.corrcoef(wn_s, cos_arrays['Trained'])[0, 1]\n", "orig_sp, _ = spearmanr(wn_s, cos_arrays['Original'])\n", "trained_sp, _ = spearmanr(wn_s, cos_arrays['Trained'])\n", "\n", "print(f\"\\n{'Method':20s} {'Pearson':>10s} {'Spearman':>10s}\")\n", "print(f\" {'Original':18s} {orig_p:10.6f} {orig_sp:10.6f}\")\n", "print(f\" {'Trained':18s} {trained_p:10.6f} {trained_sp:10.6f}\")\n", "print(f\" {'Δ':18s} {trained_p-orig_p:+10.6f} {trained_sp-orig_sp:+10.6f}\")\n", "\n", "# Gradient\n", "high_mask = wn_s >= 0.25\n", "low_mask = wn_s < 0.10\n", "if high_mask.sum() > 0 and low_mask.sum() > 0:\n", " orig_grad = cos_arrays['Original'][high_mask].mean() - cos_arrays['Original'][low_mask].mean()\n", " trained_grad = cos_arrays['Trained'][high_mask].mean() - cos_arrays['Trained'][low_mask].mean()\n", " print(f\"\\n Gradient: orig={orig_grad:.6f} trained={trained_grad:.6f} Δ={trained_grad-orig_grad:+.6f}\")\n", "\n", "# Pentachoron\n", "def cayley_menger_volume_sq(points):\n", " n = len(points)\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1; D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq; D[j + 1, i + 1] = d_sq\n", " k = n - 1\n", " det = np.linalg.det(D)\n", " return ((-1) ** (k + 1)) * det / ((2 ** k) * (math.factorial(k) ** 2))\n", "\n", "rng3 = np.random.default_rng(42)\n", "for name, reps in [('Original', encoder_reps), ('Trained', trained_reps)]:\n", " vols = []\n", " for _ in range(500):\n", " idx = rng3.choice(len(matched), size=5, replace=False)\n", " v = cayley_menger_volume_sq(reps[idx])\n", " if v > 0:\n", " vols.append(np.sqrt(v))\n", " vols = np.array(vols)\n", " print(f\" {name:18s}: CV={vols.std()/vols.mean():.4f} mean={vols.mean():.4e}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"POST-TRAINING COHERENCE\")\n", "print(f\"{'='*70}\")\n", "\n", "from transformers.modeling_outputs import BaseModelOutput\n", "\n", "test_prompts = [\n", " \"summarize: The cat is a small domesticated carnivorous mammal with soft fur.\",\n", " \"translate English to German: The geometric structure of language is a universal attractor.\",\n", " \"summarize: A triangle is a polygon with three edges and three vertices.\",\n", " \"summarize: Seven is a prime number that comes after six and before eight.\",\n", "]\n", "\n", "for prompt in test_prompts:\n", " inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", " with torch.no_grad():\n", " orig_enc = model.encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " orig_gen = model.generate(encoder_outputs=orig_enc,\n", " attention_mask=inputs.attention_mask,\n", " max_new_tokens=64)\n", " orig_text = tokenizer.decode(orig_gen[0], skip_special_tokens=True)\n", "\n", " mod_enc = mod_encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " mod_wrapped = BaseModelOutput(last_hidden_state=mod_enc.last_hidden_state)\n", " mod_gen = model.generate(encoder_outputs=mod_wrapped,\n", " attention_mask=inputs.attention_mask,\n", " max_new_tokens=64)\n", " mod_text = tokenizer.decode(mod_gen[0], skip_special_tokens=True)\n", "\n", " print(f\"\\n ORIG: {orig_text}\")\n", " print(f\" TRAINED: {mod_text}\")\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", "fig.suptitle(f\"Geometric Modulator Training — Final α={final_alpha.mean():.5f}\", fontsize=14)\n", "\n", "# 1. Alpha convergence\n", "alpha_arr = np.array(alpha_history)\n", "for layer in range(alpha_arr.shape[1]):\n", " axes[0].plot(alpha_arr[:, layer], alpha=0.5, label=f'L{layer}')\n", "axes[0].axhline(0.29154, color='red', ls='--', alpha=0.7, label='0.29154')\n", "axes[0].axhline(0.50, color='gray', ls=':', alpha=0.5, label='0.50')\n", "axes[0].set_xlabel(\"Checkpoint (every 10 batches)\")\n", "axes[0].set_ylabel(\"Alpha\")\n", "axes[0].set_title(\"Alpha convergence per layer\")\n", "axes[0].legend(fontsize=7)\n", "\n", "# 2. Loss curves\n", "axes[1].plot(loss_history, 'b-', label='Train')\n", "axes[1].plot(val_loss_history, 'r-', label='Val')\n", "axes[1].set_xlabel(\"Epoch\")\n", "axes[1].set_ylabel(\"Loss\")\n", "axes[1].set_title(\"Training loss\")\n", "axes[1].legend()\n", "\n", "# 3. Alpha distribution at end\n", "axes[2].bar(range(len(final_alpha)), final_alpha, color='teal')\n", "axes[2].axhline(0.29154, color='red', ls='--', label='0.29154')\n", "axes[2].set_xlabel(\"Layer\")\n", "axes[2].set_ylabel(\"Final alpha\")\n", "axes[2].set_title(\"Per-layer final alpha\")\n", "axes[2].legend()\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"/content/modulator_training.png\", dpi=150, bbox_inches='tight')\n", "plt.show()\n", "print(\"\\nSaved: /content/modulator_training.png\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"FINAL SUMMARY\")\n", "print(f\"{'='*70}\")\n", "print(f\"Model: {model_id}\")\n", "print(f\"Trainable params: {trainable:,} ({trainable/frozen*100:.3f}% of T5)\")\n", "print(f\"Epochs: {N_EPOCHS}, Time: {total_time:.0f}s\")\n", "print(f\"\")\n", "print(f\"Final alpha per layer: {final_alpha}\")\n", "print(f\"Mean alpha: {final_alpha.mean():.6f}\")\n", "print(f\"|α - 0.29154| = {abs(final_alpha.mean() - 0.29154):.6f}\")\n", "print(f\"\")\n", "print(f\"Pearson: orig={orig_p:.4f} trained={trained_p:.4f} Δ={trained_p-orig_p:+.4f}\")\n", "print(f\"Per-token preservation: {per_tok.mean():.4f}\")\n" ] }, { "cell_type": "markdown", "id": "4a099b84", "metadata": {}, "source": [ "## 10. The 0.29154 Constant\n", "*Section IX: phase boundary measurement*" ] }, { "cell_type": "code", "execution_count": null, "id": "69164414", "metadata": {}, "outputs": [], "source": [ "# Alpha 0.29154\n", "\n", "# ============================================================================\n", "# ALPHA = 0.29514 — THE CONSTANT\n", "# Run AFTER the full pipeline is loaded\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import math\n", "import time\n", "from scipy.stats import spearmanr\n", "from tqdm import tqdm\n", "\n", "TARGET_ALPHA = 0.29154\n", "\n", "# Set alpha\n", "logit = math.log(TARGET_ALPHA / (1 - TARGET_ALPHA))\n", "with torch.no_grad():\n", " modulator.alpha.fill_(logit)\n", "\n", "actual_alpha = torch.sigmoid(modulator.alpha).detach().cpu().numpy()\n", "print(f\"Alpha set to: {actual_alpha}\")\n", "\n", "print(f\"\\nEncoding {len(matched)} definitions at alpha={TARGET_ALPHA}...\")\n", "reps_29514 = np.zeros((len(matched), 512), dtype=np.float32)\n", "t0 = time.time()\n", "BATCH_SIZE = 64\n", "n_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE\n", "\n", "for batch_idx in range(n_batches):\n", " start = batch_idx * BATCH_SIZE\n", " end = min(start + BATCH_SIZE, len(texts))\n", " inputs = tokenizer(texts[start:end], return_tensors=\"pt\", padding=True,\n", " truncation=True, max_length=128).to(device)\n", " with torch.no_grad():\n", " enc_out = mod_encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " hidden = enc_out.last_hidden_state.float()\n", " mask = inputs.attention_mask.unsqueeze(-1).float()\n", " pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1)\n", " reps_29514[start:end] = pooled.cpu().numpy()\n", "\n", "print(f\"Done in {time.time()-t0:.1f}s\")\n", "\n", "# Per-token preservation\n", "per_tok = (encoder_reps * reps_29514).sum(axis=1) / (\n", " np.linalg.norm(encoder_reps, axis=1) * np.linalg.norm(reps_29514, axis=1) + 1e-8\n", ")\n", "print(f\"Per-token cos(orig, 0.29514): mean={per_tok.mean():.6f} std={per_tok.std():.6f}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(f\"RELATIONAL MEASUREMENT AT ALPHA = {TARGET_ALPHA}\")\n", "print(f\"{'='*70}\")\n", "\n", "rng2 = np.random.default_rng(42)\n", "N_REL = min(3000, len(matched))\n", "rel_idx2 = rng2.choice(len(matched), size=N_REL, replace=False)\n", "\n", "sets = {\n", " 'Static': static_reps[rel_idx2],\n", " 'Original': encoder_reps[rel_idx2],\n", " f'α={TARGET_ALPHA}': reps_29514[rel_idx2],\n", "}\n", "sets_n = {name: reps / (np.linalg.norm(reps, axis=1, keepdims=True) + 1e-8)\n", " for name, reps in sets.items()}\n", "\n", "pi = rng2.choice(N_REL, size=50000)\n", "pj = rng2.choice(N_REL, size=50000)\n", "valid = pi != pj\n", "pi, pj = pi[valid], pj[valid]\n", "\n", "wn_s = []\n", "cos_arrays = {name: [] for name in sets_n}\n", "\n", "for k in tqdm(range(min(50000, len(pi))), desc=\"WN pairs\", miniters=5000):\n", " a, b = pi[k], pj[k]\n", " sim = synsets[rel_idx2[a]].path_similarity(synsets[rel_idx2[b]])\n", " if sim is not None and sim > 0:\n", " wn_s.append(sim)\n", " for name, normed in sets_n.items():\n", " cos_arrays[name].append(np.dot(normed[a], normed[b]))\n", "\n", "wn_s = np.array(wn_s)\n", "for name in cos_arrays:\n", " cos_arrays[name] = np.array(cos_arrays[name])\n", "\n", "orig_p = np.corrcoef(wn_s, cos_arrays['Original'])[0, 1]\n", "orig_sp, _ = spearmanr(wn_s, cos_arrays['Original'])\n", "\n", "print(f\"\\n{'Method':20s} {'Pearson':>10s} {'Spearman':>10s} {'ΔP':>10s} {'ΔS':>10s}\")\n", "for name, cos_arr in cos_arrays.items():\n", " p = np.corrcoef(wn_s, cos_arr)[0, 1]\n", " sp, _ = spearmanr(wn_s, cos_arr)\n", " dp = p - orig_p if name != 'Original' else 0.0\n", " ds = sp - orig_sp if name != 'Original' else 0.0\n", " print(f\" {name:18s} {p:10.6f} {sp:10.6f} {dp:+10.6f} {ds:+10.6f}\")\n", "\n", "# Distance bands\n", "print(f\"\\n--- DISTANCE BANDS ---\")\n", "bands = [(0.5, 1.0), (0.25, 0.5), (0.10, 0.25), (0.05, 0.10), (0.0, 0.05)]\n", "key_29 = f'α={TARGET_ALPHA}'\n", "print(f\"{'Band':>12s} {'Orig':>8s} {'0.29514':>8s} {'Δ':>8s}\")\n", "for lo, hi in bands:\n", " mask = (wn_s >= lo) & (wn_s < hi) if hi < 1.0 else (wn_s >= lo) & (wn_s <= hi)\n", " if mask.sum() < 5:\n", " continue\n", " oc = cos_arrays['Original'][mask].mean()\n", " mc = cos_arrays[key_29][mask].mean()\n", " print(f\" [{lo:.2f},{hi:.2f}) {oc:8.4f} {mc:8.4f} {mc-oc:+8.4f}\")\n", "\n", "# Gradients\n", "high_mask = wn_s >= 0.25\n", "low_mask = wn_s < 0.10\n", "if high_mask.sum() > 0 and low_mask.sum() > 0:\n", " print(f\"\\n Gradients (high - low):\")\n", " for name, cos_arr in cos_arrays.items():\n", " grad = cos_arr[high_mask].mean() - cos_arr[low_mask].mean()\n", " print(f\" {name:18s}: {grad:.6f}\")\n", "\n", "# Pentachoron\n", "print(f\"\\n--- PENTACHORON GEOMETRY ---\")\n", "def cayley_menger_volume_sq(points):\n", " n = len(points)\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1; D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq; D[j + 1, i + 1] = d_sq\n", " k = n - 1\n", " det = np.linalg.det(D)\n", " return ((-1) ** (k + 1)) * det / ((2 ** k) * (math.factorial(k) ** 2))\n", "\n", "rng3 = np.random.default_rng(42)\n", "for name, reps in [('Original', encoder_reps), (f'α={TARGET_ALPHA}', reps_29514)]:\n", " vols = []\n", " for _ in range(500):\n", " idx = rng3.choice(len(matched), size=5, replace=False)\n", " v = cayley_menger_volume_sq(reps[idx])\n", " if v > 0:\n", " vols.append(np.sqrt(v))\n", " vols = np.array(vols)\n", " print(f\" {name:18s}: CV={vols.std()/vols.mean():.4f} mean={vols.mean():.4e}\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(f\"COHERENCE TEST AT ALPHA = {TARGET_ALPHA}\")\n", "print(f\"{'='*70}\")\n", "\n", "test_prompts = [\n", " \"summarize: The cat is a small domesticated carnivorous mammal with soft fur.\",\n", " \"translate English to German: The geometric structure of language is a universal attractor.\",\n", " \"summarize: Mathematics is the study of numbers, quantities, shapes, and patterns.\",\n", " \"summarize: Seven is a prime number that comes after six and before eight.\",\n", " \"summarize: A triangle is a polygon with three edges and three vertices.\",\n", "]\n", "\n", "for prompt in test_prompts:\n", " inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n", "\n", " with torch.no_grad():\n", " # Original\n", " orig_enc = model.encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " orig_gen = model.generate(encoder_outputs=orig_enc,\n", " attention_mask=inputs.attention_mask,\n", " max_new_tokens=64)\n", " orig_text = tokenizer.decode(orig_gen[0], skip_special_tokens=True)\n", "\n", " # Modulated\n", " mod_enc = mod_encoder(input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask)\n", " from transformers.modeling_outputs import BaseModelOutput\n", " mod_wrapped = BaseModelOutput(last_hidden_state=mod_enc.last_hidden_state)\n", " mod_gen = model.generate(encoder_outputs=mod_wrapped,\n", " attention_mask=inputs.attention_mask,\n", " max_new_tokens=64)\n", " mod_text = tokenizer.decode(mod_gen[0], skip_special_tokens=True)\n", "\n", " print(f\"\\n ORIG: {orig_text}\")\n", " print(f\" 0.29: {mod_text}\")\n", " if orig_text == mod_text:\n", " print(f\" >>> IDENTICAL\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(f\"SUMMARY AT ALPHA = {TARGET_ALPHA}\")\n", "print(f\"{'='*70}\")\n", "p_29 = np.corrcoef(wn_s, cos_arrays[key_29])[0, 1]\n", "sp_29, _ = spearmanr(wn_s, cos_arrays[key_29])\n", "grad_orig = cos_arrays['Original'][high_mask].mean() - cos_arrays['Original'][low_mask].mean()\n", "grad_29 = cos_arrays[key_29][high_mask].mean() - cos_arrays[key_29][low_mask].mean()\n", "\n", "print(f\"Per-token preservation: {per_tok.mean():.4f}\")\n", "print(f\"Pearson: orig={orig_p:.4f} α=0.2915={p_29:.4f} Δ={p_29-orig_p:+.4f}\")\n", "print(f\"Spearman: orig={orig_sp:.4f} α=0.2915={sp_29:.4f} Δ={sp_29-orig_sp:+.4f}\")\n", "print(f\"Gradient: orig={grad_orig:.4f} α=0.2915={grad_29:.4f} Δ={grad_29-grad_orig:+.4f}\")\n", "print(f\"\")\n", "print(f\"Reference points:\")\n", "print(f\" α=0.01: Pearson=0.099, gradient=0.022, preservation=0.9998\")\n", "print(f\" α=0.50: Pearson=0.185, gradient=0.076, preservation=0.176\")\n", "print(f\" α=0.29154: ??? — the constant decides\")\n" ] }, { "cell_type": "markdown", "id": "9d53ab8a", "metadata": {}, "source": [ "## 11. Talk to the Modulated T5\n", "*Qualitative coherence test + alpha sweep*" ] }, { "cell_type": "code", "execution_count": null, "id": "0f6cd243", "metadata": {}, "outputs": [], "source": [ "# Talk to modulated T5\n", "\n", "# ============================================================================\n", "# TALK TO THE MODULATED T5\n", "# Run AFTER the full pipeline (modulator, mod_encoder, model all in memory)\n", "# Compare what T5 says with and without geometric modulation\n", "# ============================================================================\n", "\n", "import torch\n", "\n", "def generate_comparison(prompt, max_new_tokens=128, temperature=0.7):\n", " \"\"\"Generate from both original and modulated encoder, compare outputs.\"\"\"\n", "\n", " inputs = tokenizer(prompt, return_tensors=\"pt\", padding=False).to(device)\n", "\n", " # --- Original encoder → decoder generation ---\n", " with torch.no_grad():\n", " orig_enc_out = model.encoder(\n", " input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask,\n", " )\n", " orig_generated = model.generate(\n", " encoder_outputs=orig_enc_out,\n", " attention_mask=inputs.attention_mask,\n", " max_new_tokens=max_new_tokens,\n", " do_sample=True if temperature > 0 else False,\n", " temperature=temperature,\n", " top_p=0.9,\n", " )\n", " orig_text = tokenizer.decode(orig_generated[0], skip_special_tokens=True)\n", "\n", " # --- Modulated encoder → decoder generation ---\n", " with torch.no_grad():\n", " mod_enc_out = mod_encoder(\n", " input_ids=inputs.input_ids,\n", " attention_mask=inputs.attention_mask,\n", " )\n", " # Wrap in the format model.generate expects\n", " from transformers.modeling_outputs import BaseModelOutput\n", " mod_enc_wrapped = BaseModelOutput(last_hidden_state=mod_enc_out.last_hidden_state)\n", "\n", " mod_generated = model.generate(\n", " encoder_outputs=mod_enc_wrapped,\n", " attention_mask=inputs.attention_mask,\n", " max_new_tokens=max_new_tokens,\n", " do_sample=True if temperature > 0 else False,\n", " temperature=temperature,\n", " top_p=0.9,\n", " )\n", " mod_text = tokenizer.decode(mod_generated[0], skip_special_tokens=True)\n", "\n", " return orig_text, mod_text\n", "\n", "\n", "def talk(prompt, max_new_tokens=128, temperature=0.0):\n", " \"\"\"Pretty-print comparison.\"\"\"\n", " print(f\"\\n{'='*70}\")\n", " print(f\"PROMPT: {prompt}\")\n", " print(f\"{'='*70}\")\n", "\n", " orig, mod = generate_comparison(prompt, max_new_tokens, temperature)\n", "\n", " print(f\"\\n ORIGINAL: {orig}\")\n", " print(f\" MODULATED: {mod}\")\n", "\n", " if orig == mod:\n", " print(f\"\\n >>> IDENTICAL (modulator too quiet or greedy decoding converged)\")\n", " else:\n", " # Count token-level differences\n", " orig_toks = tokenizer.encode(orig)\n", " mod_toks = tokenizer.encode(mod)\n", " max_len = max(len(orig_toks), len(mod_toks))\n", " diffs = sum(1 for i in range(min(len(orig_toks), len(mod_toks)))\n", " if orig_toks[i] != mod_toks[i])\n", " diffs += abs(len(orig_toks) - len(mod_toks))\n", " print(f\"\\n >>> {diffs} token differences out of {max_len}\")\n", "\n", "\n", "print(\"Alpha:\", torch.sigmoid(modulator.alpha).detach().cpu().numpy())\n", "print()\n", "\n", "# Summarization\n", "talk(\"summarize: The cat is a small domesticated carnivorous mammal with soft fur, a short snout, and retractable claws. It is widely kept as a pet and valued for companionship and ability to hunt vermin.\")\n", "\n", "# Translation\n", "talk(\"translate English to German: The geometric structure of language is a universal attractor.\")\n", "\n", "# Question-style\n", "talk(\"summarize: Mathematics is the study of numbers, quantities, shapes, and patterns. It uses rigorous logical reasoning and abstraction to understand structures that exist independently of physical reality.\")\n", "\n", "# Semantic similarity test — do related concepts change differently?\n", "talk(\"summarize: A dog is a domesticated descendant of the wolf, characterized by loyalty and trainability.\")\n", "talk(\"summarize: A wolf is a large wild canine that lives and hunts in packs across the Northern Hemisphere.\")\n", "\n", "# Abstract concept\n", "talk(\"summarize: Love is a complex set of emotions, behaviors, and beliefs associated with strong feelings of affection, protectiveness, warmth, and respect for another person.\")\n", "\n", "# Numbers (digit manifold test)\n", "talk(\"summarize: Seven is a prime number that comes after six and before eight in the natural number sequence.\")\n", "\n", "print(f\"\\n{'='*70}\")\n", "print(\"ALPHA SENSITIVITY — same prompt, varying geometric strength\")\n", "print(f\"{'='*70}\")\n", "\n", "test_prompt = \"summarize: A triangle is a polygon with three edges and three vertices. It is one of the basic shapes in geometry.\"\n", "\n", "for alpha_val in [0.01, 0.05, 0.10, 0.20, 0.50]:\n", " logit = torch.tensor(alpha_val / (1 - alpha_val)).log()\n", " with torch.no_grad():\n", " modulator.alpha.fill_(logit.item())\n", "\n", " _, mod_text = generate_comparison(test_prompt, max_new_tokens=64, temperature=0.0)\n", " print(f\" alpha={alpha_val:.2f}: {mod_text}\")\n", "\n", "# Reset alpha\n", "import math\n", "with torch.no_grad():\n", " modulator.alpha.fill_(math.log(0.01 / 0.99))\n", "print(f\"\\nAlpha reset to 0.01\")\n" ] }, { "cell_type": "markdown", "id": "59c17c58", "metadata": {}, "source": [ "## 12. Geometric Field Modulator (Multi-Expert)\n", "*Section VIII: KSimplex experts k=1,2,4 + multiplicative gating. 38,552 params.*" ] }, { "cell_type": "code", "execution_count": null, "id": "e39a5eb9", "metadata": {}, "outputs": [], "source": [ "# Field modulator\n", "\n", "\"\"\"\n", "GeometricFieldModulator\n", "=======================\n", "\n", "Multi-expert geometric field constraints on transformer residual dynamics.\n", "\n", "Architecture:\n", " Multiple KSimplexChannel experts (different k) measure different scales\n", " of geometric coherence. Each produces a validity gate. The residual stream\n", " is multiplicatively modulated — valid regions pass, invalid regions are\n", " suppressed. The model naturally migrates toward the valid manifold.\n", "\n", "Principle:\n", " The geometry doesn't add a delta. It constrains the delta.\n", " Features × Π(sigmoid(expert_gates)) = constrained features.\n", "\n", "Safeguards:\n", " - Validity monitoring via CM determinants at every layer\n", " - Null space preservation (configurable fraction of dims untouched)\n", " - Per-layer per-expert learned alpha with sigmoid clamping\n", " - Gradient scaling on geometric params to prevent runaway\n", " - Health metrics computed without inference for diagnostics\n", " - Early stopping criteria based on CV drift from universal band\n", "\n", "Grounded in measurements from 2026-03-05:\n", " - Pentachoron CV universal band: 0.20–0.23\n", " - Participation ratio / dim: ~0.53–0.56\n", " - Q sparsity: 93–99% (the null space the model carved for us)\n", " - Depth gradient: monotonically increasing alpha (low early, high late)\n", " - Phase boundary at 0.29154 (binding/separation constant)\n", " - Best result: Pearson +152%, CV stayed in band, coherence preserved\n", "\n", "Authors: AbstractPhil + Claude\n", "License: Apache 2.0\n", "\"\"\"\n", "\n", "import math\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from itertools import combinations\n", "from typing import Tuple, Dict, Optional, List, NamedTuple\n", "from dataclasses import dataclass\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# CONFIGURATION\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "@dataclass\n", "class FieldModulatorConfig:\n", " \"\"\"Configuration for GeometricFieldModulator.\"\"\"\n", " d_model: int = 512\n", " vocab_size: int = 32128\n", " n_layers: int = 6\n", "\n", " # Expert configuration — each expert is a KSimplexChannel at different k\n", " expert_ks: Tuple[int, ...] = (1, 2, 4) # edge, triangle, pentachoron\n", " expert_edim: int = 8 # simplex embedding dimension\n", "\n", " # Alpha configuration\n", " initial_alpha: float = 0.01 # start quiet\n", " alpha_max: float = 0.35 # hard ceiling — never exceed binding/separation boundary\n", " alpha_min: float = 0.001 # floor — always some geometric signal\n", " per_layer_per_expert: bool = True # full granularity\n", "\n", " # Null space preservation\n", " null_space_fraction: float = 0.25 # 25% of dims untouched by modulator\n", " null_space_position: str = \"tail\" # \"tail\" = last 25%, \"random\" = scattered\n", "\n", " # Safeguards\n", " gradient_scale: float = 0.1 # scale gradients on geometric params\n", " cv_target: float = 0.21 # universal pentachoron CV target\n", " cv_tolerance: float = 0.05 # alert if CV drifts outside [0.16, 0.26]\n", " deform_scale: float = 0.05 # simplex deformation magnitude\n", "\n", " # Monitoring\n", " track_validity: bool = True # track CM volumes per expert per layer\n", " track_cv: bool = True # periodic CV measurement\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# CAYLEY-MENGER VALIDATOR — batch-friendly, differentiable\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "class CMValidator(nn.Module):\n", " \"\"\"\n", " Cayley-Menger determinant for k-simplex volume computation.\n", "\n", " Input: (..., n_vertices, embed_dim)\n", " Output: d2_pairs (..., n_pairs), vol2 (...,)\n", "\n", " For k=4: 5 vertices → 10 pairwise d² + 1 vol² = 11 geometric features.\n", " \"\"\"\n", "\n", " def __init__(self, k: int):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer(\"_pi\", torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer(\"_pj\", torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " @property\n", " def n_pairs(self):\n", " return self._npairs\n", "\n", " @property\n", " def out_dim(self):\n", " return self._npairs + 1 # d² pairs + vol²\n", "\n", " def forward(self, verts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", " \"\"\"\n", " Args:\n", " verts: (..., n_vertices, embed_dim)\n", " Returns:\n", " d2_pairs: (..., n_pairs)\n", " vol2: (...,)\n", " \"\"\"\n", " # Gram matrix → pairwise squared distances\n", " gram = torch.einsum(\"...ve,...we->...vw\", verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat) # numerical safety\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " # Bordered distance matrix → determinant → volume\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V + 1, V + 1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", "\n", " vol2 = self._prefactor * torch.linalg.det(cm)\n", "\n", " return d2_pairs, vol2\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# SIMPLEX TEMPLATE FACTORY — deterministic, frozen\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "def build_regular_simplex(k: int, edim: int) -> torch.Tensor:\n", " \"\"\"\n", " Build a regular k-simplex with k+1 vertices in edim dimensions.\n", " All edges have equal length. Centered at origin.\n", "\n", " Returns: (k+1, edim) tensor\n", " \"\"\"\n", " nv = k + 1\n", " if edim < k:\n", " raise ValueError(f\"edim ({edim}) must be >= k ({k})\")\n", "\n", " # Regular simplex construction via diagonal formula\n", " verts = torch.zeros(nv, edim)\n", " for i in range(nv):\n", " for j in range(min(i, edim)):\n", " if j < i:\n", " verts[i, j] = 0.0\n", " if i < edim:\n", " # Diagonal element\n", " verts[i, i] = math.sqrt((k + 1) / k) if i == 0 else math.sqrt(\n", " (k + 1) / (k * (1 + 1 / (i + 1)))\n", " )\n", " # Off-diagonal correction\n", " for prev in range(i):\n", " verts[i, prev] = -1.0 / (k * verts[prev, prev]) if verts[prev, prev] != 0 else 0\n", "\n", " # Center at origin\n", " centroid = verts.mean(dim=0)\n", " verts = verts - centroid\n", "\n", " # Normalize edge lengths to 1\n", " if nv >= 2:\n", " edge_len = (verts[0] - verts[1]).norm()\n", " if edge_len > 1e-8:\n", " verts = verts / edge_len\n", "\n", " return verts\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# K-SIMPLEX CHANNEL EXPERT — geometric features per position\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "class KSimplexChannelExpert(nn.Module):\n", " \"\"\"\n", " Single geometric expert. Transforms input features into simplex deformations,\n", " applies CM validation, outputs geometric features + validity signal.\n", "\n", " Input: (..., in_dim)\n", " Output: geo_features (..., out_dim), vol2 (...,), validity (..., 1)\n", "\n", " The geometric features are the d² pairs + vol² from the Cayley-Menger\n", " determinant. The validity signal is sigmoid(vol2) — 1.0 when the simplex\n", " is well-formed, 0.0 when degenerate.\n", " \"\"\"\n", "\n", " def __init__(self, k: int, in_dim: int, edim: int, deform_scale: float = 0.05):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", " self._edim = edim\n", " self._deform_scale = deform_scale\n", "\n", " self._cm = CMValidator(k)\n", "\n", " # Frozen regular simplex template\n", " template = build_regular_simplex(k, edim)\n", " self.register_buffer(\"_template\", template)\n", "\n", " # Learned deformation from input features\n", " self._to_deform = nn.Linear(in_dim, self._nv * edim)\n", "\n", " # LayerNorm on geometric output for stable gradients\n", " self._norm = nn.LayerNorm(self._cm.out_dim)\n", "\n", " # Validity projection — vol² → scalar gate\n", " self._validity_proj = nn.Sequential(\n", " nn.Linear(self._cm.out_dim, 1),\n", " nn.Sigmoid(),\n", " )\n", "\n", " @property\n", " def out_dim(self):\n", " return self._cm.out_dim\n", "\n", " @property\n", " def k(self):\n", " return self._k\n", "\n", " def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n", " \"\"\"\n", " Args:\n", " x: (..., in_dim)\n", " Returns:\n", " geo: (..., out_dim) — geometric features [d², vol²]\n", " vol2: (...,) — raw simplex volume squared (for monitoring)\n", " validity: (..., 1) — sigmoid validity gate\n", " \"\"\"\n", " # Input → deformation vectors for each vertex\n", " deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim))\n", "\n", " # Template + scaled deformation = actual vertices\n", " verts = self._template + self._deform_scale * deform\n", "\n", " # CM computation → geometric features\n", " d2, vol2 = self._cm(verts)\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", " geo = self._norm(geo)\n", "\n", " # Validity gate from geometric features\n", " validity = self._validity_proj(geo)\n", "\n", " return geo, vol2, validity\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# GEOMETRIC FIELD MODULATOR — the full system\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "class FieldModulatorOutput(NamedTuple):\n", " \"\"\"Output from the field modulator.\"\"\"\n", " residual: torch.Tensor # modulated residual stream\n", " validity_map: torch.Tensor # per-position combined validity\n", " expert_volumes: Dict[int, torch.Tensor] # per-expert vol² for monitoring\n", "\n", "\n", "class GeometricFieldModulator(nn.Module):\n", " \"\"\"\n", " Multi-expert geometric field constraints on transformer residual dynamics.\n", "\n", " Multiple KSimplexChannel experts at different k (edge, triangle, pentachoron)\n", " each measure a different scale of geometric coherence. Their validity gates\n", " are combined multiplicatively to constrain the residual stream.\n", "\n", " The model's features pass through where geometry is valid.\n", " They're suppressed where geometry is violated.\n", " The model naturally migrates toward the valid manifold.\n", " \"\"\"\n", "\n", " def __init__(self, config: FieldModulatorConfig):\n", " super().__init__()\n", " self.config = config\n", " d = config.d_model\n", "\n", " # Compute active dimensions (excluding null space)\n", " n_null = int(d * config.null_space_fraction)\n", " self.n_active = d - n_null\n", " self.n_null = n_null\n", "\n", " # Build null space mask\n", " if config.null_space_position == \"tail\":\n", " # Last n_null dims are untouched\n", " mask = torch.ones(d)\n", " mask[self.n_active:] = 0.0\n", " else:\n", " # Random scattered null space (deterministic seed)\n", " rng = torch.Generator().manual_seed(42)\n", " perm = torch.randperm(d, generator=rng)\n", " mask = torch.ones(d)\n", " mask[perm[:n_null]] = 0.0\n", " self.register_buffer(\"_active_mask\", mask.unsqueeze(0).unsqueeze(0)) # (1, 1, d)\n", "\n", " # Build experts\n", " self.experts = nn.ModuleDict()\n", " for k in config.expert_ks:\n", " self.experts[f\"k{k}\"] = KSimplexChannelExpert(\n", " k=k,\n", " in_dim=self.n_active,\n", " edim=config.expert_edim,\n", " deform_scale=config.deform_scale,\n", " )\n", "\n", " # Per-expert gate projection: geo_features → d_model gate\n", " self.gate_projs = nn.ModuleDict()\n", " for k in config.expert_ks:\n", " expert = self.experts[f\"k{k}\"]\n", " self.gate_projs[f\"k{k}\"] = nn.Sequential(\n", " nn.Linear(expert.out_dim, self.n_active),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Per-layer per-expert alpha (in logit space)\n", " n_experts = len(config.expert_ks)\n", " logit_init = math.log(config.initial_alpha / (1 - config.initial_alpha))\n", " self.alpha_logits = nn.Parameter(\n", " torch.full((config.n_layers, n_experts), logit_init)\n", " )\n", "\n", " # Monitoring buffers\n", " if config.track_validity:\n", " self.register_buffer(\"_validity_history\", torch.zeros(config.n_layers, n_experts))\n", " self.register_buffer(\"_step_count\", torch.tensor(0, dtype=torch.long))\n", "\n", " @property\n", " def alphas(self) -> torch.Tensor:\n", " \"\"\"Current alpha values after sigmoid + clamping.\"\"\"\n", " raw = torch.sigmoid(self.alpha_logits)\n", " return torch.clamp(raw, min=self.config.alpha_min, max=self.config.alpha_max)\n", "\n", " def _extract_active(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Extract active dimensions from residual stream.\"\"\"\n", " if self.config.null_space_position == \"tail\":\n", " return x[..., :self.n_active]\n", " else:\n", " return x * self._active_mask\n", "\n", " def forward(\n", " self,\n", " residual: torch.Tensor,\n", " layer_idx: int,\n", " ) -> FieldModulatorOutput:\n", " \"\"\"\n", " Apply geometric field modulation to residual stream.\n", "\n", " Args:\n", " residual: (batch, seq_len, d_model) current residual state\n", " layer_idx: which transformer layer (for per-layer alpha)\n", "\n", " Returns:\n", " FieldModulatorOutput with modulated residual, validity map, expert volumes\n", " \"\"\"\n", " B, S, D = residual.shape\n", " alphas = self.alphas[layer_idx] # (n_experts,)\n", "\n", " # Extract active dimensions for geometric computation\n", " active = self._extract_active(residual) # (B, S, n_active)\n", "\n", " # Compute combined validity gate from all experts\n", " combined_gate = torch.ones(B, S, self.n_active, device=residual.device, dtype=residual.dtype)\n", " expert_volumes = {}\n", "\n", " for i, k in enumerate(self.config.expert_ks):\n", " key = f\"k{k}\"\n", " expert = self.experts[key]\n", " gate_proj = self.gate_projs[key]\n", "\n", " # Expert computes geometric features\n", " geo, vol2, validity = expert(active)\n", " expert_volumes[k] = vol2.detach()\n", "\n", " # Project geometric features to gate dimensions\n", " expert_gate = gate_proj(geo) # (B, S, n_active)\n", "\n", " # Blend: combined = combined × (1 - α + α × expert_gate)\n", " # At α=0: no effect. At α=max: full gating.\n", " # This is softer than pure multiplication — allows gradual engagement\n", " alpha_k = alphas[i]\n", " blended_gate = (1 - alpha_k) + alpha_k * expert_gate\n", " combined_gate = combined_gate * blended_gate\n", "\n", " # Monitoring\n", " if self.config.track_validity and self.training:\n", " with torch.no_grad():\n", " valid_frac = (vol2 > 0).float().mean()\n", " self._validity_history[layer_idx, i] = (\n", " 0.99 * self._validity_history[layer_idx, i] + 0.01 * valid_frac\n", " )\n", "\n", " # Apply gate to active dimensions only (null space untouched)\n", " if self.config.null_space_position == \"tail\":\n", " modulated = residual.clone()\n", " modulated[..., :self.n_active] = active * combined_gate\n", " else:\n", " modulated = residual * (1 - self._active_mask) + (active * combined_gate) * self._active_mask\n", "\n", " # Update step count\n", " if self.training:\n", " self._step_count += 1\n", "\n", " return FieldModulatorOutput(\n", " residual=modulated,\n", " validity_map=combined_gate.mean(dim=-1), # (B, S)\n", " expert_volumes=expert_volumes,\n", " )\n", "\n", " # ──────────────────────────────────────────────────────────────────────\n", " # HEALTH METRICS — no inference required\n", " # ──────────────────────────────────────────────────────────────────────\n", "\n", " def health_report(self) -> Dict[str, float]:\n", " \"\"\"\n", " Compute health metrics from current parameters.\n", " No inference needed — reads directly from weights.\n", " \"\"\"\n", " report = {}\n", "\n", " # Alpha statistics\n", " alphas = self.alphas.detach()\n", " report[\"alpha_mean\"] = alphas.mean().item()\n", " report[\"alpha_std\"] = alphas.std().item()\n", " report[\"alpha_min\"] = alphas.min().item()\n", " report[\"alpha_max\"] = alphas.max().item()\n", "\n", " # Per-layer alpha means\n", " for layer in range(self.config.n_layers):\n", " report[f\"alpha_layer_{layer}\"] = alphas[layer].mean().item()\n", "\n", " # Validity history (if tracking)\n", " if self.config.track_validity:\n", " report[\"validity_mean\"] = self._validity_history.mean().item()\n", " for i, k in enumerate(self.config.expert_ks):\n", " report[f\"validity_k{k}\"] = self._validity_history[:, i].mean().item()\n", "\n", " # Expert template volumes (frozen — sanity check)\n", " for k in self.config.expert_ks:\n", " expert = self.experts[f\"k{k}\"]\n", " _, vol2 = expert._cm(expert._template.unsqueeze(0))\n", " report[f\"template_vol2_k{k}\"] = vol2.item()\n", "\n", " # Parameter norms by component\n", " total_params = 0\n", " for name, p in self.named_parameters():\n", " total_params += p.numel()\n", " if \"deform\" in name:\n", " report[f\"deform_norm_{name.split('.')[1]}\"] = p.data.norm().item()\n", " report[\"total_params\"] = total_params\n", "\n", " return report\n", "\n", " def cv_check(self, embeddings: torch.Tensor, n_simplices: int = 500) -> Dict[str, float]:\n", " \"\"\"\n", " Measure pentachoron CV on a set of embeddings.\n", " Uses the k=4 expert's CM validator if available, else builds one.\n", "\n", " Args:\n", " embeddings: (N, d_model) tensor of representations\n", " n_simplices: number of random 5-point simplices to measure\n", "\n", " Returns:\n", " dict with cv, mean_vol, valid_fraction\n", " \"\"\"\n", " if 4 in self.config.expert_ks:\n", " cm = self.experts[\"k4\"]._cm\n", " else:\n", " cm = CMValidator(4).to(embeddings.device)\n", "\n", " N = embeddings.shape[0]\n", " active = embeddings[..., :self.n_active]\n", "\n", " vols = []\n", " rng = torch.Generator().manual_seed(42)\n", "\n", " for _ in range(n_simplices):\n", " idx = torch.randint(N, (5,), generator=rng)\n", " pts = active[idx].unsqueeze(0) # (1, 5, n_active)\n", "\n", " # Pad to edim if needed\n", " edim = self.config.expert_edim\n", " if pts.shape[-1] > edim:\n", " # Project to edim dimensions (use first edim for simplicity)\n", " pts_proj = pts[..., :edim]\n", " else:\n", " pts_proj = pts\n", "\n", " _, vol2 = cm(pts_proj)\n", " if vol2.item() > 0:\n", " vols.append(math.sqrt(vol2.item()))\n", "\n", " if len(vols) == 0:\n", " return {\"cv\": float(\"nan\"), \"mean_vol\": 0.0, \"valid_fraction\": 0.0}\n", "\n", " vols_arr = torch.tensor(vols)\n", " cv = (vols_arr.std() / vols_arr.mean()).item()\n", " return {\n", " \"cv\": cv,\n", " \"mean_vol\": vols_arr.mean().item(),\n", " \"valid_fraction\": len(vols) / n_simplices,\n", " \"in_band\": abs(cv - self.config.cv_target) < self.config.cv_tolerance,\n", " }\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# MODULATED ENCODER WRAPPER — plugs into any encoder-decoder transformer\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "class FieldModulatedEncoder(nn.Module):\n", " \"\"\"\n", " Wraps a transformer encoder with geometric field modulation.\n", " Steps through blocks manually, applying the field modulator between blocks.\n", "\n", " The modulator constrains the residual stream multiplicatively.\n", " Null space dimensions are never touched.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " encoder: nn.Module,\n", " modulator: GeometricFieldModulator,\n", " modulate_layers: Optional[List[int]] = None,\n", " ):\n", " super().__init__()\n", " self.encoder = encoder\n", " self.modulator = modulator\n", "\n", " n_layers = len(encoder.block)\n", " if modulate_layers is None:\n", " modulate_layers = list(range(n_layers))\n", " self.modulate_layers = set(modulate_layers)\n", "\n", " # Track per-forward validity for diagnostics\n", " self._last_validity = None\n", " self._last_expert_volumes = None\n", "\n", " def forward(\n", " self,\n", " input_ids: torch.Tensor,\n", " attention_mask: Optional[torch.Tensor] = None,\n", " output_hidden_states: bool = False,\n", " **kwargs,\n", " ):\n", " # Embed\n", " hidden_states = self.encoder.embed_tokens(input_ids)\n", " hidden_states = self.encoder.dropout(hidden_states)\n", "\n", " # Prepare attention mask\n", " if attention_mask is not None:\n", " extended_mask = attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)\n", " extended_mask = (1.0 - extended_mask) * torch.finfo(hidden_states.dtype).min\n", " else:\n", " extended_mask = None\n", "\n", " all_hidden = [hidden_states] if output_hidden_states else None\n", " position_bias = None\n", " cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)\n", "\n", " validity_maps = []\n", " expert_vol_maps = []\n", "\n", " for i, block in enumerate(self.encoder.block):\n", " # GEOMETRIC FIELD MODULATION — constrain before block processes\n", " if i in self.modulate_layers:\n", " mod_output = self.modulator(hidden_states, layer_idx=i)\n", " hidden_states = mod_output.residual\n", " validity_maps.append(mod_output.validity_map.detach())\n", " expert_vol_maps.append(mod_output.expert_volumes)\n", "\n", " # Standard transformer block\n", " block_output = block(\n", " hidden_states,\n", " attention_mask=extended_mask,\n", " position_bias=position_bias,\n", " cache_position=cache_position,\n", " )\n", " hidden_states = block_output[0]\n", "\n", " # Extract position bias from first block\n", " if position_bias is None:\n", " for out in block_output[1:]:\n", " if isinstance(out, torch.Tensor) and out.dim() == 4:\n", " position_bias = out\n", " break\n", "\n", " if output_hidden_states:\n", " all_hidden.append(hidden_states)\n", "\n", " # Final layer norm\n", " hidden_states = self.encoder.final_layer_norm(hidden_states)\n", " hidden_states = self.encoder.dropout(hidden_states)\n", "\n", " # Store diagnostics\n", " self._last_validity = validity_maps\n", " self._last_expert_volumes = expert_vol_maps\n", "\n", " if output_hidden_states:\n", " all_hidden.append(hidden_states)\n", "\n", " return type(\"Output\", (), {\n", " \"last_hidden_state\": hidden_states,\n", " \"hidden_states\": tuple(all_hidden) if all_hidden else None,\n", " })()\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# TRAINING UTILITIES\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "class FieldModulatorTrainer:\n", " \"\"\"\n", " Training utilities for the GeometricFieldModulator.\n", " Handles gradient scaling, alpha monitoring, CV checks, early stopping.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " modulator: GeometricFieldModulator,\n", " lr: float = 1e-3,\n", " weight_decay: float = 0.01,\n", " gradient_clip: float = 1.0,\n", " ):\n", " self.modulator = modulator\n", " self.gradient_clip = gradient_clip\n", "\n", " # Separate param groups: geometric params get scaled gradients\n", " geo_params = []\n", " alpha_params = []\n", " gate_params = []\n", "\n", " for name, param in modulator.named_parameters():\n", " if \"alpha\" in name:\n", " alpha_params.append(param)\n", " elif \"expert\" in name or \"deform\" in name:\n", " geo_params.append(param)\n", " else:\n", " gate_params.append(param)\n", "\n", " self.optimizer = torch.optim.AdamW([\n", " {\"params\": geo_params, \"lr\": lr * modulator.config.gradient_scale},\n", " {\"params\": alpha_params, \"lr\": lr * 0.5}, # alpha learns slower\n", " {\"params\": gate_params, \"lr\": lr},\n", " ], weight_decay=weight_decay)\n", "\n", " # Tracking\n", " self.alpha_history = []\n", " self.cv_history = []\n", " self.loss_history = []\n", "\n", " def step(self, loss: torch.Tensor):\n", " \"\"\"Single training step with safeguards.\"\"\"\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(self.modulator.parameters(), self.gradient_clip)\n", " self.optimizer.step()\n", " self.optimizer.zero_grad()\n", "\n", " # Record alpha\n", " alphas = self.modulator.alphas.detach().cpu().numpy()\n", " self.alpha_history.append(alphas.copy())\n", " self.loss_history.append(loss.item())\n", "\n", " def check_cv(self, embeddings: torch.Tensor) -> bool:\n", " \"\"\"\n", " Check if pentachoron CV is in the universal band.\n", " Returns True if healthy, False if drifting.\n", " \"\"\"\n", " result = self.modulator.cv_check(embeddings)\n", " self.cv_history.append(result)\n", " return result.get(\"in_band\", True)\n", "\n", " def should_stop(self, patience: int = 5) -> bool:\n", " \"\"\"\n", " Early stopping if CV has been outside the universal band\n", " for `patience` consecutive checks.\n", " \"\"\"\n", " if len(self.cv_history) < patience:\n", " return False\n", " recent = self.cv_history[-patience:]\n", " return all(not r.get(\"in_band\", True) for r in recent)\n", "\n", "\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "# QUICK TEST\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "if __name__ == \"__main__\":\n", " print(\"=\" * 70)\n", " print(\"GeometricFieldModulator — Self Test\")\n", " print(\"=\" * 70)\n", "\n", " config = FieldModulatorConfig(\n", " d_model=512,\n", " vocab_size=32128,\n", " n_layers=6,\n", " expert_ks=(1, 2, 4),\n", " expert_edim=8,\n", " initial_alpha=0.01,\n", " null_space_fraction=0.25,\n", " )\n", "\n", " modulator = GeometricFieldModulator(config)\n", "\n", " # Count params\n", " total = sum(p.numel() for p in modulator.parameters())\n", " trainable = sum(p.numel() for p in modulator.parameters() if p.requires_grad)\n", " print(f\"\\nTotal params: {total:,}\")\n", " print(f\"Trainable: {trainable:,}\")\n", "\n", " # Architecture summary\n", " print(f\"\\nExperts:\")\n", " for k in config.expert_ks:\n", " expert = modulator.experts[f\"k{k}\"]\n", " n = sum(p.numel() for p in expert.parameters())\n", " print(f\" k={k}: {expert._nv} vertices, {expert._cm.n_pairs} pairs, \"\n", " f\"out_dim={expert.out_dim}, params={n:,}\")\n", "\n", " print(f\"\\nActive dims: {modulator.n_active} / {config.d_model}\")\n", " print(f\"Null space: {modulator.n_null} / {config.d_model}\")\n", " print(f\"Alpha range: [{config.alpha_min}, {config.alpha_max}]\")\n", "\n", " # Forward pass test\n", " B, S, D = 2, 16, config.d_model\n", " x = torch.randn(B, S, D)\n", "\n", " print(f\"\\nForward pass test: ({B}, {S}, {D})\")\n", " for layer in range(config.n_layers):\n", " output = modulator(x, layer_idx=layer)\n", " alpha = modulator.alphas[layer]\n", " print(f\" Layer {layer}: validity={output.validity_map.mean():.4f}, \"\n", " f\"alpha=[{', '.join(f'{a:.4f}' for a in alpha.tolist())}]\")\n", "\n", " # Null space preservation check\n", " null_start = modulator.n_active\n", " original_null = x[..., null_start:]\n", " modulated_null = output.residual[..., null_start:]\n", " null_preserved = torch.allclose(original_null, modulated_null, atol=1e-6)\n", " print(f\"\\nNull space preserved: {null_preserved}\")\n", "\n", " # Health report\n", " print(f\"\\nHealth report:\")\n", " health = modulator.health_report()\n", " for k, v in health.items():\n", " if isinstance(v, float):\n", " print(f\" {k}: {v:.6f}\")\n", " else:\n", " print(f\" {k}: {v}\")\n", "\n", " # CV check with random embeddings\n", " test_embeds = torch.randn(1000, D)\n", " cv_result = modulator.cv_check(test_embeds)\n", " print(f\"\\nCV check (random embeddings):\")\n", " for k, v in cv_result.items():\n", " print(f\" {k}: {v}\")\n", "\n", " print(f\"\\n{'=' * 70}\")\n", " print(\"All tests passed.\")\n", " print(f\"{'=' * 70}\")\n" ] }, { "cell_type": "markdown", "id": "188a456e", "metadata": {}, "source": [ "## 13. T5Gemma2 Battery\n", "*Section XII: Gemma 2 adapted to enc-dec. Tests whether Q sparsity is architectural or pretraining.*\n", "\n", "Swap `MODEL_ID` for 1B-1B vs 4B-4B. **Requires:** Gemma license." ] }, { "cell_type": "code", "execution_count": null, "id": "d66e9616", "metadata": {}, "outputs": [], "source": [ "# T5Gemma2 battery\n", "\n", "# ============================================================================\n", "# T5GEMMA2 INACTIVE WEIGHT GEOMETRY\n", "# Gemma 2 adapted into encoder-decoder architecture\n", "# Tests whether Q sparsity survives decoder→encoder adaptation\n", "# Run on both 1B-1B (2.1B total) and 4B-4B (8.9B total)\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import time\n", "import gc\n", "from collections import defaultdict\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name()}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n", "\n", "# ── CHANGE THIS TO SWITCH MODELS ──\n", "# MODEL_ID = \"google/t5gemma-2-1b-1b\" # 2.1B params — fits on T4\n", "MODEL_ID = \"google/t5gemma-2-4b-4b\" # 8.9B params — needs ~20GB\n", "# ───────────────────────────────────\n", "\n", "print(\"=\" * 70)\n", "print(f\"T5GEMMA2: {MODEL_ID}\")\n", "print(\"=\" * 70)\n", "\n", "# Load\n", "print(f\"\\nLoading {MODEL_ID} (fp16)...\")\n", "t0 = time.time()\n", "from transformers import AutoModelForSeq2SeqLM\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\n", " MODEL_ID, torch_dtype=torch.float16, device_map=\"auto\",\n", ")\n", "model.eval()\n", "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Loaded in {time.time()-t0:.0f}s, {total_params:,} params\")\n", "if torch.cuda.is_available():\n", " print(f\"VRAM used: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "\n", "# Print ALL config attributes to discover naming\n", "config = model.config\n", "print(f\"\\nConfig (all attributes):\")\n", "config_dict = config.to_dict() if hasattr(config, 'to_dict') else vars(config)\n", "for key, val in sorted(config_dict.items()):\n", " if key.startswith('_') or key in ('transformers_version', 'torch_dtype'): continue\n", " if isinstance(val, (int, float, str, bool, type(None))):\n", " print(f\" {key} = {val}\")\n", "\n", "# Adaptive config reading — try Gemma-style then T5-style\n", "d_model = getattr(config, 'hidden_size', None) or getattr(config, 'd_model', None)\n", "n_heads = getattr(config, 'num_attention_heads', None) or getattr(config, 'num_heads', None)\n", "n_kv_heads = getattr(config, 'num_key_value_heads', n_heads)\n", "n_enc_layers = getattr(config, 'num_hidden_layers', None) or getattr(config, 'num_layers', None)\n", "n_dec_layers = getattr(config, 'num_decoder_layers', n_enc_layers)\n", "d_ff = getattr(config, 'intermediate_size', None) or getattr(config, 'd_ff', None)\n", "head_dim = getattr(config, 'head_dim', None) or (d_model // n_heads if d_model and n_heads else None)\n", "\n", "# Check for separate encoder/decoder configs (T5Gemma2 may nest them)\n", "enc_config = getattr(config, 'encoder', None) or getattr(config, 'encoder_config', None)\n", "dec_config = getattr(config, 'decoder', None) or getattr(config, 'decoder_config', None)\n", "if enc_config and hasattr(enc_config, 'hidden_size'):\n", " print(f\"\\n [Nested encoder config found]\")\n", " d_model = d_model or enc_config.hidden_size\n", " n_heads = n_heads or getattr(enc_config, 'num_attention_heads', None)\n", " n_kv_heads = getattr(enc_config, 'num_key_value_heads', n_heads)\n", " n_enc_layers = n_enc_layers or getattr(enc_config, 'num_hidden_layers', None)\n", " d_ff = d_ff or getattr(enc_config, 'intermediate_size', None)\n", "if dec_config and hasattr(dec_config, 'hidden_size'):\n", " print(f\" [Nested decoder config found]\")\n", " n_dec_layers = n_dec_layers or getattr(dec_config, 'num_hidden_layers', None)\n", "\n", "# Last resort: infer from actual weight shapes\n", "if d_model is None:\n", " # Prefer decoder/language model weights over vision tower\n", " for name, param in model.named_parameters():\n", " if 'vision' in name:\n", " continue # skip vision tower\n", " if ('embed' in name or 'shared' in name) and param.dim() == 2:\n", " d_model = param.shape[-1]\n", " print(f\" [Inferred d_model={d_model} from {name}]\")\n", " break\n", " # Fallback to any embedding\n", " if d_model is None:\n", " for name, param in model.named_parameters():\n", " if 'embed' in name and param.dim() == 2:\n", " d_model = param.shape[-1]\n", " print(f\" [Inferred d_model={d_model} from {name} (fallback)]\")\n", " break\n", "\n", "if n_enc_layers is None or n_dec_layers is None:\n", " enc_layers_found = set()\n", " dec_layers_found = set()\n", " for name in dict(model.named_parameters()).keys():\n", " parts = name.split('.')\n", " for i, p in enumerate(parts):\n", " if p == 'layers' and i+1 < len(parts):\n", " try:\n", " l = int(parts[i+1])\n", " if 'encoder' in name: enc_layers_found.add(l)\n", " elif 'decoder' in name: dec_layers_found.add(l)\n", " except: pass\n", " if n_enc_layers is None and enc_layers_found:\n", " n_enc_layers = max(enc_layers_found) + 1\n", " print(f\" [Inferred {n_enc_layers} encoder layers from param names]\")\n", " if n_dec_layers is None and dec_layers_found:\n", " n_dec_layers = max(dec_layers_found) + 1\n", " print(f\" [Inferred {n_dec_layers} decoder layers from param names]\")\n", "\n", "print(f\"\\n RESOLVED: d_model={d_model}, d_ff={d_ff}, heads={n_heads}, kv_heads={n_kv_heads}, \"\n", " f\"head_dim={head_dim}, enc_layers={n_enc_layers}, dec_layers={n_dec_layers}\")\n", "if n_kv_heads and n_heads and n_kv_heads != n_heads:\n", " print(f\" GQA: {n_heads} query heads, {n_kv_heads} KV heads (ratio {n_heads // n_kv_heads}:1)\")\n", "\n", "# ── CLASSIFY WEIGHTS ──\n", "print(f\"\\n{'='*70}\\nCATALOG\\n{'='*70}\")\n", "\n", "catalog = defaultdict(list)\n", "for name, param in model.named_parameters():\n", " if param.dim() != 2:\n", " continue\n", " parts = name.split('.')\n", "\n", " # Location\n", " if 'encoder' in name:\n", " loc = 'encoder'\n", " elif 'decoder' in name:\n", " loc = 'decoder'\n", " else:\n", " loc = 'shared'\n", "\n", " # Layer number — try 'layers', 'block', 'layer'\n", " layer = -1\n", " for i, p in enumerate(parts):\n", " if p in ('layers', 'block', 'layer') and i+1 < len(parts):\n", " try: layer = int(parts[i+1])\n", " except: pass\n", "\n", " # Weight type — Gemma2 uses q_proj, k_proj, v_proj, o_proj\n", " if 'embed' in name or 'shared' in name:\n", " wt = 'embedding'\n", " elif 'norm' in name:\n", " wt = 'layernorm'\n", " elif 'self_attn' in name or 'SelfAttention' in name:\n", " if 'q_proj' in name: wt = 'self_attn_q'\n", " elif 'k_proj' in name: wt = 'self_attn_k'\n", " elif 'v_proj' in name: wt = 'self_attn_v'\n", " elif 'o_proj' in name: wt = 'self_attn_o'\n", " # T5-style naming fallback\n", " elif name.endswith('.q.weight'): wt = 'self_attn_q'\n", " elif name.endswith('.k.weight'): wt = 'self_attn_k'\n", " elif name.endswith('.v.weight'): wt = 'self_attn_v'\n", " elif name.endswith('.o.weight'): wt = 'self_attn_o'\n", " else: wt = 'self_attn_other'\n", " elif 'cross_attn' in name or 'EncDecAttention' in name or 'encoder_attn' in name:\n", " if 'q_proj' in name or name.endswith('.q.weight'): wt = 'cross_attn_q'\n", " elif 'k_proj' in name or name.endswith('.k.weight'): wt = 'cross_attn_k'\n", " elif 'v_proj' in name or name.endswith('.v.weight'): wt = 'cross_attn_v'\n", " elif 'o_proj' in name or name.endswith('.o.weight'): wt = 'cross_attn_o'\n", " else: wt = 'cross_attn_other'\n", " elif 'mlp' in name or 'DenseReluDense' in name:\n", " if 'gate_proj' in name or 'wi_0' in name: wt = 'mlp_gate'\n", " elif 'up_proj' in name or 'wi_1' in name: wt = 'mlp_up'\n", " elif 'down_proj' in name or 'wo' in name: wt = 'mlp_down'\n", " else: wt = 'mlp_other'\n", " elif 'lm_head' in name:\n", " wt = 'lm_head'\n", " else:\n", " wt = 'other'\n", "\n", " catalog[wt].append({'name': name, 'shape': tuple(param.shape), 'loc': loc,\n", " 'layer': layer, 'numel': param.numel()})\n", "\n", "for wt, entries in sorted(catalog.items()):\n", " t = sum(e['numel'] for e in entries)\n", " enc = sum(1 for e in entries if e['loc']=='encoder')\n", " dec = sum(1 for e in entries if e['loc']=='decoder')\n", " shapes = set(str(e['shape']) for e in entries)\n", " print(f\" {wt:25s}: {len(entries):4d} (E:{enc} D:{dec}) {t:>15,} {shapes}\")\n", "\n", "# Helper\n", "def get_w(name):\n", " parts = name.split('.')\n", " obj = model\n", " for p in parts:\n", " obj = obj[int(p)] if p.isdigit() else getattr(obj, p)\n", " return obj.detach().float()\n", "\n", "# All layers — no sampling\n", "enc_layers = list(range(n_enc_layers)) if n_enc_layers else []\n", "dec_layers = list(range(n_dec_layers)) if n_dec_layers else []\n", "print(f\"\\nEncoder layers: {len(enc_layers)}\")\n", "print(f\"Decoder layers: {len(dec_layers)}\")\n", "\n", "# ── SVD ──\n", "print(f\"\\n{'='*70}\\nSVD EFFECTIVE RANK\\n{'='*70}\")\n", "svd_results = []\n", "skip = {'embedding','layernorm','other','self_attn_other','cross_attn_other','mlp_other','lm_head'}\n", "all_entries = [(wt,e) for wt,entries in catalog.items() if wt not in skip for e in entries]\n", "total = len(all_entries)\n", "t0 = time.time()\n", "\n", "for idx, (wt, entry) in enumerate(all_entries):\n", " if (idx+1) % 10 == 0 or idx == 0:\n", " print(f\" [{idx+1}/{total}] {wt} {entry['loc']} L{entry['layer']} \", end=\"\\r\")\n", " try:\n", " W = get_w(entry['name'])\n", " S = torch.linalg.svdvals(W).cpu().numpy()\n", " svd_results.append({\n", " 'wt': wt, 'loc': entry['loc'], 'layer': entry['layer'], 'shape': entry['shape'],\n", " 'sr': (S**2).sum()/(S[0]**2) if S[0]>0 else 0,\n", " 'pr': (S.sum())**2/((S**2).sum()) if (S**2).sum()>0 else 0,\n", " 'af': (S>0.01*S[0]).sum()/len(S),\n", " 'r90': np.searchsorted(np.cumsum(S)/S.sum(),0.90)+1,\n", " 'cond': S[0]/(S[-1]+1e-10),\n", " })\n", " except Exception as e:\n", " print(f\"\\n SVD fail {entry['name']}: {e}\")\n", "\n", "print(f\" Done: {len(svd_results)} matrices in {time.time()-t0:.0f}s \")\n", "\n", "by_type = defaultdict(list)\n", "for r in svd_results: by_type[r['wt']].append(r)\n", "print(f\"\\n{'Type':25s} {'SR':>8s} {'PR':>8s} {'Act%':>6s} {'R90':>5s} {'Cond':>12s}\")\n", "for w in sorted(by_type):\n", " s = by_type[w]\n", " print(f\" {w:23s} {np.mean([r['sr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['pr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['af'] for r in s]):6.3f} \"\n", " f\"{np.mean([r['r90'] for r in s]):5.0f} \"\n", " f\"{np.mean([r['cond'] for r in s]):12.1f}\")\n", "\n", "# ── SPARSITY ──\n", "print(f\"\\n{'='*70}\\nSPARSITY\\n{'='*70}\")\n", "thresholds = [1e-4,1e-3,1e-2,1e-1]\n", "sparsity = defaultdict(lambda: {'total':0, **{t:0 for t in thresholds}})\n", "t0 = time.time()\n", "for wt, entries in catalog.items():\n", " if wt in skip: continue\n", " for entry in entries:\n", " W = get_w(entry['name']); a = W.abs(); n = a.numel()\n", " sparsity[wt]['total'] += n\n", " for t in thresholds:\n", " sparsity[wt][t] += (a8s} {'<1e-3':>8s} {'<0.01':>8s} {'<0.1':>8s}\")\n", "for wt in sorted(sparsity):\n", " sc = sparsity[wt]\n", " if sc['total']==0: continue\n", " print(f\" {wt:23s} {sc[1e-4]/sc['total']:8.4f} {sc[1e-3]/sc['total']:8.4f} \"\n", " f\"{sc[1e-2]/sc['total']:8.4f} {sc[1e-1]/sc['total']:8.4f}\")\n", "\n", "# Encoder vs decoder\n", "print(f\"\\n--- ENCODER vs DECODER SPARSITY (<0.1) ---\")\n", "for wt in ['self_attn_q','self_attn_k','self_attn_v','cross_attn_q','cross_attn_k','cross_attn_v']:\n", " for loc in ['encoder','decoder']:\n", " entries = [e for e in catalog.get(wt,[]) if e['loc']==loc]\n", " if not entries: continue\n", " total_n = 0; below = 0\n", " for e in entries:\n", " W = get_w(e['name']); a = W.abs()\n", " total_n += a.numel(); below += (a<0.1).sum().item()\n", " if total_n > 0:\n", " print(f\" {loc:8s} {wt:20s}: {below/total_n*100:.1f}%\")\n", "\n", "# GQA Q/K shape comparison\n", "print(f\"\\n--- GQA SHAPE ANALYSIS ---\")\n", "for wt in ['self_attn_q','self_attn_k','self_attn_v','self_attn_o']:\n", " entries = catalog.get(wt, [])\n", " if entries:\n", " shapes = set(str(e['shape']) for e in entries)\n", " print(f\" {wt:20s}: {shapes}\")\n", "\n", "# ── QK MANIFOLD — eigvalsh on CPU ──\n", "print(f\"\\n{'='*70}\\nQK MANIFOLD (eigvalsh on CPU)\\n{'='*70}\")\n", "\n", "for loc, layers in [('encoder', enc_layers), ('decoder', dec_layers)]:\n", " print(f\"\\n--- {loc.upper()} self-attention ---\")\n", " q_map = {e['layer']:e['name'] for e in catalog.get('self_attn_q',[]) if e['loc']==loc}\n", " k_map = {e['layer']:e['name'] for e in catalog.get('self_attn_k',[]) if e['loc']==loc}\n", "\n", " qk_results = []\n", " for layer in layers:\n", " if layer not in q_map or layer not in k_map: continue\n", " try:\n", " t0 = time.time()\n", " Wq = get_w(q_map[layer])\n", " Wk = get_w(k_map[layer])\n", "\n", " # GQA: Q is [q_dim, d_model], K is [kv_dim, d_model]\n", " # q_dim = n_heads * head_dim, kv_dim = n_kv_heads * head_dim\n", " # To compute QK^T we need matching output dims\n", " # Repeat K rows to match Q's output dimension\n", " if Wq.shape[0] != Wk.shape[0]:\n", " repeat_factor = Wq.shape[0] // Wk.shape[0]\n", " # Reshape K to [n_kv_heads, head_dim, d_model], repeat, reshape back\n", " Wk_expanded = Wk.repeat_interleave(repeat_factor, dim=0)\n", " QK = Wq @ Wk_expanded.T\n", " del Wk_expanded\n", " print(f\" L{layer:2d} (GQA {repeat_factor}:1): \", end=\"\")\n", " elif Wq.shape[1] != Wk.shape[1]:\n", " # Different input dims — shouldn't happen but handle gracefully\n", " min_dim = min(Wq.shape[1], Wk.shape[1])\n", " QK = Wq[:, :min_dim] @ Wk[:, :min_dim].T\n", " print(f\" L{layer:2d} (truncated): \", end=\"\")\n", " else:\n", " QK = Wq @ Wk.T\n", " print(f\" L{layer:2d}: \", end=\"\")\n", " del Wq, Wk\n", "\n", " sym = torch.norm(QK-QK.T).item()/(torch.norm(QK).item()+1e-10)\n", "\n", " # CPU for eigvalsh\n", " QK_cpu = ((QK+QK.T)/2).cpu()\n", " del QK\n", " if torch.cuda.is_available(): torch.cuda.empty_cache()\n", "\n", " eig = torch.linalg.eigvalsh(QK_cpu).numpy()[::-1]\n", " del QK_cpu\n", " n_pos=(eig>0).sum(); n_neg=(eig<0).sum(); dim=len(eig)\n", "\n", " print(f\"pos={n_pos}({n_pos/dim:.3f}), neg={n_neg}({n_neg/dim:.3f}), \"\n", " f\"sym={sym:.4f}, top={eig[0]:.2f} ({time.time()-t0:.1f}s)\")\n", " qk_results.append({'layer':layer,'n_pos':n_pos,'n_neg':n_neg,'dim':dim})\n", " del eig; gc.collect()\n", " except Exception as e:\n", " print(f\"FAIL — {e}\")\n", "\n", " if len(qk_results)>=2:\n", " f,l=qk_results[0],qk_results[-1]\n", " print(f\" Trend: L{f['layer']}={f['n_pos']/f['dim']:.3f} → L{l['layer']}={l['n_pos']/l['dim']:.3f}\")\n", "\n", "# Cross-attention QK\n", "if catalog.get('cross_attn_q') and catalog.get('cross_attn_k'):\n", " print(f\"\\n--- DECODER cross-attention ---\")\n", " xq_map = {e['layer']:e['name'] for e in catalog.get('cross_attn_q',[]) if e['loc']=='decoder'}\n", " xk_map = {e['layer']:e['name'] for e in catalog.get('cross_attn_k',[]) if e['loc']=='decoder'}\n", " for layer in dec_layers:\n", " if layer not in xq_map or layer not in xk_map: continue\n", " try:\n", " t0 = time.time()\n", " Wq = get_w(xq_map[layer]); Wk = get_w(xk_map[layer])\n", " if Wq.shape[0] != Wk.shape[0]:\n", " repeat = Wq.shape[0] // Wk.shape[0]\n", " Wk = Wk.repeat_interleave(repeat, dim=0)\n", " QK = Wq @ Wk.T; del Wq, Wk\n", " sym = torch.norm(QK-QK.T).item()/(torch.norm(QK).item()+1e-10)\n", " QK_cpu = ((QK+QK.T)/2).cpu(); del QK\n", " if torch.cuda.is_available(): torch.cuda.empty_cache()\n", " eig = torch.linalg.eigvalsh(QK_cpu).numpy()[::-1]; del QK_cpu\n", " n_pos=(eig>0).sum(); n_neg=(eig<0).sum(); dim=len(eig)\n", " print(f\" L{layer:2d}: pos={n_pos}({n_pos/dim:.3f}), neg={n_neg}({n_neg/dim:.3f}), \"\n", " f\"sym={sym:.4f}, top={eig[0]:.2f} ({time.time()-t0:.1f}s)\")\n", " del eig; gc.collect()\n", " except Exception as e:\n", " print(f\" L{layer}: FAIL — {e}\")\n", "\n", "# ── DEAD NEURONS ──\n", "print(f\"\\n{'='*70}\\nMLP DEAD NEURONS (GeGLU)\\n{'='*70}\")\n", "for loc, layers in [('encoder', enc_layers), ('decoder', dec_layers)]:\n", " print(f\"\\n--- {loc.upper()} ---\")\n", " g_map = {e['layer']:e['name'] for e in catalog.get('mlp_gate',[]) if e['loc']==loc}\n", " u_map = {e['layer']:e['name'] for e in catalog.get('mlp_up',[]) if e['loc']==loc}\n", " d_map = {e['layer']:e['name'] for e in catalog.get('mlp_down',[]) if e['loc']==loc}\n", " td=tn=0\n", " for layer in layers:\n", " if layer not in g_map or layer not in u_map or layer not in d_map: continue\n", " Wg = get_w(g_map[layer]); gn = torch.norm(Wg,dim=1).cpu().numpy(); del Wg\n", " Wu = get_w(u_map[layer]); un = torch.norm(Wu,dim=1).cpu().numpy(); del Wu\n", " Wd = get_w(d_map[layer]); dn = torch.norm(Wd,dim=0).cpu().numpy(); del Wd\n", " c = gn*un*dn; d_ff=len(c); mc=c.mean()\n", " dead=(c<0.01*mc).sum(); weak=(c<0.10*mc).sum()\n", " td+=dead; tn+=d_ff\n", " print(f\" L{layer:2d}: d_ff={d_ff}, dead={dead}({dead/d_ff*100:.1f}%), weak={weak}({weak/d_ff*100:.1f}%)\")\n", " if tn: print(f\" Total: {td}/{tn} ({td/tn*100:.2f}%)\")\n", "\n", "# ── CROSS-LAYER Q ──\n", "print(f\"\\n{'='*70}\\nCROSS-LAYER Q CORRELATION\\n{'='*70}\")\n", "for loc, layers in [('encoder', enc_layers), ('decoder', dec_layers)]:\n", " q_map = {e['layer']:e['name'] for e in catalog.get('self_attn_q',[]) if e['loc']==loc}\n", " qf = {}\n", " for l in layers:\n", " if l in q_map:\n", " try: qf[l] = get_w(q_map[l]).cpu().flatten()\n", " except: pass\n", " if len(qf)>=2:\n", " ls = sorted(qf)\n", " adj = []\n", " for i in range(len(ls)-1):\n", " a,b = qf[ls[i]],qf[ls[i+1]]\n", " if a.shape==b.shape:\n", " adj.append((torch.dot(a,b)/(torch.norm(a)*torch.norm(b)+1e-8)).item())\n", " if adj:\n", " print(f\" {loc} adj Q cos: mean={np.mean(adj):.4f}, range=[{min(adj):.4f},{max(adj):.4f}]\")\n", " del qf; gc.collect()\n", "\n", "# ── SUMMARY ──\n", "print(f\"\\n{'='*70}\\nSUMMARY — {MODEL_ID}\\n{'='*70}\")\n", "print(f\"Params: {total_params:,}\")\n", "print(f\"d_model={d_model}, heads={n_heads}, kv_heads={n_kv_heads}\"\n", " + (f\" (GQA {n_heads//n_kv_heads}:1)\" if n_kv_heads and n_heads and n_kv_heads != n_heads else \"\"))\n", "print(f\"Layers: {n_enc_layers} enc + {n_dec_layers} dec\")\n", "print(f\"Architecture: t5gemma2 (Gemma 2 adapted to encoder-decoder)\")\n", "for w in ['self_attn_q','self_attn_k','self_attn_v','cross_attn_q']:\n", " sc = sparsity.get(w,{})\n", " if sc.get('total',0) > 0:\n", " print(f\" {w} (<0.1): {sc[1e-1]/sc['total']*100:.1f}%\")\n", "\n", "print(f\"\\nReference (T5 family):\")\n", "print(f\" T5-Small Q: 93.7% | T5-Base Q: 99.4% | T5-v1.1-XXL Q: 100.0%\")\n", "print(f\" BERT-large Q: 99.1% (all uniform) | DINOv2: 100% (all uniform)\")\n", "if torch.cuda.is_available():\n", " print(f\"\\nVRAM at end: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "print(\"Done.\")\n" ] }, { "cell_type": "markdown", "id": "43f4be8c", "metadata": {}, "source": [ "## 14. Diffusion UNet Geometry\n", "*Section XIII: SD 1.5 / SDXL UNet — U-path QK gradient, cross-attn lock*\n", "\n", "Swap `MODEL_ID` at top." ] }, { "cell_type": "code", "execution_count": null, "id": "a9c7a012", "metadata": {}, "outputs": [], "source": [ "# UNet geometry\n", "\n", "# ============================================================================\n", "# STABLE DIFFUSION UNET — INACTIVE WEIGHT GEOMETRY\n", "# SD 1.5 (~860M) and SDXL (~2.6B)\n", "# Completely different architecture: UNet with ResNet + Attention\n", "# Tests whether geometric invariants hold outside transformers\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import time\n", "import gc\n", "from collections import defaultdict\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name()}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n", "\n", "# ── CHANGE THIS TO SWITCH MODELS ──\n", "# MODEL_ID = \"stable-diffusion-v1-5/stable-diffusion-v1-5\" # SD 1.5\n", "MODEL_ID = \"stabilityai/stable-diffusion-xl-base-1.0\" # SDXL\n", "# ───────────────────────────────────\n", "\n", "print(\"=\" * 70)\n", "print(f\"UNET GEOMETRY: {MODEL_ID}\")\n", "print(\"=\" * 70)\n", "\n", "# Load UNet only — not the full pipeline\n", "print(f\"\\nLoading UNet (fp16)...\")\n", "t0 = time.time()\n", "from diffusers import UNet2DConditionModel\n", "unet = UNet2DConditionModel.from_pretrained(\n", " MODEL_ID, subfolder=\"unet\", torch_dtype=torch.float16,\n", ")\n", "unet = unet.to(device)\n", "unet.eval()\n", "total_params = sum(p.numel() for p in unet.parameters())\n", "print(f\"Loaded in {time.time()-t0:.0f}s, {total_params:,} params\")\n", "if torch.cuda.is_available():\n", " print(f\"VRAM used: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "\n", "# UNet config\n", "config = unet.config\n", "print(f\"\\nConfig:\")\n", "for key in ['sample_size', 'in_channels', 'out_channels', 'block_out_channels',\n", " 'layers_per_block', 'cross_attention_dim', 'attention_head_dim',\n", " 'transformer_layers_per_block']:\n", " val = getattr(config, key, None)\n", " if val is not None:\n", " print(f\" {key} = {val}\")\n", "\n", "# ── CLASSIFY WEIGHTS ──\n", "# UNet structure:\n", "# down_blocks[i].attentions[j].transformer_blocks[k].attn1 (self-attention)\n", "# down_blocks[i].attentions[j].transformer_blocks[k].attn2 (cross-attention to text)\n", "# down_blocks[i].resnets[j] (ResNet blocks)\n", "# mid_block.attentions[0].transformer_blocks[k].attn1/attn2\n", "# up_blocks[i].attentions[j].transformer_blocks[k].attn1/attn2\n", "# up_blocks[i].resnets[j]\n", "\n", "print(f\"\\n{'='*70}\\nCATALOG\\n{'='*70}\")\n", "\n", "catalog = defaultdict(list)\n", "for name, param in unet.named_parameters():\n", " if param.dim() != 2:\n", " continue\n", " parts = name.split('.')\n", "\n", " # Location in U-net\n", " if 'down_blocks' in name:\n", " loc = 'down'\n", " elif 'mid_block' in name:\n", " loc = 'mid'\n", " elif 'up_blocks' in name:\n", " loc = 'up'\n", " else:\n", " loc = 'other'\n", "\n", " # Block index\n", " block_idx = -1\n", " for i, p in enumerate(parts):\n", " if p in ('down_blocks', 'up_blocks') and i+1 < len(parts):\n", " try: block_idx = int(parts[i+1])\n", " except: pass\n", "\n", " # Layer index within block\n", " layer_idx = -1\n", " for i, p in enumerate(parts):\n", " if p in ('attentions', 'resnets', 'transformer_blocks') and i+1 < len(parts):\n", " try: layer_idx = int(parts[i+1])\n", " except: pass\n", "\n", " # Weight type\n", " if 'attn1' in name: # self-attention\n", " if 'to_q' in name: wt = 'self_attn_q'\n", " elif 'to_k' in name: wt = 'self_attn_k'\n", " elif 'to_v' in name: wt = 'self_attn_v'\n", " elif 'to_out' in name: wt = 'self_attn_o'\n", " else: wt = 'self_attn_other'\n", " elif 'attn2' in name: # cross-attention (to text encoder)\n", " if 'to_q' in name: wt = 'cross_attn_q'\n", " elif 'to_k' in name: wt = 'cross_attn_k'\n", " elif 'to_v' in name: wt = 'cross_attn_v'\n", " elif 'to_out' in name: wt = 'cross_attn_o'\n", " else: wt = 'cross_attn_other'\n", " elif 'ff' in name and 'net' in name: # feedforward in transformer block\n", " if 'net.0' in name: wt = 'ff_gate' # GEGLU first projection\n", " elif 'net.2' in name: wt = 'ff_down' # output projection\n", " else: wt = 'ff_other'\n", " elif 'resnets' in name:\n", " if 'conv' in name or 'in_layers' in name or 'out_layers' in name:\n", " continue # skip conv weights (not 2D linear)\n", " wt = 'resnet_linear'\n", " elif 'proj_in' in name: wt = 'proj_in'\n", " elif 'proj_out' in name: wt = 'proj_out'\n", " elif 'conv' in name:\n", " continue # skip convolutions\n", " elif 'norm' in name:\n", " wt = 'norm'\n", " continue # skip norms\n", " elif 'time_emb' in name or 'emb_layers' in name:\n", " wt = 'time_embed'\n", " else:\n", " wt = 'other'\n", "\n", " catalog[wt].append({\n", " 'name': name, 'shape': tuple(param.shape), 'loc': loc,\n", " 'block': block_idx, 'layer': layer_idx, 'numel': param.numel(),\n", " })\n", "\n", "for wt, entries in sorted(catalog.items()):\n", " t = sum(e['numel'] for e in entries)\n", " locs = defaultdict(int)\n", " for e in entries: locs[e['loc']] += 1\n", " loc_str = ' '.join(f\"{k}:{v}\" for k,v in sorted(locs.items()))\n", " shapes = set(str(e['shape']) for e in entries)\n", " print(f\" {wt:25s}: {len(entries):4d} ({loc_str}) {t:>12,} shapes={len(shapes)}\")\n", "\n", "# Helper\n", "def get_w(name):\n", " parts = name.split('.')\n", " obj = unet\n", " for p in parts:\n", " obj = obj[int(p)] if p.isdigit() else getattr(obj, p)\n", " return obj.detach().float()\n", "\n", "# ── SVD ──\n", "print(f\"\\n{'='*70}\\nSVD EFFECTIVE RANK\\n{'='*70}\")\n", "svd_results = []\n", "skip = {'other', 'norm', 'resnet_linear', 'time_embed'}\n", "all_entries = [(wt,e) for wt,entries in catalog.items() if wt not in skip for e in entries]\n", "total = len(all_entries)\n", "t0 = time.time()\n", "\n", "for idx, (wt, entry) in enumerate(all_entries):\n", " if (idx+1) % 10 == 0 or idx == 0:\n", " print(f\" [{idx+1}/{total}] {wt} {entry['loc']} \", end=\"\\r\")\n", " try:\n", " W = get_w(entry['name'])\n", " S = torch.linalg.svdvals(W).cpu().numpy()\n", " svd_results.append({\n", " 'wt': wt, 'loc': entry['loc'], 'block': entry['block'],\n", " 'shape': entry['shape'],\n", " 'sr': (S**2).sum()/(S[0]**2) if S[0]>0 else 0,\n", " 'pr': (S.sum())**2/((S**2).sum()) if (S**2).sum()>0 else 0,\n", " 'af': (S>0.01*S[0]).sum()/len(S),\n", " 'r90': np.searchsorted(np.cumsum(S)/S.sum(),0.90)+1,\n", " 'cond': S[0]/(S[-1]+1e-10),\n", " })\n", " except Exception as e:\n", " print(f\"\\n SVD fail {entry['name']}: {e}\")\n", "\n", "print(f\" Done: {len(svd_results)} matrices in {time.time()-t0:.0f}s \")\n", "\n", "by_type = defaultdict(list)\n", "for r in svd_results: by_type[r['wt']].append(r)\n", "print(f\"\\n{'Type':25s} {'SR':>8s} {'PR':>8s} {'Act%':>6s} {'R90':>5s} {'Cond':>12s}\")\n", "for w in sorted(by_type):\n", " s = by_type[w]\n", " print(f\" {w:23s} {np.mean([r['sr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['pr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['af'] for r in s]):6.3f} \"\n", " f\"{np.mean([r['r90'] for r in s]):5.0f} \"\n", " f\"{np.mean([r['cond'] for r in s]):12.1f}\")\n", "\n", "# SVD by U-net position (down/mid/up)\n", "print(f\"\\n--- SVD BY POSITION ---\")\n", "for loc in ['down', 'mid', 'up']:\n", " for wt in ['self_attn_q', 'self_attn_k', 'cross_attn_q', 'cross_attn_k']:\n", " subset = [r for r in svd_results if r['wt']==wt and r['loc']==loc]\n", " if subset:\n", " sr = np.mean([r['sr'] for r in subset])\n", " print(f\" {loc:4s} {wt:20s}: SR={sr:.2f} (n={len(subset)})\")\n", "\n", "# ── SPARSITY ──\n", "print(f\"\\n{'='*70}\\nSPARSITY\\n{'='*70}\")\n", "thresholds = [1e-4,1e-3,1e-2,1e-1]\n", "sparsity = defaultdict(lambda: {'total':0, **{t:0 for t in thresholds}})\n", "t0 = time.time()\n", "for wt, entries in catalog.items():\n", " if wt in skip: continue\n", " for entry in entries:\n", " W = get_w(entry['name']); a = W.abs(); n = a.numel()\n", " sparsity[wt]['total'] += n\n", " for t in thresholds:\n", " sparsity[wt][t] += (a8s} {'<1e-3':>8s} {'<0.01':>8s} {'<0.1':>8s}\")\n", "for wt in sorted(sparsity):\n", " sc = sparsity[wt]\n", " if sc['total']==0: continue\n", " print(f\" {wt:23s} {sc[1e-4]/sc['total']:8.4f} {sc[1e-3]/sc['total']:8.4f} \"\n", " f\"{sc[1e-2]/sc['total']:8.4f} {sc[1e-1]/sc['total']:8.4f}\")\n", "\n", "# Q vs K comparison\n", "print(f\"\\n--- Q/K SPARSITY (<0.1) ---\")\n", "for wt in ['self_attn_q','self_attn_k','self_attn_v','self_attn_o',\n", " 'cross_attn_q','cross_attn_k','cross_attn_v','cross_attn_o']:\n", " sc = sparsity.get(wt,{})\n", " if sc.get('total',0) > 0:\n", " print(f\" {wt:25s}: {sc[1e-1]/sc['total']*100:.1f}%\")\n", "\n", "# Sparsity by position\n", "print(f\"\\n--- SPARSITY BY POSITION (<0.1) ---\")\n", "for loc in ['down', 'mid', 'up']:\n", " for wt in ['self_attn_q', 'self_attn_k', 'cross_attn_q', 'cross_attn_k']:\n", " entries = [e for e in catalog.get(wt,[]) if e['loc']==loc]\n", " if not entries: continue\n", " total_n = 0; below = 0\n", " for e in entries:\n", " W = get_w(e['name']); a = W.abs()\n", " total_n += a.numel(); below += (a<0.1).sum().item()\n", " if total_n > 0:\n", " print(f\" {loc:4s} {wt:20s}: {below/total_n*100:.1f}%\")\n", "\n", "# ── QK MANIFOLD — eigvalsh on CPU ──\n", "print(f\"\\n{'='*70}\\nQK MANIFOLD (eigvalsh on CPU)\\n{'='*70}\")\n", "\n", "# Group Q/K by their attention block\n", "q_map = {} # (loc, block, layer, attn_type) → name\n", "k_map = {}\n", "for wt in ['self_attn_q', 'self_attn_k', 'cross_attn_q', 'cross_attn_k']:\n", " for entry in catalog.get(wt, []):\n", " attn_type = 'self' if 'self' in wt else 'cross'\n", " proj = 'q' if '_q' in wt else 'k'\n", " key = (entry['loc'], entry['block'], entry['layer'], attn_type)\n", " name_parts = entry['name'].split('.')\n", " # Find transformer_blocks index for more precise matching\n", " tb_idx = -1\n", " for i, p in enumerate(name_parts):\n", " if p == 'transformer_blocks' and i+1 < len(name_parts):\n", " try: tb_idx = int(name_parts[i+1])\n", " except: pass\n", " full_key = (entry['loc'], entry['block'], entry['layer'], tb_idx, attn_type)\n", " if proj == 'q':\n", " q_map[full_key] = entry['name']\n", " else:\n", " k_map[full_key] = entry['name']\n", "\n", "for attn_type in ['self', 'cross']:\n", " print(f\"\\n--- {attn_type.upper()}-ATTENTION ---\")\n", " qk_results = []\n", " keys = sorted([k for k in q_map if k[4]==attn_type and k in k_map])\n", "\n", " for key in keys:\n", " try:\n", " t0 = time.time()\n", " Wq = get_w(q_map[key])\n", " Wk = get_w(k_map[key])\n", "\n", " # Handle different Q/K dims (cross-attn: Q from spatial, K from text)\n", " if Wq.shape[0] != Wk.shape[0]:\n", " # GQA or cross-modal: repeat K to match Q\n", " if Wq.shape[0] > Wk.shape[0] and Wq.shape[0] % Wk.shape[0] == 0:\n", " repeat = Wq.shape[0] // Wk.shape[0]\n", " Wk = Wk.repeat_interleave(repeat, dim=0)\n", " else:\n", " # Different dims entirely — use min\n", " min_out = min(Wq.shape[0], Wk.shape[0])\n", " Wq = Wq[:min_out]; Wk = Wk[:min_out]\n", "\n", " if Wq.shape[1] != Wk.shape[1]:\n", " # Cross-attention: Q input is spatial, K input is text\n", " # QK^T doesn't make sense as weight-space product\n", " # Instead compute Q^T Q and K^T K separately\n", " min_in = min(Wq.shape[1], Wk.shape[1])\n", " Wq = Wq[:, :min_in]; Wk = Wk[:, :min_in]\n", "\n", " QK = Wq @ Wk.T\n", " del Wq, Wk\n", "\n", " sym = torch.norm(QK-QK.T).item()/(torch.norm(QK).item()+1e-10)\n", " QK_cpu = ((QK+QK.T)/2).cpu()\n", " del QK\n", " if torch.cuda.is_available(): torch.cuda.empty_cache()\n", "\n", " eig = torch.linalg.eigvalsh(QK_cpu).numpy()[::-1]\n", " del QK_cpu\n", " n_pos=(eig>0).sum(); n_neg=(eig<0).sum(); dim=len(eig)\n", "\n", " loc, block, layer, tb, _ = key\n", " print(f\" {loc} B{block} L{layer} T{tb}: pos={n_pos}({n_pos/dim:.3f}), \"\n", " f\"neg={n_neg}({n_neg/dim:.3f}), sym={sym:.4f}, \"\n", " f\"top={eig[0]:.2f} ({time.time()-t0:.1f}s)\")\n", " qk_results.append({'key':key, 'n_pos':n_pos, 'n_neg':n_neg, 'dim':dim})\n", " del eig; gc.collect()\n", " except Exception as e:\n", " print(f\" {key}: FAIL — {e}\")\n", "\n", " if len(qk_results) >= 2:\n", " fracs = [r['n_pos']/r['dim'] for r in qk_results]\n", " print(f\"\\n Pos fraction: mean={np.mean(fracs):.3f}, std={np.std(fracs):.3f}, \"\n", " f\"range=[{min(fracs):.3f}, {max(fracs):.3f}]\")\n", "\n", "# ── DEAD NEURONS (FF blocks in transformer layers) ──\n", "print(f\"\\n{'='*70}\\nFF DEAD NEURONS (GEGLU)\\n{'='*70}\")\n", "gate_map = {(e['loc'],e['block'],e['layer']):e['name'] for e in catalog.get('ff_gate',[])}\n", "down_map = {(e['loc'],e['block'],e['layer']):e['name'] for e in catalog.get('ff_down',[])}\n", "td=tn=0\n", "for key in sorted(gate_map):\n", " if key not in down_map: continue\n", " try:\n", " Wg = get_w(gate_map[key])\n", " Wd = get_w(down_map[key])\n", " # GEGLU: gate is [2*d_ff, d_model] (first half is gate, second is value)\n", " # Or it might be [d_ff*2, d_model] → split\n", " if Wg.shape[0] > Wd.shape[1]:\n", " half = Wg.shape[0] // 2\n", " gn = torch.norm(Wg[:half], dim=1).cpu().numpy()\n", " un = torch.norm(Wg[half:], dim=1).cpu().numpy()\n", " else:\n", " gn = torch.norm(Wg, dim=1).cpu().numpy()\n", " un = gn # non-gated fallback\n", " dn = torch.norm(Wd, dim=0).cpu().numpy()\n", "\n", " # Align dimensions\n", " d_ff = min(len(gn), len(un), len(dn))\n", " gn = gn[:d_ff]; un = un[:d_ff]; dn = dn[:d_ff]\n", "\n", " c = gn*un*dn; mc = c.mean()\n", " dead=(c<0.01*mc).sum(); weak=(c<0.10*mc).sum()\n", " td+=dead; tn+=d_ff\n", " if dead > 0 or weak > 5:\n", " loc, block, layer = key\n", " print(f\" {loc} B{block} L{layer}: d_ff={d_ff}, dead={dead}({dead/d_ff*100:.1f}%), \"\n", " f\"weak={weak}({weak/d_ff*100:.1f}%)\")\n", " except Exception as e:\n", " pass\n", "if tn > 0:\n", " print(f\"\\n Total: {td}/{tn} ({td/tn*100:.2f}%)\")\n", " if td == 0:\n", " print(f\" All {tn} neurons alive\")\n", "\n", "# ── CROSS-LAYER Q CORRELATION ──\n", "print(f\"\\n{'='*70}\\nCROSS-BLOCK Q CORRELATION\\n{'='*70}\")\n", "for attn_type in ['self_attn_q', 'cross_attn_q']:\n", " entries = catalog.get(attn_type, [])\n", " if len(entries) < 2: continue\n", " qf = []\n", " for e in sorted(entries, key=lambda x: (x['loc'], x['block'], x['layer'])):\n", " try:\n", " W = get_w(e['name']).cpu().flatten()\n", " qf.append((f\"{e['loc']} B{e['block']}\", W))\n", " except: pass\n", " if len(qf) >= 2:\n", " adj = []\n", " for i in range(len(qf)-1):\n", " a, b = qf[i][1], qf[i+1][1]\n", " if a.shape == b.shape:\n", " adj.append((torch.dot(a,b)/(torch.norm(a)*torch.norm(b)+1e-8)).item())\n", " if adj:\n", " print(f\" {attn_type}: adj cos mean={np.mean(adj):.4f}, range=[{min(adj):.4f},{max(adj):.4f}]\")\n", " del qf; gc.collect()\n", "\n", "# ── SUMMARY ──\n", "print(f\"\\n{'='*70}\\nSUMMARY — {MODEL_ID} (UNet)\\n{'='*70}\")\n", "print(f\"Params: {total_params:,}\")\n", "print(f\"Architecture: UNet2DConditionModel\")\n", "if hasattr(config, 'block_out_channels'):\n", " print(f\"Channels: {config.block_out_channels}\")\n", "if hasattr(config, 'cross_attention_dim'):\n", " print(f\"Cross-attn dim: {config.cross_attention_dim}\")\n", "for w in ['self_attn_q','self_attn_k','self_attn_v',\n", " 'cross_attn_q','cross_attn_k','cross_attn_v']:\n", " sc = sparsity.get(w,{})\n", " if sc.get('total',0) > 0:\n", " print(f\" {w} (<0.1): {sc[1e-1]/sc['total']*100:.1f}%\")\n", "print(f\"\\nRef (transformers):\")\n", "print(f\" T5-v1.1-XXL Q: 100.0% | BERT: 99.1% | DINOv2: 100% | T5Gemma2: 100%\")\n", "if torch.cuda.is_available():\n", " print(f\"\\nVRAM at end: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "print(\"Done.\")\n" ] }, { "cell_type": "markdown", "id": "1c81b07f", "metadata": {}, "source": [ "## 15. VAE Weight Topology\n", "*Section XIV: SD 1.5, SDXL, Flux.1, Flux.2 — all four sequentially*\n", "\n", "**Requires:** Flux access." ] }, { "cell_type": "code", "execution_count": null, "id": "3c63ff2d", "metadata": {}, "outputs": [], "source": [ "# VAE geometry\n", "\n", "# ============================================================================\n", "# DIFFUSION VAE — INACTIVE WEIGHT GEOMETRY\n", "# SD 1.5, SDXL, Flux.1, Flux.2\n", "# Mostly convolutional — Conv2d reshaped to [out, in*k*k] for analysis\n", "# Mid-block attention analyzed separately\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import time\n", "import gc\n", "from collections import defaultdict\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name()}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n", "\n", "# ── SELECT VAE ──\n", "VAES = {\n", " 'sd15': (\"stable-diffusion-v1-5/stable-diffusion-v1-5\", \"vae\", \"SD 1.5\"),\n", " 'sdxl': (\"stabilityai/stable-diffusion-xl-base-1.0\", \"vae\", \"SDXL\"),\n", " 'flux1': (\"black-forest-labs/FLUX.1-dev\", \"vae\", \"Flux.1\"),\n", " 'flux2': (\"black-forest-labs/FLUX.2-dev\", \"vae\", \"Flux.2\"),\n", "}\n", "\n", "# Run all of them sequentially\n", "from diffusers import AutoencoderKL\n", "\n", "all_results = {}\n", "\n", "for vae_key, (repo_id, subfolder, label) in VAES.items():\n", " print(f\"\\n{'='*70}\")\n", " print(f\"VAE: {label} ({repo_id})\")\n", " print(f\"{'='*70}\")\n", "\n", " try:\n", " print(f\"Loading {subfolder} from {repo_id}...\")\n", " t0 = time.time()\n", " vae = AutoencoderKL.from_pretrained(repo_id, subfolder=subfolder, torch_dtype=torch.float16)\n", " vae = vae.to(device)\n", " vae.eval()\n", " total_params = sum(p.numel() for p in vae.parameters())\n", " print(f\"Loaded in {time.time()-t0:.0f}s, {total_params:,} params\")\n", " if torch.cuda.is_available():\n", " print(f\"VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", " except Exception as e:\n", " print(f\"FAILED to load: {e}\")\n", " all_results[vae_key] = {'error': str(e)}\n", " continue\n", "\n", " # Config\n", " config = vae.config\n", " print(f\"\\nConfig:\")\n", " for key in ['in_channels', 'out_channels', 'latent_channels',\n", " 'block_out_channels', 'layers_per_block', 'sample_size',\n", " 'scaling_factor', 'norm_num_groups']:\n", " val = getattr(config, key, None)\n", " if val is not None:\n", " print(f\" {key} = {val}\")\n", "\n", " # ── CATALOG — handle Conv2d (4D) and Linear (2D) ──\n", " print(f\"\\n--- CATALOG ---\")\n", " catalog = defaultdict(list)\n", " for name, param in vae.named_parameters():\n", " parts = name.split('.')\n", "\n", " # Location\n", " if 'encoder' in name: loc = 'encoder'\n", " elif 'decoder' in name: loc = 'decoder'\n", " elif 'quant_conv' in name: loc = 'quant'\n", " elif 'post_quant_conv' in name: loc = 'post_quant'\n", " else: loc = 'other'\n", "\n", " # Block depth\n", " block = -1\n", " for i, p in enumerate(parts):\n", " if p in ('down_blocks', 'up_blocks') and i+1 < len(parts):\n", " try: block = int(parts[i+1])\n", " except: pass\n", "\n", " # Weight type\n", " dim = param.dim()\n", " if dim == 4: # Conv2d\n", " if 'attentions' in name or 'attn' in name:\n", " if 'to_q' in name or 'query' in name or 'group_norm' not in name and 'to_q' in name:\n", " wt = 'attn_q_conv'\n", " elif 'to_k' in name or 'key' in name:\n", " wt = 'attn_k_conv'\n", " elif 'to_v' in name or 'value' in name:\n", " wt = 'attn_v_conv'\n", " elif 'to_out' in name or 'proj_attn' in name:\n", " wt = 'attn_o_conv'\n", " else:\n", " wt = 'attn_conv_other'\n", " elif 'mid_block' in name:\n", " wt = 'mid_conv'\n", " elif 'down' in name:\n", " wt = 'down_conv'\n", " elif 'up' in name:\n", " wt = 'up_conv'\n", " elif 'conv_in' in name:\n", " wt = 'conv_in'\n", " elif 'conv_out' in name:\n", " wt = 'conv_out'\n", " elif 'conv_norm_out' in name:\n", " continue # skip\n", " elif 'quant' in name:\n", " wt = 'quant_conv'\n", " else:\n", " wt = 'conv_other'\n", " elif dim == 2: # Linear (attention projections)\n", " if 'to_q' in name or 'query' in name:\n", " wt = 'attn_q'\n", " elif 'to_k' in name or 'key' in name:\n", " wt = 'attn_k'\n", " elif 'to_v' in name or 'value' in name:\n", " wt = 'attn_v'\n", " elif 'to_out' in name or 'proj_attn' in name:\n", " wt = 'attn_o'\n", " else:\n", " wt = 'linear_other'\n", " elif dim == 1: # bias/norm\n", " continue\n", " else:\n", " continue\n", "\n", " # Reshape info for Conv2d\n", " shape_2d = None\n", " if dim == 4:\n", " out_c, in_c, kh, kw = param.shape\n", " shape_2d = (out_c, in_c * kh * kw)\n", "\n", " catalog[wt].append({\n", " 'name': name, 'shape': tuple(param.shape), 'loc': loc,\n", " 'block': block, 'numel': param.numel(), 'dim': dim,\n", " 'shape_2d': shape_2d,\n", " })\n", "\n", " for wt, entries in sorted(catalog.items()):\n", " t = sum(e['numel'] for e in entries)\n", " locs = defaultdict(int)\n", " for e in entries: locs[e['loc']] += 1\n", " loc_str = ' '.join(f\"{k}:{v}\" for k,v in sorted(locs.items()))\n", " print(f\" {wt:25s}: {len(entries):4d} ({loc_str}) {t:>12,}\")\n", "\n", " # Helper — get param as fp32 on GPU, reshape Conv2d to 2D\n", " def get_w_2d(entry):\n", " parts = entry['name'].split('.')\n", " obj = vae\n", " for p in parts:\n", " obj = obj[int(p)] if p.isdigit() else getattr(obj, p)\n", " W = obj.detach().float()\n", " if entry['dim'] == 4:\n", " # Reshape Conv2d [out, in, kh, kw] → [out, in*kh*kw]\n", " W = W.reshape(W.shape[0], -1)\n", " return W\n", "\n", " # ── SVD ──\n", " print(f\"\\n--- SVD ---\")\n", " svd_results = []\n", " skip = {'conv_other', 'linear_other', 'attn_conv_other'}\n", " all_entries = [(wt,e) for wt,entries in catalog.items() if wt not in skip for e in entries]\n", " t0 = time.time()\n", "\n", " for idx, (wt, entry) in enumerate(all_entries):\n", " if (idx+1) % 10 == 0:\n", " print(f\" [{idx+1}/{len(all_entries)}] \", end=\"\\r\")\n", " try:\n", " W = get_w_2d(entry)\n", " S = torch.linalg.svdvals(W).cpu().numpy()\n", " svd_results.append({\n", " 'wt': wt, 'loc': entry['loc'], 'shape': entry['shape'],\n", " 'sr': (S**2).sum()/(S[0]**2) if S[0]>0 else 0,\n", " 'pr': (S.sum())**2/((S**2).sum()) if (S**2).sum()>0 else 0,\n", " 'af': (S>0.01*S[0]).sum()/len(S),\n", " 'cond': S[0]/(S[-1]+1e-10),\n", " })\n", " except Exception as e:\n", " pass\n", "\n", " print(f\" Done: {len(svd_results)} in {time.time()-t0:.0f}s \")\n", "\n", " by_type = defaultdict(list)\n", " for r in svd_results: by_type[r['wt']].append(r)\n", " print(f\" {'Type':25s} {'SR':>8s} {'PR':>8s} {'Act%':>6s} {'Cond':>12s}\")\n", " for w in sorted(by_type):\n", " s = by_type[w]\n", " print(f\" {w:23s} {np.mean([r['sr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['pr'] for r in s]):8.2f} \"\n", " f\"{np.mean([r['af'] for r in s]):6.3f} \"\n", " f\"{np.mean([r['cond'] for r in s]):12.1f}\")\n", "\n", " # ── SPARSITY ──\n", " print(f\"\\n--- SPARSITY ---\")\n", " thresholds = [1e-4, 1e-3, 1e-2, 1e-1]\n", " sparsity = defaultdict(lambda: {'total':0, **{t:0 for t in thresholds}})\n", " for wt, entries in catalog.items():\n", " if wt in skip: continue\n", " for entry in entries:\n", " W = get_w_2d(entry); a = W.abs(); n = a.numel()\n", " sparsity[wt]['total'] += n\n", " for t in thresholds:\n", " sparsity[wt][t] += (a8s} {'<1e-3':>8s} {'<0.01':>8s} {'<0.1':>8s}\")\n", " for wt in sorted(sparsity):\n", " sc = sparsity[wt]\n", " if sc['total']==0: continue\n", " print(f\" {wt:23s} {sc[1e-4]/sc['total']:8.4f} {sc[1e-3]/sc['total']:8.4f} \"\n", " f\"{sc[1e-2]/sc['total']:8.4f} {sc[1e-1]/sc['total']:8.4f}\")\n", "\n", " # Encoder vs decoder sparsity\n", " print(f\"\\n --- ENC vs DEC (<0.1) ---\")\n", " for loc in ['encoder', 'decoder']:\n", " total_n = 0; below = 0\n", " for wt, entries in catalog.items():\n", " if wt in skip: continue\n", " for e in entries:\n", " if e['loc'] != loc: continue\n", " W = get_w_2d(e); a = W.abs()\n", " total_n += a.numel(); below += (a<0.1).sum().item()\n", " if total_n > 0:\n", " print(f\" {loc:10s}: {below/total_n*100:.1f}%\")\n", "\n", " # ── QK MANIFOLD (mid-block attention if present) ──\n", " attn_q = [e for e in catalog.get('attn_q', []) + catalog.get('attn_q_conv', [])]\n", " attn_k = [e for e in catalog.get('attn_k', []) + catalog.get('attn_k_conv', [])]\n", " if attn_q and attn_k:\n", " print(f\"\\n--- QK MANIFOLD (mid-block attention) ---\")\n", " for q_entry, k_entry in zip(attn_q, attn_k):\n", " try:\n", " Wq = get_w_2d(q_entry)\n", " Wk = get_w_2d(k_entry)\n", " if Wq.shape != Wk.shape:\n", " min_d = min(Wq.shape[0], Wk.shape[0])\n", " Wq = Wq[:min_d]; Wk = Wk[:min_d]\n", " if Wq.shape[1] != Wk.shape[1]:\n", " min_d2 = min(Wq.shape[1], Wk.shape[1])\n", " Wq = Wq[:,:min_d2]; Wk = Wk[:,:min_d2]\n", " QK = Wq @ Wk.T; del Wq, Wk\n", " sym = torch.norm(QK-QK.T).item()/(torch.norm(QK).item()+1e-10)\n", " QK_cpu = ((QK+QK.T)/2).cpu(); del QK\n", " eig = torch.linalg.eigvalsh(QK_cpu).numpy()[::-1]; del QK_cpu\n", " n_pos=(eig>0).sum(); n_neg=(eig<0).sum(); dim=len(eig)\n", " print(f\" {q_entry['loc']}: pos={n_pos}({n_pos/dim:.3f}), \"\n", " f\"neg={n_neg}({n_neg/dim:.3f}), sym={sym:.4f}, top={eig[0]:.2f}\")\n", " del eig\n", " except Exception as e:\n", " print(f\" FAIL: {e}\")\n", " else:\n", " print(f\"\\n No attention Q/K found (pure conv VAE)\")\n", "\n", " # ── CROSS-LAYER CONV CORRELATION ──\n", " print(f\"\\n--- CROSS-LAYER CONV CORRELATION ---\")\n", " for loc in ['encoder', 'decoder']:\n", " for wt in ['down_conv', 'up_conv', 'mid_conv']:\n", " entries = sorted([e for e in catalog.get(wt, []) if e['loc']==loc],\n", " key=lambda x: x['name'])\n", " if len(entries) < 2: continue\n", " # Group by same shape only\n", " shape_groups = defaultdict(list)\n", " for e in entries:\n", " shape_groups[e['shape']].append(e)\n", " for shape, group in shape_groups.items():\n", " if len(group) < 2: continue\n", " adj = []\n", " for i in range(len(group)-1):\n", " try:\n", " a = get_w_2d(group[i]).cpu().flatten()\n", " b = get_w_2d(group[i+1]).cpu().flatten()\n", " if a.shape == b.shape:\n", " c = torch.dot(a,b)/(torch.norm(a)*torch.norm(b)+1e-8)\n", " adj.append(c.item())\n", " except: pass\n", " if adj:\n", " print(f\" {loc} {wt} {shape}: adj cos mean={np.mean(adj):.4f} \"\n", " f\"range=[{min(adj):.4f},{max(adj):.4f}] (n={len(adj)})\")\n", "\n", " # Store results\n", " result = {\n", " 'params': total_params,\n", " 'latent_ch': getattr(config, 'latent_channels', '?'),\n", " 'block_ch': getattr(config, 'block_out_channels', '?'),\n", " 'svd': {w: np.mean([r['sr'] for r in s]) for w,s in by_type.items()},\n", " 'sparsity': {wt: sc[1e-1]/sc['total']*100 if sc['total']>0 else 0\n", " for wt, sc in sparsity.items()},\n", " }\n", " all_results[vae_key] = result\n", "\n", " # Cleanup\n", " del vae\n", " gc.collect()\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", "\n", "# ── CROSS-VAE COMPARISON ──\n", "print(f\"\\n\\n{'='*70}\")\n", "print(\"CROSS-VAE COMPARISON\")\n", "print(f\"{'='*70}\")\n", "\n", "print(f\"\\n{'VAE':10s} {'Params':>12s} {'Latent Ch':>10s} {'Block Ch':>20s}\")\n", "for key, label in [('sd15','SD 1.5'),('sdxl','SDXL'),('flux1','Flux.1'),('flux2','Flux.2')]:\n", " r = all_results.get(key, {})\n", " if 'error' in r:\n", " print(f\" {label:8s} FAILED: {r['error'][:50]}\")\n", " continue\n", " print(f\" {label:8s} {r.get('params',0):>12,} {str(r.get('latent_ch','?')):>10s} \"\n", " f\"{str(r.get('block_ch','?')):>20s}\")\n", "\n", "print(f\"\\nDone.\")\n" ] }, { "cell_type": "markdown", "id": "a5cc4011", "metadata": {}, "source": [ "## 16. Procrustes Analysis — VAE Weight-Space Alignment\n", "*Section XV: pairwise orthogonal Procrustes on 4 VAEs × 68 matrices*\n", "\n", "70-76% aligned after rotation. Spectral corr 0.94-0.98." ] }, { "cell_type": "code", "execution_count": null, "id": "3d073d0f", "metadata": {}, "outputs": [], "source": [ "# Procrustes VAE\n", "\n", "# ============================================================================\n", "# PROCRUSTES ANALYSIS — DIFFUSION VAEs (v2 — ALL GPU)\n", "# SD 1.5, SDXL, Flux.1, Flux.2\n", "# Everything on GPU. Proper normalization. Fixed spectral.\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import time\n", "import gc\n", "from collections import defaultdict\n", "from itertools import combinations\n", "from diffusers import AutoencoderKL\n", "\n", "device = torch.device(\"cuda\")\n", "print(f\"GPU: {torch.cuda.get_device_name()}\")\n", "\n", "VAES = {\n", " 'sd15': (\"stable-diffusion-v1-5/stable-diffusion-v1-5\", \"vae\", \"SD 1.5\"),\n", " 'sdxl': (\"stabilityai/stable-diffusion-xl-base-1.0\", \"vae\", \"SDXL\"),\n", " 'flux1': (\"black-forest-labs/FLUX.1-dev\", \"vae\", \"Flux.1\"),\n", " 'flux2': (\"black-forest-labs/FLUX.2-dev\", \"vae\", \"Flux.2\"),\n", "}\n", "\n", "print(\"=\" * 70)\n", "print(\"PROCRUSTES ANALYSIS — DIFFUSION VAEs\")\n", "print(\"=\" * 70)\n", "\n", "# ── Extract weights — store on GPU as fp32 ──\n", "def extract_weights(repo_id, subfolder):\n", " vae = AutoencoderKL.from_pretrained(repo_id, subfolder=subfolder, torch_dtype=torch.float32)\n", " vae.eval()\n", " weights = {}\n", " for name, param in vae.named_parameters():\n", " if param.dim() == 4:\n", " weights[name] = param.detach().reshape(param.shape[0], -1).to(device)\n", " elif param.dim() == 2:\n", " weights[name] = param.detach().to(device)\n", " del vae; gc.collect(); torch.cuda.empty_cache()\n", " return weights\n", "\n", "print(\"\\nExtracting weights (all on GPU)...\")\n", "all_weights = {}\n", "for key, (repo_id, subfolder, label) in VAES.items():\n", " print(f\" {label}...\", end=\" \")\n", " t0 = time.time()\n", " try:\n", " all_weights[key] = extract_weights(repo_id, subfolder)\n", " print(f\"{len(all_weights[key])} matrices in {time.time()-t0:.0f}s\")\n", " except Exception as e:\n", " print(f\"FAILED: {e}\")\n", " all_weights[key] = {}\n", "\n", "print(f\"VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "\n", "# Common matrices\n", "model_keys = [k for k in all_weights if all_weights[k]]\n", "common_names = None\n", "for key in model_keys:\n", " ns = {n: tuple(w.shape) for n, w in all_weights[key].items()}\n", " common_names = ns if common_names is None else {n: s for n, s in common_names.items() if n in ns and ns[n] == s}\n", "print(f\"\\nCommon matrices: {len(common_names)}\")\n", "\n", "# ── GPU Procrustes ──\n", "def procrustes_gpu(A, B):\n", " \"\"\"Orthogonal Procrustes on GPU. Normalizes to unit Frobenius norm.\"\"\"\n", " A_n = A / (torch.norm(A) + 1e-10)\n", " B_n = B / (torch.norm(B) + 1e-10)\n", " U, S, Vt = torch.linalg.svd(B_n.T @ A_n)\n", " R = U @ Vt\n", " B_aligned = B_n @ R\n", " residual = torch.norm(A_n - B_aligned).item()\n", " cosine = torch.dot(A_n.flatten(), B_aligned.flatten()).item()\n", " return residual, cosine\n", "\n", "def spectral_corr_gpu(A, B):\n", " \"\"\"Spectral correlation on GPU.\"\"\"\n", " Sa = torch.linalg.svdvals(A)\n", " Sb = torch.linalg.svdvals(B)\n", " n = min(len(Sa), len(Sb))\n", " if n < 3: return 0.0\n", " Sa = Sa[:n] / (Sa[0] + 1e-10)\n", " Sb = Sb[:n] / (Sb[0] + 1e-10)\n", " Sa_c = Sa - Sa.mean(); Sb_c = Sb - Sb.mean()\n", " return (torch.dot(Sa_c, Sb_c) / (torch.norm(Sa_c) * torch.norm(Sb_c) + 1e-10)).item()\n", "\n", "def raw_cosine_gpu(A, B):\n", " \"\"\"Raw weight cosine (no rotation).\"\"\"\n", " return (torch.dot(A.flatten(), B.flatten()) / (torch.norm(A) * torch.norm(B) + 1e-10)).item()\n", "\n", "# ── Pairwise analysis ──\n", "all_pairs = list(combinations(model_keys, 2))\n", "print(f\"\\n{'='*70}\\nPAIRWISE PROCRUSTES ({len(all_pairs)} pairs × {len(common_names)} matrices)\\n{'='*70}\")\n", "\n", "pair_results = {}\n", "for key_a, key_b in all_pairs:\n", " la, lb = VAES[key_a][2], VAES[key_b][2]\n", " print(f\"\\n--- {la} vs {lb} ---\")\n", " t0 = time.time()\n", "\n", " residuals = []; cosines = []; raw_cos = []; spec_corrs = []\n", " attn_res = []; conv_res = []\n", "\n", " for idx, name in enumerate(common_names):\n", " if (idx+1) % 20 == 0:\n", " print(f\" [{idx+1}/{len(common_names)}] \", end=\"\\r\")\n", " A = all_weights[key_a][name]\n", " B = all_weights[key_b][name]\n", " res, cos = procrustes_gpu(A, B)\n", " residuals.append(res); cosines.append(cos)\n", " raw_cos.append(raw_cosine_gpu(A, B))\n", " spec_corrs.append(spectral_corr_gpu(A, B))\n", " if 'attn' in name: attn_res.append(res)\n", " else: conv_res.append(res)\n", "\n", " elapsed = time.time() - t0\n", " gain = np.mean(cosines) - np.mean(raw_cos)\n", " print(f\" {len(residuals)} matrices in {elapsed:.0f}s \")\n", " print(f\" Procrustes residual: {np.mean(residuals):.4f} ±{np.std(residuals):.4f}\")\n", " print(f\" Procrustes cosine: {np.mean(cosines):.4f} ±{np.std(cosines):.4f}\")\n", " print(f\" Raw cosine (no rot): {np.mean(raw_cos):.4f} ±{np.std(raw_cos):.4f}\")\n", " print(f\" Spectral corr: {np.mean(spec_corrs):.4f} ±{np.std(spec_corrs):.4f}\")\n", " if attn_res: print(f\" Attn residual: {np.mean(attn_res):.4f} (n={len(attn_res)})\")\n", " if conv_res: print(f\" Conv residual: {np.mean(conv_res):.4f} (n={len(conv_res)})\")\n", " print(f\" Rotation gain: {gain:+.4f}\")\n", "\n", " pair_results[(key_a, key_b)] = {\n", " 'res': np.mean(residuals), 'cos': np.mean(cosines),\n", " 'raw': np.mean(raw_cos), 'spec': np.mean(spec_corrs),\n", " 'attn': np.mean(attn_res) if attn_res else None,\n", " 'conv': np.mean(conv_res) if conv_res else None, 'gain': gain,\n", " }\n", "\n", "# ── Distance matrices ──\n", "print(f\"\\n{'='*70}\\nDISTANCE MATRICES\\n{'='*70}\")\n", "labels = [VAES[k][2] for k in model_keys]\n", "n = len(model_keys)\n", "\n", "def gpv(i, j, f):\n", " if i == j: return 0.0 if 'res' in f else 1.0\n", " k = (model_keys[i], model_keys[j])\n", " if k in pair_results: return pair_results[k][f]\n", " k = (model_keys[j], model_keys[i])\n", " if k in pair_results: return pair_results[k][f]\n", " return float('nan')\n", "\n", "for f, t in [('res','PROCRUSTES RESIDUAL'), ('cos','PROCRUSTES COSINE'),\n", " ('raw','RAW COSINE'), ('spec','SPECTRAL CORRELATION')]:\n", " print(f\"\\n {t}\")\n", " print(f\" {'':10s}\", end=\"\")\n", " for l in labels: print(f\" {l:>8s}\", end=\"\")\n", " print()\n", " for i, l in enumerate(labels):\n", " print(f\" {l:10s}\", end=\"\")\n", " for j in range(n): print(f\" {gpv(i,j,f):8.4f}\", end=\"\")\n", " print()\n", "\n", "# ── Depth profile ──\n", "print(f\"\\n{'='*70}\\nDEPTH PROFILE: SD 1.5 vs Flux.2\\n{'='*70}\")\n", "if 'sd15' in all_weights and 'flux2' in all_weights:\n", " dr = defaultdict(list)\n", " for name in common_names:\n", " res, cos = procrustes_gpu(all_weights['sd15'][name], all_weights['flux2'][name])\n", " if 'encoder.down_blocks.0' in name: d='enc_d0'\n", " elif 'encoder.down_blocks.1' in name: d='enc_d1'\n", " elif 'encoder.down_blocks.2' in name: d='enc_d2'\n", " elif 'encoder.down_blocks.3' in name: d='enc_d3'\n", " elif 'encoder.mid_block' in name: d='enc_mid'\n", " elif 'decoder.up_blocks.0' in name: d='dec_u0'\n", " elif 'decoder.up_blocks.1' in name: d='dec_u1'\n", " elif 'decoder.up_blocks.2' in name: d='dec_u2'\n", " elif 'decoder.up_blocks.3' in name: d='dec_u3'\n", " elif 'decoder.mid_block' in name: d='dec_mid'\n", " elif 'quant' in name: d='quant'\n", " else: d='other'\n", " dr[d].append({'res':res,'cos':cos})\n", "\n", " print(f\" {'Depth':12s} {'N':>3s} {'Residual':>10s} {'Cosine':>10s}\")\n", " for d in ['enc_d0','enc_d1','enc_d2','enc_d3','enc_mid','dec_mid','dec_u0','dec_u1','dec_u2','dec_u3','quant','other']:\n", " if d in dr:\n", " r = dr[d]\n", " print(f\" {d:12s} {len(r):3d} {np.mean([x['res'] for x in r]):10.4f} \"\n", " f\"{np.mean([x['cos'] for x in r]):10.4f}\")\n", "\n", "# ── Encoder vs decoder within each VAE ──\n", "print(f\"\\n{'='*70}\\nENCODER vs DECODER (within each VAE)\\n{'='*70}\")\n", "for key in model_keys:\n", " label = VAES[key][2]; weights = all_weights[key]\n", " pairs_ed = []\n", " for name in weights:\n", " if 'encoder.mid_block' in name:\n", " dn = name.replace('encoder','decoder')\n", " if dn in weights and weights[name].shape == weights[dn].shape:\n", " pairs_ed.append((name, dn))\n", " if pairs_ed:\n", " rl=[]; cl=[]\n", " for en,dn in pairs_ed:\n", " r,c = procrustes_gpu(weights[en],weights[dn])\n", " rl.append(r); cl.append(c)\n", " print(f\" {label:8s}: {len(pairs_ed)} mid-block pairs, res={np.mean(rl):.4f}, cos={np.mean(cl):.4f}\")\n", "\n", "# ── Attention detail ──\n", "print(f\"\\n{'='*70}\\nATTENTION DETAIL\\n{'='*70}\")\n", "attn_names = sorted([n for n in common_names if 'attn' in n])\n", "if attn_names:\n", " for name in attn_names:\n", " short = '.'.join(name.split('.')[-3:])\n", " print(f\" {short:30s}\", end=\"\")\n", " for ka, kb in all_pairs:\n", " r,c = procrustes_gpu(all_weights[ka][name], all_weights[kb][name])\n", " la,lb = VAES[ka][2][:2], VAES[kb][2][:2]\n", " print(f\" {la}{lb}={c:.3f}\", end=\"\")\n", " print()\n", "\n", "# ── Summary ──\n", "print(f\"\\n{'='*70}\\nSUMMARY\\n{'='*70}\")\n", "print(f\"\\n{'Pair':20s} {'Residual':>10s} {'ProcrCos':>10s} {'RawCos':>10s} {'Spectral':>10s} {'Gain':>8s}\")\n", "for (ka,kb),r in sorted(pair_results.items()):\n", " la,lb = VAES[ka][2],VAES[kb][2]\n", " print(f\" {la+' vs '+lb:18s} {r['res']:10.4f} {r['cos']:10.4f} \"\n", " f\"{r['raw']:10.4f} {r['spec']:10.4f} {r['gain']:+8.4f}\")\n", "\n", "best = min(pair_results.items(), key=lambda x: x[1]['res'])\n", "worst = max(pair_results.items(), key=lambda x: x[1]['res'])\n", "print(f\"\\nClosest: {VAES[best[0][0]][2]} vs {VAES[best[0][1]][2]} (res={best[1]['res']:.4f})\")\n", "print(f\"Furthest: {VAES[worst[0][0]][2]} vs {VAES[worst[0][1]][2]} (res={worst[1]['res']:.4f})\")\n", "print(f\"\\nVRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "print(\"Done.\")\n" ] }, { "cell_type": "markdown", "id": "bde2684e", "metadata": {}, "source": [ "## 17. Procrustes Analysis — BERT vs DINOv2\n", "*Cross-modal alignment: language (MLM) vs vision (self-supervised)*\n", "\n", "61% aligned overall, 70% in deep layers. O > V > K ≈ Q." ] }, { "cell_type": "code", "execution_count": null, "id": "38f8e8eb", "metadata": {}, "outputs": [], "source": [ "# Procrustes BERT vs DINOv2\n", "\n", "# ============================================================================\n", "# PROCRUSTES ANALYSIS — BERT-large vs DINOv2-large\n", "# Same architecture: 1024-d, 24 layers, 16 heads\n", "# Different modality: language (MLM) vs vision (self-supervised)\n", "# The cross-modal alignment test\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "import time\n", "import gc\n", "from collections import defaultdict\n", "from transformers import BertModel, Dinov2Model\n", "\n", "device = torch.device(\"cuda\")\n", "print(f\"GPU: {torch.cuda.get_device_name()}\")\n", "\n", "print(\"=\" * 70)\n", "print(\"PROCRUSTES: BERT-large vs DINOv2-large\")\n", "print(\"Both 1024-d, 24 layers, 16 heads\")\n", "print(\"Language (MLM) vs Vision (self-supervised)\")\n", "print(\"=\" * 70)\n", "\n", "# ── Load both models, extract 2D weights ──\n", "def extract_bert(model):\n", " weights = {}\n", " for name, param in model.named_parameters():\n", " if param.dim() != 2: continue\n", " # Normalize naming to generic format\n", " generic = name\n", " # Map BERT naming to generic\n", " generic = generic.replace('bert.encoder.layer.', 'layer.')\n", " generic = generic.replace('.attention.self.query.', '.attn.q.')\n", " generic = generic.replace('.attention.self.key.', '.attn.k.')\n", " generic = generic.replace('.attention.self.value.', '.attn.v.')\n", " generic = generic.replace('.attention.output.dense.', '.attn.o.')\n", " generic = generic.replace('.intermediate.dense.', '.mlp.up.')\n", " generic = generic.replace('.output.dense.', '.mlp.down.')\n", " weights[generic] = param.detach().float().to(device)\n", " return weights\n", "\n", "def extract_dino(model):\n", " weights = {}\n", " for name, param in model.named_parameters():\n", " if param.dim() != 2: continue\n", " generic = name\n", " generic = generic.replace('dinov2.encoder.layer.', 'layer.')\n", " generic = generic.replace('.attention.attention.query.', '.attn.q.')\n", " generic = generic.replace('.attention.attention.key.', '.attn.k.')\n", " generic = generic.replace('.attention.attention.value.', '.attn.v.')\n", " generic = generic.replace('.attention.output.dense.', '.attn.o.')\n", " generic = generic.replace('.intermediate.dense.', '.mlp.up.')\n", " generic = generic.replace('.output.dense.', '.mlp.down.')\n", " weights[generic] = param.detach().float().to(device)\n", " return weights\n", "\n", "print(\"\\nLoading BERT-large...\")\n", "bert = BertModel.from_pretrained(\"google-bert/bert-large-uncased\", torch_dtype=torch.float32)\n", "bert.eval()\n", "bert_w = extract_bert(bert)\n", "del bert; gc.collect(); torch.cuda.empty_cache()\n", "print(f\" {len(bert_w)} 2D matrices\")\n", "\n", "print(\"Loading DINOv2-large...\")\n", "dino = Dinov2Model.from_pretrained(\"facebook/dinov2-large\", torch_dtype=torch.float32)\n", "dino.eval()\n", "dino_w = extract_dino(dino)\n", "del dino; gc.collect(); torch.cuda.empty_cache()\n", "print(f\" {len(dino_w)} 2D matrices\")\n", "\n", "print(f\"VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "\n", "# Find matching names with same shapes\n", "common = {}\n", "for name in bert_w:\n", " if name in dino_w and bert_w[name].shape == dino_w[name].shape:\n", " common[name] = bert_w[name].shape\n", "\n", "print(f\"\\nCommon matrices (same name + shape): {len(common)}\")\n", "# Group by type\n", "type_counts = defaultdict(int)\n", "type_shapes = defaultdict(set)\n", "for name, shape in common.items():\n", " if '.attn.q.' in name: t = 'attn_q'\n", " elif '.attn.k.' in name: t = 'attn_k'\n", " elif '.attn.v.' in name: t = 'attn_v'\n", " elif '.attn.o.' in name: t = 'attn_o'\n", " elif '.mlp.up.' in name: t = 'mlp_up'\n", " elif '.mlp.down.' in name: t = 'mlp_down'\n", " else: t = 'other'\n", " type_counts[t] += 1\n", " type_shapes[t].add(str(shape))\n", "\n", "for t in sorted(type_counts):\n", " print(f\" {t:15s}: {type_counts[t]:3d} shapes={type_shapes[t]}\")\n", "\n", "# If no common names, try by layer structure matching\n", "if len(common) == 0:\n", " print(\"\\n No name-matched matrices. Trying structural matching...\")\n", " # Match by layer number + weight type + shape\n", " bert_by_struct = {}\n", " for name, W in bert_w.items():\n", " # Extract layer num and type\n", " for l in range(24):\n", " if f'layer.{l}.' in name:\n", " for wt in ['attn.q', 'attn.k', 'attn.v', 'attn.o', 'mlp.up', 'mlp.down']:\n", " if wt in name:\n", " bert_by_struct[(l, wt)] = (name, W)\n", " break\n", "\n", " dino_by_struct = {}\n", " for name, W in dino_w.items():\n", " for l in range(24):\n", " if f'layer.{l}.' in name:\n", " for wt in ['attn.q', 'attn.k', 'attn.v', 'attn.o', 'mlp.up', 'mlp.down']:\n", " if wt in name:\n", " dino_by_struct[(l, wt)] = (name, W)\n", " break\n", "\n", " # Find matching structures\n", " common_struct = {}\n", " for key in bert_by_struct:\n", " if key in dino_by_struct:\n", " bn, bw = bert_by_struct[key]\n", " dn, dw = dino_by_struct[key]\n", " if bw.shape == dw.shape:\n", " common_struct[key] = (bn, dn, bw.shape)\n", "\n", " print(f\" Structurally matched: {len(common_struct)}\")\n", " for t in ['attn.q', 'attn.k', 'attn.v', 'attn.o', 'mlp.up', 'mlp.down']:\n", " n = sum(1 for k in common_struct if k[1] == t)\n", " if n > 0:\n", " shapes = set(str(common_struct[k][2]) for k in common_struct if k[1] == t)\n", " print(f\" {t:15s}: {n:3d} shapes={shapes}\")\n", "\n", "# ── GPU Procrustes functions ──\n", "def procrustes_gpu(A, B):\n", " A_n = A / (torch.norm(A) + 1e-10)\n", " B_n = B / (torch.norm(B) + 1e-10)\n", " U, S, Vt = torch.linalg.svd(B_n.T @ A_n)\n", " R = U @ Vt\n", " B_aligned = B_n @ R\n", " residual = torch.norm(A_n - B_aligned).item()\n", " cosine = torch.dot(A_n.flatten(), B_aligned.flatten()).item()\n", " return residual, cosine\n", "\n", "def spectral_corr_gpu(A, B):\n", " Sa = torch.linalg.svdvals(A)\n", " Sb = torch.linalg.svdvals(B)\n", " n = min(len(Sa), len(Sb))\n", " if n < 3: return 0.0\n", " Sa = Sa[:n] / (Sa[0] + 1e-10); Sb = Sb[:n] / (Sb[0] + 1e-10)\n", " Sa_c = Sa - Sa.mean(); Sb_c = Sb - Sb.mean()\n", " return (torch.dot(Sa_c, Sb_c) / (torch.norm(Sa_c) * torch.norm(Sb_c) + 1e-10)).item()\n", "\n", "def raw_cosine_gpu(A, B):\n", " return (torch.dot(A.flatten(), B.flatten()) / (torch.norm(A) * torch.norm(B) + 1e-10)).item()\n", "\n", "# ── Run Procrustes ──\n", "print(f\"\\n{'='*70}\")\n", "print(\"PAIRWISE PROCRUSTES — BERT vs DINOv2\")\n", "print(f\"{'='*70}\")\n", "\n", "# Use whichever matching method worked\n", "if len(common) > 0:\n", " # Name-matched\n", " items = [(name, bert_w[name], dino_w[name]) for name in common]\n", "elif len(common_struct) > 0:\n", " items = [(f\"L{k[0]}.{k[1]}\", bert_by_struct[k][1], dino_by_struct[k][1])\n", " for k in sorted(common_struct)]\n", "else:\n", " print(\"No matching matrices found!\")\n", " items = []\n", "\n", "if items:\n", " t0 = time.time()\n", " all_res = []; all_cos = []; all_raw = []; all_spec = []\n", " type_results = defaultdict(lambda: {'res':[], 'cos':[], 'raw':[], 'spec':[]})\n", "\n", " for idx, (name, A, B) in enumerate(items):\n", " if (idx+1) % 10 == 0:\n", " print(f\" [{idx+1}/{len(items)}] \", end=\"\\r\")\n", "\n", " res, cos = procrustes_gpu(A, B)\n", " raw = raw_cosine_gpu(A, B)\n", " spec = spectral_corr_gpu(A, B)\n", "\n", " all_res.append(res); all_cos.append(cos)\n", " all_raw.append(raw); all_spec.append(spec)\n", "\n", " # Categorize\n", " if 'attn.q' in name or '.q.' in name: t = 'attn_q'\n", " elif 'attn.k' in name or '.k.' in name: t = 'attn_k'\n", " elif 'attn.v' in name or '.v.' in name: t = 'attn_v'\n", " elif 'attn.o' in name or '.o.' in name: t = 'attn_o'\n", " elif 'mlp.up' in name: t = 'mlp_up'\n", " elif 'mlp.down' in name: t = 'mlp_down'\n", " else: t = 'other'\n", " type_results[t]['res'].append(res)\n", " type_results[t]['cos'].append(cos)\n", " type_results[t]['raw'].append(raw)\n", " type_results[t]['spec'].append(spec)\n", "\n", " elapsed = time.time() - t0\n", " gain = np.mean(all_cos) - np.mean(all_raw)\n", "\n", " print(f\" {len(items)} matrices in {elapsed:.0f}s \")\n", " print(f\"\\n OVERALL:\")\n", " print(f\" Procrustes residual: {np.mean(all_res):.4f} ±{np.std(all_res):.4f}\")\n", " print(f\" Procrustes cosine: {np.mean(all_cos):.4f} ±{np.std(all_cos):.4f}\")\n", " print(f\" Raw cosine (no rot): {np.mean(all_raw):.4f} ±{np.std(all_raw):.4f}\")\n", " print(f\" Spectral corr: {np.mean(all_spec):.4f} ±{np.std(all_spec):.4f}\")\n", " print(f\" Rotation gain: {gain:+.4f}\")\n", "\n", " print(f\"\\n BY WEIGHT TYPE:\")\n", " print(f\" {'Type':15s} {'N':>3s} {'Residual':>10s} {'ProcrCos':>10s} {'RawCos':>10s} {'Spectral':>10s}\")\n", " for t in ['attn_q', 'attn_k', 'attn_v', 'attn_o', 'mlp_up', 'mlp_down', 'other']:\n", " r = type_results[t]\n", " if r['res']:\n", " print(f\" {t:15s} {len(r['res']):3d} {np.mean(r['res']):10.4f} \"\n", " f\"{np.mean(r['cos']):10.4f} {np.mean(r['raw']):10.4f} \"\n", " f\"{np.mean(r['spec']):10.4f}\")\n", "\n", " # ── Per-layer depth profile ──\n", " print(f\"\\n{'='*70}\")\n", " print(\"DEPTH PROFILE — BERT vs DINOv2\")\n", " print(f\"{'='*70}\")\n", "\n", " layer_results = defaultdict(lambda: {'res':[], 'cos':[], 'raw':[], 'spec':[]})\n", " for name, A, B in items:\n", " for l in range(24):\n", " if f'L{l}.' in name or f'layer.{l}.' in name or f'.{l}.' in name:\n", " res, cos = procrustes_gpu(A, B)\n", " raw = raw_cosine_gpu(A, B)\n", " spec = spectral_corr_gpu(A, B)\n", " layer_results[l]['res'].append(res)\n", " layer_results[l]['cos'].append(cos)\n", " layer_results[l]['raw'].append(raw)\n", " layer_results[l]['spec'].append(spec)\n", " break\n", "\n", " print(f\" {'Layer':>6s} {'N':>3s} {'Residual':>10s} {'ProcrCos':>10s} {'RawCos':>10s} {'Spectral':>10s}\")\n", " for l in range(24):\n", " r = layer_results[l]\n", " if r['res']:\n", " print(f\" L{l:2d} {len(r['res']):3d} {np.mean(r['res']):10.4f} \"\n", " f\"{np.mean(r['cos']):10.4f} {np.mean(r['raw']):10.4f} \"\n", " f\"{np.mean(r['spec']):10.4f}\")\n", "\n", " if layer_results:\n", " early = [np.mean(layer_results[l]['cos']) for l in range(6) if layer_results[l]['cos']]\n", " late = [np.mean(layer_results[l]['cos']) for l in range(18,24) if layer_results[l]['cos']]\n", " if early and late:\n", " print(f\"\\n Early layers (0-5): mean Procrustes cos = {np.mean(early):.4f}\")\n", " print(f\" Late layers (18-23): mean Procrustes cos = {np.mean(late):.4f}\")\n", " if np.mean(early) > np.mean(late):\n", " print(f\" → Early layers more aligned (shared low-level structure)\")\n", " else:\n", " print(f\" → Late layers more aligned (shared high-level structure)\")\n", "\n", " # ── QK Procrustes per layer ──\n", " print(f\"\\n{'='*70}\")\n", " print(\"Q vs K ALIGNMENT — per layer\")\n", " print(f\"{'='*70}\")\n", "\n", " print(f\" {'Layer':>6s} {'Q cos':>8s} {'K cos':>8s} {'V cos':>8s} {'O cos':>8s}\")\n", " for l in range(24):\n", " row = f\" L{l:2d} \"\n", " for wt in ['attn.q', 'attn.k', 'attn.v', 'attn.o']:\n", " found = False\n", " for name, A, B in items:\n", " if f'layer.{l}.' in name and f'.{wt}.' in name:\n", " _, cos = procrustes_gpu(A, B)\n", " row += f\" {cos:8.4f}\"\n", " found = True\n", " break\n", " if not found:\n", " row += f\" -\"\n", " print(row)\n", "\n", "# ── Summary ──\n", "print(f\"\\n{'='*70}\")\n", "print(\"SUMMARY — BERT-large vs DINOv2-large\")\n", "print(f\"{'='*70}\")\n", "print(f\"Architecture: identical (1024-d, 24L, 16H)\")\n", "print(f\"Training: BERT=Masked LM (text), DINOv2=Self-supervised (vision)\")\n", "print(f\"Matrices compared: {len(items)}\")\n", "if items:\n", " print(f\"\\nProcrustes cosine: {np.mean(all_cos):.4f} (after optimal rotation)\")\n", " print(f\"Raw cosine: {np.mean(all_raw):.4f} (no rotation)\")\n", " print(f\"Rotation gain: {gain:+.4f}\")\n", " print(f\"Spectral corr: {np.mean(all_spec):.4f}\")\n", "\n", "print(f\"\\nReference (same-architecture, same-task — VAEs):\")\n", "print(f\" SD1.5 vs Flux.2: Procrustes cos=0.757, raw=0.000, spectral=0.979\")\n", "print(f\" SDXL vs Flux.1: Procrustes cos=0.675, raw=0.024, spectral=0.939\")\n", "print(f\"\\nVRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB\")\n", "print(\"Done.\")\n" ] }, { "cell_type": "markdown", "id": "cdb80b80", "metadata": {}, "source": [ "## 18. Qwen3.5 Embedding Probes (Prior Session)\n", "*Sections II, III*" ] }, { "cell_type": "code", "execution_count": null, "id": "31756fc9", "metadata": {}, "outputs": [], "source": [ "# Qwen3.5-0.8B embeddings\n", "\n", "# ============================================================================\n", "# PROBE QWEN3.5 EMBEDDING GEOMETRY\n", "# Run in Colab with GPU runtime\n", "# ============================================================================\n", "\n", "# # Probing Qwen3.5 Embedding Geometry\n", "# Analyzing the 248,320 × hidden_dim embedding matrix for geometric structure.\n", "\n", "# Qwen3.5 requires transformers from main — not in any stable release yet.\n", "# !pip install -q git+https://github.com/huggingface/transformers.git@main\n", "# !pip install -q torch numpy scipy matplotlib scikit-learn\n", "\n", "import torch\n", "import numpy as np\n", "from scipy.spatial.distance import pdist, squareform\n", "from scipy.stats import describe\n", "import matplotlib.pyplot as plt\n", "from collections import Counter\n", "import math\n", "import gc, time\n", "\n", "# Smallest Qwen3.5 dense model — same 248K vocab, same DeltaNet hybrid arch\n", "# Available sizes: 0.8B, 2B, 4B, 9B (dense) | 35B-A3B, 122B-A10B, 397B-A17B (MoE)\n", "MODEL_ID = \"Qwen/Qwen3.5-0.8B\"\n", "\n", "print(f\"Loading {MODEL_ID}...\")\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_ID,\n", " dtype=\"auto\",\n", " device_map=\"cpu\", # embeddings are cheap, keep on CPU for analysis\n", ")\n", "model.eval()\n", "\n", "# Extract the embedding matrix (convert to fp32 for analysis regardless of load dtype)\n", "embed_weight = model.model.embed_tokens.weight.detach().float().clone()\n", "lm_head_weight = model.lm_head.weight.detach().float().clone()\n", "\n", "vocab_size, hidden_dim = embed_weight.shape\n", "print(f\"\\nEmbed matrix: {vocab_size} × {hidden_dim}\")\n", "print(f\"Embed params: {vocab_size * hidden_dim:,} ({vocab_size * hidden_dim * 4 / 1e6:.1f}MB fp32)\")\n", "print(f\"LM head params: {lm_head_weight.shape[0] * lm_head_weight.shape[1]:,}\")\n", "print(f\"Tied weights: {torch.allclose(embed_weight, lm_head_weight)}\")\n", "\n", "# Free the model, keep only the matrices\n", "del model\n", "gc.collect()\n", "torch.cuda.empty_cache() if torch.cuda.is_available() else None\n", "\n", "print(\"=\" * 70)\n", "print(\"GLOBAL EMBEDDING STATISTICS\")\n", "print(\"=\" * 70)\n", "\n", "norms = embed_weight.norm(dim=1)\n", "print(f\"\\nVector norms:\")\n", "print(f\" Mean: {norms.mean():.6f}\")\n", "print(f\" Std: {norms.std():.6f}\")\n", "print(f\" Min: {norms.min():.6f} (token {norms.argmin().item()})\")\n", "print(f\" Max: {norms.max():.6f} (token {norms.argmax().item()})\")\n", "print(f\" Median: {norms.median():.6f}\")\n", "\n", "# Norm distribution\n", "print(f\"\\nNorm percentiles:\")\n", "for p in [1, 5, 25, 50, 75, 95, 99]:\n", " val = torch.quantile(norms, p / 100.0)\n", " print(f\" {p:3d}%: {val:.6f}\")\n", "\n", "# Check for dead/zero embeddings\n", "zero_mask = norms < 1e-6\n", "print(f\"\\nZero/near-zero embeddings: {zero_mask.sum().item()} / {vocab_size}\")\n", "\n", "# Per-dimension statistics\n", "dim_means = embed_weight.mean(dim=0)\n", "dim_stds = embed_weight.std(dim=0)\n", "print(f\"\\nPer-dimension mean of means: {dim_means.mean():.8f}\")\n", "print(f\"Per-dimension mean of stds: {dim_stds.mean():.8f}\")\n", "print(f\"Per-dimension std of means: {dim_means.std():.8f}\")\n", "print(f\"Per-dimension std of stds: {dim_stds.std():.8f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"COSINE SIMILARITY DISTRIBUTION (sampled)\")\n", "print(\"=\" * 70)\n", "\n", "N_SAMPLE = 5000\n", "rng = np.random.default_rng(42)\n", "sample_idx = rng.choice(vocab_size, size=N_SAMPLE, replace=False)\n", "sample_embeds = embed_weight[sample_idx]\n", "\n", "# Normalize for cosine\n", "sample_normed = sample_embeds / sample_embeds.norm(dim=1, keepdim=True).clamp(min=1e-8)\n", "\n", "# Pairwise cosine similarities (upper triangle)\n", "cos_sim = sample_normed @ sample_normed.T\n", "# Extract upper triangle\n", "triu_idx = torch.triu_indices(N_SAMPLE, N_SAMPLE, offset=1)\n", "cos_values = cos_sim[triu_idx[0], triu_idx[1]].numpy()\n", "\n", "print(f\"Pairwise cosine similarities ({len(cos_values):,} pairs):\")\n", "print(f\" Mean: {cos_values.mean():.6f}\")\n", "print(f\" Std: {cos_values.std():.6f}\")\n", "print(f\" Min: {cos_values.min():.6f}\")\n", "print(f\" Max: {cos_values.max():.6f}\")\n", "print(f\" Median: {np.median(cos_values):.6f}\")\n", "\n", "for p in [1, 5, 25, 50, 75, 95, 99]:\n", " val = np.percentile(cos_values, p)\n", " print(f\" {p:3d}%: {val:.6f}\")\n", "\n", "# Check for the 0.29514 constant\n", "target = 0.29514\n", "closest_to_target = np.abs(cos_values - target).min()\n", "print(f\"\\nClosest cosine sim to 0.29514: {target + (cos_values[np.abs(cos_values - target).argmin()] - target):.6f}\")\n", "print(f\" Distance: {closest_to_target:.8f}\")\n", "\n", "# What fraction of pairs fall near common thresholds?\n", "for threshold in [0.0, 0.1, 0.2, 0.29514, 0.3, 0.5, 0.7, 0.9]:\n", " frac = (np.abs(cos_values - threshold) < 0.01).mean()\n", " print(f\" Pairs within ±0.01 of {threshold:.5f}: {frac*100:.3f}%\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"EUCLIDEAN DISTANCE DISTRIBUTION (sampled)\")\n", "print(\"=\" * 70)\n", "\n", "N_SMALL = 2000\n", "small_idx = rng.choice(vocab_size, size=N_SMALL, replace=False)\n", "small_embeds = embed_weight[small_idx].numpy()\n", "\n", "dists = pdist(small_embeds, metric='euclidean')\n", "print(f\"Pairwise Euclidean distances ({len(dists):,} pairs):\")\n", "print(f\" Mean: {dists.mean():.6f}\")\n", "print(f\" Std: {dists.std():.6f}\")\n", "print(f\" Min: {dists.min():.6f}\")\n", "print(f\" Max: {dists.max():.6f}\")\n", "print(f\" Median: {np.median(dists):.6f}\")\n", "\n", "# Normalized distances (by sqrt(dim))\n", "norm_dists = dists / np.sqrt(hidden_dim)\n", "print(f\"\\nNormalized by sqrt({hidden_dim}):\")\n", "print(f\" Mean: {norm_dists.mean():.6f}\")\n", "print(f\" Std: {norm_dists.std():.6f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"EIGENSPECTRUM & INTRINSIC DIMENSIONALITY\")\n", "print(\"=\" * 70)\n", "\n", "# Use a larger sample for PCA\n", "N_PCA = 10000\n", "pca_idx = rng.choice(vocab_size, size=N_PCA, replace=False)\n", "pca_embeds = embed_weight[pca_idx].numpy()\n", "\n", "# Center the data\n", "pca_centered = pca_embeds - pca_embeds.mean(axis=0)\n", "\n", "# Covariance matrix (hidden_dim × hidden_dim)\n", "t0 = time.time()\n", "cov = (pca_centered.T @ pca_centered) / (N_PCA - 1)\n", "eigenvalues = np.linalg.eigvalsh(cov)[::-1] # sorted descending\n", "print(f\"Eigendecomposition took {time.time()-t0:.1f}s\")\n", "\n", "# Explained variance\n", "total_var = eigenvalues.sum()\n", "cumvar = np.cumsum(eigenvalues) / total_var\n", "\n", "print(f\"\\nTotal variance: {total_var:.4f}\")\n", "print(f\"Top eigenvalues:\")\n", "for i in range(min(20, len(eigenvalues))):\n", " print(f\" λ_{i:3d}: {eigenvalues[i]:10.4f} ({eigenvalues[i]/total_var*100:5.2f}% | cumulative: {cumvar[i]*100:6.2f}%)\")\n", "\n", "# Intrinsic dimensionality estimates\n", "for threshold in [0.80, 0.90, 0.95, 0.99]:\n", " n_dims = np.searchsorted(cumvar, threshold) + 1\n", " print(f\" Dims for {threshold*100:.0f}% variance: {n_dims}\")\n", "\n", "# Participation ratio (effective dimensionality)\n", "participation_ratio = (eigenvalues.sum() ** 2) / (eigenvalues ** 2).sum()\n", "print(f\"\\nParticipation ratio (effective dim): {participation_ratio:.1f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"SIMPLEX GEOMETRY PROBING (Cayley-Menger)\")\n", "print(\"=\" * 70)\n", "\n", "def cayley_menger_volume_sq(points):\n", " \"\"\"\n", " Compute squared volume of simplex from Cayley-Menger determinant.\n", " points: (n_points, dim) — for a k-simplex, n_points = k+1\n", " Returns squared volume (can be negative if degenerate).\n", " \"\"\"\n", " n = len(points)\n", " # Build Cayley-Menger matrix\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1\n", " D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq\n", " D[j + 1, i + 1] = d_sq\n", "\n", " k = n - 1 # simplex dimension\n", " sign = (-1) ** (k + 1)\n", " factorial_sq = math.factorial(k) ** 2\n", " denom = (2 ** k) * factorial_sq\n", "\n", " det = np.linalg.det(D)\n", " vol_sq = sign * det / denom\n", " return vol_sq\n", "\n", "\n", "# Sample random 5-point simplices (4-simplices = pentachora)\n", "N_SIMPLICES = 1000\n", "volumes = []\n", "all_embed_np = embed_weight.numpy()\n", "\n", "for _ in range(N_SIMPLICES):\n", " idx = rng.choice(vocab_size, size=5, replace=False)\n", " pts = all_embed_np[idx]\n", " vol_sq = cayley_menger_volume_sq(pts)\n", " if vol_sq > 0:\n", " volumes.append(np.sqrt(vol_sq))\n", "\n", "volumes = np.array(volumes)\n", "print(f\"Valid pentachora: {len(volumes)} / {N_SIMPLICES}\")\n", "print(f\" Mean volume: {volumes.mean():.6f}\")\n", "print(f\" Std volume: {volumes.std():.6f}\")\n", "print(f\" Min volume: {volumes.min():.6f}\")\n", "print(f\" Max volume: {volumes.max():.6f}\")\n", "print(f\" Median volume: {np.median(volumes):.6f}\")\n", "\n", "# Normalize by expected volume for random points in this dimension\n", "print(f\"\\nNormalized (÷ mean): {volumes.mean() / volumes.mean():.6f}\")\n", "print(f\" CV (std/mean): {volumes.std() / volumes.mean():.6f}\")\n", "\n", "# Compare: random Gaussian points in same dimension\n", "rand_vols = []\n", "for _ in range(N_SIMPLICES):\n", " pts = rng.standard_normal((5, hidden_dim)) * dim_stds.numpy()\n", " vol_sq = cayley_menger_volume_sq(pts)\n", " if vol_sq > 0:\n", " rand_vols.append(np.sqrt(vol_sq))\n", "rand_vols = np.array(rand_vols)\n", "\n", "print(f\"\\nRandom Gaussian comparison:\")\n", "print(f\" Mean volume: {rand_vols.mean():.6f}\")\n", "print(f\" Ratio (embed/random): {volumes.mean() / rand_vols.mean():.6f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"TOKEN CATEGORY STRUCTURE\")\n", "print(\"=\" * 70)\n", "\n", "# Decode a range of tokens to see what we're working with\n", "# Check specific token ranges for geometric clustering\n", "categories = {\n", " \"ASCII digits\": list(range(48, 58)), # 0-9 as individual chars\n", " \"ASCII upper\": list(range(65, 91)), # A-Z\n", " \"ASCII lower\": list(range(97, 123)), # a-z\n", " \"Special tokens (tail)\": list(range(vocab_size - 100, vocab_size)),\n", "}\n", "\n", "# Find actual token IDs for semantic categories\n", "digit_tokens = [tokenizer.encode(str(d), add_special_tokens=False)[0] for d in range(10)]\n", "print(f\"Digit token IDs: {digit_tokens}\")\n", "\n", "# Measure intra-category vs inter-category distances\n", "for cat_name, token_ids in categories.items():\n", " valid_ids = [t for t in token_ids if t < vocab_size]\n", " if len(valid_ids) < 2:\n", " continue\n", " cat_embeds = embed_weight[valid_ids]\n", " cat_normed = cat_embeds / cat_embeds.norm(dim=1, keepdim=True).clamp(min=1e-8)\n", " intra_cos = (cat_normed @ cat_normed.T)\n", " triu = torch.triu_indices(len(valid_ids), len(valid_ids), offset=1)\n", " intra_vals = intra_cos[triu[0], triu[1]]\n", " print(f\"\\n{cat_name} ({len(valid_ids)} tokens):\")\n", " print(f\" Intra-category cosine sim: mean={intra_vals.mean():.4f}, std={intra_vals.std():.4f}\")\n", " print(f\" vs global mean: {cos_values.mean():.4f}\")\n", "\n", "# Digit tokens specifically\n", "if len(digit_tokens) >= 2:\n", " digit_embeds = embed_weight[digit_tokens]\n", " digit_normed = digit_embeds / digit_embeds.norm(dim=1, keepdim=True).clamp(min=1e-8)\n", " digit_cos = (digit_normed @ digit_normed.T).numpy()\n", " print(f\"\\nDigit pairwise cosine similarities:\")\n", " for i in range(10):\n", " for j in range(i + 1, 10):\n", " print(f\" '{i}' vs '{j}': {digit_cos[i, j]:.4f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"EMBED vs LM_HEAD COMPARISON\")\n", "print(\"=\" * 70)\n", "\n", "# Are they related?\n", "cos_embed_head = torch.nn.functional.cosine_similarity(\n", " embed_weight, lm_head_weight, dim=1\n", ")\n", "print(f\"Per-token cosine(embed, lm_head):\")\n", "print(f\" Mean: {cos_embed_head.mean():.6f}\")\n", "print(f\" Std: {cos_embed_head.std():.6f}\")\n", "print(f\" Min: {cos_embed_head.min():.6f}\")\n", "print(f\" Max: {cos_embed_head.max():.6f}\")\n", "\n", "# Frobenius norm of difference\n", "diff_norm = (embed_weight - lm_head_weight).norm()\n", "embed_norm = embed_weight.norm()\n", "head_norm = lm_head_weight.norm()\n", "print(f\"\\nFrobenius norms:\")\n", "print(f\" Embed: {embed_norm:.4f}\")\n", "print(f\" LM head: {head_norm:.4f}\")\n", "print(f\" Difference: {diff_norm:.4f}\")\n", "print(f\" Relative: {diff_norm / embed_norm:.4f}\")\n", "\n", "fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", "fig.suptitle(f\"Qwen3.5 Embedding Geometry — {MODEL_ID}\", fontsize=14)\n", "\n", "# 1. Norm distribution\n", "axes[0, 0].hist(norms.numpy(), bins=100, alpha=0.7, color='steelblue')\n", "axes[0, 0].set_title(\"Embedding Norm Distribution\")\n", "axes[0, 0].set_xlabel(\"L2 Norm\")\n", "axes[0, 0].axvline(norms.mean(), color='red', linestyle='--', label=f'mean={norms.mean():.3f}')\n", "axes[0, 0].legend()\n", "\n", "# 2. Cosine similarity distribution\n", "axes[0, 1].hist(cos_values, bins=200, alpha=0.7, color='coral')\n", "axes[0, 1].set_title(\"Pairwise Cosine Similarity\")\n", "axes[0, 1].set_xlabel(\"Cosine Similarity\")\n", "axes[0, 1].axvline(0.29514, color='green', linestyle='--', label='0.29514')\n", "axes[0, 1].axvline(cos_values.mean(), color='red', linestyle='--', label=f'mean={cos_values.mean():.4f}')\n", "axes[0, 1].legend()\n", "\n", "# 3. Eigenspectrum\n", "axes[0, 2].semilogy(eigenvalues[:100], 'o-', markersize=2, color='darkgreen')\n", "axes[0, 2].set_title(\"Eigenspectrum (top 100)\")\n", "axes[0, 2].set_xlabel(\"Component\")\n", "axes[0, 2].set_ylabel(\"Eigenvalue (log)\")\n", "\n", "# 4. Cumulative variance\n", "axes[1, 0].plot(cumvar[:200], color='purple')\n", "axes[1, 0].axhline(0.95, color='red', linestyle='--', alpha=0.5, label='95%')\n", "axes[1, 0].axhline(0.99, color='orange', linestyle='--', alpha=0.5, label='99%')\n", "axes[1, 0].set_title(\"Cumulative Variance Explained\")\n", "axes[1, 0].set_xlabel(\"Components\")\n", "axes[1, 0].legend()\n", "\n", "# 5. Pentachoron volume distribution\n", "axes[1, 1].hist(np.log10(volumes + 1e-30), bins=50, alpha=0.7, color='teal', label='Embed')\n", "axes[1, 1].hist(np.log10(rand_vols + 1e-30), bins=50, alpha=0.5, color='gray', label='Random')\n", "axes[1, 1].set_title(\"Pentachoron Volumes (log10)\")\n", "axes[1, 1].set_xlabel(\"log10(volume)\")\n", "axes[1, 1].legend()\n", "\n", "# 6. Embed vs LM head per-token cosine\n", "axes[1, 2].hist(cos_embed_head.numpy(), bins=100, alpha=0.7, color='goldenrod')\n", "axes[1, 2].set_title(\"Embed ↔ LM Head Cosine Similarity\")\n", "axes[1, 2].set_xlabel(\"Cosine Similarity\")\n", "axes[1, 2].axvline(cos_embed_head.mean(), color='red', linestyle='--',\n", " label=f'mean={cos_embed_head.mean():.3f}')\n", "axes[1, 2].legend()\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"/content/qwen35_embedding_geometry.png\", dpi=150, bbox_inches='tight')\n", "plt.show()\n", "print(\"\\nSaved: /content/qwen35_embedding_geometry.png\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"SUMMARY\")\n", "print(\"=\" * 70)\n", "print(f\"Model: {MODEL_ID}\")\n", "print(f\"Vocab: {vocab_size:,} tokens\")\n", "print(f\"Hidden dim: {hidden_dim}\")\n", "print(f\"Embed params: {vocab_size * hidden_dim:,}\")\n", "print(f\"Mean norm: {norms.mean():.4f}\")\n", "print(f\"Mean pairwise cosine: {cos_values.mean():.6f}\")\n", "print(f\"Intrinsic dim (participation ratio): {participation_ratio:.1f}\")\n", "print(f\"95% variance dims: {np.searchsorted(cumvar, 0.95) + 1}\")\n", "print(f\"Pentachoron vol ratio (embed/random): {volumes.mean() / rand_vols.mean():.4f}\")\n", "print(f\"Embed-Head alignment: {cos_embed_head.mean():.4f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "97259d95", "metadata": {}, "outputs": [], "source": [ "# Qwen3.5 cross-scale projection\n", "\n", "# ============================================================================\n", "# CROSS-SCALE PROJECTION: Qwen3.5-4B (2560d) → 0.8B (1024d)\n", "# Run AFTER the cross-check script (reuses embeddings dict in memory)\n", "# ============================================================================\n", "\n", "import torch\n", "import numpy as np\n", "from scipy.linalg import orthogonal_procrustes\n", "import matplotlib.pyplot as plt\n", "import math\n", "\n", "# Both models loaded from cross-check: embeddings[\"0.8B\"], embeddings[\"4B\"]\n", "E_small = embeddings[\"0.8B\"] # [248320, 1024]\n", "E_large = embeddings[\"4B\"] # [248320, 2560]\n", "\n", "# Step 1: PCA the 4B down to 1024 dims (match small model's dimensionality)\n", "# Use a large anchor set for stable PCA\n", "N_PCA = 50000\n", "rng = np.random.default_rng(42)\n", "pca_idx = rng.choice(248320, size=N_PCA, replace=False)\n", "\n", "E_large_sample = E_large[pca_idx].numpy()\n", "mean_large = E_large_sample.mean(axis=0)\n", "E_large_centered = E_large_sample - mean_large\n", "\n", "# SVD for PCA\n", "print(\"Computing PCA projection (2560 → 1024)...\")\n", "U, S, Vt = np.linalg.svd(E_large_centered, full_matrices=False)\n", "# Top 1024 components\n", "V_proj = Vt[:1024].T # [2560, 1024] projection matrix\n", "\n", "# Project ALL 4B embeddings into 1024-d\n", "E_large_all = E_large.numpy()\n", "E_large_projected = (E_large_all - mean_large) @ V_proj # [248320, 1024]\n", "print(f\"Projected shape: {E_large_projected.shape}\")\n", "\n", "# Variance preserved by top 1024 of 2560\n", "total_var = (S ** 2).sum()\n", "kept_var = (S[:1024] ** 2).sum()\n", "print(f\"Variance retained: {kept_var/total_var:.4f} ({kept_var/total_var*100:.1f}%)\")\n", "\n", "# Step 2: Orthogonal Procrustes alignment\n", "# Find rotation R such that E_large_projected @ R ≈ E_small\n", "# Use anchor tokens for fitting, hold out rest for evaluation\n", "N_ANCHOR = 10000\n", "N_TEST = 5000\n", "all_idx = rng.permutation(248320)\n", "anchor_idx = all_idx[:N_ANCHOR]\n", "test_idx = all_idx[N_ANCHOR:N_ANCHOR + N_TEST]\n", "\n", "# Center both sets on anchor means\n", "E_small_np = E_small.numpy()\n", "anchor_small = E_small_np[anchor_idx]\n", "anchor_large = E_large_projected[anchor_idx]\n", "\n", "mean_s = anchor_small.mean(axis=0)\n", "mean_l = anchor_large.mean(axis=0)\n", "\n", "anchor_small_c = anchor_small - mean_s\n", "anchor_large_c = anchor_large - mean_l\n", "\n", "print(\"Computing Procrustes rotation...\")\n", "R, scale = orthogonal_procrustes(anchor_large_c, anchor_small_c)\n", "print(f\"Procrustes scale factor: {scale:.6f}\")\n", "\n", "# Apply: project + rotate + recenter\n", "E_large_aligned = (E_large_projected - mean_l) @ R + mean_s\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"PROJECTION QUALITY ON HELD-OUT TOKENS\")\n", "print(\"=\" * 70)\n", "\n", "test_small = E_small_np[test_idx]\n", "test_aligned = E_large_aligned[test_idx]\n", "\n", "# Per-token cosine similarity\n", "def cos_sim_rows(A, B):\n", " dot = (A * B).sum(axis=1)\n", " norm_a = np.linalg.norm(A, axis=1)\n", " norm_b = np.linalg.norm(B, axis=1)\n", " return dot / (norm_a * norm_b + 1e-8)\n", "\n", "token_cos = cos_sim_rows(test_small, test_aligned)\n", "print(f\"\\nPer-token cosine(0.8B, aligned_4B) on {N_TEST} held-out tokens:\")\n", "print(f\" Mean: {token_cos.mean():.6f}\")\n", "print(f\" Std: {token_cos.std():.6f}\")\n", "print(f\" Median: {np.median(token_cos):.6f}\")\n", "print(f\" Min: {token_cos.min():.6f} (token {test_idx[token_cos.argmin()]})\")\n", "print(f\" Max: {token_cos.max():.6f}\")\n", "for p in [1, 5, 10, 25, 50, 75, 90, 95, 99]:\n", " print(f\" {p:>3}%: {np.percentile(token_cos, p):.6f}\")\n", "\n", "# Euclidean distance after alignment\n", "token_dist = np.linalg.norm(test_small - test_aligned, axis=1)\n", "print(f\"\\nPer-token L2 distance:\")\n", "print(f\" Mean: {token_dist.mean():.6f}\")\n", "print(f\" Std: {token_dist.std():.6f}\")\n", "print(f\" Median: {np.median(token_dist):.6f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"RELATIONAL STRUCTURE AFTER PROJECTION\")\n", "print(\"=\" * 70)\n", "\n", "N_REL = 1000\n", "rel_idx = test_idx[:N_REL]\n", "rel_small = E_small_np[rel_idx]\n", "rel_aligned = E_large_aligned[rel_idx]\n", "\n", "# Normalize\n", "rel_small_n = rel_small / (np.linalg.norm(rel_small, axis=1, keepdims=True) + 1e-8)\n", "rel_aligned_n = rel_aligned / (np.linalg.norm(rel_aligned, axis=1, keepdims=True) + 1e-8)\n", "\n", "cos_mat_small = rel_small_n @ rel_small_n.T\n", "cos_mat_aligned = rel_aligned_n @ rel_aligned_n.T\n", "\n", "tri = np.triu_indices(N_REL, k=1)\n", "flat_s = cos_mat_small[tri[0], tri[1]]\n", "flat_a = cos_mat_aligned[tri[0], tri[1]]\n", "\n", "rel_corr = np.corrcoef(flat_s, flat_a)[0, 1]\n", "rel_diff = np.abs(flat_s - flat_a).mean()\n", "print(f\"Relational correlation (projected 4B vs native 0.8B): {rel_corr:.6f}\")\n", "print(f\"Mean |diff| in pairwise cosine: {rel_diff:.6f}\")\n", "\n", "# Compare to the raw (pre-projection) relational correlation from cross-check\n", "# (was 0.920 on 500 tokens — this should be higher after alignment)\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"DIGIT GEOMETRY AFTER PROJECTION\")\n", "print(\"=\" * 70)\n", "\n", "digit_tokens = [tokenizer.encode(str(d), add_special_tokens=False)[0] for d in range(10)]\n", "\n", "digit_small = E_small_np[digit_tokens]\n", "digit_aligned = E_large_aligned[digit_tokens]\n", "\n", "# Per-digit cosine alignment\n", "for d in range(10):\n", " cs = cos_sim_rows(digit_small[d:d+1], digit_aligned[d:d+1])[0]\n", " print(f\" '{d}': cosine = {cs:.6f}\")\n", "\n", "# Digit pairwise structure comparison\n", "ds_n = digit_small / (np.linalg.norm(digit_small, axis=1, keepdims=True) + 1e-8)\n", "da_n = digit_aligned / (np.linalg.norm(digit_aligned, axis=1, keepdims=True) + 1e-8)\n", "cos_digits_s = ds_n @ ds_n.T\n", "cos_digits_a = da_n @ da_n.T\n", "\n", "tri10 = np.triu_indices(10, k=1)\n", "r_digits = np.corrcoef(cos_digits_s[tri10], cos_digits_a[tri10])[0, 1]\n", "print(f\"\\n Digit pairwise structure correlation: {r_digits:.6f}\")\n", "print(f\" Mean |diff|: {np.abs(cos_digits_s[tri10] - cos_digits_a[tri10]).mean():.6f}\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"LARGEST DIVERGENCES (tokens where 0.8B ≠ projected 4B)\")\n", "print(\"=\" * 70)\n", "\n", "# Compute alignment for ALL tokens\n", "all_cos = cos_sim_rows(E_small_np, E_large_aligned)\n", "worst_idx = np.argsort(all_cos)[:50]\n", "best_idx = np.argsort(all_cos)[-50:][::-1]\n", "\n", "print(\"\\n--- 20 WORST aligned tokens ---\")\n", "for i, idx in enumerate(worst_idx[:20]):\n", " tok = tokenizer.decode([idx]).replace('\\n', '\\\\n')\n", " print(f\" {i+1:3d}. token {idx:6d} cos={all_cos[idx]:.4f} '{tok}'\")\n", "\n", "print(\"\\n--- 20 BEST aligned tokens ---\")\n", "for i, idx in enumerate(best_idx[:20]):\n", " tok = tokenizer.decode([idx]).replace('\\n', '\\\\n')\n", " print(f\" {i+1:3d}. token {idx:6d} cos={all_cos[idx]:.4f} '{tok}'\")\n", "\n", "# Distribution of alignment quality\n", "print(\"\\n--- Alignment distribution ---\")\n", "for thresh in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:\n", " frac = (all_cos > thresh).mean()\n", " print(f\" cos > {thresh:.1f}: {frac*100:.2f}%\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"PENTACHORON GEOMETRY IN PROJECTED vs NATIVE SPACE\")\n", "print(\"=\" * 70)\n", "\n", "def cayley_menger_volume_sq(points):\n", " n = len(points)\n", " D = np.zeros((n + 1, n + 1))\n", " D[0, 1:] = 1\n", " D[1:, 0] = 1\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " d_sq = np.sum((points[i] - points[j]) ** 2)\n", " D[i + 1, j + 1] = d_sq\n", " D[j + 1, i + 1] = d_sq\n", " k = n - 1\n", " sign = (-1) ** (k + 1)\n", " factorial_sq = math.factorial(k) ** 2\n", " denom = (2 ** k) * factorial_sq\n", " det = np.linalg.det(D)\n", " vol_sq = sign * det / denom\n", " return vol_sq\n", "\n", "N_SIMP = 1000\n", "vols_native = []\n", "vols_projected = []\n", "\n", "for _ in range(N_SIMP):\n", " idx = rng.choice(248320, size=5, replace=False)\n", "\n", " pts_native = E_small_np[idx]\n", " vol_sq = cayley_menger_volume_sq(pts_native)\n", " if vol_sq > 0:\n", " vols_native.append(np.sqrt(vol_sq))\n", "\n", " pts_proj = E_large_aligned[idx]\n", " vol_sq = cayley_menger_volume_sq(pts_proj)\n", " if vol_sq > 0:\n", " vols_projected.append(np.sqrt(vol_sq))\n", "\n", "vols_native = np.array(vols_native)\n", "vols_projected = np.array(vols_projected)\n", "\n", "print(f\"Native 0.8B: mean={vols_native.mean():.6f} CV={vols_native.std()/vols_native.mean():.4f}\")\n", "print(f\"Projected 4B: mean={vols_projected.mean():.6f} CV={vols_projected.std()/vols_projected.mean():.4f}\")\n", "print(f\"Volume ratio: {vols_projected.mean()/vols_native.mean():.6f}\")\n", "\n", "# Correlation of per-simplex volumes (same random indices)\n", "min_len = min(len(vols_native), len(vols_projected))\n", "vol_corr = np.corrcoef(vols_native[:min_len], vols_projected[:min_len])[0, 1]\n", "print(f\"Per-simplex volume correlation: {vol_corr:.6f}\")\n", "\n", "fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", "fig.suptitle(\"Qwen3.5: 4B→0.8B Procrustes Projection Analysis\", fontsize=14)\n", "\n", "# 1. Per-token alignment histogram\n", "axes[0, 0].hist(all_cos, bins=200, color='steelblue', alpha=0.8)\n", "axes[0, 0].axvline(all_cos.mean(), color='red', linestyle='--', label=f'mean={all_cos.mean():.3f}')\n", "axes[0, 0].set_title(\"Per-token cosine(native, projected)\")\n", "axes[0, 0].legend()\n", "\n", "# 2. Relational scatter: native 0.8B vs projected 4B pairwise cosines\n", "axes[0, 1].scatter(flat_s, flat_a, alpha=0.02, s=1, color='darkgreen')\n", "axes[0, 1].plot([flat_s.min(), flat_s.max()], [flat_s.min(), flat_s.max()], 'r--', alpha=0.5)\n", "axes[0, 1].set_xlabel(\"Native 0.8B pairwise cosine\")\n", "axes[0, 1].set_ylabel(\"Projected 4B pairwise cosine\")\n", "axes[0, 1].set_title(f\"Relational structure (r={rel_corr:.4f})\")\n", "\n", "# 3. Digit heatmap comparison\n", "diff_digits = cos_digits_s - cos_digits_a\n", "im = axes[0, 2].imshow(diff_digits, cmap='RdBu_r', vmin=-0.1, vmax=0.1)\n", "axes[0, 2].set_title(\"Digit cosine diff (native - projected)\")\n", "axes[0, 2].set_xticks(range(10))\n", "axes[0, 2].set_yticks(range(10))\n", "plt.colorbar(im, ax=axes[0, 2])\n", "\n", "# 4. Pentachoron volume comparison\n", "axes[1, 0].scatter(vols_native[:min_len], vols_projected[:min_len], alpha=0.3, s=5, color='purple')\n", "axes[1, 0].plot([0, vols_native.max()], [0, vols_native.max()], 'r--', alpha=0.5)\n", "axes[1, 0].set_xlabel(\"Native 0.8B volume\")\n", "axes[1, 0].set_ylabel(\"Projected 4B volume\")\n", "axes[1, 0].set_title(f\"Simplex volumes (r={vol_corr:.4f})\")\n", "\n", "# 5. Alignment by token ID (structure in vocab ordering?)\n", "# Sample every 100th token for visibility\n", "sample_step = 100\n", "x_tok = np.arange(0, 248320, sample_step)\n", "y_cos = all_cos[::sample_step]\n", "axes[1, 1].scatter(x_tok, y_cos, alpha=0.3, s=1, color='darkorange')\n", "axes[1, 1].set_xlabel(\"Token ID\")\n", "axes[1, 1].set_ylabel(\"Cosine alignment\")\n", "axes[1, 1].set_title(\"Alignment vs token ID\")\n", "axes[1, 1].axhline(all_cos.mean(), color='red', linestyle='--', alpha=0.5)\n", "\n", "# 6. Eigenspectrum of residual (what's in 4B but NOT in 0.8B?)\n", "residual = E_large_projected - E_large_aligned # NOT the right residual\n", "# Actually: what did 4B learn that 0.8B didn't?\n", "# Compare variance in the aligned vs native\n", "residual = E_small_np[test_idx] - E_large_aligned[test_idx]\n", "res_cov = (residual.T @ residual) / len(test_idx)\n", "res_eig = np.linalg.eigvalsh(res_cov)[::-1]\n", "axes[1, 2].semilogy(range(min(200, len(res_eig))), res_eig[:200], color='crimson')\n", "axes[1, 2].set_title(\"Eigenspectrum of residual (native - projected)\")\n", "axes[1, 2].set_xlabel(\"Component\")\n", "axes[1, 2].set_ylabel(\"Variance\")\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"/content/qwen35_projection_4B_to_08B.png\", dpi=150, bbox_inches='tight')\n", "plt.show()\n", "print(\"\\nSaved: /content/qwen35_projection_4B_to_08B.png\")\n", "\n", "print(\"\\n\" + \"=\" * 70)\n", "print(\"PROJECTION SUMMARY\")\n", "print(\"=\" * 70)\n", "print(f\"PCA variance retained (2560→1024): {kept_var/total_var*100:.1f}%\")\n", "print(f\"Per-token alignment mean cosine: {all_cos.mean():.4f}\")\n", "print(f\"Per-token alignment median cosine: {np.median(all_cos):.4f}\")\n", "print(f\"Relational structure correlation: {rel_corr:.4f}\")\n", "print(f\"Digit structure correlation: {r_digits:.4f}\")\n", "print(f\"Pentachoron per-simplex correlation: {vol_corr:.4f}\")\n", "print(f\"Pentachoron volume ratio (proj/native): {vols_projected.mean()/vols_native.mean():.4f}\")\n", "print(f\"Tokens with cos > 0.5: {(all_cos > 0.5).mean()*100:.1f}%\")\n", "print(f\"Tokens with cos > 0.7: {(all_cos > 0.7).mean()*100:.1f}%\")\n", "print(f\"Tokens with cos > 0.9: {(all_cos > 0.9).mean()*100:.1f}%\")\n" ] }, { "cell_type": "markdown", "id": "cb89d00a", "metadata": {}, "source": [ "---\n", "\n", "## Summary\n", "\n", "17 models profiled across 5 architecture families and 6 training objectives.\n", "\n", "**Universal invariants (hold across ALL models):**\n", "- Cross-layer weight decorrelation: ~0.000 (attention AND conv)\n", "- Full neuron/filter utilization: 0% dead (with minor exceptions)\n", "- Cross-modal QK eigenvalue balance: locked at 0.500\n", "\n", "**Training-specific findings:**\n", "- T5 Q sparsity asymmetry: 93.7% → 100.0% (T5 pretraining only, absent in T5Gemma2)\n", "- UNet QK U-gradient: repulsion in downpath, attraction in uppath\n", "- VAE decoder QK: breaks toward repulsion (reconstruction requires discrimination)\n", "\n", "**Procrustes alignment:**\n", "- VAE pairs: 70-76% cosine after rotation, spectral corr 0.94-0.98\n", "- BERT vs DINOv2 (cross-modal): 61% overall, 70% deep layers, 97% spectral\n", "- O projections most aligned, Q projections least aligned\n", "\n", "**Key architectural result:** The geometric field modulator targets structure that is 60-76% shared across architectures, modalities, and training objectives. Procrustes rotation enables cross-model transfer.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 5 }