Fraser commited on
Commit
99c356d
·
verified ·
1 Parent(s): e29c68a

Update modeling_gemma3_tiled.py

Browse files
Files changed (1) hide show
  1. modeling_gemma3_tiled.py +147 -156
modeling_gemma3_tiled.py CHANGED
@@ -6,12 +6,9 @@ are tiled into grids, with spatial rearrangement of embeddings and
6
  linebreak tokens between rows.
7
  """
8
 
9
- from typing import Optional, Union
10
-
11
  import torch
12
- import torch.nn as nn
13
-
14
- from transformers import Gemma3ForConditionalGeneration, Gemma3Model, AutoTokenizer
15
  from transformers.cache_utils import Cache
16
 
17
  from .configuration_gemma3_tiled import Gemma3TiledConfig
@@ -20,171 +17,167 @@ from .configuration_gemma3_tiled import Gemma3TiledConfig
20
  class Gemma3TiledModel(Gemma3Model):
21
  """
22
  Gemma3 model with tiled image support.
23
-
24
  Key differences from Gemma3Model:
25
  - get_image_features() handles tile grids and spatial rearrangement
26
  - get_placeholder_mask() validates tiled structure
27
  - Inserts linebreak embeddings (from "\n" token) between rows
28
  """
29
-
30
  config_class = Gemma3TiledConfig
31
-
32
  def __init__(self, config: Gemma3TiledConfig):
33
  super().__init__(config)
34
  self.tokens_per_tile = config.mm_tokens_per_image # 256
35
- self.tokens_per_tile_side = int(self.tokens_per_tile ** 0.5) # 16
36
 
37
  # Look up newline token ID from tokenizer vocab
38
  tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
39
  vocab = tokenizer.get_vocab()
40
  if "\n" not in vocab:
41
- raise ValueError(f"Tokenizer vocab does not contain '\\n' token")
42
  self._linebreak_token_id = vocab["\n"]
43
 
44
  def get_linebreak_embedding(self) -> torch.Tensor:
45
  """Get the embedding for the linebreak token."""
46
  embedding_layer = self.get_input_embeddings()
47
  return embedding_layer.weight[self._linebreak_token_id]
48
-
49
- def get_image_features_tiled(
50
  self,
51
  pixel_values: torch.Tensor,
52
- tile_grid_shape: tuple[int, int],
 
53
  ) -> torch.Tensor:
54
  """
55
- Process tiled image and return spatially arranged embeddings with linebreaks.
56
-
57
  Args:
58
  pixel_values: Tensor of shape [num_tiles, 3, 896, 896]
59
- tile_grid_shape: Tuple of (grid_h, grid_w)
60
-
 
61
  Returns:
62
  Tensor of shape [total_tokens, hidden_size] where:
63
  total_tokens = (grid_h * 16) * (grid_w * 16) + (grid_h * 16 - 1)
