yiyexy commited on
Commit
ceb04eb
·
verified ·
1 Parent(s): 3b2a6ee

Update modeling_onevision_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_onevision_encoder.py +1327 -345
modeling_onevision_encoder.py CHANGED
@@ -1,113 +1,91 @@
1
- from typing import Optional, Tuple, Union
 
2
 
3
  import torch
4
  import torch.nn as nn
 
5
 
6
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
 
 
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.models.siglip.modeling_siglip import SiglipMLP
 
9
  from transformers.utils import (
10
- add_start_docstrings,
11
- add_start_docstrings_to_model_forward,
12
- logging,
 
13
  replace_return_docstrings,
14
  )
15
 
16
- from .configuration_onevision_encoder import OneVisionEncoderConfig
17
 
18
 
19
- try:
20
  from flash_attn import flash_attn_func
21
 
22
- _flash_attn_available = True
23
- except ImportError:
24
- _flash_attn_available = False
25
 
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- # ---------------------------------------------------------------------------
30
- # Model Docstrings
31
- # ---------------------------------------------------------------------------
32
-
33
- ONEVISION_ENCODER_START_DOCSTRING = r"""
34
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
35
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
36
- etc.)
37
-
38
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
39
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
40
- and behavior.
41
-
42
- Parameters:
43
- config ([`OneVisionEncoderConfig`]): Model configuration class with all the parameters of the model.
44
- Initializing with a config file does not load the weights associated with the model, only the
45
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
46
- """
47
-
48
- ONEVISION_ENCODER_INPUTS_DOCSTRING = r"""
49
- Args:
50
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch_size, num_channels, num_frames, height, width)`):
51
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`].
52
- visible_indices (`torch.Tensor`, *optional*):
53
- Indices of visible patches for masking. Used in MAE-style pretraining or inference.
54
- output_attentions (`bool`, *optional*):
55
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
56
- tensors for more detail.
57
- output_hidden_states (`bool`, *optional*):
58
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
59
- more detail.
60
- return_dict (`bool`, *optional*):
61
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
62
- """
63
-
64
-
65
- # ---------------------------------------------------------------------------
66
- # Helper Functions & Layers
67
- # ---------------------------------------------------------------------------
68
 
 
 
 
69
 
70
- def get_norm_layer(config):
71
- if config.layer_norm_type == "rms_norm":
72
- return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
73
- else:
74
- return nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
75
 
76
 
77
- def rotate_half(x):
 
 
 
78
  """
79
- Interleaved rotation to match Source model's implementation.
80
- (x1, x2, x3, x4) -> (-x2, x1, -x4, x3)
 
 
 
 
 
 
 
 
 
 
81
  """
82
- x_even = x[..., ::2]
83
- x_odd = x[..., 1::2]
84
- return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
85
-
86
-
87
- def apply_rotary_pos_emb(q, k, freqs):
88
- # q, k: (B, H, L, D)
89
- # freqs: (B, L, D)
90
 
91
- # We need to broadcast freqs to match heads
92
- # (B, L, D) -> (B, 1, L, D)
 
 
 
93
 
94
- # !!! CRITICAL FIX: Cast cos/sin to q.dtype (bf16/fp16) immediately
95
- # freqs are typically float32, so cos() returns float32.
96
- # Without this cast, (q * cos) upcasts q to float32, causing FlashAttention to fail.
97
- cos = freqs.cos().unsqueeze(1).to(q.dtype)
98
- sin = freqs.sin().unsqueeze(1).to(q.dtype)
99
 
100
- q_embed = (q * cos) + (rotate_half(q) * sin)
101
- k_embed = (k * cos) + (rotate_half(k) * sin)
102
- return q_embed, k_embed
103
 
104
 
105
- class VideoRotaryEmbeddingSplit466(nn.Module):
106
  """
107
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
 
108
  """
109
 
110
- def __init__(self, config: OneVisionEncoderConfig):
111
  super().__init__()
112
  head_dim = config.hidden_size // config.num_attention_heads
113
  base = config.rope_theta
@@ -120,6 +98,7 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
120
  self.head_dim = head_dim
121
  self.half = half
122
 
 
123
  unit = half // 16
124
  self.t_size = 4 * unit
125
  self.h_size = 6 * unit
@@ -141,90 +120,121 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
141
  persistent=False,
142
  )
143
 
144
- def forward(self, t: int, h: int, w: int, device=None):
145
- if device is None:
146
- device = self.inv_freq_t.device
 
 
 
147
 
 
 
 
 
148
  inv_t = self.inv_freq_t.to(device=device)
149
  inv_h = self.inv_freq_h.to(device=device)
150
  inv_w = self.inv_freq_w.to(device=device)
151
 
152
- ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
153
- fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
154
- fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
155
 
156
- t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
157
- h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
158
- w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
 
159
 
160
- freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
161
- return freqs
 
 
 
 
 
 
 
 
162
 
163
  def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
164
  """
165
  Compute rotary position embeddings from explicit patch positions.
166
 
167
  Args:
168
- patch_positions: [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch
169
 
170
  Returns:
171
- freqs: [batch_size, seq_len, half] tensor of position frequencies
172
  """
173
  device = patch_positions.device
174
  inv_t = self.inv_freq_t.to(device=device)
175
  inv_h = self.inv_freq_h.to(device=device)
176
  inv_w = self.inv_freq_w.to(device=device)
177
 
178
- t_pos = patch_positions[..., 0].float() # [batch_size, seq_len]
179
- h_pos = patch_positions[..., 1].float() # [batch_size, seq_len]
180
- w_pos = patch_positions[..., 2].float() # [batch_size, seq_len]
181
 
182
- # Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim]
183
- ft = torch.einsum("bs,d->bsd", t_pos, inv_t)
184
- fh = torch.einsum("bs,d->bsd", h_pos, inv_h)
185
- fw = torch.einsum("bs,d->bsd", w_pos, inv_w)
186
 
187
  return torch.cat([ft, fh, fw], dim=-1)
188
 
 
 
 
189
 
190
- class Siglip2MultiheadAttentionPoolingHead(nn.Module):
191
- """
192
- Multi-Head Attention Pooling with a learned probe (PMA-style).
193
- """
 
194
 
195
- def __init__(self, config: OneVisionEncoderConfig):
196
- super().__init__()
197
- self.embed_dim = config.hidden_size
198
- self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
199
- self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
200
- self.norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
201
- self.mlp = SiglipMLP(config)
202
 
203
- def forward(self, hidden_states):
204
- batch_size = hidden_states.shape[0]
205
- probe = self.probe.repeat(batch_size, 1, 1)
206
 
207
- attn_output, _ = self.attention(probe, hidden_states, hidden_states)
 
 
208
 
209
- residual = attn_output
210
- attn_output = self.norm(attn_output)
211
- attn_output = residual + self.mlp(attn_output)
212
 
213
- return attn_output[:, 0]
 
214
 
215
 
216
  # ---------------------------------------------------------------------------
217
- # Modeling Components
218
  # ---------------------------------------------------------------------------
219
 
220
 
221
- class OneVisionEncoderEmbeddings(nn.Module):
222
- def __init__(self, config: OneVisionEncoderConfig):
 
 
 
 
 
 
 
 
 
 
223
  super().__init__()
224
  self.config = config
225
  self.embed_dim = config.hidden_size
226
  self.image_size = config.image_size
227
  self.patch_size = config.patch_size
 
228
 
229
  self.patch_embedding = nn.Conv2d(
230
  in_channels=config.num_channels,
@@ -234,101 +244,284 @@ class OneVisionEncoderEmbeddings(nn.Module):
234
  bias=False,
235
  )
236
 
237
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
238
- # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
239
- if pixel_values.dim() == 4:
240
- pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
241
-
242
- batch_size, channels, t_frames, height, width = pixel_values.shape
243
 
244
- # Merge time into batch for Conv2d
245
- x_2d = pixel_values.permute(0, 2, 1, 3, 4).reshape(batch_size * t_frames, channels, height, width)
246
 
247
- # Patch Embed
248
- embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
249
- embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
250
 
251
- # Flatten all patches
252
- total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
253
- embeddings = embeddings.reshape(batch_size, total_patches, self.embed_dim)
254
 
255
- return embeddings
256
 
 
 
 
257
 
258
- class OneVisionEncoderAttention(nn.Module):
259
- """Multi-headed attention with RoPE support"""
 
260
 
261
- def __init__(self, config: OneVisionEncoderConfig):
 
 
 
 
 
 
262
  super().__init__()
263
- self.config = config
264
- self.embed_dim = config.hidden_size
265
- self.num_heads = config.num_attention_heads
266
- self.head_dim = self.embed_dim // self.num_heads
267
- if self.head_dim * self.num_heads != self.embed_dim:
268
- raise ValueError(
269
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  )
271
 
272
- self.scale = self.head_dim**-0.5
273
- self.dropout = config.attention_dropout
 
 
274
 
275
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
276
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
279
 
280
- def forward(
281
- self,
282
- hidden_states: torch.Tensor,
283
- attention_mask: Optional[torch.Tensor] = None,
284
- rotary_pos_emb: Optional[torch.Tensor] = None,
285
- output_attentions: bool = False,
286
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
287
- batch_size, q_len, _ = hidden_states.size()
288
 
289
- query_states = self.q_proj(hidden_states)
290
- key_states = self.k_proj(hidden_states)
291
- value_states = self.v_proj(hidden_states)
 
 
 
292
 
293
- # (B, L, H, D) -> Transpose to (B, H, L, D)
294
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
295
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
296
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
297
 
298
- if rotary_pos_emb is not None:
299
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, rotary_pos_emb)
300
 
301
- # Calculate attention scores
302
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
303
 
304
- if attention_mask is not None:
305
- if attention_mask.size() != (batch_size, 1, q_len, q_len):
306
- if attention_mask.dim() == 3:
307
- attention_mask = attention_mask.unsqueeze(1)
308
- attn_weights = attn_weights + attention_mask
 
309
 
310
- # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
311
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
312
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
313
 
314
- attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
315
 
316
- attn_output = attn_output.transpose(1, 2).contiguous()
317
- attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
318
 
319
- attn_output = self.out_proj(attn_output)
 
320
 
