reFlow / experiments.py
”reuAC“
Add bilingual interpretability demo with 11 experiments
bf44358
"""Experiment functions for the reFlow interpretability demo, adapted for Gradio."""
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
try:
from adjustText import adjust_text
except ImportError:
adjust_text = lambda texts, **kwargs: None
from model_loader import get_model, get_cached_tensors
REAL_VOCAB = 50257
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _embed(model, ids):
result = model.transformer.wte(ids)
return result[0] if isinstance(result, tuple) else result
def _get_vocab_signals(model):
wte = model.transformer.wte
if hasattr(wte, '_apply_sparsity'):
return wte._apply_sparsity(wte.vocab_to_signals.weight.data)
return wte.vocab_to_signals.weight.data
def _forward_through_layers(model, ids):
with torch.no_grad():
x = _embed(model, ids)
freqs_cis = model.freqs_cis[:ids.size(1)]
for block in model.transformer.h:
x = block(x, freqs_cis)
return x
def _get_logits_from_hidden(model, x_norm):
vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix()
return F.linear(x_norm, vocab_matrix)
def _gini(arr):
arr = np.sort(np.abs(arr))
n = len(arr)
if n == 0 or np.sum(arr) == 0:
return 0.0
index = np.arange(1, n + 1)
return (2 * np.sum(index * arr) / (n * np.sum(arr))) - (n + 1) / n
# ---------------------------------------------------------------------------
# 1. Semantic Galaxy (PCA)
# ---------------------------------------------------------------------------
DEFAULT_CLUSTERS = {
"Countries": ["China", "France", "Germany", "Japan", "India", "Russia"],
"Animals": ["cat", "dog", "fish", "bird", "horse", "bear"],
"Numbers": ["one", "two", "three", "four", "five", "ten"],
"Colors": ["red", "blue", "green", "black", "white", "yellow"],
"Emotions": ["happy", "sad", "angry", "love", "fear", "hate"],
}
@torch.inference_mode()
def exp_semantic_galaxy(
use_countries, use_animals, use_numbers, use_colors, use_emotions, custom_words
):
model, enc, device = get_model()
W_v2s = _get_vocab_signals(model).cpu().numpy()
# Build clusters from checkboxes
clusters = {}
if use_countries:
clusters["Countries"] = DEFAULT_CLUSTERS["Countries"]
if use_animals:
clusters["Animals"] = DEFAULT_CLUSTERS["Animals"]
if use_numbers:
clusters["Numbers"] = DEFAULT_CLUSTERS["Numbers"]
if use_colors:
clusters["Colors"] = DEFAULT_CLUSTERS["Colors"]
if use_emotions:
clusters["Emotions"] = DEFAULT_CLUSTERS["Emotions"]
# Custom words
if custom_words and custom_words.strip():
custom_list = [w.strip() for w in custom_words.split(",") if w.strip()]
if custom_list:
clusters["Custom"] = custom_list
if not clusters:
clusters = DEFAULT_CLUSTERS
recipes, labels, words = [], [], []
for cat, wl in clusters.items():
for w in wl:
tids = enc.encode(" " + w)
if tids and tids[0] < REAL_VOCAB:
recipes.append(W_v2s[tids[0]])
labels.append(cat)
words.append(w)
if len(words) < 3:
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, "Need at least 3 valid words", ha='center', va='center', fontsize=14)
ax.axis('off')
return fig
recipes_arr = np.array(recipes)
coords = PCA(n_components=2).fit_transform(recipes_arr)
label_ids = [list(clusters.keys()).index(l) for l in labels]
sil = silhouette_score(recipes_arr, label_ids) if len(set(label_ids)) >= 2 else 0.0
fig = plt.figure(figsize=(12, 9))
color_map = dict(zip(clusters.keys(), sns.color_palette("Set2", len(clusters))))
texts = []
for i, w in enumerate(words):
plt.scatter(coords[i, 0], coords[i, 1], color=color_map[labels[i]],
s=150, alpha=0.7, edgecolors='white', linewidths=0.5)
texts.append(plt.text(coords[i, 0], coords[i, 1], w, fontsize=11))
if callable(adjust_text) and getattr(adjust_text, '__name__', '') != '<lambda>':
adjust_text(texts, arrowprops=dict(arrowstyle="-", color='gray'))
handles = [plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=color_map[l], markersize=12, label=l) for l in clusters]
plt.legend(handles=handles, title="Clusters", fontsize=10)
plt.title(f"reFlow Semantic Galaxy (PCA)\nSilhouette Score = {sil:.4f}",
fontsize=14, fontweight='bold')
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.tight_layout()
return fig
# ---------------------------------------------------------------------------
# 2. Semantic Algebra
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_semantic_algebra(positive_words, negative_words):
model, enc, device = get_model()
W_v2s = _get_vocab_signals(model)
W_valid = W_v2s[:REAL_VOCAB]
pos_list = [w.strip() for w in positive_words.split(",") if w.strip()]
neg_list = [w.strip() for w in negative_words.split(",") if w.strip()]
if not pos_list:
return "Please enter at least one positive word."
target_vec = torch.zeros(model.config.n_signals, device=device)
exclude_ids = set()
for w in pos_list:
tids = enc.encode(" " + w)
if tids and tids[0] < REAL_VOCAB:
target_vec += W_v2s[tids[0]]
exclude_ids.add(tids[0])
for w in neg_list:
tids = enc.encode(" " + w)
if tids and tids[0] < REAL_VOCAB:
target_vec -= W_v2s[tids[0]]
exclude_ids.add(tids[0])
sims = F.cosine_similarity(target_vec.unsqueeze(0), W_valid)
for tid in exclude_ids:
sims[tid] = -1.0
top_vals, top_ids = torch.topk(sims, 20)
expr = " + ".join(pos_list)
if neg_list:
expr += " - " + " - ".join(neg_list)
rows = []
for i in range(len(top_ids)):
try:
w = enc.decode([top_ids[i].item()]).strip()
if len(w) >= 1:
rows.append(f"#{len(rows)+1:2d} {w:<20s} cos={top_vals[i].item():.4f}")
except Exception:
continue
if len(rows) >= 15:
break
header = f"Expression: {expr}\n{'='*50}\nRank Word Similarity\n{'-'*50}\n"
return header + "\n".join(rows)
# ---------------------------------------------------------------------------
# 3. Typo Resilience
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_typo_resilience(sent_normal, sent_typo, sent_diff):
model, enc, device = get_model()
W_basis = model.transformer.wte.signal_basis.data
def get_deep_signal(text):
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
return x_norm @ W_basis.t()
sig_normal = get_deep_signal(sent_normal)
sig_typo = get_deep_signal(sent_typo)
sig_diff = get_deep_signal(sent_diff)
sim_typo = F.cosine_similarity(sig_normal.unsqueeze(0), sig_typo.unsqueeze(0)).item()
sim_diff = F.cosine_similarity(sig_normal.unsqueeze(0), sig_diff.unsqueeze(0)).item()
fig, ax = plt.subplots(figsize=(8, 5))
categories = ['Self\n(baseline)', 'Normal vs Typo\n(same meaning)', 'Normal vs Different\n(different meaning)']
values = [1.0, sim_typo, sim_diff]
colors = ['#2ecc71', '#f39c12', '#e74c3c']
bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', width=0.5)
for bar, val in zip(bars, values):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.4f}', ha='center', fontsize=11, fontweight='bold')
ax.set_ylim(0, 1.15)
ax.set_ylabel("Cosine Similarity")
ax.set_title("reFlow Typo Resilience - Deep Signal Similarity", fontsize=13, fontweight='bold')
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
return fig
# ---------------------------------------------------------------------------
# 4. Sparsity Profile
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_sparsity_profile(word_to_inspect):
model, enc, device = get_model()
W_v2s = _get_vocab_signals(model)
W = W_v2s[:REAL_VOCAB]
vocab_size, n_signals = W.shape
mean_val = W.abs().mean().item()
std_val = W.abs().std().item()
threshold = mean_val + std_val
active_mask = W.abs() > threshold
active_per_word = active_mask.sum(dim=1).cpu().numpy()
active_per_signal = active_mask.sum(dim=0).cpu().numpy()
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Histogram of active signals per word
int_bins = np.arange(active_per_word.min(), active_per_word.max() + 2) - 0.5
axes[0].hist(active_per_word, bins=int_bins, color='teal', alpha=0.7, edgecolor='black')
axes[0].axvline(x=np.mean(active_per_word), color='red', linestyle='--',
label=f'Mean: {np.mean(active_per_word):.1f}')
axes[0].set_title("Per-Word Sparsity (# Active Signals)")
axes[0].set_xlabel("Number of Active Signals")
axes[0].set_ylabel("Frequency")
axes[0].legend()
# Signal utilization
axes[1].bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0)
axes[1].set_title("Signal Utilization (# words activating each signal)")
axes[1].set_xlabel("Signal Index")
axes[1].set_ylabel("# Words")
axes[1].axhline(y=np.mean(active_per_signal), color='red', linestyle='--',
label=f'Mean: {np.mean(active_per_signal):.0f}')
axes[1].legend()
plt.suptitle("reFlow Sparsity Profile", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
# Per-word stats
stats_text = f"Threshold: {threshold:.4f} (mean + std)\n"
stats_text += f"Avg active signals per word: {np.mean(active_per_word):.1f} / {n_signals}\n"
stats_text += f"Global activation rate: {active_mask.float().mean().item():.2%}\n"
if word_to_inspect and word_to_inspect.strip():
w = word_to_inspect.strip()
tids = enc.encode(" " + w)
if tids and tids[0] < REAL_VOCAB:
word_recipe = W[tids[0]]
word_active = (word_recipe.abs() > threshold).sum().item()
top_sigs = torch.argsort(word_recipe.abs(), descending=True)[:10]
stats_text += f"\n--- '{w}' ---\n"
stats_text += f"Active signals: {word_active}\n"
stats_text += f"Top 10 signal indices: {top_sigs.tolist()}\n"
stats_text += f"Top 10 amplitudes: {[f'{word_recipe[s].item():.4f}' for s in top_sigs]}\n"
else:
stats_text += f"\n'{w}' not found in vocabulary.\n"
return fig, stats_text
# ---------------------------------------------------------------------------
# 5. Layer Evolution
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_layer_evolution(prompt_text):
model, enc, device = get_model()
vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix()
n_layers = len(model.transformer.h)
ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0)
layer_probs = []
layer_entropies = []
x = _embed(model, ids)
freqs_cis = model.freqs_cis[:ids.size(1)]
for block in model.transformer.h:
x = block(x, freqs_cis)
x_norm = model.transformer.ln_f(x[0, -1, :])
probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1)
layer_probs.append(probs.cpu().numpy())
entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item()
layer_entropies.append(entropy)
final_probs = layer_probs[-1][:REAL_VOCAB]
top_idx = np.argsort(final_probs)[-6:]
prob_flow = np.array([[p[i] for i in top_idx] for p in layer_probs])
layers = np.arange(1, n_layers + 1)
fig, (ax_prob, ax_ent) = plt.subplots(1, 2, figsize=(16, 5))
colors_palette = sns.color_palette("husl", len(top_idx))
for i, idx in enumerate(top_idx):
label = repr(enc.decode([idx])).strip("'")
ax_prob.plot(layers, prob_flow[:, i], label=label, lw=2.5, color=colors_palette[i])
ax_prob.set_title(f"Probability Evolution: '{prompt_text}'", fontsize=11, fontweight='bold')
ax_prob.set_xlabel("Layer")
ax_prob.set_ylabel("Probability")
ax_prob.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
ax_prob.legend(fontsize=8, loc='upper left')
ax_prob.grid(True, alpha=0.3)
ax_ent.plot(layers, layer_entropies, color='#FF6B35', lw=2.5, marker='o', markersize=3)
ax_ent.set_title(f"Entropy Decay: '{prompt_text}'", fontsize=11, fontweight='bold')
ax_ent.set_xlabel("Layer")
ax_ent.set_ylabel("Entropy (nats)")
ax_ent.grid(True, alpha=0.3)
predicted = enc.decode([np.argmax(final_probs)])
plt.suptitle(f"reFlow Layer Evolution | Prediction: '{predicted}' (p={final_probs.max():.2%})",
fontsize=13, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
return fig
# ---------------------------------------------------------------------------
# 6. Causal Ablation
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_causal_ablation(prompt_text):
model, enc, device = get_model()
W_basis = model.transformer.wte.signal_basis.data
W_v2s = _get_vocab_signals(model)
ablation_steps = [1, 2, 4, 8, 16, 32, 64, 128]
ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0)
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
sig_acts = x_norm @ W_basis.t()
logits_base = sig_acts @ W_v2s[:REAL_VOCAB].t()
probs_base = F.softmax(logits_base, dim=-1)
pred_id = torch.argmax(probs_base).item()
pred_word = enc.decode([pred_id])
pred_prob = probs_base[pred_id].item()
contribs = sig_acts * W_v2s[pred_id]
sorted_sig_ids = torch.argsort(contribs, descending=True)
steps, probs_list, new_preds = [], [], []
for n_ablate in ablation_steps:
if n_ablate > len(sorted_sig_ids):
break
ablated = sig_acts.clone()
ablated[sorted_sig_ids[:n_ablate]] = 0.0
logits_abl = ablated @ W_v2s[:REAL_VOCAB].t()
probs_abl = F.softmax(logits_abl, dim=-1)
new_pred_id = torch.argmax(probs_abl).item()
steps.append(n_ablate)
probs_list.append(probs_abl[pred_id].item())
new_preds.append(enc.decode([new_pred_id]))
# Codebook for top signal
top_sig = sorted_sig_ids[0].item()
col = W_v2s[:REAL_VOCAB, top_sig]
top_vals, top_ids = torch.topk(col, 8)
cb_words = []
for tid in top_ids:
try:
cb_words.append(enc.decode([tid.item()]).strip())
except Exception:
cb_words.append(f"[{tid.item()}]")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(steps, [max(p, 1e-8) for p in probs_list],
'o-', color='#e74c3c', lw=2.5, markersize=6)
ax1.axhline(y=pred_prob, color='blue', linestyle='--', alpha=0.5,
label=f"Baseline: {pred_prob:.1%}")
ax1.set_title(f"'{prompt_text}'\nPrediction: '{pred_word}'", fontsize=10, fontweight='bold')
ax1.set_xlabel("# Signals Ablated")
ax1.set_ylabel("P(original prediction)")
ax1.set_yscale('log')
ax1.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=2))
ax1.set_xscale('log', base=2)
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)
# Text summary
ax2.axis('off')
summary = f"Baseline: '{pred_word}' (p={pred_prob:.2%})\n"
summary += f"Key Signal: #{top_sig}\n"
summary += f"Codebook: {', '.join(cb_words[:6])}\n\n"
summary += "Ablation Results:\n" + "-"*40 + "\n"
for s, p, nw in zip(steps, probs_list, new_preds):
summary += f" {s:3d} signals removed -> p={p:.2%}, pred='{nw}'\n"
ax2.text(0.05, 0.95, summary, transform=ax2.transAxes, fontsize=10,
verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
plt.suptitle("reFlow Causal Ablation", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
return fig
# ---------------------------------------------------------------------------
# 7. Concept Inception
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_concept_inception(prompt_text, target_word, alpha_max):
model, enc, device = get_model()
W_basis = model.transformer.wte.signal_basis.data
W_v2s = _get_vocab_signals(model)
tid = enc.encode(" " + target_word)[0]
target_recipe = W_v2s[tid]
ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0)
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
base_sig = x_norm @ W_basis.t()
logits_base = base_sig @ W_v2s[:REAL_VOCAB].t()
probs_base = F.softmax(logits_base, dim=-1)
orig_word = enc.decode([torch.argmax(probs_base).item()])
orig_prob = probs_base[tid].item()
# Binary search for critical alpha
lo, hi = 0.0, float(alpha_max)
critical_alpha = None
probs_hi = F.softmax((base_sig + hi * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1)
if torch.argmax(probs_hi).item() == tid:
for _ in range(20):
mid = (lo + hi) / 2
probs_mid = F.softmax((base_sig + mid * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1)
if torch.argmax(probs_mid).item() == tid:
hi = mid
else:
lo = mid
critical_alpha = hi
# Build curve
alpha_range = min(float(alpha_max), (critical_alpha or float(alpha_max)) * 1.5)
alphas = np.linspace(0, alpha_range, 50)
target_probs = []
for a in alphas:
probs = F.softmax((base_sig + a * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1)
target_probs.append(probs[tid].item())
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(alphas, target_probs, 'o-', color='#9b59b6', lw=2, markersize=3)
if critical_alpha:
ax.axvline(critical_alpha, color='red', linestyle='--',
label=f"Critical alpha={critical_alpha:.1f}")
ax.axhline(y=orig_prob, color='gray', linestyle=':', alpha=0.5,
label=f"Baseline P('{target_word}')={orig_prob:.1e}")
ax.set_title(f"'{prompt_text}'\n'{orig_word}' -> '{target_word}'",
fontsize=11, fontweight='bold')
ax.set_xlabel("Steering Strength (alpha)")
ax.set_ylabel(f"P('{target_word}')")
ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
info = f"Original prediction: '{orig_word}'\n"
info += f"Target: '{target_word}'\n"
if critical_alpha:
info += f"Critical flip point: alpha = {critical_alpha:.2f}\n"
else:
info += f"Target not reached within alpha <= {alpha_max}\n"
return fig, info
# ---------------------------------------------------------------------------
# 8. Text Generation
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_generate(prompt_text, num_samples, max_tokens, temperature, top_k, repetition_penalty):
model, enc, device = get_model()
num_samples = int(num_samples)
max_tokens = int(max_tokens)
top_k = int(top_k) if top_k and top_k > 0 else None
temperature = float(temperature)
repetition_penalty = float(repetition_penalty)
if not prompt_text.strip():
return "Please enter a prompt."
ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0)
# Repeat for num_samples
ids = ids.expand(num_samples, -1).contiguous()
results = []
for s in range(num_samples):
x = ids[s:s+1]
for _ in range(max_tokens):
x_cond = x if x.size(1) <= model.config.block_size else x[:, -model.config.block_size:]
logits, _ = model(x_cond)
logits = logits[:, -1, :]
# Repetition penalty
if repetition_penalty != 1.0:
generated_ids = x[0].tolist()
for token_id in set(generated_ids):
if logits[0, token_id] > 0:
logits[0, token_id] /= repetition_penalty
else:
logits[0, token_id] *= repetition_penalty
# Temperature
logits = logits / max(temperature, 1e-8)
# Top-k filtering
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
x = torch.cat((x, idx_next), dim=1)
text = enc.decode(x[0].tolist())
results.append(text)
separator = "\n" + "=" * 60 + "\n"
output = ""
for i, text in enumerate(results):
if num_samples > 1:
output += f"--- Sample {i+1}/{num_samples} ---\n"
output += text + "\n"
if i < len(results) - 1:
output += separator
return output
# ---------------------------------------------------------------------------
# 9. Signal Basis Geometry
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_basis_geometry():
model, enc, device = get_model()
W_basis = model.transformer.wte.signal_basis.data.cpu().float()
n_signals, n_embd = W_basis.shape
U, S, Vt = torch.linalg.svd(W_basis, full_matrices=False)
S_np = S.numpy()
s_norm = S_np / S_np.sum()
effective_rank = np.exp(-np.sum(s_norm * np.log(s_norm + 1e-12)))
random_mat = torch.randn_like(W_basis)
_, S_rand, _ = torch.linalg.svd(random_mat, full_matrices=False)
S_rand_np = S_rand.numpy()
s_rand_norm = S_rand_np / S_rand_np.sum()
effective_rank_rand = np.exp(-np.sum(s_rand_norm * np.log(s_rand_norm + 1e-12)))
show_n = min(64, n_signals)
W_show = W_basis[:show_n]
W_normed = F.normalize(W_show, dim=1)
cos_sim = (W_normed @ W_normed.t()).numpy()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
ax1.plot(S_np / S_np[0], 'b-', lw=2, label='Learned Basis')
ax1.plot(S_rand_np / S_rand_np[0], 'r--', lw=1.5, label='Random Gaussian')
ax1.set_title(f"Singular Value Spectrum\n(Eff. rank: learned={effective_rank:.0f}, random={effective_rank_rand:.0f})")
ax1.set_xlabel("Component Index")
ax1.set_ylabel("Normalized Singular Value")
ax1.set_yscale('log')
ax1.legend()
ax1.grid(True, alpha=0.3)
im = ax2.imshow(cos_sim, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
ax2.set_title(f"Cosine Similarity (first {show_n} signals)")
ax2.set_xlabel("Signal Index")
ax2.set_ylabel("Signal Index")
plt.colorbar(im, ax=ax2, fraction=0.046)
plt.suptitle("reFlow Signal Basis Geometry", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
stats = f"Signal basis shape: ({n_signals}, {n_embd})\n"
stats += f"Effective rank (learned): {effective_rank:.1f} / {min(n_signals, n_embd)}\n"
stats += f"Effective rank (random): {effective_rank_rand:.1f} / {min(n_signals, n_embd)}\n"
return fig, stats
# ---------------------------------------------------------------------------
# 10. Recipe Neighbors (Nearest Neighbor Lookup)
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_recipe_neighbors(query_word, top_n):
model, enc, device = get_model()
W_v2s = _get_vocab_signals(model)
W = W_v2s[:REAL_VOCAB]
W_normed = F.normalize(W, dim=1)
top_n = int(top_n)
words = [w.strip() for w in query_word.split(",") if w.strip()]
if not words:
return "Please enter at least one word."
output = ""
for w in words:
tids = enc.encode(" " + w)
if not tids or tids[0] >= REAL_VOCAB:
output += f"'{w}' not found in vocabulary.\n\n"
continue
tid = tids[0]
sims = (W_normed[tid] @ W_normed.t())
sims[tid] = -1
top_vals, top_ids = torch.topk(sims, top_n)
output += f"Nearest neighbors for '{w}':\n" + "-" * 40 + "\n"
for i, (v, nid) in enumerate(zip(top_vals, top_ids)):
try:
nw = enc.decode([nid.item()]).strip()
except Exception:
nw = f"[{nid.item()}]"
output += f" #{i+1:2d} {nw:<20s} cos={v.item():.4f}\n"
output += "\n"
return output
# ---------------------------------------------------------------------------
# 11. Task Crystallization
# ---------------------------------------------------------------------------
@torch.inference_mode()
def exp_task_crystallization(prompt_text, target_word, max_alpha, start_layer):
model, enc, device = get_model()
W_basis = model.transformer.wte.signal_basis.data
W_v2s = _get_vocab_signals(model)
n_layers = len(model.transformer.h)
start_layer = int(start_layer)
max_alpha = float(max_alpha)
target_tid = enc.encode(" " + target_word.strip())[0]
ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0)
# Get baseline prediction
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
logits_base = _get_logits_from_hidden(model, x_norm)
base_pred_id = torch.argmax(logits_base).item()
base_pred = enc.decode([base_pred_id])
# Find working alpha
def continuous_steer(alpha, intercept_layer):
steer_vec = W_v2s[target_tid] - W_v2s[base_pred_id]
x = _embed(model, ids)
if intercept_layer == 0:
x[:, -1, :] += (alpha * steer_vec) @ W_basis
freqs_cis = model.freqs_cis[:ids.size(1)]
for i, block in enumerate(model.transformer.h):
x = block(x, freqs_cis)
if i + 1 >= intercept_layer:
x[:, -1, :] += (alpha * steer_vec) @ W_basis
x_norm = model.transformer.ln_f(x[0, -1, :])
logits = _get_logits_from_hidden(model, x_norm)
probs = F.softmax(logits, dim=-1)
pred_id = torch.argmax(logits).item()
return probs[target_tid].item(), enc.decode([pred_id]).strip()
# Find minimum alpha that works at start_layer
working_alpha = None
for a in np.arange(2.0, max_alpha, 2.0):
_, pred = continuous_steer(a, start_layer)
if pred.strip() == target_word.strip():
working_alpha = a * 1.2
break
if working_alpha is None:
return None, f"Cannot steer to '{target_word}' within alpha <= {max_alpha}"
# Scan across layers
layer_probs = []
c_layer = n_layers
for L in range(n_layers):
p_target, pred = continuous_steer(working_alpha, L)
layer_probs.append(p_target)
if pred.strip() != target_word.strip() and c_layer == n_layers:
c_layer = L
# Plot
fig, ax = plt.subplots(figsize=(10, 6))
layers_x = np.arange(n_layers)
ax.plot(layers_x, layer_probs, 'o-', color='#9b59b6', lw=2.5, markersize=4)
if c_layer < n_layers:
ax.scatter(c_layer, layer_probs[c_layer], color='red', s=150, marker='X', edgecolors='black', zorder=5)
ax.axvline(c_layer, color='red', linestyle='--', alpha=0.5, label=f'Crystallization boundary: Layer {c_layer}')
ax.set_title(f"Task Crystallization: '{prompt_text}' → '{target_word}'", fontsize=11, fontweight='bold')
ax.set_xlabel("Intervention Start Layer")
ax.set_ylabel(f"P('{target_word}')")
ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
info = f"Base prediction: '{base_pred}'\n"
info += f"Target: '{target_word}'\n"
info += f"Working alpha: {working_alpha:.1f}\n"
info += f"Crystallization boundary: Layer {c_layer}\n"
return fig, info