+

No Kernels

+

First, we run the model without any custom kernels to get a reference point.

+

Forward

+

Forward and Backward

+

Next, we'll attempt to run a forward and backward pass without any custom kernels. This will likely run out of memory since the default implementation is not optimized for memory usage.

+

Kernels

+

Next we can run with Megablocks kernels enabled.

+

Forward

+

First, we run a forward pass with Megablocks kernels.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: forward_only | 118.48s | FAILED + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +
+
+
# /// script
+# requires-python = ">=3.12"
+# dependencies = [
+#     "accelerate>=1.10.1",
+#     "torch>=2.7.0",
+#     "kernels==0.10.0",
+#     "transformers@https://github.com/huggingface/transformers.git",
+#     "ipdb>=0.13.13",
+#     "matplotlib>=3.7.2",
+#     "numpy>=1.24.3",
+# ]
+# ///
+
+import torch
+from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
+import time
+import torch.nn as nn
+from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
+import sys
+import torch.profiler
+import gc
+import logging
+from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
+
+
+replace_kernel_forward_from_hub(GptOssRMSNorm, None)
+
+# set to debug logging
+logging.basicConfig(level=logging.INFO)
+
+def reset_peak_memory_stats():
+    """Clear CUDA cache and reset memory allocation counters."""
+    torch.cuda.empty_cache()
+    if torch.cuda.is_available():
+        torch.cuda.reset_peak_memory_stats()
+    gc.collect()
+
+def get_memory_stats():
+    """Get current and peak CUDA memory usage."""
+    if not torch.cuda.is_available():
+        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
+    return {
+        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
+        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
+        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
+    }
+
+def override_kernel_layer_name(cls_name: str, value) -> bool:
+    """Helper to dynamically override the kernel_layer_name in a model class."""
+    for mod in sys.modules.values():
+        if mod is None:
+            continue
+        obj = getattr(mod, cls_name, None)
+        if isinstance(obj, type) and issubclass(obj, nn.Module):
+            setattr(obj, "kernel_layer_name", value)
+            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
+            return True
+    return False
+
+
+# Init the model the normal way
+model_id = "openai/gpt-oss-20b"
+tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
+quantization_config = Mxfp4Config(dequantize=True)
+
+
+
+model = GptOssForCausalLM.from_pretrained(
+    model_id,
+    dtype="bfloat16",
+    device_map="auto",
+    use_kernels=True,
+    quantization_config=quantization_config,
+).eval()
+
+messages = [
+    {"role": "system", "content": "What is Tensor Parallelism?"},
+]
+
+inputs = tokenizer.apply_chat_template(
+    messages,
+    add_generation_prompt=True,
+    return_tensors="pt",
+    return_dict=True,
+    reasoning_effort="low",
+).to("cuda")
+
+max_tokens = 256
+
+with torch.inference_mode():
+    start_time = time.perf_counter()
+    generated = model.generate(
+        **inputs,
+        max_new_tokens=max_tokens,
+        do_sample=False,
+        temperature=None,
+    )
+    end_time = time.perf_counter()
+
+print(tokenizer.decode(generated[0], skip_special_tokens=False))
+print(f"Generation took {end_time - start_time:.2f} seconds")
+
+ +
+
+
+
+
+
+
▶ UV Install Logs
+ +
+
Fetching 3 files: 0%| | 0/3 [00:00<?, ?it/s] +Fetching 3 files: 0%| | 0/3 [00:50<?, ?it/s] +Traceback (most recent call last): + File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks/.uvnote/cells/forward_only.py", line 68, in <module> + model = GptOssForCausalLM.from_pretrained( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 285, in _wrapper + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4904, in from_pretrained + checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1239, in _get_resolved_checkpoint_files + checkpoint_files, sharded_metadata = get_checkpoint_shard_files( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 1116, in get_checkpoint_shard_files + cached_filenames = cached_files( + ^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 564, in cached_files + raise e + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 491, in cached_files + snapshot_download( + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 332, in snapshot_download + thread_map( + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 69, in thread_map + return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 51, in _executor_map + return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__ + for obj in iterable: + File "/usr/lib/python3.12/concurrent/futures/_base.py", line 619, in result_iterator + yield _result_or_cancel(fs.pop()) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.12/concurrent/futures/_base.py", line 317, in _result_or_cancel + return fut.result(timeout) + ^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result + return self.__get_result() + ^^^^^^^^^^^^^^^^^^^ + File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result + raise self._exception + File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run + result = self.fn(*self.args, **self.kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 306, in _inner_hf_hub_download + return hf_hub_download( + ^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1010, in hf_hub_download + return _hf_hub_download_to_cache_dir( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1171, in _hf_hub_download_to_cache_dir + _download_to_tmp_and_move( + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1723, in _download_to_tmp_and_move + xet_get( + File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 629, in xet_get + download_files( +RuntimeError: Data processing error: CAS service error : IO Error: No space left on device (os error 28)
+
+
+ +

Forward and Backward

+

Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.

+