gemma-3-tiled-4b-it / modeling_gemma3_tiled.py
Fraser's picture
Update modeling_gemma3_tiled.py
583441f verified
"""
Gemma3Tiled model with support for tiled high-resolution images.
This model extends Gemma3ForConditionalGeneration to handle images that
are tiled into grids, with spatial rearrangement of embeddings and
linebreak tokens between rows.
"""
import torch
from torch import nn
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Model
from transformers.cache_utils import Cache
from .configuration_gemma3_tiled import Gemma3TiledConfig
class Gemma3TiledModel(Gemma3Model):
"""
Gemma3 model with tiled image support.
Key differences from Gemma3Model:
- get_image_features() handles tile grids and spatial rearrangement
- get_placeholder_mask() validates tiled structure
- Inserts linebreak embeddings (from "\n" token) between rows
"""
config_class = Gemma3TiledConfig
def __init__(self, config: Gemma3TiledConfig):
super().__init__(config)
self.tokens_per_tile = config.mm_tokens_per_image # 256
self.tokens_per_tile_side = int(self.tokens_per_tile**0.5) # 16
# Look up newline token ID from tokenizer vocab
tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
vocab = tokenizer.get_vocab()
if "\n" not in vocab:
raise ValueError("Tokenizer vocab does not contain '\\n' token")
self._linebreak_token_id = vocab["\n"]
def get_linebreak_embedding(self) -> torch.Tensor:
"""Get the embedding for the linebreak token."""
embedding_layer = self.get_input_embeddings()
return embedding_layer.weight[self._linebreak_token_id]
def _process_tiled_image(
self,
pixel_values: torch.Tensor,
grid_h: int,
grid_w: int,
) -> torch.Tensor:
"""
Process a single tiled image and return spatially arranged embeddings with linebreaks.
Args:
pixel_values: Tensor of shape [num_tiles, 3, 896, 896]
grid_h: Number of tile rows
grid_w: Number of tile columns
Returns:
Tensor of shape [total_tokens, hidden_size] where:
total_tokens = (grid_h * 16) * (grid_w * 16) + (grid_h * 16 - 1)
"""
num_tiles = grid_h * grid_w
assert pixel_values.shape[0] == num_tiles, (
f"Expected {num_tiles} tiles for {grid_h}x{grid_w} grid, got {pixel_values.shape[0]}"
)
# Process each tile through vision tower
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
# Project through multimodal projector
# Output shape: [num_tiles, 256, hidden_size]
tile_embeds = self.multi_modal_projector(vision_outputs)
# Reshape to spatial grid
# [num_tiles, 256, hidden] -> [grid_h, grid_w, 16, 16, hidden]
hidden_size = tile_embeds.shape[-1]
tile_embeds = tile_embeds.view(
grid_h, grid_w, self.tokens_per_tile_side, self.tokens_per_tile_side, hidden_size
)
# Rearrange to merge tiles spatially
# We want: for each row of tiles, merge their columns
# [grid_h, grid_w, 16, 16, hidden] -> [grid_h, 16, grid_w, 16, hidden]
tile_embeds = tile_embeds.permute(0, 2, 1, 3, 4)
# Merge into full spatial grid
# [grid_h, 16, grid_w, 16, hidden] -> [grid_h * 16, grid_w * 16, hidden]
total_rows = grid_h * self.tokens_per_tile_side
total_cols = grid_w * self.tokens_per_tile_side
tile_embeds = tile_embeds.reshape(total_rows, total_cols, hidden_size)
# Now insert linebreak embeddings between rows
linebreak_emb = self.get_linebreak_embedding() # [hidden_size]
# Build output by interleaving rows with linebreaks
output_parts = []
for row_idx in range(total_rows):
# Add the row (all columns)
row = tile_embeds[row_idx] # [total_cols, hidden_size]
output_parts.append(row)
# Add linebreak after each row except the last
if row_idx < total_rows - 1:
output_parts.append(linebreak_emb.unsqueeze(0)) # [1, hidden_size]
# Concatenate all parts
output = torch.cat(output_parts, dim=0) # [total_tokens, hidden_size]
return output
def get_image_features(
self,
pixel_values: torch.Tensor,
tile_grid_shape: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Get image features for tiled images.
Args:
pixel_values: Concatenated tiles tensor of shape [total_tiles, 3, H, W]
tile_grid_shape: Tensor of shape [num_images, 2] where each row is (grid_h, grid_w).
If None, falls back to parent's non-tiled processing.
Returns:
Image features tensor of shape [total_tokens, hidden_size]
"""
if tile_grid_shape is None:
# Standard single-image processing (non-tiled)
return super().get_image_features(pixel_values)
# Get device and dtype from vision tower weights
vision_weight = self.vision_tower.vision_model.embeddings.patch_embedding.weight
target_device = vision_weight.device
target_dtype = vision_weight.dtype
# Normalize tile_grid_shape: list -> tensor
if isinstance(tile_grid_shape, list):
tile_grid_shape = torch.tensor(tile_grid_shape, device=target_device)
# Ensure pixel_values is tensor on correct device/dtype
if not isinstance(pixel_values, torch.Tensor):
pixel_values = torch.tensor(pixel_values, dtype=target_dtype, device=target_device)
else:
pixel_values = pixel_values.to(device=target_device, dtype=target_dtype)
# Calculate tile counts per image for splitting concatenated pixel_values
tile_counts = (tile_grid_shape[:, 0] * tile_grid_shape[:, 1]).tolist()
# Split concatenated pixel_values by image
pixel_splits = torch.split(pixel_values, tile_counts, dim=0)
# Process each image
all_features = []
for pv, grid_shape in zip(pixel_splits, tile_grid_shape.tolist()):
grid_h, grid_w = int(grid_shape[0]), int(grid_shape[1])
features = self._process_tiled_image(pv, grid_h, grid_w)
all_features.append(features)
return torch.cat(all_features, dim=0)
def get_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: torch.FloatTensor,
tile_grid_shape: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Get mask for placeholder tokens, with validation for tiled images.
Args:
input_ids: Input token IDs
inputs_embeds: Input embeddings
image_features: Image feature embeddings
tile_grid_shape: Tensor of shape [num_images, 2] where each row is (grid_h, grid_w).
If provided, validates against expected tiled structure.
Returns:
Boolean mask tensor
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = special_image_mask.sum().item()
# Validate tiled structure if applicable
if tile_grid_shape is not None:
tokens_per_tile_side = int(self.config.mm_tokens_per_image**0.5)
# Normalize to tensor if list
if isinstance(tile_grid_shape, list):
tile_grid_shape = torch.tensor(tile_grid_shape)
# Calculate expected tokens for all images
expected_total = 0
for grid_shape in tile_grid_shape.tolist():
grid_h, grid_w = int(grid_shape[0]), int(grid_shape[1])
total_rows = grid_h * tokens_per_tile_side
total_cols = grid_w * tokens_per_tile_side
expected_img_tokens = total_rows * total_cols
expected_linebreaks = total_rows - 1
expected_total += expected_img_tokens + expected_linebreaks
if n_image_tokens != expected_total:
raise ValueError(
f"Tiled image validation failed: expected {expected_total} tokens "
f"for tile grid(s) {tile_grid_shape.tolist()}, but found {n_image_tokens} placeholder tokens"
)
# Standard validation
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: "
f"tokens: {n_image_tokens}, features: {image_features.numel() // image_features.shape[-1]}"
)
return special_image_mask
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
token_type_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
tile_grid_shape: torch.Tensor | None = None,
**lm_kwargs,
):
"""Forward pass with support for tiled images."""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Replace image id with PAD if the image token is OOV
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
image_features = None
# Check for non-empty pixel_values (empty list would pass "is not None" check)
has_images = pixel_values is not None and (not isinstance(pixel_values, (list, tuple)) or len(pixel_values) > 0)
if has_images:
# Get image features (handles tiled if tile_grid_shape provided)
image_features = self.get_image_features(pixel_values, tile_grid_shape)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
# Ensure correct shape for scatter
if image_features.dim() == 2:
# [total_tokens, hidden] -> [1, total_tokens, hidden] for batch dim
image_features = image_features.unsqueeze(0)
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features, tile_grid_shape=tile_grid_shape
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# Rest is same as parent - create attention masks and run through LM
# ... (inheriting the attention mask logic from parent)
return super().forward(
input_ids=None, # We've already embedded
pixel_values=None, # Already processed
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
token_type_ids=token_type_ids,
cache_position=cache_position,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**lm_kwargs,
)
class Gemma3TiledForConditionalGeneration(Gemma3ForConditionalGeneration):
"""
Gemma3 model for conditional generation with tiled image support.
This is the main model class to use for both training and inference.
"""
config_class = Gemma3TiledConfig
def __init__(self, config: Gemma3TiledConfig):
super().__init__(config)
# Replace the model with our tiled version
self.model = Gemma3TiledModel(config)
def forward(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
token_type_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
tile_grid_shape: torch.Tensor | None = None,
**lm_kwargs,
):
"""Forward pass with tiled image support."""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
tile_grid_shape=tile_grid_shape, # Pass through
**lm_kwargs,
)
hidden_states = outputs[0]
# Compute logits
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Use parent's loss computation logic
logits_float = logits.float()
shift_logits = logits_float[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=getattr(outputs, "image_hidden_states", None),
)
__all__ = [
"Gemma3TiledForConditionalGeneration",
"Gemma3TiledModel",
]