64
  """
65
- grid_h, grid_w = tile_grid_shape
66
  num_tiles = grid_h * grid_w
67
-
68
  assert pixel_values.shape[0] == num_tiles, (
69
- f"Expected {num_tiles} tiles for {grid_h}x{grid_w} grid, "
70
- f"got {pixel_values.shape[0]}"
71
  )
72
-
73
  # Process each tile through vision tower
74
  vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
75
-
76
  # Project through multimodal projector
77
  # Output shape: [num_tiles, 256, hidden_size]
78
  tile_embeds = self.multi_modal_projector(vision_outputs)
79
-
80
  # Reshape to spatial grid
81
  # [num_tiles, 256, hidden] -> [grid_h, grid_w, 16, 16, hidden]
82
  hidden_size = tile_embeds.shape[-1]
83
  tile_embeds = tile_embeds.view(
84
- grid_h, grid_w,
85
- self.tokens_per_tile_side, self.tokens_per_tile_side,
86
- hidden_size
87
  )
88
-
89
  # Rearrange to merge tiles spatially
90
  # We want: for each row of tiles, merge their columns
91
  # [grid_h, grid_w, 16, 16, hidden] -> [grid_h, 16, grid_w, 16, hidden]
92
  tile_embeds = tile_embeds.permute(0, 2, 1, 3, 4)
93
-
94
  # Merge into full spatial grid
95
  # [grid_h, 16, grid_w, 16, hidden] -> [grid_h * 16, grid_w * 16, hidden]
96
  total_rows = grid_h * self.tokens_per_tile_side
97
  total_cols = grid_w * self.tokens_per_tile_side
98
  tile_embeds = tile_embeds.reshape(total_rows, total_cols, hidden_size)
99
-
100
  # Now insert linebreak embeddings between rows
101
  linebreak_emb = self.get_linebreak_embedding() # [hidden_size]
102
-
103
  # Build output by interleaving rows with linebreaks
104
  output_parts = []
105
  for row_idx in range(total_rows):
106
  # Add the row (all columns)
107
  row = tile_embeds[row_idx] # [total_cols, hidden_size]
108
  output_parts.append(row)
109
-
110
  # Add linebreak after each row except the last
111
  if row_idx < total_rows - 1:
112
  output_parts.append(linebreak_emb.unsqueeze(0)) # [1, hidden_size]
113
-
114
  # Concatenate all parts
115
  output = torch.cat(output_parts, dim=0) # [total_tokens, hidden_size]
116
-
117
  return output
118
-
119
  def get_image_features(
120
  self,
121
- pixel_values,
122
- tile_grid_shape=None,
123
  ) -> torch.Tensor:
124
  """
125
- Get image features, handling both single images and tiled images.
126
- Supports batched inputs where pixel_values is a list and tile_grid_shape is a list of tuples.
127
-
128
  Args:
129
- pixel_values: Image tensor(s) - can be a single tensor or list of tensors
130
- tile_grid_shape: If provided, treats input as tiled. Can be a single tuple or list of tuples.
131
-
 
132
  Returns:
133
- Image features tensor
134
  """
135
  if tile_grid_shape is None:
136
- # Standard single-image processing
137
  return super().get_image_features(pixel_values)
138
-
139
  # Get device and dtype from vision tower weights
140
  vision_weight = self.vision_tower.vision_model.embeddings.patch_embedding.weight
141
  target_device = vision_weight.device
142
  target_dtype = vision_weight.dtype
143
-
144
- # Handle batched inputs: pixel_values is list of tensors, tile_grid_shape is list of tuples
145
  if isinstance(tile_grid_shape, list):
146
- all_features = []
147
- # pixel_values can be list of numpy arrays or tensors
148
- if isinstance(pixel_values, (list, tuple)):
149
- for pv, grid_shape in zip(pixel_values, tile_grid_shape):
150
- # Convert to tensor if needed and move to correct device/dtype
151
- if not isinstance(pv, torch.Tensor):
152
- pv = torch.tensor(pv, dtype=target_dtype, device=target_device)
153
- else:
154
- pv = pv.to(device=target_device, dtype=target_dtype)
155
- features = self.get_image_features_tiled(pv, grid_shape)
156
- all_features.append(features)
157
- # Concatenate all image features
158
- return torch.cat(all_features, dim=0)
159
- else:
160
- # pixel_values is already concatenated, but we have multiple grid shapes
161
- # This shouldn't happen with proper preprocessing, fall back to first grid shape
162
- return self.get_image_features_tiled(pixel_values, tile_grid_shape[0])
163
  else:
164
- # Single image case - ensure correct device/dtype
165
- if not isinstance(pixel_values, torch.Tensor):
166
- pixel_values = torch.tensor(pixel_values, dtype=target_dtype, device=target_device)
167
- else:
168
- pixel_values = pixel_values.to(device=target_device, dtype=target_dtype)
169
- return self.get_image_features_tiled(pixel_values, tile_grid_shape)
170
-
 
 
 
 
 
 
 
 
 
 
171
  def get_placeholder_mask(
172
  self,
173
  input_ids: torch.LongTensor,
174
  inputs_embeds: torch.FloatTensor,
175
  image_features: torch.FloatTensor,
176
- tile_grid_shape=None,
177
  ) -> torch.Tensor:
178
  """