321
- return attn_output, attn_weights if output_attentions else None
 
 
322
 
 
323
 
324
- class OneVisionEncoderFlashAttention2(nn.Module):
325
  """
326
  Multi-headed attention with RoPE support using Flash Attention 2.
327
- This module implements the same attention mechanism as OneVisionEncoderAttention but uses
328
- Flash Attention for improved performance and memory efficiency.
329
  """
330
 
331
- def __init__(self, config: OneVisionEncoderConfig):
332
  super().__init__()
333
  self.config = config
334
  self.embed_dim = config.hidden_size
@@ -341,11 +534,8 @@ class OneVisionEncoderFlashAttention2(nn.Module):
341
 
342
  self.scale = self.head_dim**-0.5
343
  self.dropout = config.attention_dropout
344
-
345
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
346
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
347
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
348
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
349
 
350
  def forward(
351
  self,
@@ -353,20 +543,22 @@ class OneVisionEncoderFlashAttention2(nn.Module):
353
  attention_mask: Optional[torch.Tensor] = None,
354
  rotary_pos_emb: Optional[torch.Tensor] = None,
355
  output_attentions: bool = False,
356
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
357
  """
358
  Forward pass using Flash Attention 2.
359
  """
360
  batch_size, q_len, _ = hidden_states.size()
361
-
362
- query_states = self.q_proj(hidden_states)
363
- key_states = self.k_proj(hidden_states)
364
- value_states = self.v_proj(hidden_states)
 
 
365
 
366
  # Flash Attention requires (B, L, H, D) format
367
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
368
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
369
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
370
 
371
  # Apply RoPE if provided
372
  if rotary_pos_emb is not None:
@@ -379,10 +571,10 @@ class OneVisionEncoderFlashAttention2(nn.Module):
379
  query_states = query_states.transpose(1, 2)
380
  key_states = key_states.transpose(1, 2)
381
 
382
- # Flash Attention forward pass
383
- if not _flash_attn_available:
384
- raise ImportError("flash_attn is not installed. Please install it to use OneVisionEncoderFlashAttention2.")
385
 
 
386
  attn_output = flash_attn_func(
387
  query_states,
388
  key_states,
@@ -396,33 +588,19 @@ class OneVisionEncoderFlashAttention2(nn.Module):
396
  attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
397
 
398
  # No extra casting here.
399
- attn_output = self.out_proj(attn_output)
 
400
 
401
  return attn_output, None
402
 
403
 
404
- ONEVISION_ENCODER_ATTENTION_CLASSES = {
405
- "eager": OneVisionEncoderAttention,
406
- "flash_attention_2": OneVisionEncoderFlashAttention2,
407
- }
408
-
409
 
410
- class OneVisionEncoderEncoderLayer(nn.Module):
411
- def __init__(self, config: OneVisionEncoderConfig):
412
  super().__init__()
413
  self.embed_dim = config.hidden_size
414
- # Get attention implementation from config, default to "flash_attention_2"
415
- attn_implementation = getattr(config, "_attn_implementation", "flash_attention_2")
416
- if attn_implementation not in ONEVISION_ENCODER_ATTENTION_CLASSES:
417
- # Fallback to eager if flash_attention_2 is not available
418
- if not _flash_attn_available and attn_implementation == "flash_attention_2":
419
- attn_implementation = "eager"
420
- else:
421
- raise ValueError(
422
- f"Unknown attention implementation: {attn_implementation}. "
423
- f"Available implementations: {list(ONEVISION_ENCODER_ATTENTION_CLASSES.keys())}"
424
- )
425
- self.self_attn = ONEVISION_ENCODER_ATTENTION_CLASSES[attn_implementation](config)
426
  self.layer_norm1 = get_norm_layer(config)
427
  self.mlp = SiglipMLP(config)
428
  self.layer_norm2 = get_norm_layer(config)
@@ -433,7 +611,7 @@ class OneVisionEncoderEncoderLayer(nn.Module):
433
  attention_mask: Optional[torch.Tensor] = None,
434
  rotary_pos_emb: Optional[torch.Tensor] = None,
435
  output_attentions: bool = False,
436
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
437
  residual = hidden_states
438
  hidden_states = self.layer_norm1(hidden_states)
439
 
@@ -454,11 +632,13 @@ class OneVisionEncoderEncoderLayer(nn.Module):
454
  return outputs
455
 
456
 
457
- class OneVisionEncoderEncoder(nn.Module):
458
- def __init__(self, config: OneVisionEncoderConfig):
459
  super().__init__()
460
  self.config = config
461
- self.layers = nn.ModuleList([OneVisionEncoderEncoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
462
 
463
  def forward(
464
  self,
@@ -476,12 +656,21 @@ class OneVisionEncoderEncoder(nn.Module):
476
  if output_hidden_states:
477
  all_hidden_states = all_hidden_states + (hidden_states,)
478
 
479
- layer_outputs = layer(
480
- hidden_states,
481
- attention_mask=attention_mask,
482
- rotary_pos_emb=rotary_pos_emb,
483
- output_attentions=output_attentions,
484
- )
 
 
 
 
 
 
 
 
 
485
 
486
  hidden_states = layer_outputs[0]
487
 
@@ -500,54 +689,124 @@ class OneVisionEncoderEncoder(nn.Module):
500
  attentions=all_self_attentions,
501
  )
502
 
 
 
 
 
 
 
 
 
503
 
504
- # ---------------------------------------------------------------------------
505
- # Main Models
506
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
507
 
 
 
508
 
509
- @add_start_docstrings(
510
- "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
511
- ONEVISION_ENCODER_START_DOCSTRING,
512
- )
513
- class OneVisionEncoderPreTrainedModel(PreTrainedModel):
514
- config_class = OneVisionEncoderConfig
515
- base_model_prefix = "onevision_encoder"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  supports_gradient_checkpointing = True
517
- _no_split_modules = ["OneVisionEncoderEncoderLayer"]
518
  _supports_flash_attn_2 = True
 
 
 
 
519
 
520
  def _init_weights(self, module):
521
- """Initialize the weights"""
522
- std = self.config.initializer_range
523
- if isinstance(module, (nn.Linear, nn.Conv2d)):
524
- module.weight.data.normal_(mean=0.0, std=std)
525
- if module.bias is not None:
526
- module.bias.data.zero_()
527
- elif isinstance(module, nn.Embedding):
528
- module.weight.data.normal_(mean=0.0, std=std)
529
- if module.padding_idx is not None:
530
- module.weight.data[module.padding_idx].zero_()
531
- elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
532
- # Fix: RMSNorm doesn't have bias, must check hasattr first
533
- module.weight.data.fill_(1.0)
534
- if hasattr(module, "bias") and module.bias is not None:
535
- module.bias.data.zero_()
536
-
537
-
538
- @add_start_docstrings(
539
- "OneVision Encoder Model with a vision transformer encoder.",
540
- ONEVISION_ENCODER_START_DOCSTRING,
541
- )
542
- class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
543
- def __init__(self, config: OneVisionEncoderConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  super().__init__(config)
545
  self.config = config
 
546
 
547
- self.embeddings = OneVisionEncoderEmbeddings(config)
 
548
  self.layernorm_pre = get_norm_layer(config)
549
- self.encoder = OneVisionEncoderEncoder(config)
550
- self.video_rope = VideoRotaryEmbeddingSplit466(config)
551
 
552
  if config.use_head:
553
  self.layernorm_post = get_norm_layer(config)
@@ -556,119 +815,842 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
556
  self.layernorm_post = None
557
  self.head = None
558
 
 
 
 
 
 
 
 
559
  self.post_init()
560
 
561
- @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
562
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
563
  def forward(
564
  self,
565
- pixel_values: torch.Tensor,
566
- visible_indices: Optional[torch.Tensor] = None,
567
  patch_positions: Optional[torch.Tensor] = None,
568
  output_attentions: Optional[bool] = None,
569
  output_hidden_states: Optional[bool] = None,
570
  return_dict: Optional[bool] = None,
 
571
  ) -> Union[tuple, BaseModelOutputWithPooling]:
572
  r"""
573
- Returns:
574
 
575
- Examples:
 
 
576
 
577
- ```python
578
- >>> from transformers import AutoModel, AutoImageProcessor
579
- >>> from PIL import Image
 
 
 
 
 
 
 
580
 
581
- >>> model = AutoModel.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True)
582
- >>> preprocessor = AutoImageProcessor.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True)
583
- >>> image = Image.open("path/to/your/image.jpg") # Replace with your image path
584
- >>> pixel_values = preprocessor(images=image, return_tensors="pt")["pixel_values"]
585
- >>> outputs = model(pixel_values)
586
- >>> last_hidden_states = outputs.last_hidden_state
587
- >>> pooled_output = outputs.pooler_output
588
- ```
589
  """
590
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
591
  output_hidden_states = (
592
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
593
  )
594
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
595
-
596
- # Determine video dimensions for RoPE
597
- # Note: pixel_values passed to embeddings can be 4D or 5D
598
- if pixel_values.dim() == 5:
599
- # Use config.rope_temporal_size if set, otherwise use actual frame count
600
- t_frames = (
601
- self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
602
- )
603
- height = pixel_values.shape[3]
604
- width = pixel_values.shape[4]
605
- else:
606
- t_frames = 1
607
- height = pixel_values.shape[2]
608
- width = pixel_values.shape[3]
609
 
610
  # 1. Embeddings
611
- hidden_states = self.embeddings(pixel_values)
 
 
 
612
  batch_size, total_patches, _ = hidden_states.shape
613
 
614
- # 2. Visible Indices Handling
615
- if visible_indices is None:
616
- visible_indices = (
617
- torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
618
- )
619
-
620
- # 3. RoPE Construction
621
- if patch_positions is not None:
622
- freqs_visible = self.video_rope.forward_from_positions(patch_positions)
623
  else:
624
- freqs_full = self.video_rope(
625
- t=t_frames,
626
- h=height // self.config.patch_size,
627
- w=width // self.config.patch_size,
628
- device=pixel_values.device,
629
- )
630
- freqs_visible = freqs_full[visible_indices]
 
 
 
 
 
 
 
 
631
 
632
  # Concatenate D/2 + D/2 -> D for applying rope
633
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
 
 
634
 
635
- # 4. Pre-Norm & Encoder
636
  hidden_states = self.layernorm_pre(hidden_states)
637
 
638
- # fix: gather hidden_states to match freqs_visible when using sparse visible_indices
639
- num_visible = visible_indices.shape[1]
640
- if num_visible != total_patches:
641
- # sparse mode: select only visible patches
642
- hidden_states = hidden_states.gather(
643
- 1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
644
- )
645
-
646
  encoder_outputs = self.encoder(
647
  hidden_states,
648
  attention_mask=None,
649
  rotary_pos_emb=freqs_visible,
650
  output_attentions=output_attentions,
651
- output_hidden_states=output_hidden_states,
652
- return_dict=return_dict,
653
  )
654
 
655
- sequence_output = encoder_outputs[0]
 
 
 
 
656
 
657
- # Apply post-norm if configured
658
  if self.layernorm_post is not None:
659
  sequence_output = self.layernorm_post(sequence_output)
660
 
661
- # 5. Pooling Head
662
- pooled_output = None
663
- if self.head is not None:
664
- pooled_output = self.head(sequence_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  if not return_dict:
667
- return (sequence_output, pooled_output) + encoder_outputs[1:]
668
 
669
  return BaseModelOutputWithPooling(
670
- last_hidden_state=sequence_output,
671
- pooler_output=pooled_output,
672
- hidden_states=encoder_outputs.hidden_states,
673
- attentions=encoder_outputs.attentions,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
+ from torch.nn import LayerNorm
7
 
8
+ from transformers import AutoModel
9
+ from transformers.cache_utils import Cache
10
+ from transformers.generation import GenerationMixin
11
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
12
  from transformers.modeling_utils import PreTrainedModel
13
  from transformers.models.siglip.modeling_siglip import SiglipMLP
14
+ from transformers.processing_utils import Unpack
15
  from transformers.utils import (
16
+ TransformersKwargs,
17
+ auto_docstring,
18
+ can_return_tuple,
19
+ is_flash_attn_2_available,
20
  replace_return_docstrings,
21
  )
22
 
23
+ from .configuration_llava_onevision2 import LlavaOnevision2Config, LlavaOnevision2VisionConfig
24
 
25
 
26
+ if is_flash_attn_2_available():
27
  from flash_attn import flash_attn_func
28
 
 
 
 
29
 
30
+ @dataclass
31
+ @auto_docstring(
32
+ custom_intro="""
33
+ Base class for Llava-Onevision-1.5 outputs, with hidden states and attentions.
34
+ """
35
+ )
36
+ class LlavaOnevision2ModelOutputWithPast(ModelOutput):
37
+ r"""
38
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
39
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
42
+ `past_key_values` input) to speed up sequential decoding.
43
+ """
44
 
45
+ last_hidden_state: Optional[torch.FloatTensor] = None
46
+ past_key_values: Optional[Cache] = None
47
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
48
+ attentions: Optional[tuple[torch.FloatTensor]] = None
 
49
 
50
 
51
+ @dataclass
52
+ @auto_docstring(
53
+ custom_intro="""
54
+ Base class for Llava-Onevision-1.5 causal language model (or autoregressive) outputs.
55
  """
56
+ )
57
+ class LlavaOnevision2CausalLMOutputWithPast(ModelOutput):
58
+ r"""
59
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
60
+ Language modeling loss (for next-token prediction).
61
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
62
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
63
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
64
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
65
+
66
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
67
+ `past_key_values` input) to speed up sequential decoding.
68
  """
 
 
 
 
 
 
 
 
69
 
70
+ loss: Optional[torch.FloatTensor] = None
71
+ logits: Optional[torch.FloatTensor] = None
72
+ past_key_values: Optional[Cache] = None
73
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
74
+ attentions: Optional[tuple[torch.FloatTensor]] = None
75
 
 
 
 
 
 
76
 
77
+ # ---------------------------------------------------------------------------
78
+ # Vision Rotary Embedding
79
+ # ---------------------------------------------------------------------------
80
 
81
 
82
+ class VisionRotaryEmbedding(nn.Module):
83
  """
84
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
85
+ Supports both grid_thw-based and explicit position-based RoPE computation.
86
  """
87
 
88
+ def __init__(self, config: LlavaOnevision2VisionConfig):
89
  super().__init__()
90
  head_dim = config.hidden_size // config.num_attention_heads
91
  base = config.rope_theta
 
98
  self.head_dim = head_dim
99
  self.half = half
100
 
101
+ # 4:6:6 split for T:H:W
102
  unit = half // 16
103
  self.t_size = 4 * unit
104
  self.h_size = 6 * unit
 
120
  persistent=False,
121
  )
122
 
123
+ def forward(self, grid_thw: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ Compute rotary position embeddings from grid_thw (Qwen2VL style).
126
+
127
+ Args:
128
+ grid_thw: [num_samples, 3] tensor with [t, h, w] for each sample
129
 
130
+ Returns:
131
+ freqs: [total_seq_len, half] tensor of position frequencies
132
+ """
133
+ device = grid_thw.device
134
  inv_t = self.inv_freq_t.to(device=device)
135
  inv_h = self.inv_freq_h.to(device=device)
136
  inv_w = self.inv_freq_w.to(device=device)
137
 
138
+ all_freqs = []
139
+ for sample_thw in grid_thw:
140
+ t, h, w = sample_thw[0].item(), sample_thw[1].item(), sample_thw[2].item()
141
 
142
+ # Compute frequency tables
143
+ ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
144
+ fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
145
+ fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
146
 
147
+ # Build position indices for this sample
148
+ t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
149
+ h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
150
+ w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
151
+
152
+ # Concatenate frequencies: [seq_len, half]
153
+ sample_freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
154
+ all_freqs.append(sample_freqs)
155
+
156
+ return torch.cat(all_freqs, dim=0)
157
 
158
  def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
159
  """
160
  Compute rotary position embeddings from explicit patch positions.
161
 
162
  Args:
163
+ patch_positions: [seq_len, 3] tensor with [t, h, w] positions for each patch
164
 
165
  Returns:
166
+ freqs: [seq_len, half] tensor of position frequencies
167
  """
168
  device = patch_positions.device
169
  inv_t = self.inv_freq_t.to(device=device)
170
  inv_h = self.inv_freq_h.to(device=device)
171
  inv_w = self.inv_freq_w.to(device=device)
172
 
173
+ t_pos = patch_positions[:, 0].float()
174
+ h_pos = patch_positions[:, 1].float()
175
+ w_pos = patch_positions[:, 2].float()
176
 
177
+ ft = torch.outer(t_pos, inv_t)
178
+ fh = torch.outer(h_pos, inv_h)
179
+ fw = torch.outer(w_pos, inv_w)
 
180
 
181
  return torch.cat([ft, fh, fw], dim=-1)
182
 
183
+ def forward_with_thw(self, t: int, h: int, w: int, device=None) -> torch.Tensor:
184
+ """
185
+ Compute rotary position embeddings from explicit t, h, w dimensions.
186
 
187
+ Args:
188
+ t: Number of temporal frames
189
+ h: Number of height patches
190
+ w: Number of width patches
191
+ device: Target device
192
 
193
+ Returns:
194
+ freqs: [t*h*w, half] tensor of position frequencies
195
+ """
196
+ if device is None:
197
+ device = self.inv_freq_t.device
 
 
198
 
199
+ inv_t = self.inv_freq_t.to(device=device)
200
+ inv_h = self.inv_freq_h.to(device=device)
201
+ inv_w = self.inv_freq_w.to(device=device)
202
 
203
+ ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
204
+ fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
205
+ fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
206
 
207
+ t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
208
+ h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
209
+ w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
210
 
211
+ freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
212
+ return freqs
213
 
214
 
215
  # ---------------------------------------------------------------------------
216
+ # Patch Embedding
217
  # ---------------------------------------------------------------------------
218
 
219
 
220
+ class LlavaViTEmbeddings(nn.Module):
221
+ """
222
+ Patch embedding layer that converts pre-processed patches to embeddings.
223
+
224
+ This module is designed to receive patches that have already been extracted
225
+ and arranged by the Qwen2VL image processor in 2x2 block spatial order.
226
+
227
+ Input format: [total_patches, num_channels, patch_size, patch_size]
228
+ Output format: [total_patches, embed_dim]
229
+ """
230
+
231
+ def __init__(self, config: LlavaOnevision2VisionConfig):
232
  super().__init__()
233
  self.config = config
234
  self.embed_dim = config.hidden_size
235
  self.image_size = config.image_size
236
  self.patch_size = config.patch_size
237
+ self.in_channels = config.num_channels
238
 
239
  self.patch_embedding = nn.Conv2d(
240
  in_channels=config.num_channels,
 
244
  bias=False,
245
  )
246
 
247
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
248
+ target_dtype = self.patch_embedding.weight.dtype
249
+ hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
250
+ hidden_states = self.patch_embedding(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
 
 
251
 
252
+ return hidden_states
 
253
 
 
 
 
254
 
255
+ # ---------------------------------------------------------------------------
256
+ # Patch Merger
257
+ # ---------------------------------------------------------------------------
258
 
 
259
 
260
+ class LlavaOnevision2VisionPatchMerger(nn.Module):
261
+ """
262
+ Patch merger that merges spatial_merge_size x spatial_merge_size patches into one.
263
 
264
+ This module is designed to work with Qwen2VL-style patch processing where patches
265
+ are already arranged in 2x2 block order by the image processor.
266
+ """
267
 
268
+ def __init__(
269
+ self,
270
+ dim: int,
271
+ context_dim: int,
272
+ spatial_merge_size: int = 2,
273
+ layer_norm_eps: float = 1e-05,
274
+ ) -> None:
275
  super().__init__()
276
+ self.hidden_size = context_dim * (spatial_merge_size**2)
277
+ self.ln_q = LayerNorm(context_dim, eps=layer_norm_eps)
278
+ self.mlp = nn.Sequential(
279
+ nn.Linear(self.hidden_size, self.hidden_size),
280
+ nn.GELU(),
281
+ nn.Linear(self.hidden_size, dim),
282
+ )
283
+ self.spatial_merge_size = spatial_merge_size
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ """
287
+ Merge patches from Qwen2VL-style input.
288
+
289
+ The input patches are already arranged in 2x2 block order by the image processor,
290
+ so we simply need to apply LayerNorm, reshape, and project through MLP.
291
+
292
+ Args:
293
+ x: Input tensor of shape [batch_size, seq_len, hidden_size] or [seq_len, hidden_size]
294
+ where seq_len = t * h * w (patches in 2x2 block order)
295
+
296
+ Returns:
297
+ Merged tensor of shape [batch_size, seq_len // spatial_merge_size^2, dim]
298
+ or [seq_len // spatial_merge_size^2, dim]
299
+ """
300
+ x = self.ln_q(x).view(-1, self.hidden_size)
301
+ x = self.mlp(x)
302
+ return x
303
+
304
+
305
+ def rotate_half(x):
306
+ """
307
+ Interleaved rotation to match Source model's implementation.
308
+ (x1, x2, x3, x4) -> (-x2, x1, -x4, x3)
309
+ """
310
+ x_even = x[..., ::2]
311
+ x_odd = x[..., 1::2]
312
+ return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
313
+
314
+
315
+ def get_norm_layer(config):
316
+ if config.layer_norm_type == "rms_norm":
317
+ return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
318
+ else:
319
+ return nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
320
+
321
+
322
+ def apply_rotary_pos_emb(q, k, freqs):
323
+ # q, k: (B, H, L, D)
324
+ # freqs: (B, L, D)
325
+ orig_q_dtype = q.dtype
326
+ orig_k_dtype = k.dtype
327
+ q, k = q.float(), k.float()
328
+ # We need to broadcast freqs to match heads
329
+ # (B, L, D) -> (B, 1, L, D)
330
+ # Keep the same dtype as q, k to avoid memory doubling from float32 promotion
331
+ cos = freqs.cos().unsqueeze(1).float()
332
+ sin = freqs.sin().unsqueeze(1).float()
333
+
334
+ q_embed = (q * cos) + (rotate_half(q) * sin)
335
+ k_embed = (k * cos) + (rotate_half(k) * sin)
336
+ q_embed = q_embed.to(orig_q_dtype)
337
+ k_embed = k_embed.to(orig_k_dtype)
338
+ return q_embed, k_embed
339
+
340
+ def convert_rope_to_block_layout(
341
+ freqs: torch.Tensor, t: int, h: int, w: int, spatial_merge_size: int = 2
342
+ ) -> torch.Tensor:
343
+ """
344
+ Convert RoPE from row-major order (1x1 layout) to 2x2 block layout.
345
+
346
+ The image processor arranges patches in 2x2 blocks when spatial_merge_size=2:
347
+ - Row-major order: [p(0,0), p(0,1), p(0,2), p(0,3), ..., p(1,0), p(1,1), ...]
348
+ - Block order: [p(0,0), p(0,1), p(1,0), p(1,1)], [p(0,2), p(0,3), p(1,2), p(1,3)], ...
349
+
350
+ Args:
351
+ freqs: RoPE frequencies in row-major order, shape [t*h*w, half]
352
+ t: temporal dimension
353
+ h: height (unmerged patch count)
354
+ w: width (unmerged patch count)
355
+ spatial_merge_size: size of spatial merge blocks (default: 2)
356
+
357
+ Returns:
358
+ torch.Tensor: RoPE frequencies in 2x2 block order, same shape [t*h*w, half]
359
+ """
360
+ sms = spatial_merge_size
361
+ if sms == 1:
362
+ return freqs
363
+
364
+ half = freqs.shape[-1]
365
+
366
+ # freqs shape: [t*h*w, half]
367
+ # Reshape to [t, h, w, half]
368
+ freqs = freqs.view(t, h, w, half)
369
+
370
+ # Calculate merged dimensions
371
+ h_merged = h // sms
372
+ w_merged = w // sms
373
+
374
+ # Reshape to [t, h_merged, sms, w_merged, sms, half]
375
+ freqs = freqs.view(t, h_merged, sms, w_merged, sms, half)
376
+
377
+ # Permute to [t, h_merged, w_merged, sms_h, sms_w, half] - 2x2 block order
378
+ freqs = freqs.permute(0, 1, 3, 2, 4, 5).contiguous()
379
+
380
+ # Reshape back to [t*h*w, half]
381
+ freqs = freqs.view(t * h * w, half)
382
+
383
+ return freqs
384
+
385
+
386
+ def convert_rope_to_block_layout_by_positions(
387
+ freqs: torch.Tensor,
388
+ patch_positions: torch.Tensor,
389
+ spatial_merge_size: int = 2,
390
+ grid_thw: Optional[torch.Tensor] = None,
391
+ ) -> torch.Tensor:
392
+ """
393
+ Convert RoPE from row-major order to 2x2 block layout, grouping by temporal index.
394
+
395
+ This function automatically groups patches by their temporal index (t) from patch_positions,
396
+ then applies 2x2 spatial reordering within each temporal group.
397
+
398
+ Optimized version: if all frames have the same spatial size, use vectorized operations.
399
+
400
+ Args:
401
+ freqs: RoPE frequencies in row-major order, shape [seq_len, half]
402
+ patch_positions: Patch positions tensor, shape [seq_len, 3] with [t, h, w] for each patch
403
+ spatial_merge_size: size of spatial merge blocks (default: 2)
404
+ grid_thw: Optional grid_thw tensor for reliable h, w extraction
405
+
406
+ Returns:
407
+ torch.Tensor: RoPE frequencies in 2x2 block order, same shape [seq_len, half]
408
+ """
409
+ sms = spatial_merge_size
410
+ if sms == 1:
411
+ return freqs
412
+
413
+ half = freqs.shape[-1]
414
+ seq_len = freqs.shape[0]
415
+
416
+ # Get temporal indices
417
+ t_indices = patch_positions[:, 0]
418
+
419
+ # Find unique t values and their counts (preserving order)
420
+ unique_t, inverse_indices, counts = torch.unique_consecutive(
421
+ t_indices, return_inverse=True, return_counts=True
422
+ )
423
+
424
+ num_groups = unique_t.shape[0]
425
+
426
+ # Fast path: single image with grid_thw available
427
+ if num_groups == 1 and grid_thw is not None:
428
+ height = grid_thw[0, 1].item()
429
+ width = grid_thw[0, 2].item()
430
+ return convert_rope_to_block_layout(freqs, t=1, h=height, w=width, spatial_merge_size=sms)
431
+
432
+ # Fast path: single image, square
433
+ if num_groups == 1:
434
+ hw = int(seq_len ** 0.5)
435
+ if hw * hw == seq_len:
436
+ return convert_rope_to_block_layout(freqs, t=1, h=hw, w=hw, spatial_merge_size=sms)
437
+
438
+ # Check if all groups have the same size (common case for videos)
439
+ # This allows vectorized processing
440
+ first_count = counts[0].item()
441
+ all_same_size = torch.all(counts == first_count).item()
442
+
443
+ if all_same_size:
444
+ # Vectorized path: all frames have same spatial size
445
+ group_size = first_count
446
+ hw = int(group_size ** 0.5)
447
+
448
+ if hw * hw == group_size:
449
+ # Square frames: use fully vectorized convert_rope_to_block_layout
450
+ # Reshape freqs to [num_groups, h, w, half] and process as batch
451
+ return convert_rope_to_block_layout(
452
+ freqs, t=num_groups, h=hw, w=hw, spatial_merge_size=sms
453
+ )
454
+ elif grid_thw is not None:
455
+ # Non-square but have grid_thw: get h, w from grid_thw
456
+ height = grid_thw[0, 1].item()
457
+ width = grid_thw[0, 2].item()
458
+ return convert_rope_to_block_layout(
459
+ freqs, t=num_groups, h=height, w=width, spatial_merge_size=sms
460
  )
461
 
462
+ # Slow path: variable frame sizes, process each group separately
463
+ # Pre-compute cumulative offsets to avoid repeated slicing
464
+ cum_counts = torch.cumsum(counts, dim=0)
465
+ start_indices = torch.cat([torch.tensor([0], device=counts.device), cum_counts[:-1]])
466
 
467
+ result_freqs = torch.empty_like(freqs)
 
 
 
468
 
469
+ for group_idx in range(num_groups):
470
+ start_idx = start_indices[group_idx].item()
471
+ group_size = counts[group_idx].item()
472
+ end_idx = start_idx + group_size
 
 
 
 
473
 
474
+ # Infer spatial dimensions
475
+ hw = int(group_size ** 0.5)
476
+ if hw * hw == group_size:
477
+ h, w = hw, hw
478
+ else:
479
+ h, w = _infer_hw_from_positions(patch_positions[start_idx:end_idx], sms)
480
 
481
+ # Apply block layout conversion
482
+ result_freqs[start_idx:end_idx] = convert_rope_to_block_layout(
483
+ freqs[start_idx:end_idx], t=1, h=h, w=w, spatial_merge_size=sms
484
+ )
485
 
486
+ return result_freqs
 
487
 
 
 
488
 
489
+ def _infer_hw_from_positions(
490
+ group_positions: torch.Tensor,
491
+ spatial_merge_size: int = 2
492
+ ) -> tuple[int, int]:
493
+ """
494
+ Infer height and width from patch positions within a temporal group.
495
 
496
+ Args:
497
+ group_positions: Patch positions for one temporal group, shape [group_size, 3]
498
+ spatial_merge_size: size of spatial merge blocks
499
 
500
+ Returns:
501
+ tuple[int, int]: (height, width) of the spatial grid
502
+ """
503
+ # Get unique h and w values
504
+ h_values = group_positions[:, 1]
505
+ w_values = group_positions[:, 2]
506
 
507
+ h_unique = torch.unique(h_values)
508
+ w_unique = torch.unique(w_values)
509
 
510
+ h = h_unique.shape[0]
511
+ w = w_unique.shape[0]
512
 
513
+ # Validate dimensions are divisible by spatial_merge_size
514
+ assert h % spatial_merge_size == 0, f"Height {h} not divisible by {spatial_merge_size}"
515
+ assert w % spatial_merge_size == 0, f"Width {w} not divisible by {spatial_merge_size}"
516
 
517
+ return h, w
518
 
519
+ class LlavaViTFlashAttention2(nn.Module):
520
  """
521
  Multi-headed attention with RoPE support using Flash Attention 2.
 
 
522
  """
523
 
524
+ def __init__(self, config: LlavaOnevision2VisionConfig):
525
  super().__init__()
526
  self.config = config
527
  self.embed_dim = config.hidden_size
 
534
 
535
  self.scale = self.head_dim**-0.5
536
  self.dropout = config.attention_dropout
537
+ self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3)
538
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
 
 
 
539
 
540
  def forward(
541
  self,
 
543
  attention_mask: Optional[torch.Tensor] = None,
544
  rotary_pos_emb: Optional[torch.Tensor] = None,
545
  output_attentions: bool = False,
546
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
547
  """
548
  Forward pass using Flash Attention 2.
549
  """
550
  batch_size, q_len, _ = hidden_states.size()
551
+ q, k, v = (
552
+ self.qkv(hidden_states)
553
+ .reshape(batch_size, q_len, 3, self.num_heads, self.head_dim)
554
+ .permute(2, 0, 1, 3, 4)
555
+ .unbind(0)
556
+ )
557
 
558
  # Flash Attention requires (B, L, H, D) format
559
+ query_states = q
560
+ key_states = k
561
+ value_states = v
562
 
563
  # Apply RoPE if provided
564
  if rotary_pos_emb is not None:
 
571
  query_states = query_states.transpose(1, 2)
572
  key_states = key_states.transpose(1, 2)
573
 
574
+ # FIX: Removed the explicit float32 check and downcast.
575
+ # We assume input is already correct (bf16/fp16) thanks to RoPE fix.
 
576
 
577
+ # Flash Attention forward pass
578
  attn_output = flash_attn_func(
579
  query_states,
580
  key_states,
 
588
  attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
589
 
590
  # No extra casting here.
591
+ # attn_output = self.out_proj(attn_output)
592
+ attn_output = self.proj(attn_output)
593
 
594
  return attn_output, None
595
 
596
 
597
+ class LlavaViTEncoderLayer(nn.Module):
598
+ """Vision encoder layer with pre-norm and Flash Attention 2."""
 
 
 
599
 
600
+ def __init__(self, config: LlavaOnevision2VisionConfig):
 
601
  super().__init__()
602
  self.embed_dim = config.hidden_size
603
+ self.self_attn = LlavaViTFlashAttention2(config)
 
 
 
 
 
 
 
 
 
 
 
604
  self.layer_norm1 = get_norm_layer(config)
605
  self.mlp = SiglipMLP(config)
606
  self.layer_norm2 = get_norm_layer(config)
 
611
  attention_mask: Optional[torch.Tensor] = None,
612
  rotary_pos_emb: Optional[torch.Tensor] = None,
613
  output_attentions: bool = False,
614
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
615
  residual = hidden_states
616
  hidden_states = self.layer_norm1(hidden_states)
617
 
 
632
  return outputs
633
 
634
 
635
+ class LlavaViTEncoder(nn.Module):
636
+ def __init__(self, config: LlavaOnevision2VisionConfig):
637
  super().__init__()
638
  self.config = config
639
+ self.layers = nn.ModuleList([LlavaViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
640
+ # Gradient checkpointing support
641
+ self.gradient_checkpointing = False
642
 
643
  def forward(
644
  self,
 
656
  if output_hidden_states:
657
  all_hidden_states = all_hidden_states + (hidden_states,)
658
 
659
+ if self.gradient_checkpointing and self.training:
660
+ layer_outputs = self._gradient_checkpointing_func(
661
+ layer.__call__,
662
+ hidden_states,
663
+ attention_mask,
664
+ rotary_pos_emb,
665
+ output_attentions,
666
+ )
667
+ else:
668
+ layer_outputs = layer(
669
+ hidden_states,
670
+ attention_mask=attention_mask,
671
+ rotary_pos_emb=rotary_pos_emb,
672
+ output_attentions=output_attentions,
673
+ )
674
 
675
  hidden_states = layer_outputs[0]
676
 
 
689
  attentions=all_self_attentions,
690
  )
691
 
692
+ def forward_debug(
693
+ self,
694
+ hidden_states: torch.Tensor,
695
+ attention_mask: Optional[torch.Tensor] = None,
696
+ rotary_pos_emb: Optional[torch.Tensor] = None,
697
+ ) -> dict:
698
+ """
699
+ Forward pass with layer-by-layer debug outputs for consistency checking.
700
 
701
+ Returns:
702
+ dict: Contains:
703
+ - 'input_hidden_states': Input to the encoder
704
+ - 'input_rotary_pos_emb': Rotary position embeddings input
705
+ - 'layer_outputs': Dict mapping layer index to output after that layer
706
+ - 'final_output': Final encoder output
707
+ """
708
+ output = {}
709
+
710
+ # Save input
711
+ output["input_hidden_states"] = hidden_states.clone()
712
+ if rotary_pos_emb is not None:
713
+ output["input_rotary_pos_emb"] = rotary_pos_emb.clone()
714
 
715
+ # Layer-by-layer outputs
716
+ layer_outputs = {}
717
 
718
+ for layer_idx, layer in enumerate(self.layers):
719
+ # Save input to this layer
720
+ layer_outputs[f"layer_{layer_idx}_input"] = hidden_states.clone()
721
+
722
+ # Forward through layer
723
+ layer_result = layer(
724
+ hidden_states,
725
+ attention_mask=attention_mask,
726
+ rotary_pos_emb=rotary_pos_emb,
727
+ output_attentions=False,
728
+ )
729
+ hidden_states = layer_result[0]
730
+
731
+ # Save output of this layer
732
+ layer_outputs[f"layer_{layer_idx}_output"] = hidden_states.clone()
733
+
734
+ output["layer_outputs"] = layer_outputs
735
+ output["final_output"] = hidden_states.clone()
736
+
737
+ return output
738
+
739
+
740
+ class LlavaOnevision2PreTrainedModel(PreTrainedModel):
741
+ config_class = LlavaOnevision2Config
742
+ base_model_prefix = "model"
743
  supports_gradient_checkpointing = True
 
744
  _supports_flash_attn_2 = True
745
+ _no_split_modules = ["LlavaViTEncoderLayer"]
746
+ _skip_keys_device_placement = "past_key_values"
747
+ _supports_flash_attn = True
748
+ _supports_sdpa = True
749
 
750
  def _init_weights(self, module):
751
+ super()._init_weights(module)
752
+ # Custom weight initialization can be added here if needed
753
+ # For LlavaOnevision2VisionPretrainedModel, we rely on default initialization
754
+
755
+
756
+ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
757
+ """
758
+ Multi-Head Attention Pooling with a learned probe (PMA-style).
759
+ """
760
+
761
+ def __init__(self, config: LlavaOnevision2VisionConfig):
762
+ super().__init__()
763
+ self.embed_dim = config.hidden_size
764
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
765
+ self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
766
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
767
+ self.mlp = SiglipMLP(config)
768
+
769
+ def forward(self, hidden_states):
770
+ batch_size = hidden_states.shape[0]
771
+ probe = self.probe.repeat(batch_size, 1, 1)
772
+
773
+ attn_output, _ = self.attention(probe, hidden_states, hidden_states)
774
+
775
+ residual = attn_output
776
+ attn_output = self.norm(attn_output)
777
+ attn_output = residual + self.mlp(attn_output)
778
+
779
+ return attn_output[:, 0]
780
+
781
+ # ---------------------------------------------------------------------------
782
+ # Vision Model
783
+ # ---------------------------------------------------------------------------
784
+
785
+
786
+ class LlavaOnevision2VisionPretrainedModel(LlavaOnevision2PreTrainedModel):
787
+ """
788
+ LLaVA-OneVision 2.0 Vision Model.
789
+
790
+ This vision model is designed to work with Qwen2VL-style image processing:
791
+ - Receives pre-processed patches in 2x2 block spatial order
792
+ - Applies RoPE with matching 2x2 block layout conversion
793
+ - Accepts explicit patch_positions for RoPE computation
794
+
795
+ Input format:
796
+ hidden_state: [total_patches, num_channels, patch_size, patch_size]
797
+ grid_thw: [num_samples, 3] with [t, h, w] for each sample
798
+ """
799
+
800
+ def __init__(self, config: LlavaOnevision2VisionConfig):
801
  super().__init__(config)
802
  self.config = config
803
+ self.spatial_merge_size = config.spatial_merge_size
804
 
805
+ # Vision components
806
+ self.embeddings = LlavaViTEmbeddings(config)
807
  self.layernorm_pre = get_norm_layer(config)
808
+ self.encoder = LlavaViTEncoder(config)
809
+ self.video_rope = VisionRotaryEmbedding(config)
810
 
811
  if config.use_head:
812
  self.layernorm_post = get_norm_layer(config)
 
815
  self.layernorm_post = None
816
  self.head = None
817
 
818
+ self.merger = LlavaOnevision2VisionPatchMerger(
819
+ dim=config.out_hidden_size,
820
+ context_dim=config.hidden_size,
821
+ spatial_merge_size=config.spatial_merge_size,
822
+ layer_norm_eps=config.layer_norm_eps,
823
+ )
824
+
825
  self.post_init()
826
 
827
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=LlavaOnevision2VisionConfig)
 
828
  def forward(
829
  self,
830
+ hidden_state: torch.Tensor,
831
+ grid_thw: Optional[torch.Tensor] = None,
832
  patch_positions: Optional[torch.Tensor] = None,
833
  output_attentions: Optional[bool] = None,
834
  output_hidden_states: Optional[bool] = None,
835
  return_dict: Optional[bool] = None,
836
+ skip_merger: Optional[bool] = False,
837
  ) -> Union[tuple, BaseModelOutputWithPooling]:
838
  r"""
839
+ Forward pass for vision model.
840
 
841
+ This method accepts pre-processed patches from Qwen2VL image processor and applies
842
+ RoPE (Rotary Position Embedding) in 2x2 block layout to match the spatial arrangement
843
+ of patches.
844
 
845
+ Args:
846
+ hidden_state: Pre-processed patches from Qwen2VL processor.
847
+ Shape: [total_patches, num_channels, patch_size, patch_size]
848
+ grid_thw: Grid sizes tensor of shape [num_samples, 3] with [t, h, w] for each sample.
849
+ Required for computing RoPE and handling visible indices.
850
+ patch_positions: Optional explicit patch positions for RoPE computation.
851
+ output_attentions: Whether to return attention weights.
852
+ output_hidden_states: Whether to return all hidden states.
853
+ return_dict: Whether to return a ModelOutput instead of tuple.
854
+ skip_merger: If True, skip patch merger (useful for consistency checking).
855
 
856
+ Returns:
857
+ BaseModelOutputWithPooling with last_hidden_state containing merged features.
 
 
 
 
 
 
858
  """
859
+ output_attentions = (
860
+ output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
861
+ )
862
  output_hidden_states = (
863
+ output_hidden_states
864
+ if output_hidden_states is not None
865
+ else getattr(self.config, "output_hidden_states", False)
866
  )
867
+ return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  # 1. Embeddings
870
+ # Note: embeddings returns [total_patches, embed_dim], we need to add batch dimension
871
+ hidden_states = self.embeddings(hidden_state)
872
+ if hidden_states.dim() == 2:
873
+ hidden_states = hidden_states.unsqueeze(0) # [1, total_patches, embed_dim]
874
  batch_size, total_patches, _ = hidden_states.shape
875
 
876
+ # 2. RoPE Construction
877
+ # Get dimensions from grid_thw for block layout conversion
878
+ if grid_thw is not None:
879
+ t_frames = grid_thw[0, 0].item()
880
+ height = grid_thw[0, 1].item()
881
+ width = grid_thw[0, 2].item()
 
 
 
882
  else:
883
+ # Fallback: infer from total_patches (assume single frame, square)
884
+ t_frames = 1
885
+ height = int(total_patches ** 0.5)
886
+ width = height
887
+
888
+ if patch_positions is not None and patch_positions.dim() == 3:
889
+ patch_positions = patch_positions.squeeze(0)
890
+ freqs_visible = self.video_rope.forward_from_positions(patch_positions)
891
+
892
+ # Convert RoPE from row-major to block layout (matching Qwen2VL processor output)
893
+ # Use position-based grouping for videos with variable frame sizes
894
+ # Pass grid_thw for reliable h, w extraction (especially for non-square images)
895
+ freqs_visible = convert_rope_to_block_layout_by_positions(
896
+ freqs_visible, patch_positions, spatial_merge_size=2, grid_thw=grid_thw
897
+ )
898
 
899
  # Concatenate D/2 + D/2 -> D for applying rope
900
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
901
+ if freqs_visible.dim() == 2:
902
+ freqs_visible = freqs_visible.unsqueeze(0)
903
 
904
+ # 3. Pre-Norm & Encoder
905
  hidden_states = self.layernorm_pre(hidden_states)
906
 
 
 
 
 
 
 
 
 
907
  encoder_outputs = self.encoder(
908
  hidden_states,
909
  attention_mask=None,
910
  rotary_pos_emb=freqs_visible,
911
  output_attentions=output_attentions,
912
+ output_hidden_states=True, # Always get hidden states to use -2 layer
913
+ return_dict=True,
914
  )
915
 
916
+ # Use second-to-last layer output for better feature representation
917
+ if encoder_outputs.hidden_states is not None and len(encoder_outputs.hidden_states) >= 2 and not skip_merger:
918
+ sequence_output = encoder_outputs.hidden_states[-2]
919
+ else:
920
+ sequence_output = encoder_outputs[0]
921
 
922
+ # Post-Norm
923
  if self.layernorm_post is not None:
924
  sequence_output = self.layernorm_post(sequence_output)
925
 
926
+ # Skip merger for consistency check with original ViT
927
+ if skip_merger:
928
+ pooled_output = None
929
+ if self.head is not None:
930
+ pooled_output = self.head(sequence_output)
931
+
932
+ if not return_dict:
933
+ return (sequence_output, pooled_output) + (
934
+ encoder_outputs.hidden_states if output_hidden_states else None,
935
+ )
936
+ return BaseModelOutputWithPooling(
937
+ last_hidden_state=sequence_output,
938
+ pooler_output=pooled_output,
939
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
940
+ attentions=encoder_outputs.attentions if output_attentions else None,
941
+ )
942
+
943
+ # Patch merger: input patches are already in 2x2 block order from Qwen2VL processor
944
+ merged_output = self.merger(sequence_output)
945
 
946
  if not return_dict:
947
+ return (merged_output,) + (encoder_outputs.hidden_states if output_hidden_states else None,)
948
 
949
  return BaseModelOutputWithPooling(
950
+ last_hidden_state=merged_output,
951
+ pooler_output=None,
952
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
953
+ attentions=encoder_outputs.attentions if output_attentions else None,
954
+ )
955
+
956
+ def forward_debug(
957
+ self,
958
+ hidden_state: torch.Tensor,
959
+ grid_thw: torch.Tensor,
960
+ ) -> dict:
961
+ """
962
+ Debug version of forward pass that captures intermediate states.
963
+
964
+ Identical to forward() but saves intermediate outputs at key stages
965
+ for debugging and consistency checking purposes.
966
+
967
+ Args:
968
+ hidden_state: Pre-processed patches from Qwen2VL processor.
969
+ Shape: [total_patches, num_channels, patch_size, patch_size] or [total_patches, patch_dim]
970
+ grid_thw: Grid sizes tensor of shape [num_samples, 3] with [t, h, w] for each sample.
971
+
972
+ Returns:
973
+ dict: Dictionary containing intermediate outputs:
974
+ - "input_pixel_values": Input to the model
975
+ - "after_patch_embed": Embeddings after patch projection
976
+ - "rotary_pos_emb": Rotary position embeddings
977
+ - "after_pre_layernorm": Embeddings after pre-normalization
978
+ - "layer_outputs": Dict mapping layer index to input/output
979
+ - "before_adapter": Final output before merger (same as after_decoder)
980
+ - "after_merger": Output after patch merger
981
+ """
982
+ output = {}
983
+
984
+ # Store input for consistency checking
985
+ output["input_pixel_values"] = hidden_state.clone()
986
+ output["input_grid_thw"] = grid_thw.clone()
987
+
988
+ batch_size = grid_thw.size(0)
989
+ assert batch_size == 1, "Currently only batch_size=1 is supported for forward_debug."
990
+
991
+ # Determine video dimensions for RoPE
992
+ t_frames = grid_thw[0, 0].item()
993
+ height = grid_thw[0, 1].item()
994
+ width = grid_thw[0, 2].item()
995
+
996
+ # 1. Embeddings
997
+ hidden_states = self.embeddings(hidden_state)
998
+ if hidden_states.dim() == 2:
999
+ hidden_states = hidden_states.unsqueeze(0) # [1, total_patches, embed_dim]
1000
+ output["after_patch_embed"] = hidden_states.clone()
1001
+
1002
+ batch_size, total_patches, _ = hidden_states.shape
1003
+
1004
+ # 2. Visible Indices (simplified for debug - use all patches)
1005
+ visible_indices = (
1006
+ torch.arange(total_patches, device=hidden_state.device).unsqueeze(0).expand(batch_size, -1)
1007
+ )
1008
+
1009
+ # 3. RoPE Construction
1010
+ freqs_full = self.video_rope.forward_with_thw(
1011
+ t=64 if t_frames > 1 else 1,
1012
+ h=height,
1013
+ w=width,
1014
+ device=hidden_state.device,
1015
+ )
1016
+
1017
+ # Convert RoPE from row-major to block layout
1018
+ freqs_full_block = convert_rope_to_block_layout(
1019
+ freqs_full, 1 if t_frames == 1 else 64, height, width, spatial_merge_size=2
1020
+ ).unsqueeze(0)
1021
+
1022
+ # Concatenate D/2 + D/2 -> D for applying rope
1023
+ freqs_visible = torch.cat([freqs_full_block, freqs_full_block], dim=-1)
1024
+ output["rotary_pos_emb"] = freqs_visible.clone()
1025
+
1026
+ # 4. Pre-Norm
1027
+ hidden_states = self.layernorm_pre(hidden_states)
1028
+ output["after_pre_layernorm"] = hidden_states.clone()
1029
+
1030
+ # 5. Encoder with layer-by-layer debug
1031
+ encoder_debug_output = self.encoder.forward_debug(
1032
+ hidden_states,
1033
+ attention_mask=None,
1034
+ rotary_pos_emb=freqs_visible,
1035
+ )
1036
+
1037
+ # Extract layer outputs
1038
+ output["layer_outputs"] = encoder_debug_output.get("layer_outputs", {})
1039
+
1040
+ # Get second-to-last layer output for merger (matching forward behavior)
1041
+ # In forward_debug of encoder, final_output is the last layer output
1042
+ final_hidden_states = encoder_debug_output.get("final_output", hidden_states)
1043
+
1044
+ # For consistency with Megatron, we use the second-to-last layer
1045
+ # But forward_debug doesn't easily give us that, so we'll use final
1046
+ # and note that this is the output before merger
1047
+ output["before_adapter"] = final_hidden_states.clone()
1048
+
1049
+ # 6. Post-Norm (if exists)
1050
+ if self.layernorm_post is not None:
1051
+ final_hidden_states = self.layernorm_post(final_hidden_states)
1052
+
1053
+ # 7. Merger
1054
+ merged_output = self.merger(final_hidden_states)
1055
+ output["after_merger"] = merged_output.clone()
1056
+
1057
+ return output
1058
+
1059
+
1060
+ @auto_docstring
1061
+ class LlavaOnevision2Model(LlavaOnevision2PreTrainedModel):
1062
+ base_model_prefix = ""
1063
+ _checkpoint_conversion_mapping = {"^model": "language_model"}
1064
+ # Reference: fix gemma3 grad acc #37208
1065
+ accepts_loss_kwargs = False
1066
+ config: LlavaOnevision2Config
1067
+ _no_split_modules = ["LlavaViTEncoderLayer"]
1068
+
1069
+ def __init__(self, config: LlavaOnevision2Config):
1070
+ super().__init__(config)
1071
+ self.visual = LlavaOnevision2VisionPretrainedModel._from_config(config.vision_config)
1072
+ self.language_model = AutoModel.from_config(config.text_config)
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def get_input_embeddings(self):
1077
+ return self.language_model.get_input_embeddings()
1078
+
1079
+ def set_input_embeddings(self, value):
1080
+ self.language_model.set_input_embeddings(value)
1081
+
1082
+ def set_decoder(self, decoder):
1083
+ self.language_model = decoder
1084
+
1085
+ def get_decoder(self):
1086
+ return self.language_model
1087
+
1088
+ def get_video_features(
1089
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None, patch_positions=None
1090
+ ):
1091
+ """
1092
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
1093
+
1094
+ Args:
1095
+ pixel_values_videos: Pre-processed patches from Qwen2VL processor.
1096
+ `torch.FloatTensor` of shape `(total_patches, num_channels, patch_size, patch_size)`
1097
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1098
+ The temporal, height and width of feature shape of each video in LLM.
1099
+ """
1100
+ # Convert to correct dtype
1101
+ pixel_values_videos = pixel_values_videos.type(self.visual.embeddings.patch_embedding.weight.dtype)
1102
+
1103
+ # Forward through vision model with grid_thw
1104
+ vision_output = self.visual(pixel_values_videos, grid_thw=video_grid_thw, patch_positions=patch_positions)
1105
+
1106
+ # Extract the actual tensor from BaseModelOutputWithPooling
1107
+ if hasattr(vision_output, "last_hidden_state"):
1108
+ video_embeds = vision_output.last_hidden_state
1109
+ else:
1110
+ video_embeds = vision_output[0] # Fallback for tuple output
1111
+
1112
+ # Compute split sizes from video_grid_thw or from input shape
1113
+ if video_grid_thw is not None:
1114
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1115
+ else:
1116
+ # Compute from input shape
1117
+ batch_size = pixel_values_videos.shape[0]
1118
+ split_sizes = [video_embeds.shape[1]] * batch_size
1119
+
1120
+ # Split embeddings per video
1121
+ if len(split_sizes) > 1:
1122
+ video_embeds = torch.split(video_embeds.view(-1, video_embeds.shape[-1]), split_sizes)
1123
+ else:
1124
+ video_embeds = [video_embeds.view(-1, video_embeds.shape[-1])]
1125
+
1126
+ return video_embeds
1127
+
1128
+ def get_image_features(self, pixel_values, image_grid_thw: Optional[torch.LongTensor] = None, patch_positions=None):
1129
+ """
1130
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1131
+
1132
+ Args:
1133
+ pixel_values: Pre-processed patches from Qwen2VL processor.
1134
+ - `torch.FloatTensor` of shape `(total_patches, num_channels, patch_size, patch_size)`
1135
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1136
+ The temporal, height and width of feature shape of each image in LLM.
1137
+ """
1138
+ # Standard format from Qwen2VL processor
1139
+ if pixel_values.dim() == 2:
1140
+ # Convert to correct dtype
1141
+ pixel_values = pixel_values.type(self.visual.embeddings.patch_embedding.weight.dtype)
1142
+
1143
+ # Forward through vision model with grid_thw
1144
+ vision_output = self.visual(pixel_values, grid_thw=image_grid_thw, patch_positions=patch_positions)
1145
+
1146
+ # Extract the actual tensor from BaseModelOutputWithPooling
1147
+ if hasattr(vision_output, "last_hidden_state"):
1148
+ image_embeds = vision_output.last_hidden_state
1149
+ else:
1150
+ image_embeds = vision_output[0]
1151
+
1152
+ # Compute split sizes from grid_thw
1153
+ if image_grid_thw is not None:
1154
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1155
+ else:
1156
+ # Fallback: assume single image
1157
+ split_sizes = [image_embeds.shape[0] if image_embeds.dim() == 2 else image_embeds.shape[1]]
1158
+
1159
+ # Split embeddings per image
1160
+ image_embeds_flat = image_embeds.view(-1, image_embeds.shape[-1])
1161
+ if len(split_sizes) > 1:
1162
+ image_embeds = list(torch.split(image_embeds_flat, split_sizes))
1163
+ else:
1164
+ image_embeds = [image_embeds_flat]
1165
+
1166
+ return image_embeds
1167
+ else:
1168
+ raise ValueError(
1169
+ f"Unsupported pixel_values shape: expected 4D tensor [total_patches, C, H, W], "
1170
+ f"got {pixel_values.shape if hasattr(pixel_values, 'shape') else type(pixel_values)}"
1171
+ )
1172
+
1173
+ def get_placeholder_mask(
1174
+ self,
1175
+ input_ids: torch.LongTensor,
1176
+ inputs_embeds: torch.FloatTensor,
1177
+ image_features: Optional[torch.FloatTensor] = None,
1178
+ video_features: Optional[torch.FloatTensor] = None,
1179
+ ):
1180
+ """
1181
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1182
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1183
+ """
1184
+ if input_ids is None:
1185
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1186
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1187
+ )
1188
+ special_image_mask = special_image_mask.all(-1)
1189
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
1190
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
1191
+ )
1192
+ special_video_mask = special_video_mask.all(-1)
1193
+ else:
1194
+ special_image_mask = input_ids == self.config.image_token_id
1195
+ special_video_mask = input_ids == self.config.video_token_id
1196
+
1197
+ n_image_tokens = special_image_mask.sum()
1198
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1199
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
1200
+ raise ValueError(
1201
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
1202
+ )
1203
+
1204
+ n_video_tokens = special_video_mask.sum()
1205
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1206
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
1207
+ raise ValueError(
1208
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
1209
+ )
1210
+
1211
+ return special_image_mask, special_video_mask
1212
+
1213
+ @auto_docstring
1214
+ def forward(
1215
+ self,
1216
+ input_ids: Optional[torch.LongTensor] = None,
1217
+ attention_mask: Optional[torch.Tensor] = None,
1218
+ position_ids: Optional[torch.LongTensor] = None,
1219
+ past_key_values: Optional[Cache] = None,
1220
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1221
+ use_cache: Optional[bool] = None,
1222
+ output_attentions: Optional[bool] = None,
1223
+ output_hidden_states: Optional[bool] = None,
1224
+ return_dict: Optional[bool] = None,
1225
+ pixel_values: Optional[torch.Tensor] = None,
1226
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1227
+ image_grid_thw: Optional[torch.LongTensor] = None,
1228
+ patch_positions: Optional[torch.LongTensor] = None,
1229
+ video_grid_thw: Optional[torch.LongTensor] = None,
1230
+ cache_position: Optional[torch.LongTensor] = None,
1231
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1232
+ **kwargs: Unpack[TransformersKwargs],
1233
+ ) -> Union[tuple, LlavaOnevision2ModelOutputWithPast]:
1234
+ r"""
1235
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1236
+ The temporal, height and width of feature shape of each image in LLM.
1237
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1238
+ The temporal, height and width of feature shape of each video in LLM.
1239
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1240
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1241
+ """
1242
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1243
+ output_hidden_states = (
1244
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1245
+ )
1246
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1247
+
1248
+ if inputs_embeds is None:
1249
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1250
+
1251
+ image_embeds = None
1252
+
1253
+ if pixel_values is not None:
1254
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw, patch_positions=patch_positions)
1255
+
1256
+ if image_embeds is not None:
1257
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1258
+ image_mask, _ = self.get_placeholder_mask(
1259
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1260
+ )
1261
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1262
+
1263
+ if pixel_values_videos is not None:
1264
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
1265
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1266
+ _, video_mask = self.get_placeholder_mask(
1267
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1268
+ )
1269
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1270
+
1271
+ # Use simple 1D position_ids
1272
+ if position_ids is None:
1273
+ batch_size, seq_length, _ = inputs_embeds.shape
1274
+ if attention_mask is not None:
1275
+ position_ids = attention_mask.long().cumsum(-1) - 1
1276
+ position_ids.masked_fill_(attention_mask == 0, 1)
1277
+ else:
1278
+ position_ids = (
1279
+ torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
1280
+ )
1281
+
1282
+ # Handle cache_position for generation
1283
+ if cache_position is not None and cache_position[0] != 0:
1284
+ position_ids = position_ids + cache_position[0]
1285
+
1286
+ outputs = self.language_model(
1287
+ input_ids=None,
1288
+ position_ids=position_ids,
1289
+ attention_mask=attention_mask,
1290
+ past_key_values=past_key_values,
1291
+ inputs_embeds=inputs_embeds,
1292
+ use_cache=use_cache,
1293
+ output_attentions=output_attentions,
1294
+ output_hidden_states=output_hidden_states,
1295
+ return_dict=True,
1296
+ cache_position=cache_position,
1297
+ **kwargs,
1298
  )
1299
+
1300
+ output = LlavaOnevision2ModelOutputWithPast(
1301
+ last_hidden_state=outputs.last_hidden_state,
1302
+ past_key_values=outputs.past_key_values,
1303
+ hidden_states=outputs.hidden_states,
1304
+ attentions=outputs.attentions,
1305
+ )
1306
+ return output if return_dict else output.to_tuple()
1307
+
1308
+
1309
+ @auto_docstring
1310
+ class LlavaOnevision2ForConditionalGeneration(LlavaOnevision2PreTrainedModel, GenerationMixin):
1311
+ _checkpoint_conversion_mapping = {
1312
+ "^visual": "model.visual",
1313
+ r"^model(?!\.(language_model|visual))": "model.language_model",
1314
+ }
1315
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
1316
+ # Reference: fix gemma3 grad acc #37208
1317
+ accepts_loss_kwargs = False
1318
+
1319
+ def __init__(self, config):
1320
+ super().__init__(config)
1321
+ self.model = LlavaOnevision2Model(config)
1322
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1323
+ self.post_init()
1324
+
1325
+ def get_input_embeddings(self):
1326
+ return self.model.get_input_embeddings()
1327
+
1328
+ def set_input_embeddings(self, value):
1329
+ self.model.set_input_embeddings(value)
1330
+
1331
+ def set_decoder(self, decoder):
1332
+ self.model.set_decoder(decoder)
1333
+
1334
+ def get_decoder(self):
1335
+ return self.model.get_decoder()
1336
+
1337
+ def get_video_features(
1338
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1339
+ ):
1340
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1341
+
1342
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1343
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1344
+
1345
+ # Make modules available through conditional class for BC
1346
+ @property
1347
+ def language_model(self):
1348
+ return self.model.language_model
1349
+
1350
+ @property
1351
+ def visual(self):
1352
+ return self.model.visual
1353
+
1354
+ @can_return_tuple
1355
+ @auto_docstring
1356
+ def forward(
1357
+ self,
1358
+ input_ids: Optional[torch.LongTensor] = None,
1359
+ attention_mask: Optional[torch.Tensor] = None,
1360
+ position_ids: Optional[torch.LongTensor] = None,
1361
+ past_key_values: Optional[Cache] = None,
1362
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1363
+ labels: Optional[torch.LongTensor] = None,
1364
+ use_cache: Optional[bool] = None,
1365
+ output_attentions: Optional[bool] = None,
1366
+ output_hidden_states: Optional[bool] = None,
1367
+ pixel_values: Optional[torch.Tensor] = None,
1368
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1369
+ image_grid_thw: Optional[torch.LongTensor] = None,
1370
+ patch_positions: Optional[torch.LongTensor] = None,
1371
+ video_grid_thw: Optional[torch.LongTensor] = None,
1372
+ cache_position: Optional[torch.LongTensor] = None,
1373
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1374
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1375
+ **kwargs: Unpack[TransformersKwargs],
1376
+ ) -> Union[tuple, LlavaOnevision2CausalLMOutputWithPast]:
1377
+ r"""
1378
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1379
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1380
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1381
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1382
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1383
+ The temporal, height and width of feature shape of each image in LLM.
1384
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1385
+ The temporal, height and width of feature shape of each video in LLM.
1386
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1387
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1388
+
1389
+ Example:
1390
+
1391
+ ```python
1392
+ >>> from PIL import Image
1393
+ >>> import requests
1394
+ >>> from transformers import AutoProcessor, LlavaOnevision2ForConditionalGeneration
1395
+
1396
+ >>> model = LlavaOnevision2ForConditionalGeneration.from_pretrained("Deep-VLM/LLaVA-OneVision-1.5-8B-Instruct-hf", trust_remote_code=True)
1397
+ >>> processor = AutoProcessor.from_pretrained("Deep-VLM/LLaVA-OneVision-1.5-8B-Instruct-hf", trust_remote_code=True)
1398
+
1399
+ >>> messages = [
1400
+ {
1401
+ "role": "user",
1402
+ "content": [
1403
+ {"type": "image"},
1404
+ {"type": "text", "text": "What is shown in this image?"},
1405
+ ],
1406
+ },
1407
+ ]
1408
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1409
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1410
+
1411
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1412
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1413
+
1414
+ >>> # Generate
1415
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1416
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1417
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1418
+ ```"""
1419
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1420
+ output_hidden_states = (
1421
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1422
+ )
1423
+ outputs = self.model(
1424
+ input_ids=input_ids,
1425
+ pixel_values=pixel_values,
1426
+ pixel_values_videos=pixel_values_videos,
1427
+ image_grid_thw=image_grid_thw,
1428
+ patch_positions=patch_positions,
1429
+ video_grid_thw=video_grid_thw,
1430
+ second_per_grid_ts=second_per_grid_ts,
1431
+ position_ids=position_ids,
1432
+ attention_mask=attention_mask,
1433
+ past_key_values=past_key_values,
1434
+ inputs_embeds=inputs_embeds,
1435
+ use_cache=use_cache,
1436
+ output_attentions=output_attentions,
1437
+ output_hidden_states=output_hidden_states,
1438
+ return_dict=True,
1439
+ cache_position=cache_position,
1440
+ **kwargs,
1441
+ )
1442
+
1443
+ hidden_states = outputs[0]
1444
+
1445
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1446
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1447
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1448
+
1449
+ loss = None
1450
+ if labels is not None:
1451
+ loss = self.loss_function(
1452
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1453
+ )
1454
+
1455
+ return LlavaOnevision2CausalLMOutputWithPast(
1456
+ loss=loss,
1457
+ logits=logits,
1458
+ past_key_values=outputs.past_key_values,
1459
+ hidden_states=outputs.hidden_states,
1460
+ attentions=outputs.attentions,
1461
+ )
1462
+
1463
+ def prepare_inputs_for_generation(
1464
+ self,
1465
+ input_ids,
1466
+ past_key_values=None,
1467
+ attention_mask=None,
1468
+ inputs_embeds=None,
1469
+ cache_position=None,
1470
+ position_ids=None,
1471
+ use_cache=True,
1472
+ pixel_values=None,
1473
+ pixel_values_videos=None,
1474
+ image_grid_thw=None,
1475
+ patch_positions=None,
1476
+ video_grid_thw=None,
1477
+ second_per_grid_ts=None,
1478
+ **kwargs,
1479
+ ):
1480
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1481
+ model_inputs = super().prepare_inputs_for_generation(
1482
+ input_ids,
1483
+ past_key_values=past_key_values,
1484
+ attention_mask=attention_mask,
1485
+ inputs_embeds=inputs_embeds,
1486
+ cache_position=cache_position,
1487
+ position_ids=position_ids,
1488
+ pixel_values=pixel_values,
1489
+ pixel_values_videos=pixel_values_videos,
1490
+ image_grid_thw=image_grid_thw,
1491
+ video_grid_thw=video_grid_thw,
1492
+ second_per_grid_ts=second_per_grid_ts,
1493
+ patch_positions=patch_positions,
1494
+ use_cache=use_cache,
1495
+ **kwargs,
1496
+ )
1497
+
1498
+ if cache_position[0] != 0:
1499
+ model_inputs["pixel_values"] = None
1500
+ model_inputs["pixel_values_videos"] = None
1501
+
1502
+ return model_inputs
1503
+
1504
+ def _get_image_nums_and_video_nums(
1505
+ self,
1506
+ input_ids: Optional[torch.LongTensor],
1507
+ inputs_embeds: Optional[torch.Tensor] = None,
1508
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1509
+ """
1510
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1511
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1512
+
1513
+ Args:
1514
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1515
+ Indices of input sequence tokens in the vocabulary.
1516
+
1517
+ Returns:
1518
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1519
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1520
+ """
1521
+ image_token_id = self.config.image_token_id
1522
+ video_token_id = self.config.video_token_id
1523
+ vision_start_token_id = self.config.vision_start_token_id
1524
+
1525
+ if inputs_embeds is not None:
1526
+ vision_start_mask = (
1527
+ inputs_embeds
1528
+ == self.get_input_embeddings()(
1529
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1530
+ )
1531
+ )[..., 0]
1532
+ image_mask = (
1533
+ inputs_embeds
1534
+ == self.get_input_embeddings()(
1535
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1536
+ )
1537
+ )[..., 0]
1538
+ video_mask = (
1539
+ inputs_embeds
1540
+ == self.get_input_embeddings()(
1541
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1542
+ )
1543
+ )[..., 0]
1544
+ else:
1545
+ vision_start_mask = input_ids == vision_start_token_id
1546
+ image_mask = input_ids == image_token_id
1547
+ video_mask = input_ids == video_token_id
1548
+
1549
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1550
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1551
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1552
+
1553
+ return image_nums, video_nums
1554
+
1555
+ def _expand_inputs_for_generation(
1556
+ self,
1557
+ expand_size: int = 1,
1558
+ is_encoder_decoder: bool = False,
1559
+ input_ids: Optional[torch.LongTensor] = None,
1560
+ **model_kwargs,
1561
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1562
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1563
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1564
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1565
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1566
+
1567
+ if expand_size == 1:
1568
+ return input_ids, model_kwargs
1569
+
1570
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1571
+
1572
+ def _expand_dict_for_generation_visual(dict_to_expand):
1573
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1574
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1575
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1576
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1577
+ )
1578
+
1579
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1580
+ samples = torch.split(x, lengths)
1581
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1582
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1583
+ return result
1584
+
1585
+ for key in dict_to_expand:
1586
+ if key == "pixel_values":
1587
+ # split images into samples
1588
+ samples = torch.split(image_grid_thw, list(image_nums))
1589
+ # compute the sequence length of images for each sample
1590
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1591
+ dict_to_expand[key] = _repeat_interleave_samples(
1592
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1593
+ )
1594
+ elif key == "image_grid_thw":
1595
+ # get the num of images for each sample
1596
+ lengths = list(image_nums)
1597
+ dict_to_expand[key] = _repeat_interleave_samples(
1598
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1599
+ )
1600
+ elif key == "pixel_values_videos":
1601
+ samples = torch.split(video_grid_thw, list(video_nums))
1602
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1603
+ dict_to_expand[key] = _repeat_interleave_samples(
1604
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1605
+ )
1606
+ elif key == "video_grid_thw":
1607
+ lengths = list(video_nums)
1608
+ dict_to_expand[key] = _repeat_interleave_samples(
1609
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1610
+ )
1611
+ elif key == "second_per_grid_ts":
1612
+ dict_to_expand[key] = _repeat_interleave_samples(
1613
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1614
+ )
1615
+ return dict_to_expand
1616
+
1617
+ def _expand_dict_for_generation(dict_to_expand):
1618
+ for key in dict_to_expand:
1619
+ if (
1620
+ key != "cache_position"
1621
+ and dict_to_expand[key] is not None
1622
+ and isinstance(dict_to_expand[key], torch.Tensor)
1623
+ and key not in visual_keys
1624
+ ):
1625
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1626
+ return dict_to_expand
1627
+
1628
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1629
+
1630
+ if input_ids is not None:
1631
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1632
+
1633
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1634
+
1635
+ if is_encoder_decoder:
1636
+ if model_kwargs.get("encoder_outputs") is None:
1637
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1638
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1639
+
1640
+ return input_ids, model_kwargs
1641
+
1642
+
1643
+ __all__ = [
1644
+ "LlavaOnevision2ForConditionalGeneration",
1645
+ "LlavaOnevision2Model",
1646
+ "LlavaOnevision2PreTrainedModel",
1647
+ "LlavaOnevision2VisionPretrainedModel",
1648
+ # Vision components
1649
+ "VisionRotaryEmbedding",
1650
+ "LlavaViTEmbeddings",
1651
+ "LlavaViTFlashAttention2",
1652
+ "LlavaViTEncoderLayer",
1653
+ "LlavaViTEncoder",
1654
+ "LlavaOnevision2VisionPatchMerger",
1655
+ "Siglip2MultiheadAttentionPoolingHead",
1656
+ ]