File size: 3,217 Bytes
c2bf4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Per-subsystem timing to find the tok/s bottleneck.

Runs a single forward+backward at (B=8, T=2048) and times each stage via
torch.cuda.Event. Reports ms/stage and derived tok/s budget.
"""
import os, sys, time
os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64")
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN

B, T = 8, MAX_SEQ_LEN

def timeit(name, fn, warmup=1, n=3):
    for _ in range(warmup):
        fn(); torch.cuda.synchronize()
    s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
    times = []
    for _ in range(n):
        torch.cuda.synchronize()
        s.record(); fn(); e.record(); torch.cuda.synchronize()
        times.append(s.elapsed_time(e))
    avg = sum(times)/len(times)
    print(f"  {name:30s} {avg:8.2f} ms   (min {min(times):.2f} max {max(times):.2f})")
    return avg

cfg = PostSemClawConfig()
model = PostSemClawModel(cfg).cuda()
model.init_weights()
model.train()
idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long)
y = idx.clone()

print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n")

# Warmup full forward
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    _ = model(idx, y)
torch.cuda.synchronize()

print("Stage times (3 iter avg):\n")

# 1) wte
timeit("wte embedding", lambda: model.wte(idx).sum().item())

# 2) sdr_semantic (STE forward)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item())

# 3) sdr binary_only
timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item())

# 4) HTM full forward (with reset/learn)
with torch.no_grad():
    timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item())

# 5) Mamba block stack only
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    def _blocks():
        x = model.wte(idx)
        from train import norm
        x = norm(x)
        streams = model.mhc[0].init_streams(x)
        for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)):
            def _bfn(h, _b=block): return _b(norm(h))
            streams = mhc_layer(streams, _bfn)
        x = model.mhc[-1].merge_streams(streams)
        return x.sum().item()
    timeit("Mamba+mHC blocks (n_layer=4)", _blocks)

# 6) Full forward+loss
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    timeit("FULL forward+loss", lambda: model(idx, y).item())

# 7) Full forward+loss+backward
def full_fwd_bwd():
    model.zero_grad(set_to_none=True)
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        loss = model(idx, y)
    loss.backward()
    return loss.item()
t_full = timeit("FULL forward+backward", full_fwd_bwd)

print()
print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens")
print(f"tok/s per forward: {B*T / (t_full/1000):.0f}")
print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")