179
  Get mask for placeholder tokens, with validation for tiled images.
180
-
181
  Args:
182
  input_ids: Input token IDs
183
  inputs_embeds: Input embeddings
184
  image_features: Image feature embeddings
185
- tile_grid_shape: If provided, validates against expected tiled structure.
186
- Can be a single tuple or list of tuples.
187
-
188
  Returns:
189
  Boolean mask tensor
190
  """
@@ -195,76 +188,73 @@ class Gemma3TiledModel(Gemma3Model):
195
  special_image_mask = special_image_mask.all(-1)
196
  else:
197
  special_image_mask = input_ids == self.config.image_token_id
198
-
199
  n_image_tokens = special_image_mask.sum().item()
200
-
201
  # Validate tiled structure if applicable
202
  if tile_grid_shape is not None:
203
- tokens_per_tile_side = int(self.config.mm_tokens_per_image ** 0.5)
204
-
205
- # Handle list of grid shapes (multiple images)
206
  if isinstance(tile_grid_shape, list):
207
- expected_total = 0
208
- for grid_h, grid_w in tile_grid_shape:
209
- total_rows = grid_h * tokens_per_tile_side
210
- total_cols = grid_w * tokens_per_tile_side
211
- expected_img_tokens = total_rows * total_cols
212
- expected_linebreaks = total_rows - 1
213
- expected_total += expected_img_tokens + expected_linebreaks
214
- else:
215
- grid_h, grid_w = tile_grid_shape
216
  total_rows = grid_h * tokens_per_tile_side
217
  total_cols = grid_w * tokens_per_tile_side
218
  expected_img_tokens = total_rows * total_cols
219
  expected_linebreaks = total_rows - 1
220
- expected_total = expected_img_tokens + expected_linebreaks
221
-
222
  if n_image_tokens != expected_total:
223
  raise ValueError(
224
  f"Tiled image validation failed: expected {expected_total} tokens "
225
- f"for tile grid(s) {tile_grid_shape}, but found {n_image_tokens} placeholder tokens"
226
  )
227
-
228
  # Standard validation
229
  special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
230
-
231
  if inputs_embeds[special_image_mask].numel() != image_features.numel():
232
  raise ValueError(
233
  f"Image features and image tokens do not match: "
234
  f"tokens: {n_image_tokens}, features: {image_features.numel() // image_features.shape[-1]}"
235
  )
236
-
237
  return special_image_mask
238
-
239
  def forward(
240
  self,
241
- input_ids: Optional[torch.LongTensor] = None,
242
- pixel_values: Optional[torch.FloatTensor] = None,
243
- attention_mask: Optional[torch.Tensor] = None,
244
- position_ids: Optional[torch.LongTensor] = None,
245
- past_key_values: Optional[Cache] = None,
246
- token_type_ids: Optional[torch.LongTensor] = None,
247
- cache_position: Optional[torch.LongTensor] = None,
248
- inputs_embeds: Optional[torch.FloatTensor] = None,
249
- labels: Optional[torch.LongTensor] = None,
250
- use_cache: Optional[bool] = None,
251
- output_attentions: Optional[bool] = None,
252
- output_hidden_states: Optional[bool] = None,
253
- return_dict: Optional[bool] = None,
254
- tile_grid_shape: Optional[tuple[int, int]] = None, # NEW
255
  **lm_kwargs,
256
  ):
257
  """Forward pass with support for tiled images."""
258
-
259
  if (input_ids is None) ^ (inputs_embeds is not None):
260
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
261
-
262
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
263
  output_hidden_states = (
264
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
265
  )
266
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
267
-
268
  # Replace image id with PAD if the image token is OOV
269
  if input_ids is not None and self.config.image_token_id >= self.vocab_size:
270
  special_image_mask = input_ids == self.config.image_token_id
@@ -272,37 +262,38 @@ class Gemma3TiledModel(Gemma3Model):
272
  llm_input_ids[special_image_mask] = 0
273
  else:
274
  llm_input_ids = input_ids
275
-
276
  if inputs_embeds is None:
277
  inputs_embeds = self.get_input_embeddings()(llm_input_ids)
278
-
279
  if cache_position is None:
280
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
281
  cache_position = torch.arange(
282
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
283
  )
284
-
285
  # Merge text and images
286
  image_features = None
287
- if pixel_values is not None:
 
 
288
  # Get image features (handles tiled if tile_grid_shape provided)
289
  image_features = self.get_image_features(pixel_values, tile_grid_shape)
290
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
291
-
292
  # Ensure correct shape for scatter
293
  if image_features.dim() == 2:
294
  # [total_tokens, hidden] -> [1, total_tokens, hidden] for batch dim
295
  image_features = image_features.unsqueeze(0)
296
-
297
  special_image_mask = self.get_placeholder_mask(
298
- input_ids, inputs_embeds=inputs_embeds, image_features=image_features,
299
- tile_grid_shape=tile_grid_shape
300
  )
301
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
302
-
303
  # Rest is same as parent - create attention masks and run through LM
304
  # ... (inheriting the attention mask logic from parent)
305
-
306
  return super().forward(
307
  input_ids=None, # We've already embedded
308
  pixel_values=None, # Already processed
@@ -324,44 +315,44 @@ class Gemma3TiledModel(Gemma3Model):
324
  class Gemma3TiledForConditionalGeneration(Gemma3ForConditionalGeneration):
325
  """
