#!/usr/bin/env python3 # pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownParameterType=false, reportUnknownArgumentType=false, reportAny=false, reportUnusedCallResult=false, reportUnannotatedClassAttribute=false, reportImplicitOverride=false, reportDeprecated=false, reportCallIssue=false, reportArgumentType=false, reportFunctionMemberAccess=false, reportPossiblyUnboundVariable=false, reportOptionalMemberAccess=false, reportOptionalCall=false, reportAttributeAccessIssue=false, reportPrivateImportUsage=false from __future__ import annotations import os os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "") import statistics import time import gc from dataclasses import dataclass from typing import Callable import torch import torch.nn as nn from train import ( BATCH_SIZE, DATASET_CONFIG, DATASET_NAME, SEQ_LEN, StreamingTokenBuffer, ) from nemotron_kan import NemotronKAN, NemotronKANConfig from nemotron_kan.layers.kan_layers import GRKANActivation, GRKANMLPReplacement from nemotron_kan.model import CausalSelfAttention ITERATIONS = 50 WARMUP_COMPILE_STEPS = 3 WARMUP_COMPILE_TRAIN_STEPS = 5 NEMO_BATCH_SIZE = 2 NEMO_SEQ_LEN = 256 COMPILE_BATCH_SIZE = 1 @dataclass class BenchResult: level: str description: str median_tok_s: float | None median_ms: float | None status: str = "ok" error: str | None = None class TinyTransformerLM(nn.Module): def __init__( self, vocab_size: int, d_model: int = 256, nhead: int = 4, layers: int = 4 ): super().__init__() self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(SEQ_LEN, d_model) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=0.0, activation="gelu", batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers) self.lm_head = nn.Linear(d_model, vocab_size) def forward(self, idx: torch.Tensor) -> torch.Tensor: b, t = idx.shape pos = ( torch.arange(0, t, device=idx.device, dtype=torch.long) .unsqueeze(0) .expand(b, t) ) x = self.tok_emb(idx) + self.pos_emb(pos) x = self.encoder(x) return self.lm_head(x) class VanillaGPT2Block(nn.Module): def __init__(self, n_embd: int, n_head: int): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True) self.ln2 = nn.LayerNorm(n_embd) self.mlp = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), ) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.ln1(x) x = x + self.attn(h, h, h, need_weights=False)[0] x = x + self.mlp(self.ln2(x)) return x class VanillaGPT2(nn.Module): def __init__( self, vocab_size: int = 50304, n_layer: int = 12, n_embd: int = 768, n_head: int = 12, block_size: int = 256, ): super().__init__() self.wte = nn.Embedding(vocab_size, n_embd) self.wpe = nn.Embedding(block_size, n_embd) self.blocks = nn.ModuleList( [VanillaGPT2Block(n_embd, n_head) for _ in range(n_layer)] ) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.wte.weight = self.lm_head.weight def forward(self, idx: torch.Tensor) -> torch.Tensor: _b, t = idx.shape pos = torch.arange(t, device=idx.device, dtype=torch.long).unsqueeze(0) x = self.wte(idx) + self.wpe(pos) for block in self.blocks: x = block(x) x = self.ln_f(x) return self.lm_head(x) def cleanup_cuda(*objs: object) -> None: for obj in objs: del obj gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def result_from_exception(level: str, description: str, exc: Exception) -> BenchResult: msg = str(exc) if isinstance(exc, RuntimeError) and "out of memory" in msg.lower(): gc.collect() torch.cuda.empty_cache() return BenchResult(level, description, None, None, status="oom", error=msg) gc.collect() return BenchResult(level, description, None, None, status="error", error=msg) def gpu_sync(device: torch.device) -> None: if device.type == "cuda": torch.cuda.synchronize(device) def run_step_benchmark( level: str, description: str, tokens_per_step: int, step_fn: Callable[[], object], device: torch.device, warmup_steps: int = 0, ) -> BenchResult: try: for _ in range(warmup_steps): step_fn() gpu_sync(device) step_times_ms: list[float] = [] tok_s: list[float] = [] for _ in range(ITERATIONS): gpu_sync(device) t0 = time.perf_counter() step_fn() gpu_sync(device) dt = (time.perf_counter() - t0) * 1000.0 step_times_ms.append(dt) tok_s.append(tokens_per_step / (dt / 1000.0)) return BenchResult( level=level, description=description, median_tok_s=statistics.median(tok_s), median_ms=statistics.median(step_times_ms), ) except RuntimeError as exc: msg = str(exc) if "out of memory" in msg.lower(): if device.type == "cuda": torch.cuda.empty_cache() return BenchResult( level=level, description=description, median_tok_s=None, median_ms=None, status="oom", error=msg, ) return BenchResult( level=level, description=description, median_tok_s=None, median_ms=None, status="error", error=msg, ) def print_level_result(result: BenchResult) -> None: if result.status == "ok": print( f"Level {result.level}: {result.description} | {result.median_tok_s:,.0f} tok/s | {result.median_ms:.1f} ms/step" ) elif result.status == "oom": print(f"Level {result.level}: {result.description} | OOM | OOM") else: print(f"Level {result.level}: {result.description} | ERROR | ERROR") if result.error: print(f" error: {result.error}") def print_summary(results: list[BenchResult]) -> None: print("\n=== Throughput Summary ===") print(f"{'Level':<8} {'Description':<50} {'Median tok/s':>14} {'Median ms':>12}") print("-" * 90) for r in results: tok_s = ( f"{r.median_tok_s:,.0f}" if r.median_tok_s is not None else r.status.upper() ) ms = f"{r.median_ms:.1f}" if r.median_ms is not None else r.status.upper() print(f"{r.level:<8} {r.description:<50} {tok_s:>14} {ms:>12}") print("\n=== Bottleneck Ratios (vs previous level) ===") for i in range(1, len(results)): prev = results[i - 1] curr = results[i] if prev.median_tok_s is None or curr.median_tok_s is None: ratio_txt = "N/A" else: ratio = curr.median_tok_s / prev.median_tok_s ratio_txt = f"{ratio:.3f}x" print(f"L{curr.level} / L{prev.level}: {ratio_txt}") valid = [r for r in results if r.median_tok_s is not None] print("\n=== Succeeded Levels ===") if valid: print(", ".join([f"L{r.level}" for r in valid])) else: print("None") if valid: bottleneck = min(valid, key=lambda x: x.median_tok_s or 0.0) best = max(valid, key=lambda x: x.median_tok_s or 0.0) print("\n=== Bottleneck ===") print( f"Slowest measured stage: Level {bottleneck.level} ({bottleneck.description}) at {bottleneck.median_tok_s:,.0f} tok/s" ) print( f"Best measured stage: Level {best.level} ({best.description}) at {best.median_tok_s:,.0f} tok/s" ) if bottleneck.median_tok_s and best.median_tok_s: print( f"Gap (best/slowest): {best.median_tok_s / bottleneck.median_tok_s:.2f}x" ) def make_level4_config( *, block_size: int, gradient_checkpointing: bool = False, mhc: bool = False, hc_num_streams: int = 2, hc_disable: bool = False, ) -> NemotronKANConfig: return NemotronKANConfig( n_layer=12, n_embd=768, n_head=12, vocab_size=50304, block_size=block_size, dropout=0.0, kan_type="grkan", kan_num_grids=4, kan_hidden_mult=4, kan_grkan_num_groups=8, use_engram=False, use_sdr_compression=False, use_axiomatic_attention=False, use_holographic_memory=False, use_jit_cache=False, use_synaptic_offload=False, mhc=mhc, hc_num_streams=hc_num_streams, hc_disable=hc_disable, gradient_checkpointing=gradient_checkpointing, use_fused_ce=False, ) def main() -> None: if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this benchmark. No GPU detected.") device = torch.device("cuda") cpu = torch.device("cpu") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") print(f"GPU: {torch.cuda.get_device_name(device)}") print(f"CUDA: {torch.version.cuda}") print(f"Levels 0-3 batch size: {BATCH_SIZE}, seq len: {SEQ_LEN}") print(f"Levels 4-6 batch size: {NEMO_BATCH_SIZE}, seq len: {NEMO_SEQ_LEN}") print(f"Levels 7-9 batch size: {COMPILE_BATCH_SIZE}, seq len: {NEMO_SEQ_LEN}") print(f"Iterations: {ITERATIONS}") print() tokens_per_step_main = BATCH_SIZE * SEQ_LEN tokens_per_step_nemo = NEMO_BATCH_SIZE * NEMO_SEQ_LEN tokens_per_step_compile = COMPILE_BATCH_SIZE * NEMO_SEQ_LEN results: list[BenchResult] = [] stream = StreamingTokenBuffer( split="train", block_size=SEQ_LEN, batch_size=BATCH_SIZE, dataset_name=DATASET_NAME, dataset_config=DATASET_CONFIG, accelerator=None, ) stream.prefill(32) level0 = run_step_benchmark( level="0", description="Pure data loading (CPU)", tokens_per_step=tokens_per_step_main, step_fn=lambda: stream.get_batch(cpu), device=cpu, ) print_level_result(level0) results.append(level0) cleanup_cuda() level1 = run_step_benchmark( level="1", description="Data loading + CUDA transfer", tokens_per_step=tokens_per_step_main, step_fn=lambda: stream.get_batch(device), device=device, ) print_level_result(level1) results.append(level1) cleanup_cuda() x_static = torch.randint( 0, 50304, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long ) y_static = torch.randint( 0, 50304, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long ) x_nemo = torch.randint( 0, 50304, (NEMO_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long ) y_nemo = torch.randint( 0, 50304, (NEMO_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long ) x_compile = torch.randint( 0, 50304, (COMPILE_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long ) y_compile = torch.randint( 0, 50304, (COMPILE_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long ) level2_desc = "Tiny transformer forward only" tiny_model: TinyTransformerLM | None = None try: tiny_model = TinyTransformerLM(vocab_size=50304).to(device) tiny_model.train() def level2_step() -> None: with torch.inference_mode(): tiny_model(x_static) level2 = run_step_benchmark( level="2", description=level2_desc, tokens_per_step=tokens_per_step_main, step_fn=level2_step, device=device, ) except Exception as exc: level2 = result_from_exception("2", level2_desc, exc) finally: if tiny_model is not None: cleanup_cuda(tiny_model) print_level_result(level2) results.append(level2) cleanup_cuda() level3_desc = "Tiny transformer forward + backward" tiny_model_bwd: TinyTransformerLM | None = None tiny_optim: torch.optim.Optimizer | None = None try: tiny_model_bwd = TinyTransformerLM(vocab_size=50304).to(device) tiny_model_bwd.train() tiny_optim = torch.optim.AdamW(tiny_model_bwd.parameters(), lr=1e-3) def level3_step() -> None: tiny_optim.zero_grad(set_to_none=True) logits = tiny_model_bwd(x_static) loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), y_static.view(-1) ) loss.backward() level3 = run_step_benchmark( level="3", description=level3_desc, tokens_per_step=tokens_per_step_main, step_fn=level3_step, device=device, ) except Exception as exc: level3 = result_from_exception("3", level3_desc, exc) finally: if tiny_model_bwd is not None and tiny_optim is not None: cleanup_cuda(tiny_model_bwd, tiny_optim) print_level_result(level3) results.append(level3) cleanup_cuda() level3c_desc = "Vanilla GPT-2 124M fwd+bwd (AMP)" model3c: VanillaGPT2 | None = None try: model3c = VanillaGPT2(block_size=NEMO_SEQ_LEN).to(device) model3c.train() params3c = sum(p.numel() for p in model3c.parameters()) print(f"Level 3c param count: {params3c:,}") def level3c_step() -> None: model3c.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): logits = model3c(x_nemo) loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), y_nemo.view(-1) ) loss.backward() level3c = run_step_benchmark( level="3c", description=level3c_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level3c_step, device=device, ) except Exception as exc: level3c = result_from_exception("3c", level3c_desc, exc) finally: if model3c is not None: cleanup_cuda(model3c) print_level_result(level3c) results.append(level3c) cleanup_cuda() level3d_desc = "Vanilla GPT-2 124M + torch.compile (AMP)" model3d: object | None = None try: model3d = VanillaGPT2(block_size=NEMO_SEQ_LEN).to(device) model3d.train() model3d = torch.compile(model3d, backend="inductor") def level3d_step() -> None: model3d.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): logits = model3d(x_nemo) loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), y_nemo.view(-1) ) loss.backward() level3d = run_step_benchmark( level="3d", description=level3d_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level3d_step, device=device, warmup_steps=WARMUP_COMPILE_TRAIN_STEPS, ) except Exception as exc: level3d = result_from_exception("3d", level3d_desc, exc) finally: if model3d is not None: cleanup_cuda(model3d) print_level_result(level3d) results.append(level3d) cleanup_cuda() level4_desc = "NemotronKAN fwd+bwd (all trees off, AMP)" model4: NemotronKAN | None = None try: model4 = NemotronKAN(make_level4_config(block_size=NEMO_SEQ_LEN)).to(device) model4.train() def level4_step() -> None: model4.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model4(x_nemo, y_nemo) if loss is None: raise RuntimeError("Level 4 loss is None") loss.backward() level4 = run_step_benchmark( level="4", description=level4_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level4_step, device=device, ) except Exception as exc: level4 = result_from_exception("4", level4_desc, exc) finally: if model4 is not None: cleanup_cuda(model4) print_level_result(level4) results.append(level4) cleanup_cuda() level4b_desc = "NemotronKAN fwd+bwd + grad checkpointing (AMP)" model4b: NemotronKAN | None = None try: model4b = NemotronKAN( make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) ).to(device) model4b.train() def level4b_step() -> None: model4b.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model4b(x_nemo, y_nemo) if loss is None: raise RuntimeError("Level 4b loss is None") loss.backward() level4b = run_step_benchmark( level="4b", description=level4b_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level4b_step, device=device, ) except Exception as exc: level4b = result_from_exception("4b", level4b_desc, exc) finally: if model4b is not None: cleanup_cuda(model4b) print_level_result(level4b) results.append(level4b) cleanup_cuda() level4c_desc = "NemotronKAN fwd+bwd + HC disabled (AMP)" model4c: NemotronKAN | None = None try: model4c = NemotronKAN( make_level4_config( block_size=NEMO_SEQ_LEN, gradient_checkpointing=False, mhc=False, hc_num_streams=1, hc_disable=True, ) ).to(device) model4c.train() def level4c_step() -> None: model4c.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model4c(x_nemo, y_nemo) if loss is None: raise RuntimeError("Level 4c loss is None") loss.backward() level4c = run_step_benchmark( level="4c", description=level4c_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level4c_step, device=device, ) except Exception as exc: level4c = result_from_exception("4c", level4c_desc, exc) finally: if model4c is not None: cleanup_cuda(model4c) print_level_result(level4c) results.append(level4c) cleanup_cuda() level5_desc = "NemotronKAN + HyperConnections (AMP)" model5: NemotronKAN | None = None try: model5 = NemotronKAN( make_level4_config(block_size=NEMO_SEQ_LEN, mhc=True, hc_num_streams=2) ).to(device) model5.train() def level5_step() -> None: model5.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model5(x_nemo, y_nemo) if loss is None: raise RuntimeError("Level 5 loss is None") loss.backward() level5 = run_step_benchmark( level="5", description=level5_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level5_step, device=device, ) except Exception as exc: level5 = result_from_exception("5", level5_desc, exc) finally: if model5 is not None: cleanup_cuda(model5) print_level_result(level5) results.append(level5) cleanup_cuda() level6_desc = "NemotronKAN + gradient checkpointing (AMP)" model6: NemotronKAN | None = None try: model6 = NemotronKAN( make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) ).to(device) model6.train() def level6_step() -> None: model6.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model6(x_nemo, y_nemo) if loss is None: raise RuntimeError("Level 6 loss is None") loss.backward() level6 = run_step_benchmark( level="6", description=level6_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level6_step, device=device, ) except Exception as exc: level6 = result_from_exception("6", level6_desc, exc) finally: if model6 is not None: cleanup_cuda(model6) print_level_result(level6) results.append(level6) cleanup_cuda() level7_desc = "NemotronKAN + torch.compile (AMP)" model7: object | None = None try: model7 = NemotronKAN(make_level4_config(block_size=NEMO_SEQ_LEN)).to(device) model7.train() model7 = torch.compile(model7, backend="inductor") def level7_step() -> None: model7.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model7(x_compile, y_compile) if loss is None: raise RuntimeError("Level 7 loss is None") loss.backward() level7 = run_step_benchmark( level="7", description=level7_desc, tokens_per_step=tokens_per_step_compile, step_fn=level7_step, device=device, warmup_steps=WARMUP_COMPILE_STEPS, ) except Exception as exc: level7 = result_from_exception("7", level7_desc, exc) finally: if model7 is not None: cleanup_cuda(model7) print_level_result(level7) results.append(level7) cleanup_cuda() level8_desc = "Full training step (AMP+AdamW fused+ckpt)" model8: NemotronKAN | None = None optim8: torch.optim.Optimizer | None = None scaler8: torch.amp.GradScaler | None = None try: model8 = NemotronKAN( make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) ).to(device) model8.train() optim8 = torch.optim.AdamW(model8.parameters(), lr=1e-4, fused=True) scaler8 = torch.amp.GradScaler("cuda", enabled=True) def level8_step() -> None: optim8.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model8(x_compile, y_compile) if loss is None: raise RuntimeError("Level 8 loss is None") scaler8.scale(loss).backward() scaler8.unscale_(optim8) torch.nn.utils.clip_grad_norm_(model8.parameters(), max_norm=1.0) scaler8.step(optim8) scaler8.update() level8 = run_step_benchmark( level="8", description=level8_desc, tokens_per_step=tokens_per_step_compile, step_fn=level8_step, device=device, ) except Exception as exc: level8 = result_from_exception("8", level8_desc, exc) finally: if model8 is not None and optim8 is not None and scaler8 is not None: cleanup_cuda(model8, optim8, scaler8) print_level_result(level8) results.append(level8) cleanup_cuda() level9_desc = "Full training step + torch.compile (AMP+AdamW fused+ckpt)" model9: object | None = None optim9: torch.optim.Optimizer | None = None scaler9: torch.amp.GradScaler | None = None try: model9 = NemotronKAN( make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) ).to(device) model9.train() model9 = torch.compile(model9, backend="inductor") optim9 = torch.optim.AdamW(model9.parameters(), lr=1e-4, fused=True) scaler9 = torch.amp.GradScaler("cuda", enabled=True) def level9_step() -> None: optim9.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): _logits, loss, _info = model9(x_compile, y_compile) if loss is None: raise RuntimeError("Level 9 loss is None") scaler9.scale(loss).backward() scaler9.unscale_(optim9) torch.nn.utils.clip_grad_norm_(model9.parameters(), max_norm=1.0) scaler9.step(optim9) scaler9.update() level9 = run_step_benchmark( level="9", description=level9_desc, tokens_per_step=tokens_per_step_compile, step_fn=level9_step, device=device, warmup_steps=WARMUP_COMPILE_TRAIN_STEPS, ) except Exception as exc: level9 = result_from_exception("9", level9_desc, exc) finally: if model9 is not None and optim9 is not None and scaler9 is not None: cleanup_cuda(model9, optim9, scaler9) print_level_result(level9) results.append(level9) cleanup_cuda() level10_desc = "GRKANActivation only (Triton kernel)" level10: BenchResult act10: GRKANActivation | None = None try: act10 = GRKANActivation(num_groups=8).to(device) x10 = torch.randn( NEMO_BATCH_SIZE, NEMO_SEQ_LEN, 768, device=device, dtype=torch.float32, requires_grad=False, ) def level10_step() -> None: _ = act10(x10) level10 = run_step_benchmark( level="10", description=level10_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level10_step, device=device, ) except Exception as exc: level10 = result_from_exception("10", level10_desc, exc) finally: if act10 is not None: cleanup_cuda(act10) print_level_result(level10) results.append(level10) cleanup_cuda() level11_desc = "nn.Linear 768->3072 only" level11: BenchResult fc11: nn.Linear | None = None try: fc11 = nn.Linear(768, 3072).to(device).half() x11 = torch.randn( NEMO_BATCH_SIZE, NEMO_SEQ_LEN, 768, device=device, dtype=torch.float16, requires_grad=False, ) def level11_step() -> None: _ = fc11(x11) level11 = run_step_benchmark( level="11", description=level11_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level11_step, device=device, ) except Exception as exc: level11 = result_from_exception("11", level11_desc, exc) finally: if fc11 is not None: cleanup_cuda(fc11) print_level_result(level11) results.append(level11) cleanup_cuda() level12_desc = "GRKANMLPReplacement alone fwd+bwd (AMP)" level12: BenchResult mlp12: GRKANMLPReplacement | None = None try: mlp12 = GRKANMLPReplacement( 768, hidden_mult=4, num_groups=8, dropout=0.0, use_checkpoint=False, ).to(device) mlp12.train() x12 = torch.randn( NEMO_BATCH_SIZE, NEMO_SEQ_LEN, 768, device=device, dtype=torch.float32, ) def level12_step() -> None: mlp12.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): y12 = mlp12(x12) loss12 = y12.float().pow(2).mean() loss12.backward() level12 = run_step_benchmark( level="12", description=level12_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level12_step, device=device, ) except Exception as exc: level12 = result_from_exception("12", level12_desc, exc) finally: if mlp12 is not None: cleanup_cuda(mlp12) print_level_result(level12) results.append(level12) cleanup_cuda() level13_desc = "Standard MLP alone fwd+bwd (AMP)" level13: BenchResult mlp13: nn.Sequential | None = None try: mlp13 = nn.Sequential( nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768), ).to(device) mlp13.train() x13 = torch.randn( NEMO_BATCH_SIZE, NEMO_SEQ_LEN, 768, device=device, dtype=torch.float16, ) def level13_step() -> None: mlp13.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): y13 = mlp13(x13) loss13 = y13.float().pow(2).mean() loss13.backward() level13 = run_step_benchmark( level="13", description=level13_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level13_step, device=device, ) except Exception as exc: level13 = result_from_exception("13", level13_desc, exc) finally: if mlp13 is not None: cleanup_cuda(mlp13) print_level_result(level13) results.append(level13) cleanup_cuda() level14_desc = "CausalSelfAttention alone fwd+bwd (AMP)" level14: BenchResult attn14: CausalSelfAttention | None = None try: cfg14 = NemotronKANConfig(n_embd=768, n_head=12, block_size=256, dropout=0.0) attn14 = CausalSelfAttention(cfg14).to(device) attn14.train() x14 = torch.randn( NEMO_BATCH_SIZE, NEMO_SEQ_LEN, 768, device=device, dtype=torch.float32, ) def level14_step() -> None: attn14.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", dtype=torch.float16): y14 = attn14(x14) loss14 = y14.float().pow(2).mean() loss14.backward() level14 = run_step_benchmark( level="14", description=level14_desc, tokens_per_step=tokens_per_step_nemo, step_fn=level14_step, device=device, ) except Exception as exc: level14 = result_from_exception("14", level14_desc, exc) finally: if attn14 is not None: cleanup_cuda(attn14) print_level_result(level14) results.append(level14) cleanup_cuda() print("\nDIAGNOSIS:") if level10.median_ms is not None: print(f" GRKANActivation overhead vs no-op: {level10.median_ms:.1f}ms") else: print(" GRKANActivation overhead vs no-op: N/A") if level12.median_ms is not None and level13.median_ms is not None: print( f" GR-KAN MLP vs Standard MLP: {level12.median_ms / level13.median_ms:.1f}x slower" ) else: print(" GR-KAN MLP vs Standard MLP: N/A") if level14.median_ms is not None: print(f" Attention: {level14.median_ms:.1f}ms per call") else: print(" Attention: N/A") if level12.median_ms is not None: print(f" MLP: {level12.median_ms:.1f}ms per call") else: print(" MLP: N/A") if level14.median_ms is not None and level12.median_ms is not None: print(f" Attention/MLP ratio: {level14.median_ms / level12.median_ms:.2f}") else: print(" Attention/MLP ratio: N/A") print_summary(results) if __name__ == "__main__": main()