DuoNeural commited on
Commit
a45bc43
Β·
verified Β·
1 Parent(s): f1a4ead

Add ctm_world_model_v30.py

Browse files
Files changed (1) hide show
  1. experiments/ctm_world_model_v30.py +326 -0
experiments/ctm_world_model_v30.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CTM World Model v30 β€” Gap Test: Different Sine Period
4
+ Archon / DuoNeural 2026-04-30
5
+
6
+ Robustness check #1 for the Tripartite Temporal Principle (Paper 4).
7
+
8
+ If the finding is genuine and not a fluke of T=8 being conveniently equal to
9
+ some architectural constant, then doubling the period to T=16 should shift
10
+ the learned effective delay to ~16 as well.
11
+
12
+ v29 used SINE_PERIOD=8 and found eff_delay β‰ˆ 8 at T_GATE=32.
13
+ Here we use SINE_PERIOD=16, T_GATE sweep {16, 32}.
14
+ Both T_GATE values are >= SINE_PERIOD, so the gate has enough room to find it.
15
+
16
+ Verdict: "PERIOD_SHIFTED" if eff_delay β‰ˆ 16 at T_GATE=32
17
+ Cross-ref: compare to v29's eff_delay β‰ˆ 8 to confirm proportional shift.
18
+ """
19
+
20
+ import torch, numpy as np, json, os, math, time
21
+ from torch import nn
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ print(f"Device: {DEVICE}")
25
+
26
+ # ── Config ────────────────────────────────────────────────────────────────────
27
+ N_OBJ = 8
28
+ SINE_PERIOD = 16 # KEY CHANGE: doubled from v29's T=8
29
+ T_GATE_LIST = [16, 32] # must be >= SINE_PERIOD to give the gate a chance
30
+ TRAIN_STEPS = 40_000 # robustness check, shorter than main experiments
31
+ BATCH_SIZE = 128
32
+ K_PRED = 4 # predict k steps ahead
33
+ HIDDEN_DIM = 128
34
+ N_SLOTS = 4
35
+ LOG_FILE = os.path.expanduser("~/duoneural/ctm_world_model_v30/wm_v30.log")
36
+ OUT_DIR = os.path.expanduser("~/duoneural/ctm_world_model_v30")
37
+ os.makedirs(OUT_DIR, exist_ok=True)
38
+
39
+ # v29 reference results (period=8) for comparison in verdict
40
+ V29_REF_DELAY = {8: 8.0, 16: 8.0, 32: 8.0} # approximate, update after v29 runs
41
+
42
+ def ts():
43
+ return time.strftime('%Y-%m-%d %H:%M:%S')
44
+
45
+ def log(msg):
46
+ print(msg, flush=True)
47
+ with open(LOG_FILE, 'a') as f:
48
+ f.write(f"[{ts()}] {msg}\n")
49
+
50
+ # ── Sine wave data ─────────────────────────────────────────────────────────────
51
+ def generate_sine_batch(batch_size, seq_len, n_obj=N_OBJ, period=SINE_PERIOD):
52
+ """
53
+ N independent sine waves, random phase + slight amplitude jitter.
54
+ OBJ_DIM=1 (position only, partial obs).
55
+ Returns: (B, seq_len, N_OBJ)
56
+ """
57
+ t = torch.arange(seq_len, dtype=torch.float32)
58
+ phases = torch.rand(batch_size, n_obj) * 2 * math.pi # (B, N_OBJ)
59
+ amps = 0.8 + 0.4 * torch.rand(batch_size, n_obj) # (B, N_OBJ) ∈ [0.8, 1.2]
60
+ noise = 0.02 * torch.randn(batch_size, seq_len, n_obj)
61
+
62
+ omega = 2 * math.pi / period
63
+ t_exp = t.unsqueeze(0).unsqueeze(-1) # (1, seq_len, 1)
64
+ phases_e = phases.unsqueeze(1) # (B, 1, N_OBJ)
65
+ amps_e = amps.unsqueeze(1) # (B, 1, N_OBJ)
66
+
67
+ x = amps_e * torch.sin(omega * t_exp + phases_e) + noise # (B, seq_len, N_OBJ)
68
+ return x
69
+
70
+ # ── Architecture (identical to v29) ───────────────────────────────────────────
71
+ class LearnedTemporalGateEncoder(nn.Module):
72
+ """
73
+ Softmax gate over T_GATE past timesteps β€” the thing we're studying.
74
+ One global gate shared across all objects (keeps it clean for analysis).
75
+ """
76
+ def __init__(self, t_gate, obj_dim, hidden_dim):
77
+ super().__init__()
78
+ self.t_gate = t_gate
79
+ self.obj_dim = obj_dim
80
+ self.hidden_dim = hidden_dim
81
+ # THE gate: T_GATE learnable logits β†’ softmax β†’ weighted sum over time
82
+ self.gate_logits = nn.Parameter(torch.zeros(t_gate))
83
+ # Shared per-timestep encoder
84
+ self.encoder = nn.Sequential(
85
+ nn.Linear(obj_dim, hidden_dim),
86
+ nn.LayerNorm(hidden_dim),
87
+ nn.GELU(),
88
+ nn.Linear(hidden_dim, hidden_dim),
89
+ )
90
+
91
+ def forward(self, history):
92
+ """
93
+ history: (B, T_GATE, N_OBJ, obj_dim)
94
+ returns: encoded (B, N_OBJ, hidden_dim), gates (T_GATE,)
95
+ """
96
+ B, T, N, D = history.shape
97
+ gates = torch.softmax(self.gate_logits, dim=0) # (T,)
98
+
99
+ # Encode every timestep independently (shared weights)
100
+ h_flat = history.reshape(B * T * N, D)
101
+ enc_flat = self.encoder(h_flat) # (B*T*N, hidden_dim)
102
+ enc = enc_flat.reshape(B, T, N, self.hidden_dim)
103
+
104
+ # Temporal attention: weighted sum over T dimension
105
+ gates_e = gates.view(1, T, 1, 1)
106
+ out = (enc * gates_e).sum(dim=1) # (B, N_OBJ, hidden_dim)
107
+ return out, gates
108
+
109
+
110
+ class SlotDynamics(nn.Module):
111
+ """Self-attention over objects + MLP, then linear decode. Same as v29."""
112
+ def __init__(self, hidden_dim, n_slots, obj_dim):
113
+ super().__init__()
114
+ self.n_slots = n_slots
115
+ self.attn = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
116
+ self.norm1 = nn.LayerNorm(hidden_dim)
117
+ self.ff = nn.Sequential(
118
+ nn.Linear(hidden_dim, hidden_dim * 2),
119
+ nn.GELU(),
120
+ nn.Linear(hidden_dim * 2, hidden_dim),
121
+ )
122
+ self.norm2 = nn.LayerNorm(hidden_dim)
123
+ self.decoder = nn.Linear(hidden_dim, obj_dim)
124
+
125
+ def forward(self, enc):
126
+ """enc: (B, N_OBJ, hidden_dim) β†’ pred: (B, N_OBJ, obj_dim)"""
127
+ x, _ = self.attn(enc, enc, enc)
128
+ x = self.norm1(enc + x)
129
+ x = self.norm2(x + self.ff(x))
130
+ return self.decoder(x)
131
+
132
+
133
+ class SineCTM(nn.Module):
134
+ """Full model: gate encoder β†’ slot dynamics β†’ prediction."""
135
+ def __init__(self, t_gate, obj_dim=1, hidden_dim=HIDDEN_DIM, n_slots=N_SLOTS):
136
+ super().__init__()
137
+ self.t_gate = t_gate
138
+ self.gate_enc = LearnedTemporalGateEncoder(t_gate, obj_dim, hidden_dim)
139
+ self.dynamics = SlotDynamics(hidden_dim, n_slots, obj_dim)
140
+
141
+ def forward(self, history):
142
+ """history: (B, T_GATE, N_OBJ, 1) β†’ pred: (B, N_OBJ), gates: (T_GATE,)"""
143
+ enc, gates = self.gate_enc(history) # (B, N_OBJ, hidden_dim), (T_GATE,)
144
+ pred = self.dynamics(enc) # (B, N_OBJ, 1)
145
+ return pred.squeeze(-1), gates # (B, N_OBJ), (T_GATE,)
146
+
147
+ # ── Gate analysis helper ───────────────────────────────────────────────────────
148
+ def analyze_gates(model, t_gate):
149
+ """Returns dict of gate metrics given trained model."""
150
+ with torch.no_grad():
151
+ g = torch.softmax(model.gate_enc.gate_logits, dim=0).cpu().numpy()
152
+ peak_idx = int(np.argmax(g))
153
+ peak_prob = float(g[peak_idx])
154
+ # index 0 = most recent (delay=0), index T-1 = oldest (delay=T-1)
155
+ # so effective delay = sum over i of (i * g[i]), reversed:
156
+ # delay at position i = (T-1 - i)? No β€” gate[0] is most recent = delay 0.
157
+ # Actually: history[:, 0, :] = t-T+1 (oldest), history[:, T-1, :] = t (most recent)
158
+ # So delay i steps back = gate index T-1-i.
159
+ # eff_delay = sum_i delay_i * gate_weight_at_that_delay
160
+ delays = np.arange(t_gate)[::-1].copy() # delays[0]=T-1, delays[T-1]=0
161
+ eff_delay = float(np.sum(delays * g))
162
+ gate_spec = float(np.sum(g * (delays - eff_delay)**2) ** 0.5)
163
+ return {
164
+ "gate_dist": g,
165
+ "peak_idx": peak_idx,
166
+ "peak_delay": t_gate - 1 - peak_idx,
167
+ "peak_prob": peak_prob,
168
+ "eff_delay": eff_delay,
169
+ "gate_spec": gate_spec,
170
+ }
171
+
172
+ # ── Training loop ──────────────────────────────────────────────────────────────
173
+ def run_experiment(t_gate):
174
+ log(f"\n{'='*60}")
175
+ log(f"T_GATE={t_gate} β€” Period Shift Gap Test (SINE_PERIOD={SINE_PERIOD})")
176
+ log(f"{'='*60}")
177
+
178
+ SEQ_LEN = t_gate + K_PRED + 10
179
+
180
+ # Resume check
181
+ ckpt_path = os.path.join(OUT_DIR, f"ckpt_v30_tg{t_gate}.pt")
182
+ model = SineCTM(t_gate=t_gate).to(DEVICE)
183
+ opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
184
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, TRAIN_STEPS)
185
+ start_step = 0
186
+ best_mse = float('inf')
187
+
188
+ if os.path.exists(ckpt_path):
189
+ log(f" Resuming from checkpoint: {ckpt_path}")
190
+ ckpt = torch.load(ckpt_path, map_location=DEVICE)
191
+ model.load_state_dict(ckpt['model'])
192
+ opt.load_state_dict(ckpt['opt'])
193
+ sched.load_state_dict(ckpt['sched'])
194
+ start_step = ckpt['step']
195
+ best_mse = ckpt.get('best_mse', float('inf'))
196
+ log(f" Resumed at step {start_step}, best_mse={best_mse:.6f}")
197
+
198
+ for step in range(start_step + 1, TRAIN_STEPS + 1):
199
+ model.train()
200
+ seq = generate_sine_batch(BATCH_SIZE, SEQ_LEN).to(DEVICE)
201
+
202
+ t_start = torch.randint(0, SEQ_LEN - t_gate - K_PRED, (1,)).item()
203
+ history = seq[:, t_start:t_start+t_gate, :].unsqueeze(-1) # (B, T, N, 1)
204
+ target = seq[:, t_start+t_gate+K_PRED-1, :] # (B, N)
205
+
206
+ pred, gates = model(history)
207
+ loss = ((pred - target)**2).mean()
208
+
209
+ opt.zero_grad()
210
+ loss.backward()
211
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
212
+ opt.step()
213
+ sched.step()
214
+
215
+ if loss.item() < best_mse:
216
+ best_mse = loss.item()
217
+
218
+ if step % 5000 == 0 or step == TRAIN_STEPS:
219
+ m = analyze_gates(model, t_gate)
220
+ log(f" step {step:6d} | loss={loss.item():.6f} | "
221
+ f"peak@t-{m['peak_delay']}({m['peak_prob']:.3f}) | "
222
+ f"eff_delay={m['eff_delay']:.2f}")
223
+
224
+ # Save checkpoint every 5k steps
225
+ torch.save({
226
+ 'model': model.state_dict(),
227
+ 'opt': opt.state_dict(),
228
+ 'sched': sched.state_dict(),
229
+ 'step': step,
230
+ 'best_mse': best_mse,
231
+ }, ckpt_path)
232
+
233
+ # Final analysis
234
+ m = analyze_gates(model, t_gate)
235
+ period_delta = round(m['eff_delay'] - SINE_PERIOD, 2)
236
+
237
+ log(f"\n ── T_GATE={t_gate} FINAL ──")
238
+ log(f" gate dist: {np.round(m['gate_dist'], 3).tolist()}")
239
+ log(f" peak: t-{m['peak_delay']} (prob={m['peak_prob']:.4f})")
240
+ log(f" eff_delay: {m['eff_delay']:.2f}")
241
+ log(f" SINE_PERIOD: {SINE_PERIOD}")
242
+ log(f" delta vs T: {period_delta:+.2f} ← key test")
243
+ log(f" best_loss: {best_mse:.6f}")
244
+
245
+ result = {
246
+ "t_gate": t_gate,
247
+ "sine_period": SINE_PERIOD,
248
+ "max_delay_used": round(m['eff_delay'], 2),
249
+ "peak_delay": m['peak_delay'],
250
+ "peak_prob": round(m['peak_prob'], 4),
251
+ "gate_spec": round(m['gate_spec'], 4),
252
+ "best_loss": round(best_mse, 8),
253
+ "delta_vs_period": period_delta,
254
+ "gate_distribution": [round(float(x), 4) for x in m['gate_dist']],
255
+ }
256
+
257
+ # Verdict logic
258
+ if abs(period_delta) < 2.5 and t_gate >= SINE_PERIOD:
259
+ # eff_delay is within 2.5 steps of the sine period β€” tracking it
260
+ log(f" *** PERIOD_TRACKING: eff_delay β‰ˆ SINE_PERIOD={SINE_PERIOD} ***")
261
+ result["verdict"] = "PERIOD_TRACKING"
262
+ elif m['eff_delay'] < 3.0:
263
+ log(f" *** MARKOVIAN: gate collapsed to present ***")
264
+ result["verdict"] = "MARKOVIAN"
265
+ else:
266
+ log(f" *** EXTENDED: uses history but not cleanly period-aligned ***")
267
+ result["verdict"] = "EXTENDED"
268
+
269
+ return result
270
+
271
+ # ── Main ───────────────────────────────────────────────────────────────────────
272
+ log(f"CTM World Model v30 β€” Gap Test: SINE_PERIOD={SINE_PERIOD} (was 8 in v29)")
273
+ log(f"N_OBJ={N_OBJ}, TRAIN_STEPS={TRAIN_STEPS}, T_GATE sweep={T_GATE_LIST}")
274
+ log(f"Hypothesis: eff_delay should shift to β‰ˆ{SINE_PERIOD} (doubled from v29)")
275
+ log(f"Device: {DEVICE}")
276
+
277
+ all_results = {}
278
+
279
+ # Load prior results if resuming
280
+ results_path = os.path.join(OUT_DIR, "results_v30.json")
281
+ if os.path.exists(results_path):
282
+ with open(results_path) as f:
283
+ all_results = json.load(f)
284
+ log(f"Loaded existing results for T_GATE keys: {list(all_results.keys())}")
285
+
286
+ for tg in T_GATE_LIST:
287
+ if str(tg) in all_results:
288
+ log(f"T_GATE={tg} already in results, skipping (delete checkpoint to re-run)")
289
+ continue
290
+ r = run_experiment(tg)
291
+ all_results[str(tg)] = r
292
+ with open(results_path, 'w') as f:
293
+ json.dump(all_results, f, indent=2)
294
+ log(f"[checkpoint] results_v30.json saved (T_GATE={tg} done)")
295
+
296
+ # ── Summary & comparison vs v29 ───────────────────────────────────────────────
297
+ log(f"\n{'='*60}")
298
+ log(f"V30 COMPLETE β€” Period Shift Test Summary")
299
+ log(f"{'='*60}")
300
+ log(f"{'T_GATE':>8} | {'eff_delay':>10} | {'vs_period':>10} | {'verdict'}")
301
+ for tg_str, r in all_results.items():
302
+ log(f"{tg_str:>8} | {r['max_delay_used']:>10.2f} | {r['delta_vs_period']:>+10.2f} | {r['verdict']}")
303
+
304
+ # Final cross-experiment verdict
305
+ if "32" in all_results:
306
+ r32 = all_results["32"]
307
+ # The key question: did the delay shift proportionally from v29's ~8 to ~16?
308
+ if r32["verdict"] == "PERIOD_TRACKING":
309
+ log(f"\n *** VERDICT: PERIOD_SHIFTED ***")
310
+ log(f" eff_delay({SINE_PERIOD}) β‰ˆ {r32['max_delay_used']:.1f} β€” proportional shift confirmed")
311
+ log(f" Tripartite principle is NOT a T=8 fluke. Genuine period tracking.")
312
+ all_results["global_verdict"] = "PERIOD_SHIFTED"
313
+ else:
314
+ log(f"\n *** VERDICT: NO_SHIFT ***")
315
+ log(f" eff_delay={r32['max_delay_used']:.1f} β‰  SINE_PERIOD={SINE_PERIOD}")
316
+ log(f" Period tracking does NOT generalize β€” need to investigate further.")
317
+ all_results["global_verdict"] = "NO_SHIFT"
318
+ elif "16" in all_results:
319
+ r16 = all_results["16"]
320
+ if r16["verdict"] == "PERIOD_TRACKING":
321
+ log(f"\n *** VERDICT: PERIOD_SHIFTED (at T_GATE=16) ***")
322
+ all_results["global_verdict"] = "PERIOD_SHIFTED"
323
+
324
+ with open(results_path, 'w') as f:
325
+ json.dump(all_results, f, indent=2)
326
+ log(f"All results saved to {results_path}")