326
  Gemma3 model for conditional generation with tiled image support.
327
-
328
  This is the main model class to use for both training and inference.
329
  """
330
-
331
  config_class = Gemma3TiledConfig
332
-
333
  def __init__(self, config: Gemma3TiledConfig):
334
  super().__init__(config)
335
  # Replace the model with our tiled version
336
  self.model = Gemma3TiledModel(config)
337
-
338
  def forward(
339
  self,
340
- input_ids: Optional[torch.LongTensor] = None,
341
- pixel_values: Optional[torch.FloatTensor] = None,
342
- attention_mask: Optional[torch.Tensor] = None,
343
- position_ids: Optional[torch.LongTensor] = None,
344
- past_key_values: Optional[Cache] = None,
345
- token_type_ids: Optional[torch.LongTensor] = None,
346
- cache_position: Optional[torch.LongTensor] = None,
347
- inputs_embeds: Optional[torch.FloatTensor] = None,
348
- labels: Optional[torch.LongTensor] = None,
349
- use_cache: Optional[bool] = None,
350
- output_attentions: Optional[bool] = None,
351
- output_hidden_states: Optional[bool] = None,
352
- return_dict: Optional[bool] = None,
353
- logits_to_keep: Union[int, torch.Tensor] = 0,
354
- tile_grid_shape: Optional[tuple[int, int]] = None, # NEW
355
  **lm_kwargs,
356
  ):
357
  """Forward pass with tiled image support."""
358
-
359
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
  output_hidden_states = (
361
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
  )
363
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
-
365
  outputs = self.model(
366
  input_ids=input_ids,
367
  pixel_values=pixel_values,
@@ -379,13 +370,13 @@ class Gemma3TiledForConditionalGeneration(Gemma3ForConditionalGeneration):
379
  tile_grid_shape=tile_grid_shape, # Pass through
380
  **lm_kwargs,
381
  )
382
-
383
  hidden_states = outputs[0]
384
-
385
  # Compute logits
386
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
387
  logits = self.lm_head(hidden_states[:, slice_indices, :])
388
-
389
  loss = None
390
  if labels is not None:
391
  # Use parent's loss computation logic
@@ -393,35 +384,35 @@ class Gemma3TiledForConditionalGeneration(Gemma3ForConditionalGeneration):
393
  shift_logits = logits_float[..., :-1, :]
394
  shift_labels = labels[..., 1:]
395
  if attention_mask is not None:
396
- shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
397
  shift_logits = shift_logits[shift_attention_mask != 0].contiguous()
398
  shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
399
  else:
400
  shift_logits = shift_logits.contiguous()
401
  shift_labels = shift_labels.contiguous()
402
-
403
  loss_fct = nn.CrossEntropyLoss()
404
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
405
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
406
  loss = loss_fct(flat_logits, flat_labels)
407
-
408
  if not return_dict:
409
  output = (logits,) + outputs[1:]
410
  return (loss,) + output if loss is not None else output
411
-
412
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
413
-
414
  return Gemma3CausalLMOutputWithPast(
415
  loss=loss,
416
  logits=logits,
417
  past_key_values=outputs.past_key_values,
418
  hidden_states=outputs.hidden_states,
419
  attentions=outputs.attentions,
420
- image_hidden_states=getattr(outputs, 'image_hidden_states', None),
421
  )
422
 
423
 
424
  __all__ = [
425
- "Gemma3TiledModel",
426
  "Gemma3TiledForConditionalGeneration",
 
427
  ]
 
6
  linebreak tokens between rows.
7
  """
