feather-runtime / overlay /scripts /profile_forward.py
Jackoatmon's picture
Update Feather H200 runtime: Nemotron streaming and HTM force-CPU canary fixes
c2bf4b6 verified
"""Per-subsystem timing to find the tok/s bottleneck.
Runs a single forward+backward at (B=8, T=2048) and times each stage via
torch.cuda.Event. Reports ms/stage and derived tok/s budget.
"""
import os, sys, time
os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64")
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN
B, T = 8, MAX_SEQ_LEN
def timeit(name, fn, warmup=1, n=3):
for _ in range(warmup):
fn(); torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
times = []
for _ in range(n):
torch.cuda.synchronize()
s.record(); fn(); e.record(); torch.cuda.synchronize()
times.append(s.elapsed_time(e))
avg = sum(times)/len(times)
print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})")
return avg
cfg = PostSemClawConfig()
model = PostSemClawModel(cfg).cuda()
model.init_weights()
model.train()
idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long)
y = idx.clone()
print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n")
# Warmup full forward
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
_ = model(idx, y)
torch.cuda.synchronize()
print("Stage times (3 iter avg):\n")
# 1) wte
timeit("wte embedding", lambda: model.wte(idx).sum().item())
# 2) sdr_semantic (STE forward)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item())
# 3) sdr binary_only
timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item())
# 4) HTM full forward (with reset/learn)
with torch.no_grad():
timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item())
# 5) Mamba block stack only
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
def _blocks():
x = model.wte(idx)
from train import norm
x = norm(x)
streams = model.mhc[0].init_streams(x)
for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)):
def _bfn(h, _b=block): return _b(norm(h))
streams = mhc_layer(streams, _bfn)
x = model.mhc[-1].merge_streams(streams)
return x.sum().item()
timeit("Mamba+mHC blocks (n_layer=4)", _blocks)
# 6) Full forward+loss
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
timeit("FULL forward+loss", lambda: model(idx, y).item())
# 7) Full forward+loss+backward
def full_fwd_bwd():
model.zero_grad(set_to_none=True)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(idx, y)
loss.backward()
return loss.item()
t_full = timeit("FULL forward+backward", full_fwd_bwd)
print()
print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens")
print(f"tok/s per forward: {B*T / (t_full/1000):.0f}")
print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")