Spaces:
Runtime error
Runtime error
| """Per-subsystem timing to find the tok/s bottleneck. | |
| Runs a single forward+backward at (B=8, T=2048) and times each stage via | |
| torch.cuda.Event. Reports ms/stage and derived tok/s budget. | |
| """ | |
| import os, sys, time | |
| os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64") | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN | |
| B, T = 8, MAX_SEQ_LEN | |
| def timeit(name, fn, warmup=1, n=3): | |
| for _ in range(warmup): | |
| fn(); torch.cuda.synchronize() | |
| s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True) | |
| times = [] | |
| for _ in range(n): | |
| torch.cuda.synchronize() | |
| s.record(); fn(); e.record(); torch.cuda.synchronize() | |
| times.append(s.elapsed_time(e)) | |
| avg = sum(times)/len(times) | |
| print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})") | |
| return avg | |
| cfg = PostSemClawConfig() | |
| model = PostSemClawModel(cfg).cuda() | |
| model.init_weights() | |
| model.train() | |
| idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long) | |
| y = idx.clone() | |
| print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n") | |
| # Warmup full forward | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| _ = model(idx, y) | |
| torch.cuda.synchronize() | |
| print("Stage times (3 iter avg):\n") | |
| # 1) wte | |
| timeit("wte embedding", lambda: model.wte(idx).sum().item()) | |
| # 2) sdr_semantic (STE forward) | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item()) | |
| # 3) sdr binary_only | |
| timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item()) | |
| # 4) HTM full forward (with reset/learn) | |
| with torch.no_grad(): | |
| timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item()) | |
| # 5) Mamba block stack only | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| def _blocks(): | |
| x = model.wte(idx) | |
| from train import norm | |
| x = norm(x) | |
| streams = model.mhc[0].init_streams(x) | |
| for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)): | |
| def _bfn(h, _b=block): return _b(norm(h)) | |
| streams = mhc_layer(streams, _bfn) | |
| x = model.mhc[-1].merge_streams(streams) | |
| return x.sum().item() | |
| timeit("Mamba+mHC blocks (n_layer=4)", _blocks) | |
| # 6) Full forward+loss | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| timeit("FULL forward+loss", lambda: model(idx, y).item()) | |
| # 7) Full forward+loss+backward | |
| def full_fwd_bwd(): | |
| model.zero_grad(set_to_none=True) | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| loss = model(idx, y) | |
| loss.backward() | |
| return loss.item() | |
| t_full = timeit("FULL forward+backward", full_fwd_bwd) | |
| print() | |
| print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens") | |
| print(f"tok/s per forward: {B*T / (t_full/1000):.0f}") | |
| print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s") | |