8
 
 
 
9
  import torch
10
+ from torch import nn
11
+ from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Model
 
12
  from transformers.cache_utils import Cache
13
 
14
  from .configuration_gemma3_tiled import Gemma3TiledConfig
 
17
  class Gemma3TiledModel(Gemma3Model):
18
  """
19
  Gemma3 model with tiled image support.
20
+
21
  Key differences from Gemma3Model:
22
  - get_image_features() handles tile grids and spatial rearrangement
23
  - get_placeholder_mask() validates tiled structure
24
  - Inserts linebreak embeddings (from "\n" token) between rows
25
  """
26
+
27
  config_class = Gemma3TiledConfig
28
+
29
  def __init__(self, config: Gemma3TiledConfig):
30
  super().__init__(config)
31
  self.tokens_per_tile = config.mm_tokens_per_image # 256
32
+ self.tokens_per_tile_side = int(self.tokens_per_tile**0.5) # 16
33
 
34
  # Look up newline token ID from tokenizer vocab
35
  tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
36
  vocab = tokenizer.get_vocab()
37
  if "\n" not in vocab:
38
+ raise ValueError("Tokenizer vocab does not contain '\\n' token")
39
  self._linebreak_token_id = vocab["\n"]
40
 
41
  def get_linebreak_embedding(self) -> torch.Tensor:
42
  """Get the embedding for the linebreak token."""
43
  embedding_layer = self.get_input_embeddings()
44
  return embedding_layer.weight[self._linebreak_token_id]
45
+
46
+ def _process_tiled_image(
47
  self,
48
  pixel_values: torch.Tensor,
49
+ grid_h: int,
50
+ grid_w: int,
51
  ) -> torch.Tensor:
52
  """
53
+ Process a single tiled image and return spatially arranged embeddings with linebreaks.
54
+
55
  Args:
56
  pixel_values: Tensor of shape [num_tiles, 3, 896, 896]
57
+ grid_h: Number of tile rows
58
+ grid_w: Number of tile columns
59
+
60
  Returns:
61
  Tensor of shape [total_tokens, hidden_size] where:
62
  total_tokens = (grid_h * 16) * (grid_w * 16) + (grid_h * 16 - 1)
63
  """
 
64
  num_tiles = grid_h * grid_w
65
+
66
  assert pixel_values.shape[0] == num_tiles, (
67
+ f"Expected {num_tiles} tiles for {grid_h}x{grid_w} grid, got {pixel_values.shape[0]}"
 
68
  )
69
+
70
  # Process each tile through vision tower
71
  vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
72
+
73
  # Project through multimodal projector
74
  # Output shape: [num_tiles, 256, hidden_size]
75
  tile_embeds = self.multi_modal_projector(vision_outputs)
76
+
77
  # Reshape to spatial grid
78
  # [num_tiles, 256, hidden] -> [grid_h, grid_w, 16, 16, hidden]
79
  hidden_size = tile_embeds.shape[-1]
80
  tile_embeds = tile_embeds.view(
81
+ grid_h, grid_w, self.tokens_per_tile_side, self.tokens_per_tile_side, hidden_size
 
 
82
  )
83
+
84
  # Rearrange to merge tiles spatially
85
  # We want: for each row of tiles, merge their columns
86
  # [grid_h, grid_w, 16, 16, hidden] -> [grid_h, 16, grid_w, 16, hidden]
