Comparison of Megablocks and Yamoe Kernels
This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.
Megablocks kernel
▼ code
▼ output
▶ uv-logs
|
Cell: setup2 | 18.93s | FAILED
|
Raw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "accelerate>=1.10.1",
# "torch>=2.7.0",
# "kernels==0.10.0",
# "transformers@https://github.com/huggingface/transformers.git",
# "ipdb>=0.13.13",
# "matplotlib>=3.7.2",
# "numpy>=1.24.3",
# ]
# ///
import torch
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
import time
import torch.nn as nn
from kernels import register_kernel_mapping, Mode, LayerRepository
import sys
import torch.profiler
import gc
import logging
# set to debug logging
logging.basicConfig(level=logging.INFO)
def reset_peak_memory_stats():
"""Clear CUDA cache and reset memory allocation counters."""
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
gc.collect()
def get_memory_stats():
"""Get current and peak CUDA memory usage."""
if not torch.cuda.is_available():
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
return {
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
}
def override_kernel_layer_name(cls_name: str, value) -> bool:
"""Helper to dynamically override the kernel_layer_name in a model class."""
for mod in sys.modules.values():
if mod is None:
continue
obj = getattr(mod, cls_name, None)
if isinstance(obj, type) and issubclass(obj, nn.Module):
setattr(obj, "kernel_layer_name", value)
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
return True
return False
# Init the model the normal way
model_id = "openai/gpt-oss-20b"
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
quantization_config = Mxfp4Config(dequantize=True)
from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
replace_kernel_forward_from_hub(GptOssRMSNorm, None) # direct, type-safe
custom_mapping = {
"Yamoe": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="drbh/yamoe",
layer_name="Yamoe",
revision="v0.3.0",
)
}
}
}
register_kernel_mapping(custom_mapping)
model = GptOssForCausalLM.from_pretrained(
model_id,
dtype="bfloat16",
device_map="auto",
use_kernels=True,
quantization_config=quantization_config,
).eval()
messages = [
{"role": "system", "content": "What is Tensor Parallelism?"},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
reasoning_effort="low",
).to("cuda")
max_tokens = 256
with torch.inference_mode():
start_time = time.perf_counter()
generated = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=None,
)
end_time = time.perf_counter()
print(tokenizer.decode(generated[0], skip_special_tokens=False))
print(f"Generation took {end_time - start_time:.2f} seconds")
Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB)
Downloading cpython-3.13.7-linux-x86_64-gnu (download)
Updating https://github.com/huggingface/transformers.git (HEAD)
Updated https://github.com/huggingface/transformers.git (e691f84412563b6abca098f3e044980725d8daa3)
× No solution found when resolving script dependencies:
╰─▶ Because only transformers==4.57.0.dev0 is available and
transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1,
we can conclude that all versions of transformers depend on
huggingface-hub==1.0.0rc1.
And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0,
we can conclude that kernels==0.10.0 and all versions of transformers
are incompatible.
And because you require kernels==0.10.0 and transformers, we can
conclude that your requirements are unsatisfiable.