tiny-random-gemma4-dense / modeling_gemma4.py
katuni4ka's picture
Upload 10 files
9c3dbaf verified
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted for transformers 4.57.1 (trust_remote_code=True usage).
import math
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cached_property
from typing import Optional
import torch
import torch.nn.init as init
from torch import nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationMixin
from transformers.masking_utils import (
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
)
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
# Local import for trust_remote_code usage
from configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
# --- 4.57.1 compat stubs ---
@contextmanager
def maybe_autocast(device_type=None, enabled=True):
if enabled and device_type is not None:
with torch.autocast(device_type=device_type):
yield
else:
yield
def merge_with_config_defaults(kwargs, config=None, *keys):
return kwargs
def use_experts_implementation(cls):
return cls
def use_kernelized_func(fn):
def decorator(cls):
return cls
return decorator
def torch_compilable_check(condition, message=""):
if not condition:
raise ValueError(message)
class OutputRecorder:
def __init__(self, *args, **kwargs):
pass
def capture_outputs(fn):
return fn
def create_bidirectional_mask(attention_mask=None, dtype=None, device=None, config=None, inputs_embeds=None,
and_mask_function=None, **kwargs):
"""Bidirectional mask: all-ones (no causal masking), with padding applied."""
if inputs_embeds is not None and attention_mask is None:
bsz, seq_len = inputs_embeds.shape[:2]
attention_mask = torch.ones(bsz, seq_len, device=inputs_embeds.device, dtype=torch.bool)
elif inputs_embeds is not None:
bsz, seq_len = inputs_embeds.shape[:2]
else:
bsz, seq_len = attention_mask.shape[:2]
if dtype is None:
if inputs_embeds is not None:
dtype = inputs_embeds.dtype
else:
dtype = torch.float32
if device is None:
if inputs_embeds is not None:
device = inputs_embeds.device
elif attention_mask is not None:
device = attention_mask.device
else:
device = torch.device("cpu")
# Create full bidirectional mask (all zeros = no masking)
mask = torch.zeros(bsz, 1, seq_len, seq_len, dtype=dtype, device=device)
# Apply padding mask if attention_mask is a 2D bool/int mask
if attention_mask is not None and attention_mask.dim() == 2:
# attention_mask: 1=valid, 0=padding
pad = (attention_mask == 0) # [bsz, seq_len]
# Mask out positions where key is padding
mask = mask.masked_fill(pad.unsqueeze(1).unsqueeze(2), torch.finfo(dtype).min)
return mask
def _compute_proportional_rope_parameters(
config,
device=None,
seq_len=None,
layer_type=None,
head_dim_key="head_dim",
):
"""
Proportional RoPE: partial rotary applied to head_dim,
rest filled with zeros. Ported from transformers 5.x.
"""
if layer_type is not None:
rope_params_dict = config.rope_parameters[layer_type]
else:
rope_params_dict = config.rope_parameters
head_dim = getattr(config, head_dim_key, None) or config.hidden_size // config.num_attention_heads
base = rope_params_dict["rope_theta"]
factor = rope_params_dict.get("factor", 1.0)
rope_proportion = rope_params_dict.get("partial_rotary_factor", 1.0)
attention_factor = 1.0
rope_angles = int(rope_proportion * head_dim // 2)
inv_freq_rotated = 1.0 / (
base
** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / head_dim)
)
nope_angles = head_dim // 2 - rope_angles
if nope_angles > 0:
inv_freq = torch.cat(
(inv_freq_rotated, torch.zeros(nope_angles, dtype=torch.float32, device=device)),
dim=0,
)
else:
inv_freq = inv_freq_rotated
inv_freq = inv_freq / factor
return inv_freq, attention_factor
# Patch ROPE_INIT_FUNCTIONS if proportional is missing
if "proportional" not in ROPE_INIT_FUNCTIONS:
ROPE_INIT_FUNCTIONS["proportional"] = _compute_proportional_rope_parameters
# --- Output dataclasses ---
@dataclass
class Gemma4ModelOutputWithPast(BaseModelOutputWithPast):
r"""
Base class for Gemma4 outputs.
image_hidden_states (`torch.FloatTensor`, *optional*):
image_hidden_states from the vision encoder.
audio_hidden_states (`torch.FloatTensor`, *optional*):
audio_hidden_states from the audio encoder.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
audio_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Gemma4CausalLMOutputWithPast(ModelOutput):
r"""
Causal LM output for Gemma4.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple] = None
attentions: Optional[tuple] = None
image_hidden_states: Optional[torch.FloatTensor] = None
audio_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
r"""
Audio model output.
attention_mask (`torch.BoolTensor`, *optional*):
True for valid positions, False for padding.
"""
attention_mask: Optional[torch.BoolTensor] = None
# --- Modules ---
class Gemma4ClippableLinear(nn.Module):
def __init__(
self,
config,
in_features: int,
out_features: int,
) -> None:
super().__init__()
self.use_clipped_linears = config.use_clipped_linears
self.linear = nn.Linear(in_features, out_features, bias=False)
if self.use_clipped_linears:
self.register_buffer("input_min", torch.tensor(-float("inf")))
self.register_buffer("input_max", torch.tensor(float("inf")))
self.register_buffer("output_min", torch.tensor(-float("inf")))
self.register_buffer("output_max", torch.tensor(float("inf")))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_clipped_linears:
hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max)
hidden_states = self.linear(hidden_states)
if self.use_clipped_linears:
hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max)
return hidden_states
class Gemma4RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
super().__init__()
self.eps = eps
self.with_scale = with_scale
if self.with_scale:
self.weight = nn.Parameter(torch.ones(dim), requires_grad=True)
def _norm(self, hidden_states: torch.Tensor):
mean_squared = hidden_states.pow(2).mean(-1, keepdim=True) + self.eps
return hidden_states * torch.pow(mean_squared, -0.5)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
normed_output = self._norm(hidden_states.float())
if self.with_scale:
normed_output = normed_output * self.weight.float()
return normed_output.type_as(hidden_states)
class Gemma4AudioRelPositionalEncoding(nn.Module):
"""Sinusoidal relative positional encoding for the audio encoder."""
inv_timescales: torch.Tensor
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.context_size = (
config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
)
min_timescale = 1.0
max_timescale = 10000.0
num_timescales = self.hidden_size // 2
log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False)
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
position_ids = torch.arange(12, -1, -1, device=hidden_states.device)
position_ids = position_ids[..., None]
scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device)
pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
return pos_embed.to(dtype=hidden_states.dtype)
class Gemma4AudioAttention(nn.Module):
"""Chunked local attention with relative position bias"""
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_logits_soft_cap = config.attention_logit_cap
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.q_scale = (self.head_dim**-0.5) / math.log(2)
self.k_scale = math.log(1 + math.e) / math.log(2)
self.chunk_size = config.attention_chunk_size
self.max_past_horizon = config.attention_context_left - 1
self.max_future_horizon = config.attention_context_right
self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim))
self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False)
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, num_heads, head_dim = hidden_states.shape
num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size
pad = num_blocks * self.chunk_size - seq_len
hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad))
return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous()
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, num_heads, head_dim = hidden_states.shape
hidden_states = F.pad(
hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)
)
hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size)
hidden_states = torch.movedim(hidden_states, -1, 2)
return hidden_states.contiguous()
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_heads, num_blocks, block_size, position_length = x.shape
context_size = self.context_size
x = F.pad(x, (0, context_size + 1 - position_length))
x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1))
x = x[..., : block_size * context_size]
return x.view(batch_size, num_heads, num_blocks, block_size, context_size)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask=None,
):
batch_size, seq_length, _ = hidden_states.shape
hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim)
query_states = self.q_proj(hidden_states).float().view(hidden_shape)
key_states = self.k_proj(hidden_states).float().view(hidden_shape)
value_states = self.v_proj(hidden_states).float().view(hidden_shape)
query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale)
key_states = key_states * self.k_scale
query_states = self._convert_to_block(query_states)
key_states = self._extract_block_context(key_states)
value_states = self._extract_block_context(value_states)
num_blocks = query_states.shape[1]
relative_key_states = self.relative_k_proj(position_embeddings)
relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim)
relative_key_states = relative_key_states.to(dtype=query_states.dtype)
queries = query_states.permute(0, 3, 1, 2, 4)
matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2)
queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim)
matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0)
matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1)
matrix_bd = self._rel_shift(matrix_bd)
attn_weights = matrix_ac + matrix_bd
attn_weights = attn_weights / self.softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.softcap
if attention_mask is not None:
attn_weights = attn_weights.masked_fill(
attention_mask.logical_not(), self.config.attention_invalid_logits_value
)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1)
attn_output = attn_output[:, :seq_length].contiguous()
attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype))
return attn_output, attn_weights
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
def __init__(self, in_channels, out_channels, norm_eps):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=1,
bias=False,
)
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
self.act = nn.ReLU()
def forward(self, hidden_states: torch.Tensor, mask=None):
if mask is not None:
mask = mask.to(device=hidden_states.device)
hidden_states = hidden_states * mask[:, None, :, None]
hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype))
hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
if mask is not None:
mask = mask[:, ::2]
return hidden_states, mask
class Gemma4AudioSubSampleConvProjection(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=1,
out_channels=config.subsampling_conv_channels[0],
norm_eps=config.rms_norm_eps,
)
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=config.subsampling_conv_channels[0],
out_channels=config.subsampling_conv_channels[1],
norm_eps=config.rms_norm_eps,
)
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False)
def forward(self, input_features: torch.Tensor, input_features_mask=None):
hidden_states = input_features.unsqueeze(1)
hidden_states, mask = self.layer0(hidden_states, input_features_mask)
hidden_states, mask = self.layer1(hidden_states, mask)
batch_size, _, seq_len, _ = hidden_states.shape
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
return self.input_proj_linear(hidden_states), mask
class Gemma4AudioFeedForward(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
self.post_layer_scale = config.residual_weight
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
residual = hidden_states
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.pre_layer_norm(hidden_states)
hidden_states = self.ffw_layer_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.ffw_layer_2(hidden_states)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.post_layer_norm(hidden_states)
hidden_states *= self.post_layer_scale
hidden_states += residual
return hidden_states
class Gemma4AudioCausalConv1d(nn.Conv1d):
@cached_property
def left_pad(self):
effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1
return effective_kernel_size - self.stride[0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.functional.pad(x, (self.left_pad, 0))
return super().forward(x)
class Gemma4AudioLightConv1d(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.conv_kernel_size,
groups=config.hidden_size,
bias=False,
)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.pre_layer_norm(hidden_states)
hidden_states = self.linear_start(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=-1)
hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2)
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.conv_norm(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_end(hidden_states)
hidden_states += residual
return hidden_states
class Gemma4AudioLayer(nn.Module):
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.feed_forward1 = Gemma4AudioFeedForward(config)
self.feed_forward2 = Gemma4AudioFeedForward(config)
self.self_attn = Gemma4AudioAttention(config, layer_idx)
self.lconv1d = Gemma4AudioLightConv1d(config)
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_out = Gemma4RMSNorm(config.hidden_size)
self.gradient_clipping = config.gradient_clipping
def forward(
self,
hidden_states: torch.Tensor,
attention_mask=None,
position_embeddings: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
hidden_states = self.feed_forward1(hidden_states)
residual = hidden_states
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.norm_pre_attn(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.norm_post_attn(hidden_states)
hidden_states += residual
hidden_states = self.lconv1d(hidden_states)
hidden_states = self.feed_forward2(hidden_states)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.norm_out(hidden_states)
return hidden_states
# ---- Vision Encoder Layers ----
class Gemma4VisionPatchEmbedder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.position_embedding_size = config.position_embedding_size
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
clamped_positions = pixel_position_ids.clamp(min=0)
one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
position_embeddings = one_hot @ self.position_embedding_table
position_embeddings = position_embeddings.sum(dim=1)
position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
return position_embeddings
def forward(
self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor
) -> torch.Tensor:
pixel_values = 2 * (pixel_values - 0.5)
hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))
position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions)
return hidden_states + position_embeddings
class Gemma4VisionPooler(nn.Module):
"""Scaling and optional spatial pooling for vision encodings"""
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.root_hidden_size = self.hidden_size**0.5
def _avg_pool_by_positions(self, hidden_states, pixel_position_ids, length):
input_seq_len = hidden_states.shape[1]
k = int((input_seq_len // length) ** 0.5)
k_squared = k**2
if k_squared * length != input_seq_len:
raise ValueError(
f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}."
)
clamped_positions = pixel_position_ids.clamp(min=0)
max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared
output = weights.transpose(1, 2) @ hidden_states.float()
mask = torch.logical_not((weights == 0).all(dim=1))
return output.to(hidden_states.dtype), mask
def forward(self, hidden_states, pixel_position_ids, padding_positions, output_length=None):
if output_length > hidden_states.shape[1]:
raise ValueError(
f"Cannot output more soft tokens (requested {output_length}) than there are patches"
f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing."
)
hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0)
if hidden_states.shape[1] != output_length:
hidden_states, padding_positions = self._avg_pool_by_positions(
hidden_states, pixel_position_ids, output_length
)
hidden_states *= self.root_hidden_size
return hidden_states, padding_positions
class Gemma4VisionMLP(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma4VisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: Gemma4VisionConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
else:
rope_init_fn = self.compute_default_rope_parameters
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
@staticmethod
def compute_default_rope_parameters(config=None, device=None, seq_len=None):
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
spatial_dim = dim // 2
attention_factor = 1.0
inv_freq = 1.0 / (
base
** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
all_cos, all_sin = [], []
for i in range(2):
dim_position_ids = position_ids[:, :, i]
dim_position_ids_expanded = dim_position_ids[:, None, :].float()
with maybe_autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
all_cos.append(cos)
all_sin.append(sin)
cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype)
sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (x * cos) + (rotate_half(x) * sin)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask,
dropout: float = 0.0,
scaling=None,
softcap=None,
**kwargs,
):
if scaling is None:
scaling = module.head_dim**-0.5
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if softcap is not None:
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_multidimensional_rope(x, cos, sin, position_ids, unsqueeze_dim=2):
ndim = position_ids.shape[-1]
num_input_channels = x.shape[-1]
num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim))
if num_rotated_channels_per_dim <= 0:
raise ValueError(
f"Invalid configuration: num_rotated_channels_per_dim must be > 0, got"
f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels}, ndim={ndim})"
)
split_sizes = [num_rotated_channels_per_dim] * ndim
x_parts = torch.split(x, split_sizes, dim=-1)
cos_parts = torch.split(cos, split_sizes, dim=-1)
sin_parts = torch.split(sin, split_sizes, dim=-1)
y_parts = [
apply_rotary_pos_emb(x=x_parts[k], cos=cos_parts[k], sin=sin_parts[k], unsqueeze_dim=unsqueeze_dim)
for k in range(ndim)
]
return torch.cat(y_parts, dim=-1)
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = 1.0
self.attention_dropout = self.config.attention_dropout
self.is_causal = False
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
self.q_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma4RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings=None,
attention_mask=None,
position_ids=None,
**kwargs,
):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
cos, sin = position_embeddings
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids)
key_states = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape)
value_states = self.v_norm(value_states)
value_states = value_states.transpose(1, 2)
# In 4.57.1, ALL_ATTENTION_FUNCTIONS is a dict-like, use .get with fallback
attn_impl = self.config._attn_implementation
if attn_impl != "eager" and attn_impl in ALL_ATTENTION_FUNCTIONS:
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]
else:
attention_interface = eager_attention_forward
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Gemma4VisionEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4VisionMLP(config)
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings=None,
attention_mask=None,
position_ids=None,
**kwargs,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma4VisionEncoder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.num_layers = config.num_hidden_layers
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
self.layers = nn.ModuleList(
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
)
def forward(self, inputs_embeds, attention_mask, pixel_position_ids=None, **kwargs):
attention_mask = create_bidirectional_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
position_ids=pixel_position_ids,
**kwargs,
)
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
class Gemma4TextMLP(nn.Module):
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma4TextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.layer_types = set(config.layer_types)
self.rope_init_fns = {}
self.rope_type = {}
for lt in self.layer_types:
rope_params = self.config.rope_parameters[lt]
if rope_params is None:
continue
rope_type_name = rope_params["rope_type"]
if rope_type_name != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_name]
else:
rope_init_fn = self.compute_default_rope_parameters
self.rope_init_fns[lt] = rope_init_fn
self.rope_type[lt] = rope_type_name
rope_init_fn_kwargs = {"device": device, "layer_type": lt}
if lt == "full_attention" and rope_type_name == "proportional":
rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs)
self.register_buffer(f"{lt}_inv_freq", curr_inv_freq, persistent=False)
self.register_buffer(f"{lt}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
setattr(self, f"{lt}_attention_scaling", curr_attention_scaling)
@staticmethod
def compute_default_rope_parameters(config=None, device=None, seq_len=None, layer_type=None, **kwargs):
base = config.rope_parameters[layer_type]["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids, layer_type=None):
inv_freq = getattr(self, f"{layer_type}_inv_freq")
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * attention_scaling
sin = emb.sin() * attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.is_sliding = self.layer_type == "sliding_attention"
self.sliding_window = config.sliding_window if self.is_sliding else None
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
num_key_value_heads = (
config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads
)
self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
self.scaling = 1.0
self.attention_dropout = self.config.attention_dropout
self.is_causal = config.use_bidirectional_attention != "all"
# Shared kv cache
first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0)
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
if self.is_kv_shared_layer:
self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
self.store_full_length_kv = False
else:
self.kv_shared_layer_index = None
self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
config.layer_types[layer_idx]
)
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
if not self.is_kv_shared_layer:
self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
self.k_proj = nn.Linear(
config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = (
nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
if not self.use_alternative_attention
else None
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask,
shared_kv_states,
past_key_values=None,
**kwargs,
):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
cos, sin = position_embeddings
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
key_states = self.k_norm(key_states)
key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
key_states = key_states.transpose(1, 2)
value_states = self.v_norm(value_states)
value_states = value_states.transpose(1, 2)
if past_key_values is not None and not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
if self.store_full_length_kv:
shared_kv_states[self.layer_idx] = key_states, value_states
attn_impl = self.config._attn_implementation
if attn_impl != "eager" and attn_impl in ALL_ATTENTION_FUNCTIONS:
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]
else:
attention_interface = eager_attention_forward
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@use_experts_implementation
class Gemma4TextExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, hidden_states, top_k_index, top_k_weights):
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class Gemma4TextRouter(nn.Module):
def __init__(self, config: Gemma4TextConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.scalar_root_size = self.hidden_size**-0.5
self.eps = config.rms_norm_eps
self.norm = Gemma4RMSNorm(self.hidden_size, eps=self.eps, with_scale=False)
self.proj = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.scale = nn.Parameter(torch.ones(self.hidden_size))
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
def forward(self, hidden_states):
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * self.scale * self.scalar_root_size
expert_scores = self.proj(hidden_states)
router_probabilities = nn.functional.softmax(expert_scores, dim=-1)
top_k_weights, top_k_index = torch.topk(
router_probabilities,
k=self.config.top_k_experts,
dim=-1,
)
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]
return router_probabilities, top_k_weights, top_k_index
class Gemma4TextDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Gemma4TextAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4TextMLP(config, layer_idx)
self.input_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.register_buffer("layer_scalar", torch.ones(1))
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.act_fn = ACT2FN[config.hidden_activation]
self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
self.post_per_layer_input_norm = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.enable_moe_block = config.enable_moe_block
if self.enable_moe_block:
self.router = Gemma4TextRouter(config)
self.experts = Gemma4TextExperts(config)
self.post_feedforward_layernorm_1 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm_2 = Gemma4RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
per_layer_input=None,
shared_kv_states=None,
position_embeddings=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
**kwargs,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
shared_kv_states=shared_kv_states,
position_ids=position_ids,
past_key_values=past_key_values,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.enable_moe_block:
hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
hidden_states_flat = residual.reshape(-1, residual.shape[-1])
_, top_k_weights, top_k_index = self.router(hidden_states_flat)
hidden_states_2 = self.pre_feedforward_layernorm_2(hidden_states_flat)
hidden_states_2 = self.experts(hidden_states_2, top_k_index, top_k_weights)
hidden_states_2 = hidden_states_2.reshape(residual.shape)
hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
hidden_states = hidden_states_1 + hidden_states_2
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
if self.hidden_size_per_layer_input:
residual = hidden_states
hidden_states = self.per_layer_input_gate(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = hidden_states * per_layer_input
hidden_states = self.per_layer_projection(hidden_states)
hidden_states = self.post_per_layer_input_norm(hidden_states)
hidden_states = residual + hidden_states
hidden_states *= self.layer_scalar
return hidden_states
class Gemma4TextScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.scalar_embed_scale = embed_scale
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
# ---- Model Classes ----
class Gemma4PreTrainedModel(PreTrainedModel):
config: Gemma4Config
supports_gradient_checkpointing = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_no_split_modules = ["Gemma4TextDecoderLayer", "Gemma4VisionEncoderLayer", "Gemma4AudioLayer"]
_skip_keys_device_placement = ["past_key_values", "shared_kv_states"]
input_modalities = ("image", "text", "video", "audio")
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Gemma4VisionPatchEmbedder):
init.ones_(module.position_embedding_table)
elif isinstance(module, Gemma4AudioRelPositionalEncoding):
min_timescale = 1.0
max_timescale = 10000.0
num_timescales = module.hidden_size // 2
log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
module.inv_timescales.copy_(inv_timescales.unsqueeze(0).unsqueeze(0))
elif isinstance(module, Gemma4AudioAttention):
init.constant_(module.softcap, module.attention_logits_soft_cap)
init.zeros_(module.per_dim_scale)
elif isinstance(module, Gemma4TextRotaryEmbedding):
for layer_type, rope_init_fn in module.rope_init_fns.items():
rope_init_fn_kwargs = {"layer_type": layer_type}
if layer_type == "full_attention" and module.rope_type[layer_type] == "proportional":
rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
curr_inv_freq, _ = rope_init_fn(module.config, **rope_init_fn_kwargs)
getattr(module, f"{layer_type}_inv_freq").copy_(curr_inv_freq)
getattr(module, f"{layer_type}_original_inv_freq").copy_(curr_inv_freq)
elif isinstance(module, Gemma4VisionRotaryEmbedding):
rope_fn = (
ROPE_INIT_FUNCTIONS[module.rope_type]
if module.rope_type != "default"
else module.compute_default_rope_parameters
)
buffer_value, _ = rope_fn(module.config)
module.inv_freq.copy_(buffer_value)
module.original_inv_freq.copy_(buffer_value)
elif isinstance(module, Gemma4TextScaledWordEmbedding):
init.constant_(module.embed_scale, module.scalar_embed_scale)
elif isinstance(module, Gemma4TextRouter):
init.ones_(module.scale)
init.ones_(module.per_expert_scale)
elif isinstance(module, Gemma4TextExperts):
std = self.config.initializer_range
init.normal_(module.gate_up_proj, mean=0.0, std=std)
init.normal_(module.down_proj, mean=0.0, std=std)
elif isinstance(module, Gemma4TextDecoderLayer):
init.ones_(module.layer_scalar)
elif isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears:
init.constant_(module.input_min, -float("inf"))
init.constant_(module.input_max, float("inf"))
init.constant_(module.output_min, -float("inf"))
init.constant_(module.output_max, float("inf"))
elif isinstance(module, Gemma4VisionModel) and module.config.standardize:
init.zeros_(module.std_bias)
init.ones_(module.std_scale)
class Gemma4TextModel(Gemma4PreTrainedModel):
"""The base Gemma 4 language model without a language modeling head."""
config: Gemma4TextConfig
input_modalities = ("text",)
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = Gemma4TextScaledWordEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
)
self.layers = nn.ModuleList(
[Gemma4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Gemma4TextRotaryEmbedding(config)
self.gradient_checkpointing = False
self.unique_layer_types = set(self.config.layer_types)
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
self.padding_idx,
embed_scale=config.hidden_size_per_layer_input**0.5,
)
self.per_layer_input_scale = 2.0**-0.5
self.per_layer_model_projection = nn.Linear(
config.hidden_size,
config.num_hidden_layers * config.hidden_size_per_layer_input,
bias=False,
)
self.per_layer_model_projection_scale = config.hidden_size**-0.5
self.per_layer_projection_norm = Gemma4RMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
# Update `_keys_to_ignore_on_load_unexpected` for shared kv layers
self._keys_to_ignore_on_load_unexpected = []
for i, layer in enumerate(self.layers):
if layer.self_attn.is_kv_shared_layer:
self._keys_to_ignore_on_load_unexpected.extend(
[f"layers.{i}.self_attn.{name}" for name in ("k_proj", "v_proj", "k_norm", "v_norm")]
)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
per_layer_inputs=None,
use_cache=None,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids)
if self.hidden_size_per_layer_input:
if per_layer_inputs is None:
per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds)
per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
position_ids = position_ids.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": torch.arange(
past_key_values.get_seq_length() if past_key_values is not None else 0,
(past_key_values.get_seq_length() if past_key_values is not None else 0) + inputs_embeds.shape[1],
device=inputs_embeds.device,
),
"past_key_values": past_key_values,
"position_ids": position_ids,
}
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
hidden_states = inputs_embeds
position_embeddings = {}
for layer_type in self.unique_layer_types:
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
shared_kv_states = {}
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
per_layer_input = per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None
hidden_states = decoder_layer(
hidden_states,
per_layer_input,
shared_kv_states=shared_kv_states,
position_embeddings=position_embeddings[self.config.layer_types[i]],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
def get_per_layer_inputs(self, input_ids, inputs_embeds):
if not self.hidden_size_per_layer_input:
raise RuntimeError("Model not configured with per-layer embeddings.")
if input_ids is None:
with torch.no_grad():
input_ids = (
(
inputs_embeds[:, :, None, :]
== self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
)
.all(dim=3)
.nonzero()[:, 2]
)
try:
input_ids = input_ids.view(inputs_embeds.shape[:2])
except RuntimeError:
raise RuntimeError(
"Cannot reverse embedding to recover input_ids. Provide both `input_ids` and `inputs_embeds`."
)
return self.embed_tokens_per_layer(input_ids).reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs=None):
if not self.hidden_size_per_layer_input:
raise RuntimeError("Model not configured with per-layer embeddings.")
per_layer_projection = self.per_layer_model_projection(inputs_embeds) * self.per_layer_model_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
class Gemma4ForCausalLM(Gemma4PreTrainedModel, GenerationMixin):
"""The base Gemma 4 language model with a language modeling head."""
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
config: Gemma4TextConfig
base_model_prefix = "model"
def __init__(self, config: Gemma4TextConfig):
super().__init__(config)
self.model = Gemma4TextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self._keys_to_ignore_on_load_unexpected = [
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
]
self.post_init()
@can_return_tuple
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
logits_to_keep=0,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def sliding_window_mask_function(sliding_window):
"""Creates uni/bidirectional attention mask with sliding window."""
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
left_window_size, right_window_size = sliding_window
dist = q_idx - kv_idx
left_mask = (dist >= 0) & (dist < left_window_size)
right_mask = (dist < 0) & (-dist < right_window_size)
return left_mask | right_mask
return inner_mask
class Gemma4AudioModel(Gemma4PreTrainedModel):
"""An audio encoder based on the Universal Speech Model architecture."""
config: Gemma4AudioConfig
main_input_name = "input_features"
base_model_prefix = "model.audio_tower"
def __init__(self, config: Gemma4AudioConfig):
super().__init__(config)
self.config = config
self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection(config)
self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config)
self.layers = nn.ModuleList(
[Gemma4AudioLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.output_proj = nn.Linear(config.hidden_size, config.output_proj_dims, bias=True)
self.post_init()
def _convert_4d_mask_to_blocked_5d(self, mask_4d: torch.Tensor) -> torch.Tensor:
batch_size, _, seq_len, _ = mask_4d.shape
device = mask_4d.device
chunk_size = self.config.attention_chunk_size
max_past_horizon = self.config.attention_context_left - 1
max_future_horizon = self.config.attention_context_right
num_blocks = (seq_len + chunk_size - 1) // chunk_size
padded_seq_len = num_blocks * chunk_size
pad_amount = padded_seq_len - seq_len
mask_4d = F.pad(mask_4d, (0, pad_amount, 0, pad_amount), value=False)
mask_5d = mask_4d.reshape(batch_size, 1, num_blocks, chunk_size, padded_seq_len)
mask_5d = F.pad(mask_5d, (max_past_horizon, max_future_horizon), value=False)
block_starts = torch.arange(num_blocks, device=device) * chunk_size
offsets = torch.arange(chunk_size + max_past_horizon + max_future_horizon, device=device)
kv_indices = block_starts[:, None] + offsets[None, :]
kv_indices = kv_indices[None, None, :, None, :].expand(batch_size, 1, -1, chunk_size, -1)
return mask_5d.gather(-1, kv_indices)
def forward(self, input_features, attention_mask=None, **kwargs):
hidden_states, output_mask = self.subsample_conv_projection(input_features, attention_mask)
position_embeddings = self.rel_pos_enc(hidden_states)
attention_mask_4d = create_bidirectional_mask(
config=self.config,
inputs_embeds=hidden_states,
attention_mask=output_mask,
)
attention_mask_4d = self._convert_4d_mask_to_blocked_5d(attention_mask_4d)
for encoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask_4d,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.output_proj(hidden_states)
return Gemma4AudioModelOutput(last_hidden_state=hidden_states, attention_mask=output_mask)
class Gemma4VisionModel(Gemma4PreTrainedModel):
"""The Gemma 4 Vision Encoder."""
config = Gemma4VisionConfig
def __init__(self, config: Gemma4VisionConfig):
super().__init__(config)
self.patch_embedder = Gemma4VisionPatchEmbedder(config)
self.encoder = Gemma4VisionEncoder(config)
self.pooler = Gemma4VisionPooler(config)
if self.config.standardize:
self.register_buffer("std_bias", torch.empty(self.config.hidden_size))
self.register_buffer("std_scale", torch.empty(self.config.hidden_size))
self.post_init()
def forward(self, pixel_values, pixel_position_ids, **kwargs):
pooling_kernel_size = self.config.pooling_kernel_size
output_length = pixel_values.shape[-2] // (pooling_kernel_size * pooling_kernel_size)
padding_positions = (pixel_position_ids == -1).all(dim=-1)
inputs_embeds = self.patch_embedder(pixel_values, pixel_position_ids, padding_positions)
output = self.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~padding_positions,
pixel_position_ids=pixel_position_ids,
**kwargs,
)
hidden_states, pooler_mask = self.pooler(
hidden_states=output.last_hidden_state,
pixel_position_ids=pixel_position_ids,
padding_positions=padding_positions,
output_length=output_length,
)
hidden_states = hidden_states[pooler_mask]
if self.config.standardize:
hidden_states = (hidden_states - self.std_bias) * self.std_scale
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
class Gemma4MultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space."""
def __init__(self, multimodal_config, text_config):
super().__init__()
self.multimodal_hidden_size = getattr(multimodal_config, "output_proj_dims", multimodal_config.hidden_size)
self.eps = multimodal_config.rms_norm_eps
self.text_hidden_size = text_config.hidden_size
self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
self.embedding_pre_projection_norm = Gemma4RMSNorm(self.multimodal_hidden_size, eps=self.eps, with_scale=False)
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
embs_normed = self.embedding_pre_projection_norm(inputs_embeds)
return self.embedding_projection(embs_normed)
def token_type_ids_mask_function(token_type_ids, image_group_ids):
if token_type_ids is None:
return None
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
seq_length = image_group_ids.shape[-1]
q_idx_clamped = q_idx.clamp(max=seq_length - 1)
kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
q_group = image_group_ids[batch_idx, q_idx_clamped]
kv_group = image_group_ids[batch_idx, kv_idx_clamped]
q_group = torch.where(q_idx < seq_length, q_group, -1)
kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
return (q_group == kv_group) & (q_group >= 0)
return inner_mask
def create_causal_mask_mapping(
config,
inputs_embeds,
attention_mask,
past_key_values,
position_ids,
mm_token_type_ids=None,
pixel_values=None,
is_training=False,
is_first_iteration=None,
**kwargs,
):
if is_training and mm_token_type_ids is None:
raise ValueError("`mm_token_type_ids` is required as a model input when training")
past_seq_len = past_key_values.get_seq_length() if past_key_values is not None else 0
cur_seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(past_seq_len, past_seq_len + cur_seq_len, device=inputs_embeds.device)
mask_kwargs = {
"config": config.get_text_config(),
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
sliding_mask_kwargs = mask_kwargs.copy()
is_first_iteration = (
is_first_iteration
if is_first_iteration is not None
else (past_key_values is None or pixel_values is not None)
)
if mm_token_type_ids is not None and is_first_iteration:
is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2)
is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
is_prev_vision[..., 0] = False
new_vision_starts = is_vision & ~is_prev_vision
vision_group_ids = torch.cumsum(new_vision_starts.int(), dim=1) - 1
vision_group_ids = torch.where(is_vision, vision_group_ids, -1)
sliding_mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
mm_token_type_ids.to(inputs_embeds.device), vision_group_ids
)
return {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
class Gemma4Model(Gemma4PreTrainedModel):
"""
The base Gemma 4 model comprising a vision backbone, an audio backbone,
and a language model without a language modeling head.
"""
accepts_loss_kwargs = False
def __init__(self, config: Gemma4Config):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = Gemma4TextModel(config.text_config)
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
self.vision_tower = Gemma4VisionModel(config.vision_config) if config.vision_config is not None else None
self.embed_vision = (
Gemma4MultimodalEmbedder(config.vision_config, config.text_config)
if config.vision_config is not None
else None
)
self.audio_tower = Gemma4AudioModel(config.audio_config) if config.audio_config is not None else None
self.embed_audio = (
Gemma4MultimodalEmbedder(config.audio_config, config.text_config)
if config.audio_config is not None
else None
)
self._keys_to_ignore_on_load_unexpected = [
f"language_model.{name}" for name in self.language_model._keys_to_ignore_on_load_unexpected
]
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
@can_return_tuple
def get_image_features(self, pixel_values, image_position_ids=None, **kwargs):
vision_outputs = self.vision_tower(
pixel_values=pixel_values,
pixel_position_ids=image_position_ids,
**kwargs,
)
last_hidden_state = vision_outputs.last_hidden_state
vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
return vision_outputs
def get_placeholder_mask(self, input_ids=None, inputs_embeds=None):
if input_ids is not None:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
special_audio_mask = input_ids == self.config.audio_token_id
else:
special_image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
).all(-1)
special_video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
).all(-1)
special_audio_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
).all(-1)
return special_image_mask, special_video_mask, special_audio_mask
@can_return_tuple
def forward(
self,
input_ids=None,
pixel_values=None,
pixel_values_videos=None,
input_features=None,
attention_mask=None,
input_features_mask=None,
position_ids=None,
past_key_values=None,
mm_token_type_ids=None,
inputs_embeds=None,
use_cache=None,
image_position_ids=None,
video_position_ids=None,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
image_mask, video_mask, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds)
multimodal_mask = image_mask | video_mask | audio_mask
llm_input_ids = None
if inputs_embeds is None:
llm_input_ids = input_ids.clone()
llm_input_ids[multimodal_mask] = self.config.text_config.pad_token_id
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if self.config.get_text_config().hidden_size_per_layer_input:
pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)
else:
per_layer_inputs = None
if pixel_values is not None:
image_features = self.get_image_features(pixel_values, image_position_ids, return_dict=True).pooler_output
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
n_image_tokens = image_mask.sum()
image_mask_exp = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[image_mask_exp].numel() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features:"
f" {image_features.shape[0]}"
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask_exp, image_features.to(inputs_embeds.device))
if pixel_values_videos is not None:
video_features = self.get_video_features(pixel_values_videos, video_position_ids, return_dict=True).pooler_output
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
n_video_tokens = video_mask.sum()
video_mask_exp = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[video_mask_exp].numel() != video_features.numel():
raise ValueError(
f"Video features and video tokens do not match, tokens: {n_video_tokens}, features:"
f" {video_features.shape[0]}"
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask_exp, video_features.to(inputs_embeds.device))
if input_features is not None and input_features_mask is not None:
audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True)
audio_features = audio_output.pooler_output
audio_mask_from_encoder = audio_output.attention_mask
audio_features = audio_features[audio_mask_from_encoder]
n_audio_tokens = audio_mask.sum()
audio_mask_exp = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[audio_mask_exp].numel() != audio_features.numel():
raise ValueError(
f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:"
f" {audio_features.shape[0] * audio_features.shape[1]}"
)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask_exp, audio_features.to(inputs_embeds.device))
if position_ids is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
position_ids = position_ids.unsqueeze(0)
if not isinstance(causal_mask_mapping := attention_mask, dict):
if self.config.get_text_config().use_bidirectional_attention == "vision":
causal_mask_mapping = create_causal_mask_mapping(
self.config,
inputs_embeds,
attention_mask,
past_key_values,
position_ids,
mm_token_type_ids,
pixel_values,
is_training=self.training,
)
else:
past_seq_len = past_key_values.get_seq_length() if past_key_values is not None else 0
cur_seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(past_seq_len, past_seq_len + cur_seq_len, device=inputs_embeds.device)
causal_mask_mapping = create_masks_for_generate(
self.config,
inputs_embeds,
attention_mask,
cache_position,
past_key_values,
position_ids,
)
outputs = self.language_model(
per_layer_inputs=per_layer_inputs,
attention_mask=causal_mask_mapping,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
return_dict=True,
**kwargs,
)
image_features_out = image_features if pixel_values is not None else None
audio_features_out = audio_features if input_features is not None else None
return Gemma4ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features_out,
audio_hidden_states=audio_features_out,
)
@can_return_tuple
def get_audio_features(self, input_features, input_features_mask, **kwargs):
if self.audio_tower is None:
raise ValueError(
"Audio features were requested, but the model was initialized without an audio_config."
)
audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True, **kwargs)
audio_outputs.pooler_output = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state)
return audio_outputs
@can_return_tuple
def get_video_features(self, pixel_values_videos, video_position_ids=None, **kwargs):
pixel_values_videos = pixel_values_videos.flatten(0, 1)
video_position_ids = video_position_ids.flatten(0, 1)
vision_outputs = self.vision_tower(
pixel_values=pixel_values_videos,
pixel_position_ids=video_position_ids,
**kwargs,
)
last_hidden_state = vision_outputs.last_hidden_state
vision_outputs.pooler_output = self.embed_vision(inputs_embeds=last_hidden_state)
return vision_outputs
class Gemma4ForConditionalGeneration(Gemma4PreTrainedModel, GenerationMixin):
"""
The base Gemma 4 model comprising a vision backbone, an audio backbone,
a language model, and a language modeling head.
"""
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
base_model_prefix = "model"
def __init__(self, config: Gemma4Config):
super().__init__(config)
self.model = Gemma4Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self._keys_to_ignore_on_load_unexpected = [
f"model.{name}" for name in self.model._keys_to_ignore_on_load_unexpected
]
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_image_features(self, pixel_values, image_position_ids=None, **kwargs):
return self.model.get_image_features(pixel_values, image_position_ids, **kwargs)
@can_return_tuple
def forward(
self,
input_ids=None,
pixel_values=None,
pixel_values_videos=None,
input_features=None,
attention_mask=None,
input_features_mask=None,
position_ids=None,
image_position_ids=None,
video_position_ids=None,
past_key_values=None,
mm_token_type_ids=None,
inputs_embeds=None,
labels=None,
use_cache=None,
logits_to_keep=0,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
attention_mask=attention_mask,
input_features_mask=input_features_mask,
position_ids=position_ids,
past_key_values=past_key_values,
mm_token_type_ids=mm_token_type_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
image_position_ids=image_position_ids,
video_position_ids=video_position_ids,
return_dict=True,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
logits = logits / final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * final_logit_softcapping
loss = None
if labels is not None:
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
return Gemma4CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
audio_hidden_states=outputs.audio_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
pixel_values=None,
pixel_values_videos=None,
input_features=None,
attention_mask=None,
input_features_mask=None,
token_type_ids=None,
use_cache=True,
logits_to_keep=None,
labels=None,
is_first_iteration=False,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
is_first_iteration=is_first_iteration,
**kwargs,
)
if is_first_iteration or not use_cache:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["input_features"] = input_features
model_inputs["input_features_mask"] = input_features_mask
return model_inputs
@staticmethod
def create_masks_for_generate(
config,
inputs_embeds,
attention_mask,
past_key_values,
position_ids,
mm_token_type_ids=None,
is_first_iteration=False,
**kwargs,
):
if getattr(config.get_text_config(), "use_bidirectional_attention", None) == "vision":
return create_causal_mask_mapping(
config,
inputs_embeds,
attention_mask,
past_key_values,
position_ids,
mm_token_type_ids,
is_first_iteration=is_first_iteration,
**{k: v for k, v in kwargs.items() if k != "pixel_values"},
)
else:
past_seq_len = past_key_values.get_seq_length() if past_key_values is not None else 0
cur_seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(past_seq_len, past_seq_len + cur_seq_len, device=inputs_embeds.device)
return create_masks_for_generate(
config, inputs_embeds, attention_mask, cache_position, past_key_values, position_ids, **kwargs
)
__all__ = [
"Gemma4AudioModel",
"Gemma4ForCausalLM",
"Gemma4ForConditionalGeneration",
"Gemma4Model",
"Gemma4PreTrainedModel",
"Gemma4TextModel",
"Gemma4VisionModel",
]