87
  tile_embeds = tile_embeds.permute(0, 2, 1, 3, 4)
88
+
89
  # Merge into full spatial grid
90
  # [grid_h, 16, grid_w, 16, hidden] -> [grid_h * 16, grid_w * 16, hidden]
91
  total_rows = grid_h * self.tokens_per_tile_side
92
  total_cols = grid_w * self.tokens_per_tile_side
93
  tile_embeds = tile_embeds.reshape(total_rows, total_cols, hidden_size)
94
+
95
  # Now insert linebreak embeddings between rows
96
  linebreak_emb = self.get_linebreak_embedding() # [hidden_size]
97
+
98
  # Build output by interleaving rows with linebreaks
99
  output_parts = []
100
  for row_idx in range(total_rows):
101
  # Add the row (all columns)
102
  row = tile_embeds[row_idx] # [total_cols, hidden_size]
103
  output_parts.append(row)
104
+
105
  # Add linebreak after each row except the last
106
  if row_idx < total_rows - 1:
107
  output_parts.append(linebreak_emb.unsqueeze(0)) # [1, hidden_size]
108
+
109
  # Concatenate all parts
110
  output = torch.cat(output_parts, dim=0) # [total_tokens, hidden_size]
111
+
112
  return output
113
+
114
  def get_image_features(
115
  self,
116
+ pixel_values: torch.Tensor,
117
+ tile_grid_shape: torch.Tensor | None = None,
118
  ) -> torch.Tensor:
119
  """
120
+ Get image features for tiled images.
121
+
 
122
  Args:
123
+ pixel_values: Concatenated tiles tensor of shape [total_tiles, 3, H, W]
124
+ tile_grid_shape: Tensor of shape [num_images, 2] where each row is (grid_h, grid_w).
125
+ If None, falls back to parent's non-tiled processing.
126
+
127
  Returns:
128
+ Image features tensor of shape [total_tokens, hidden_size]
129
  """
130
  if tile_grid_shape is None:
131
+ # Standard single-image processing (non-tiled)
132
  return super().get_image_features(pixel_values)
133
+
134
  # Get device and dtype from vision tower weights
135
  vision_weight = self.vision_tower.vision_model.embeddings.patch_embedding.weight
136
  target_device = vision_weight.device
137
  target_dtype = vision_weight.dtype
138
+
139
+ # Normalize tile_grid_shape: list -> tensor
140
  if isinstance(tile_grid_shape, list):
141
+ tile_grid_shape = torch.tensor(tile_grid_shape, device=target_device)
142
+
143
+ # Ensure pixel_values is tensor on correct device/dtype
144
+ if not isinstance(pixel_values, torch.Tensor):
145
+ pixel_values = torch.tensor(pixel_values, dtype=target_dtype, device=target_device)
 
 
 
 
 
 
 
 
 
 
 
 
146
  else:
147
+ pixel_values = pixel_values.to(device=target_device, dtype=target_dtype)
148
+
149
+ # Calculate tile counts per image for splitting concatenated pixel_values
150
+ tile_counts = (tile_grid_shape[:, 0] * tile_grid_shape[:, 1]).tolist()
151
+
152
+ # Split concatenated pixel_values by image
153
+ pixel_splits = torch.split(pixel_values, tile_counts, dim=0)
154
+
155
+ # Process each image
156
+ all_features = []
157
+ for pv, grid_shape in zip(pixel_splits, tile_grid_shape.tolist()):
158
+ grid_h, grid_w = int(grid_shape[0]), int(grid_shape[1])
159
+ features = self._process_tiled_image(pv, grid_h, grid_w)
160
+ all_features.append(features)
161
+
162
+ return torch.cat(all_features, dim=0)
163
+
164
  def get_placeholder_mask(
165
  self,
166
  input_ids: torch.LongTensor,
167
  inputs_embeds: torch.FloatTensor,
168
  image_features: torch.FloatTensor,
169
+ tile_grid_shape: torch.Tensor | None = None,
170
  ) -> torch.Tensor:
