gemmagain-mm / modeling_gemmagain.py
ToastyPigeon's picture
Upload folder using huggingface_hub
414fb1a verified
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Gemmagain Multimodal - Gemma3 multimodal model with layer looping support for the text decoder.
This model allows running the same physical text decoder layers multiple times in sequence,
enabling parameter-efficient deep networks. The vision tower is unchanged.
Compatible with standard Gemma3 multimodal weights (Gemma3ForConditionalGeneration).
"""
import copy
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, DynamicLayer
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.deprecation import deprecate_kwarg
from transformers.models.auto import AutoModel
try:
from .configuration_gemmagain import GemmagainConfig, GemmagainTextConfig
except ImportError:
from configuration_gemmagain import GemmagainConfig, GemmagainTextConfig
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
Base class for Gemmagain outputs, with hidden states and attentions.
"""
)
class GemmagainModelOutputWithPast(BaseModelOutputWithPast):
r"""
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for Gemmagain causal language model (or autoregressive) outputs.
"""
)
class GemmagainCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*):
Contains pre-computed hidden-states for sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
Image hidden states from the vision encoder.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
class Gemma3TextScaledWordEmbedding(nn.Embedding):
"""Embedding with scaling factor."""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
class Gemma3MLP(nn.Module):
def __init__(self, config: GemmagainTextConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma3RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Gemma3RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: GemmagainTextConfig, device=None):
super().__init__()
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
softcap: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if scaling is None:
scaling = module.head_dim**-0.5
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if softcap is not None:
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Gemma3Attention(nn.Module):
"""Multi-headed attention with support for looping (cache_slot_idx)."""
def __init__(self, config: GemmagainTextConfig, layer_idx: int):
super().__init__()
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = config.query_pre_attn_scalar**-0.5
self.attention_dropout = self.config.attention_dropout
self.is_causal = not self.config.use_bidirectional_attention
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
self.attn_logit_softcapping = self.config.attn_logit_softcapping
self.sliding_window = config.sliding_window if self.is_sliding else None
self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_slot_idx: Optional[int] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
slot_idx = cache_slot_idx if cache_slot_idx is not None else self.layer_idx
key_states, value_states = past_key_values.update(key_states, value_states, slot_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Gemma3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GemmagainTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
cache_slot_idx: Optional[int] = None,
**kwargs,
) -> tuple[torch.FloatTensor, ...]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_slot_idx=cache_slot_idx,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class GemmagainPreTrainedModel(PreTrainedModel):
config_class = GemmagainConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma3DecoderLayer", "SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Gemma3DecoderLayer,
"attentions": Gemma3Attention,
}
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Gemma3MultiModalProjector):
module.mm_input_projection_weight.data.zero_()
elif "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()
def _expand_layer_sequence(layer_sequence, num_hidden_layers):
"""Expand layer_sequence config into a flat list of layer indices."""
l_seq = []
for item in layer_sequence:
if isinstance(item, int):
l_seq.append(item)
elif isinstance(item, list):
if len(item) == 2:
start, end = item
l_seq += list(range(start, min(end, num_hidden_layers)))
elif len(item) == 3:
start, end, repeats = item
l_seq += list(range(start, min(end, num_hidden_layers))) * repeats
else:
raise ValueError(f"Invalid layer_sequence item: {item}")
else:
raise ValueError(f"Invalid layer_sequence item type: {type(item)}")
return l_seq
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return abs(q_idx - kv_idx) < sliding_window
return inner_mask
def token_type_ids_mask_function(
token_type_ids: Optional[torch.Tensor],
image_group_ids: Optional[torch.Tensor],
tokens_per_image: int,
) -> Optional[Callable]:
"""Mask function for bidirectional attention on image tokens."""
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
return is_image_block & same_image_block
return inner_mask
@auto_docstring
class GemmagainTextModel(GemmagainPreTrainedModel):
"""Text model with layer looping support."""
config_class = GemmagainTextConfig
def __init__(self, config: GemmagainTextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = Gemma3TextScaledWordEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=config.hidden_size**0.5
)
self.layers = nn.ModuleList(
[Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Local RoPE with different theta
local_config = copy.deepcopy(config)
local_config.rope_theta = config.rope_local_base_freq
local_config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=local_config)
# Pre-compute expanded layer sequence for looping
self._layer_sequence = _expand_layer_sequence(config.layer_sequence, config.num_hidden_layers)
self._num_cache_slots = len(self._layer_sequence)
self.post_init()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Only create cache if use_cache is explicitly True or default (config.use_cache)
effective_use_cache = use_cache if use_cache is not None else self.config.use_cache
if effective_use_cache and not self.training:
if past_key_values is None:
cache_config = copy.copy(self.config)
cache_config.num_hidden_layers = self._num_cache_slots
past_key_values = DynamicCache(config=cache_config)
elif isinstance(past_key_values, DynamicCache) and len(past_key_values.layers) < self._num_cache_slots:
while len(past_key_values.layers) < self._num_cache_slots:
past_key_values.layers.append(DynamicLayer())
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Prepare attention masks
if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
sliding_mask_kwargs = mask_kwargs.copy()
if self.config.use_bidirectional_attention:
mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
}
hidden_states = inputs_embeds
position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
# Execute layers in the configured sequence with looping support
for cache_slot_idx, layer_idx in enumerate(self._layer_sequence):
if output_hidden_states:
all_hidden_states += (hidden_states,)
decoder_layer = self.layers[layer_idx]
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_slot_idx=cache_slot_idx,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Gemma3MultiModalProjector(nn.Module):
def __init__(self, config: GemmagainConfig):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
return projected_vision_outputs.type_as(vision_outputs)
@auto_docstring(
custom_intro="""
Gemmagain multimodal model with layer looping support for the text decoder.
"""
)
class GemmagainModel(GemmagainPreTrainedModel):
"""Multimodal model combining vision tower with looping text decoder."""
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
accepts_loss_kwargs = False
def __init__(self, config: GemmagainConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
# Use our custom text model with looping
self.language_model = GemmagainTextModel(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_input_embeddings(self):
return self.language_model.embed_tokens
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def get_placeholder_mask(
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
):
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_features.shape[0] * image_features.shape[1]
if inputs_embeds[special_image_mask].numel() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
return special_image_mask
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[tuple, GemmagainModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Replace image id with PAD if the image token is OOV
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
image_features = None
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_features)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# Prepare attention masks
if not isinstance(causal_mask_mapping := attention_mask, dict):
mask_kwargs = {
"config": self.config.get_text_config(),
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
is_prefill = (
not use_cache
or past_key_values is None
or not past_key_values.is_initialized
or pixel_values is not None
)
if token_type_ids is not None and is_prefill:
is_image = (token_type_ids == 1).to(cache_position.device)
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
)
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
outputs = self.language_model(
attention_mask=causal_mask_mapping,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return GemmagainModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
@auto_docstring(
custom_intro="""
Gemmagain multimodal model for conditional generation with layer looping support.
"""
)
class GemmagainForConditionalGeneration(GemmagainPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
accepts_loss_kwargs = False
def __init__(self, config: GemmagainConfig):
super().__init__(config)
self.model = GemmagainModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, GemmagainCausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**lm_kwargs,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return GemmagainCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
pixel_values=None,
attention_mask=None,
token_type_ids=None,
use_cache=True,
logits_to_keep=None,
labels=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
return model_inputs
@staticmethod
def create_masks_for_generate(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
mask_kwargs = {
"config": config.get_text_config(),
"input_embeds": input_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
if token_type_ids is not None and input_embeds.shape[1] != 1:
is_image = (token_type_ids == 1).to(cache_position.device)
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
)
return create_masks_for_generate(**mask_kwargs)
__all__ = [
"GemmagainForConditionalGeneration",
"GemmagainModel",
"GemmagainTextModel",
"GemmagainPreTrainedModel",
]