No Kernels
+First, we run the model without any custom kernels to get a reference point.
+Forward
+Forward and Backward
+Next, we'll attempt to run a forward and backward pass without any custom kernels. This will likely run out of memory since the default implementation is not optimized for memory usage.
+Kernels
+Next we can run with Megablocks kernels enabled.
+Forward
+First, we run a forward pass with Megablocks kernels.
+# /// script
+# requires-python = ">=3.12"
+# dependencies = [
+# "accelerate>=1.10.1",
+# "torch>=2.7.0",
+# "kernels==0.10.0",
+# "transformers@https://github.com/huggingface/transformers.git",
+# "ipdb>=0.13.13",
+# "matplotlib>=3.7.2",
+# "numpy>=1.24.3",
+# ]
+# ///
+
+import torch
+from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
+import time
+import torch.nn as nn
+from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
+import sys
+import torch.profiler
+import gc
+import logging
+from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
+
+
+replace_kernel_forward_from_hub(GptOssRMSNorm, None)
+
+# set to debug logging
+logging.basicConfig(level=logging.INFO)
+
+def reset_peak_memory_stats():
+ """Clear CUDA cache and reset memory allocation counters."""
+ torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats()
+ gc.collect()
+
+def get_memory_stats():
+ """Get current and peak CUDA memory usage."""
+ if not torch.cuda.is_available():
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
+ return {
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
+ }
+
+def override_kernel_layer_name(cls_name: str, value) -> bool:
+ """Helper to dynamically override the kernel_layer_name in a model class."""
+ for mod in sys.modules.values():
+ if mod is None:
+ continue
+ obj = getattr(mod, cls_name, None)
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
+ setattr(obj, "kernel_layer_name", value)
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
+ return True
+ return False
+
+
+# Init the model the normal way
+model_id = "openai/gpt-oss-20b"
+tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
+quantization_config = Mxfp4Config(dequantize=True)
+
+
+
+model = GptOssForCausalLM.from_pretrained(
+ model_id,
+ dtype="bfloat16",
+ device_map="auto",
+ use_kernels=True,
+ quantization_config=quantization_config,
+).eval()
+
+messages = [
+ {"role": "system", "content": "What is Tensor Parallelism?"},
+]
+
+inputs = tokenizer.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ return_dict=True,
+ reasoning_effort="low",
+).to("cuda")
+
+max_tokens = 256
+
+with torch.inference_mode():
+ start_time = time.perf_counter()
+ generated = model.generate(
+ **inputs,
+ max_new_tokens=max_tokens,
+ do_sample=False,
+ temperature=None,
+ )
+ end_time = time.perf_counter()
+
+print(tokenizer.decode(generated[0], skip_special_tokens=False))
+print(f"Generation took {end_time - start_time:.2f} seconds")
+Forward and Backward
+Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.
+