171
  """
172
  Get mask for placeholder tokens, with validation for tiled images.
173
+
174
  Args:
175
  input_ids: Input token IDs
176
  inputs_embeds: Input embeddings
177
  image_features: Image feature embeddings
178
+ tile_grid_shape: Tensor of shape [num_images, 2] where each row is (grid_h, grid_w).
179
+ If provided, validates against expected tiled structure.
180
+
181
  Returns:
182
  Boolean mask tensor
183
  """
 
188
  special_image_mask = special_image_mask.all(-1)
189
  else:
190
  special_image_mask = input_ids == self.config.image_token_id
191
+
192
  n_image_tokens = special_image_mask.sum().item()
193
+
194
  # Validate tiled structure if applicable
195
  if tile_grid_shape is not None:
196
+ tokens_per_tile_side = int(self.config.mm_tokens_per_image**0.5)
197
+
198
+ # Normalize to tensor if list
199
  if isinstance(tile_grid_shape, list):
200
+ tile_grid_shape = torch.tensor(tile_grid_shape)
201
+
202
+ # Calculate expected tokens for all images
203
+ expected_total = 0
204
+ for grid_shape in tile_grid_shape.tolist():
205
+ grid_h, grid_w = int(grid_shape[0]), int(grid_shape[1])
 
 
 
206
  total_rows = grid_h * tokens_per_tile_side
207
  total_cols = grid_w * tokens_per_tile_side
208
  expected_img_tokens = total_rows * total_cols
209
  expected_linebreaks = total_rows - 1
210
+ expected_total += expected_img_tokens + expected_linebreaks
211
+
212
  if n_image_tokens != expected_total:
213
  raise ValueError(
214
  f"Tiled image validation failed: expected {expected_total} tokens "
215
+ f"for tile grid(s) {tile_grid_shape.tolist()}, but found {n_image_tokens} placeholder tokens"
216
  )
217
+
218
  # Standard validation
219
  special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
220
+
221
  if inputs_embeds[special_image_mask].numel() != image_features.numel():
222
  raise ValueError(
223
  f"Image features and image tokens do not match: "
224
  f"tokens: {n_image_tokens}, features: {image_features.numel() // image_features.shape[-1]}"
225
  )
226
+
227
  return special_image_mask
228
+
229
  def forward(
230
  self,
231
+ input_ids: torch.LongTensor | None = None,
232
+ pixel_values: torch.FloatTensor | None = None,
233
+ attention_mask: torch.Tensor | None = None,
234
+ position_ids: torch.LongTensor | None = None,
235
+ past_key_values: Cache | None = None,
236
+ token_type_ids: torch.LongTensor | None = None,
237
+ cache_position: torch.LongTensor | None = None,
238
+ inputs_embeds: torch.FloatTensor | None = None,
239
+ labels: torch.LongTensor | None = None,
240
+ use_cache: bool | None = None,
241
+ output_attentions: bool | None = None,
242
+ output_hidden_states: bool | None = None,
243
+ return_dict: bool | None = None,
244
+ tile_grid_shape: torch.Tensor | None = None,
245
  **lm_kwargs,
246
  ):
247
  """Forward pass with support for tiled images."""
248
+
249
  if (input_ids is None) ^ (inputs_embeds is not None):
250
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
251
+
252
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
253
  output_hidden_states = (
254
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
255
  )
256
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
257
+
258
  # Replace image id with PAD if the image token is OOV
259
  if input_ids is not None and self.config.image_token_id >= self.vocab_size:
260
  special_image_mask = input_ids == self.config.image_token_id
 
262
  llm_input_ids[special_image_mask] = 0
263
  else:
264
  llm_input_ids = input_ids
265
+
266
  if inputs_embeds is None:
267
  inputs_embeds = self.get_input_embeddings()(llm_input_ids)
268
+
269
  if cache_position is None:
270
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
271
  cache_position = torch.arange(
272
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
273
  )
274
+
275
  # Merge text and images
276
  image_features = None
277
+ # Check for non-empty pixel_values (empty list would pass "is not None" check)
278
+ has_images = pixel_values is not None and (not isinstance(pixel_values, (list, tuple)) or len(pixel_values) > 0)
279
+ if has_images:
280
  # Get image features (handles tiled if tile_grid_shape provided)
