| from typing import Any, Iterable, Optional, Union |
|
|
| from dataclasses import dataclass |
| import functools |
| import inspect |
|
|
| from .configuration_gpt_oss_puzzle import GptOssPuzzleConfig |
| import torch |
| from transformers.cache_utils import Cache, DynamicCache, DynamicLayer, DynamicSlidingWindowLayer |
| from transformers.integrations import mxfp4 |
| from transformers.integrations.mxfp4 import Mxfp4GptOssExperts |
| from transformers.masking_utils import create_sliding_window_causal_mask |
| from transformers.models.gpt_oss import modeling_gpt_oss |
| from transformers.models.gpt_oss.modeling_gpt_oss import GptOssDecoderLayer, GptOssForCausalLM |
|
|
|
|
| @dataclass |
| class SlidingWindowCausalMaskPlaceholder: |
| kwargs: dict[str, Any] |
|
|
|
|
| class GptOssPuzzleDecoderLayer(GptOssDecoderLayer): |
| """ |
| Extends GptOssDecoderLayer to support per-layer configs. |
| """ |
|
|
| def __init__(self, config: GptOssPuzzleConfig, layer_idx: int): |
| layer_config = config.get_gpt_oss_config_for_layer(layer_idx) |
| super().__init__(layer_config, layer_idx) |
| self.config = layer_config |
| self.layer_idx = layer_idx |
|
|
| def forward(self, *args, **kwargs): |
| if "attention_mask" in kwargs and isinstance(kwargs["attention_mask"], SlidingWindowCausalMaskPlaceholder): |
| mask_kwargs = dict(kwargs["attention_mask"].kwargs) |
| mask_kwargs["config"] = self.config |
| if mask_kwargs["past_key_values"] is not None: |
| mask_kwargs["past_key_values"] = CacheViewForSlidingWindowMask( |
| mask_kwargs["past_key_values"], self.layer_idx |
| ) |
|
|
| kwargs["attention_mask"] = create_sliding_window_causal_mask(**mask_kwargs) |
| return super().forward(*args, **kwargs) |
|
|
|
|
| class CacheViewForSlidingWindowMask: |
| """ |
| A view wrapper around a Cache that makes `create_sliding_window_causal_mask` use the correct layer index. |
| |
| `create_sliding_window_causal_mask` iterates over `past_key_values.is_sliding` to determine which layer |
| to use for deriving mask sizes, effectively using the first layer's index. Since gpt-oss-puzzle has |
| heterogeneous sliding window sizes across layers, we need to ensure each layer uses its own sliding |
| window size. This view returns an `is_sliding` list that only marks the current layer as sliding, |
| causing `create_sliding_window_causal_mask` to use the correct layer index for mask computation. |
| """ |
|
|
| def __init__(self, cache: Cache, layer_idx: int): |
| self._cache = cache |
| self._layer_idx = layer_idx |
|
|
| @property |
| def is_sliding(self) -> list[bool]: |
| return [False] * self._layer_idx + [True] |
|
|
| def __getattr__(self, name: str): |
| return getattr(self._cache, name) |
|
|
|
|
| class Mxfp4GptOssPuzzleExperts(Mxfp4GptOssExperts): |
| def __init__(self, config: GptOssPuzzleConfig): |
| """ |
| Extends Mxfp4GptOssExperts to support per-layer configs. |
| Since this class is created without passing the layer index, we need to infer it from the call stack. |
| """ |
| |
| current_key_name = _get_variable_from_stack(["current_key_name"]) |
| if current_key_name is None: |
| module_name = _get_variable_from_stack(["module_name"]) |
| if module_name is None: |
| raise RuntimeError("`current_key_name`/`module_name` variable not found in caller stack") |
| layer_idx = int(module_name.split(".")[-3]) |
| else: |
| layer_idx = int(current_key_name[-3]) |
|
|
| layer_config = config.get_gpt_oss_config_for_layer(layer_idx) |
| super().__init__(layer_config) |
|
|
|
|
| def _get_variable_from_stack(names: list[str]) -> str | None: |
| f = inspect.currentframe().f_back |
| while f: |
| for name in names: |
| if name in f.f_locals: |
| return f.f_locals[name] |
| f = f.f_back |
| return None |
|
|
|
|
| class PuzzleDynamicCache(DynamicCache): |
| """ |
| A child class of DynamicCache that supports heterogeneous layer configurations. |
| |
| __init__ is the same as in DynamicCache, except for the usage of sliding window which is obtained per layer from `block_configs`. |
| """ |
|
|
| def __init__( |
| self, |
| ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, |
| config: Optional[GptOssPuzzleConfig] = None, |
| offloading: bool = False, |
| offload_only_non_sliding: bool = False, |
| ): |
| layers = [] |
| |
| if config is not None: |
| decoder_config = config.get_text_config(decoder=True) |
| layer_types = getattr(decoder_config, "layer_types", None) |
| if layer_types is None: |
| layer_types = [] |
| for layer_idx in range(decoder_config.num_hidden_layers): |
| sliding_window = None |
| for attr_name in ("sliding_window", "attention_chunk_size"): |
| sliding_window = getattr( |
| config.block_configs[layer_idx], |
| attr_name, |
| getattr(decoder_config, attr_name, None), |
| ) |
| if sliding_window is not None: |
| break |
| layer_types.append("sliding_attention" if sliding_window is not None else "full_attention") |
|
|
| |
| if hasattr(decoder_config, "num_kv_shared_layers"): |
| layer_types = layer_types[: -decoder_config.num_kv_shared_layers] |
|
|
| for layer_idx, layer_type in enumerate(layer_types): |
| |
| |
| if layer_type in ("sliding_attention", "chunked_attention"): |
| sliding_window = None |
| for attr_name in ("sliding_window", "attention_chunk_size"): |
| sliding_window = getattr( |
| decoder_config.block_configs[layer_idx], |
| attr_name, |
| getattr(decoder_config, attr_name, None), |
| ) |
| if sliding_window is not None: |
| break |
|
|
| layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) |
| else: |
| layers.append(DynamicLayer()) |
|
|
| |
| if ddp_cache_data is not None: |
| |
| for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data): |
| |
| if config is None: |
| layers.append(DynamicLayer()) |
| |
| _, _ = layers[layer_idx].update(key_states, value_states) |
|
|
| |
| if len(layers) == 0: |
| super(DynamicCache, self).__init__( |
| layer_class_to_replicate=DynamicLayer, |
| offloading=offloading, |
| offload_only_non_sliding=offload_only_non_sliding, |
| ) |
| else: |
| super(DynamicCache, self).__init__( |
| layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding |
| ) |
|
|
|
|
| original_load_balancing_loss_func = modeling_gpt_oss.load_balancing_loss_func |
|
|
|
|
| def load_balancing_loss_func( |
| gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], |
| num_experts: Optional[int] = None, |
| top_k=2, |
| attention_mask: Optional[torch.Tensor] = None, |
| num_experts_per_layer: tuple[int, ...] = None, |
| ) -> Union[torch.Tensor, int]: |
| if gate_logits is None or not isinstance(gate_logits, tuple): |
| return 0 |
|
|
| compute_device = gate_logits[0].device |
| overall_loss = 0 |
|
|
| for layer_idx, layer_gate_logits in enumerate(gate_logits): |
| layer_loss = original_load_balancing_loss_func( |
| gate_logits=(layer_gate_logits,), |
| num_experts=num_experts_per_layer[layer_idx], |
| top_k=top_k, |
| attention_mask=attention_mask, |
| ) |
| overall_loss += layer_loss.to(compute_device) |
|
|
| return overall_loss |
|
|
|
|
| class GptOssPuzzleForCausalLM(GptOssForCausalLM): |
| """ |
| A child class of GptOssForCausalLM to support heterogeneous layer configurations. |
| |
| This class uses monkey-patching to inject custom behavior into the parent class while maximizing |
| code reuse and minimizing duplication. During `__init__`, it temporarily replaces the decoder layer |
| class to use `GptOssPuzzleDecoderLayer`. During `forward`, it patches mask creation, cache handling, |
| and load balancing loss computation to account for per-layer variations. |
| """ |
|
|
| config_class = GptOssPuzzleConfig |
| _no_split_modules = ["GptOssPuzzleDecoderLayer"] |
| _keys_to_ignore_on_load_unexpected = [r"\.k_scale$", r"\.v_scale$"] |
|
|
| def __init__(self, config): |
| |
| config.num_local_experts = "PER_BLOCK_ATTRIBUTE" |
|
|
| original_decoder_layer_cls = modeling_gpt_oss.GptOssDecoderLayer |
| modeling_gpt_oss.GptOssDecoderLayer = GptOssPuzzleDecoderLayer |
| try: |
| super().__init__(config) |
| self.config = config |
| finally: |
| modeling_gpt_oss.GptOssDecoderLayer = original_decoder_layer_cls |
|
|
| mxfp4.Mxfp4GptOssExperts = Mxfp4GptOssPuzzleExperts |
|
|
| def forward(self, *args, **kwargs): |
| original_create_sliding_window_causal_mask = modeling_gpt_oss.create_sliding_window_causal_mask |
| original_dynamic_cache = modeling_gpt_oss.DynamicCache |
|
|
| modeling_gpt_oss.load_balancing_loss_func = functools.partial( |
| load_balancing_loss_func, |
| num_experts_per_layer=tuple(block_config.num_local_experts for block_config in self.config.block_configs), |
| ) |
| modeling_gpt_oss.create_sliding_window_causal_mask = lambda **kwargs: SlidingWindowCausalMaskPlaceholder( |
| kwargs=kwargs |
| ) |
| modeling_gpt_oss.DynamicCache = PuzzleDynamicCache |
| try: |
| return super().forward(*args, **kwargs) |
| finally: |
| modeling_gpt_oss.create_sliding_window_causal_mask = original_create_sliding_window_causal_mask |
| modeling_gpt_oss.load_balancing_loss_func = original_load_balancing_loss_func |
| modeling_gpt_oss.DynamicCache = original_dynamic_cache |
|
|
| def _prepare_cache_for_generation(self, *args, **kwargs): |
| from transformers.generation import utils as generation_utils |
|
|
| original_dynamic_cache = generation_utils.DynamicCache |
| generation_utils.DynamicCache = PuzzleDynamicCache |
| try: |
| return super()._prepare_cache_for_generation(*args, **kwargs) |
| finally: |
| generation_utils.DynamicCache = original_dynamic_cache |
|
|