No Kernels

First, we run the model without any custom kernels to get a reference point.

Forward

Forward and Backward

Next, we'll attempt to run a forward and backward pass without any custom kernels. This will likely run out of memory since the default implementation is not optimized for memory usage.

Kernels

Next we can run with Megablocks kernels enabled.

Forward

First, we run a forward pass with Megablocks kernels.

▼ code ▼ output ▶ uv-logs | Cell: forward_only | 118.48s | FAILED | Raw
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "accelerate>=1.10.1",
#     "torch>=2.7.0",
#     "kernels==0.10.0",
#     "transformers@https://github.com/huggingface/transformers.git",
#     "ipdb>=0.13.13",
#     "matplotlib>=3.7.2",
#     "numpy>=1.24.3",
# ]
# ///

import torch
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
import time
import torch.nn as nn
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
import sys
import torch.profiler
import gc
import logging
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm


replace_kernel_forward_from_hub(GptOssRMSNorm, None)

# set to debug logging
logging.basicConfig(level=logging.INFO)

def reset_peak_memory_stats():
    """Clear CUDA cache and reset memory allocation counters."""
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    gc.collect()

def get_memory_stats():
    """Get current and peak CUDA memory usage."""
    if not torch.cuda.is_available():
        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
    return {
        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
    }

def override_kernel_layer_name(cls_name: str, value) -> bool:
    """Helper to dynamically override the kernel_layer_name in a model class."""
    for mod in sys.modules.values():
        if mod is None:
            continue
        obj = getattr(mod, cls_name, None)
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            setattr(obj, "kernel_layer_name", value)
            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
            return True
    return False


# Init the model the normal way
model_id = "openai/gpt-oss-20b"
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
quantization_config = Mxfp4Config(dequantize=True)



model = GptOssForCausalLM.from_pretrained(
    model_id,
    dtype="bfloat16",
    device_map="auto",
    use_kernels=True,
    quantization_config=quantization_config,
).eval()

messages = [
    {"role": "system", "content": "What is Tensor Parallelism?"},
]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
    reasoning_effort="low",
).to("cuda")

max_tokens = 256

with torch.inference_mode():
    start_time = time.perf_counter()
    generated = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=False,
        temperature=None,
    )
    end_time = time.perf_counter()

print(tokenizer.decode(generated[0], skip_special_tokens=False))
print(f"Generation took {end_time - start_time:.2f} seconds")
▶ UV Install Logs
Fetching 3 files: 0%| | 0/3 [00:00<?, ?it/s] Fetching 3 files: 0%| | 0/3 [00:50<?, ?it/s] Traceback (most recent call last): File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks/.uvnote/cells/forward_only.py", line 68, in <module> model = GptOssForCausalLM.from_pretrained( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 285, in _wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4904, in from_pretrained checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1239, in _get_resolved_checkpoint_files checkpoint_files, sharded_metadata = get_checkpoint_shard_files( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 1116, in get_checkpoint_shard_files cached_filenames = cached_files( ^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 564, in cached_files raise e File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 491, in cached_files snapshot_download( File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 332, in snapshot_download thread_map( File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 69, in thread_map return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 51, in _executor_map return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__ for obj in iterable: File "/usr/lib/python3.12/concurrent/futures/_base.py", line 619, in result_iterator yield _result_or_cancel(fs.pop()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/lib/python3.12/concurrent/futures/_base.py", line 317, in _result_or_cancel return fut.result(timeout) ^^^^^^^^^^^^^^^^^^^ File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result return self.__get_result() ^^^^^^^^^^^^^^^^^^^ File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result raise self._exception File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 306, in _inner_hf_hub_download return hf_hub_download( ^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1010, in hf_hub_download return _hf_hub_download_to_cache_dir( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1171, in _hf_hub_download_to_cache_dir _download_to_tmp_and_move( File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1723, in _download_to_tmp_and_move xet_get( File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 629, in xet_get download_files( RuntimeError: Data processing error: CAS service error : IO Error: No space left on device (os error 28)

Forward and Backward

Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.