"""Per-subsystem timing to find the tok/s bottleneck. Runs a single forward+backward at (B=8, T=2048) and times each stage via torch.cuda.Event. Reports ms/stage and derived tok/s budget. """ import os, sys, time os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64") sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN B, T = 8, MAX_SEQ_LEN def timeit(name, fn, warmup=1, n=3): for _ in range(warmup): fn(); torch.cuda.synchronize() s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True) times = [] for _ in range(n): torch.cuda.synchronize() s.record(); fn(); e.record(); torch.cuda.synchronize() times.append(s.elapsed_time(e)) avg = sum(times)/len(times) print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})") return avg cfg = PostSemClawConfig() model = PostSemClawModel(cfg).cuda() model.init_weights() model.train() idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long) y = idx.clone() print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n") # Warmup full forward with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): _ = model(idx, y) torch.cuda.synchronize() print("Stage times (3 iter avg):\n") # 1) wte timeit("wte embedding", lambda: model.wte(idx).sum().item()) # 2) sdr_semantic (STE forward) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item()) # 3) sdr binary_only timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item()) # 4) HTM full forward (with reset/learn) with torch.no_grad(): timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item()) # 5) Mamba block stack only with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): def _blocks(): x = model.wte(idx) from train import norm x = norm(x) streams = model.mhc[0].init_streams(x) for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)): def _bfn(h, _b=block): return _b(norm(h)) streams = mhc_layer(streams, _bfn) x = model.mhc[-1].merge_streams(streams) return x.sum().item() timeit("Mamba+mHC blocks (n_layer=4)", _blocks) # 6) Full forward+loss with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): timeit("FULL forward+loss", lambda: model(idx, y).item()) # 7) Full forward+loss+backward def full_fwd_bwd(): model.zero_grad(set_to_none=True) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): loss = model(idx, y) loss.backward() return loss.item() t_full = timeit("FULL forward+backward", full_fwd_bwd) print() print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens") print(f"tok/s per forward: {B*T / (t_full/1000):.0f}") print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")