reFlow / experiment_en.py
reuAC's picture
Upload 4 files
1d9e33d verified
import sys
import os
import torch
import torch.nn.functional as F
import tiktoken
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as ticker
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import squareform
try:
from adjustText import adjust_text
except ImportError:
print("[WARNING] adjustText not installed. PCA chart labels may overlap. Run: pip install adjustText")
adjust_text = lambda texts, **kwargs: None
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("paper", font_scale=1.2)
def load_setup_and_model():
if len(sys.argv) != 2:
print("[ERROR] Config file required!\nUsage: python experiment_en.py <config_file>")
sys.exit(1)
config_file = sys.argv[1]
if not os.path.exists(config_file):
print(f"[ERROR] Config file not found: {config_file}")
sys.exit(1)
print(f"\n[INFO] Loading config: {config_file}")
config_dict = {}
with open(config_file, encoding='utf-8') as f:
exec(f.read(), config_dict)
out_dir = config_dict.get('out_dir', 'out/reflow-1')
model_config = config_dict.get('model_config', 'reflow')
with open(f"models/{model_config}.py", encoding='utf-8') as f:
exec(f.read(), globals())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
report_dir = os.path.join(out_dir, 'audit_reports')
os.makedirs(report_dir, exist_ok=True)
print(f"[INFO] Loading reFlow model from {out_dir} (Device: {device})...")
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
if not os.path.exists(ckpt_path):
print(f"[ERROR] Checkpoint not found: {ckpt_path}")
sys.exit(1)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
model = globals()['GPT'](globals()['GPTConfig'](**checkpoint['model_args']))
state_dict = checkpoint['model']
for k in list(state_dict.keys()):
if k.startswith('_orig_mod.'): state_dict[k[10:]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval().to(device)
enc = tiktoken.get_encoding("gpt2")
return model, enc, device, report_dir
def _embed(model, ids):
"""Compatible with wte() return value: reflow-topk returns tuple, others return tensor"""
result = model.transformer.wte(ids)
return result[0] if isinstance(result, tuple) else result
def _get_vocab_signals(model):
"""Get effective vocab->signals weights; topk variant applies sparsification"""
wte = model.transformer.wte
if hasattr(wte, '_apply_sparsity'):
return wte._apply_sparsity(wte.vocab_to_signals.weight.data)
return wte.vocab_to_signals.weight.data
def _forward_through_layers(model, ids):
"""Forward through all transformer layers, return final hidden state."""
with torch.no_grad():
x = _embed(model, ids)
freqs_cis = model.freqs_cis[:ids.size(1)]
for block in model.transformer.h:
x = block(x, freqs_cis)
return x
def _get_logits_from_hidden(model, x_norm):
"""Compute logits from layer-normed hidden state."""
vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix()
return F.linear(x_norm, vocab_matrix)
def exp_1_recipe_atlas(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 1] Recipe Atlas")
print("="*60)
W_v2s = _get_vocab_signals(model)
real_vocab = 50257
W = W_v2s[:real_vocab]
probe_words = [
"China", "France", "Japan", "Germany", "India", "Russia",
"Paris", "London", "Tokyo", "Berlin", "Beijing", "Rome",
"cat", "dog", "fish", "bird", "horse", "bear", "wolf",
"red", "blue", "green", "black", "white", "yellow",
"happy", "sad", "angry", "love", "fear", "hate", "joy",
"one", "two", "three", "four", "five", "ten", "hundred",
"run", "walk", "think", "eat", "write", "read", "speak",
"big", "small", "hot", "cold", "fast", "slow", "good", "bad",
"king", "queen", "man", "woman", "boy", "girl",
"water", "fire", "earth", "light", "dark", "sun", "moon",
]
probe_ids, probe_labels = [], []
for w in probe_words:
tids = enc.encode(" " + w)
if tids and tids[0] < real_vocab:
probe_ids.append(tids[0])
probe_labels.append(w)
probe_recipes = W[probe_ids]
probe_normed = F.normalize(probe_recipes, dim=1)
sim_matrix = (probe_normed @ probe_normed.t()).cpu().numpy()
sim_matrix_no_diag = sim_matrix.copy()
np.fill_diagonal(sim_matrix_no_diag, 0)
n_probe = len(probe_labels)
pairs = []
for i in range(n_probe):
for j in range(i + 1, n_probe):
pairs.append((sim_matrix_no_diag[i, j], probe_labels[i], probe_labels[j]))
pairs.sort(reverse=True)
print("\n Recipe-Space Nearest Neighbor Pairs (Top-20):")
print(" " + "-"*50)
for rank, (sim, w1, w2) in enumerate(pairs[:20]):
print(f" #{rank+1:2d} | {w1:>10s} <-> {w2:<10s} | cos={sim:.4f}")
W_normed = F.normalize(W, dim=1)
print("\n Full-Vocabulary Recipe Neighbors (Top-5 per word):")
print(" " + "-"*60)
nn_table = []
for idx, (tid, label) in enumerate(zip(probe_ids[:20], probe_labels[:20])):
sims = (W_normed[tid] @ W_normed.t())
sims[tid] = -1
top_vals, top_ids = torch.topk(sims, 5)
neighbors = []
for v, nid in zip(top_vals, top_ids):
try:
nw = enc.decode([nid.item()]).strip()
neighbors.append(f"{nw}({v:.3f})")
except Exception:
neighbors.append(f"[{nid.item()}]({v:.3f})")
nn_str = ", ".join(neighbors)
print(f" {label:>10s} -> {nn_str}")
nn_table.append((label, neighbors))
sig_var = W.var(dim=0).cpu().numpy()
top_var_idx = np.argsort(sig_var)[::-1][:20]
bottom_var_idx = np.argsort(sig_var)[:20]
print(f"\n Signal Variance Analysis:")
print(f" > Highest variance signals (most discriminative): {top_var_idx[:10].tolist()}")
print(f" > Lowest variance signals (near-constant): {bottom_var_idx[:10].tolist()}")
print(f" > Variance Gini coefficient: {_gini(sig_var):.4f}")
fig = plt.figure(figsize=(20, 14))
gs = fig.add_gridspec(2, 2, height_ratios=[1.2, 1])
ax1 = fig.add_subplot(gs[0, :])
dist_matrix = np.clip(1 - sim_matrix, 0, None)
np.fill_diagonal(dist_matrix, 0)
Z = linkage(squareform(dist_matrix), method='ward')
order = leaves_list(Z)
sim_ordered = sim_matrix[np.ix_(order, order)]
labels_ordered = [probe_labels[i] for i in order]
im = ax1.imshow(sim_ordered, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
ax1.set_xticks(range(n_probe))
ax1.set_xticklabels(labels_ordered, rotation=90, fontsize=7)
ax1.set_yticks(range(n_probe))
ax1.set_yticklabels(labels_ordered, fontsize=7)
plt.colorbar(im, ax=ax1, fraction=0.02)
ax1.set_title("Recipe Cosine Similarity (hierarchical clustering order)", fontsize=12, fontweight='bold')
ax2 = fig.add_subplot(gs[1, 0])
sorted_var = np.sort(sig_var)[::-1]
ax2.bar(range(len(sorted_var)), sorted_var, color='steelblue', alpha=0.7, width=1.0)
ax2.set_title("Signal Variance Across Vocabulary (sorted)", fontsize=11, fontweight='bold')
ax2.set_xlabel("Signal (sorted by variance)")
ax2.set_ylabel("Variance")
ax2.axhline(y=np.mean(sig_var), color='red', linestyle='--', label=f'Mean: {np.mean(sig_var):.4f}')
ax2.legend()
ax3 = fig.add_subplot(gs[1, 1])
ax3.axis('off')
table_data = []
for label, neighbors in nn_table[:15]:
nn_short = ", ".join(n.split("(")[0] for n in neighbors[:4])
table_data.append([label, nn_short])
table = ax3.table(cellText=table_data, colLabels=["Word", "Top-4 Recipe Neighbors"],
loc='center', cellLoc='left')
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1.0, 1.5)
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_facecolor('#4472C4')
cell.set_text_props(color='white', fontweight='bold')
elif row % 2 == 0:
cell.set_facecolor('#D9E2F3')
ax3.set_title("Vocabulary Recipe Nearest Neighbors", fontsize=11, fontweight='bold', pad=15)
plt.suptitle("reFlow Recipe Atlas — Signal Recipe Space Structure", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
save_path = os.path.join(report_dir, "recipe_atlas.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f"\n > Chart saved: {save_path}")
def _gini(arr):
"""Compute Gini coefficient measuring distribution inequality. 0=uniform, 1=concentrated."""
arr = np.sort(np.abs(arr))
n = len(arr)
if n == 0 or np.sum(arr) == 0:
return 0.0
index = np.arange(1, n + 1)
return (2 * np.sum(index * arr) / (n * np.sum(arr))) - (n + 1) / n
def exp_2_sparsity_profile(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 2] Sparsity Profile")
print("="*60)
W_v2s = _get_vocab_signals(model)
real_vocab = 50257
W = W_v2s[:real_vocab]
vocab_size, n_signals = W.shape
is_topk = hasattr(model.transformer.wte, '_apply_sparsity')
if is_topk:
nonzero_mask = W.abs() > 0
active_per_word = nonzero_mask.sum(dim=1).float()
k = int(active_per_word.median().item())
print(f" > TopK sparse mode detected, fixed k={k}")
nonzero_vals = W[nonzero_mask].abs().cpu().numpy()
active_per_signal = nonzero_mask.sum(dim=0).cpu().numpy()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.hist(nonzero_vals, bins=80, color='teal', alpha=0.7, edgecolor='black')
ax1.set_title(f"Active Signal Amplitude Distribution (k={k})")
ax1.set_xlabel("Absolute Amplitude")
ax1.set_ylabel("Frequency")
ax2.bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0)
ax2.set_title("Signal Utilization (# words activating each signal)")
ax2.set_xlabel("Signal Index")
ax2.set_ylabel("# Words")
ax2.axhline(y=np.mean(active_per_signal), color='red', linestyle='--',
label=f'Mean: {np.mean(active_per_signal):.0f}')
ax2.legend()
else:
mean_val = W.abs().mean().item()
std_val = W.abs().std().item()
threshold = mean_val + std_val
active_mask = W.abs() > threshold
active_per_word = active_mask.sum(dim=1).cpu().numpy()
active_per_signal = active_mask.sum(dim=0).cpu().numpy()
print(f" > Activation threshold: {threshold:.4f} (mean + std)")
print(f" > Mean active signals per word: {np.mean(active_per_word):.1f} / {n_signals}")
print(f" > Global activation rate: {active_mask.float().mean().item():.2%}")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
int_bins = np.arange(active_per_word.min(), active_per_word.max() + 2) - 0.5
ax1.hist(active_per_word, bins=int_bins, color='teal', alpha=0.7, edgecolor='black')
ax1.axvline(x=np.mean(active_per_word), color='red', linestyle='--',
label=f'Mean: {np.mean(active_per_word):.1f}')
ax1.set_title("Per-Word Sparsity (# Active Signals)")
ax1.set_xlabel("Number of Active Signals")
ax1.set_ylabel("Frequency")
ax1.legend()
ax2.bar(range(n_signals), active_per_signal, color='coral', alpha=0.7, width=1.0)
ax2.set_title("Signal Utilization (# words activating each signal)")
ax2.set_xlabel("Signal Index")
ax2.set_ylabel("# Words")
ax2.axhline(y=np.mean(active_per_signal), color='red', linestyle='--',
label=f'Mean: {np.mean(active_per_signal):.0f}')
ax2.legend()
plt.suptitle("reFlow Sparsity Profile", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
save_path = os.path.join(report_dir, "sparsity_profile.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
# === Export data for paper plotting ===
print("\n" + "="*60)
print(" [Paper Data Export] For TikZ/PGFPlots")
print("="*60)
if is_topk:
active_per_word_np = active_per_word.cpu().numpy()
else:
active_per_word_np = active_per_word
# --- Figure 1: Histogram data for active signals per word ---
hist_min = int(active_per_word_np.min())
hist_max = int(active_per_word_np.max())
hist_bins = np.arange(hist_min, hist_max + 2)
hist_counts, hist_edges = np.histogram(active_per_word_np, bins=hist_bins)
print(f"\n [Histogram] Active signals per word distribution (bin_start, count):")
print(f" mean={np.mean(active_per_word_np):.1f}, min={hist_min}, max={hist_max}")
print(" ---BEGIN_HISTOGRAM_DATA---")
for i in range(len(hist_counts)):
if hist_counts[i] > 0:
print(f" {int(hist_edges[i])} {hist_counts[i]}")
print(" ---END_HISTOGRAM_DATA---")
# --- Figure 2: Signal utilization data (sorted by utilization) ---
sorted_utilization = np.sort(active_per_signal)[::-1]
print(f"\n [Bar chart] Signal utilization (descending order, signal_rank, n_words):")
print(f" mean={np.mean(active_per_signal):.0f}, min={np.min(active_per_signal)}, max={np.max(active_per_signal)}")
print(" ---BEGIN_UTILIZATION_DATA---")
for i, val in enumerate(sorted_utilization):
print(f" {i} {val}")
print(" ---END_UTILIZATION_DATA---")
def exp_3_basis_geometry(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 3] Signal Basis Geometry")
print("="*60)
W_basis = model.transformer.wte.signal_basis.data.cpu().float()
n_signals, n_embd = W_basis.shape
U, S, Vt = torch.linalg.svd(W_basis, full_matrices=False)
S_np = S.numpy()
s_norm = S_np / S_np.sum()
effective_rank = np.exp(-np.sum(s_norm * np.log(s_norm + 1e-12)))
random_mat = torch.randn_like(W_basis)
_, S_rand, _ = torch.linalg.svd(random_mat, full_matrices=False)
S_rand_np = S_rand.numpy()
s_rand_norm = S_rand_np / S_rand_np.sum()
effective_rank_rand = np.exp(-np.sum(s_rand_norm * np.log(s_rand_norm + 1e-12)))
print(f" > Signal basis shape: ({n_signals}, {n_embd})")
print(f" > Effective rank (learned): {effective_rank:.1f} / {min(n_signals, n_embd)}")
print(f" > Effective rank (random): {effective_rank_rand:.1f} / {min(n_signals, n_embd)}")
show_n = min(64, n_signals)
W_show = W_basis[:show_n]
W_normed = F.normalize(W_show, dim=1)
cos_sim = (W_normed @ W_normed.t()).numpy()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
ax1.plot(S_np / S_np[0], 'b-', lw=2, label='Learned Basis')
ax1.plot(S_rand_np / S_rand_np[0], 'r--', lw=1.5, label='Random Gaussian')
ax1.set_title(f"Singular Value Spectrum\n(Eff. rank: learned={effective_rank:.0f}, random={effective_rank_rand:.0f})")
ax1.set_xlabel("Component Index")
ax1.set_ylabel("Normalized Singular Value")
ax1.set_yscale('log')
ax1.legend()
ax1.grid(True, alpha=0.3)
im = ax2.imshow(cos_sim, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
ax2.set_title(f"Cosine Similarity (first {show_n} signals)")
ax2.set_xlabel("Signal Index")
ax2.set_ylabel("Signal Index")
plt.colorbar(im, ax=ax2, fraction=0.046)
plt.suptitle("reFlow Signal Basis Geometry", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
save_path = os.path.join(report_dir, "basis_geometry.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_4_semantic_galaxy(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 4] Semantic Galaxy (PCA)")
print("="*60)
W_v2s = _get_vocab_signals(model).cpu().numpy()
real_vocab = 50257
clusters = {
"Countries": ["China", "France", "Germany", "Japan", "India", "Russia"],
"Animals": ["cat", "dog", "fish", "bird", "horse", "bear"],
"Numbers": ["one", "two", "three", "four", "five", "ten"],
"Colors": ["red", "blue", "green", "black", "white", "yellow"],
"Emotions": ["happy", "sad", "angry", "love", "fear", "hate"],
}
recipes, labels, words = [], [], []
for cat, wl in clusters.items():
for w in wl:
tids = enc.encode(" " + w)
if tids and tids[0] < real_vocab:
recipes.append(W_v2s[tids[0]])
labels.append(cat)
words.append(w)
recipes_arr = np.array(recipes)
coords = PCA(n_components=2).fit_transform(recipes_arr)
label_ids = [list(clusters.keys()).index(l) for l in labels]
sil = silhouette_score(recipes_arr, label_ids)
print(f" > Silhouette Score: {sil:.4f}")
plt.figure(figsize=(12, 9))
color_map = dict(zip(clusters.keys(), sns.color_palette("Set2", len(clusters))))
texts = []
for i, w in enumerate(words):
plt.scatter(coords[i, 0], coords[i, 1], color=color_map[labels[i]],
s=150, alpha=0.7, edgecolors='white', linewidths=0.5)
texts.append(plt.text(coords[i, 0], coords[i, 1], w, fontsize=11))
if adjust_text.__name__ != '<lambda>':
adjust_text(texts, arrowprops=dict(arrowstyle="-", color='gray'))
handles = [plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=color_map[l], markersize=12, label=l) for l in clusters]
plt.legend(handles=handles, title="Clusters", fontsize=10)
plt.title(f"reFlow Semantic Galaxy (PCA)\nSilhouette Score = {sil:.4f}",
fontsize=14, fontweight='bold')
plt.xlabel("PC1")
plt.ylabel("PC2")
save_path = os.path.join(report_dir, "semantic_galaxy.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_5_semantic_algebra(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 5] Semantic Algebra")
print("="*60)
W_v2s = _get_vocab_signals(model)
W_valid = W_v2s[:50257]
test_cases = [
(["Paris", "China"], ["France"], "Beijing"),
(["king", "woman"], ["man"], "queen"),
(["walked", "running"], ["walking"], "ran"),
]
results = []
for pos_words, neg_words, expected in test_cases:
expr = " + ".join(pos_words) + " - " + " - ".join(neg_words)
target_vec = torch.zeros(model.config.n_signals, device=device)
exclude_ids = set()
for w in pos_words:
tid = enc.encode(" " + w)[0]
target_vec += W_v2s[tid]
exclude_ids.add(tid)
for w in neg_words:
tid = enc.encode(" " + w)[0]
target_vec -= W_v2s[tid]
exclude_ids.add(tid)
sims = F.cosine_similarity(target_vec.unsqueeze(0), W_valid)
for tid in exclude_ids:
sims[tid] = -1.0
top_vals, top_ids = torch.topk(sims, 20)
expected_tid = enc.encode(" " + expected)[0]
expected_rank = -1
hit_words = []
for i in range(len(top_ids)):
try:
w = enc.decode([top_ids[i].item()]).strip()
if len(w) >= 2:
hit_words.append((w, top_vals[i].item()))
if top_ids[i].item() == expected_tid:
expected_rank = i + 1
except Exception:
continue
if expected_rank == -1:
all_sims = sims.clone()
sorted_ids = torch.argsort(all_sims, descending=True)
for rank_i, sid in enumerate(sorted_ids[:500]):
if sid.item() == expected_tid:
expected_rank = rank_i + 1
break
results.append((expr, expected, expected_rank, hit_words[:5]))
print(f"\n {expr} -> Expected: '{expected}'")
if expected_rank > 0:
marker = "HIT!" if expected_rank <= 10 else ""
print(f" > '{expected}' rank: #{expected_rank} {marker}")
else:
print(f" > '{expected}' not found in top-500")
print(f" > Top-5: {', '.join(f'{w}({s:.3f})' for w, s in hit_words[:5])}")
fig, ax = plt.subplots(figsize=(14, 4 + len(results)))
ax.axis('off')
table_data = []
for expr, expected, rank, hits in results:
rank_str = f"#{rank}" if rank > 0 else "Not found"
hit_str = ", ".join(w for w, _ in hits[:4])
table_data.append([expr, expected, rank_str, hit_str])
table = ax.table(
cellText=table_data,
colLabels=["Expression", "Expected", "Rank", "Top Hits"],
loc='center', cellLoc='left'
)
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.0, 1.8)
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_facecolor('#4472C4')
cell.set_text_props(color='white', fontweight='bold')
ax.set_title("reFlow Semantic Algebra Results", fontsize=14, fontweight='bold', pad=20)
save_path = os.path.join(report_dir, "semantic_algebra.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f"\n > Chart saved: {save_path}")
def exp_6_typo_resilience(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 6] Typo Resilience")
print("="*60)
sent_normal = "The scientist is very intelligent"
sent_typo = "The scientsit is vary intellgent"
sent_diff = "The dog runs in the park"
W_basis = model.transformer.wte.signal_basis.data
def get_deep_signal(text):
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
with torch.no_grad():
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
return x_norm @ W_basis.t()
sig_normal = get_deep_signal(sent_normal)
sig_typo = get_deep_signal(sent_typo)
sig_diff = get_deep_signal(sent_diff)
sim_typo = F.cosine_similarity(sig_normal.unsqueeze(0), sig_typo.unsqueeze(0)).item()
sim_diff = F.cosine_similarity(sig_normal.unsqueeze(0), sig_diff.unsqueeze(0)).item()
sim_self = 1.0
print(f" > Normal: '{sent_normal}'")
print(f" > Misspelt: '{sent_typo}'")
print(f" > Unrelated: '{sent_diff}'")
print(f"\n [Normal vs Misspelt] Deep semantic similarity: \033[93m{sim_typo:.4f}\033[0m")
print(f" [Normal vs Unrelated] Deep semantic similarity: {sim_diff:.4f}")
print(f" > Robustness metric (difference): {sim_typo - sim_diff:.4f}")
fig, ax = plt.subplots(figsize=(8, 5))
categories = ['Self\n(baseline)', 'Normal vs Typo\n(same meaning)', 'Normal vs Different\n(different meaning)']
values = [sim_self, sim_typo, sim_diff]
colors = ['#2ecc71', '#f39c12', '#e74c3c']
bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', width=0.5)
for bar, val in zip(bars, values):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.4f}', ha='center', fontsize=11, fontweight='bold')
ax.set_ylim(0, 1.15)
ax.set_ylabel("Cosine Similarity")
ax.set_title("reFlow Typo Resilience — Deep Signal Similarity", fontsize=13, fontweight='bold')
ax.grid(axis='y', alpha=0.3)
save_path = os.path.join(report_dir, "typo_resilience.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_7_layer_evolution(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 7] Layer Probability Evolution")
print("="*60)
prompts = [
"The capital of France is",
"The cat sat on the",
"The sun is very",
]
vocab_matrix = model.transformer.wte.get_dynamic_vocab_matrix()
real_vocab = 50257
n_layers = len(model.transformer.h)
fig, axes = plt.subplots(len(prompts), 2, figsize=(18, 5 * len(prompts)),
gridspec_kw={'width_ratios': [1.3, 1]})
if len(prompts) == 1:
axes = axes[np.newaxis, :]
for pi, text in enumerate(prompts):
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
layer_probs = []
layer_entropies = []
with torch.no_grad():
x = _embed(model, ids)
freqs_cis = model.freqs_cis[:ids.size(1)]
for block in model.transformer.h:
x = block(x, freqs_cis)
x_norm = model.transformer.ln_f(x[0, -1, :])
probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1)
layer_probs.append(probs.cpu().numpy())
entropy = -torch.sum(probs * torch.log(probs + 1e-9)).item()
layer_entropies.append(entropy)
final_probs = layer_probs[-1][:real_vocab]
top_idx = np.argsort(final_probs)[-6:]
prob_flow = np.array([[p[i] for i in top_idx] for p in layer_probs])
layers = np.arange(1, n_layers + 1)
ax_prob = axes[pi, 0]
colors = sns.color_palette("husl", len(top_idx))
for i, idx in enumerate(top_idx):
label = repr(enc.decode([idx])).strip("'")
ax_prob.plot(layers, prob_flow[:, i], label=label, lw=2.5, color=colors[i])
ax_prob.set_title(f"Probability Evolution: '{text}'", fontsize=11, fontweight='bold')
ax_prob.set_xlabel("Layer")
ax_prob.set_ylabel("Probability")
ax_prob.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
ax_prob.legend(fontsize=8, loc='upper left')
ax_prob.grid(True, alpha=0.3)
ax_ent = axes[pi, 1]
ax_ent.plot(layers, layer_entropies, color='#FF6B35', lw=2.5, marker='o', markersize=3)
ax_ent.set_title(f"Entropy Decay: '{text}'", fontsize=11, fontweight='bold')
ax_ent.set_xlabel("Layer")
ax_ent.set_ylabel("Entropy (nats)")
ax_ent.grid(True, alpha=0.3)
predicted = enc.decode([np.argmax(final_probs)])
print(f" > Prompt: '{text}' -> Prediction: '{predicted}' (p={final_probs.max():.2%})")
plt.suptitle("reFlow Layer Probability Evolution", fontsize=15, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
save_path = os.path.join(report_dir, "layer_evolution.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_8_signal_flow(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 8] Signal Flow Tracking")
print("="*60)
text = "The capital of France is"
W_basis = model.transformer.wte.signal_basis.data
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
n_layers = len(model.transformer.h)
tokens = [repr(enc.decode([t])).strip("'") for t in ids[0].tolist()]
layer_signals_last_token = []
final_layer_all_tokens = None
with torch.no_grad():
x = _embed(model, ids)
freqs_cis = model.freqs_cis[:ids.size(1)]
for li, block in enumerate(model.transformer.h):
x = block(x, freqs_cis)
x_norm = model.transformer.ln_f(x[0])
sigs = (x_norm @ W_basis.t()).cpu().numpy()
layer_signals_last_token.append(sigs[-1])
if li == n_layers - 1:
final_layer_all_tokens = sigs
sig_arr = np.array(layer_signals_last_token)
var_across_layers = np.var(sig_arr, axis=0)
top_layer_sig_idx = np.argsort(var_across_layers)[-15:][::-1]
var_across_tokens = np.var(final_layer_all_tokens, axis=0)
top_time_sig_idx = np.argsort(var_across_tokens)[-20:][::-1]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), gridspec_kw={'width_ratios': [1.2, 1]})
layer_data = sig_arr[:, top_layer_sig_idx].T
sns.heatmap(layer_data, cmap='RdBu_r', center=0, ax=ax1,
xticklabels=np.arange(1, n_layers + 1),
yticklabels=[f"Sig {i}" for i in top_layer_sig_idx])
ax1.set_title("Signal Flow Across Layers (last token)", fontsize=11, fontweight='bold')
ax1.set_xlabel("Layer")
ax1.set_ylabel("Signal (by layer variance)")
time_data = final_layer_all_tokens[:, top_time_sig_idx].T
sns.heatmap(time_data, cmap='mako', ax=ax2,
xticklabels=tokens,
yticklabels=[f"Sig {i}" for i in top_time_sig_idx])
ax2.set_title("Signal Activation Across Tokens (final layer)", fontsize=11, fontweight='bold')
ax2.set_xlabel("Token")
ax2.set_ylabel("Signal (by token variance)")
plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
plt.suptitle(f"reFlow Signal Flow — '{text}'", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
save_path = os.path.join(report_dir, "signal_flow.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_9_causal_ablation(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 9] Causal Ablation Curve")
print("="*60)
W_basis = model.transformer.wte.signal_basis.data
W_v2s = _get_vocab_signals(model)
prompts = [
"The capital of France is",
"The cat sat on the",
"The sun is very",
]
ablation_steps = [1, 2, 4, 8, 16, 32, 64, 128]
all_results = []
codebook_info = []
for text in prompts:
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
with torch.no_grad():
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
sig_acts = x_norm @ W_basis.t()
logits_base = sig_acts @ W_v2s[:50257].t()
probs_base = F.softmax(logits_base, dim=-1)
pred_id = torch.argmax(probs_base).item()
pred_word = enc.decode([pred_id])
pred_prob = probs_base[pred_id].item()
contribs = sig_acts * W_v2s[pred_id]
sorted_sig_ids = torch.argsort(contribs, descending=True)
result = {'text': text, 'pred': pred_word, 'base_prob': pred_prob,
'steps': [], 'probs': [], 'new_preds': []}
for n_ablate in ablation_steps:
if n_ablate > len(sorted_sig_ids):
break
ablated = sig_acts.clone()
ablated[sorted_sig_ids[:n_ablate]] = 0.0
logits_abl = ablated @ W_v2s[:50257].t()
probs_abl = F.softmax(logits_abl, dim=-1)
new_pred_id = torch.argmax(probs_abl).item()
result['steps'].append(n_ablate)
result['probs'].append(probs_abl[pred_id].item())
result['new_preds'].append(enc.decode([new_pred_id]))
all_results.append(result)
top_sig = sorted_sig_ids[0].item()
col = W_v2s[:50257, top_sig]
top_vals, top_ids = torch.topk(col, 8)
cb_words = []
for tid in top_ids:
try:
cb_words.append(enc.decode([tid.item()]).strip())
except Exception:
cb_words.append(f"[{tid.item()}]")
codebook_info.append((text, top_sig, cb_words))
print(f"\n Prompt: '{text}'")
print(f" > Baseline prediction: '{pred_word}' (p={pred_prob:.2%})")
print(f" > Key signal #{top_sig} codebook: {', '.join(cb_words[:6])}")
for step, prob, new in zip(result['steps'], result['probs'], result['new_preds']):
print(f" Ablate {step:3d} signals -> p('{pred_word}')={prob:.2%}, new prediction='{new}'")
n_prompts = len(all_results)
fig, axes = plt.subplots(2, n_prompts + 1, figsize=(5.5 * (n_prompts + 1), 9),
gridspec_kw={'width_ratios': [1] * n_prompts + [0.8]})
for i, res in enumerate(all_results):
ax = axes[0][i]
ax.plot(res['steps'], [max(p, 1e-8) for p in res['probs']],
'o-', color='#e74c3c', lw=2.5, markersize=6)
ax.axhline(y=res['base_prob'], color='blue', linestyle='--', alpha=0.5,
label=f"Baseline: {res['base_prob']:.1%}")
ax.set_title(f"'{res['text']}'\nPrediction: '{res['pred']}'", fontsize=10, fontweight='bold')
ax.set_xlabel("# Signals Ablated")
ax.set_ylabel("P(original prediction)")
ax.set_yscale('log')
ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=2))
ax.set_xscale('log', base=2)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
for i, res in enumerate(all_results):
ax = axes[1][i]
retention = [p / res['base_prob'] * 100 for p in res['probs']]
ax.plot(res['steps'], [max(r, 1e-4) for r in retention],
's-', color='#2ecc71', lw=2.5, markersize=6)
ax.axhline(y=100, color='blue', linestyle='--', alpha=0.5, label="Baseline: 100%")
ax.set_title(f"Retention rate", fontsize=10)
ax.set_xlabel("# Signals Ablated")
ax.set_ylabel("% of baseline probability retained")
ax.set_yscale('log')
ax.set_xscale('log', base=2)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
for row in range(2):
ax_cb = axes[row][-1]
ax_cb.axis('off')
ax_cb = axes[0][-1]
cb_text = "Critical Signal Codebook\n" + "="*30
for text, sig_id, words in codebook_info:
short = text[:25] + "..." if len(text) > 25 else text
cb_text += f"\n\n'{short}'\n Key Sig #{sig_id}:\n {', '.join(words[:6])}"
ax_cb.text(0.05, 0.95, cb_text, transform=ax_cb.transAxes, fontsize=9,
verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
plt.suptitle("reFlow Causal Ablation Curve", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
save_path = os.path.join(report_dir, "ablation_curve.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f"\n > Chart saved: {save_path}")
def exp_10_emotion_surgery(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 10] Emotion Surgery")
print("="*60)
W_v2s = _get_vocab_signals(model)
W_basis = model.transformer.wte.signal_basis.data
pos_words = ["excellent", "wonderful", "amazing", "great", "good"]
neg_words = ["terrible", "awful", "horrible", "bad", "poor"]
pos_vec = torch.stack([W_v2s[enc.encode(" " + w)[0]] for w in pos_words]).mean(dim=0)
neg_vec = torch.stack([W_v2s[enc.encode(" " + w)[0]] for w in neg_words]).mean(dim=0)
steer_vec = pos_vec - neg_vec
text = "The food was absolutely terrible and the service was"
n_layers = len(model.transformer.h)
scan_layers = list(range(0, n_layers, max(1, n_layers // 6)))
if scan_layers[-1] != n_layers - 1:
scan_layers.append(n_layers - 1)
scan_alphas = [0.0, 1.0, 3.0, 5.0, 8.0, 12.0]
def trace_emotion(intercept_layer=None, alpha=0.0):
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
pos_acts, neg_acts = [], []
with torch.no_grad():
x = _embed(model, ids)
freqs_cis = model.freqs_cis[:ids.size(1)]
for i, block in enumerate(model.transformer.h):
x = block(x, freqs_cis)
sig = model.transformer.ln_f(x[0, -1, :]) @ W_basis.t()
pos_acts.append(torch.dot(sig, pos_vec).item())
neg_acts.append(torch.dot(sig, neg_vec).item())
if intercept_layer is not None and i >= intercept_layer:
x[:, -1, :] += (steer_vec * alpha) @ W_basis
x_norm = model.transformer.ln_f(x[0, -1, :])
probs = F.softmax(_get_logits_from_hidden(model, x_norm), dim=-1)
pred_word = enc.decode([torch.argmax(probs).item()])
return pred_word, pos_acts, neg_acts
word_base, p_base, n_base = trace_emotion()
print(f" > [Baseline] '{text}' -> '{word_base}'")
grid_results = {}
for layer in scan_layers:
for alpha in scan_alphas:
if alpha == 0.0:
grid_results[(layer, alpha)] = word_base
else:
word, _, _ = trace_emotion(intercept_layer=layer, alpha=alpha)
grid_results[(layer, alpha)] = word
if alpha == 5.0:
print(f" > [Layer {layer:2d}, a={alpha}] -> '{word}'")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'width_ratios': [1.2, 1]})
layers_x = range(n_layers)
ax1.plot(layers_x, n_base, 'r--', lw=2, label='Negative (base)')
ax1.plot(layers_x, p_base, 'b--', lw=2, label='Positive (base)')
best_layer = scan_layers[len(scan_layers) // 2]
_, p_hack, n_hack = trace_emotion(intercept_layer=best_layer, alpha=5.0)
ax1.plot(layers_x, n_hack, 'r', lw=2.5, label=f'Negative (L{best_layer}, a=5)')
ax1.plot(layers_x, p_hack, 'b', lw=2.5, label=f'Positive (L{best_layer}, a=5)')
ax1.axvline(best_layer, color='green', linestyle=':', lw=2, label=f'Surgery @ Layer {best_layer}')
ax1.set_title("Emotion Signal Flow", fontsize=11, fontweight='bold')
ax1.set_xlabel("Layer")
ax1.set_ylabel("Dot Product with Emotion Vector")
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)
ax2.axis('off')
table_data = []
for layer in scan_layers:
row = [f"L{layer}"]
for alpha in scan_alphas:
row.append(grid_results.get((layer, alpha), "?"))
table_data.append(row)
col_labels = ["Layer"] + [f"a={a}" for a in scan_alphas]
table = ax2.table(cellText=table_data, colLabels=col_labels, loc='center', cellLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1.0, 1.6)
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_facecolor('#4472C4')
cell.set_text_props(color='white', fontweight='bold')
ax2.set_title("Intervention Grid (predicted next word)", fontsize=11, fontweight='bold', pad=20)
plt.suptitle(f"reFlow Emotion Surgery\n'{text}'", fontsize=13, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.92])
save_path = os.path.join(report_dir, "emotion_surgery.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_11_concept_inception(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 11] Concept Inception")
print("="*60)
W_basis = model.transformer.wte.signal_basis.data
W_v2s = _get_vocab_signals(model)
test_cases = [
("The capital of France is", "London"),
("The cat sat on the", "moon"),
("The sun is very", "cold"),
]
all_curves = []
for text, target in test_cases:
tid = enc.encode(" " + target)[0]
target_recipe = W_v2s[tid]
ids = torch.tensor(enc.encode(text), device=device).unsqueeze(0)
with torch.no_grad():
x = _forward_through_layers(model, ids)
x_norm = model.transformer.ln_f(x[0, -1, :])
base_sig = x_norm @ W_basis.t()
logits_base = base_sig @ W_v2s[:50257].t()
probs_base = F.softmax(logits_base, dim=-1)
orig_word = enc.decode([torch.argmax(probs_base).item()])
orig_prob = probs_base[tid].item()
lo, hi = 0.0, 200.0
critical_alpha = None
probs_hi = F.softmax((base_sig + hi * target_recipe) @ W_v2s[:50257].t(), dim=-1)
if torch.argmax(probs_hi).item() == tid:
for _ in range(20):
mid = (lo + hi) / 2
probs_mid = F.softmax((base_sig + mid * target_recipe) @ W_v2s[:50257].t(), dim=-1)
if torch.argmax(probs_mid).item() == tid:
hi = mid
else:
lo = mid
critical_alpha = hi
alphas = np.linspace(0, min(200, (critical_alpha or 200) * 1.5), 50)
target_probs = []
for a in alphas:
probs = F.softmax((base_sig + a * target_recipe) @ W_v2s[:50257].t(), dim=-1)
target_probs.append(probs[tid].item())
all_curves.append({
'text': text, 'target': target, 'orig': orig_word,
'critical_alpha': critical_alpha, 'alphas': alphas,
'target_probs': target_probs, 'orig_prob': orig_prob
})
if critical_alpha:
print(f" > '{text}' -> '{target}': critical a={critical_alpha:.1f} (original: '{orig_word}')")
else:
print(f" > '{text}' -> '{target}': not achieved within a<=200 (original: '{orig_word}')")
fig, axes = plt.subplots(1, len(all_curves), figsize=(6 * len(all_curves), 5))
if len(all_curves) == 1:
axes = [axes]
for i, curve in enumerate(all_curves):
ax = axes[i]
ax.plot(curve['alphas'], curve['target_probs'], 'o-', color='#9b59b6',
lw=2, markersize=3)
if curve['critical_alpha']:
ax.axvline(curve['critical_alpha'], color='red', linestyle='--',
label=f"Critical a={curve['critical_alpha']:.1f}")
ax.axhline(y=curve['orig_prob'], color='gray', linestyle=':', alpha=0.5,
label=f"Baseline P('{curve['target']}')={curve['orig_prob']:.1e}")
ax.set_title(f"'{curve['text']}'\n'{curve['orig']}' -> '{curve['target']}'",
fontsize=10, fontweight='bold')
ax.set_xlabel("Steering Strength (a)")
ax.set_ylabel(f"P('{curve['target']}')")
ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
plt.suptitle("reFlow Concept Inception — Steering Curves", fontsize=14, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
save_path = os.path.join(report_dir, "concept_inception.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(f" > Chart saved: {save_path}")
def exp_12_genetic_hijack(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 12] Genetic Hijack")
print("="*60)
W_v2s_eff = _get_vocab_signals(model)
W_v2s_raw = model.transformer.wte.vocab_to_signals.weight.data
pos_words = ["excellent", "perfect", "wonderful", "amazing"]
neg_words = ["terrible", "bad", "disgusting", "awful"]
pos_rec = torch.stack([W_v2s_eff[enc.encode(" " + w)[0]] for w in pos_words]).mean(dim=0)
neg_rec = torch.stack([W_v2s_eff[enc.encode(" " + w)[0]] for w in neg_words]).mean(dim=0)
def gen(prompt, max_tokens=15):
x = torch.tensor(enc.encode(prompt), device=device).unsqueeze(0)
with torch.no_grad():
for _ in range(max_tokens):
idx = x if x.size(1) <= model.config.block_size else x[:, -model.config.block_size:]
logits, _ = model(idx)
probs = F.softmax(logits[:, -1, :], dim=-1)
next_id = torch.argmax(probs, dim=-1).unsqueeze(0)
x = torch.cat((x, next_id), dim=1)
return enc.decode(x[0].tolist())
prompt = "The food was disgusting."
text_control = gen(prompt)
print(f" [Control] Natural generation:")
print(f" \033[90m{text_control}\033[0m")
orig_W = W_v2s_raw.clone()
alpha = 1.5
print(f" * Injecting positive genes, erasing negative genes (Alpha={alpha})...")
W_v2s_raw.add_(alpha * pos_rec - alpha * neg_rec)
try:
text_hijacked = gen(prompt)
print(f" [Hijacked] Post-manipulation generation:")
print(f" \033[92m{text_hijacked}\033[0m")
finally:
W_v2s_raw.copy_(orig_W)
print(" * Recipe matrix restored to prevent contamination of subsequent experiments.")
print(f"\n > Experiment complete. Compare the control and hijacked texts above.")
def exp_13_task_crystallization_shift(model, enc, device, report_dir):
print("\n" + "="*60)
print(" [Exp 13] Task-Dependent Crystallization Boundary")
print("="*60)
W_basis = model.transformer.wte.signal_basis.data
W_v2s = _get_vocab_signals(model)
n_layers = len(model.transformer.h)
task_groups = {
"Shallow (Short Context)": [
("The capital of France is", "London"),
("The cat sat on the", "moon"),
("The sky is", "red"),
("Open the door with a", "car")
],
"Deep (Long Context / Clauses)": [
("When the geography teacher asked the students, they answered that the capital of France is", "London"),
("After carefully reviewing all the evidence presented in court, the judge decided that the defendant was", "guilty"),
("When you look outside the window at the beautiful nature, the color of the clear sky is", "red"),
("I was locked out of my house yesterday, and to open the locked door, you need a", "car")
],
"Code (Structured Logic)": [
("def add(a, b): return a +", "None"),
("x = 1 + 2\ny =", "None"),
("for i in range(10):\n print(", "None"),
("if x > 0:\n result =", "None")
]
}
def continuous_steer(prompt, target_tid, base_tid, alpha, intercept_layer):
steer_vec = W_v2s[target_tid] - W_v2s[base_tid]
ids = torch.tensor(enc.encode(prompt), device=device).unsqueeze(0)
with torch.no_grad():
x = _embed(model, ids)
if intercept_layer == 0:
x[:, -1, :] += (alpha * steer_vec) @ W_basis
freqs_cis = model.freqs_cis[:ids.size(1)]
for i, block in enumerate(model.transformer.h):
x = block(x, freqs_cis)
if intercept_layer is not None and i + 1 >= intercept_layer:
x[:, -1, :] += (alpha * steer_vec) @ W_basis
x_norm = model.transformer.ln_f(x[0, -1, :])
logits = _get_logits_from_hidden(model, x_norm)
probs = F.softmax(logits, dim=-1)
pred_id = torch.argmax(logits).item()
return probs[target_tid].item(), enc.decode([pred_id]).strip(), pred_id
results = {"Shallow (Short Context)": [], "Deep (Long Context / Clauses)": [], "Code (Structured Logic)": []}
print(" Starting continuous intervention sweep...\n")
for group_name, tasks in task_groups.items():
print(f" [{group_name}]")
for prompt, target in tasks:
target_clean = target.strip()
target_tid = enc.encode(" " + target)[0]
_, base_pred, base_tid = continuous_steer(prompt, target_tid, target_tid, 0.0, None)
if base_pred == target_clean:
print(f" [Skip] '{prompt[:20]}...' already predicts '{target_clean}'.")
continue
working_alpha = None
for a in np.arange(2.0, 50.0, 2.0):
_, pred, _ = continuous_steer(prompt, target_tid, base_tid, a, 0)
if pred == target_clean:
working_alpha = a
break
if working_alpha is None:
print(f" [Skip] '{prompt[:20]}...': Cannot steer within alpha<50.")
continue
final_alpha = working_alpha * 1.2
layer_probs = []
c_layer = n_layers
for L in range(n_layers):
p_target, pred, _ = continuous_steer(prompt, target_tid, base_tid, final_alpha, L)
layer_probs.append(p_target)
if pred != target_clean and c_layer == n_layers:
c_layer = L
results[group_name].append({
'prompt': prompt,
'target': target_clean,
'alpha': final_alpha,
'base_pred': base_pred,
'c_layer': c_layer,
'layer_probs': layer_probs
})
short_prompt = prompt[:35] + "..." if len(prompt) > 35 else prompt
print(f" - '{short_prompt}' (base: '{base_pred}')")
print(f" -> Inject '{target_clean}' (α={final_alpha:.1f}) | Crystallization boundary: \033[96mLayer {c_layer}\033[0m")
print()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'width_ratios': [2, 1]})
layers_x = np.arange(0, n_layers)
colors = {"Shallow (Short Context)": "#2ecc71", "Deep (Long Context / Clauses)": "#9b59b6", "Code (Structured Logic)": "#e67e22"}
c_layers_shallow = []
c_layers_deep = []
c_layers_code = []
for group_name, res_list in results.items():
color = colors[group_name]
for i, res in enumerate(res_list):
if "Shallow" in group_name:
c_layers_shallow.append(res['c_layer'])
elif "Deep" in group_name:
c_layers_deep.append(res['c_layer'])
elif "Code" in group_name:
c_layers_code.append(res['c_layer'])
label = group_name if i == 0 else "_nolegend_"
ax1.plot(layers_x, res['layer_probs'], color=color, alpha=0.6, lw=2.5, label=label)
c_idx = res['c_layer']
if c_idx < n_layers:
ax1.scatter(c_idx, res['layer_probs'][c_idx], color=color, s=120, marker='X', edgecolors='black', zorder=5)
ax1.set_title("Target Concept Viability vs. Injection Delay", fontsize=12, fontweight='bold')
ax1.set_xlabel("Intervention Start Layer (Later start = Context already crystallized)")
ax1.set_ylabel("Final Probability of Injected Concept")
ax1.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0, decimals=0))
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
box_data = []
box_labels = []
box_colors_list = []
if c_layers_shallow:
box_data.append(c_layers_shallow)
box_labels.append("Shallow\n(Short)")
box_colors_list.append(colors["Shallow (Short Context)"])
if c_layers_deep:
box_data.append(c_layers_deep)
box_labels.append("Deep\n(Long)")
box_colors_list.append(colors["Deep (Long Context / Clauses)"])
if c_layers_code:
box_data.append(c_layers_code)
box_labels.append("Code\n(Structured)")
box_colors_list.append(colors["Code (Structured Logic)"])
if len(box_data) >= 2:
bplot = ax2.boxplot(box_data, patch_artist=True, widths=0.5)
ax2.set_xticks(range(1, len(box_data) + 1))
ax2.set_xticklabels(box_labels)
for patch, c in zip(bplot['boxes'], box_colors_list):
patch.set_facecolor(c)
patch.set_alpha(0.6)
for idx, (data, c) in enumerate(zip(box_data, box_colors_list)):
ax2.scatter(np.random.normal(idx + 1, 0.05, len(data)), data, color=c, alpha=0.9, s=50)
ax2.set_title("Crystallization Boundary Distribution", fontsize=12, fontweight='bold')
ax2.set_ylabel("Crystallization Layer (Point of No Return)")
ax2.set_ylim(-1, n_layers + 2)
ax2.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
ax2.grid(True, axis='y', alpha=0.3)
plt.suptitle("reFlow Causal Audit: Context Type Affects Information Crystallization", fontsize=15, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.95])
save_path = os.path.join(report_dir, "task_crystallization_shift.png")
plt.savefig(save_path, bbox_inches='tight', dpi=200)
plt.close()
print(" ================= Conclusions =================")
if c_layers_shallow:
avg_shallow = np.mean(c_layers_shallow)
print(f" > Shallow (short context) avg boundary: Layer {avg_shallow:.1f}")
if c_layers_deep:
avg_deep = np.mean(c_layers_deep)
print(f" > Deep (long context) avg boundary: Layer {avg_deep:.1f}")
if c_layers_code:
avg_code = np.mean(c_layers_code)
print(f" > Code (structured logic) avg boundary: Layer {avg_code:.1f}")
if c_layers_shallow and c_layers_deep:
print(f" > Shallow→Deep boundary shift: \033[93m{np.mean(c_layers_deep) - np.mean(c_layers_shallow):+.1f} Layers\033[0m")
if c_layers_shallow and c_layers_code:
print(f" > Shallow→Code boundary shift: \033[93m{np.mean(c_layers_code) - np.mean(c_layers_shallow):+.1f} Layers\033[0m")
print(f" > Results show: Context complexity affects crystallization boundary.")
print(f" More complex contexts tend to maintain representation fluidity at deeper layers.")
print(f" > Chart saved: {save_path}")
def main_menu():
model, enc, device, report_dir = load_setup_and_model()
experiments = {
'1': ("Recipe Atlas", exp_1_recipe_atlas),
'2': ("Sparsity Profile", exp_2_sparsity_profile),
'3': ("Signal Basis Geometry", exp_3_basis_geometry),
'4': ("Semantic Galaxy (PCA)", exp_4_semantic_galaxy),
'5': ("Semantic Algebra", exp_5_semantic_algebra),
'6': ("Typo Resilience", exp_6_typo_resilience),
'7': ("Layer Probability Evolution", exp_7_layer_evolution),
'8': ("Signal Flow Tracking", exp_8_signal_flow),
'9': ("Causal Ablation Curve", exp_9_causal_ablation),
'10': ("Emotion Surgery", exp_10_emotion_surgery),
'11': ("Concept Inception", exp_11_concept_inception),
'12': ("Genetic Hijack", exp_12_genetic_hijack),
'13': ("Task Crystallization Shift", exp_13_task_crystallization_shift),
}
while True:
print("\n" + "#"*60)
print(" reFlow Interpretability Experiment Suite".center(56))
print("#"*60)
for k, v in experiments.items():
print(f" [{k.rjust(2)}] {v[0]}")
print(" [all] Run all experiments")
print(" [ q ] Quit")
print("#"*60)
choice = input("Enter experiment number(s) to run (space-separated, e.g. '1 3 5'): ").strip().lower()
if choice == 'q' or choice == 'quit':
print("Exiting.")
break
selected_keys = list(experiments.keys()) if choice == 'all' else choice.split()
for k in selected_keys:
if k in experiments:
func = experiments[k][1]
try:
func(model, enc, device, report_dir)
except Exception as e:
print(f"\n [ERROR] Experiment {k} failed: {e}")
import traceback
traceback.print_exc()
else:
print(f"[Ignored] Invalid option: {k}")
if selected_keys:
print(f"\n[INFO] Batch complete. Reports saved to: {report_dir}")
if __name__ == "__main__":
main_menu()