gpt-oss-puzzle-88B / modeling_gpt_oss_puzzle.py
itlevy's picture
initial commit
3f4b61d
from typing import Any, Iterable, Optional, Union
from dataclasses import dataclass
import functools
import inspect
from .configuration_gpt_oss_puzzle import GptOssPuzzleConfig
import torch
from transformers.cache_utils import Cache, DynamicCache, DynamicLayer, DynamicSlidingWindowLayer
from transformers.integrations import mxfp4
from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
from transformers.masking_utils import create_sliding_window_causal_mask
from transformers.models.gpt_oss import modeling_gpt_oss
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssForCausalLM
@dataclass
class SlidingWindowCausalMaskPlaceholder:
kwargs: dict[str, Any]
class GptOssPuzzleDecoderLayer(GptOssDecoderLayer):
"""
Extends GptOssDecoderLayer to support per-layer configs.
"""
def __init__(self, config: GptOssPuzzleConfig, layer_idx: int):
layer_config = config.get_gpt_oss_config_for_layer(layer_idx)
super().__init__(layer_config, layer_idx)
self.config = layer_config
self.layer_idx = layer_idx
def forward(self, *args, **kwargs):
if "attention_mask" in kwargs and isinstance(kwargs["attention_mask"], SlidingWindowCausalMaskPlaceholder):
mask_kwargs = dict(kwargs["attention_mask"].kwargs)
mask_kwargs["config"] = self.config
if mask_kwargs["past_key_values"] is not None:
mask_kwargs["past_key_values"] = CacheViewForSlidingWindowMask(
mask_kwargs["past_key_values"], self.layer_idx
)
kwargs["attention_mask"] = create_sliding_window_causal_mask(**mask_kwargs)
return super().forward(*args, **kwargs)
class CacheViewForSlidingWindowMask:
"""
A view wrapper around a Cache that makes `create_sliding_window_causal_mask` use the correct layer index.
`create_sliding_window_causal_mask` iterates over `past_key_values.is_sliding` to determine which layer
to use for deriving mask sizes, effectively using the first layer's index. Since gpt-oss-puzzle has
heterogeneous sliding window sizes across layers, we need to ensure each layer uses its own sliding
window size. This view returns an `is_sliding` list that only marks the current layer as sliding,
causing `create_sliding_window_causal_mask` to use the correct layer index for mask computation.
"""
def __init__(self, cache: Cache, layer_idx: int):
self._cache = cache
self._layer_idx = layer_idx
@property
def is_sliding(self) -> list[bool]:
return [False] * self._layer_idx + [True]
def __getattr__(self, name: str):
return getattr(self._cache, name)
class Mxfp4GptOssPuzzleExperts(Mxfp4GptOssExperts):
def __init__(self, config: GptOssPuzzleConfig):
"""
Extends Mxfp4GptOssExperts to support per-layer configs.
Since this class is created without passing the layer index, we need to infer it from the call stack.
"""
# module_name is of the form *.{layer_idx}.mlp.experts
current_key_name = _get_variable_from_stack(["current_key_name"])
if current_key_name is None:
module_name = _get_variable_from_stack(["module_name"])
if module_name is None:
raise RuntimeError("`current_key_name`/`module_name` variable not found in caller stack")
layer_idx = int(module_name.split(".")[-3])
else:
layer_idx = int(current_key_name[-3])
layer_config = config.get_gpt_oss_config_for_layer(layer_idx)
super().__init__(layer_config)
def _get_variable_from_stack(names: list[str]) -> str | None:
f = inspect.currentframe().f_back
while f:
for name in names:
if name in f.f_locals:
return f.f_locals[name]
f = f.f_back
return None
class PuzzleDynamicCache(DynamicCache):
"""
A child class of DynamicCache that supports heterogeneous layer configurations.
__init__ is the same as in DynamicCache, except for the usage of sliding window which is obtained per layer from `block_configs`.
"""
def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
config: Optional[GptOssPuzzleConfig] = None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
):
layers = []
# If a config is passed, use it to infer the layer types and initialize accordingly
if config is not None:
decoder_config = config.get_text_config(decoder=True)
layer_types = getattr(decoder_config, "layer_types", None)
if layer_types is None:
layer_types = []
for layer_idx in range(decoder_config.num_hidden_layers):
sliding_window = None
for attr_name in ("sliding_window", "attention_chunk_size"):
sliding_window = getattr(
config.block_configs[layer_idx],
attr_name,
getattr(decoder_config, attr_name, None),
)
if sliding_window is not None:
break
layer_types.append("sliding_attention" if sliding_window is not None else "full_attention")
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
if hasattr(decoder_config, "num_kv_shared_layers"):
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
for layer_idx, layer_type in enumerate(layer_types):
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
# states they should return - only the mask changes to make them different at the end!
if layer_type in ("sliding_attention", "chunked_attention"):
sliding_window = None
for attr_name in ("sliding_window", "attention_chunk_size"):
sliding_window = getattr(
decoder_config.block_configs[layer_idx],
attr_name,
getattr(decoder_config, attr_name, None),
)
if sliding_window is not None:
break
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
else:
layers.append(DynamicLayer())
# In this case, use the passed data to already fill in the Cache
if ddp_cache_data is not None:
# Init all the layers with the data
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
if config is None:
layers.append(DynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)
# If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
if len(layers) == 0:
super(DynamicCache, self).__init__(
layer_class_to_replicate=DynamicLayer,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
)
else:
super(DynamicCache, self).__init__(
layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding
)
original_load_balancing_loss_func = modeling_gpt_oss.load_balancing_loss_func
def load_balancing_loss_func(
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
num_experts_per_layer: tuple[int, ...] = None,
) -> Union[torch.Tensor, int]:
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
compute_device = gate_logits[0].device
overall_loss = 0
for layer_idx, layer_gate_logits in enumerate(gate_logits):
layer_loss = original_load_balancing_loss_func(
gate_logits=(layer_gate_logits,),
num_experts=num_experts_per_layer[layer_idx],
top_k=top_k,
attention_mask=attention_mask,
)
overall_loss += layer_loss.to(compute_device)
return overall_loss
class GptOssPuzzleForCausalLM(GptOssForCausalLM):
"""
A child class of GptOssForCausalLM to support heterogeneous layer configurations.
This class uses monkey-patching to inject custom behavior into the parent class while maximizing
code reuse and minimizing duplication. During `__init__`, it temporarily replaces the decoder layer
class to use `GptOssPuzzleDecoderLayer`. During `forward`, it patches mask creation, cache handling,
and load balancing loss computation to account for per-layer variations.
"""
config_class = GptOssPuzzleConfig
_no_split_modules = ["GptOssPuzzleDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"\.k_scale$", r"\.v_scale$"]
def __init__(self, config):
# PER_BLOCK_ATTRIBUTE values that are not supposed to be used. Required just because accessed in GptOssForCausalLM's __init__
config.num_local_experts = "PER_BLOCK_ATTRIBUTE"
original_decoder_layer_cls = modeling_gpt_oss.GptOssDecoderLayer
modeling_gpt_oss.GptOssDecoderLayer = GptOssPuzzleDecoderLayer
try:
super().__init__(config)
self.config = config # Used for load_balancing_loss_func
finally:
modeling_gpt_oss.GptOssDecoderLayer = original_decoder_layer_cls
mxfp4.Mxfp4GptOssExperts = Mxfp4GptOssPuzzleExperts # Used after the model is initialized
def forward(self, *args, **kwargs):
original_create_sliding_window_causal_mask = modeling_gpt_oss.create_sliding_window_causal_mask
original_dynamic_cache = modeling_gpt_oss.DynamicCache
modeling_gpt_oss.load_balancing_loss_func = functools.partial(
load_balancing_loss_func,
num_experts_per_layer=tuple(block_config.num_local_experts for block_config in self.config.block_configs),
)
modeling_gpt_oss.create_sliding_window_causal_mask = lambda **kwargs: SlidingWindowCausalMaskPlaceholder(
kwargs=kwargs
)
modeling_gpt_oss.DynamicCache = PuzzleDynamicCache
try:
return super().forward(*args, **kwargs)
finally:
modeling_gpt_oss.create_sliding_window_causal_mask = original_create_sliding_window_causal_mask
modeling_gpt_oss.load_balancing_loss_func = original_load_balancing_loss_func
modeling_gpt_oss.DynamicCache = original_dynamic_cache
def _prepare_cache_for_generation(self, *args, **kwargs):
from transformers.generation import utils as generation_utils
original_dynamic_cache = generation_utils.DynamicCache
generation_utils.DynamicCache = PuzzleDynamicCache
try:
return super()._prepare_cache_for_generation(*args, **kwargs)
finally:
generation_utils.DynamicCache = original_dynamic_cache