| """ |
| Gemma3Tiled model with support for tiled high-resolution images. |
| |
| This model extends Gemma3ForConditionalGeneration to handle images that |
| are tiled into grids, with spatial rearrangement of embeddings and |
| linebreak tokens between rows. |
| """ |
|
|
| import torch |
| from torch import nn |
| from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Model |
| from transformers.cache_utils import Cache |
|
|
| from .configuration_gemma3_tiled import Gemma3TiledConfig |
|
|
|
|
| class Gemma3TiledModel(Gemma3Model): |
| """ |
| Gemma3 model with tiled image support. |
| |
| Key differences from Gemma3Model: |
| - get_image_features() handles tile grids and spatial rearrangement |
| - get_placeholder_mask() validates tiled structure |
| - Inserts linebreak embeddings (from "\n" token) between rows |
| """ |
|
|
| config_class = Gemma3TiledConfig |
|
|
| def __init__(self, config: Gemma3TiledConfig): |
| super().__init__(config) |
| self.tokens_per_tile = config.mm_tokens_per_image |
| self.tokens_per_tile_side = int(self.tokens_per_tile**0.5) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) |
| vocab = tokenizer.get_vocab() |
| if "\n" not in vocab: |
| raise ValueError("Tokenizer vocab does not contain '\\n' token") |
| self._linebreak_token_id = vocab["\n"] |
|
|
| def get_linebreak_embedding(self) -> torch.Tensor: |
| """Get the embedding for the linebreak token.""" |
| embedding_layer = self.get_input_embeddings() |
| return embedding_layer.weight[self._linebreak_token_id] |
|
|
| def _process_tiled_image( |
| self, |
| pixel_values: torch.Tensor, |
| grid_h: int, |
| grid_w: int, |
| ) -> torch.Tensor: |
| """ |
| Process a single tiled image and return spatially arranged embeddings with linebreaks. |
| |
| Args: |
| pixel_values: Tensor of shape [num_tiles, 3, 896, 896] |
| grid_h: Number of tile rows |
| grid_w: Number of tile columns |
| |
| Returns: |
| Tensor of shape [total_tokens, hidden_size] where: |
| total_tokens = (grid_h * 16) * (grid_w * 16) + (grid_h * 16 - 1) |
| """ |
| num_tiles = grid_h * grid_w |
|
|
| assert pixel_values.shape[0] == num_tiles, ( |
| f"Expected {num_tiles} tiles for {grid_h}x{grid_w} grid, got {pixel_values.shape[0]}" |
| ) |
|
|
| |
| vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state |
|
|
| |
| |
| tile_embeds = self.multi_modal_projector(vision_outputs) |
|
|
| |
| |
| hidden_size = tile_embeds.shape[-1] |
| tile_embeds = tile_embeds.view( |
| grid_h, grid_w, self.tokens_per_tile_side, self.tokens_per_tile_side, hidden_size |
| ) |
|
|
| |
| |
| |
| tile_embeds = tile_embeds.permute(0, 2, 1, 3, 4) |
|
|
| |
| |
| total_rows = grid_h * self.tokens_per_tile_side |
| total_cols = grid_w * self.tokens_per_tile_side |
| tile_embeds = tile_embeds.reshape(total_rows, total_cols, hidden_size) |
|
|
| |
| linebreak_emb = self.get_linebreak_embedding() |
|
|
| |
| output_parts = [] |
| for row_idx in range(total_rows): |
| |
| row = tile_embeds[row_idx] |
| output_parts.append(row) |
|
|
| |
| if row_idx < total_rows - 1: |
| output_parts.append(linebreak_emb.unsqueeze(0)) |
|
|
| |
| output = torch.cat(output_parts, dim=0) |
|
|
| return output |
|
|
| def get_image_features( |
| self, |
| pixel_values: torch.Tensor, |
| tile_grid_shape: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Get image features for tiled images. |
| |
| Args: |
| pixel_values: Concatenated tiles tensor of shape [total_tiles, 3, H, W] |
| tile_grid_shape: Tensor of shape [num_images, 2] where each row is (grid_h, grid_w). |
| If None, falls back to parent's non-tiled processing. |
| |
| Returns: |
| Image features tensor of shape [total_tokens, hidden_size] |
| """ |
| if tile_grid_shape is None: |
| |
| return super().get_image_features(pixel_values) |
|
|
| |
| vision_weight = self.vision_tower.vision_model.embeddings.patch_embedding.weight |
| target_device = vision_weight.device |
| target_dtype = vision_weight.dtype |
|
|
| |
| if isinstance(tile_grid_shape, list): |
| tile_grid_shape = torch.tensor(tile_grid_shape, device=target_device) |
|
|
| |
| if not isinstance(pixel_values, torch.Tensor): |
| pixel_values = torch.tensor(pixel_values, dtype=target_dtype, device=target_device) |
| else: |
| pixel_values = pixel_values.to(device=target_device, dtype=target_dtype) |
|
|
| |
| tile_counts = (tile_grid_shape[:, 0] * tile_grid_shape[:, 1]).tolist() |
|
|
| |
| pixel_splits = torch.split(pixel_values, tile_counts, dim=0) |
|
|
| |
| all_features = [] |
| for pv, grid_shape in zip(pixel_splits, tile_grid_shape.tolist()): |
| grid_h, grid_w = int(grid_shape[0]), int(grid_shape[1]) |
| features = self._process_tiled_image(pv, grid_h, grid_w) |
| all_features.append(features) |
|
|
| return torch.cat(all_features, dim=0) |
|
|
| def get_placeholder_mask( |
| self, |
| input_ids: torch.LongTensor, |
| inputs_embeds: torch.FloatTensor, |
| image_features: torch.FloatTensor, |
| tile_grid_shape: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Get mask for placeholder tokens, with validation for tiled images. |
| |
| Args: |
| input_ids: Input token IDs |
| inputs_embeds: Input embeddings |
| image_features: Image feature embeddings |
| tile_grid_shape: Tensor of shape [num_images, 2] where each row is (grid_h, grid_w). |
| If provided, validates against expected tiled structure. |
| |
| Returns: |
| Boolean mask tensor |
| """ |
| if input_ids is None: |
| special_image_mask = inputs_embeds == self.get_input_embeddings()( |
| torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) |
| ) |
| special_image_mask = special_image_mask.all(-1) |
| else: |
| special_image_mask = input_ids == self.config.image_token_id |
|
|
| n_image_tokens = special_image_mask.sum().item() |
|
|
| |
| if tile_grid_shape is not None: |
| tokens_per_tile_side = int(self.config.mm_tokens_per_image**0.5) |
|
|
| |
| if isinstance(tile_grid_shape, list): |
| tile_grid_shape = torch.tensor(tile_grid_shape) |
|
|
| |
| expected_total = 0 |
| for grid_shape in tile_grid_shape.tolist(): |
| grid_h, grid_w = int(grid_shape[0]), int(grid_shape[1]) |
| total_rows = grid_h * tokens_per_tile_side |
| total_cols = grid_w * tokens_per_tile_side |
| expected_img_tokens = total_rows * total_cols |
| expected_linebreaks = total_rows - 1 |
| expected_total += expected_img_tokens + expected_linebreaks |
|
|
| if n_image_tokens != expected_total: |
| raise ValueError( |
| f"Tiled image validation failed: expected {expected_total} tokens " |
| f"for tile grid(s) {tile_grid_shape.tolist()}, but found {n_image_tokens} placeholder tokens" |
| ) |
|
|
| |
| special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
| if inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| raise ValueError( |
| f"Image features and image tokens do not match: " |
| f"tokens: {n_image_tokens}, features: {image_features.numel() // image_features.shape[-1]}" |
| ) |
|
|
| return special_image_mask |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| pixel_values: torch.FloatTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| token_type_ids: torch.LongTensor | None = None, |
| cache_position: torch.LongTensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| tile_grid_shape: torch.Tensor | None = None, |
| **lm_kwargs, |
| ): |
| """Forward pass with support for tiled images.""" |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and self.config.image_token_id >= self.vocab_size: |
| special_image_mask = input_ids == self.config.image_token_id |
| llm_input_ids = input_ids.clone() |
| llm_input_ids[special_image_mask] = 0 |
| else: |
| llm_input_ids = input_ids |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.get_input_embeddings()(llm_input_ids) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
|
|
| |
| image_features = None |
| |
| has_images = pixel_values is not None and (not isinstance(pixel_values, (list, tuple)) or len(pixel_values) > 0) |
| if has_images: |
| |
| image_features = self.get_image_features(pixel_values, tile_grid_shape) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
| |
| if image_features.dim() == 2: |
| |
| image_features = image_features.unsqueeze(0) |
|
|
| special_image_mask = self.get_placeholder_mask( |
| input_ids, inputs_embeds=inputs_embeds, image_features=image_features, tile_grid_shape=tile_grid_shape |
| ) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
|
|
| |
| |
|
|
| return super().forward( |
| input_ids=None, |
| pixel_values=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| token_type_ids=token_type_ids, |
| cache_position=cache_position, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **lm_kwargs, |
| ) |
|
|
|
|
| class Gemma3TiledForConditionalGeneration(Gemma3ForConditionalGeneration): |
| """ |
| Gemma3 model for conditional generation with tiled image support. |
| |
| This is the main model class to use for both training and inference. |
| """ |
|
|
| config_class = Gemma3TiledConfig |
|
|
| def __init__(self, config: Gemma3TiledConfig): |
| super().__init__(config) |
| |
| self.model = Gemma3TiledModel(config) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| pixel_values: torch.FloatTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| token_type_ids: torch.LongTensor | None = None, |
| cache_position: torch.LongTensor | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| tile_grid_shape: torch.Tensor | None = None, |
| **lm_kwargs, |
| ): |
| """Forward pass with tiled image support.""" |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| token_type_ids=token_type_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| labels=labels, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| tile_grid_shape=tile_grid_shape, |
| **lm_kwargs, |
| ) |
|
|
| hidden_states = outputs[0] |
|
|
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| |
| logits_float = logits.float() |
| shift_logits = logits_float[..., :-1, :] |
| shift_labels = labels[..., 1:] |
| if attention_mask is not None: |
| shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) |
| shift_logits = shift_logits[shift_attention_mask != 0].contiguous() |
| shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
| else: |
| shift_logits = shift_logits.contiguous() |
| shift_labels = shift_labels.contiguous() |
|
|
| loss_fct = nn.CrossEntropyLoss() |
| flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) |
| flat_labels = shift_labels.view(-1).to(shift_logits.device) |
| loss = loss_fct(flat_logits, flat_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast |
|
|
| return Gemma3CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| image_hidden_states=getattr(outputs, "image_hidden_states", None), |
| ) |
|
|
|
|
| __all__ = [ |
| "Gemma3TiledForConditionalGeneration", |
| "Gemma3TiledModel", |
| ] |
|
|