# coding=utf-8 # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Gemmagain Multimodal - Gemma3 multimodal model with layer looping support for the text decoder. This model allows running the same physical text decoder layers multiple times in sequence, enabling parameter-efficient deep networks. The vision tower is unchanged. Compatible with standard Gemma3 multimodal weights (Gemma3ForConditionalGeneration). """ import copy from dataclasses import dataclass from typing import Callable, Optional, Union import torch import torch.nn as nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, DynamicLayer from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging from transformers.utils.deprecation import deprecate_kwarg from transformers.models.auto import AutoModel try: from .configuration_gemmagain import GemmagainConfig, GemmagainTextConfig except ImportError: from configuration_gemmagain import GemmagainConfig, GemmagainTextConfig logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Base class for Gemmagain outputs, with hidden states and attentions. """ ) class GemmagainModelOutputWithPast(BaseModelOutputWithPast): r""" image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for Gemmagain causal language model (or autoregressive) outputs. """ ) class GemmagainCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*): Contains pre-computed hidden-states for sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): Image hidden states from the vision encoder. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None class Gemma3TextScaledWordEmbedding(nn.Embedding): """Embedding with scaling factor.""" def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): super().__init__(num_embeddings, embedding_dim, padding_idx) self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) class Gemma3MLP(nn.Module): def __init__(self, config: GemmagainTextConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class Gemma3RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) output = output * (1.0 + self.weight.float()) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" class Gemma3RotaryEmbedding(nn.Module): inv_freq: torch.Tensor def __init__(self, config: GemmagainTextConfig, device=None): super().__init__() if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, softcap: Optional[float] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if scaling is None: scaling = module.head_dim**-0.5 key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if softcap is not None: attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * softcap if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Gemma3Attention(nn.Module): """Multi-headed attention with support for looping (cache_slot_idx).""" def __init__(self, config: GemmagainTextConfig, layer_idx: int): super().__init__() self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 self.attention_dropout = self.config.attention_dropout self.is_causal = not self.config.use_bidirectional_attention self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if self.is_sliding else None self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, cache_slot_idx: Optional[int] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} slot_idx = cache_slot_idx if cache_slot_idx is not None else self.layer_idx key_states, value_states = past_key_values.update(key_states, value_states, slot_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Gemma3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GemmagainTextConfig, layer_idx: int): super().__init__() self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings_global: torch.Tensor, position_embeddings_local: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, cache_slot_idx: Optional[int] = None, **kwargs, ) -> tuple[torch.FloatTensor, ...]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) if self.self_attn.is_sliding: position_embeddings = position_embeddings_local else: position_embeddings = position_embeddings_global hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, cache_slot_idx=cache_slot_idx, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.pre_feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class GemmagainPreTrainedModel(PreTrainedModel): config_class = GemmagainConfig base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["Gemma3DecoderLayer", "SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Gemma3DecoderLayer, "attentions": Gemma3Attention, } def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Gemma3MultiModalProjector): module.mm_input_projection_weight.data.zero_() elif "RMSNorm" in module.__class__.__name__: module.weight.data.zero_() def _expand_layer_sequence(layer_sequence, num_hidden_layers): """Expand layer_sequence config into a flat list of layer indices.""" l_seq = [] for item in layer_sequence: if isinstance(item, int): l_seq.append(item) elif isinstance(item, list): if len(item) == 2: start, end = item l_seq += list(range(start, min(end, num_hidden_layers))) elif len(item) == 3: start, end, repeats = item l_seq += list(range(start, min(end, num_hidden_layers))) * repeats else: raise ValueError(f"Invalid layer_sequence item: {item}") else: raise ValueError(f"Invalid layer_sequence item type: {type(item)}") return l_seq def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return abs(q_idx - kv_idx) < sliding_window return inner_mask def token_type_ids_mask_function( token_type_ids: Optional[torch.Tensor], image_group_ids: Optional[torch.Tensor], tokens_per_image: int, ) -> Optional[Callable]: """Mask function for bidirectional attention on image tokens.""" if token_type_ids is None: return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx] token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx] image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1) is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1) same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx return is_image_block & same_image_block return inner_mask @auto_docstring class GemmagainTextModel(GemmagainPreTrainedModel): """Text model with layer looping support.""" config_class = GemmagainTextConfig def __init__(self, config: GemmagainTextConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = Gemma3TextScaledWordEmbedding( config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=config.hidden_size**0.5 ) self.layers = nn.ModuleList( [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Gemma3RotaryEmbedding(config=config) self.gradient_checkpointing = False # Local RoPE with different theta local_config = copy.deepcopy(config) local_config.rope_theta = config.rope_local_base_freq local_config.rope_scaling = {"rope_type": "default"} self.rotary_emb_local = Gemma3RotaryEmbedding(config=local_config) # Pre-compute expanded layer sequence for looping self._layer_sequence = _expand_layer_sequence(config.layer_sequence, config.num_hidden_layers) self._num_cache_slots = len(self._layer_sequence) self.post_init() @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # Only create cache if use_cache is explicitly True or default (config.use_cache) effective_use_cache = use_cache if use_cache is not None else self.config.use_cache if effective_use_cache and not self.training: if past_key_values is None: cache_config = copy.copy(self.config) cache_config.num_hidden_layers = self._num_cache_slots past_key_values = DynamicCache(config=cache_config) elif isinstance(past_key_values, DynamicCache) and len(past_key_values.layers) < self._num_cache_slots: while len(past_key_values.layers) < self._num_cache_slots: past_key_values.layers.append(DynamicLayer()) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) # Prepare attention masks if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } sliding_mask_kwargs = mask_kwargs.copy() if self.config.use_bidirectional_attention: mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool) sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window) causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs), } hidden_states = inputs_embeds position_embeddings_global = self.rotary_emb(hidden_states, position_ids) position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None # Execute layers in the configured sequence with looping support for cache_slot_idx, layer_idx in enumerate(self._layer_sequence): if output_hidden_states: all_hidden_states += (hidden_states,) decoder_layer = self.layers[layer_idx] layer_outputs = decoder_layer( hidden_states, position_embeddings_global=position_embeddings_global, position_embeddings_local=position_embeddings_local, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, cache_slot_idx=cache_slot_idx, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) class Gemma3MultiModalProjector(nn.Module): def __init__(self, config: GemmagainConfig): super().__init__() self.mm_input_projection_weight = nn.Parameter( torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) ) self.mm_soft_emb_norm = Gemma3RMSNorm( config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps ) self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( batch_size, seq_length, self.patches_per_image, self.patches_per_image ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) pooled_vision_outputs = pooled_vision_outputs.flatten(2) pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight) return projected_vision_outputs.type_as(vision_outputs) @auto_docstring( custom_intro=""" Gemmagain multimodal model with layer looping support for the text decoder. """ ) class GemmagainModel(GemmagainPreTrainedModel): """Multimodal model combining vision tower with looping text decoder.""" _checkpoint_conversion_mapping = {"language_model.model": "language_model"} accepts_loss_kwargs = False def __init__(self, config: GemmagainConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = Gemma3MultiModalProjector(config) self.vocab_size = config.text_config.vocab_size # Use our custom text model with looping self.language_model = GemmagainTextModel(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_input_embeddings(self): return self.language_model.embed_tokens def set_input_embeddings(self, value): self.language_model.embed_tokens = value def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state image_features = self.multi_modal_projector(vision_outputs) return image_features def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_features.shape[0] * image_features.shape[1] if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) return special_image_mask @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **lm_kwargs, ) -> Union[tuple, GemmagainModelOutputWithPast]: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Replace image id with PAD if the image token is OOV if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) image_features = None # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Prepare attention masks if not isinstance(causal_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config.get_text_config(), "input_embeds": inputs_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } is_prefill = ( not use_cache or past_key_values is None or not past_key_values.is_initialized or pixel_values is not None ) if token_type_ids is not None and is_prefill: is_image = (token_type_ids == 1).to(cache_position.device) new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image ) causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } outputs = self.language_model( attention_mask=causal_mask_mapping, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **lm_kwargs, ) return GemmagainModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values if use_cache else None, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) @auto_docstring( custom_intro=""" Gemmagain multimodal model for conditional generation with layer looping support. """ ) class GemmagainForConditionalGeneration(GemmagainPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^vision_tower": "model.vision_tower", "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] accepts_loss_kwargs = False def __init__(self, config: GemmagainConfig): super().__init__(config) self.model = GemmagainModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) @property def language_model(self): return self.model.language_model @property def vision_tower(self): return self.model.vision_tower @property def multi_modal_projector(self): return self.model.multi_modal_projector @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[tuple, GemmagainCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **lm_kwargs, ) hidden_states = outputs[0] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: logits = logits.float() shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] if attention_mask is not None: shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device) shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: shift_logits = shift_logits.contiguous() shift_labels = shift_labels.contiguous() loss_fct = nn.CrossEntropyLoss() flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return GemmagainCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, cache_position=None, position_ids=None, pixel_values=None, attention_mask=None, token_type_ids=None, use_cache=True, logits_to_keep=None, labels=None, **kwargs, ): model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, logits_to_keep=logits_to_keep, token_type_ids=token_type_ids, **kwargs, ) if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values return model_inputs @staticmethod def create_masks_for_generate( config: PretrainedConfig, input_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], position_ids: Optional[torch.Tensor], token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: mask_kwargs = { "config": config.get_text_config(), "input_embeds": input_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } if token_type_ids is not None and input_embeds.shape[1] != 1: is_image = (token_type_ids == 1).to(cache_position.device) new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image ) return create_masks_for_generate(**mask_kwargs) __all__ = [ "GemmagainForConditionalGeneration", "GemmagainModel", "GemmagainTextModel", "GemmagainPreTrainedModel", ]