"""Experiment functions for the reFlow interpretability demo, adapted for Gradio.""" import torch import torch.nn.functional as F import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker import seaborn as sns from sklearn.decomposition import PCA from sklearn.metrics import silhouette_score try: from adjustText import adjust_text except ImportError: adjust_text = lambda texts, **kwargs: None from model_loader import get_model, get_cached_tensors REAL_VOCAB = 50257 # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _embed(model, ids): result = model.transformer.wte(ids) return result[0] if isinstance(result, tuple) else result def _get_vocab_signals(model): wte = model.transformer.wte if hasattr(wte, '_apply_sparsity'): return wte._apply_sparsity(wte.vocab_to_signals.weight.data) return wte.vocab_to_signals.weight.data def _forward_through_layers(model, ids): with torch.no_grad(): x = _embed(model, ids) freqs_cis = model.freqs_cis[:ids.size(1)] for block in model.transformer.h: x = block(x, freqs_cis) return x def _get_logits_from_hidden(model, x_norm): vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix() return F.linear(x_norm, vocab_matrix) def _gini(arr): arr = np.sort(np.abs(arr)) n = len(arr) if n == 0 or np.sum(arr) == 0: return 0.0 index = np.arange(1, n + 1) return (2 * np.sum(index * arr) / (n * np.sum(arr))) - (n + 1) / n # --------------------------------------------------------------------------- # 1. Semantic Galaxy (PCA) # --------------------------------------------------------------------------- DEFAULT_CLUSTERS = { "Countries": ["China", "France", "Germany", "Japan", "India", "Russia"], "Animals": ["cat", "dog", "fish", "bird", "horse", "bear"], "Numbers": ["one", "two", "three", "four", "five", "ten"], "Colors": ["red", "blue", "green", "black", "white", "yellow"], "Emotions": ["happy", "sad", "angry", "love", "fear", "hate"], } @torch.inference_mode() def exp_semantic_galaxy( use_countries, use_animals, use_numbers, use_colors, use_emotions, custom_words ): model, enc, device = get_model() W_v2s = _get_vocab_signals(model).cpu().numpy() # Build clusters from checkboxes clusters = {} if use_countries: clusters["Countries"] = DEFAULT_CLUSTERS["Countries"] if use_animals: clusters["Animals"] = DEFAULT_CLUSTERS["Animals"] if use_numbers: clusters["Numbers"] = DEFAULT_CLUSTERS["Numbers"] if use_colors: clusters["Colors"] = DEFAULT_CLUSTERS["Colors"] if use_emotions: clusters["Emotions"] = DEFAULT_CLUSTERS["Emotions"] # Custom words if custom_words and custom_words.strip(): custom_list = [w.strip() for w in custom_words.split(",") if w.strip()] if custom_list: clusters["Custom"] = custom_list if not clusters: clusters = DEFAULT_CLUSTERS recipes, labels, words = [], [], [] for cat, wl in clusters.items(): for w in wl: tids = enc.encode(" " + w) if tids and tids[0] < REAL_VOCAB: recipes.append(W_v2s[tids[0]]) labels.append(cat) words.append(w) if len(words) < 3: fig, ax = plt.subplots(figsize=(8, 6)) ax.text(0.5, 0.5, "Need at least 3 valid words", ha='center', va='center', fontsize=14) ax.axis('off') return fig recipes_arr = np.array(recipes) coords = PCA(n_components=2).fit_transform(recipes_arr) label_ids = [list(clusters.keys()).index(l) for l in labels] sil = silhouette_score(recipes_arr, label_ids) if len(set(label_ids)) >= 2 else 0.0 fig = plt.figure(figsize=(12, 9)) color_map = dict(zip(clusters.keys(), sns.color_palette("Set2", len(clusters)))) texts = [] for i, w in enumerate(words): plt.scatter(coords[i, 0], coords[i, 1], color=color_map[labels[i]], s=150, alpha=0.7, edgecolors='white', linewidths=0.5) texts.append(plt.text(coords[i, 0], coords[i, 1], w, fontsize=11)) if callable(adjust_text) and getattr(adjust_text, '__name__', '') != '': adjust_text(texts, arrowprops=dict(arrowstyle="-", color='gray')) handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[l], markersize=12, label=l) for l in clusters] plt.legend(handles=handles, title="Clusters", fontsize=10) plt.title(f"reFlow Semantic Galaxy (PCA)\nSilhouette Score = {sil:.4f}", fontsize=14, fontweight='bold') plt.xlabel("PC1") plt.ylabel("PC2") plt.tight_layout() return fig # --------------------------------------------------------------------------- # 2. Semantic Algebra # --------------------------------------------------------------------------- @torch.inference_mode() def exp_semantic_algebra(positive_words, negative_words): model, enc, device = get_model() W_v2s = _get_vocab_signals(model) W_valid = W_v2s[:REAL_VOCAB] pos_list = [w.strip() for w in positive_words.split(",") if w.strip()] neg_list = [w.strip() for w in negative_words.split(",") if w.strip()] if not pos_list: return "Please enter at least one positive word." target_vec = torch.zeros(model.config.n_signals, device=device) exclude_ids = set() for w in pos_list: tids = enc.encode(" " + w) if tids and tids[0] < REAL_VOCAB: target_vec += W_v2s[tids[0]] exclude_ids.add(tids[0]) for w in neg_list: tids = enc.encode(" " + w) if tids and tids[0] < REAL_VOCAB: target_vec -= W_v2s[tids[0]] exclude_ids.add(tids[0]) sims = F.cosine_similarity(target_vec.unsqueeze(0), W_valid) for tid in exclude_ids: sims[tid] = -1.0 top_vals, top_ids = torch.topk(sims, 20) expr = " + ".join(pos_list) if neg_list: expr += " - " + " - ".join(neg_list) rows = [] for i in range(len(top_ids)): try: w = enc.decode([top_ids[i].item()]).strip() if len(w) >= 1: rows.append(f"#{len(rows)+1:2d} {w:<20s} cos={top_vals[i].item():.4f}") except Exception: continue if len(rows) >= 15: break header = f"Expression: {expr}\n{'='*50}\nRank Word Similarity\n{'-'*50}\n" return header + "\n".join(rows) # --------------------------------------------------------------------------- # 3. Typo Resilience # --------------------------------------------------------------------------- @torch.inference_mode() def exp_typo_resilience(sent_normal, sent_typo, sent_diff): model, enc, device = get_model() W_basis = model.transformer.wte.signal_basis.data def get_deep_signal(text): ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0) x = _forward_through_layers(model, ids) x_norm = model.transformer.ln_f(x[0, -1, :]) return x_norm @ W_basis.t() sig_normal = get_deep_signal(sent_normal) sig_typo = get_deep_signal(sent_typo) sig_diff = get_deep_signal(sent_diff) sim_typo = F.cosine_similarity(sig_normal.unsqueeze(0), sig_typo.unsqueeze(0)).item() sim_diff = F.cosine_similarity(sig_normal.unsqueeze(0), sig_diff.unsqueeze(0)).item() fig, ax = plt.subplots(figsize=(8, 5)) categories = ['Self\n(baseline)', 'Normal vs Typo\n(same meaning)', 'Normal vs Different\n(different meaning)'] values = [1.0, sim_typo, sim_diff] colors = ['#2ecc71', '#f39c12', '#e74c3c'] bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', width=0.5) for bar, val in zip(bars, values): ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.4f}', ha='center', fontsize=11, fontweight='bold') ax.set_ylim(0, 1.15) ax.set_ylabel("Cosine Similarity") ax.set_title("reFlow Typo Resilience - Deep Signal Similarity", fontsize=13, fontweight='bold') ax.grid(axis='y', alpha=0.3) plt.tight_layout() return fig # --------------------------------------------------------------------------- # 4. Sparsity Profile # --------------------------------------------------------------------------- @torch.inference_mode() def exp_sparsity_profile(word_to_inspect): model, enc, device = get_model() W_v2s = _get_vocab_signals(model) W = W_v2s[:REAL_VOCAB] vocab_size, n_signals = W.shape mean_val = W.abs().mean().item() std_val = W.abs().std().item() threshold = mean_val + std_val active_mask = W.abs() > threshold active_per_word = active_mask.sum(dim=1).cpu().numpy() active_per_signal = active_mask.sum(dim=0).cpu().numpy() fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Histogram of active signals per word int_bins = np.arange(active_per_word.min(), active_per_word.max() + 2) - 0.5 axes[0].hist(active_per_word, bins=int_bins, color='teal', alpha=0.7, edgecolor='black') axes[0].axvline(x=np.mean(active_per_word), color='red', linestyle='--', label=f'Mean: {np.mean(active_per_word):.1f}') axes[0].set_title("Per-Word Sparsity (# Active Signals)") axes[0].set_xlabel("Number of Active Signals") axes[0].set_ylabel("Frequency") axes[0].legend() # Signal utilization axes[1].bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0) axes[1].set_title("Signal Utilization (# words activating each signal)") axes[1].set_xlabel("Signal Index") axes[1].set_ylabel("# Words") axes[1].axhline(y=np.mean(active_per_signal), color='red', linestyle='--', label=f'Mean: {np.mean(active_per_signal):.0f}') axes[1].legend() plt.suptitle("reFlow Sparsity Profile", fontsize=14, fontweight='bold') plt.tight_layout(rect=[0, 0, 1, 0.95]) # Per-word stats stats_text = f"Threshold: {threshold:.4f} (mean + std)\n" stats_text += f"Avg active signals per word: {np.mean(active_per_word):.1f} / {n_signals}\n" stats_text += f"Global activation rate: {active_mask.float().mean().item():.2%}\n" if word_to_inspect and word_to_inspect.strip(): w = word_to_inspect.strip() tids = enc.encode(" " + w) if tids and tids[0] < REAL_VOCAB: word_recipe = W[tids[0]] word_active = (word_recipe.abs() > threshold).sum().item() top_sigs = torch.argsort(word_recipe.abs(), descending=True)[:10] stats_text += f"\n--- '{w}' ---\n" stats_text += f"Active signals: {word_active}\n" stats_text += f"Top 10 signal indices: {top_sigs.tolist()}\n" stats_text += f"Top 10 amplitudes: {[f'{word_recipe[s].item():.4f}' for s in top_sigs]}\n" else: stats_text += f"\n'{w}' not found in vocabulary.\n" return fig, stats_text # --------------------------------------------------------------------------- # 5. Layer Evolution # --------------------------------------------------------------------------- @torch.inference_mode() def exp_layer_evolution(prompt_text): model, enc, device = get_model() vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix() n_layers = len(model.transformer.h) ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) layer_probs = [] layer_entropies = [] x = _embed(model, ids) freqs_cis = model.freqs_cis[:ids.size(1)] for block in model.transformer.h: x = block(x, freqs_cis) x_norm = model.transformer.ln_f(x[0, -1, :]) probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1) layer_probs.append(probs.cpu().numpy()) entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item() layer_entropies.append(entropy) final_probs = layer_probs[-1][:REAL_VOCAB] top_idx = np.argsort(final_probs)[-6:] prob_flow = np.array([[p[i] for i in top_idx] for p in layer_probs]) layers = np.arange(1, n_layers + 1) fig, (ax_prob, ax_ent) = plt.subplots(1, 2, figsize=(16, 5)) colors_palette = sns.color_palette("husl", len(top_idx)) for i, idx in enumerate(top_idx): label = repr(enc.decode([idx])).strip("'") ax_prob.plot(layers, prob_flow[:, i], label=label, lw=2.5, color=colors_palette[i]) ax_prob.set_title(f"Probability Evolution: '{prompt_text}'", fontsize=11, fontweight='bold') ax_prob.set_xlabel("Layer") ax_prob.set_ylabel("Probability") ax_prob.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0)) ax_prob.legend(fontsize=8, loc='upper left') ax_prob.grid(True, alpha=0.3) ax_ent.plot(layers, layer_entropies, color='#FF6B35', lw=2.5, marker='o', markersize=3) ax_ent.set_title(f"Entropy Decay: '{prompt_text}'", fontsize=11, fontweight='bold') ax_ent.set_xlabel("Layer") ax_ent.set_ylabel("Entropy (nats)") ax_ent.grid(True, alpha=0.3) predicted = enc.decode([np.argmax(final_probs)]) plt.suptitle(f"reFlow Layer Evolution | Prediction: '{predicted}' (p={final_probs.max():.2%})", fontsize=13, fontweight='bold') plt.tight_layout(rect=[0, 0, 1, 0.95]) return fig # --------------------------------------------------------------------------- # 6. Causal Ablation # --------------------------------------------------------------------------- @torch.inference_mode() def exp_causal_ablation(prompt_text): model, enc, device = get_model() W_basis = model.transformer.wte.signal_basis.data W_v2s = _get_vocab_signals(model) ablation_steps = [1, 2, 4, 8, 16, 32, 64, 128] ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) x = _forward_through_layers(model, ids) x_norm = model.transformer.ln_f(x[0, -1, :]) sig_acts = x_norm @ W_basis.t() logits_base = sig_acts @ W_v2s[:REAL_VOCAB].t() probs_base = F.softmax(logits_base, dim=-1) pred_id = torch.argmax(probs_base).item() pred_word = enc.decode([pred_id]) pred_prob = probs_base[pred_id].item() contribs = sig_acts * W_v2s[pred_id] sorted_sig_ids = torch.argsort(contribs, descending=True) steps, probs_list, new_preds = [], [], [] for n_ablate in ablation_steps: if n_ablate > len(sorted_sig_ids): break ablated = sig_acts.clone() ablated[sorted_sig_ids[:n_ablate]] = 0.0 logits_abl = ablated @ W_v2s[:REAL_VOCAB].t() probs_abl = F.softmax(logits_abl, dim=-1) new_pred_id = torch.argmax(probs_abl).item() steps.append(n_ablate) probs_list.append(probs_abl[pred_id].item()) new_preds.append(enc.decode([new_pred_id])) # Codebook for top signal top_sig = sorted_sig_ids[0].item() col = W_v2s[:REAL_VOCAB, top_sig] top_vals, top_ids = torch.topk(col, 8) cb_words = [] for tid in top_ids: try: cb_words.append(enc.decode([tid.item()]).strip()) except Exception: cb_words.append(f"[{tid.item()}]") fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) ax1.plot(steps, [max(p, 1e-8) for p in probs_list], 'o-', color='#e74c3c', lw=2.5, markersize=6) ax1.axhline(y=pred_prob, color='blue', linestyle='--', alpha=0.5, label=f"Baseline: {pred_prob:.1%}") ax1.set_title(f"'{prompt_text}'\nPrediction: '{pred_word}'", fontsize=10, fontweight='bold') ax1.set_xlabel("# Signals Ablated") ax1.set_ylabel("P(original prediction)") ax1.set_yscale('log') ax1.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=2)) ax1.set_xscale('log', base=2) ax1.legend(fontsize=8) ax1.grid(True, alpha=0.3) # Text summary ax2.axis('off') summary = f"Baseline: '{pred_word}' (p={pred_prob:.2%})\n" summary += f"Key Signal: #{top_sig}\n" summary += f"Codebook: {', '.join(cb_words[:6])}\n\n" summary += "Ablation Results:\n" + "-"*40 + "\n" for s, p, nw in zip(steps, probs_list, new_preds): summary += f" {s:3d} signals removed -> p={p:.2%}, pred='{nw}'\n" ax2.text(0.05, 0.95, summary, transform=ax2.transAxes, fontsize=10, verticalalignment='top', fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8)) plt.suptitle("reFlow Causal Ablation", fontsize=14, fontweight='bold') plt.tight_layout(rect=[0, 0, 1, 0.95]) return fig # --------------------------------------------------------------------------- # 7. Concept Inception # --------------------------------------------------------------------------- @torch.inference_mode() def exp_concept_inception(prompt_text, target_word, alpha_max): model, enc, device = get_model() W_basis = model.transformer.wte.signal_basis.data W_v2s = _get_vocab_signals(model) tid = enc.encode(" " + target_word)[0] target_recipe = W_v2s[tid] ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) x = _forward_through_layers(model, ids) x_norm = model.transformer.ln_f(x[0, -1, :]) base_sig = x_norm @ W_basis.t() logits_base = base_sig @ W_v2s[:REAL_VOCAB].t() probs_base = F.softmax(logits_base, dim=-1) orig_word = enc.decode([torch.argmax(probs_base).item()]) orig_prob = probs_base[tid].item() # Binary search for critical alpha lo, hi = 0.0, float(alpha_max) critical_alpha = None probs_hi = F.softmax((base_sig + hi * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1) if torch.argmax(probs_hi).item() == tid: for _ in range(20): mid = (lo + hi) / 2 probs_mid = F.softmax((base_sig + mid * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1) if torch.argmax(probs_mid).item() == tid: hi = mid else: lo = mid critical_alpha = hi # Build curve alpha_range = min(float(alpha_max), (critical_alpha or float(alpha_max)) * 1.5) alphas = np.linspace(0, alpha_range, 50) target_probs = [] for a in alphas: probs = F.softmax((base_sig + a * target_recipe) @ W_v2s[:REAL_VOCAB].t(), dim=-1) target_probs.append(probs[tid].item()) fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(alphas, target_probs, 'o-', color='#9b59b6', lw=2, markersize=3) if critical_alpha: ax.axvline(critical_alpha, color='red', linestyle='--', label=f"Critical alpha={critical_alpha:.1f}") ax.axhline(y=orig_prob, color='gray', linestyle=':', alpha=0.5, label=f"Baseline P('{target_word}')={orig_prob:.1e}") ax.set_title(f"'{prompt_text}'\n'{orig_word}' -> '{target_word}'", fontsize=11, fontweight='bold') ax.set_xlabel("Steering Strength (alpha)") ax.set_ylabel(f"P('{target_word}')") ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0)) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) plt.tight_layout() info = f"Original prediction: '{orig_word}'\n" info += f"Target: '{target_word}'\n" if critical_alpha: info += f"Critical flip point: alpha = {critical_alpha:.2f}\n" else: info += f"Target not reached within alpha <= {alpha_max}\n" return fig, info # --------------------------------------------------------------------------- # 8. Text Generation # --------------------------------------------------------------------------- @torch.inference_mode() def exp_generate(prompt_text, num_samples, max_tokens, temperature, top_k, repetition_penalty): model, enc, device = get_model() num_samples = int(num_samples) max_tokens = int(max_tokens) top_k = int(top_k) if top_k and top_k > 0 else None temperature = float(temperature) repetition_penalty = float(repetition_penalty) if not prompt_text.strip(): return "Please enter a prompt." ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) # Repeat for num_samples ids = ids.expand(num_samples, -1).contiguous() results = [] for s in range(num_samples): x = ids[s:s+1] for _ in range(max_tokens): x_cond = x if x.size(1) <= model.config.block_size else x[:, -model.config.block_size:] logits, _ = model(x_cond) logits = logits[:, -1, :] # Repetition penalty if repetition_penalty != 1.0: generated_ids = x[0].tolist() for token_id in set(generated_ids): if logits[0, token_id] > 0: logits[0, token_id] /= repetition_penalty else: logits[0, token_id] *= repetition_penalty # Temperature logits = logits / max(temperature, 1e-8) # Top-k filtering if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) x = torch.cat((x, idx_next), dim=1) text = enc.decode(x[0].tolist()) results.append(text) separator = "\n" + "=" * 60 + "\n" output = "" for i, text in enumerate(results): if num_samples > 1: output += f"--- Sample {i+1}/{num_samples} ---\n" output += text + "\n" if i < len(results) - 1: output += separator return output # --------------------------------------------------------------------------- # 9. Signal Basis Geometry # --------------------------------------------------------------------------- @torch.inference_mode() def exp_basis_geometry(): model, enc, device = get_model() W_basis = model.transformer.wte.signal_basis.data.cpu().float() n_signals, n_embd = W_basis.shape U, S, Vt = torch.linalg.svd(W_basis, full_matrices=False) S_np = S.numpy() s_norm = S_np / S_np.sum() effective_rank = np.exp(-np.sum(s_norm * np.log(s_norm + 1e-12))) random_mat = torch.randn_like(W_basis) _, S_rand, _ = torch.linalg.svd(random_mat, full_matrices=False) S_rand_np = S_rand.numpy() s_rand_norm = S_rand_np / S_rand_np.sum() effective_rank_rand = np.exp(-np.sum(s_rand_norm * np.log(s_rand_norm + 1e-12))) show_n = min(64, n_signals) W_show = W_basis[:show_n] W_normed = F.normalize(W_show, dim=1) cos_sim = (W_normed @ W_normed.t()).numpy() fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) ax1.plot(S_np / S_np[0], 'b-', lw=2, label='Learned Basis') ax1.plot(S_rand_np / S_rand_np[0], 'r--', lw=1.5, label='Random Gaussian') ax1.set_title(f"Singular Value Spectrum\n(Eff. rank: learned={effective_rank:.0f}, random={effective_rank_rand:.0f})") ax1.set_xlabel("Component Index") ax1.set_ylabel("Normalized Singular Value") ax1.set_yscale('log') ax1.legend() ax1.grid(True, alpha=0.3) im = ax2.imshow(cos_sim, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto') ax2.set_title(f"Cosine Similarity (first {show_n} signals)") ax2.set_xlabel("Signal Index") ax2.set_ylabel("Signal Index") plt.colorbar(im, ax=ax2, fraction=0.046) plt.suptitle("reFlow Signal Basis Geometry", fontsize=14, fontweight='bold') plt.tight_layout(rect=[0, 0, 1, 0.95]) stats = f"Signal basis shape: ({n_signals}, {n_embd})\n" stats += f"Effective rank (learned): {effective_rank:.1f} / {min(n_signals, n_embd)}\n" stats += f"Effective rank (random): {effective_rank_rand:.1f} / {min(n_signals, n_embd)}\n" return fig, stats # --------------------------------------------------------------------------- # 10. Recipe Neighbors (Nearest Neighbor Lookup) # --------------------------------------------------------------------------- @torch.inference_mode() def exp_recipe_neighbors(query_word, top_n): model, enc, device = get_model() W_v2s = _get_vocab_signals(model) W = W_v2s[:REAL_VOCAB] W_normed = F.normalize(W, dim=1) top_n = int(top_n) words = [w.strip() for w in query_word.split(",") if w.strip()] if not words: return "Please enter at least one word." output = "" for w in words: tids = enc.encode(" " + w) if not tids or tids[0] >= REAL_VOCAB: output += f"'{w}' not found in vocabulary.\n\n" continue tid = tids[0] sims = (W_normed[tid] @ W_normed.t()) sims[tid] = -1 top_vals, top_ids = torch.topk(sims, top_n) output += f"Nearest neighbors for '{w}':\n" + "-" * 40 + "\n" for i, (v, nid) in enumerate(zip(top_vals, top_ids)): try: nw = enc.decode([nid.item()]).strip() except Exception: nw = f"[{nid.item()}]" output += f" #{i+1:2d} {nw:<20s} cos={v.item():.4f}\n" output += "\n" return output # --------------------------------------------------------------------------- # 11. Task Crystallization # --------------------------------------------------------------------------- @torch.inference_mode() def exp_task_crystallization(prompt_text, target_word, max_alpha, start_layer): model, enc, device = get_model() W_basis = model.transformer.wte.signal_basis.data W_v2s = _get_vocab_signals(model) n_layers = len(model.transformer.h) start_layer = int(start_layer) max_alpha = float(max_alpha) target_tid = enc.encode(" " + target_word.strip())[0] ids = torch.tensor(enc.encode(prompt_text), device=device).unsqueeze(0) # Get baseline prediction x = _forward_through_layers(model, ids) x_norm = model.transformer.ln_f(x[0, -1, :]) logits_base = _get_logits_from_hidden(model, x_norm) base_pred_id = torch.argmax(logits_base).item() base_pred = enc.decode([base_pred_id]) # Find working alpha def continuous_steer(alpha, intercept_layer): steer_vec = W_v2s[target_tid] - W_v2s[base_pred_id] x = _embed(model, ids) if intercept_layer == 0: x[:, -1, :] += (alpha * steer_vec) @ W_basis freqs_cis = model.freqs_cis[:ids.size(1)] for i, block in enumerate(model.transformer.h): x = block(x, freqs_cis) if i + 1 >= intercept_layer: x[:, -1, :] += (alpha * steer_vec) @ W_basis x_norm = model.transformer.ln_f(x[0, -1, :]) logits = _get_logits_from_hidden(model, x_norm) probs = F.softmax(logits, dim=-1) pred_id = torch.argmax(logits).item() return probs[target_tid].item(), enc.decode([pred_id]).strip() # Find minimum alpha that works at start_layer working_alpha = None for a in np.arange(2.0, max_alpha, 2.0): _, pred = continuous_steer(a, start_layer) if pred.strip() == target_word.strip(): working_alpha = a * 1.2 break if working_alpha is None: return None, f"Cannot steer to '{target_word}' within alpha <= {max_alpha}" # Scan across layers layer_probs = [] c_layer = n_layers for L in range(n_layers): p_target, pred = continuous_steer(working_alpha, L) layer_probs.append(p_target) if pred.strip() != target_word.strip() and c_layer == n_layers: c_layer = L # Plot fig, ax = plt.subplots(figsize=(10, 6)) layers_x = np.arange(n_layers) ax.plot(layers_x, layer_probs, 'o-', color='#9b59b6', lw=2.5, markersize=4) if c_layer < n_layers: ax.scatter(c_layer, layer_probs[c_layer], color='red', s=150, marker='X', edgecolors='black', zorder=5) ax.axvline(c_layer, color='red', linestyle='--', alpha=0.5, label=f'Crystallization boundary: Layer {c_layer}') ax.set_title(f"Task Crystallization: '{prompt_text}' → '{target_word}'", fontsize=11, fontweight='bold') ax.set_xlabel("Intervention Start Layer") ax.set_ylabel(f"P('{target_word}')") ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0)) ax.legend(fontsize=9) ax.grid(True, alpha=0.3) plt.tight_layout() info = f"Base prediction: '{base_pred}'\n" info += f"Target: '{target_word}'\n" info += f"Working alpha: {working_alpha:.1f}\n" info += f"Crystallization boundary: Layer {c_layer}\n" return fig, info