File size: 8,089 Bytes
0b0ec56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | # gemma4_optimization.py
"""
Hayson Cheung, 2026, Oringinal Script written to optimize
Gemma4 on Hugging Face's Transformers library.
LICENSED UNDER THE MIT LICENSE.
This file contains optimized variants of Gemma4 text model components, including a mixin for remapping weights from original Gemma4 models to optimized versions. The optimizations include support for an additional zero-compute expert in the MoE router and experts, as well as adjustments to the router's projection and scaling parameters to accommodate the expanded expert set. The load_optimization_weights method enables loading weights from a base Gemma4 model while remapping tensors as needed for the optimized architecture.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import nn
from .modeling_gemma4 import (
Gemma4ForCausalLM,
Gemma4TextDecoderLayer,
Gemma4TextExperts,
Gemma4TextModel,
Gemma4TextRouter,
)
@dataclass(frozen=True)
class Gemma4OptimizationLoadResult:
loaded_keys: tuple[str, ...]
skipped_keys: tuple[str, ...]
@property
def loaded_count(self) -> int:
return len(self.loaded_keys)
@property
def skipped_count(self) -> int:
return len(self.skipped_keys)
class Gemma4OptimizationWeightsMixin:
"""
Mixin for modules that need a custom remount step when loading weights
from an original Gemma4 model into an optimized variant.
"""
def _remap_optimization_tensors(
self,
base_state_dict: dict[str, torch.Tensor],
target_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
return {}
def load_optimization_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult:
if not isinstance(self, nn.Module):
raise TypeError("Gemma4OptimizationWeightsMixin can only be used with nn.Module subclasses.")
if not isinstance(base_model, nn.Module):
raise TypeError("base_model must be an nn.Module.")
target_state_dict = self.state_dict()
loaded: dict[str, torch.Tensor] = {}
for module_name, module in self.named_modules():
if not isinstance(module, Gemma4OptimizationWeightsMixin):
continue
try:
base_module = base_model if module_name == "" else base_model.get_submodule(module_name)
except AttributeError:
continue
remapped_tensors = module._remap_optimization_tensors(base_module.state_dict(), module.state_dict())
for tensor_name, tensor_value in remapped_tensors.items():
full_name = f"{module_name}.{tensor_name}" if module_name else tensor_name
loaded[full_name] = tensor_value.to(
device=target_state_dict[full_name].device,
dtype=target_state_dict[full_name].dtype,
)
for tensor_name, tensor_value in base_model.state_dict().items():
if tensor_name in loaded:
continue
target_tensor = target_state_dict.get(tensor_name)
if target_tensor is None or target_tensor.shape != tensor_value.shape:
continue
loaded[tensor_name] = tensor_value.to(device=target_tensor.device, dtype=target_tensor.dtype)
self.load_state_dict(loaded, strict=False)
skipped = tuple(sorted(set(base_model.state_dict()) - set(loaded)))
return Gemma4OptimizationLoadResult(tuple(sorted(loaded)), skipped)
def _load_weights(self, base_model: nn.Module) -> Gemma4OptimizationLoadResult:
return self.load_optimization_weights(base_model)
def get_total_optimized_experts(num_experts: int, add_zero_compute_expert: bool) -> int:
return num_experts + int(add_zero_compute_expert)
class OptimizedGemma4TextExperts(Gemma4TextExperts):
def __init__(self, config):
super().__init__(config)
self.total_num_experts = get_total_optimized_experts(
self.num_experts, getattr(config, "add_zero_compute_expert", False)
)
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx >= self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class OptimizedGemma4TextRouter(Gemma4OptimizationWeightsMixin, Gemma4TextRouter):
def __init__(self, config):
super().__init__(config)
self.num_experts = config.num_experts
self.total_num_experts = get_total_optimized_experts(
self.num_experts, getattr(config, "add_zero_compute_expert", False)
)
self.proj = nn.Linear(config.hidden_size, self.total_num_experts, bias=False)
self.per_expert_scale = nn.Parameter(torch.ones(self.total_num_experts))
def _remap_optimization_tensors(
self,
base_state_dict: dict[str, torch.Tensor],
target_state_dict: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
remapped: dict[str, torch.Tensor] = {}
base_proj = base_state_dict.get("proj.weight")
target_proj = target_state_dict.get("proj.weight")
if (
base_proj is not None
and target_proj is not None
and target_proj.shape[1] == base_proj.shape[1]
and target_proj.shape[0] == base_proj.shape[0] + 1
):
expanded_proj = target_proj.clone()
expanded_proj.zero_()
expanded_proj[: base_proj.shape[0]].copy_(base_proj)
remapped["proj.weight"] = expanded_proj
base_per_expert_scale = base_state_dict.get("per_expert_scale")
target_per_expert_scale = target_state_dict.get("per_expert_scale")
if (
base_per_expert_scale is not None
and target_per_expert_scale is not None
and target_per_expert_scale.shape[0] == base_per_expert_scale.shape[0] + 1
):
expanded_per_expert_scale = target_per_expert_scale.clone()
expanded_per_expert_scale.fill_(1.0)
expanded_per_expert_scale[: base_per_expert_scale.shape[0]].copy_(base_per_expert_scale)
remapped["per_expert_scale"] = expanded_per_expert_scale
return remapped
class OptimizedGemma4TextDecoderLayer(Gemma4TextDecoderLayer):
router_class = OptimizedGemma4TextRouter
experts_class = OptimizedGemma4TextExperts
class OptimizedGemma4TextModel(Gemma4OptimizationWeightsMixin, Gemma4TextModel):
decoder_layer_class = OptimizedGemma4TextDecoderLayer
class OptimizedGemma4ForCausalLM(Gemma4OptimizationWeightsMixin, Gemma4ForCausalLM):
text_model_class = OptimizedGemma4TextModel
__all__ = [
"Gemma4OptimizationLoadResult",
"Gemma4OptimizationWeightsMixin",
"OptimizedGemma4ForCausalLM",
"OptimizedGemma4TextDecoderLayer",
"OptimizedGemma4TextExperts",
"OptimizedGemma4TextModel",
"OptimizedGemma4TextRouter",
"get_total_optimized_experts",
]
|