281
  image_features = self.get_image_features(pixel_values, tile_grid_shape)
282
  image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
283
+
284
  # Ensure correct shape for scatter
285
  if image_features.dim() == 2:
286
  # [total_tokens, hidden] -> [1, total_tokens, hidden] for batch dim
287
  image_features = image_features.unsqueeze(0)
288
+
289
  special_image_mask = self.get_placeholder_mask(
290
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features, tile_grid_shape=tile_grid_shape
 
291
  )
292
  inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
293
+
294
  # Rest is same as parent - create attention masks and run through LM
295
  # ... (inheriting the attention mask logic from parent)
296
+
297
  return super().forward(
298
  input_ids=None, # We've already embedded
299
  pixel_values=None, # Already processed
 
315
  class Gemma3TiledForConditionalGeneration(Gemma3ForConditionalGeneration):
316
  """
317
  Gemma3 model for conditional generation with tiled image support.
318
+
319
  This is the main model class to use for both training and inference.
320
  """
321
+
322
  config_class = Gemma3TiledConfig
323
+
324
  def __init__(self, config: Gemma3TiledConfig):
325
  super().__init__(config)
326
  # Replace the model with our tiled version
327
  self.model = Gemma3TiledModel(config)
328
+
329
  def forward(
330
  self,
331
+ input_ids: torch.LongTensor | None = None,
332
+ pixel_values: torch.FloatTensor | None = None,
333
+ attention_mask: torch.Tensor | None = None,
334
+ position_ids: torch.LongTensor | None = None,
335
+ past_key_values: Cache | None = None,
336
+ token_type_ids: torch.LongTensor | None = None,
337
+ cache_position: torch.LongTensor | None = None,
338
+ inputs_embeds: torch.FloatTensor | None = None,
339
+ labels: torch.LongTensor | None = None,
340
+ use_cache: bool | None = None,
341
+ output_attentions: bool | None = None,
342
+ output_hidden_states: bool | None = None,
343
+ return_dict: bool | None = None,
344
+ logits_to_keep: int | torch.Tensor = 0,
345
+ tile_grid_shape: torch.Tensor | None = None,
346
  **lm_kwargs,
347
  ):
348
  """Forward pass with tiled image support."""
349
+
350
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
351
  output_hidden_states = (
352
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
353
  )
354
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
355
+
356
  outputs = self.model(
357
  input_ids=input_ids,
358
  pixel_values=pixel_values,
 
370
  tile_grid_shape=tile_grid_shape, # Pass through
371
  **lm_kwargs,
372
  )
373
+
374
  hidden_states = outputs[0]
375
+
376
  # Compute logits
377
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
378
  logits = self.lm_head(hidden_states[:, slice_indices, :])
379
+
380
  loss = None
381
  if labels is not None:
382
  # Use parent's loss computation logic
 
384
  shift_logits = logits_float[..., :-1, :]
385
  shift_labels = labels[..., 1:]
386
  if attention_mask is not None:
387
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
388
  shift_logits = shift_logits[shift_attention_mask != 0].contiguous()
389
  shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
390
  else:
391
  shift_logits = shift_logits.contiguous()
392
  shift_labels = shift_labels.contiguous()
393
+
394
  loss_fct = nn.CrossEntropyLoss()
395
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
396
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
397
  loss = loss_fct(flat_logits, flat_labels)
398
+
399
  if not return_dict:
400
  output = (logits,) + outputs[1:]
401
  return (loss,) + output if loss is not None else output
402
+
403
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
404
+
405
  return Gemma3CausalLMOutputWithPast(
406
  loss=loss,
407
  logits=logits,
408
  past_key_values=outputs.past_key_values,
409
  hidden_states=outputs.hidden_states,
410
  attentions=outputs.attentions,
411
+ image_hidden_states=getattr(outputs, "image_hidden_states", None),
412
  )
413
 
414
 
415
  __all__ = [
 
416
  "Gemma3TiledForConditionalGeneration",
417
+ "Gemma3TiledModel",
418
  ]