# gemma4_optimization.py """ Hayson Cheung, 2026, Oringinal Script written to optimize Gemma4 on Hugging Face's Transformers library. LICENSED UNDER THE MIT LICENSE. This file contains optimized variants of Gemma4 text model components, including a mixin for remapping weights from original Gemma4 models to optimized versions. The optimizations include support for an additional zero-compute expert in the MoE router and experts, as well as adjustments to the router's projection and scaling parameters to accommodate the expanded expert set. The load_optimization_weights method enables loading weights from a base Gemma4 model while remapping tensors as needed for the optimized architecture. """ from __future__ import annotations from dataclasses import dataclass import torch from torch import nn from .modeling_gemma4 import ( Gemma4ForCausalLM, Gemma4TextDecoderLayer, Gemma4TextExperts, Gemma4TextModel, Gemma4TextRouter, ) @dataclass(frozen=True) class Gemma4OptimizationLoadResult: loaded_keys: tuple[str, ...] skipped_keys: tuple[str, ...] @property def loaded_count(self) -> int: return len(self.loaded_keys) @property def skipped_count(self) -> int: return len(self.skipped_keys) class Gemma4OptimizationWeightsMixin: """ Mixin for modules that need a custom remount step when loading weights from an original Gemma4 model into an optimized variant. """ def _remap_optimization_tensors( self, base_state_dict: dict[str, torch.Tensor], target_state_dict: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: return {} def load_optimization_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult: if not isinstance(self, nn.Module): raise TypeError("Gemma4OptimizationWeightsMixin can only be used with nn.Module subclasses.") if not isinstance(base_model, nn.Module): raise TypeError("base_model must be an nn.Module.") target_state_dict = self.state_dict() loaded: dict[str, torch.Tensor] = {} for module_name, module in self.named_modules(): if not isinstance(module, Gemma4OptimizationWeightsMixin): continue try: base_module = base_model if module_name == "" else base_model.get_submodule(module_name) except AttributeError: continue remapped_tensors = module._remap_optimization_tensors(base_module.state_dict(), module.state_dict()) for tensor_name, tensor_value in remapped_tensors.items(): full_name = f"{module_name}.{tensor_name}" if module_name else tensor_name loaded[full_name] = tensor_value.to( device=target_state_dict[full_name].device, dtype=target_state_dict[full_name].dtype, ) for tensor_name, tensor_value in base_model.state_dict().items(): if tensor_name in loaded: continue target_tensor = target_state_dict.get(tensor_name) if target_tensor is None or target_tensor.shape != tensor_value.shape: continue loaded[tensor_name] = tensor_value.to(device=target_tensor.device, dtype=target_tensor.dtype) self.load_state_dict(loaded, strict=False) skipped = tuple(sorted(set(base_model.state_dict()) - set(loaded))) return Gemma4OptimizationLoadResult(tuple(sorted(loaded)), skipped) def _load_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult: return self.load_optimization_weights(base_model) def get_total_optimized_experts(num_experts: int, add_zero_compute_expert: bool) -> int: return num_experts + int(add_zero_compute_expert) class OptimizedGemma4TextExperts(Gemma4TextExperts): def __init__(self, config): super().__init__(config) self.total_num_experts = get_total_optimized_experts( self.num_experts, getattr(config, "add_zero_compute_expert", False) ) def forward( self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] if expert_idx >= self.num_experts: continue top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states class OptimizedGemma4TextRouter(Gemma4OptimizationWeightsMixin, Gemma4TextRouter): def __init__(self, config): super().__init__(config) self.num_experts = config.num_experts self.total_num_experts = get_total_optimized_experts( self.num_experts, getattr(config, "add_zero_compute_expert", False) ) self.proj = nn.Linear(config.hidden_size, self.total_num_experts, bias=False) self.per_expert_scale = nn.Parameter(torch.ones(self.total_num_experts)) def _remap_optimization_tensors( self, base_state_dict: dict[str, torch.Tensor], target_state_dict: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: remapped: dict[str, torch.Tensor] = {} base_proj = base_state_dict.get("proj.weight") target_proj = target_state_dict.get("proj.weight") if ( base_proj is not None and target_proj is not None and target_proj.shape[1] == base_proj.shape[1] and target_proj.shape[0] == base_proj.shape[0] + 1 ): expanded_proj = target_proj.clone() expanded_proj.zero_() expanded_proj[: base_proj.shape[0]].copy_(base_proj) remapped["proj.weight"] = expanded_proj base_per_expert_scale = base_state_dict.get("per_expert_scale") target_per_expert_scale = target_state_dict.get("per_expert_scale") if ( base_per_expert_scale is not None and target_per_expert_scale is not None and target_per_expert_scale.shape[0] == base_per_expert_scale.shape[0] + 1 ): expanded_per_expert_scale = target_per_expert_scale.clone() expanded_per_expert_scale.fill_(1.0) expanded_per_expert_scale[: base_per_expert_scale.shape[0]].copy_(base_per_expert_scale) remapped["per_expert_scale"] = expanded_per_expert_scale return remapped class OptimizedGemma4TextDecoderLayer(Gemma4TextDecoderLayer): router_class = OptimizedGemma4TextRouter experts_class = OptimizedGemma4TextExperts class OptimizedGemma4TextModel(Gemma4OptimizationWeightsMixin, Gemma4TextModel): decoder_layer_class = OptimizedGemma4TextDecoderLayer class OptimizedGemma4ForCausalLM(Gemma4OptimizationWeightsMixin, Gemma4ForCausalLM): text_model_class = OptimizedGemma4TextModel __all__ = [ "Gemma4OptimizationLoadResult", "Gemma4OptimizationWeightsMixin", "OptimizedGemma4ForCausalLM", "OptimizedGemma4TextDecoderLayer", "OptimizedGemma4TextExperts", "OptimizedGemma4TextModel", "OptimizedGemma4TextRouter", "get_total_optimized_experts", ]