| |
| |
| from __future__ import annotations |
|
|
| import os |
|
|
| os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") |
| os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ.get("PATH", "") |
|
|
| import statistics |
| import time |
| import gc |
| from dataclasses import dataclass |
| from typing import Callable |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from train import ( |
| BATCH_SIZE, |
| DATASET_CONFIG, |
| DATASET_NAME, |
| SEQ_LEN, |
| StreamingTokenBuffer, |
| ) |
| from nemotron_kan import NemotronKAN, NemotronKANConfig |
| from nemotron_kan.layers.kan_layers import GRKANActivation, GRKANMLPReplacement |
| from nemotron_kan.model import CausalSelfAttention |
|
|
|
|
| ITERATIONS = 50 |
| WARMUP_COMPILE_STEPS = 3 |
| WARMUP_COMPILE_TRAIN_STEPS = 5 |
| NEMO_BATCH_SIZE = 2 |
| NEMO_SEQ_LEN = 256 |
| COMPILE_BATCH_SIZE = 1 |
|
|
|
|
| @dataclass |
| class BenchResult: |
| level: str |
| description: str |
| median_tok_s: float | None |
| median_ms: float | None |
| status: str = "ok" |
| error: str | None = None |
|
|
|
|
| class TinyTransformerLM(nn.Module): |
| def __init__( |
| self, vocab_size: int, d_model: int = 256, nhead: int = 4, layers: int = 4 |
| ): |
| super().__init__() |
| self.tok_emb = nn.Embedding(vocab_size, d_model) |
| self.pos_emb = nn.Embedding(SEQ_LEN, d_model) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=d_model * 4, |
| dropout=0.0, |
| activation="gelu", |
| batch_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers) |
| self.lm_head = nn.Linear(d_model, vocab_size) |
|
|
| def forward(self, idx: torch.Tensor) -> torch.Tensor: |
| b, t = idx.shape |
| pos = ( |
| torch.arange(0, t, device=idx.device, dtype=torch.long) |
| .unsqueeze(0) |
| .expand(b, t) |
| ) |
| x = self.tok_emb(idx) + self.pos_emb(pos) |
| x = self.encoder(x) |
| return self.lm_head(x) |
|
|
|
|
| class VanillaGPT2Block(nn.Module): |
| def __init__(self, n_embd: int, n_head: int): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(n_embd) |
| self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True) |
| self.ln2 = nn.LayerNorm(n_embd) |
| self.mlp = nn.Sequential( |
| nn.Linear(n_embd, 4 * n_embd), |
| nn.GELU(), |
| nn.Linear(4 * n_embd, n_embd), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.ln1(x) |
| x = x + self.attn(h, h, h, need_weights=False)[0] |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class VanillaGPT2(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int = 50304, |
| n_layer: int = 12, |
| n_embd: int = 768, |
| n_head: int = 12, |
| block_size: int = 256, |
| ): |
| super().__init__() |
| self.wte = nn.Embedding(vocab_size, n_embd) |
| self.wpe = nn.Embedding(block_size, n_embd) |
| self.blocks = nn.ModuleList( |
| [VanillaGPT2Block(n_embd, n_head) for _ in range(n_layer)] |
| ) |
| self.ln_f = nn.LayerNorm(n_embd) |
| self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) |
| self.wte.weight = self.lm_head.weight |
|
|
| def forward(self, idx: torch.Tensor) -> torch.Tensor: |
| _b, t = idx.shape |
| pos = torch.arange(t, device=idx.device, dtype=torch.long).unsqueeze(0) |
| x = self.wte(idx) + self.wpe(pos) |
| for block in self.blocks: |
| x = block(x) |
| x = self.ln_f(x) |
| return self.lm_head(x) |
|
|
|
|
| def cleanup_cuda(*objs: object) -> None: |
| for obj in objs: |
| del obj |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| def result_from_exception(level: str, description: str, exc: Exception) -> BenchResult: |
| msg = str(exc) |
| if isinstance(exc, RuntimeError) and "out of memory" in msg.lower(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| return BenchResult(level, description, None, None, status="oom", error=msg) |
| gc.collect() |
| return BenchResult(level, description, None, None, status="error", error=msg) |
|
|
|
|
| def gpu_sync(device: torch.device) -> None: |
| if device.type == "cuda": |
| torch.cuda.synchronize(device) |
|
|
|
|
| def run_step_benchmark( |
| level: str, |
| description: str, |
| tokens_per_step: int, |
| step_fn: Callable[[], object], |
| device: torch.device, |
| warmup_steps: int = 0, |
| ) -> BenchResult: |
| try: |
| for _ in range(warmup_steps): |
| step_fn() |
| gpu_sync(device) |
|
|
| step_times_ms: list[float] = [] |
| tok_s: list[float] = [] |
| for _ in range(ITERATIONS): |
| gpu_sync(device) |
| t0 = time.perf_counter() |
| step_fn() |
| gpu_sync(device) |
| dt = (time.perf_counter() - t0) * 1000.0 |
| step_times_ms.append(dt) |
| tok_s.append(tokens_per_step / (dt / 1000.0)) |
|
|
| return BenchResult( |
| level=level, |
| description=description, |
| median_tok_s=statistics.median(tok_s), |
| median_ms=statistics.median(step_times_ms), |
| ) |
| except RuntimeError as exc: |
| msg = str(exc) |
| if "out of memory" in msg.lower(): |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
| return BenchResult( |
| level=level, |
| description=description, |
| median_tok_s=None, |
| median_ms=None, |
| status="oom", |
| error=msg, |
| ) |
| return BenchResult( |
| level=level, |
| description=description, |
| median_tok_s=None, |
| median_ms=None, |
| status="error", |
| error=msg, |
| ) |
|
|
|
|
| def print_level_result(result: BenchResult) -> None: |
| if result.status == "ok": |
| print( |
| f"Level {result.level}: {result.description} | {result.median_tok_s:,.0f} tok/s | {result.median_ms:.1f} ms/step" |
| ) |
| elif result.status == "oom": |
| print(f"Level {result.level}: {result.description} | OOM | OOM") |
| else: |
| print(f"Level {result.level}: {result.description} | ERROR | ERROR") |
| if result.error: |
| print(f" error: {result.error}") |
|
|
|
|
| def print_summary(results: list[BenchResult]) -> None: |
| print("\n=== Throughput Summary ===") |
| print(f"{'Level':<8} {'Description':<50} {'Median tok/s':>14} {'Median ms':>12}") |
| print("-" * 90) |
| for r in results: |
| tok_s = ( |
| f"{r.median_tok_s:,.0f}" if r.median_tok_s is not None else r.status.upper() |
| ) |
| ms = f"{r.median_ms:.1f}" if r.median_ms is not None else r.status.upper() |
| print(f"{r.level:<8} {r.description:<50} {tok_s:>14} {ms:>12}") |
|
|
| print("\n=== Bottleneck Ratios (vs previous level) ===") |
| for i in range(1, len(results)): |
| prev = results[i - 1] |
| curr = results[i] |
| if prev.median_tok_s is None or curr.median_tok_s is None: |
| ratio_txt = "N/A" |
| else: |
| ratio = curr.median_tok_s / prev.median_tok_s |
| ratio_txt = f"{ratio:.3f}x" |
| print(f"L{curr.level} / L{prev.level}: {ratio_txt}") |
|
|
| valid = [r for r in results if r.median_tok_s is not None] |
| print("\n=== Succeeded Levels ===") |
| if valid: |
| print(", ".join([f"L{r.level}" for r in valid])) |
| else: |
| print("None") |
|
|
| if valid: |
| bottleneck = min(valid, key=lambda x: x.median_tok_s or 0.0) |
| best = max(valid, key=lambda x: x.median_tok_s or 0.0) |
| print("\n=== Bottleneck ===") |
| print( |
| f"Slowest measured stage: Level {bottleneck.level} ({bottleneck.description}) at {bottleneck.median_tok_s:,.0f} tok/s" |
| ) |
| print( |
| f"Best measured stage: Level {best.level} ({best.description}) at {best.median_tok_s:,.0f} tok/s" |
| ) |
| if bottleneck.median_tok_s and best.median_tok_s: |
| print( |
| f"Gap (best/slowest): {best.median_tok_s / bottleneck.median_tok_s:.2f}x" |
| ) |
|
|
|
|
| def make_level4_config( |
| *, |
| block_size: int, |
| gradient_checkpointing: bool = False, |
| mhc: bool = False, |
| hc_num_streams: int = 2, |
| hc_disable: bool = False, |
| ) -> NemotronKANConfig: |
| return NemotronKANConfig( |
| n_layer=12, |
| n_embd=768, |
| n_head=12, |
| vocab_size=50304, |
| block_size=block_size, |
| dropout=0.0, |
| kan_type="grkan", |
| kan_num_grids=4, |
| kan_hidden_mult=4, |
| kan_grkan_num_groups=8, |
| use_engram=False, |
| use_sdr_compression=False, |
| use_axiomatic_attention=False, |
| use_holographic_memory=False, |
| use_jit_cache=False, |
| use_synaptic_offload=False, |
| mhc=mhc, |
| hc_num_streams=hc_num_streams, |
| hc_disable=hc_disable, |
| gradient_checkpointing=gradient_checkpointing, |
| use_fused_ce=False, |
| ) |
|
|
|
|
| def main() -> None: |
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required for this benchmark. No GPU detected.") |
|
|
| device = torch.device("cuda") |
| cpu = torch.device("cpu") |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision("high") |
|
|
| print(f"GPU: {torch.cuda.get_device_name(device)}") |
| print(f"CUDA: {torch.version.cuda}") |
| print(f"Levels 0-3 batch size: {BATCH_SIZE}, seq len: {SEQ_LEN}") |
| print(f"Levels 4-6 batch size: {NEMO_BATCH_SIZE}, seq len: {NEMO_SEQ_LEN}") |
| print(f"Levels 7-9 batch size: {COMPILE_BATCH_SIZE}, seq len: {NEMO_SEQ_LEN}") |
| print(f"Iterations: {ITERATIONS}") |
| print() |
|
|
| tokens_per_step_main = BATCH_SIZE * SEQ_LEN |
| tokens_per_step_nemo = NEMO_BATCH_SIZE * NEMO_SEQ_LEN |
| tokens_per_step_compile = COMPILE_BATCH_SIZE * NEMO_SEQ_LEN |
| results: list[BenchResult] = [] |
|
|
| stream = StreamingTokenBuffer( |
| split="train", |
| block_size=SEQ_LEN, |
| batch_size=BATCH_SIZE, |
| dataset_name=DATASET_NAME, |
| dataset_config=DATASET_CONFIG, |
| accelerator=None, |
| ) |
| stream.prefill(32) |
|
|
| level0 = run_step_benchmark( |
| level="0", |
| description="Pure data loading (CPU)", |
| tokens_per_step=tokens_per_step_main, |
| step_fn=lambda: stream.get_batch(cpu), |
| device=cpu, |
| ) |
| print_level_result(level0) |
| results.append(level0) |
| cleanup_cuda() |
|
|
| level1 = run_step_benchmark( |
| level="1", |
| description="Data loading + CUDA transfer", |
| tokens_per_step=tokens_per_step_main, |
| step_fn=lambda: stream.get_batch(device), |
| device=device, |
| ) |
| print_level_result(level1) |
| results.append(level1) |
| cleanup_cuda() |
|
|
| x_static = torch.randint( |
| 0, 50304, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long |
| ) |
| y_static = torch.randint( |
| 0, 50304, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.long |
| ) |
| x_nemo = torch.randint( |
| 0, 50304, (NEMO_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long |
| ) |
| y_nemo = torch.randint( |
| 0, 50304, (NEMO_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long |
| ) |
| x_compile = torch.randint( |
| 0, 50304, (COMPILE_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long |
| ) |
| y_compile = torch.randint( |
| 0, 50304, (COMPILE_BATCH_SIZE, NEMO_SEQ_LEN), device=device, dtype=torch.long |
| ) |
|
|
| level2_desc = "Tiny transformer forward only" |
| tiny_model: TinyTransformerLM | None = None |
| try: |
| tiny_model = TinyTransformerLM(vocab_size=50304).to(device) |
| tiny_model.train() |
|
|
| def level2_step() -> None: |
| with torch.inference_mode(): |
| tiny_model(x_static) |
|
|
| level2 = run_step_benchmark( |
| level="2", |
| description=level2_desc, |
| tokens_per_step=tokens_per_step_main, |
| step_fn=level2_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level2 = result_from_exception("2", level2_desc, exc) |
| finally: |
| if tiny_model is not None: |
| cleanup_cuda(tiny_model) |
| print_level_result(level2) |
| results.append(level2) |
| cleanup_cuda() |
|
|
| level3_desc = "Tiny transformer forward + backward" |
| tiny_model_bwd: TinyTransformerLM | None = None |
| tiny_optim: torch.optim.Optimizer | None = None |
| try: |
| tiny_model_bwd = TinyTransformerLM(vocab_size=50304).to(device) |
| tiny_model_bwd.train() |
| tiny_optim = torch.optim.AdamW(tiny_model_bwd.parameters(), lr=1e-3) |
|
|
| def level3_step() -> None: |
| tiny_optim.zero_grad(set_to_none=True) |
| logits = tiny_model_bwd(x_static) |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), y_static.view(-1) |
| ) |
| loss.backward() |
|
|
| level3 = run_step_benchmark( |
| level="3", |
| description=level3_desc, |
| tokens_per_step=tokens_per_step_main, |
| step_fn=level3_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level3 = result_from_exception("3", level3_desc, exc) |
| finally: |
| if tiny_model_bwd is not None and tiny_optim is not None: |
| cleanup_cuda(tiny_model_bwd, tiny_optim) |
| print_level_result(level3) |
| results.append(level3) |
| cleanup_cuda() |
|
|
| level3c_desc = "Vanilla GPT-2 124M fwd+bwd (AMP)" |
| model3c: VanillaGPT2 | None = None |
| try: |
| model3c = VanillaGPT2(block_size=NEMO_SEQ_LEN).to(device) |
| model3c.train() |
| params3c = sum(p.numel() for p in model3c.parameters()) |
| print(f"Level 3c param count: {params3c:,}") |
|
|
| def level3c_step() -> None: |
| model3c.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| logits = model3c(x_nemo) |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), y_nemo.view(-1) |
| ) |
| loss.backward() |
|
|
| level3c = run_step_benchmark( |
| level="3c", |
| description=level3c_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level3c_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level3c = result_from_exception("3c", level3c_desc, exc) |
| finally: |
| if model3c is not None: |
| cleanup_cuda(model3c) |
| print_level_result(level3c) |
| results.append(level3c) |
| cleanup_cuda() |
|
|
| level3d_desc = "Vanilla GPT-2 124M + torch.compile (AMP)" |
| model3d: object | None = None |
| try: |
| model3d = VanillaGPT2(block_size=NEMO_SEQ_LEN).to(device) |
| model3d.train() |
| model3d = torch.compile(model3d, backend="inductor") |
|
|
| def level3d_step() -> None: |
| model3d.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| logits = model3d(x_nemo) |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), y_nemo.view(-1) |
| ) |
| loss.backward() |
|
|
| level3d = run_step_benchmark( |
| level="3d", |
| description=level3d_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level3d_step, |
| device=device, |
| warmup_steps=WARMUP_COMPILE_TRAIN_STEPS, |
| ) |
| except Exception as exc: |
| level3d = result_from_exception("3d", level3d_desc, exc) |
| finally: |
| if model3d is not None: |
| cleanup_cuda(model3d) |
| print_level_result(level3d) |
| results.append(level3d) |
| cleanup_cuda() |
|
|
| level4_desc = "NemotronKAN fwd+bwd (all trees off, AMP)" |
| model4: NemotronKAN | None = None |
| try: |
| model4 = NemotronKAN(make_level4_config(block_size=NEMO_SEQ_LEN)).to(device) |
| model4.train() |
|
|
| def level4_step() -> None: |
| model4.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model4(x_nemo, y_nemo) |
| if loss is None: |
| raise RuntimeError("Level 4 loss is None") |
| loss.backward() |
|
|
| level4 = run_step_benchmark( |
| level="4", |
| description=level4_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level4_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level4 = result_from_exception("4", level4_desc, exc) |
| finally: |
| if model4 is not None: |
| cleanup_cuda(model4) |
| print_level_result(level4) |
| results.append(level4) |
| cleanup_cuda() |
|
|
| level4b_desc = "NemotronKAN fwd+bwd + grad checkpointing (AMP)" |
| model4b: NemotronKAN | None = None |
| try: |
| model4b = NemotronKAN( |
| make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) |
| ).to(device) |
| model4b.train() |
|
|
| def level4b_step() -> None: |
| model4b.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model4b(x_nemo, y_nemo) |
| if loss is None: |
| raise RuntimeError("Level 4b loss is None") |
| loss.backward() |
|
|
| level4b = run_step_benchmark( |
| level="4b", |
| description=level4b_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level4b_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level4b = result_from_exception("4b", level4b_desc, exc) |
| finally: |
| if model4b is not None: |
| cleanup_cuda(model4b) |
| print_level_result(level4b) |
| results.append(level4b) |
| cleanup_cuda() |
|
|
| level4c_desc = "NemotronKAN fwd+bwd + HC disabled (AMP)" |
| model4c: NemotronKAN | None = None |
| try: |
| model4c = NemotronKAN( |
| make_level4_config( |
| block_size=NEMO_SEQ_LEN, |
| gradient_checkpointing=False, |
| mhc=False, |
| hc_num_streams=1, |
| hc_disable=True, |
| ) |
| ).to(device) |
| model4c.train() |
|
|
| def level4c_step() -> None: |
| model4c.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model4c(x_nemo, y_nemo) |
| if loss is None: |
| raise RuntimeError("Level 4c loss is None") |
| loss.backward() |
|
|
| level4c = run_step_benchmark( |
| level="4c", |
| description=level4c_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level4c_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level4c = result_from_exception("4c", level4c_desc, exc) |
| finally: |
| if model4c is not None: |
| cleanup_cuda(model4c) |
| print_level_result(level4c) |
| results.append(level4c) |
| cleanup_cuda() |
|
|
| level5_desc = "NemotronKAN + HyperConnections (AMP)" |
| model5: NemotronKAN | None = None |
| try: |
| model5 = NemotronKAN( |
| make_level4_config(block_size=NEMO_SEQ_LEN, mhc=True, hc_num_streams=2) |
| ).to(device) |
| model5.train() |
|
|
| def level5_step() -> None: |
| model5.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model5(x_nemo, y_nemo) |
| if loss is None: |
| raise RuntimeError("Level 5 loss is None") |
| loss.backward() |
|
|
| level5 = run_step_benchmark( |
| level="5", |
| description=level5_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level5_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level5 = result_from_exception("5", level5_desc, exc) |
| finally: |
| if model5 is not None: |
| cleanup_cuda(model5) |
| print_level_result(level5) |
| results.append(level5) |
| cleanup_cuda() |
|
|
| level6_desc = "NemotronKAN + gradient checkpointing (AMP)" |
| model6: NemotronKAN | None = None |
| try: |
| model6 = NemotronKAN( |
| make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) |
| ).to(device) |
| model6.train() |
|
|
| def level6_step() -> None: |
| model6.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model6(x_nemo, y_nemo) |
| if loss is None: |
| raise RuntimeError("Level 6 loss is None") |
| loss.backward() |
|
|
| level6 = run_step_benchmark( |
| level="6", |
| description=level6_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level6_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level6 = result_from_exception("6", level6_desc, exc) |
| finally: |
| if model6 is not None: |
| cleanup_cuda(model6) |
| print_level_result(level6) |
| results.append(level6) |
| cleanup_cuda() |
|
|
| level7_desc = "NemotronKAN + torch.compile (AMP)" |
| model7: object | None = None |
| try: |
| model7 = NemotronKAN(make_level4_config(block_size=NEMO_SEQ_LEN)).to(device) |
| model7.train() |
| model7 = torch.compile(model7, backend="inductor") |
|
|
| def level7_step() -> None: |
| model7.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model7(x_compile, y_compile) |
| if loss is None: |
| raise RuntimeError("Level 7 loss is None") |
| loss.backward() |
|
|
| level7 = run_step_benchmark( |
| level="7", |
| description=level7_desc, |
| tokens_per_step=tokens_per_step_compile, |
| step_fn=level7_step, |
| device=device, |
| warmup_steps=WARMUP_COMPILE_STEPS, |
| ) |
| except Exception as exc: |
| level7 = result_from_exception("7", level7_desc, exc) |
| finally: |
| if model7 is not None: |
| cleanup_cuda(model7) |
| print_level_result(level7) |
| results.append(level7) |
| cleanup_cuda() |
|
|
| level8_desc = "Full training step (AMP+AdamW fused+ckpt)" |
| model8: NemotronKAN | None = None |
| optim8: torch.optim.Optimizer | None = None |
| scaler8: torch.amp.GradScaler | None = None |
| try: |
| model8 = NemotronKAN( |
| make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) |
| ).to(device) |
| model8.train() |
| optim8 = torch.optim.AdamW(model8.parameters(), lr=1e-4, fused=True) |
| scaler8 = torch.amp.GradScaler("cuda", enabled=True) |
|
|
| def level8_step() -> None: |
| optim8.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model8(x_compile, y_compile) |
| if loss is None: |
| raise RuntimeError("Level 8 loss is None") |
| scaler8.scale(loss).backward() |
| scaler8.unscale_(optim8) |
| torch.nn.utils.clip_grad_norm_(model8.parameters(), max_norm=1.0) |
| scaler8.step(optim8) |
| scaler8.update() |
|
|
| level8 = run_step_benchmark( |
| level="8", |
| description=level8_desc, |
| tokens_per_step=tokens_per_step_compile, |
| step_fn=level8_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level8 = result_from_exception("8", level8_desc, exc) |
| finally: |
| if model8 is not None and optim8 is not None and scaler8 is not None: |
| cleanup_cuda(model8, optim8, scaler8) |
| print_level_result(level8) |
| results.append(level8) |
| cleanup_cuda() |
|
|
| level9_desc = "Full training step + torch.compile (AMP+AdamW fused+ckpt)" |
| model9: object | None = None |
| optim9: torch.optim.Optimizer | None = None |
| scaler9: torch.amp.GradScaler | None = None |
| try: |
| model9 = NemotronKAN( |
| make_level4_config(block_size=NEMO_SEQ_LEN, gradient_checkpointing=True) |
| ).to(device) |
| model9.train() |
| model9 = torch.compile(model9, backend="inductor") |
| optim9 = torch.optim.AdamW(model9.parameters(), lr=1e-4, fused=True) |
| scaler9 = torch.amp.GradScaler("cuda", enabled=True) |
|
|
| def level9_step() -> None: |
| optim9.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| _logits, loss, _info = model9(x_compile, y_compile) |
| if loss is None: |
| raise RuntimeError("Level 9 loss is None") |
| scaler9.scale(loss).backward() |
| scaler9.unscale_(optim9) |
| torch.nn.utils.clip_grad_norm_(model9.parameters(), max_norm=1.0) |
| scaler9.step(optim9) |
| scaler9.update() |
|
|
| level9 = run_step_benchmark( |
| level="9", |
| description=level9_desc, |
| tokens_per_step=tokens_per_step_compile, |
| step_fn=level9_step, |
| device=device, |
| warmup_steps=WARMUP_COMPILE_TRAIN_STEPS, |
| ) |
| except Exception as exc: |
| level9 = result_from_exception("9", level9_desc, exc) |
| finally: |
| if model9 is not None and optim9 is not None and scaler9 is not None: |
| cleanup_cuda(model9, optim9, scaler9) |
| print_level_result(level9) |
| results.append(level9) |
| cleanup_cuda() |
|
|
| level10_desc = "GRKANActivation only (Triton kernel)" |
| level10: BenchResult |
| act10: GRKANActivation | None = None |
| try: |
| act10 = GRKANActivation(num_groups=8).to(device) |
| x10 = torch.randn( |
| NEMO_BATCH_SIZE, |
| NEMO_SEQ_LEN, |
| 768, |
| device=device, |
| dtype=torch.float32, |
| requires_grad=False, |
| ) |
|
|
| def level10_step() -> None: |
| _ = act10(x10) |
|
|
| level10 = run_step_benchmark( |
| level="10", |
| description=level10_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level10_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level10 = result_from_exception("10", level10_desc, exc) |
| finally: |
| if act10 is not None: |
| cleanup_cuda(act10) |
| print_level_result(level10) |
| results.append(level10) |
| cleanup_cuda() |
|
|
| level11_desc = "nn.Linear 768->3072 only" |
| level11: BenchResult |
| fc11: nn.Linear | None = None |
| try: |
| fc11 = nn.Linear(768, 3072).to(device).half() |
| x11 = torch.randn( |
| NEMO_BATCH_SIZE, |
| NEMO_SEQ_LEN, |
| 768, |
| device=device, |
| dtype=torch.float16, |
| requires_grad=False, |
| ) |
|
|
| def level11_step() -> None: |
| _ = fc11(x11) |
|
|
| level11 = run_step_benchmark( |
| level="11", |
| description=level11_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level11_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level11 = result_from_exception("11", level11_desc, exc) |
| finally: |
| if fc11 is not None: |
| cleanup_cuda(fc11) |
| print_level_result(level11) |
| results.append(level11) |
| cleanup_cuda() |
|
|
| level12_desc = "GRKANMLPReplacement alone fwd+bwd (AMP)" |
| level12: BenchResult |
| mlp12: GRKANMLPReplacement | None = None |
| try: |
| mlp12 = GRKANMLPReplacement( |
| 768, |
| hidden_mult=4, |
| num_groups=8, |
| dropout=0.0, |
| use_checkpoint=False, |
| ).to(device) |
| mlp12.train() |
| x12 = torch.randn( |
| NEMO_BATCH_SIZE, |
| NEMO_SEQ_LEN, |
| 768, |
| device=device, |
| dtype=torch.float32, |
| ) |
|
|
| def level12_step() -> None: |
| mlp12.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| y12 = mlp12(x12) |
| loss12 = y12.float().pow(2).mean() |
| loss12.backward() |
|
|
| level12 = run_step_benchmark( |
| level="12", |
| description=level12_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level12_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level12 = result_from_exception("12", level12_desc, exc) |
| finally: |
| if mlp12 is not None: |
| cleanup_cuda(mlp12) |
| print_level_result(level12) |
| results.append(level12) |
| cleanup_cuda() |
|
|
| level13_desc = "Standard MLP alone fwd+bwd (AMP)" |
| level13: BenchResult |
| mlp13: nn.Sequential | None = None |
| try: |
| mlp13 = nn.Sequential( |
| nn.Linear(768, 3072), |
| nn.GELU(), |
| nn.Linear(3072, 768), |
| ).to(device) |
| mlp13.train() |
| x13 = torch.randn( |
| NEMO_BATCH_SIZE, |
| NEMO_SEQ_LEN, |
| 768, |
| device=device, |
| dtype=torch.float16, |
| ) |
|
|
| def level13_step() -> None: |
| mlp13.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| y13 = mlp13(x13) |
| loss13 = y13.float().pow(2).mean() |
| loss13.backward() |
|
|
| level13 = run_step_benchmark( |
| level="13", |
| description=level13_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level13_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level13 = result_from_exception("13", level13_desc, exc) |
| finally: |
| if mlp13 is not None: |
| cleanup_cuda(mlp13) |
| print_level_result(level13) |
| results.append(level13) |
| cleanup_cuda() |
|
|
| level14_desc = "CausalSelfAttention alone fwd+bwd (AMP)" |
| level14: BenchResult |
| attn14: CausalSelfAttention | None = None |
| try: |
| cfg14 = NemotronKANConfig(n_embd=768, n_head=12, block_size=256, dropout=0.0) |
| attn14 = CausalSelfAttention(cfg14).to(device) |
| attn14.train() |
| x14 = torch.randn( |
| NEMO_BATCH_SIZE, |
| NEMO_SEQ_LEN, |
| 768, |
| device=device, |
| dtype=torch.float32, |
| ) |
|
|
| def level14_step() -> None: |
| attn14.zero_grad(set_to_none=True) |
| with torch.amp.autocast("cuda", dtype=torch.float16): |
| y14 = attn14(x14) |
| loss14 = y14.float().pow(2).mean() |
| loss14.backward() |
|
|
| level14 = run_step_benchmark( |
| level="14", |
| description=level14_desc, |
| tokens_per_step=tokens_per_step_nemo, |
| step_fn=level14_step, |
| device=device, |
| ) |
| except Exception as exc: |
| level14 = result_from_exception("14", level14_desc, exc) |
| finally: |
| if attn14 is not None: |
| cleanup_cuda(attn14) |
| print_level_result(level14) |
| results.append(level14) |
| cleanup_cuda() |
|
|
| print("\nDIAGNOSIS:") |
| if level10.median_ms is not None: |
| print(f" GRKANActivation overhead vs no-op: {level10.median_ms:.1f}ms") |
| else: |
| print(" GRKANActivation overhead vs no-op: N/A") |
| if level12.median_ms is not None and level13.median_ms is not None: |
| print( |
| f" GR-KAN MLP vs Standard MLP: {level12.median_ms / level13.median_ms:.1f}x slower" |
| ) |
| else: |
| print(" GR-KAN MLP vs Standard MLP: N/A") |
| if level14.median_ms is not None: |
| print(f" Attention: {level14.median_ms:.1f}ms per call") |
| else: |
| print(" Attention: N/A") |
| if level12.median_ms is not None: |
| print(f" MLP: {level12.median_ms:.1f}ms per call") |
| else: |
| print(" MLP: N/A") |
| if level14.median_ms is not None and level12.median_ms is not None: |
| print(f" Attention/MLP ratio: {level14.median_ms / level12.median_ms:.2f}") |
| else: |
| print(" Attention/MLP ratio: N/A") |
|
|
| print_summary(results) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|