nemotron-kan-350m / bench_throughput.py
icarus112's picture
feat(arch): sync with Nemotron-4 architecture alignment (d823c2c)
eca97aa verified
#!/usr/bin/env python3
# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownParameterType=false, reportUnknownArgumentType=false, reportAny=false, reportUnusedCallResult=false, reportUnannotatedClassAttribute=false, reportImplicitOverride=false, reportDeprecated=false, reportCallIssue=false, reportArgumentType=false, reportFunctionMemberAccess=false, reportPossiblyUnboundVariable=false, reportOptionalMemberAccess=false, reportOptionalCall=false, reportAttributeAccessIssue=false, reportPrivateImportUsage=false
from __future__ import annotations
import os
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "")
import statistics
import time
import gc
from dataclasses import dataclass
from typing import Callable
import torch
import torch.nn as nn
from train import (
BATCH_SIZE,
DATASET_CONFIG,
DATASET_NAME,
SEQ_LEN,
StreamingTokenBuffer,
)
from nemotron_kan import NemotronKAN, NemotronKANConfig
from nemotron_kan.layers.kan_layers import GRKANActivation, GRKANMLPReplacement
from nemotron_kan.model import CausalSelfAttention
ITERATIONS = 50
WARMUP_COMPILE_STEPS = 3
WARMUP_COMPILE_TRAIN_STEPS = 5
NEMO_BATCH_SIZE = 2
NEMO_SEQ_LEN = 256
COMPILE_BATCH_SIZE = 1
@dataclass
class BenchResult:
level: str
description: str
median_tok_s: float | None
median_ms: float | None
status: str = "ok"
error: str | None = None
class TinyTransformerLM(nn.Module):
def __init__(
self, vocab_size: int, d_model: int = 256, nhead: int = 4, layers: int = 4
):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(SEQ_LEN, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=0.0,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, idx: torch.Tensor) -> torch.Tensor:
b, t = idx.shape
pos = (
torch.arange(0, t, device=idx.device, dtype=torch.long)
.unsqueeze(0)
.expand(b, t)
)
x = self.tok_emb(idx) + self.pos_emb(pos)
x = self.encoder(x)
return self.lm_head(x)
class VanillaGPT2Block(nn.Module):
def __init__(self, n_embd: int, n_head: int):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.ln1(x)
x = x + self.attn(h, h, h, need_weights=False)[0]
x = x + self.mlp(self.ln2(x))
return x
class VanillaGPT2(nn.Module):
def __init__(
self,
vocab_size: int = 50304,
n_layer: int = 12,
n_embd: int = 768,
n_head: int = 12,
block_size: int = 256,
):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_embd)
self.wpe = nn.Embedding(block_size, n_embd)
self.blocks = nn.ModuleList(
[VanillaGPT2Block(n_embd, n_head) for _ in range(n_layer)]
)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.wte.weight = self.lm_head.weight
def forward(self, idx: torch.Tensor) -> torch.Tensor:
_b, t = idx.shape
pos = torch.arange(t, device=idx.device, dtype=torch.long).unsqueeze(0)
x = self.wte(idx) + self.wpe(pos)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.lm_head(x)
def cleanup_cuda(*objs: object) -> None:
for obj in objs:
del obj
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def result_from_exception(level: str, description: str, exc: Exception) -> BenchResult:
msg = str(exc)
if isinstance(exc, RuntimeError) and "out of memory" in msg.lower():
gc.collect()
torch.cuda.empty_cache()
return BenchResult(level, description, None, None, status="oom", error=msg)
gc.collect()
return BenchResult(level, description, None, None, status="error", error=msg)
def gpu_sync(device: torch.device) -> None:
if device.type == "cuda":
torch.cuda.synchronize(device)
def run_step_benchmark(
level: str,
description: str,
tokens_per_step: int,
step_fn: Callable[[], object],
device: torch.device,
warmup_steps: int = 0,
) -> BenchResult:
try:
for _ in range(warmup_steps):
step_fn()
gpu_sync(device)
step_times_ms: list[float] = []
tok_s: list[float] = []
for _ in range(ITERATIONS):
gpu_sync(device)
t0 = time.perf_counter()
step_fn()
gpu_sync(device)
dt = (time.perf_counter() - t0) * 1000.0
step_times_ms.append(dt)
tok_s.append(tokens_per_step / (dt / 1000.0))
return BenchResult(
level=level,
description=description,
median_tok_s=statistics.median(tok_s),
median_ms=statistics.median(step_times_ms),
)
except RuntimeError as exc:
msg = str(exc)
if "out of memory" in msg.lower():
if device.type == "cuda":
torch.cuda.empty_cache()
return BenchResult(
level=level,
description=description,
median_tok_s=None,
median_ms=None,
status="oom",
error=msg,
)
return BenchResult(
level=level,
description=description,
median_tok_s=None,
median_ms=None,
status="error",
error=msg,
)
def print_level_result(result: BenchResult) -> None:
if result.status == "ok":
print(
f"Level {result.level}: {result.description} | {result.median_tok_s:,.0f} tok/s | {result.median_ms:.1f} ms/step"
)
elif result.status == "oom":
print(f"Level {result.level}: {result.description} | OOM | OOM")
else:
print(f"Level {result.level}: {result.description} | ERROR | ERROR")
if result.error:
print(f" error: {result.error}")
def print_summary(results: list[BenchResult]) -> None:
print("\n=== Throughput Summary ===")
print(f"{'Level':<8} {'Description':<50} {'Median tok/s':>14} {'Median ms':>12}")
print("-" * 90)
for r in results:
tok_s = (
f"{r.median_tok_s:,.0f}" if r.median_tok_s is not None else r.status.upper()
)
ms = f"{r.median_ms:.1f}" if r.median_ms is not None else r.status.upper()
print(f"{r.level:<8} {r.description:<50} {tok_s:>14} {ms:>12}")
print("\n=== Bottleneck Ratios (vs previous level) ===")
for i in range(1, len(results)):
prev = results[i - 1]
curr = results[i]
if prev.median_tok_s is None or curr.median_tok_s is None:
ratio_txt = "N/A"
else:
ratio = curr.median_tok_s / prev.median_tok_s
ratio_txt = f"{ratio:.3f}x"
print(f"L{curr.level} / L{prev.level}: {ratio_txt}")
valid = [r for r in results if r.median_tok_s is not None]
print("\n=== Succeeded Levels ===")
if valid:
print(", ".join([f"L{r.level}" for r in valid]))
else:
print("None")
if valid:
bottleneck = min(valid, key=lambda x: x.median_tok_s or 0.0)
best = max(valid, key=lambda x: x.median_tok_s or 0.0)
print("\n=== Bottleneck ===")
print(
f"Slowest measured stage: Level {bottleneck.level} ({bottleneck.description}) at {bottleneck.median_tok_s:,.0f} tok/s"
)
print(
f"Best measured stage: Level {best.level} ({best.description}) at {best.median_tok_s:,.0f} tok/s"
)
if bottleneck.median_tok_s and best.median_tok_s:
print(
f"Gap (best/slowest): {best.median_tok_s / bottleneck.median_tok_s:.2f}x"
)
def make_level4_config(
*,
block_size: int,
gradient_checkpointing: bool = False,
mhc: bool = False,
hc_num_streams: int = 2,
hc_disable: bool = False,
) -> NemotronKANConfig:
return NemotronKANConfig(
n_layer=12,
n_embd=768,
n_head=12,
vocab_size=50304,
block_size=block_size,
dropout=0.0,
kan_type="grkan",
kan_num_grids=4,
kan_hidden_mult=4,
kan_grkan_num_groups=8,
use_engram=False,
use_sdr_compression=False,
use_axiomatic_attention=False,
use_holographic_memory=False,
use_jit_cache=False,
use_synaptic_offload=False,
mhc=mhc,
hc_num_streams=hc_num_streams,
hc_disable=hc_disable,
gradient_checkpointing=gradient_checkpointing,
use_fused_ce=False,
)
def main() -> None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark. No GPU detected.")
device = torch.device("cuda")
cpu = torch.device("cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
print(f"GPU: {torch.cuda.get_device_name(device)}")
print(f"CUDA: {torch.version.cuda}")
print(f"Levels 0-3 batch size: {BATCH_SIZE}, seq len: {SEQ_LEN}")
print(f"Levels 4-6 batch size: {NEMO_BATCH_SIZE}, seq len: {NEMO_SEQ_LEN}")
print(f"Levels 7-9 batch size: {COMPILE_BATCH_SIZE}, seq len: {NEMO_SEQ_LEN}")
print(f"Iterations: {ITERATIONS}")
print()
tokens_per_step_main = BATCH_SIZE * SEQ_LEN
tokens_per_step_nemo = NEMO_BATCH_SIZE * NEMO_SEQ_LEN
tokens_per_step_compile = COMPILE_BATCH_SIZE * NEMO_SEQ_LEN
results: list[BenchResult] = []
stream = StreamingTokenBuffer(
split="train",
block_size=SEQ_LEN,
batch_size=BATCH_SIZE,
dataset_name=DATASET_NAME,
dataset_config=DATASET_CONFIG,
accelerator=None,
)
stream.prefill(32)
level0 = run_step_benchmark(
level="0",
description="Pure data loading (CPU)",
tokens_per_step=tokens_per_step_main,
step_fn=lambda: stream.get_batch(cpu),
device=cpu,
)
print_level_result(level0)
results.append(level0)
cleanup_cuda()
level1 = run_step_benchmark(
level="1",
description="Data loading + CUDA transfer",
tokens_per_step=tokens_per_step_main,
step_fn=lambda: stream.get_batch(device),
device=device,
)
print_level_result(level1)
results.append(level1)
cleanup_cuda()
x_static = torch.randint(
0, 50304, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long
)
y_static = torch.randint(
0, 50304, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long
)
x_nemo = torch.randint(
0, 50304, (NEMO_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long
)
y_nemo = torch.randint(
0, 50304, (NEMO_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long
)
x_compile = torch.randint(
0, 50304, (COMPILE_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long
)
y_compile = torch.randint(
0, 50304, (COMPILE_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long
)
level2_desc = "Tiny transformer forward only"
tiny_model: TinyTransformerLM | None = None
try:
tiny_model = TinyTransformerLM(vocab_size=50304).to(device)
tiny_model.train()
def level2_step() -> None:
with torch.inference_mode():
tiny_model(x_static)
level2 = run_step_benchmark(
level="2",
description=level2_desc,
tokens_per_step=tokens_per_step_main,
step_fn=level2_step,
device=device,
)
except Exception as exc:
level2 = result_from_exception("2", level2_desc, exc)
finally:
if tiny_model is not None:
cleanup_cuda(tiny_model)
print_level_result(level2)
results.append(level2)
cleanup_cuda()
level3_desc = "Tiny transformer forward + backward"
tiny_model_bwd: TinyTransformerLM | None = None
tiny_optim: torch.optim.Optimizer | None = None
try:
tiny_model_bwd = TinyTransformerLM(vocab_size=50304).to(device)
tiny_model_bwd.train()
tiny_optim = torch.optim.AdamW(tiny_model_bwd.parameters(), lr=1e-3)
def level3_step() -> None:
tiny_optim.zero_grad(set_to_none=True)
logits = tiny_model_bwd(x_static)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), y_static.view(-1)
)
loss.backward()
level3 = run_step_benchmark(
level="3",
description=level3_desc,
tokens_per_step=tokens_per_step_main,
step_fn=level3_step,
device=device,
)
except Exception as exc:
level3 = result_from_exception("3", level3_desc, exc)
finally:
if tiny_model_bwd is not None and tiny_optim is not None:
cleanup_cuda(tiny_model_bwd, tiny_optim)
print_level_result(level3)
results.append(level3)
cleanup_cuda()
level3c_desc = "Vanilla GPT-2 124M fwd+bwd (AMP)"
model3c: VanillaGPT2 | None = None
try:
model3c = VanillaGPT2(block_size=NEMO_SEQ_LEN).to(device)
model3c.train()
params3c = sum(p.numel() for p in model3c.parameters())
print(f"Level 3c param count: {params3c:,}")
def level3c_step() -> None:
model3c.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
logits = model3c(x_nemo)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), y_nemo.view(-1)
)
loss.backward()
level3c = run_step_benchmark(
level="3c",
description=level3c_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level3c_step,
device=device,
)
except Exception as exc:
level3c = result_from_exception("3c", level3c_desc, exc)
finally:
if model3c is not None:
cleanup_cuda(model3c)
print_level_result(level3c)
results.append(level3c)
cleanup_cuda()
level3d_desc = "Vanilla GPT-2 124M + torch.compile (AMP)"
model3d: object | None = None
try:
model3d = VanillaGPT2(block_size=NEMO_SEQ_LEN).to(device)
model3d.train()
model3d = torch.compile(model3d, backend="inductor")
def level3d_step() -> None:
model3d.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
logits = model3d(x_nemo)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), y_nemo.view(-1)
)
loss.backward()
level3d = run_step_benchmark(
level="3d",
description=level3d_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level3d_step,
device=device,
warmup_steps=WARMUP_COMPILE_TRAIN_STEPS,
)
except Exception as exc:
level3d = result_from_exception("3d", level3d_desc, exc)
finally:
if model3d is not None:
cleanup_cuda(model3d)
print_level_result(level3d)
results.append(level3d)
cleanup_cuda()
level4_desc = "NemotronKAN fwd+bwd (all trees off, AMP)"
model4: NemotronKAN | None = None
try:
model4 = NemotronKAN(make_level4_config(block_size=NEMO_SEQ_LEN)).to(device)
model4.train()
def level4_step() -> None:
model4.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model4(x_nemo, y_nemo)
if loss is None:
raise RuntimeError("Level 4 loss is None")
loss.backward()
level4 = run_step_benchmark(
level="4",
description=level4_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level4_step,
device=device,
)
except Exception as exc:
level4 = result_from_exception("4", level4_desc, exc)
finally:
if model4 is not None:
cleanup_cuda(model4)
print_level_result(level4)
results.append(level4)
cleanup_cuda()
level4b_desc = "NemotronKAN fwd+bwd + grad checkpointing (AMP)"
model4b: NemotronKAN | None = None
try:
model4b = NemotronKAN(
make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True)
).to(device)
model4b.train()
def level4b_step() -> None:
model4b.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model4b(x_nemo, y_nemo)
if loss is None:
raise RuntimeError("Level 4b loss is None")
loss.backward()
level4b = run_step_benchmark(
level="4b",
description=level4b_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level4b_step,
device=device,
)
except Exception as exc:
level4b = result_from_exception("4b", level4b_desc, exc)
finally:
if model4b is not None:
cleanup_cuda(model4b)
print_level_result(level4b)
results.append(level4b)
cleanup_cuda()
level4c_desc = "NemotronKAN fwd+bwd + HC disabled (AMP)"
model4c: NemotronKAN | None = None
try:
model4c = NemotronKAN(
make_level4_config(
block_size=NEMO_SEQ_LEN,
gradient_checkpointing=False,
mhc=False,
hc_num_streams=1,
hc_disable=True,
)
).to(device)
model4c.train()
def level4c_step() -> None:
model4c.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model4c(x_nemo, y_nemo)
if loss is None:
raise RuntimeError("Level 4c loss is None")
loss.backward()
level4c = run_step_benchmark(
level="4c",
description=level4c_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level4c_step,
device=device,
)
except Exception as exc:
level4c = result_from_exception("4c", level4c_desc, exc)
finally:
if model4c is not None:
cleanup_cuda(model4c)
print_level_result(level4c)
results.append(level4c)
cleanup_cuda()
level5_desc = "NemotronKAN + HyperConnections (AMP)"
model5: NemotronKAN | None = None
try:
model5 = NemotronKAN(
make_level4_config(block_size=NEMO_SEQ_LEN, mhc=True, hc_num_streams=2)
).to(device)
model5.train()
def level5_step() -> None:
model5.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model5(x_nemo, y_nemo)
if loss is None:
raise RuntimeError("Level 5 loss is None")
loss.backward()
level5 = run_step_benchmark(
level="5",
description=level5_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level5_step,
device=device,
)
except Exception as exc:
level5 = result_from_exception("5", level5_desc, exc)
finally:
if model5 is not None:
cleanup_cuda(model5)
print_level_result(level5)
results.append(level5)
cleanup_cuda()
level6_desc = "NemotronKAN + gradient checkpointing (AMP)"
model6: NemotronKAN | None = None
try:
model6 = NemotronKAN(
make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True)
).to(device)
model6.train()
def level6_step() -> None:
model6.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model6(x_nemo, y_nemo)
if loss is None:
raise RuntimeError("Level 6 loss is None")
loss.backward()
level6 = run_step_benchmark(
level="6",
description=level6_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level6_step,
device=device,
)
except Exception as exc:
level6 = result_from_exception("6", level6_desc, exc)
finally:
if model6 is not None:
cleanup_cuda(model6)
print_level_result(level6)
results.append(level6)
cleanup_cuda()
level7_desc = "NemotronKAN + torch.compile (AMP)"
model7: object | None = None
try:
model7 = NemotronKAN(make_level4_config(block_size=NEMO_SEQ_LEN)).to(device)
model7.train()
model7 = torch.compile(model7, backend="inductor")
def level7_step() -> None:
model7.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model7(x_compile, y_compile)
if loss is None:
raise RuntimeError("Level 7 loss is None")
loss.backward()
level7 = run_step_benchmark(
level="7",
description=level7_desc,
tokens_per_step=tokens_per_step_compile,
step_fn=level7_step,
device=device,
warmup_steps=WARMUP_COMPILE_STEPS,
)
except Exception as exc:
level7 = result_from_exception("7", level7_desc, exc)
finally:
if model7 is not None:
cleanup_cuda(model7)
print_level_result(level7)
results.append(level7)
cleanup_cuda()
level8_desc = "Full training step (AMP+AdamW fused+ckpt)"
model8: NemotronKAN | None = None
optim8: torch.optim.Optimizer | None = None
scaler8: torch.amp.GradScaler | None = None
try:
model8 = NemotronKAN(
make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True)
).to(device)
model8.train()
optim8 = torch.optim.AdamW(model8.parameters(), lr=1e-4, fused=True)
scaler8 = torch.amp.GradScaler("cuda", enabled=True)
def level8_step() -> None:
optim8.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model8(x_compile, y_compile)
if loss is None:
raise RuntimeError("Level 8 loss is None")
scaler8.scale(loss).backward()
scaler8.unscale_(optim8)
torch.nn.utils.clip_grad_norm_(model8.parameters(), max_norm=1.0)
scaler8.step(optim8)
scaler8.update()
level8 = run_step_benchmark(
level="8",
description=level8_desc,
tokens_per_step=tokens_per_step_compile,
step_fn=level8_step,
device=device,
)
except Exception as exc:
level8 = result_from_exception("8", level8_desc, exc)
finally:
if model8 is not None and optim8 is not None and scaler8 is not None:
cleanup_cuda(model8, optim8, scaler8)
print_level_result(level8)
results.append(level8)
cleanup_cuda()
level9_desc = "Full training step + torch.compile (AMP+AdamW fused+ckpt)"
model9: object | None = None
optim9: torch.optim.Optimizer | None = None
scaler9: torch.amp.GradScaler | None = None
try:
model9 = NemotronKAN(
make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True)
).to(device)
model9.train()
model9 = torch.compile(model9, backend="inductor")
optim9 = torch.optim.AdamW(model9.parameters(), lr=1e-4, fused=True)
scaler9 = torch.amp.GradScaler("cuda", enabled=True)
def level9_step() -> None:
optim9.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
_logits, loss, _info = model9(x_compile, y_compile)
if loss is None:
raise RuntimeError("Level 9 loss is None")
scaler9.scale(loss).backward()
scaler9.unscale_(optim9)
torch.nn.utils.clip_grad_norm_(model9.parameters(), max_norm=1.0)
scaler9.step(optim9)
scaler9.update()
level9 = run_step_benchmark(
level="9",
description=level9_desc,
tokens_per_step=tokens_per_step_compile,
step_fn=level9_step,
device=device,
warmup_steps=WARMUP_COMPILE_TRAIN_STEPS,
)
except Exception as exc:
level9 = result_from_exception("9", level9_desc, exc)
finally:
if model9 is not None and optim9 is not None and scaler9 is not None:
cleanup_cuda(model9, optim9, scaler9)
print_level_result(level9)
results.append(level9)
cleanup_cuda()
level10_desc = "GRKANActivation only (Triton kernel)"
level10: BenchResult
act10: GRKANActivation | None = None
try:
act10 = GRKANActivation(num_groups=8).to(device)
x10 = torch.randn(
NEMO_BATCH_SIZE,
NEMO_SEQ_LEN,
768,
device=device,
dtype=torch.float32,
requires_grad=False,
)
def level10_step() -> None:
_ = act10(x10)
level10 = run_step_benchmark(
level="10",
description=level10_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level10_step,
device=device,
)
except Exception as exc:
level10 = result_from_exception("10", level10_desc, exc)
finally:
if act10 is not None:
cleanup_cuda(act10)
print_level_result(level10)
results.append(level10)
cleanup_cuda()
level11_desc = "nn.Linear 768->3072 only"
level11: BenchResult
fc11: nn.Linear | None = None
try:
fc11 = nn.Linear(768, 3072).to(device).half()
x11 = torch.randn(
NEMO_BATCH_SIZE,
NEMO_SEQ_LEN,
768,
device=device,
dtype=torch.float16,
requires_grad=False,
)
def level11_step() -> None:
_ = fc11(x11)
level11 = run_step_benchmark(
level="11",
description=level11_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level11_step,
device=device,
)
except Exception as exc:
level11 = result_from_exception("11", level11_desc, exc)
finally:
if fc11 is not None:
cleanup_cuda(fc11)
print_level_result(level11)
results.append(level11)
cleanup_cuda()
level12_desc = "GRKANMLPReplacement alone fwd+bwd (AMP)"
level12: BenchResult
mlp12: GRKANMLPReplacement | None = None
try:
mlp12 = GRKANMLPReplacement(
768,
hidden_mult=4,
num_groups=8,
dropout=0.0,
use_checkpoint=False,
).to(device)
mlp12.train()
x12 = torch.randn(
NEMO_BATCH_SIZE,
NEMO_SEQ_LEN,
768,
device=device,
dtype=torch.float32,
)
def level12_step() -> None:
mlp12.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
y12 = mlp12(x12)
loss12 = y12.float().pow(2).mean()
loss12.backward()
level12 = run_step_benchmark(
level="12",
description=level12_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level12_step,
device=device,
)
except Exception as exc:
level12 = result_from_exception("12", level12_desc, exc)
finally:
if mlp12 is not None:
cleanup_cuda(mlp12)
print_level_result(level12)
results.append(level12)
cleanup_cuda()
level13_desc = "Standard MLP alone fwd+bwd (AMP)"
level13: BenchResult
mlp13: nn.Sequential | None = None
try:
mlp13 = nn.Sequential(
nn.Linear(768, 3072),
nn.GELU(),
nn.Linear(3072, 768),
).to(device)
mlp13.train()
x13 = torch.randn(
NEMO_BATCH_SIZE,
NEMO_SEQ_LEN,
768,
device=device,
dtype=torch.float16,
)
def level13_step() -> None:
mlp13.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
y13 = mlp13(x13)
loss13 = y13.float().pow(2).mean()
loss13.backward()
level13 = run_step_benchmark(
level="13",
description=level13_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level13_step,
device=device,
)
except Exception as exc:
level13 = result_from_exception("13", level13_desc, exc)
finally:
if mlp13 is not None:
cleanup_cuda(mlp13)
print_level_result(level13)
results.append(level13)
cleanup_cuda()
level14_desc = "CausalSelfAttention alone fwd+bwd (AMP)"
level14: BenchResult
attn14: CausalSelfAttention | None = None
try:
cfg14 = NemotronKANConfig(n_embd=768, n_head=12, block_size=256, dropout=0.0)
attn14 = CausalSelfAttention(cfg14).to(device)
attn14.train()
x14 = torch.randn(
NEMO_BATCH_SIZE,
NEMO_SEQ_LEN,
768,
device=device,
dtype=torch.float32,
)
def level14_step() -> None:
attn14.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", dtype=torch.float16):
y14 = attn14(x14)
loss14 = y14.float().pow(2).mean()
loss14.backward()
level14 = run_step_benchmark(
level="14",
description=level14_desc,
tokens_per_step=tokens_per_step_nemo,
step_fn=level14_step,
device=device,
)
except Exception as exc:
level14 = result_from_exception("14", level14_desc, exc)
finally:
if attn14 is not None:
cleanup_cuda(attn14)
print_level_result(level14)
results.append(level14)
cleanup_cuda()
print("\nDIAGNOSIS:")
if level10.median_ms is not None:
print(f" GRKANActivation overhead vs no-op: {level10.median_ms:.1f}ms")
else:
print(" GRKANActivation overhead vs no-op: N/A")
if level12.median_ms is not None and level13.median_ms is not None:
print(
f" GR-KAN MLP vs Standard MLP: {level12.median_ms / level13.median_ms:.1f}x slower"
)
else:
print(" GR-KAN MLP vs Standard MLP: N/A")
if level14.median_ms is not None:
print(f" Attention: {level14.median_ms:.1f}ms per call")
else:
print(" Attention: N/A")
if level12.median_ms is not None:
print(f" MLP: {level12.median_ms:.1f}ms per call")
else:
print(" MLP: N/A")
if level14.median_ms is not None and level12.median_ms is not None:
print(f" Attention/MLP ratio: {level14.median_ms / level12.median_ms:.2f}")
else:
print(" Attention/MLP ratio: N/A")
print_summary(results)
if __name__ == "__main__":
main()