WrinkleBrane / Wrinkle /09_standalone_model /experiments_round2.py
WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
#!/usr/bin/env python3
"""Direction 9 — Round 2: Extended Experiments.
Deeper exploration of the standalone WrinkleBrane model, probing:
1. Synthetic Grammar LM training (never tested in round 1!)
2. Extended associative recall training (push accuracy higher)
3. Longer sequences & overload behavior (T >> K)
4. Autoregressive generation from sequential mode
5. Head specialization analysis (do heads learn different things?)
6. Temperature sensitivity sweep
7. WrinkleBrane vs Transformer on longer sequences (crossover point)
8. Gate opening dynamics (how does the GatedFFN gate evolve?)
9. Codebook evolution (how much do learned codes drift from Hadamard init?)
10. Copy task at increasing difficulty (longer sequences)
Usage:
PYTHONPATH=src python3 experiments_round2.py
Tuned for CPU: all 10 experiments should complete in ~5 minutes.
"""
from __future__ import annotations
import math
import sys
import time
from typing import Dict, List
import torch
from torch import nn, Tensor
from wrinklebrane.standalone_model import WrinkleBraneConfig, WrinkleBraneModel
from wrinklebrane.baseline_transformer import SmallTransformer, SmallTransformerConfig
from wrinklebrane.tasks import (
SequenceCopyTask,
AssociativeRecallTask,
SyntheticGrammarTask,
compute_accuracy,
)
from wrinklebrane.train import train_loop, evaluate
# ===========================================================================
# Helpers
# ===========================================================================
def make_config(**overrides) -> WrinkleBraneConfig:
"""Small config suitable for fast CPU experiments."""
defaults = dict(
vocab_size=32, d_model=64, n_layers=2, n_heads=2,
L=16, K=32, max_seq_len=128, dropout=0.0,
ffn_expansion=2, ortho_lambda=0.01,
)
defaults.update(overrides)
return WrinkleBraneConfig(**defaults)
def separator(title: str) -> None:
line = "=" * 70
print(f"\n{line}\n {title}\n{line}\n", flush=True)
def quick_train(model, task, n_steps, batch_size=32, lr=1e-3, ignore_index=-100,
log_every=50, label=""):
"""Train with inline progress printing (no buffering)."""
model.train()
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
ortho_lam = getattr(model, 'config', None)
ortho_lam = ortho_lam.ortho_lambda if ortho_lam and hasattr(ortho_lam, 'ortho_lambda') else 0.0
history = []
for step in range(n_steps):
opt.zero_grad()
inp, tgt = task.generate_batch(batch_size)
logits = model(inp)
B, T, V = logits.shape
loss = nn.functional.cross_entropy(
logits.reshape(B * T, V), tgt.reshape(B * T),
ignore_index=ignore_index,
)
total = loss
if ortho_lam > 0 and hasattr(model, 'ortho_loss'):
total = total + ortho_lam * model.ortho_loss()
total.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step % log_every == 0 or step == n_steps - 1:
with torch.no_grad():
acc = compute_accuracy(logits.detach(), tgt, ignore_index)
entry = {"step": step, "loss": float(loss.detach()), "acc": acc}
history.append(entry)
if label:
print(f" {label} step {step:4d}: loss={entry['loss']:.4f}, acc={acc:.4f}",
flush=True)
return history
# ===========================================================================
# Experiment 1: Synthetic Grammar LM
# ===========================================================================
def experiment_1_grammar_lm():
separator("Experiment 1: Synthetic Grammar Language Modeling")
torch.manual_seed(42)
config = make_config(K=32, max_seq_len=64)
task = SyntheticGrammarTask(vocab_size=config.vocab_size, seq_len=16)
# Compute theoretical entropy bounds
n_data = config.vocab_size - task.token_offset
n_det = len(task.det_rules)
n_stoch = len(task.stoch_rules)
n_wild = len(task.wild_tokens)
stoch_entropies = []
for _, (a, b, p) in task.stoch_rules.items():
H = -(p * math.log(p) + (1-p) * math.log(1-p))
stoch_entropies.append(H)
avg_stoch_H = sum(stoch_entropies) / max(len(stoch_entropies), 1)
wild_H = math.log(n_data)
frac_det = n_det / n_data
frac_stoch = n_stoch / n_data
frac_wild = n_wild / n_data
theoretical_H = frac_det * 0 + frac_stoch * avg_stoch_H + frac_wild * wild_H
uniform_loss = math.log(n_data)
print(f" Grammar: {n_det} det ({frac_det:.0%}), {n_stoch} stoch ({frac_stoch:.0%}), "
f"{n_wild} wild ({frac_wild:.0%})")
print(f" Theoretical min H: {theoretical_H:.3f} nats | Uniform: {uniform_loss:.3f} nats")
model = WrinkleBraneModel(config)
print(f" Model: {model.count_parameters()['total']:,} params")
history = quick_train(model, task, n_steps=300, batch_size=32, log_every=50,
label="grammar")
eval_result = evaluate(model, task, n_batches=10, ignore_index=-100)
print(f"\n Final eval: loss={eval_result['loss']:.4f}, "
f"acc={eval_result['accuracy']:.4f}, "
f"ppl={eval_result['perplexity']:.2f}")
beat_uniform = eval_result["loss"] < uniform_loss * 0.95
print(f" Beats uniform ({uniform_loss:.3f})? {'YES' if beat_uniform else 'NO'}")
return {"eval": eval_result, "theoretical_H": theoretical_H, "beat_uniform": beat_uniform}
# ===========================================================================
# Experiment 2: Extended Associative Recall
# ===========================================================================
def experiment_2_extended_recall():
separator("Experiment 2: Extended Associative Recall")
torch.manual_seed(42)
config = make_config(K=32, n_layers=2, n_heads=4, d_model=128)
# 4-pair recall (harder)
task4 = AssociativeRecallTask(vocab_size=config.vocab_size, n_pairs=4)
model4 = WrinkleBraneModel(config)
print(f" 4-pair: {model4.count_parameters()['total']:,} params")
history4 = quick_train(model4, task4, n_steps=600, batch_size=32,
ignore_index=task4.ignore_index, log_every=100, label="4-pair")
eval4 = evaluate(model4, task4, n_batches=20, ignore_index=task4.ignore_index)
chance = 1.0 / (config.vocab_size - task4.token_offset)
print(f" 4-pair eval: acc={eval4['accuracy']:.4f} "
f"({eval4['accuracy']/chance:.1f}x chance), loss={eval4['loss']:.4f}")
# 2-pair recall (easier)
torch.manual_seed(123)
task2 = AssociativeRecallTask(vocab_size=config.vocab_size, n_pairs=2)
model2 = WrinkleBraneModel(config)
history2 = quick_train(model2, task2, n_steps=500, batch_size=32,
ignore_index=task2.ignore_index, log_every=100, label="2-pair")
eval2 = evaluate(model2, task2, n_batches=20, ignore_index=task2.ignore_index)
print(f" 2-pair eval: acc={eval2['accuracy']:.4f}, loss={eval2['loss']:.4f}")
return {"4pair": eval4, "2pair": eval2}
# ===========================================================================
# Experiment 3: Sequence Length Stress Test
# ===========================================================================
def experiment_3_length_stress():
separator("Experiment 3: Sequence Length Stress Test")
torch.manual_seed(42)
config = make_config(K=32, max_seq_len=512)
model = WrinkleBraneModel(config)
model.eval()
lengths = [8, 16, 32, 64, 128, 256]
results = {}
for T in lengths:
tokens = torch.randint(0, config.vocab_size, (2, T))
reps = 5 if T <= 64 else 2
t0 = time.time()
with torch.no_grad():
for _ in range(reps):
logits = model(tokens)
elapsed = (time.time() - t0) / reps
ok = not (torch.isnan(logits).any().item() or torch.isinf(logits).any().item())
max_abs = logits.abs().max().item()
overload = "OVERLOAD" if T > config.K else "within cap"
results[T] = {"time_ms": elapsed * 1000, "max_abs": max_abs, "ok": ok}
print(f" T={T:4d} ({overload:10s}): {elapsed*1000:7.1f}ms, "
f"max|logit|={max_abs:.1f}, {'OK' if ok else 'FAIL'}", flush=True)
t1, t2 = results[lengths[0]]["time_ms"], results[lengths[-1]]["time_ms"]
ratio = t2 / max(t1, 0.01)
len_ratio = lengths[-1] / lengths[0]
print(f"\n Time scaling: {lengths[0]}{lengths[-1]}: "
f"{len_ratio:.0f}x length, {ratio:.1f}x time")
print(f" Sub-quadratic? {'YES' if ratio < len_ratio**2 * 0.5 else 'UNCLEAR'}")
return results
# ===========================================================================
# Experiment 4: Autoregressive Generation
# ===========================================================================
def experiment_4_autoregressive_generation():
separator("Experiment 4: Autoregressive Generation (Sequential Mode)")
torch.manual_seed(42)
config = make_config(K=32, max_seq_len=64, persistence_lambda=1.0)
task = SyntheticGrammarTask(vocab_size=config.vocab_size, seq_len=16)
model = WrinkleBraneModel(config)
print(" Training on grammar for 200 steps...")
quick_train(model, task, n_steps=200, batch_size=32, log_every=200, label="gen-train")
model.eval()
gen_len = 30
n_samples = 5
print(f"\n Generating {n_samples} sequences of length {gen_len}:")
total_det, total_det_ok = 0, 0
total_stoch, total_stoch_ok = 0, 0
for s in range(n_samples):
tokens = [task.bos_token]
start = task.token_offset + torch.randint(0, config.vocab_size - task.token_offset, (1,)).item()
tokens.append(start)
states = None
input_ids = torch.tensor([[tokens[0], tokens[1]]], dtype=torch.long)
with torch.no_grad():
logits, states = model.forward_sequential(input_ids, states)
current = tokens[-1]
for _ in range(gen_len - 2):
inp = torch.tensor([[current]], dtype=torch.long)
with torch.no_grad():
logits, states = model.forward_sequential(inp, states)
probs = torch.softmax(logits[0, -1], dim=-1)
current = torch.multinomial(probs, 1).item()
tokens.append(current)
# Analyze rule-following
n_det_ok, n_det_tot, n_stoch_ok, n_stoch_tot = 0, 0, 0, 0
for i in range(len(tokens) - 1):
t, t_next = tokens[i], tokens[i + 1]
if t in task.det_rules:
n_det_tot += 1
if t_next == task.det_rules[t]:
n_det_ok += 1
elif t in task.stoch_rules:
n_stoch_tot += 1
a, b, _ = task.stoch_rules[t]
if t_next in (a, b):
n_stoch_ok += 1
total_det += n_det_tot; total_det_ok += n_det_ok
total_stoch += n_stoch_tot; total_stoch_ok += n_stoch_ok
seq_str = " ".join(str(t) for t in tokens[:20])
print(f" Sample {s+1}: [{seq_str}...] "
f"det={n_det_ok}/{n_det_tot} stoch={n_stoch_ok}/{n_stoch_tot}", flush=True)
det_rate = total_det_ok / max(total_det, 1)
stoch_rate = total_stoch_ok / max(total_stoch, 1)
print(f"\n Overall: det rules followed {det_rate:.0%}, "
f"stoch rules valid {stoch_rate:.0%}")
return {"det_rate": det_rate, "stoch_rate": stoch_rate}
# ===========================================================================
# Experiment 5: Head Specialization Analysis
# ===========================================================================
def experiment_5_head_specialization():
separator("Experiment 5: Head Specialization Analysis")
torch.manual_seed(42)
config = make_config(K=32, n_layers=2, n_heads=4, d_model=128)
task = SequenceCopyTask(vocab_size=config.vocab_size, seq_len=6)
model = WrinkleBraneModel(config)
quick_train(model, task, n_steps=300, batch_size=32, log_every=300,
ignore_index=task.ignore_index, label="heads")
model.eval()
for li, layer in enumerate(model.layers):
print(f"\n Layer {li}:", flush=True)
for h in range(config.n_heads):
C = layer.codebooks[h]()
temp = layer.temperatures[h].item()
P = layer.read_projections[h]
gram = C.T @ C
off_diag = gram - torch.diag(torch.diag(gram))
max_interf = off_diag.abs().max().item()
mean_interf = off_diag.abs().mean().item()
# Read entropy
q = torch.randn(100, config.d_head)
w = torch.softmax(q @ P / temp, dim=-1)
entropy = -(w * (w + 1e-10).log()).sum(-1).mean().item()
max_H = math.log(config.K)
print(f" Head {h}: T={temp:.4f}, "
f"interf={mean_interf:.4f}(mean)/{max_interf:.4f}(max), "
f"read_H={entropy:.2f}/{max_H:.2f}")
if hasattr(layer.ffn, 'gate'):
print(f" Gate: {layer.ffn.gate.item():.4f}")
return True
# ===========================================================================
# Experiment 6: Temperature Sensitivity Sweep
# ===========================================================================
def experiment_6_temperature_sweep():
separator("Experiment 6: Temperature Sensitivity Sweep")
temperatures = [0.01, 0.05, 0.1, 0.5, 1.0]
task = SequenceCopyTask(vocab_size=32, seq_len=6)
results = {}
for temp in temperatures:
torch.manual_seed(42)
config = make_config(temperature=temp, K=32)
model = WrinkleBraneModel(config)
quick_train(model, task, n_steps=200, batch_size=32, log_every=200,
ignore_index=task.ignore_index)
ev = evaluate(model, task, n_batches=5, ignore_index=task.ignore_index)
learned_temps = [t.item() for layer in model.layers for t in layer.temperatures]
avg_t = sum(learned_temps) / len(learned_temps)
results[temp] = {"eval_loss": ev["loss"], "eval_acc": ev["accuracy"],
"learned_temp": avg_t}
print(f" T_init={temp:.2f} → T_learned={avg_t:.4f}, "
f"loss={ev['loss']:.4f}, acc={ev['accuracy']:.4f}", flush=True)
best = min(results, key=lambda t: results[t]["eval_loss"])
print(f"\n Best initial temperature: {best}")
return results
# ===========================================================================
# Experiment 7: WB vs Transformer — Sequence Length Comparison
# ===========================================================================
def experiment_7_length_comparison():
separator("Experiment 7: WB vs Transformer — Length Comparison")
seq_lengths = [4, 8, 16]
results = {}
for seq_len in seq_lengths:
total_len = 2 * seq_len
torch.manual_seed(42)
wb_cfg = make_config(K=max(32, seq_len * 2),
max_seq_len=max(64, total_len + 8))
tf_cfg = SmallTransformerConfig(
vocab_size=wb_cfg.vocab_size, d_model=wb_cfg.d_model,
max_seq_len=wb_cfg.max_seq_len, n_layers=wb_cfg.n_layers,
n_heads=wb_cfg.n_heads, ffn_expansion=wb_cfg.ffn_expansion,
dropout=0.0, weight_tying=True,
)
task = SequenceCopyTask(vocab_size=wb_cfg.vocab_size, seq_len=seq_len)
wb_model = WrinkleBraneModel(wb_cfg)
quick_train(wb_model, task, n_steps=300, batch_size=32, log_every=300,
ignore_index=task.ignore_index)
wb_ev = evaluate(wb_model, task, n_batches=5, ignore_index=task.ignore_index)
torch.manual_seed(42)
tf_model = SmallTransformer(tf_cfg)
quick_train(tf_model, task, n_steps=300, batch_size=32, log_every=300,
ignore_index=task.ignore_index)
tf_ev = evaluate(tf_model, task, n_batches=5, ignore_index=task.ignore_index)
ratio = wb_ev["loss"] / max(tf_ev["loss"], 1e-6)
results[seq_len] = {
"wb_loss": wb_ev["loss"], "wb_acc": wb_ev["accuracy"],
"tf_loss": tf_ev["loss"], "tf_acc": tf_ev["accuracy"],
"ratio": ratio,
}
print(f" L={seq_len:2d} (T={total_len:3d}): "
f"WB={wb_ev['loss']:.4f}/{wb_ev['accuracy']:.2%} | "
f"TF={tf_ev['loss']:.4f}/{tf_ev['accuracy']:.2%} | "
f"ratio={ratio:.2f}x", flush=True)
r_short = results[seq_lengths[0]]["ratio"]
r_long = results[seq_lengths[-1]]["ratio"]
print(f"\n Ratio: {r_short:.2f}x (short) → {r_long:.2f}x (long)")
print(f" WB improving at longer seqs? {'YES' if r_long < r_short else 'NO'}")
return results
# ===========================================================================
# Experiment 8: Gate Opening Dynamics
# ===========================================================================
def experiment_8_gate_dynamics():
separator("Experiment 8: Gate Opening Dynamics")
torch.manual_seed(42)
config = make_config(K=32, n_layers=3)
task = SequenceCopyTask(vocab_size=config.vocab_size, seq_len=6)
model = WrinkleBraneModel(config)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
gate_history = []
for step in range(300):
model.train()
opt.zero_grad()
inp, tgt = task.generate_batch(32)
logits = model(inp)
B, T, V = logits.shape
loss = nn.functional.cross_entropy(
logits.reshape(B * T, V), tgt.reshape(B * T),
ignore_index=task.ignore_index)
if config.ortho_lambda > 0:
loss = loss + config.ortho_lambda * model.ortho_loss()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step % 30 == 0 or step == 299:
gates = [layer.ffn.gate.item() for layer in model.layers]
gate_history.append({"step": step, "gates": gates, "loss": float(loss.detach())})
g = ", ".join(f"L{i}={g:.4f}" for i, g in enumerate(gates))
print(f" step {step:3d}: loss={float(loss.detach()):.4f}, gates: {g}", flush=True)
print(f"\n Gate evolution:")
for i in range(config.n_layers):
g0 = gate_history[0]["gates"][i]
gf = gate_history[-1]["gates"][i]
print(f" Layer {i}: {g0:.4f}{gf:.4f} "
f"({'opened' if abs(gf) > abs(g0) + 0.001 else 'minimal change'})")
return gate_history
# ===========================================================================
# Experiment 9: Codebook Evolution
# ===========================================================================
def experiment_9_codebook_evolution():
separator("Experiment 9: Codebook Evolution from Hadamard Init")
torch.manual_seed(42)
config = make_config(K=32, n_layers=2, n_heads=2)
model = WrinkleBraneModel(config)
initial_codes = {}
for li, layer in enumerate(model.layers):
for hi, cb in enumerate(layer.codebooks):
initial_codes[(li, hi)] = cb().detach().clone()
task = SequenceCopyTask(vocab_size=config.vocab_size, seq_len=6)
print(" Training for 300 steps...", flush=True)
quick_train(model, task, n_steps=300, batch_size=32, log_every=300,
ignore_index=task.ignore_index)
print(f"\n Codebook drift analysis:")
for li, layer in enumerate(model.layers):
for hi, cb in enumerate(layer.codebooks):
C0 = initial_codes[(li, hi)]
Cf = cb().detach()
drift = (Cf - C0).norm().item() / C0.norm().item()
gram = Cf.T @ Cf
off = gram - torch.diag(torch.diag(gram))
ortho_viol = off.abs().max().item()
cos = torch.nn.functional.cosine_similarity(C0, Cf, dim=0)
print(f" L{li}H{hi}: drift={drift:.4f}, "
f"cos_sim={cos.mean().item():.4f}(avg)/{cos.min().item():.4f}(min), "
f"ortho_violation={ortho_viol:.4f}")
return True
# ===========================================================================
# Experiment 10: Copy Task Difficulty Scaling
# ===========================================================================
def experiment_10_copy_difficulty():
separator("Experiment 10: Copy Task — Difficulty Scaling")
seq_lengths = [4, 8, 12, 16, 24]
results = {}
for seq_len in seq_lengths:
torch.manual_seed(42)
total_len = 2 * seq_len
config = make_config(
K=max(32, total_len + 8),
max_seq_len=max(64, total_len + 8),
)
task = SequenceCopyTask(vocab_size=config.vocab_size, seq_len=seq_len)
model = WrinkleBraneModel(config)
quick_train(model, task, n_steps=300, batch_size=32, log_every=300,
ignore_index=task.ignore_index)
ev = evaluate(model, task, n_batches=5, ignore_index=task.ignore_index)
params = model.count_parameters()["total"]
results[seq_len] = {"loss": ev["loss"], "acc": ev["accuracy"], "params": params}
print(f" L={seq_len:2d} (T={total_len:3d}, {params:,}p): "
f"loss={ev['loss']:.4f}, acc={ev['accuracy']:.4f}", flush=True)
for L in seq_lengths:
if results[L]["acc"] < 0.5:
print(f"\n Accuracy drops below 50% at seq_len={L}")
break
else:
print(f"\n Model maintains >50% accuracy through seq_len={seq_lengths[-1]}!")
return results
# ===========================================================================
# Main
# ===========================================================================
def main():
print("=" * 70)
print(" Direction 9: Standalone WrinkleBrane — Round 2 Experiments")
print("=" * 70, flush=True)
t0 = time.time()
results = {}
completed = 0
experiments = [
("grammar", experiment_1_grammar_lm),
("recall", experiment_2_extended_recall),
("length_stress", experiment_3_length_stress),
("generation", experiment_4_autoregressive_generation),
("heads", experiment_5_head_specialization),
("temperature", experiment_6_temperature_sweep),
("length_comparison", experiment_7_length_comparison),
("gates", experiment_8_gate_dynamics),
("codebooks", experiment_9_codebook_evolution),
("copy_difficulty", experiment_10_copy_difficulty),
]
for name, fn in experiments:
exp_t0 = time.time()
try:
results[name] = fn()
completed += 1
print(f" [{name}] done in {time.time()-exp_t0:.1f}s", flush=True)
except Exception as e:
print(f"\n *** {name} FAILED: {e} ***", flush=True)
import traceback
traceback.print_exc()
elapsed = time.time() - t0
separator("ROUND 2 COMPLETE")
print(f" Total time: {elapsed:.1f}s ({elapsed/60:.1f} minutes)")
print(f" {completed}/10 experiments completed")
print(flush=True)
return results
if __name__ == "__main__":
results = main()