No Kernels
First, we run the model without any custom kernels to get a reference point.
Forward
Forward and Backward
Next, we'll attempt to run a forward and backward pass without any custom kernels. This will likely run out of memory since the default implementation is not optimized for memory usage.
Kernels
Next we can run with Megablocks kernels enabled.
Forward
First, we run a forward pass with Megablocks kernels.
▼ code
▼ output
▶ uv-logs
|
Cell: forward_only | 118.48s | FAILED
|
Raw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "accelerate>=1.10.1",
# "torch>=2.7.0",
# "kernels==0.10.0",
# "transformers@https://github.com/huggingface/transformers.git",
# "ipdb>=0.13.13",
# "matplotlib>=3.7.2",
# "numpy>=1.24.3",
# ]
# ///
import torch
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
import time
import torch.nn as nn
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
import sys
import torch.profiler
import gc
import logging
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
# set to debug logging
logging.basicConfig(level=logging.INFO)
def reset_peak_memory_stats():
"""Clear CUDA cache and reset memory allocation counters."""
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
gc.collect()
def get_memory_stats():
"""Get current and peak CUDA memory usage."""
if not torch.cuda.is_available():
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
return {
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
}
def override_kernel_layer_name(cls_name: str, value) -> bool:
"""Helper to dynamically override the kernel_layer_name in a model class."""
for mod in sys.modules.values():
if mod is None:
continue
obj = getattr(mod, cls_name, None)
if isinstance(obj, type) and issubclass(obj, nn.Module):
setattr(obj, "kernel_layer_name", value)
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
return True
return False
# Init the model the normal way
model_id = "openai/gpt-oss-20b"
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
quantization_config = Mxfp4Config(dequantize=True)
model = GptOssForCausalLM.from_pretrained(
model_id,
dtype="bfloat16",
device_map="auto",
use_kernels=True,
quantization_config=quantization_config,
).eval()
messages = [
{"role": "system", "content": "What is Tensor Parallelism?"},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
reasoning_effort="low",
).to("cuda")
max_tokens = 256
with torch.inference_mode():
start_time = time.perf_counter()
generated = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=None,
)
end_time = time.perf_counter()
print(tokenizer.decode(generated[0], skip_special_tokens=False))
print(f"Generation took {end_time - start_time:.2f} seconds")
▶ UV Install Logs
Fetching 3 files: 0%| | 0/3 [00:00<?, ?it/s]
Fetching 3 files: 0%| | 0/3 [00:50<?, ?it/s]
Traceback (most recent call last):
File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks/.uvnote/cells/forward_only.py", line 68, in <module>
model = GptOssForCausalLM.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 285, in _wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4904, in from_pretrained
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1239, in _get_resolved_checkpoint_files
checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 1116, in get_checkpoint_shard_files
cached_filenames = cached_files(
^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 564, in cached_files
raise e
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 491, in cached_files
snapshot_download(
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 332, in snapshot_download
thread_map(
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 69, in thread_map
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 51, in _executor_map
return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 619, in result_iterator
yield _result_or_cancel(fs.pop())
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 317, in _result_or_cancel
return fut.result(timeout)
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 306, in _inner_hf_hub_download
return hf_hub_download(
^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1010, in hf_hub_download
return _hf_hub_download_to_cache_dir(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1171, in _hf_hub_download_to_cache_dir
_download_to_tmp_and_move(
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1723, in _download_to_tmp_and_move
xet_get(
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 629, in xet_get
download_files(
RuntimeError: Data processing error: CAS service error : IO Error: No space left on device (os error 28)
Forward and Backward
Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.