| import sys
|
| import os
|
| import torch
|
| import torch.nn.functional as F
|
| import tiktoken
|
| import numpy as np
|
| import matplotlib.pyplot as plt
|
| import seaborn as sns
|
| import matplotlib.ticker as ticker
|
| from sklearn.decomposition import PCA
|
| from sklearn.metrics import silhouette_score
|
| from scipy.cluster.hierarchy import linkage, leaves_list
|
| from scipy.spatial.distance import squareform
|
|
|
| try:
|
| from adjustText import adjust_text
|
| except ImportError:
|
| print("[WARNING] 未安装 adjustText,PCA图表的文本可能会重叠。建议运行: pip install adjustText")
|
| adjust_text = lambda texts, **kwargs: None
|
|
|
| plt.style.use('seaborn-v0_8-whitegrid')
|
| sns.set_context("paper", font_scale=1.2)
|
|
|
| def load_setup_and_model():
|
| if len(sys.argv) != 2:
|
| print("[ERROR] 必须指定配置文件!\n用法: python experiment.py <config_file>")
|
| sys.exit(1)
|
|
|
| config_file = sys.argv[1]
|
| if not os.path.exists(config_file):
|
| print(f"[ERROR] 找不到配置文件: {config_file}")
|
| sys.exit(1)
|
|
|
| print(f"\n[INFO] 正在加载配置: {config_file}")
|
| config_dict = {}
|
| with open(config_file, encoding='utf-8') as f:
|
| exec(f.read(), config_dict)
|
|
|
| out_dir = config_dict.get('out_dir', 'out/reflow-1')
|
| model_config = config_dict.get('model_config', 'reflow')
|
|
|
| with open(f"models/{model_config}.py", encoding='utf-8') as f:
|
| exec(f.read(), globals())
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
| report_dir = os.path.join(out_dir, 'audit_reports')
|
| os.makedirs(report_dir, exist_ok=True)
|
|
|
| print(f"[INFO] 正在从 {out_dir} 加载 reFlow 模型 (Device: {device})...")
|
| ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
| if not os.path.exists(ckpt_path):
|
| print(f"[ERROR] 找不到权重文件: {ckpt_path}")
|
| sys.exit(1)
|
|
|
| checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| model = globals()['GPT'](globals()['GPTConfig'](**checkpoint['model_args']))
|
|
|
| state_dict = checkpoint['model']
|
| for k in list(state_dict.keys()):
|
| if k.startswith('_orig_mod.'): state_dict[k[10:]] = state_dict.pop(k)
|
| model.load_state_dict(state_dict)
|
| model.eval().to(device)
|
|
|
| enc = tiktoken.get_encoding("gpt2")
|
| return model, enc, device, report_dir
|
|
|
| def _embed(model, ids):
|
| """兼容 wte() 返回值:reflow-topk 返回元组,其他返回单张量"""
|
| result = model.transformer.wte(ids)
|
| return result[0] if isinstance(result, tuple) else result
|
|
|
| def _get_vocab_signals(model):
|
| """获取有效 vocab→signals 权重,topk 版本会应用稀疏化"""
|
| wte = model.transformer.wte
|
| if hasattr(wte, '_apply_sparsity'):
|
| return wte._apply_sparsity(wte.vocab_to_signals.weight.data)
|
| return wte.vocab_to_signals.weight.data
|
|
|
| def _forward_through_layers(model, ids):
|
| """通过所有 transformer 层,返回最终隐藏状态。"""
|
| with torch.no_grad():
|
| x = _embed(model, ids)
|
| freqs_cis = model.freqs_cis[:ids.size(1)]
|
| for block in model.transformer.h:
|
| x = block(x, freqs_cis)
|
| return x
|
|
|
| def _get_logits_from_hidden(model, x_norm):
|
| """从 layer-normed 隐藏状态计算 logits。"""
|
| vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix()
|
| return F.linear(x_norm, vocab_matrix)
|
|
|
|
|
| def exp_1_recipe_atlas(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 1] 配方空间图谱 (Recipe Atlas)")
|
| print("="*60)
|
|
|
| W_v2s = _get_vocab_signals(model)
|
| real_vocab = 50257
|
| W = W_v2s[:real_vocab]
|
|
|
| probe_words = [
|
| "China", "France", "Japan", "Germany", "India", "Russia",
|
| "Paris", "London", "Tokyo", "Berlin", "Beijing", "Rome",
|
| "cat", "dog", "fish", "bird", "horse", "bear", "wolf",
|
| "red", "blue", "green", "black", "white", "yellow",
|
| "happy", "sad", "angry", "love", "fear", "hate", "joy",
|
| "one", "two", "three", "four", "five", "ten", "hundred",
|
| "run", "walk", "think", "eat", "write", "read", "speak",
|
| "big", "small", "hot", "cold", "fast", "slow", "good", "bad",
|
| "king", "queen", "man", "woman", "boy", "girl",
|
| "water", "fire", "earth", "light", "dark", "sun", "moon",
|
| ]
|
| probe_ids, probe_labels = [], []
|
| for w in probe_words:
|
| tids = enc.encode(" " + w)
|
| if tids and tids[0] < real_vocab:
|
| probe_ids.append(tids[0])
|
| probe_labels.append(w)
|
|
|
| probe_recipes = W[probe_ids]
|
| probe_normed = F.normalize(probe_recipes, dim=1)
|
| sim_matrix = (probe_normed @ probe_normed.t()).cpu().numpy()
|
| sim_matrix_no_diag = sim_matrix.copy()
|
| np.fill_diagonal(sim_matrix_no_diag, 0)
|
|
|
| n_probe = len(probe_labels)
|
| pairs = []
|
| for i in range(n_probe):
|
| for j in range(i + 1, n_probe):
|
| pairs.append((sim_matrix_no_diag[i, j], probe_labels[i], probe_labels[j]))
|
| pairs.sort(reverse=True)
|
|
|
| print("\n 配方空间最近邻词对 (Top-20):")
|
| print(" " + "-"*50)
|
| for rank, (sim, w1, w2) in enumerate(pairs[:20]):
|
| print(f" #{rank+1:2d} | {w1:>10s} ↔ {w2:<10s} | cos={sim:.4f}")
|
|
|
| W_normed = F.normalize(W, dim=1)
|
| print("\n 全词表配方近邻 (每词 Top-5):")
|
| print(" " + "-"*60)
|
| nn_table = []
|
| for idx, (tid, label) in enumerate(zip(probe_ids[:20], probe_labels[:20])):
|
| sims = (W_normed[tid] @ W_normed.t())
|
| sims[tid] = -1
|
| top_vals, top_ids = torch.topk(sims, 5)
|
| neighbors = []
|
| for v, nid in zip(top_vals, top_ids):
|
| try:
|
| nw = enc.decode([nid.item()]).strip()
|
| neighbors.append(f"{nw}({v:.3f})")
|
| except Exception:
|
| neighbors.append(f"[{nid.item()}]({v:.3f})")
|
| nn_str = ", ".join(neighbors)
|
| print(f" {label:>10s} → {nn_str}")
|
| nn_table.append((label, neighbors))
|
|
|
| sig_var = W.var(dim=0).cpu().numpy()
|
| top_var_idx = np.argsort(sig_var)[::-1][:20]
|
| bottom_var_idx = np.argsort(sig_var)[:20]
|
|
|
| print(f"\n 信号方差分析:")
|
| print(f" > 最高方差信号 (最具区分力): {top_var_idx[:10].tolist()}")
|
| print(f" > 最低方差信号 (近似常数): {bottom_var_idx[:10].tolist()}")
|
| print(f" > 方差 Gini 系数: {_gini(sig_var):.4f}")
|
|
|
| fig = plt.figure(figsize=(20, 14))
|
| gs = fig.add_gridspec(2, 2, height_ratios=[1.2, 1])
|
|
|
| ax1 = fig.add_subplot(gs[0, :])
|
| dist_matrix = np.clip(1 - sim_matrix, 0, None)
|
| np.fill_diagonal(dist_matrix, 0)
|
| Z = linkage(squareform(dist_matrix), method='ward')
|
| order = leaves_list(Z)
|
| sim_ordered = sim_matrix[np.ix_(order, order)]
|
| labels_ordered = [probe_labels[i] for i in order]
|
| im = ax1.imshow(sim_ordered, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
|
| ax1.set_xticks(range(n_probe))
|
| ax1.set_xticklabels(labels_ordered, rotation=90, fontsize=7)
|
| ax1.set_yticks(range(n_probe))
|
| ax1.set_yticklabels(labels_ordered, fontsize=7)
|
| plt.colorbar(im, ax=ax1, fraction=0.02)
|
| ax1.set_title("Recipe Cosine Similarity (hierarchical clustering order)", fontsize=12, fontweight='bold')
|
|
|
| ax2 = fig.add_subplot(gs[1, 0])
|
| sorted_var = np.sort(sig_var)[::-1]
|
| ax2.bar(range(len(sorted_var)), sorted_var, color='steelblue', alpha=0.7, width=1.0)
|
| ax2.set_title("Signal Variance Across Vocabulary (sorted)", fontsize=11, fontweight='bold')
|
| ax2.set_xlabel("Signal (sorted by variance)")
|
| ax2.set_ylabel("Variance")
|
| ax2.axhline(y=np.mean(sig_var), color='red', linestyle='--', label=f'Mean: {np.mean(sig_var):.4f}')
|
| ax2.legend()
|
|
|
| ax3 = fig.add_subplot(gs[1, 1])
|
| ax3.axis('off')
|
| table_data = []
|
| for label, neighbors in nn_table[:15]:
|
| nn_short = ", ".join(n.split("(")[0] for n in neighbors[:4])
|
| table_data.append([label, nn_short])
|
| table = ax3.table(cellText=table_data, colLabels=["Word", "Top-4 Recipe Neighbors"],
|
| loc='center', cellLoc='left')
|
| table.auto_set_font_size(False)
|
| table.set_fontsize(9)
|
| table.scale(1.0, 1.5)
|
| for (row, col), cell in table.get_celld().items():
|
| if row == 0:
|
| cell.set_facecolor('#4472C4')
|
| cell.set_text_props(color='white', fontweight='bold')
|
| elif row % 2 == 0:
|
| cell.set_facecolor('#D9E2F3')
|
| ax3.set_title("Vocabulary Recipe Nearest Neighbors", fontsize=11, fontweight='bold', pad=15)
|
|
|
| plt.suptitle("reFlow Recipe Atlas — Signal Recipe Space Structure", fontsize=14, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| save_path = os.path.join(report_dir, "recipe_atlas.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f"\n > 图表已保存: {save_path}")
|
|
|
|
|
| def _gini(arr):
|
| """计算 Gini 系数,衡量分布不均匀度。0=完全均匀,1=完全集中。"""
|
| arr = np.sort(np.abs(arr))
|
| n = len(arr)
|
| if n == 0 or np.sum(arr) == 0:
|
| return 0.0
|
| index = np.arange(1, n + 1)
|
| return (2 * np.sum(index * arr) / (n * np.sum(arr))) - (n + 1) / n
|
|
|
|
|
| def exp_2_sparsity_profile(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 2] 信号稀疏性分析 (Sparsity Profile)")
|
| print("="*60)
|
|
|
| W_v2s = _get_vocab_signals(model)
|
| real_vocab = 50257
|
| W = W_v2s[:real_vocab]
|
| vocab_size, n_signals = W.shape
|
|
|
| is_topk = hasattr(model.transformer.wte, '_apply_sparsity')
|
|
|
| if is_topk:
|
| nonzero_mask = W.abs() > 0
|
| active_per_word = nonzero_mask.sum(dim=1).float()
|
| k = int(active_per_word.median().item())
|
| print(f" > 检测到 TopK 稀疏模式,固定 k={k}")
|
|
|
| nonzero_vals = W[nonzero_mask].abs().cpu().numpy()
|
| active_per_signal = nonzero_mask.sum(dim=0).cpu().numpy()
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
| ax1.hist(nonzero_vals, bins=80, color='teal', alpha=0.7, edgecolor='black')
|
| ax1.set_title(f"Active Signal Amplitude Distribution (k={k})")
|
| ax1.set_xlabel("Absolute Amplitude")
|
| ax1.set_ylabel("Frequency")
|
|
|
| ax2.bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0)
|
| ax2.set_title("Signal Utilization (# words activating each signal)")
|
| ax2.set_xlabel("Signal Index")
|
| ax2.set_ylabel("# Words")
|
| ax2.axhline(y=np.mean(active_per_signal), color='red', linestyle='--',
|
| label=f'Mean: {np.mean(active_per_signal):.0f}')
|
| ax2.legend()
|
| else:
|
| mean_val = W.abs().mean().item()
|
| std_val = W.abs().std().item()
|
| threshold = mean_val + std_val
|
| active_mask = W.abs() > threshold
|
|
|
| active_per_word = active_mask.sum(dim=1).cpu().numpy()
|
| active_per_signal = active_mask.sum(dim=0).cpu().numpy()
|
|
|
| print(f" > 活跃阈值: {threshold:.4f} (mean + std)")
|
| print(f" > 平均每词活跃信号: {np.mean(active_per_word):.1f} / {n_signals}")
|
| print(f" > 全局激活率: {active_mask.float().mean().item():.2%}")
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
|
| int_bins = np.arange(active_per_word.min(), active_per_word.max() + 2) - 0.5
|
| ax1.hist(active_per_word, bins=int_bins, color='teal', alpha=0.7, edgecolor='black')
|
| ax1.axvline(x=np.mean(active_per_word), color='red', linestyle='--',
|
| label=f'Mean: {np.mean(active_per_word):.1f}')
|
| ax1.set_title("Per-Word Sparsity (# Active Signals)")
|
| ax1.set_xlabel("Number of Active Signals")
|
| ax1.set_ylabel("Frequency")
|
| ax1.legend()
|
|
|
| ax2.bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0)
|
| ax2.set_title("Signal Utilization (# words activating each signal)")
|
| ax2.set_xlabel("Signal Index")
|
| ax2.set_ylabel("# Words")
|
| ax2.axhline(y=np.mean(active_per_signal), color='red', linestyle='--',
|
| label=f'Mean: {np.mean(active_per_signal):.0f}')
|
| ax2.legend()
|
|
|
| plt.suptitle("reFlow Sparsity Profile", fontsize=14, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| save_path = os.path.join(report_dir, "sparsity_profile.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| print("\n" + "="*60)
|
| print(" [论文数据导出] 用于 TikZ/PGFPlots 绘图")
|
| print("="*60)
|
|
|
| if is_topk:
|
| active_per_word_np = active_per_word.cpu().numpy()
|
| else:
|
| active_per_word_np = active_per_word
|
|
|
|
|
| hist_min = int(active_per_word_np.min())
|
| hist_max = int(active_per_word_np.max())
|
| hist_bins = np.arange(hist_min, hist_max + 2)
|
| hist_counts, hist_edges = np.histogram(active_per_word_np, bins=hist_bins)
|
| print(f"\n [直方图] 每词活跃信号数分布 (bin_start, count):")
|
| print(f" mean={np.mean(active_per_word_np):.1f}, min={hist_min}, max={hist_max}")
|
| print(" ---BEGIN_HISTOGRAM_DATA---")
|
| for i in range(len(hist_counts)):
|
| if hist_counts[i] > 0:
|
| print(f" {int(hist_edges[i])} {hist_counts[i]}")
|
| print(" ---END_HISTOGRAM_DATA---")
|
|
|
|
|
| sorted_utilization = np.sort(active_per_signal)[::-1]
|
| print(f"\n [柱状图] 信号利用率 (按降序排列, signal_rank, n_words):")
|
| print(f" mean={np.mean(active_per_signal):.0f}, min={np.min(active_per_signal)}, max={np.max(active_per_signal)}")
|
| print(" ---BEGIN_UTILIZATION_DATA---")
|
| for i, val in enumerate(sorted_utilization):
|
| print(f" {i} {val}")
|
| print(" ---END_UTILIZATION_DATA---")
|
|
|
|
|
| def exp_3_basis_geometry(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 3] 信号基底几何结构 (Signal Basis Geometry)")
|
| print("="*60)
|
|
|
| W_basis = model.transformer.wte.signal_basis.data.cpu().float()
|
| n_signals, n_embd = W_basis.shape
|
|
|
| U, S, Vt = torch.linalg.svd(W_basis, full_matrices=False)
|
| S_np = S.numpy()
|
|
|
| s_norm = S_np / S_np.sum()
|
| effective_rank = np.exp(-np.sum(s_norm * np.log(s_norm + 1e-12)))
|
|
|
| random_mat = torch.randn_like(W_basis)
|
| _, S_rand, _ = torch.linalg.svd(random_mat, full_matrices=False)
|
| S_rand_np = S_rand.numpy()
|
| s_rand_norm = S_rand_np / S_rand_np.sum()
|
| effective_rank_rand = np.exp(-np.sum(s_rand_norm * np.log(s_rand_norm + 1e-12)))
|
|
|
| print(f" > Signal basis shape: ({n_signals}, {n_embd})")
|
| print(f" > Effective rank (learned): {effective_rank:.1f} / {min(n_signals, n_embd)}")
|
| print(f" > Effective rank (random): {effective_rank_rand:.1f} / {min(n_signals, n_embd)}")
|
|
|
| show_n = min(64, n_signals)
|
| W_show = W_basis[:show_n]
|
| W_normed = F.normalize(W_show, dim=1)
|
| cos_sim = (W_normed @ W_normed.t()).numpy()
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
|
|
| ax1.plot(S_np / S_np[0], 'b-', lw=2, label='Learned Basis')
|
| ax1.plot(S_rand_np / S_rand_np[0], 'r--', lw=1.5, label='Random Gaussian')
|
| ax1.set_title(f"Singular Value Spectrum\n(Eff. rank: learned={effective_rank:.0f}, random={effective_rank_rand:.0f})")
|
| ax1.set_xlabel("Component Index")
|
| ax1.set_ylabel("Normalized Singular Value")
|
| ax1.set_yscale('log')
|
| ax1.legend()
|
| ax1.grid(True, alpha=0.3)
|
|
|
| im = ax2.imshow(cos_sim, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
|
| ax2.set_title(f"Cosine Similarity (first {show_n} signals)")
|
| ax2.set_xlabel("Signal Index")
|
| ax2.set_ylabel("Signal Index")
|
| plt.colorbar(im, ax=ax2, fraction=0.046)
|
|
|
| plt.suptitle("reFlow Signal Basis Geometry", fontsize=14, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| save_path = os.path.join(report_dir, "basis_geometry.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_4_semantic_galaxy(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 4] 语义星空图 PCA (Semantic Galaxy)")
|
| print("="*60)
|
|
|
| W_v2s = _get_vocab_signals(model).cpu().numpy()
|
| real_vocab = 50257
|
| clusters = {
|
| "Countries": ["China", "France", "Germany", "Japan", "India", "Russia"],
|
| "Animals": ["cat", "dog", "fish", "bird", "horse", "bear"],
|
| "Numbers": ["one", "two", "three", "four", "five", "ten"],
|
| "Colors": ["red", "blue", "green", "black", "white", "yellow"],
|
| "Emotions": ["happy", "sad", "angry", "love", "fear", "hate"],
|
| }
|
|
|
| recipes, labels, words = [], [], []
|
| for cat, wl in clusters.items():
|
| for w in wl:
|
| tids = enc.encode(" " + w)
|
| if tids and tids[0] < real_vocab:
|
| recipes.append(W_v2s[tids[0]])
|
| labels.append(cat)
|
| words.append(w)
|
|
|
| recipes_arr = np.array(recipes)
|
| coords = PCA(n_components=2).fit_transform(recipes_arr)
|
|
|
| label_ids = [list(clusters.keys()).index(l) for l in labels]
|
| sil = silhouette_score(recipes_arr, label_ids)
|
| print(f" > Silhouette Score: {sil:.4f}")
|
|
|
| plt.figure(figsize=(12, 9))
|
| color_map = dict(zip(clusters.keys(), sns.color_palette("Set2", len(clusters))))
|
|
|
| texts = []
|
| for i, w in enumerate(words):
|
| plt.scatter(coords[i, 0], coords[i, 1], color=color_map[labels[i]],
|
| s=150, alpha=0.7, edgecolors='white', linewidths=0.5)
|
| texts.append(plt.text(coords[i, 0], coords[i, 1], w, fontsize=11))
|
|
|
| if adjust_text.__name__ != '<lambda>':
|
| adjust_text(texts, arrowprops=dict(arrowstyle="-", color='gray'))
|
|
|
| handles = [plt.Line2D([0], [0], marker='o', color='w',
|
| markerfacecolor=color_map[l], markersize=12, label=l) for l in clusters]
|
| plt.legend(handles=handles, title="Clusters", fontsize=10)
|
| plt.title(f"reFlow Semantic Galaxy (PCA)\nSilhouette Score = {sil:.4f}",
|
| fontsize=14, fontweight='bold')
|
| plt.xlabel("PC1")
|
| plt.ylabel("PC2")
|
|
|
| save_path = os.path.join(report_dir, "semantic_galaxy.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_5_semantic_algebra(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 5] 语义代数运算 (Semantic Algebra)")
|
| print("="*60)
|
|
|
| W_v2s = _get_vocab_signals(model)
|
| W_valid = W_v2s[:50257]
|
|
|
| test_cases = [
|
| (["Paris", "China"], ["France"], "Beijing"),
|
| (["king", "woman"], ["man"], "queen"),
|
| (["walked", "running"], ["walking"], "ran"),
|
| ]
|
|
|
| results = []
|
| for pos_words, neg_words, expected in test_cases:
|
| expr = " + ".join(pos_words) + " - " + " - ".join(neg_words)
|
| target_vec = torch.zeros(model.config.n_signals, device=device)
|
| exclude_ids = set()
|
|
|
| for w in pos_words:
|
| tid = enc.encode(" " + w)[0]
|
| target_vec += W_v2s[tid]
|
| exclude_ids.add(tid)
|
| for w in neg_words:
|
| tid = enc.encode(" " + w)[0]
|
| target_vec -= W_v2s[tid]
|
| exclude_ids.add(tid)
|
|
|
| sims = F.cosine_similarity(target_vec.unsqueeze(0), W_valid)
|
| for tid in exclude_ids:
|
| sims[tid] = -1.0
|
|
|
| top_vals, top_ids = torch.topk(sims, 20)
|
|
|
| expected_tid = enc.encode(" " + expected)[0]
|
| expected_rank = -1
|
| hit_words = []
|
| for i in range(len(top_ids)):
|
| try:
|
| w = enc.decode([top_ids[i].item()]).strip()
|
| if len(w) >= 2:
|
| hit_words.append((w, top_vals[i].item()))
|
| if top_ids[i].item() == expected_tid:
|
| expected_rank = i + 1
|
| except Exception:
|
| continue
|
|
|
| if expected_rank == -1:
|
| all_sims = sims.clone()
|
| sorted_ids = torch.argsort(all_sims, descending=True)
|
| for rank_i, sid in enumerate(sorted_ids[:500]):
|
| if sid.item() == expected_tid:
|
| expected_rank = rank_i + 1
|
| break
|
|
|
| results.append((expr, expected, expected_rank, hit_words[:5]))
|
|
|
| print(f"\n {expr} → 期望: '{expected}'")
|
| if expected_rank > 0:
|
| marker = "HIT!" if expected_rank <= 10 else ""
|
| print(f" > '{expected}' 排名: #{expected_rank} {marker}")
|
| else:
|
| print(f" > '{expected}' 未在 top-500 中找到")
|
| print(f" > Top-5: {', '.join(f'{w}({s:.3f})' for w, s in hit_words[:5])}")
|
|
|
| fig, ax = plt.subplots(figsize=(14, 4 + len(results)))
|
| ax.axis('off')
|
| table_data = []
|
| for expr, expected, rank, hits in results:
|
| rank_str = f"#{rank}" if rank > 0 else "Not found"
|
| hit_str = ", ".join(w for w, _ in hits[:4])
|
| table_data.append([expr, expected, rank_str, hit_str])
|
|
|
| table = ax.table(
|
| cellText=table_data,
|
| colLabels=["Expression", "Expected", "Rank", "Top Hits"],
|
| loc='center', cellLoc='left'
|
| )
|
| table.auto_set_font_size(False)
|
| table.set_fontsize(10)
|
| table.scale(1.0, 1.8)
|
| for (row, col), cell in table.get_celld().items():
|
| if row == 0:
|
| cell.set_facecolor('#4472C4')
|
| cell.set_text_props(color='white', fontweight='bold')
|
| ax.set_title("reFlow Semantic Algebra Results", fontsize=14, fontweight='bold', pad=20)
|
|
|
| save_path = os.path.join(report_dir, "semantic_algebra.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f"\n > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_6_typo_resilience(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 6] 拼写鲁棒性 (Typo Resilience)")
|
| print("="*60)
|
|
|
| sent_normal = "The scientist is very intelligent"
|
| sent_typo = "The scientsit is vary intellgent"
|
| sent_diff = "The dog runs in the park"
|
|
|
| W_basis = model.transformer.wte.signal_basis.data
|
|
|
| def get_deep_signal(text):
|
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
|
| with torch.no_grad():
|
| x = _forward_through_layers(model, ids)
|
| x_norm = model.transformer.ln_f(x[0, -1, :])
|
| return x_norm @ W_basis.t()
|
|
|
| sig_normal = get_deep_signal(sent_normal)
|
| sig_typo = get_deep_signal(sent_typo)
|
| sig_diff = get_deep_signal(sent_diff)
|
|
|
| sim_typo = F.cosine_similarity(sig_normal.unsqueeze(0), sig_typo.unsqueeze(0)).item()
|
| sim_diff = F.cosine_similarity(sig_normal.unsqueeze(0), sig_diff.unsqueeze(0)).item()
|
| sim_self = 1.0
|
|
|
| print(f" > 正常: '{sent_normal}'")
|
| print(f" > 拼错: '{sent_typo}'")
|
| print(f" > 无关: '{sent_diff}'")
|
| print(f"\n [正常 vs 拼错] 深层语义相似度: \033[93m{sim_typo:.4f}\033[0m")
|
| print(f" [正常 vs 无关] 深层语义相似度: {sim_diff:.4f}")
|
| print(f" > 鲁棒性指标 (差值): {sim_typo - sim_diff:.4f}")
|
|
|
| fig, ax = plt.subplots(figsize=(8, 5))
|
| categories = ['Self\n(baseline)', 'Normal vs Typo\n(same meaning)', 'Normal vs Different\n(different meaning)']
|
| values = [sim_self, sim_typo, sim_diff]
|
| colors = ['#2ecc71', '#f39c12', '#e74c3c']
|
| bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', width=0.5)
|
| for bar, val in zip(bars, values):
|
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
| f'{val:.4f}', ha='center', fontsize=11, fontweight='bold')
|
| ax.set_ylim(0, 1.15)
|
| ax.set_ylabel("Cosine Similarity")
|
| ax.set_title("reFlow Typo Resilience — Deep Signal Similarity", fontsize=13, fontweight='bold')
|
| ax.grid(axis='y', alpha=0.3)
|
|
|
| save_path = os.path.join(report_dir, "typo_resilience.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_7_layer_evolution(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 7] 层级概率演化 (Layer Probability Evolution)")
|
| print("="*60)
|
|
|
| prompts = [
|
| "The capital of France is",
|
| "The cat sat on the",
|
| "The sun is very",
|
| ]
|
| vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix()
|
| real_vocab = 50257
|
| n_layers = len(model.transformer.h)
|
|
|
| fig, axes = plt.subplots(len(prompts), 2, figsize=(18, 5 * len(prompts)),
|
| gridspec_kw={'width_ratios': [1.3, 1]})
|
| if len(prompts) == 1:
|
| axes = axes[np.newaxis, :]
|
|
|
| for pi, text in enumerate(prompts):
|
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
|
| layer_probs = []
|
| layer_entropies = []
|
|
|
| with torch.no_grad():
|
| x = _embed(model, ids)
|
| freqs_cis = model.freqs_cis[:ids.size(1)]
|
| for block in model.transformer.h:
|
| x = block(x, freqs_cis)
|
| x_norm = model.transformer.ln_f(x[0, -1, :])
|
| probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1)
|
| layer_probs.append(probs.cpu().numpy())
|
| entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item()
|
| layer_entropies.append(entropy)
|
|
|
| final_probs = layer_probs[-1][:real_vocab]
|
| top_idx = np.argsort(final_probs)[-6:]
|
| prob_flow = np.array([[p[i] for i in top_idx] for p in layer_probs])
|
| layers = np.arange(1, n_layers + 1)
|
|
|
| ax_prob = axes[pi, 0]
|
| colors = sns.color_palette("husl", len(top_idx))
|
| for i, idx in enumerate(top_idx):
|
| label = repr(enc.decode([idx])).strip("'")
|
| ax_prob.plot(layers, prob_flow[:, i], label=label, lw=2.5, color=colors[i])
|
| ax_prob.set_title(f"Probability Evolution: '{text}'", fontsize=11, fontweight='bold')
|
| ax_prob.set_xlabel("Layer")
|
| ax_prob.set_ylabel("Probability")
|
| ax_prob.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
|
| ax_prob.legend(fontsize=8, loc='upper left')
|
| ax_prob.grid(True, alpha=0.3)
|
|
|
| ax_ent = axes[pi, 1]
|
| ax_ent.plot(layers, layer_entropies, color='#FF6B35', lw=2.5, marker='o', markersize=3)
|
| ax_ent.set_title(f"Entropy Decay: '{text}'", fontsize=11, fontweight='bold')
|
| ax_ent.set_xlabel("Layer")
|
| ax_ent.set_ylabel("Entropy (nats)")
|
| ax_ent.grid(True, alpha=0.3)
|
|
|
| predicted = enc.decode([np.argmax(final_probs)])
|
| print(f" > Prompt: '{text}' → Prediction: '{predicted}' (p={final_probs.max():.2%})")
|
|
|
| plt.suptitle("reFlow Layer Probability Evolution", fontsize=15, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| save_path = os.path.join(report_dir, "layer_evolution.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_8_signal_flow(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 8] 信号流追踪 (Signal Flow Tracking)")
|
| print("="*60)
|
|
|
| text = "The capital of France is"
|
| W_basis = model.transformer.wte.signal_basis.data
|
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
|
| n_layers = len(model.transformer.h)
|
| tokens = [repr(enc.decode([t])).strip("'") for t in ids[0].tolist()]
|
|
|
| layer_signals_last_token = []
|
| final_layer_all_tokens = None
|
|
|
| with torch.no_grad():
|
| x = _embed(model, ids)
|
| freqs_cis = model.freqs_cis[:ids.size(1)]
|
| for li, block in enumerate(model.transformer.h):
|
| x = block(x, freqs_cis)
|
| x_norm = model.transformer.ln_f(x[0])
|
| sigs = (x_norm @ W_basis.t()).cpu().numpy()
|
| layer_signals_last_token.append(sigs[-1])
|
| if li == n_layers - 1:
|
| final_layer_all_tokens = sigs
|
|
|
| sig_arr = np.array(layer_signals_last_token)
|
|
|
| var_across_layers = np.var(sig_arr, axis=0)
|
| top_layer_sig_idx = np.argsort(var_across_layers)[-15:][::-1]
|
|
|
| var_across_tokens = np.var(final_layer_all_tokens, axis=0)
|
| top_time_sig_idx = np.argsort(var_across_tokens)[-20:][::-1]
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), gridspec_kw={'width_ratios': [1.2, 1]})
|
|
|
| layer_data = sig_arr[:, top_layer_sig_idx].T
|
| sns.heatmap(layer_data, cmap='RdBu_r', center=0, ax=ax1,
|
| xticklabels=np.arange(1, n_layers + 1),
|
| yticklabels=[f"Sig {i}" for i in top_layer_sig_idx])
|
| ax1.set_title("Signal Flow Across Layers (last token)", fontsize=11, fontweight='bold')
|
| ax1.set_xlabel("Layer")
|
| ax1.set_ylabel("Signal (by layer variance)")
|
|
|
| time_data = final_layer_all_tokens[:, top_time_sig_idx].T
|
| sns.heatmap(time_data, cmap='mako', ax=ax2,
|
| xticklabels=tokens,
|
| yticklabels=[f"Sig {i}" for i in top_time_sig_idx])
|
| ax2.set_title("Signal Activation Across Tokens (final layer)", fontsize=11, fontweight='bold')
|
| ax2.set_xlabel("Token")
|
| ax2.set_ylabel("Signal (by token variance)")
|
| plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
|
|
|
| plt.suptitle(f"reFlow Signal Flow — '{text}'", fontsize=14, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| save_path = os.path.join(report_dir, "signal_flow.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_9_causal_ablation(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 9] 因果消融曲线 (Causal Ablation Curve)")
|
| print("="*60)
|
|
|
| W_basis = model.transformer.wte.signal_basis.data
|
| W_v2s = _get_vocab_signals(model)
|
| prompts = [
|
| "The capital of France is",
|
| "The cat sat on the",
|
| "The sun is very",
|
| ]
|
| ablation_steps = [1, 2, 4, 8, 16, 32, 64, 128]
|
|
|
| all_results = []
|
| codebook_info = []
|
|
|
| for text in prompts:
|
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
|
|
|
| with torch.no_grad():
|
| x = _forward_through_layers(model, ids)
|
| x_norm = model.transformer.ln_f(x[0, -1, :])
|
|
|
| sig_acts = x_norm @ W_basis.t()
|
|
|
| logits_base = sig_acts @ W_v2s[:50257].t()
|
| probs_base = F.softmax(logits_base, dim=-1)
|
| pred_id = torch.argmax(probs_base).item()
|
| pred_word = enc.decode([pred_id])
|
| pred_prob = probs_base[pred_id].item()
|
|
|
| contribs = sig_acts * W_v2s[pred_id]
|
| sorted_sig_ids = torch.argsort(contribs, descending=True)
|
|
|
| result = {'text': text, 'pred': pred_word, 'base_prob': pred_prob,
|
| 'steps': [], 'probs': [], 'new_preds': []}
|
| for n_ablate in ablation_steps:
|
| if n_ablate > len(sorted_sig_ids):
|
| break
|
| ablated = sig_acts.clone()
|
| ablated[sorted_sig_ids[:n_ablate]] = 0.0
|
| logits_abl = ablated @ W_v2s[:50257].t()
|
| probs_abl = F.softmax(logits_abl, dim=-1)
|
| new_pred_id = torch.argmax(probs_abl).item()
|
| result['steps'].append(n_ablate)
|
| result['probs'].append(probs_abl[pred_id].item())
|
| result['new_preds'].append(enc.decode([new_pred_id]))
|
|
|
| all_results.append(result)
|
|
|
| top_sig = sorted_sig_ids[0].item()
|
| col = W_v2s[:50257, top_sig]
|
| top_vals, top_ids = torch.topk(col, 8)
|
| cb_words = []
|
| for tid in top_ids:
|
| try:
|
| cb_words.append(enc.decode([tid.item()]).strip())
|
| except Exception:
|
| cb_words.append(f"[{tid.item()}]")
|
| codebook_info.append((text, top_sig, cb_words))
|
|
|
| print(f"\n Prompt: '{text}'")
|
| print(f" > 基线预测: '{pred_word}' (p={pred_prob:.2%})")
|
| print(f" > 关键信号 #{top_sig} codebook: {', '.join(cb_words[:6])}")
|
| for step, prob, new in zip(result['steps'], result['probs'], result['new_preds']):
|
| print(f" 消融 {step:3d} 信号 → p('{pred_word}')={prob:.2%}, 新预测='{new}'")
|
|
|
| n_prompts = len(all_results)
|
|
|
| fig, axes = plt.subplots(2, n_prompts + 1, figsize=(5.5 * (n_prompts + 1), 9),
|
| gridspec_kw={'width_ratios': [1] * n_prompts + [0.8]})
|
|
|
| for i, res in enumerate(all_results):
|
| ax = axes[0][i]
|
| ax.plot(res['steps'], [max(p, 1e-8) for p in res['probs']],
|
| 'o-', color='#e74c3c', lw=2.5, markersize=6)
|
| ax.axhline(y=res['base_prob'], color='blue', linestyle='--', alpha=0.5,
|
| label=f"Baseline: {res['base_prob']:.1%}")
|
| ax.set_title(f"'{res['text']}'\nPrediction: '{res['pred']}'", fontsize=10, fontweight='bold')
|
| ax.set_xlabel("# Signals Ablated")
|
| ax.set_ylabel("P(original prediction)")
|
| ax.set_yscale('log')
|
| ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=2))
|
| ax.set_xscale('log', base=2)
|
| ax.legend(fontsize=8)
|
| ax.grid(True, alpha=0.3)
|
|
|
| for i, res in enumerate(all_results):
|
| ax = axes[1][i]
|
| retention = [p / res['base_prob'] * 100 for p in res['probs']]
|
| ax.plot(res['steps'], [max(r, 1e-4) for r in retention],
|
| 's-', color='#2ecc71', lw=2.5, markersize=6)
|
| ax.axhline(y=100, color='blue', linestyle='--', alpha=0.5, label="Baseline: 100%")
|
| ax.set_title(f"Retention rate", fontsize=10)
|
| ax.set_xlabel("# Signals Ablated")
|
| ax.set_ylabel("% of baseline probability retained")
|
| ax.set_yscale('log')
|
| ax.set_xscale('log', base=2)
|
| ax.legend(fontsize=8)
|
| ax.grid(True, alpha=0.3)
|
|
|
| for row in range(2):
|
| ax_cb = axes[row][-1]
|
| ax_cb.axis('off')
|
| ax_cb = axes[0][-1]
|
| cb_text = "Critical Signal Codebook\n" + "="*30
|
| for text, sig_id, words in codebook_info:
|
| short = text[:25] + "..." if len(text) > 25 else text
|
| cb_text += f"\n\n'{short}'\n Key Sig #{sig_id}:\n {', '.join(words[:6])}"
|
| ax_cb.text(0.05, 0.95, cb_text, transform=ax_cb.transAxes, fontsize=9,
|
| verticalalignment='top', fontfamily='monospace',
|
| bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
|
|
|
| plt.suptitle("reFlow Causal Ablation Curve", fontsize=14, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| save_path = os.path.join(report_dir, "ablation_curve.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f"\n > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_10_emotion_surgery(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 10] 情绪手术 (Emotion Surgery)")
|
| print("="*60)
|
|
|
| W_v2s = _get_vocab_signals(model)
|
| W_basis = model.transformer.wte.signal_basis.data
|
| pos_words = ["excellent", "wonderful", "amazing", "great", "good"]
|
| neg_words = ["terrible", "awful", "horrible", "bad", "poor"]
|
| pos_vec = torch.stack([W_v2s[enc.encode(" " + w)[0]] for w in pos_words]).mean(dim=0)
|
| neg_vec = torch.stack([W_v2s[enc.encode(" " + w)[0]] for w in neg_words]).mean(dim=0)
|
| steer_vec = pos_vec - neg_vec
|
|
|
| text = "The food was absolutely terrible and the service was "
|
| n_layers = len(model.transformer.h)
|
|
|
| scan_layers = list(range(0, n_layers, max(1, n_layers // 6)))
|
| if scan_layers[-1] != n_layers - 1:
|
| scan_layers.append(n_layers - 1)
|
| scan_alphas = [0.0, 1.0, 3.0, 5.0, 8.0, 12.0]
|
|
|
| def trace_emotion(intercept_layer=None, alpha=0.0):
|
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
|
| pos_acts, neg_acts = [], []
|
| with torch.no_grad():
|
| x = _embed(model, ids)
|
| freqs_cis = model.freqs_cis[:ids.size(1)]
|
| for i, block in enumerate(model.transformer.h):
|
| x = block(x, freqs_cis)
|
| sig = model.transformer.ln_f(x[0, -1, :]) @ W_basis.t()
|
| pos_acts.append(torch.dot(sig, pos_vec).item())
|
| neg_acts.append(torch.dot(sig, neg_vec).item())
|
| if intercept_layer is not None and i >= intercept_layer:
|
| x[:, -1, :] += (steer_vec * alpha) @ W_basis
|
| x_norm = model.transformer.ln_f(x[0, -1, :])
|
| probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1)
|
| pred_word = enc.decode([torch.argmax(probs).item()])
|
| return pred_word, pos_acts, neg_acts
|
|
|
| word_base, p_base, n_base = trace_emotion()
|
| print(f" > [基线] '{text}' → '{word_base}'")
|
|
|
| grid_results = {}
|
| for layer in scan_layers:
|
| for alpha in scan_alphas:
|
| if alpha == 0.0:
|
| grid_results[(layer, alpha)] = word_base
|
| else:
|
| word, _, _ = trace_emotion(intercept_layer=layer, alpha=alpha)
|
| grid_results[(layer, alpha)] = word
|
| if alpha == 5.0:
|
| print(f" > [Layer {layer:2d}, α={alpha}] → '{word}'")
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'width_ratios': [1.2, 1]})
|
|
|
| layers_x = range(n_layers)
|
| ax1.plot(layers_x, n_base, 'r--', lw=2, label='Negative (base)')
|
| ax1.plot(layers_x, p_base, 'b--', lw=2, label='Positive (base)')
|
|
|
| best_layer = scan_layers[len(scan_layers) // 2]
|
| _, p_hack, n_hack = trace_emotion(intercept_layer=best_layer, alpha=5.0)
|
| ax1.plot(layers_x, n_hack, 'r', lw=2.5, label=f'Negative (L{best_layer}, α=5)')
|
| ax1.plot(layers_x, p_hack, 'b', lw=2.5, label=f'Positive (L{best_layer}, α=5)')
|
| ax1.axvline(best_layer, color='green', linestyle=':', lw=2, label=f'Surgery @ Layer {best_layer}')
|
| ax1.set_title("Emotion Signal Flow", fontsize=11, fontweight='bold')
|
| ax1.set_xlabel("Layer")
|
| ax1.set_ylabel("Dot Product with Emotion Vector")
|
| ax1.legend(fontsize=8)
|
| ax1.grid(True, alpha=0.3)
|
|
|
| ax2.axis('off')
|
| table_data = []
|
| for layer in scan_layers:
|
| row = [f"L{layer}"]
|
| for alpha in scan_alphas:
|
| row.append(grid_results.get((layer, alpha), "?"))
|
| table_data.append(row)
|
| col_labels = ["Layer"] + [f"α={a}" for a in scan_alphas]
|
| table = ax2.table(cellText=table_data, colLabels=col_labels, loc='center', cellLoc='center')
|
| table.auto_set_font_size(False)
|
| table.set_fontsize(8)
|
| table.scale(1.0, 1.6)
|
| for (row, col), cell in table.get_celld().items():
|
| if row == 0:
|
| cell.set_facecolor('#4472C4')
|
| cell.set_text_props(color='white', fontweight='bold')
|
| ax2.set_title("Intervention Grid (predicted next word)", fontsize=11, fontweight='bold', pad=20)
|
|
|
| plt.suptitle(f"reFlow Emotion Surgery\n'{text}'", fontsize=13, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.92])
|
| save_path = os.path.join(report_dir, "emotion_surgery.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_11_concept_inception(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 11] 概念注入 (Concept Inception)")
|
| print("="*60)
|
|
|
| W_basis = model.transformer.wte.signal_basis.data
|
| W_v2s = _get_vocab_signals(model)
|
|
|
| test_cases = [
|
| ("The capital of France is", "London"),
|
| ("The cat sat on the", "moon"),
|
| ("The sun is very", "cold"),
|
| ]
|
|
|
| all_curves = []
|
| for text, target in test_cases:
|
| tid = enc.encode(" " + target)[0]
|
| target_recipe = W_v2s[tid]
|
|
|
| ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
|
| with torch.no_grad():
|
| x = _forward_through_layers(model, ids)
|
| x_norm = model.transformer.ln_f(x[0, -1, :])
|
| base_sig = x_norm @ W_basis.t()
|
|
|
| logits_base = base_sig @ W_v2s[:50257].t()
|
| probs_base = F.softmax(logits_base, dim=-1)
|
| orig_word = enc.decode([torch.argmax(probs_base).item()])
|
| orig_prob = probs_base[tid].item()
|
|
|
| lo, hi = 0.0, 200.0
|
| critical_alpha = None
|
|
|
| probs_hi = F.softmax((base_sig + hi * target_recipe) @ W_v2s[:50257].t(), dim=-1)
|
| if torch.argmax(probs_hi).item() == tid:
|
| for _ in range(20):
|
| mid = (lo + hi) / 2
|
| probs_mid = F.softmax((base_sig + mid * target_recipe) @ W_v2s[:50257].t(), dim=-1)
|
| if torch.argmax(probs_mid).item() == tid:
|
| hi = mid
|
| else:
|
| lo = mid
|
| critical_alpha = hi
|
|
|
| alphas = np.linspace(0, min(200, (critical_alpha or 200) * 1.5), 50)
|
| target_probs = []
|
| for a in alphas:
|
| probs = F.softmax((base_sig + a * target_recipe) @ W_v2s[:50257].t(), dim=-1)
|
| target_probs.append(probs[tid].item())
|
|
|
| all_curves.append({
|
| 'text': text, 'target': target, 'orig': orig_word,
|
| 'critical_alpha': critical_alpha, 'alphas': alphas,
|
| 'target_probs': target_probs, 'orig_prob': orig_prob
|
| })
|
|
|
| if critical_alpha:
|
| print(f" > '{text}' → '{target}': 临界 α={critical_alpha:.1f} (原: '{orig_word}')")
|
| else:
|
| print(f" > '{text}' → '{target}': 在 α≤200 范围内未突破 (原: '{orig_word}')")
|
|
|
| fig, axes = plt.subplots(1, len(all_curves), figsize=(6 * len(all_curves), 5))
|
| if len(all_curves) == 1:
|
| axes = [axes]
|
|
|
| for i, curve in enumerate(all_curves):
|
| ax = axes[i]
|
| ax.plot(curve['alphas'], curve['target_probs'], 'o-', color='#9b59b6',
|
| lw=2, markersize=3)
|
| if curve['critical_alpha']:
|
| ax.axvline(curve['critical_alpha'], color='red', linestyle='--',
|
| label=f"Critical α={curve['critical_alpha']:.1f}")
|
| ax.axhline(y=curve['orig_prob'], color='gray', linestyle=':', alpha=0.5,
|
| label=f"Baseline P('{curve['target']}')={curve['orig_prob']:.1e}")
|
| ax.set_title(f"'{curve['text']}'\n'{curve['orig']}' → '{curve['target']}'",
|
| fontsize=10, fontweight='bold')
|
| ax.set_xlabel("Steering Strength (α)")
|
| ax.set_ylabel(f"P('{curve['target']}')")
|
| ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
|
| ax.legend(fontsize=8)
|
| ax.grid(True, alpha=0.3)
|
|
|
| plt.suptitle("reFlow Concept Inception — Steering Curves", fontsize=14, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| save_path = os.path.join(report_dir, "concept_inception.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
| print(f" > 图表已保存: {save_path}")
|
|
|
|
|
| def exp_12_genetic_hijack(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 12] 基因库篡改 (Genetic Hijack)")
|
| print("="*60)
|
|
|
| W_v2s_eff = _get_vocab_signals(model)
|
| W_v2s_raw = model.transformer.wte.vocab_to_signals.weight.data
|
| pos_words = ["excellent", "perfect", "wonderful", "amazing"]
|
| neg_words = ["terrible", "bad", "disgusting", "awful"]
|
| pos_rec = torch.stack([W_v2s_eff[enc.encode(" " + w)[0]] for w in pos_words]).mean(dim=0)
|
| neg_rec = torch.stack([W_v2s_eff[enc.encode(" " + w)[0]] for w in neg_words]).mean(dim=0)
|
|
|
| def gen(prompt, max_tokens=15):
|
| x = torch.tensor(enc.encode(prompt), device=device).unsqueeze(0)
|
| with torch.no_grad():
|
| for _ in range(max_tokens):
|
| idx = x if x.size(1) <= model.config.block_size else x[:, -model.config.block_size:]
|
| logits, _ = model(idx)
|
| probs = F.softmax(logits[:, -1, :], dim=-1)
|
| next_id = torch.argmax(probs, dim=-1).unsqueeze(0)
|
| x = torch.cat((x, next_id), dim=1)
|
| return enc.decode(x[0].tolist())
|
|
|
| prompt = "The food was disgusting."
|
|
|
| text_control = gen(prompt)
|
| print(f" [对照组] 自然生成:")
|
| print(f" \033[90m{text_control}\033[0m")
|
|
|
| orig_W = W_v2s_raw.clone()
|
|
|
| alpha = 1.5
|
| print(f" * 注入积极基因, 抹除消极基因 (Alpha={alpha})...")
|
| W_v2s_raw.add_(alpha * pos_rec - alpha * neg_rec)
|
|
|
| try:
|
| text_hijacked = gen(prompt)
|
| print(f" [干预组] 篡改后生成:")
|
| print(f" \033[92m{text_hijacked}\033[0m")
|
| finally:
|
| W_v2s_raw.copy_(orig_W)
|
| print(" * 基因库已恢复原状,防止污染后续实验。")
|
|
|
| print(f"\n > 实验完成。对照组与干预组的文本对比即为结果。")
|
|
|
| def exp_13_task_crystallization_shift(model, enc, device, report_dir):
|
| print("\n" + "="*60)
|
| print(" [实验 13] 任务类型与结晶边界偏移 (Context-Dependent Crystallization)")
|
| print("="*60)
|
|
|
| W_basis = model.transformer.wte.signal_basis.data
|
| W_v2s = _get_vocab_signals(model)
|
| n_layers = len(model.transformer.h)
|
|
|
|
|
|
|
| task_groups = {
|
| "Shallow (Short Context)": [
|
| ("The capital of France is", "London"),
|
| ("The cat sat on the", "moon"),
|
| ("The sky is", "red"),
|
| ("Open the door with a", "car")
|
| ],
|
| "Deep (Long Context / Clauses)": [
|
| ("When the geography teacher asked the students, they answered that the capital of France is", "London"),
|
| ("After carefully reviewing all the evidence presented in court, the judge decided that the defendant was", "guilty"),
|
| ("When you look outside the window at the beautiful nature, the color of the clear sky is", "red"),
|
| ("I was locked out of my house yesterday, and to open the locked door, you need a", "car")
|
| ],
|
| "Code (Structured Logic)": [
|
| ("def add(a, b): return a +", "None"),
|
| ("x = 1 + 2\ny =", "None"),
|
| ("for i in range(10):\n print(", "None"),
|
| ("if x > 0:\n result =", "None")
|
| ]
|
| }
|
|
|
| def continuous_steer(prompt, target_tid, base_tid, alpha, intercept_layer):
|
|
|
| steer_vec = W_v2s[target_tid] - W_v2s[base_tid]
|
| ids = torch.tensor(enc.encode(prompt), device=device).unsqueeze(0)
|
|
|
| with torch.no_grad():
|
| x = _embed(model, ids)
|
|
|
| if intercept_layer == 0:
|
| x[:, -1, :] += (alpha * steer_vec) @ W_basis
|
|
|
| freqs_cis = model.freqs_cis[:ids.size(1)]
|
| for i, block in enumerate(model.transformer.h):
|
| x = block(x, freqs_cis)
|
|
|
| if intercept_layer is not None and i + 1 >= intercept_layer:
|
| x[:, -1, :] += (alpha * steer_vec) @ W_basis
|
|
|
| x_norm = model.transformer.ln_f(x[0, -1, :])
|
| logits = _get_logits_from_hidden(model, x_norm)
|
| probs = F.softmax(logits, dim=-1)
|
| pred_id = torch.argmax(logits).item()
|
| return probs[target_tid].item(), enc.decode([pred_id]).strip(), pred_id
|
|
|
| results = {"Shallow (Short Context)": [], "Deep (Long Context / Clauses)": [], "Code (Structured Logic)": []}
|
|
|
| print(" 开始执行层级连续干预扫描 (Continuous Intervention Sweep)...\n")
|
|
|
| for group_name, tasks in task_groups.items():
|
| print(f" [{group_name}]")
|
| for prompt, target in tasks:
|
| target_clean = target.strip()
|
| target_tid = enc.encode(" " + target)[0]
|
|
|
|
|
| _, base_pred, base_tid = continuous_steer(prompt, target_tid, target_tid, 0.0, None)
|
| if base_pred == target_clean:
|
| print(f" [Skip] '{prompt[:20]}...' 自然预测已是 '{target_clean}'。")
|
| continue
|
|
|
|
|
| working_alpha = None
|
| for a in np.arange(2.0, 50.0, 2.0):
|
| _, pred, _ = continuous_steer(prompt, target_tid, base_tid, a, 0)
|
| if pred == target_clean:
|
| working_alpha = a
|
| break
|
|
|
| if working_alpha is None:
|
| print(f" [Skip] '{prompt[:20]}...': Alpha在50内无法干预,跳过。")
|
| continue
|
|
|
|
|
| final_alpha = working_alpha * 1.2
|
|
|
|
|
| layer_probs = []
|
| c_layer = n_layers
|
|
|
| for L in range(n_layers):
|
| p_target, pred, _ = continuous_steer(prompt, target_tid, base_tid, final_alpha, L)
|
| layer_probs.append(p_target)
|
|
|
|
|
| if pred != target_clean and c_layer == n_layers:
|
| c_layer = L
|
|
|
| results[group_name].append({
|
| 'prompt': prompt,
|
| 'target': target_clean,
|
| 'alpha': final_alpha,
|
| 'base_pred': base_pred,
|
| 'c_layer': c_layer,
|
| 'layer_probs': layer_probs
|
| })
|
|
|
| short_prompt = prompt[:35] + "..." if len(prompt) > 35 else prompt
|
| print(f" - '{short_prompt}' (原预测: '{base_pred}')")
|
| print(f" -> 持续注入 '{target_clean}' (α={final_alpha:.1f}) | 结晶失效边界: \033[96mLayer {c_layer}\033[0m")
|
| print()
|
|
|
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'width_ratios': [2, 1]})
|
|
|
| layers_x = np.arange(0, n_layers)
|
| colors = {"Shallow (Short Context)": "#2ecc71", "Deep (Long Context / Clauses)": "#9b59b6", "Code (Structured Logic)": "#e67e22"}
|
|
|
| c_layers_shallow = []
|
| c_layers_deep = []
|
| c_layers_code = []
|
|
|
| for group_name, res_list in results.items():
|
| color = colors[group_name]
|
| for i, res in enumerate(res_list):
|
| if "Shallow" in group_name:
|
| c_layers_shallow.append(res['c_layer'])
|
| elif "Deep" in group_name:
|
| c_layers_deep.append(res['c_layer'])
|
| elif "Code" in group_name:
|
| c_layers_code.append(res['c_layer'])
|
|
|
| label = group_name if i == 0 else "_nolegend_"
|
| ax1.plot(layers_x, res['layer_probs'], color=color, alpha=0.6, lw=2.5, label=label)
|
|
|
| c_idx = res['c_layer']
|
| if c_idx < n_layers:
|
| ax1.scatter(c_idx, res['layer_probs'][c_idx], color=color, s=120, marker='X', edgecolors='black', zorder=5)
|
|
|
| ax1.set_title("Target Concept Viability vs. Injection Delay", fontsize=12, fontweight='bold')
|
| ax1.set_xlabel("Intervention Start Layer (Later start = Context already crystallized)")
|
| ax1.set_ylabel("Final Probability of Injected Concept")
|
| ax1.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
|
| ax1.legend(fontsize=10)
|
| ax1.grid(True, alpha=0.3)
|
|
|
| box_data = []
|
| box_labels = []
|
| box_colors_list = []
|
| if c_layers_shallow:
|
| box_data.append(c_layers_shallow)
|
| box_labels.append("Shallow\n(Short)")
|
| box_colors_list.append(colors["Shallow (Short Context)"])
|
| if c_layers_deep:
|
| box_data.append(c_layers_deep)
|
| box_labels.append("Deep\n(Long)")
|
| box_colors_list.append(colors["Deep (Long Context / Clauses)"])
|
| if c_layers_code:
|
| box_data.append(c_layers_code)
|
| box_labels.append("Code\n(Structured)")
|
| box_colors_list.append(colors["Code (Structured Logic)"])
|
|
|
| if len(box_data) >= 2:
|
| bplot = ax2.boxplot(box_data, patch_artist=True, widths=0.5)
|
| ax2.set_xticks(range(1, len(box_data) + 1))
|
| ax2.set_xticklabels(box_labels)
|
|
|
| for patch, c in zip(bplot['boxes'], box_colors_list):
|
| patch.set_facecolor(c)
|
| patch.set_alpha(0.6)
|
|
|
| for idx, (data, c) in enumerate(zip(box_data, box_colors_list)):
|
| ax2.scatter(np.random.normal(idx + 1, 0.05, len(data)), data, color=c, alpha=0.9, s=50)
|
|
|
| ax2.set_title("Crystallization Boundary Distribution", fontsize=12, fontweight='bold')
|
| ax2.set_ylabel("Crystallization Layer (Point of No Return)")
|
| ax2.set_ylim(-1, n_layers + 2)
|
| ax2.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
|
| ax2.grid(True, axis='y', alpha=0.3)
|
|
|
| plt.suptitle("reFlow Causal Audit: Context Type Affects Information Crystallization", fontsize=15, fontweight='bold')
|
| plt.tight_layout(rect=[0, 0, 1, 0.95])
|
|
|
| save_path = os.path.join(report_dir, "task_crystallization_shift.png")
|
| plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
| plt.close()
|
|
|
| print(" ================= 实验结论 =================")
|
| if c_layers_shallow:
|
| avg_shallow = np.mean(c_layers_shallow)
|
| print(f" > 短上下文 (浅层任务) 平均结晶边界: Layer {avg_shallow:.1f}")
|
| if c_layers_deep:
|
| avg_deep = np.mean(c_layers_deep)
|
| print(f" > 长上下文 (深层任务) 平均结晶边界: Layer {avg_deep:.1f}")
|
| if c_layers_code:
|
| avg_code = np.mean(c_layers_code)
|
| print(f" > 代码 (结构化逻辑) 平均结晶边界: Layer {avg_code:.1f}")
|
| if c_layers_shallow and c_layers_deep:
|
| print(f" > 短→长 边界延迟量: \033[93m{np.mean(c_layers_deep) - np.mean(c_layers_shallow):+.1f} Layers\033[0m")
|
| if c_layers_shallow and c_layers_code:
|
| print(f" > 短→代码 边界延迟量: \033[93m{np.mean(c_layers_code) - np.mean(c_layers_shallow):+.1f} Layers\033[0m")
|
| print(f" > 实验表明:不同任务类型的上下文复杂度影响模型内部表征的结晶边界,")
|
| print(f" 更复杂的上下文倾向于在更深层级保持内部表征的流动性。")
|
| print(f" > 图表已保存: {save_path}")
|
|
|
| def main_menu():
|
| model, enc, device, report_dir = load_setup_and_model()
|
|
|
| experiments = {
|
| '1': ("配方空间图谱 (Recipe Atlas)", exp_1_recipe_atlas),
|
| '2': ("信号稀疏性分析 (Sparsity Profile)", exp_2_sparsity_profile),
|
| '3': ("信号基底几何 (Basis Geometry)", exp_3_basis_geometry),
|
| '4': ("语义星空图 PCA (Semantic Galaxy)", exp_4_semantic_galaxy),
|
| '5': ("语义代数运算 (Semantic Algebra)", exp_5_semantic_algebra),
|
| '6': ("拼写鲁棒性 (Typo Resilience)", exp_6_typo_resilience),
|
| '7': ("层级概率演化 (Layer Evolution)", exp_7_layer_evolution),
|
| '8': ("信号流追踪 (Signal Flow)", exp_8_signal_flow),
|
| '9': ("因果消融曲线 (Causal Ablation)", exp_9_causal_ablation),
|
| '10': ("情绪手术 (Emotion Surgery)", exp_10_emotion_surgery),
|
| '11': ("概念注入 (Concept Inception)", exp_11_concept_inception),
|
| '12': ("基因库篡改 (Genetic Hijack)", exp_12_genetic_hijack),
|
| '13': ("任务结晶边界偏移 (Task Shift)", exp_13_task_crystallization_shift),
|
| }
|
|
|
| while True:
|
| print("\n" + "#"*60)
|
| print(" reFlow 可解释性实验套件 (Interpretability Suite)".center(56))
|
| print("#"*60)
|
| for k, v in experiments.items():
|
| print(f" [{k.rjust(2)}] {v[0]}")
|
| print(" [all] 运行所有实验")
|
| print(" [ q ] 退出系统")
|
| print("#"*60)
|
|
|
| choice = input("请输入要运行的实验编号 (空格分隔, 如 '1 3 5'): ").strip().lower()
|
|
|
| if choice == 'q' or choice == 'quit':
|
| print("退出系统。")
|
| break
|
|
|
| selected_keys = list(experiments.keys()) if choice == 'all' else choice.split()
|
|
|
| for k in selected_keys:
|
| if k in experiments:
|
| func = experiments[k][1]
|
| try:
|
| func(model, enc, device, report_dir)
|
| except Exception as e:
|
| print(f"\n [ERROR] 实验 {k} 运行失败: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| else:
|
| print(f"[忽略] 无效的选项: {k}")
|
|
|
| if selected_keys:
|
| print(f"\n[INFO] 当前批次实验已完成。图表报告保存在: {report_dir}")
|
|
|
| if __name__ == "__main__":
|
| main_menu()
|
|
|