yiyexy commited on
Commit
9908b86
·
verified ·
1 Parent(s): ceb04eb

Update modeling_onevision_encoder.py

Browse files
Files changed (1) hide show
  1. modeling_onevision_encoder.py +346 -1328
modeling_onevision_encoder.py CHANGED
@@ -1,91 +1,113 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Optional, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
- from torch.nn import LayerNorm
7
 
8
- from transformers import AutoModel
9
- from transformers.cache_utils import Cache
10
- from transformers.generation import GenerationMixin
11
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
12
  from transformers.modeling_utils import PreTrainedModel
13
  from transformers.models.siglip.modeling_siglip import SiglipMLP
14
- from transformers.processing_utils import Unpack
15
  from transformers.utils import (
16
- TransformersKwargs,
17
- auto_docstring,
18
- can_return_tuple,
19
- is_flash_attn_2_available,
20
  replace_return_docstrings,
21
  )
22
 
23
- from .configuration_llava_onevision2 import LlavaOnevision2Config, LlavaOnevision2VisionConfig
24
 
25
 
26
- if is_flash_attn_2_available():
27
  from flash_attn import flash_attn_func
28
 
 
 
 
29
 
30
- @dataclass
31
- @auto_docstring(
32
- custom_intro="""
33
- Base class for Llava-Onevision-1.5 outputs, with hidden states and attentions.
34
- """
35
- )
36
- class LlavaOnevision2ModelOutputWithPast(ModelOutput):
37
- r"""
38
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
39
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
40
 
41
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
42
- `past_key_values` input) to speed up sequential decoding.
43
- """
44
 
45
- last_hidden_state: Optional[torch.FloatTensor] = None
46
- past_key_values: Optional[Cache] = None
47
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
48
- attentions: Optional[tuple[torch.FloatTensor]] = None
49
 
 
 
 
 
50
 
51
- @dataclass
52
- @auto_docstring(
53
- custom_intro="""
54
- Base class for Llava-Onevision-1.5 causal language model (or autoregressive) outputs.
55
- """
56
- )
57
- class LlavaOnevision2CausalLMOutputWithPast(ModelOutput):
58
- r"""
59
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
60
- Language modeling loss (for next-token prediction).
61
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
62
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
63
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
64
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
65
-
66
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
67
- `past_key_values` input) to speed up sequential decoding.
68
- """
69
 
70
- loss: Optional[torch.FloatTensor] = None
71
- logits: Optional[torch.FloatTensor] = None
72
- past_key_values: Optional[Cache] = None
73
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
74
- attentions: Optional[tuple[torch.FloatTensor]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
  # ---------------------------------------------------------------------------
78
- # Vision Rotary Embedding
79
  # ---------------------------------------------------------------------------
80
 
81
 
82
- class VisionRotaryEmbedding(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  """
84
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
85
- Supports both grid_thw-based and explicit position-based RoPE computation.
86
  """
87
 
88
- def __init__(self, config: LlavaOnevision2VisionConfig):
89
  super().__init__()
90
  head_dim = config.hidden_size // config.num_attention_heads
91
  base = config.rope_theta
@@ -98,7 +120,6 @@ class VisionRotaryEmbedding(nn.Module):
98
  self.head_dim = head_dim
99
  self.half = half
100
 
101
- # 4:6:6 split for T:H:W
102
  unit = half // 16
103
  self.t_size = 4 * unit
104
  self.h_size = 6 * unit
@@ -120,121 +141,90 @@ class VisionRotaryEmbedding(nn.Module):
120
  persistent=False,
121
  )
122
 
123
- def forward(self, grid_thw: torch.Tensor) -> torch.Tensor:
124
- """
125
- Compute rotary position embeddings from grid_thw (Qwen2VL style).
126
-
127
- Args:
128
- grid_thw: [num_samples, 3] tensor with [t, h, w] for each sample
129
 
130
- Returns:
131
- freqs: [total_seq_len, half] tensor of position frequencies
132
- """
133
- device = grid_thw.device
134
  inv_t = self.inv_freq_t.to(device=device)
135
  inv_h = self.inv_freq_h.to(device=device)
136
  inv_w = self.inv_freq_w.to(device=device)
137
 
138
- all_freqs = []
139
- for sample_thw in grid_thw:
140
- t, h, w = sample_thw[0].item(), sample_thw[1].item(), sample_thw[2].item()
141
-
142
- # Compute frequency tables
143
- ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
144
- fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
145
- fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
146
-
147
- # Build position indices for this sample
148
- t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
149
- h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
150
- w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
151
 
152
- # Concatenate frequencies: [seq_len, half]
153
- sample_freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
154
- all_freqs.append(sample_freqs)
155
 
156
- return torch.cat(all_freqs, dim=0)
 
157
 
158
  def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
159
  """
160
  Compute rotary position embeddings from explicit patch positions.
161
 
162
  Args:
163
- patch_positions: [seq_len, 3] tensor with [t, h, w] positions for each patch
164
 
165
  Returns:
166
- freqs: [seq_len, half] tensor of position frequencies
167
  """
168
  device = patch_positions.device
169
  inv_t = self.inv_freq_t.to(device=device)
170
  inv_h = self.inv_freq_h.to(device=device)
171
  inv_w = self.inv_freq_w.to(device=device)
172
 
173
- t_pos = patch_positions[:, 0].float()
174
- h_pos = patch_positions[:, 1].float()
175
- w_pos = patch_positions[:, 2].float()
176
 
177
- ft = torch.outer(t_pos, inv_t)
178
- fh = torch.outer(h_pos, inv_h)
179
- fw = torch.outer(w_pos, inv_w)
 
180
 
181
  return torch.cat([ft, fh, fw], dim=-1)
182
 
183
- def forward_with_thw(self, t: int, h: int, w: int, device=None) -> torch.Tensor:
184
- """
185
- Compute rotary position embeddings from explicit t, h, w dimensions.
186
 
187
- Args:
188
- t: Number of temporal frames
189
- h: Number of height patches
190
- w: Number of width patches
191
- device: Target device
192
 
193
- Returns:
194
- freqs: [t*h*w, half] tensor of position frequencies
195
- """
196
- if device is None:
197
- device = self.inv_freq_t.device
 
 
198
 
199
- inv_t = self.inv_freq_t.to(device=device)
200
- inv_h = self.inv_freq_h.to(device=device)
201
- inv_w = self.inv_freq_w.to(device=device)
202
 
203
- ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
204
- fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
205
- fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
206
 
207
- t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
208
- h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
209
- w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
210
 
211
- freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
212
- return freqs
213
 
214
 
215
  # ---------------------------------------------------------------------------
216
- # Patch Embedding
217
  # ---------------------------------------------------------------------------
218
 
219
 
220
- class LlavaViTEmbeddings(nn.Module):
221
- """
222
- Patch embedding layer that converts pre-processed patches to embeddings.
223
-
224
- This module is designed to receive patches that have already been extracted
225
- and arranged by the Qwen2VL image processor in 2x2 block spatial order.
226
-
227
- Input format: [total_patches, num_channels, patch_size, patch_size]
228
- Output format: [total_patches, embed_dim]
229
- """
230
-
231
- def __init__(self, config: LlavaOnevision2VisionConfig):
232
  super().__init__()
233
  self.config = config
234
  self.embed_dim = config.hidden_size
235
  self.image_size = config.image_size
236
  self.patch_size = config.patch_size
237
- self.in_channels = config.num_channels
238
 
239
  self.patch_embedding = nn.Conv2d(
240
  in_channels=config.num_channels,
@@ -244,284 +234,101 @@ class LlavaViTEmbeddings(nn.Module):
244
  bias=False,
245
  )
246
 
247
- def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
248
- target_dtype = self.patch_embedding.weight.dtype
249
- hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
250
- hidden_states = self.patch_embedding(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
251
-
252
- return hidden_states
253
-
254
-
255
- # ---------------------------------------------------------------------------
256
- # Patch Merger
257
- # ---------------------------------------------------------------------------
258
-
259
-
260
- class LlavaOnevision2VisionPatchMerger(nn.Module):
261
- """
262
- Patch merger that merges spatial_merge_size x spatial_merge_size patches into one.
263
 
264
- This module is designed to work with Qwen2VL-style patch processing where patches
265
- are already arranged in 2x2 block order by the image processor.
266
- """
267
 
268
- def __init__(
269
- self,
270
- dim: int,
271
- context_dim: int,
272
- spatial_merge_size: int = 2,
273
- layer_norm_eps: float = 1e-05,
274
- ) -> None:
275
- super().__init__()
276
- self.hidden_size = context_dim * (spatial_merge_size**2)
277
- self.ln_q = LayerNorm(context_dim, eps=layer_norm_eps)
278
- self.mlp = nn.Sequential(
279
- nn.Linear(self.hidden_size, self.hidden_size),
280
- nn.GELU(),
281
- nn.Linear(self.hidden_size, dim),
282
- )
283
- self.spatial_merge_size = spatial_merge_size
284
 
285
- def forward(self, x: torch.Tensor) -> torch.Tensor:
286
- """
287
- Merge patches from Qwen2VL-style input.
288
 
289
- The input patches are already arranged in 2x2 block order by the image processor,
290
- so we simply need to apply LayerNorm, reshape, and project through MLP.
 
291
 
292
- Args:
293
- x: Input tensor of shape [batch_size, seq_len, hidden_size] or [seq_len, hidden_size]
294
- where seq_len = t * h * w (patches in 2x2 block order)
295
 
296
- Returns:
297
- Merged tensor of shape [batch_size, seq_len // spatial_merge_size^2, dim]
298
- or [seq_len // spatial_merge_size^2, dim]
299
- """
300
- x = self.ln_q(x).view(-1, self.hidden_size)
301
- x = self.mlp(x)
302
- return x
303
 
 
 
304
 
305
- def rotate_half(x):
306
- """
307
- Interleaved rotation to match Source model's implementation.
308
- (x1, x2, x3, x4) -> (-x2, x1, -x4, x3)
309
- """
310
- x_even = x[..., ::2]
311
- x_odd = x[..., 1::2]
312
- return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
313
-
314
-
315
- def get_norm_layer(config):
316
- if config.layer_norm_type == "rms_norm":
317
- return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
318
- else:
319
- return nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
320
-
321
-
322
- def apply_rotary_pos_emb(q, k, freqs):
323
- # q, k: (B, H, L, D)
324
- # freqs: (B, L, D)
325
- orig_q_dtype = q.dtype
326
- orig_k_dtype = k.dtype
327
- q, k = q.float(), k.float()
328
- # We need to broadcast freqs to match heads
329
- # (B, L, D) -> (B, 1, L, D)
330
- # Keep the same dtype as q, k to avoid memory doubling from float32 promotion
331
- cos = freqs.cos().unsqueeze(1).float()
332
- sin = freqs.sin().unsqueeze(1).float()
333
-
334
- q_embed = (q * cos) + (rotate_half(q) * sin)
335
- k_embed = (k * cos) + (rotate_half(k) * sin)
336
- q_embed = q_embed.to(orig_q_dtype)
337
- k_embed = k_embed.to(orig_k_dtype)
338
- return q_embed, k_embed
339
-
340
- def convert_rope_to_block_layout(
341
- freqs: torch.Tensor, t: int, h: int, w: int, spatial_merge_size: int = 2
342
- ) -> torch.Tensor:
343
- """
344
- Convert RoPE from row-major order (1x1 layout) to 2x2 block layout.
345
-
346
- The image processor arranges patches in 2x2 blocks when spatial_merge_size=2:
347
- - Row-major order: [p(0,0), p(0,1), p(0,2), p(0,3), ..., p(1,0), p(1,1), ...]
348
- - Block order: [p(0,0), p(0,1), p(1,0), p(1,1)], [p(0,2), p(0,3), p(1,2), p(1,3)], ...
349
-
350
- Args:
351
- freqs: RoPE frequencies in row-major order, shape [t*h*w, half]
352
- t: temporal dimension
353
- h: height (unmerged patch count)
354
- w: width (unmerged patch count)
355
- spatial_merge_size: size of spatial merge blocks (default: 2)
356
-
357
- Returns:
358
- torch.Tensor: RoPE frequencies in 2x2 block order, same shape [t*h*w, half]
359
- """
360
- sms = spatial_merge_size
361
- if sms == 1:
362
- return freqs
363
-
364
- half = freqs.shape[-1]
365
-
366
- # freqs shape: [t*h*w, half]
367
- # Reshape to [t, h, w, half]
368
- freqs = freqs.view(t, h, w, half)
369
-
370
- # Calculate merged dimensions
371
- h_merged = h // sms
372
- w_merged = w // sms
373
-
374
- # Reshape to [t, h_merged, sms, w_merged, sms, half]
375
- freqs = freqs.view(t, h_merged, sms, w_merged, sms, half)
376
-
377
- # Permute to [t, h_merged, w_merged, sms_h, sms_w, half] - 2x2 block order
378
- freqs = freqs.permute(0, 1, 3, 2, 4, 5).contiguous()
379
-
380
- # Reshape back to [t*h*w, half]
381
- freqs = freqs.view(t * h * w, half)
382
-
383
- return freqs
384
-
385
-
386
- def convert_rope_to_block_layout_by_positions(
387
- freqs: torch.Tensor,
388
- patch_positions: torch.Tensor,
389
- spatial_merge_size: int = 2,
390
- grid_thw: Optional[torch.Tensor] = None,
391
- ) -> torch.Tensor:
392
- """
393
- Convert RoPE from row-major order to 2x2 block layout, grouping by temporal index.
394
-
395
- This function automatically groups patches by their temporal index (t) from patch_positions,
396
- then applies 2x2 spatial reordering within each temporal group.
397
-
398
- Optimized version: if all frames have the same spatial size, use vectorized operations.
399
-
400
- Args:
401
- freqs: RoPE frequencies in row-major order, shape [seq_len, half]
402
- patch_positions: Patch positions tensor, shape [seq_len, 3] with [t, h, w] for each patch
403
- spatial_merge_size: size of spatial merge blocks (default: 2)
404
- grid_thw: Optional grid_thw tensor for reliable h, w extraction
405
-
406
- Returns:
407
- torch.Tensor: RoPE frequencies in 2x2 block order, same shape [seq_len, half]
408
- """
409
- sms = spatial_merge_size
410
- if sms == 1:
411
- return freqs
412
-
413
- half = freqs.shape[-1]
414
- seq_len = freqs.shape[0]
415
-
416
- # Get temporal indices
417
- t_indices = patch_positions[:, 0]
418
-
419
- # Find unique t values and their counts (preserving order)
420
- unique_t, inverse_indices, counts = torch.unique_consecutive(
421
- t_indices, return_inverse=True, return_counts=True
422
- )
423
-
424
- num_groups = unique_t.shape[0]
425
-
426
- # Fast path: single image with grid_thw available
427
- if num_groups == 1 and grid_thw is not None:
428
- height = grid_thw[0, 1].item()
429
- width = grid_thw[0, 2].item()
430
- return convert_rope_to_block_layout(freqs, t=1, h=height, w=width, spatial_merge_size=sms)
431
-
432
- # Fast path: single image, square
433
- if num_groups == 1:
434
- hw = int(seq_len ** 0.5)
435
- if hw * hw == seq_len:
436
- return convert_rope_to_block_layout(freqs, t=1, h=hw, w=hw, spatial_merge_size=sms)
437
-
438
- # Check if all groups have the same size (common case for videos)
439
- # This allows vectorized processing
440
- first_count = counts[0].item()
441
- all_same_size = torch.all(counts == first_count).item()
442
-
443
- if all_same_size:
444
- # Vectorized path: all frames have same spatial size
445
- group_size = first_count
446
- hw = int(group_size ** 0.5)
447
-
448
- if hw * hw == group_size:
449
- # Square frames: use fully vectorized convert_rope_to_block_layout
450
- # Reshape freqs to [num_groups, h, w, half] and process as batch
451
- return convert_rope_to_block_layout(
452
- freqs, t=num_groups, h=hw, w=hw, spatial_merge_size=sms
453
- )
454
- elif grid_thw is not None:
455
- # Non-square but have grid_thw: get h, w from grid_thw
456
- height = grid_thw[0, 1].item()
457
- width = grid_thw[0, 2].item()
458
- return convert_rope_to_block_layout(
459
- freqs, t=num_groups, h=height, w=width, spatial_merge_size=sms
460
  )
461
 
462
- # Slow path: variable frame sizes, process each group separately
463
- # Pre-compute cumulative offsets to avoid repeated slicing
464
- cum_counts = torch.cumsum(counts, dim=0)
465
- start_indices = torch.cat([torch.tensor([0], device=counts.device), cum_counts[:-1]])
466
 
467
- result_freqs = torch.empty_like(freqs)
 
 
 
468
 
469
- for group_idx in range(num_groups):
470
- start_idx = start_indices[group_idx].item()
471
- group_size = counts[group_idx].item()
472
- end_idx = start_idx + group_size
 
 
 
 
473
 
474
- # Infer spatial dimensions
475
- hw = int(group_size ** 0.5)
476
- if hw * hw == group_size:
477
- h, w = hw, hw
478
- else:
479
- h, w = _infer_hw_from_positions(patch_positions[start_idx:end_idx], sms)
480
 
481
- # Apply block layout conversion
482
- result_freqs[start_idx:end_idx] = convert_rope_to_block_layout(
483
- freqs[start_idx:end_idx], t=1, h=h, w=w, spatial_merge_size=sms
484
- )
485
 
486
- return result_freqs
 
487
 
 
 
488
 
489
- def _infer_hw_from_positions(
490
- group_positions: torch.Tensor,
491
- spatial_merge_size: int = 2
492
- ) -> tuple[int, int]:
493
- """
494
- Infer height and width from patch positions within a temporal group.
495
 
496
- Args:
497
- group_positions: Patch positions for one temporal group, shape [group_size, 3]
498
- spatial_merge_size: size of spatial merge blocks
499
 
500
- Returns:
501
- tuple[int, int]: (height, width) of the spatial grid
502
- """
503
- # Get unique h and w values
504
- h_values = group_positions[:, 1]
505
- w_values = group_positions[:, 2]
506
 
507
- h_unique = torch.unique(h_values)
508
- w_unique = torch.unique(w_values)
509
 
510
- h = h_unique.shape[0]
511
- w = w_unique.shape[0]
512
 
513
- # Validate dimensions are divisible by spatial_merge_size
514
- assert h % spatial_merge_size == 0, f"Height {h} not divisible by {spatial_merge_size}"
515
- assert w % spatial_merge_size == 0, f"Width {w} not divisible by {spatial_merge_size}"
516
 
517
- return h, w
518
 
519
- class LlavaViTFlashAttention2(nn.Module):
520
  """
521
  Multi-headed attention with RoPE support using Flash Attention 2.
 
 
522
  """
523
 
524
- def __init__(self, config: LlavaOnevision2VisionConfig):
525
  super().__init__()
526
  self.config = config
527
  self.embed_dim = config.hidden_size
@@ -534,8 +341,11 @@ class LlavaViTFlashAttention2(nn.Module):
534
 
535
  self.scale = self.head_dim**-0.5
536
  self.dropout = config.attention_dropout
537
- self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3)
538
- self.proj = nn.Linear(self.embed_dim, self.embed_dim)
 
 
 
539
 
540
  def forward(
541
  self,
@@ -543,22 +353,20 @@ class LlavaViTFlashAttention2(nn.Module):
543
  attention_mask: Optional[torch.Tensor] = None,
544
  rotary_pos_emb: Optional[torch.Tensor] = None,
545
  output_attentions: bool = False,
546
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
547
  """
548
  Forward pass using Flash Attention 2.
549
  """
550
  batch_size, q_len, _ = hidden_states.size()
551
- q, k, v = (
552
- self.qkv(hidden_states)
553
- .reshape(batch_size, q_len, 3, self.num_heads, self.head_dim)
554
- .permute(2, 0, 1, 3, 4)
555
- .unbind(0)
556
- )
557
 
558
  # Flash Attention requires (B, L, H, D) format
559
- query_states = q
560
- key_states = k
561
- value_states = v
562
 
563
  # Apply RoPE if provided
564
  if rotary_pos_emb is not None:
@@ -571,10 +379,10 @@ class LlavaViTFlashAttention2(nn.Module):
571
  query_states = query_states.transpose(1, 2)
572
  key_states = key_states.transpose(1, 2)
573
 
574
- # FIX: Removed the explicit float32 check and downcast.
575
- # We assume input is already correct (bf16/fp16) thanks to RoPE fix.
576
-
577
  # Flash Attention forward pass
 
 
 
578
  attn_output = flash_attn_func(
579
  query_states,
580
  key_states,
@@ -588,19 +396,33 @@ class LlavaViTFlashAttention2(nn.Module):
588
  attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
589
 
590
  # No extra casting here.
591
- # attn_output = self.out_proj(attn_output)
592
- attn_output = self.proj(attn_output)
593
 
594
  return attn_output, None
595
 
596
 
597
- class LlavaViTEncoderLayer(nn.Module):
598
- """Vision encoder layer with pre-norm and Flash Attention 2."""
 
 
599
 
600
- def __init__(self, config: LlavaOnevision2VisionConfig):
 
 
601
  super().__init__()
602
  self.embed_dim = config.hidden_size
603
- self.self_attn = LlavaViTFlashAttention2(config)
 
 
 
 
 
 
 
 
 
 
 
604
  self.layer_norm1 = get_norm_layer(config)
605
  self.mlp = SiglipMLP(config)
606
  self.layer_norm2 = get_norm_layer(config)
@@ -611,7 +433,7 @@ class LlavaViTEncoderLayer(nn.Module):
611
  attention_mask: Optional[torch.Tensor] = None,
612
  rotary_pos_emb: Optional[torch.Tensor] = None,
613
  output_attentions: bool = False,
614
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
615
  residual = hidden_states
616
  hidden_states = self.layer_norm1(hidden_states)
617
 
@@ -632,13 +454,11 @@ class LlavaViTEncoderLayer(nn.Module):
632
  return outputs
633
 
634
 
635
- class LlavaViTEncoder(nn.Module):
636
- def __init__(self, config: LlavaOnevision2VisionConfig):
637
  super().__init__()
638
  self.config = config
639
- self.layers = nn.ModuleList([LlavaViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
640
- # Gradient checkpointing support
641
- self.gradient_checkpointing = False
642
 
643
  def forward(
644
  self,
@@ -656,21 +476,12 @@ class LlavaViTEncoder(nn.Module):
656
  if output_hidden_states:
657
  all_hidden_states = all_hidden_states + (hidden_states,)
658
 
659
- if self.gradient_checkpointing and self.training:
660
- layer_outputs = self._gradient_checkpointing_func(
661
- layer.__call__,
662
- hidden_states,
663
- attention_mask,
664
- rotary_pos_emb,
665
- output_attentions,
666
- )
667
- else:
668
- layer_outputs = layer(
669
- hidden_states,
670
- attention_mask=attention_mask,
671
- rotary_pos_emb=rotary_pos_emb,
672
- output_attentions=output_attentions,
673
- )
674
 
675
  hidden_states = layer_outputs[0]
676
 
@@ -689,124 +500,54 @@ class LlavaViTEncoder(nn.Module):
689
  attentions=all_self_attentions,
690
  )
691
 
692
- def forward_debug(
693
- self,
694
- hidden_states: torch.Tensor,
695
- attention_mask: Optional[torch.Tensor] = None,
696
- rotary_pos_emb: Optional[torch.Tensor] = None,
697
- ) -> dict:
698
- """
699
- Forward pass with layer-by-layer debug outputs for consistency checking.
700
 
701
- Returns:
702
- dict: Contains:
703
- - 'input_hidden_states': Input to the encoder
704
- - 'input_rotary_pos_emb': Rotary position embeddings input
705
- - 'layer_outputs': Dict mapping layer index to output after that layer
706
- - 'final_output': Final encoder output
707
- """
708
- output = {}
709
-
710
- # Save input
711
- output["input_hidden_states"] = hidden_states.clone()
712
- if rotary_pos_emb is not None:
713
- output["input_rotary_pos_emb"] = rotary_pos_emb.clone()
714
-
715
- # Layer-by-layer outputs
716
- layer_outputs = {}
717
-
718
- for layer_idx, layer in enumerate(self.layers):
719
- # Save input to this layer
720
- layer_outputs[f"layer_{layer_idx}_input"] = hidden_states.clone()
721
-
722
- # Forward through layer
723
- layer_result = layer(
724
- hidden_states,
725
- attention_mask=attention_mask,
726
- rotary_pos_emb=rotary_pos_emb,
727
- output_attentions=False,
728
- )
729
- hidden_states = layer_result[0]
730
-
731
- # Save output of this layer
732
- layer_outputs[f"layer_{layer_idx}_output"] = hidden_states.clone()
733
-
734
- output["layer_outputs"] = layer_outputs
735
- output["final_output"] = hidden_states.clone()
736
-
737
- return output
738
 
739
 
740
- class LlavaOnevision2PreTrainedModel(PreTrainedModel):
741
- config_class = LlavaOnevision2Config
742
- base_model_prefix = "model"
 
 
 
 
743
  supports_gradient_checkpointing = True
 
744
  _supports_flash_attn_2 = True
745
- _no_split_modules = ["LlavaViTEncoderLayer"]
746
- _skip_keys_device_placement = "past_key_values"
747
- _supports_flash_attn = True
748
- _supports_sdpa = True
749
 
750
  def _init_weights(self, module):
751
- super()._init_weights(module)
752
- # Custom weight initialization can be added here if needed
753
- # For LlavaOnevision2VisionPretrainedModel, we rely on default initialization
754
-
755
-
756
- class Siglip2MultiheadAttentionPoolingHead(nn.Module):
757
- """
758
- Multi-Head Attention Pooling with a learned probe (PMA-style).
759
- """
760
-
761
- def __init__(self, config: LlavaOnevision2VisionConfig):
762
- super().__init__()
763
- self.embed_dim = config.hidden_size
764
- self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
765
- self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
766
- self.norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
767
- self.mlp = SiglipMLP(config)
768
-
769
- def forward(self, hidden_states):
770
- batch_size = hidden_states.shape[0]
771
- probe = self.probe.repeat(batch_size, 1, 1)
772
-
773
- attn_output, _ = self.attention(probe, hidden_states, hidden_states)
774
-
775
- residual = attn_output
776
- attn_output = self.norm(attn_output)
777
- attn_output = residual + self.mlp(attn_output)
778
-
779
- return attn_output[:, 0]
780
-
781
- # ---------------------------------------------------------------------------
782
- # Vision Model
783
- # ---------------------------------------------------------------------------
784
-
785
-
786
- class LlavaOnevision2VisionPretrainedModel(LlavaOnevision2PreTrainedModel):
787
- """
788
- LLaVA-OneVision 2.0 Vision Model.
789
-
790
- This vision model is designed to work with Qwen2VL-style image processing:
791
- - Receives pre-processed patches in 2x2 block spatial order
792
- - Applies RoPE with matching 2x2 block layout conversion
793
- - Accepts explicit patch_positions for RoPE computation
794
-
795
- Input format:
796
- hidden_state: [total_patches, num_channels, patch_size, patch_size]
797
- grid_thw: [num_samples, 3] with [t, h, w] for each sample
798
- """
799
-
800
- def __init__(self, config: LlavaOnevision2VisionConfig):
801
  super().__init__(config)
802
  self.config = config
803
- self.spatial_merge_size = config.spatial_merge_size
804
 
805
- # Vision components
806
- self.embeddings = LlavaViTEmbeddings(config)
807
  self.layernorm_pre = get_norm_layer(config)
808
- self.encoder = LlavaViTEncoder(config)
809
- self.video_rope = VisionRotaryEmbedding(config)
810
 
811
  if config.use_head:
812
  self.layernorm_post = get_norm_layer(config)
@@ -815,842 +556,119 @@ class LlavaOnevision2VisionPretrainedModel(LlavaOnevision2PreTrainedModel):
815
  self.layernorm_post = None
816
  self.head = None
817
 
818
- self.merger = LlavaOnevision2VisionPatchMerger(
819
- dim=config.out_hidden_size,
820
- context_dim=config.hidden_size,
821
- spatial_merge_size=config.spatial_merge_size,
822
- layer_norm_eps=config.layer_norm_eps,
823
- )
824
-
825
  self.post_init()
826
 
827
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=LlavaOnevision2VisionConfig)
 
828
  def forward(
829
  self,
830
- hidden_state: torch.Tensor,
831
- grid_thw: Optional[torch.Tensor] = None,
832
  patch_positions: Optional[torch.Tensor] = None,
833
  output_attentions: Optional[bool] = None,
834
  output_hidden_states: Optional[bool] = None,
835
  return_dict: Optional[bool] = None,
836
- skip_merger: Optional[bool] = False,
837
  ) -> Union[tuple, BaseModelOutputWithPooling]:
838
  r"""
839
- Forward pass for vision model.
840
 
841
- This method accepts pre-processed patches from Qwen2VL image processor and applies
842
- RoPE (Rotary Position Embedding) in 2x2 block layout to match the spatial arrangement
843
- of patches.
844
 
845
- Args:
846
- hidden_state: Pre-processed patches from Qwen2VL processor.
847
- Shape: [total_patches, num_channels, patch_size, patch_size]
848
- grid_thw: Grid sizes tensor of shape [num_samples, 3] with [t, h, w] for each sample.
849
- Required for computing RoPE and handling visible indices.
850
- patch_positions: Optional explicit patch positions for RoPE computation.
851
- output_attentions: Whether to return attention weights.
852
- output_hidden_states: Whether to return all hidden states.
853
- return_dict: Whether to return a ModelOutput instead of tuple.
854
- skip_merger: If True, skip patch merger (useful for consistency checking).
855
 
856
- Returns:
857
- BaseModelOutputWithPooling with last_hidden_state containing merged features.
 
 
 
 
 
 
858
  """
859
- output_attentions = (
860
- output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
861
- )
862
  output_hidden_states = (
863
- output_hidden_states
864
- if output_hidden_states is not None
865
- else getattr(self.config, "output_hidden_states", False)
866
  )
867
- return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  # 1. Embeddings
870
- # Note: embeddings returns [total_patches, embed_dim], we need to add batch dimension
871
- hidden_states = self.embeddings(hidden_state)
872
- if hidden_states.dim() == 2:
873
- hidden_states = hidden_states.unsqueeze(0) # [1, total_patches, embed_dim]
874
  batch_size, total_patches, _ = hidden_states.shape
875
 
876
- # 2. RoPE Construction
877
- # Get dimensions from grid_thw for block layout conversion
878
- if grid_thw is not None:
879
- t_frames = grid_thw[0, 0].item()
880
- height = grid_thw[0, 1].item()
881
- width = grid_thw[0, 2].item()
 
 
 
882
  else:
883
- # Fallback: infer from total_patches (assume single frame, square)
884
- t_frames = 1
885
- height = int(total_patches ** 0.5)
886
- width = height
887
-
888
- if patch_positions is not None and patch_positions.dim() == 3:
889
- patch_positions = patch_positions.squeeze(0)
890
- freqs_visible = self.video_rope.forward_from_positions(patch_positions)
891
-
892
- # Convert RoPE from row-major to block layout (matching Qwen2VL processor output)
893
- # Use position-based grouping for videos with variable frame sizes
894
- # Pass grid_thw for reliable h, w extraction (especially for non-square images)
895
- freqs_visible = convert_rope_to_block_layout_by_positions(
896
- freqs_visible, patch_positions, spatial_merge_size=2, grid_thw=grid_thw
897
- )
898
 
899
  # Concatenate D/2 + D/2 -> D for applying rope
900
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
901
- if freqs_visible.dim() == 2:
902
- freqs_visible = freqs_visible.unsqueeze(0)
903
 
904
- # 3. Pre-Norm & Encoder
905
  hidden_states = self.layernorm_pre(hidden_states)
906
 
 
 
 
 
 
 
 
 
907
  encoder_outputs = self.encoder(
908
  hidden_states,
909
  attention_mask=None,
910
  rotary_pos_emb=freqs_visible,
911
  output_attentions=output_attentions,
912
- output_hidden_states=True, # Always get hidden states to use -2 layer
913
- return_dict=True,
914
  )
915
 
916
- # Use second-to-last layer output for better feature representation
917
- if encoder_outputs.hidden_states is not None and len(encoder_outputs.hidden_states) >= 2 and not skip_merger:
918
- sequence_output = encoder_outputs.hidden_states[-2]
919
- else:
920
- sequence_output = encoder_outputs[0]
921
 
922
- # Post-Norm
923
  if self.layernorm_post is not None:
924
  sequence_output = self.layernorm_post(sequence_output)
925
 
926
- # Skip merger for consistency check with original ViT
927
- if skip_merger:
928
- pooled_output = None
929
- if self.head is not None:
930
- pooled_output = self.head(sequence_output)
931
-
932
- if not return_dict:
933
- return (sequence_output, pooled_output) + (
934
- encoder_outputs.hidden_states if output_hidden_states else None,
935
- )
936
- return BaseModelOutputWithPooling(
937
- last_hidden_state=sequence_output,
938
- pooler_output=pooled_output,
939
- hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
940
- attentions=encoder_outputs.attentions if output_attentions else None,
941
- )
942
-
943
- # Patch merger: input patches are already in 2x2 block order from Qwen2VL processor
944
- merged_output = self.merger(sequence_output)
945
 
946
  if not return_dict:
947
- return (merged_output,) + (encoder_outputs.hidden_states if output_hidden_states else None,)
948
 
949
  return BaseModelOutputWithPooling(
950
- last_hidden_state=merged_output,
951
- pooler_output=None,
952
- hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
953
- attentions=encoder_outputs.attentions if output_attentions else None,
954
- )
955
-
956
- def forward_debug(
957
- self,
958
- hidden_state: torch.Tensor,
959
- grid_thw: torch.Tensor,
960
- ) -> dict:
961
- """
962
- Debug version of forward pass that captures intermediate states.
963
-
964
- Identical to forward() but saves intermediate outputs at key stages
965
- for debugging and consistency checking purposes.
966
-
967
- Args:
968
- hidden_state: Pre-processed patches from Qwen2VL processor.
969
- Shape: [total_patches, num_channels, patch_size, patch_size] or [total_patches, patch_dim]
970
- grid_thw: Grid sizes tensor of shape [num_samples, 3] with [t, h, w] for each sample.
971
-
972
- Returns:
973
- dict: Dictionary containing intermediate outputs:
974
- - "input_pixel_values": Input to the model
975
- - "after_patch_embed": Embeddings after patch projection
976
- - "rotary_pos_emb": Rotary position embeddings
977
- - "after_pre_layernorm": Embeddings after pre-normalization
978
- - "layer_outputs": Dict mapping layer index to input/output
979
- - "before_adapter": Final output before merger (same as after_decoder)
980
- - "after_merger": Output after patch merger
981
- """
982
- output = {}
983
-
984
- # Store input for consistency checking
985
- output["input_pixel_values"] = hidden_state.clone()
986
- output["input_grid_thw"] = grid_thw.clone()
987
-
988
- batch_size = grid_thw.size(0)
989
- assert batch_size == 1, "Currently only batch_size=1 is supported for forward_debug."
990
-
991
- # Determine video dimensions for RoPE
992
- t_frames = grid_thw[0, 0].item()
993
- height = grid_thw[0, 1].item()
994
- width = grid_thw[0, 2].item()
995
-
996
- # 1. Embeddings
997
- hidden_states = self.embeddings(hidden_state)
998
- if hidden_states.dim() == 2:
999
- hidden_states = hidden_states.unsqueeze(0) # [1, total_patches, embed_dim]
1000
- output["after_patch_embed"] = hidden_states.clone()
1001
-
1002
- batch_size, total_patches, _ = hidden_states.shape
1003
-
1004
- # 2. Visible Indices (simplified for debug - use all patches)
1005
- visible_indices = (
1006
- torch.arange(total_patches, device=hidden_state.device).unsqueeze(0).expand(batch_size, -1)
1007
- )
1008
-
1009
- # 3. RoPE Construction
1010
- freqs_full = self.video_rope.forward_with_thw(
1011
- t=64 if t_frames > 1 else 1,
1012
- h=height,
1013
- w=width,
1014
- device=hidden_state.device,
1015
- )
1016
-
1017
- # Convert RoPE from row-major to block layout
1018
- freqs_full_block = convert_rope_to_block_layout(
1019
- freqs_full, 1 if t_frames == 1 else 64, height, width, spatial_merge_size=2
1020
- ).unsqueeze(0)
1021
-
1022
- # Concatenate D/2 + D/2 -> D for applying rope
1023
- freqs_visible = torch.cat([freqs_full_block, freqs_full_block], dim=-1)
1024
- output["rotary_pos_emb"] = freqs_visible.clone()
1025
-
1026
- # 4. Pre-Norm
1027
- hidden_states = self.layernorm_pre(hidden_states)
1028
- output["after_pre_layernorm"] = hidden_states.clone()
1029
-
1030
- # 5. Encoder with layer-by-layer debug
1031
- encoder_debug_output = self.encoder.forward_debug(
1032
- hidden_states,
1033
- attention_mask=None,
1034
- rotary_pos_emb=freqs_visible,
1035
- )
1036
-
1037
- # Extract layer outputs
1038
- output["layer_outputs"] = encoder_debug_output.get("layer_outputs", {})
1039
-
1040
- # Get second-to-last layer output for merger (matching forward behavior)
1041
- # In forward_debug of encoder, final_output is the last layer output
1042
- final_hidden_states = encoder_debug_output.get("final_output", hidden_states)
1043
-
1044
- # For consistency with Megatron, we use the second-to-last layer
1045
- # But forward_debug doesn't easily give us that, so we'll use final
1046
- # and note that this is the output before merger
1047
- output["before_adapter"] = final_hidden_states.clone()
1048
-
1049
- # 6. Post-Norm (if exists)
1050
- if self.layernorm_post is not None:
1051
- final_hidden_states = self.layernorm_post(final_hidden_states)
1052
-
1053
- # 7. Merger
1054
- merged_output = self.merger(final_hidden_states)
1055
- output["after_merger"] = merged_output.clone()
1056
-
1057
- return output
1058
-
1059
-
1060
- @auto_docstring
1061
- class LlavaOnevision2Model(LlavaOnevision2PreTrainedModel):
1062
- base_model_prefix = ""
1063
- _checkpoint_conversion_mapping = {"^model": "language_model"}
1064
- # Reference: fix gemma3 grad acc #37208
1065
- accepts_loss_kwargs = False
1066
- config: LlavaOnevision2Config
1067
- _no_split_modules = ["LlavaViTEncoderLayer"]
1068
-
1069
- def __init__(self, config: LlavaOnevision2Config):
1070
- super().__init__(config)
1071
- self.visual = LlavaOnevision2VisionPretrainedModel._from_config(config.vision_config)
1072
- self.language_model = AutoModel.from_config(config.text_config)
1073
- # Initialize weights and apply final processing
1074
- self.post_init()
1075
-
1076
- def get_input_embeddings(self):
1077
- return self.language_model.get_input_embeddings()
1078
-
1079
- def set_input_embeddings(self, value):
1080
- self.language_model.set_input_embeddings(value)
1081
-
1082
- def set_decoder(self, decoder):
1083
- self.language_model = decoder
1084
-
1085
- def get_decoder(self):
1086
- return self.language_model
1087
-
1088
- def get_video_features(
1089
- self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None, patch_positions=None
1090
- ):
1091
- """
1092
- Encodes videos into continuous embeddings that can be forwarded to the language model.
1093
-
1094
- Args:
1095
- pixel_values_videos: Pre-processed patches from Qwen2VL processor.
1096
- `torch.FloatTensor` of shape `(total_patches, num_channels, patch_size, patch_size)`
1097
- video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1098
- The temporal, height and width of feature shape of each video in LLM.
1099
- """
1100
- # Convert to correct dtype
1101
- pixel_values_videos = pixel_values_videos.type(self.visual.embeddings.patch_embedding.weight.dtype)
1102
-
1103
- # Forward through vision model with grid_thw
1104
- vision_output = self.visual(pixel_values_videos, grid_thw=video_grid_thw, patch_positions=patch_positions)
1105
-
1106
- # Extract the actual tensor from BaseModelOutputWithPooling
1107
- if hasattr(vision_output, "last_hidden_state"):
1108
- video_embeds = vision_output.last_hidden_state
1109
- else:
1110
- video_embeds = vision_output[0] # Fallback for tuple output
1111
-
1112
- # Compute split sizes from video_grid_thw or from input shape
1113
- if video_grid_thw is not None:
1114
- split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1115
- else:
1116
- # Compute from input shape
1117
- batch_size = pixel_values_videos.shape[0]
1118
- split_sizes = [video_embeds.shape[1]] * batch_size
1119
-
1120
- # Split embeddings per video
1121
- if len(split_sizes) > 1:
1122
- video_embeds = torch.split(video_embeds.view(-1, video_embeds.shape[-1]), split_sizes)
1123
- else:
1124
- video_embeds = [video_embeds.view(-1, video_embeds.shape[-1])]
1125
-
1126
- return video_embeds
1127
-
1128
- def get_image_features(self, pixel_values, image_grid_thw: Optional[torch.LongTensor] = None, patch_positions=None):
1129
- """
1130
- Encodes images into continuous embeddings that can be forwarded to the language model.
1131
-
1132
- Args:
1133
- pixel_values: Pre-processed patches from Qwen2VL processor.
1134
- - `torch.FloatTensor` of shape `(total_patches, num_channels, patch_size, patch_size)`
1135
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1136
- The temporal, height and width of feature shape of each image in LLM.
1137
- """
1138
- # Standard format from Qwen2VL processor
1139
- if pixel_values.dim() == 2:
1140
- # Convert to correct dtype
1141
- pixel_values = pixel_values.type(self.visual.embeddings.patch_embedding.weight.dtype)
1142
-
1143
- # Forward through vision model with grid_thw
1144
- vision_output = self.visual(pixel_values, grid_thw=image_grid_thw, patch_positions=patch_positions)
1145
-
1146
- # Extract the actual tensor from BaseModelOutputWithPooling
1147
- if hasattr(vision_output, "last_hidden_state"):
1148
- image_embeds = vision_output.last_hidden_state
1149
- else:
1150
- image_embeds = vision_output[0]
1151
-
1152
- # Compute split sizes from grid_thw
1153
- if image_grid_thw is not None:
1154
- split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1155
- else:
1156
- # Fallback: assume single image
1157
- split_sizes = [image_embeds.shape[0] if image_embeds.dim() == 2 else image_embeds.shape[1]]
1158
-
1159
- # Split embeddings per image
1160
- image_embeds_flat = image_embeds.view(-1, image_embeds.shape[-1])
1161
- if len(split_sizes) > 1:
1162
- image_embeds = list(torch.split(image_embeds_flat, split_sizes))
1163
- else:
1164
- image_embeds = [image_embeds_flat]
1165
-
1166
- return image_embeds
1167
- else:
1168
- raise ValueError(
1169
- f"Unsupported pixel_values shape: expected 4D tensor [total_patches, C, H, W], "
1170
- f"got {pixel_values.shape if hasattr(pixel_values, 'shape') else type(pixel_values)}"
1171
- )
1172
-
1173
- def get_placeholder_mask(
1174
- self,
1175
- input_ids: torch.LongTensor,
1176
- inputs_embeds: torch.FloatTensor,
1177
- image_features: Optional[torch.FloatTensor] = None,
1178
- video_features: Optional[torch.FloatTensor] = None,
1179
- ):
1180
- """
1181
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1182
- equal to the length of multimodal features. If the lengths are different, an error is raised.
1183
- """
1184
- if input_ids is None:
1185
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
1186
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1187
- )
1188
- special_image_mask = special_image_mask.all(-1)
1189
- special_video_mask = inputs_embeds == self.get_input_embeddings()(
1190
- torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
1191
- )
1192
- special_video_mask = special_video_mask.all(-1)
1193
- else:
1194
- special_image_mask = input_ids == self.config.image_token_id
1195
- special_video_mask = input_ids == self.config.video_token_id
1196
-
1197
- n_image_tokens = special_image_mask.sum()
1198
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1199
- if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
1200
- raise ValueError(
1201
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
1202
- )
1203
-
1204
- n_video_tokens = special_video_mask.sum()
1205
- special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1206
- if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
1207
- raise ValueError(
1208
- f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
1209
- )
1210
-
1211
- return special_image_mask, special_video_mask
1212
-
1213
- @auto_docstring
1214
- def forward(
1215
- self,
1216
- input_ids: Optional[torch.LongTensor] = None,
1217
- attention_mask: Optional[torch.Tensor] = None,
1218
- position_ids: Optional[torch.LongTensor] = None,
1219
- past_key_values: Optional[Cache] = None,
1220
- inputs_embeds: Optional[torch.FloatTensor] = None,
1221
- use_cache: Optional[bool] = None,
1222
- output_attentions: Optional[bool] = None,
1223
- output_hidden_states: Optional[bool] = None,
1224
- return_dict: Optional[bool] = None,
1225
- pixel_values: Optional[torch.Tensor] = None,
1226
- pixel_values_videos: Optional[torch.FloatTensor] = None,
1227
- image_grid_thw: Optional[torch.LongTensor] = None,
1228
- patch_positions: Optional[torch.LongTensor] = None,
1229
- video_grid_thw: Optional[torch.LongTensor] = None,
1230
- cache_position: Optional[torch.LongTensor] = None,
1231
- second_per_grid_ts: Optional[torch.Tensor] = None,
1232
- **kwargs: Unpack[TransformersKwargs],
1233
- ) -> Union[tuple, LlavaOnevision2ModelOutputWithPast]:
1234
- r"""
1235
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1236
- The temporal, height and width of feature shape of each image in LLM.
1237
- video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1238
- The temporal, height and width of feature shape of each video in LLM.
1239
- second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1240
- The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1241
- """
1242
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1243
- output_hidden_states = (
1244
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1245
- )
1246
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1247
-
1248
- if inputs_embeds is None:
1249
- inputs_embeds = self.get_input_embeddings()(input_ids)
1250
-
1251
- image_embeds = None
1252
-
1253
- if pixel_values is not None:
1254
- image_embeds = self.get_image_features(pixel_values, image_grid_thw, patch_positions=patch_positions)
1255
-
1256
- if image_embeds is not None:
1257
- image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1258
- image_mask, _ = self.get_placeholder_mask(
1259
- input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1260
- )
1261
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1262
-
1263
- if pixel_values_videos is not None:
1264
- video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
1265
- video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1266
- _, video_mask = self.get_placeholder_mask(
1267
- input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1268
- )
1269
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1270
-
1271
- # Use simple 1D position_ids
1272
- if position_ids is None:
1273
- batch_size, seq_length, _ = inputs_embeds.shape
1274
- if attention_mask is not None:
1275
- position_ids = attention_mask.long().cumsum(-1) - 1
1276
- position_ids.masked_fill_(attention_mask == 0, 1)
1277
- else:
1278
- position_ids = (
1279
- torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
1280
- )
1281
-
1282
- # Handle cache_position for generation
1283
- if cache_position is not None and cache_position[0] != 0:
1284
- position_ids = position_ids + cache_position[0]
1285
-
1286
- outputs = self.language_model(
1287
- input_ids=None,
1288
- position_ids=position_ids,
1289
- attention_mask=attention_mask,
1290
- past_key_values=past_key_values,
1291
- inputs_embeds=inputs_embeds,
1292
- use_cache=use_cache,
1293
- output_attentions=output_attentions,
1294
- output_hidden_states=output_hidden_states,
1295
- return_dict=True,
1296
- cache_position=cache_position,
1297
- **kwargs,
1298
- )
1299
-
1300
- output = LlavaOnevision2ModelOutputWithPast(
1301
- last_hidden_state=outputs.last_hidden_state,
1302
- past_key_values=outputs.past_key_values,
1303
- hidden_states=outputs.hidden_states,
1304
- attentions=outputs.attentions,
1305
- )
1306
- return output if return_dict else output.to_tuple()
1307
-
1308
-
1309
- @auto_docstring
1310
- class LlavaOnevision2ForConditionalGeneration(LlavaOnevision2PreTrainedModel, GenerationMixin):
1311
- _checkpoint_conversion_mapping = {
1312
- "^visual": "model.visual",
1313
- r"^model(?!\.(language_model|visual))": "model.language_model",
1314
- }
1315
- _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
1316
- # Reference: fix gemma3 grad acc #37208
1317
- accepts_loss_kwargs = False
1318
-
1319
- def __init__(self, config):
1320
- super().__init__(config)
1321
- self.model = LlavaOnevision2Model(config)
1322
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1323
- self.post_init()
1324
-
1325
- def get_input_embeddings(self):
1326
- return self.model.get_input_embeddings()
1327
-
1328
- def set_input_embeddings(self, value):
1329
- self.model.set_input_embeddings(value)
1330
-
1331
- def set_decoder(self, decoder):
1332
- self.model.set_decoder(decoder)
1333
-
1334
- def get_decoder(self):
1335
- return self.model.get_decoder()
1336
-
1337
- def get_video_features(
1338
- self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1339
- ):
1340
- return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1341
-
1342
- def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1343
- return self.model.get_image_features(pixel_values, image_grid_thw)
1344
-
1345
- # Make modules available through conditional class for BC
1346
- @property
1347
- def language_model(self):
1348
- return self.model.language_model
1349
-
1350
- @property
1351
- def visual(self):
1352
- return self.model.visual
1353
-
1354
- @can_return_tuple
1355
- @auto_docstring
1356
- def forward(
1357
- self,
1358
- input_ids: Optional[torch.LongTensor] = None,
1359
- attention_mask: Optional[torch.Tensor] = None,
1360
- position_ids: Optional[torch.LongTensor] = None,
1361
- past_key_values: Optional[Cache] = None,
1362
- inputs_embeds: Optional[torch.FloatTensor] = None,
1363
- labels: Optional[torch.LongTensor] = None,
1364
- use_cache: Optional[bool] = None,
1365
- output_attentions: Optional[bool] = None,
1366
- output_hidden_states: Optional[bool] = None,
1367
- pixel_values: Optional[torch.Tensor] = None,
1368
- pixel_values_videos: Optional[torch.FloatTensor] = None,
1369
- image_grid_thw: Optional[torch.LongTensor] = None,
1370
- patch_positions: Optional[torch.LongTensor] = None,
1371
- video_grid_thw: Optional[torch.LongTensor] = None,
1372
- cache_position: Optional[torch.LongTensor] = None,
1373
- second_per_grid_ts: Optional[torch.Tensor] = None,
1374
- logits_to_keep: Union[int, torch.Tensor] = 0,
1375
- **kwargs: Unpack[TransformersKwargs],
1376
- ) -> Union[tuple, LlavaOnevision2CausalLMOutputWithPast]:
1377
- r"""
1378
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1379
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1380
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1381
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1382
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1383
- The temporal, height and width of feature shape of each image in LLM.
1384
- video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1385
- The temporal, height and width of feature shape of each video in LLM.
1386
- second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
1387
- The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
1388
-
1389
- Example:
1390
-
1391
- ```python
1392
- >>> from PIL import Image
1393
- >>> import requests
1394
- >>> from transformers import AutoProcessor, LlavaOnevision2ForConditionalGeneration
1395
-
1396
- >>> model = LlavaOnevision2ForConditionalGeneration.from_pretrained("Deep-VLM/LLaVA-OneVision-1.5-8B-Instruct-hf", trust_remote_code=True)
1397
- >>> processor = AutoProcessor.from_pretrained("Deep-VLM/LLaVA-OneVision-1.5-8B-Instruct-hf", trust_remote_code=True)
1398
-
1399
- >>> messages = [
1400
- {
1401
- "role": "user",
1402
- "content": [
1403
- {"type": "image"},
1404
- {"type": "text", "text": "What is shown in this image?"},
1405
- ],
1406
- },
1407
- ]
1408
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1409
- >>> image = Image.open(requests.get(url, stream=True).raw)
1410
-
1411
- >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1412
- >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1413
-
1414
- >>> # Generate
1415
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1416
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1417
- "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1418
- ```"""
1419
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1420
- output_hidden_states = (
1421
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1422
- )
1423
- outputs = self.model(
1424
- input_ids=input_ids,
1425
- pixel_values=pixel_values,
1426
- pixel_values_videos=pixel_values_videos,
1427
- image_grid_thw=image_grid_thw,
1428
- patch_positions=patch_positions,
1429
- video_grid_thw=video_grid_thw,
1430
- second_per_grid_ts=second_per_grid_ts,
1431
- position_ids=position_ids,
1432
- attention_mask=attention_mask,
1433
- past_key_values=past_key_values,
1434
- inputs_embeds=inputs_embeds,
1435
- use_cache=use_cache,
1436
- output_attentions=output_attentions,
1437
- output_hidden_states=output_hidden_states,
1438
- return_dict=True,
1439
- cache_position=cache_position,
1440
- **kwargs,
1441
- )
1442
-
1443
- hidden_states = outputs[0]
1444
-
1445
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1446
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1447
- logits = self.lm_head(hidden_states[:, slice_indices, :])
1448
-
1449
- loss = None
1450
- if labels is not None:
1451
- loss = self.loss_function(
1452
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1453
- )
1454
-
1455
- return LlavaOnevision2CausalLMOutputWithPast(
1456
- loss=loss,
1457
- logits=logits,
1458
- past_key_values=outputs.past_key_values,
1459
- hidden_states=outputs.hidden_states,
1460
- attentions=outputs.attentions,
1461
- )
1462
-
1463
- def prepare_inputs_for_generation(
1464
- self,
1465
- input_ids,
1466
- past_key_values=None,
1467
- attention_mask=None,
1468
- inputs_embeds=None,
1469
- cache_position=None,
1470
- position_ids=None,
1471
- use_cache=True,
1472
- pixel_values=None,
1473
- pixel_values_videos=None,
1474
- image_grid_thw=None,
1475
- patch_positions=None,
1476
- video_grid_thw=None,
1477
- second_per_grid_ts=None,
1478
- **kwargs,
1479
- ):
1480
- # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1481
- model_inputs = super().prepare_inputs_for_generation(
1482
- input_ids,
1483
- past_key_values=past_key_values,
1484
- attention_mask=attention_mask,
1485
- inputs_embeds=inputs_embeds,
1486
- cache_position=cache_position,
1487
- position_ids=position_ids,
1488
- pixel_values=pixel_values,
1489
- pixel_values_videos=pixel_values_videos,
1490
- image_grid_thw=image_grid_thw,
1491
- video_grid_thw=video_grid_thw,
1492
- second_per_grid_ts=second_per_grid_ts,
1493
- patch_positions=patch_positions,
1494
- use_cache=use_cache,
1495
- **kwargs,
1496
- )
1497
-
1498
- if cache_position[0] != 0:
1499
- model_inputs["pixel_values"] = None
1500
- model_inputs["pixel_values_videos"] = None
1501
-
1502
- return model_inputs
1503
-
1504
- def _get_image_nums_and_video_nums(
1505
- self,
1506
- input_ids: Optional[torch.LongTensor],
1507
- inputs_embeds: Optional[torch.Tensor] = None,
1508
- ) -> tuple[torch.Tensor, torch.Tensor]:
1509
- """
1510
- Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1511
- These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1512
-
1513
- Args:
1514
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1515
- Indices of input sequence tokens in the vocabulary.
1516
-
1517
- Returns:
1518
- image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1519
- video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1520
- """
1521
- image_token_id = self.config.image_token_id
1522
- video_token_id = self.config.video_token_id
1523
- vision_start_token_id = self.config.vision_start_token_id
1524
-
1525
- if inputs_embeds is not None:
1526
- vision_start_mask = (
1527
- inputs_embeds
1528
- == self.get_input_embeddings()(
1529
- torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1530
- )
1531
- )[..., 0]
1532
- image_mask = (
1533
- inputs_embeds
1534
- == self.get_input_embeddings()(
1535
- torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1536
- )
1537
- )[..., 0]
1538
- video_mask = (
1539
- inputs_embeds
1540
- == self.get_input_embeddings()(
1541
- torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1542
- )
1543
- )[..., 0]
1544
- else:
1545
- vision_start_mask = input_ids == vision_start_token_id
1546
- image_mask = input_ids == image_token_id
1547
- video_mask = input_ids == video_token_id
1548
-
1549
- vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1550
- image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1551
- video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1552
-
1553
- return image_nums, video_nums
1554
-
1555
- def _expand_inputs_for_generation(
1556
- self,
1557
- expand_size: int = 1,
1558
- is_encoder_decoder: bool = False,
1559
- input_ids: Optional[torch.LongTensor] = None,
1560
- **model_kwargs,
1561
- ) -> tuple[torch.LongTensor, dict[str, Any]]:
1562
- # Overwritten -- Support for expanding tensors without a batch size dimension
1563
- # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1564
- # pixel_values.shape[0] is sum(seqlen_images for samples)
1565
- # image_grid_thw.shape[0] is sum(num_images for samples)
1566
-
1567
- if expand_size == 1:
1568
- return input_ids, model_kwargs
1569
-
1570
- visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1571
-
1572
- def _expand_dict_for_generation_visual(dict_to_expand):
1573
- image_grid_thw = model_kwargs.get("image_grid_thw", None)
1574
- video_grid_thw = model_kwargs.get("video_grid_thw", None)
1575
- image_nums, video_nums = self._get_image_nums_and_video_nums(
1576
- input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1577
- )
1578
-
1579
- def _repeat_interleave_samples(x, lengths, repeat_times):
1580
- samples = torch.split(x, lengths)
1581
- repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1582
- result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1583
- return result
1584
-
1585
- for key in dict_to_expand:
1586
- if key == "pixel_values":
1587
- # split images into samples
1588
- samples = torch.split(image_grid_thw, list(image_nums))
1589
- # compute the sequence length of images for each sample
1590
- lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1591
- dict_to_expand[key] = _repeat_interleave_samples(
1592
- dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1593
- )
1594
- elif key == "image_grid_thw":
1595
- # get the num of images for each sample
1596
- lengths = list(image_nums)
1597
- dict_to_expand[key] = _repeat_interleave_samples(
1598
- dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1599
- )
1600
- elif key == "pixel_values_videos":
1601
- samples = torch.split(video_grid_thw, list(video_nums))
1602
- lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1603
- dict_to_expand[key] = _repeat_interleave_samples(
1604
- dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1605
- )
1606
- elif key == "video_grid_thw":
1607
- lengths = list(video_nums)
1608
- dict_to_expand[key] = _repeat_interleave_samples(
1609
- dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1610
- )
1611
- elif key == "second_per_grid_ts":
1612
- dict_to_expand[key] = _repeat_interleave_samples(
1613
- dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1614
- )
1615
- return dict_to_expand
1616
-
1617
- def _expand_dict_for_generation(dict_to_expand):
1618
- for key in dict_to_expand:
1619
- if (
1620
- key != "cache_position"
1621
- and dict_to_expand[key] is not None
1622
- and isinstance(dict_to_expand[key], torch.Tensor)
1623
- and key not in visual_keys
1624
- ):
1625
- dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1626
- return dict_to_expand
1627
-
1628
- model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1629
-
1630
- if input_ids is not None:
1631
- input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1632
-
1633
- model_kwargs = _expand_dict_for_generation(model_kwargs)
1634
-
1635
- if is_encoder_decoder:
1636
- if model_kwargs.get("encoder_outputs") is None:
1637
- raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1638
- model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1639
-
1640
- return input_ids, model_kwargs
1641
-
1642
-
1643
- __all__ = [
1644
- "LlavaOnevision2ForConditionalGeneration",
1645
- "LlavaOnevision2Model",
1646
- "LlavaOnevision2PreTrainedModel",
1647
- "LlavaOnevision2VisionPretrainedModel",
1648
- # Vision components
1649
- "VisionRotaryEmbedding",
1650
- "LlavaViTEmbeddings",
1651
- "LlavaViTFlashAttention2",
1652
- "LlavaViTEncoderLayer",
1653
- "LlavaViTEncoder",
1654
- "LlavaOnevision2VisionPatchMerger",
1655
- "Siglip2MultiheadAttentionPoolingHead",
1656
- ]
 
1
+ from typing import Optional, Tuple, Union
 
2
 
3
  import torch
4
  import torch.nn as nn
 
5
 
6
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 
 
 
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.models.siglip.modeling_siglip import SiglipMLP
 
9
  from transformers.utils import (
10
+ add_start_docstrings,
11
+ add_start_docstrings_to_model_forward,
12
+ logging,
 
13
  replace_return_docstrings,
14
  )
15
 
16
+ from .configuration_onevision_encoder import OneVisionEncoderConfig
17
 
18
 
19
+ try:
20
  from flash_attn import flash_attn_func
21
 
22
+ _flash_attn_available = True
23
+ except ImportError:
24
+ _flash_attn_available = False
25
 
26
+ logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
27
 
 
 
 
28
 
29
+ # ---------------------------------------------------------------------------
30
+ # Model Docstrings
31
+ # ---------------------------------------------------------------------------
 
32
 
33
+ ONEVISION_ENCODER_START_DOCSTRING = r"""
34
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
35
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
36
+ etc.)
37
 
38
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
39
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
40
+ and behavior.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ Parameters:
43
+ config ([`OneVisionEncoderConfig`]): Model configuration class with all the parameters of the model.
44
+ Initializing with a config file does not load the weights associated with the model, only the
45
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
46
+ """
47
+
48
+ ONEVISION_ENCODER_INPUTS_DOCSTRING = r"""
49
+ Args:
50
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch_size, num_channels, num_frames, height, width)`):
51
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`].
52
+ visible_indices (`torch.Tensor`, *optional*):
53
+ Indices of visible patches for masking. Used in MAE-style pretraining or inference.
54
+ output_attentions (`bool`, *optional*):
55
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
56
+ tensors for more detail.
57
+ output_hidden_states (`bool`, *optional*):
58
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
59
+ more detail.
60
+ return_dict (`bool`, *optional*):
61
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
62
+ """
63
 
64
 
65
  # ---------------------------------------------------------------------------
66
+ # Helper Functions & Layers
67
  # ---------------------------------------------------------------------------
68
 
69
 
70
+ def get_norm_layer(config):
71
+ if config.layer_norm_type == "rms_norm":
72
+ return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+ else:
74
+ return nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
75
+
76
+
77
+ def rotate_half(x):
78
+ """
79
+ Interleaved rotation to match Source model's implementation.
80
+ (x1, x2, x3, x4) -> (-x2, x1, -x4, x3)
81
+ """
82
+ x_even = x[..., ::2]
83
+ x_odd = x[..., 1::2]
84
+ return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
85
+
86
+
87
+ def apply_rotary_pos_emb(q, k, freqs):
88
+ # q, k: (B, H, L, D)
89
+ # freqs: (B, L, D)
90
+
91
+ # We need to broadcast freqs to match heads
92
+ # (B, L, D) -> (B, 1, L, D)
93
+
94
+ # !!! CRITICAL FIX: Cast cos/sin to q.dtype (bf16/fp16) immediately
95
+ # freqs are typically float32, so cos() returns float32.
96
+ # Without this cast, (q * cos) upcasts q to float32, causing FlashAttention to fail.
97
+ cos = freqs.cos().unsqueeze(1).to(q.dtype)
98
+ sin = freqs.sin().unsqueeze(1).to(q.dtype)
99
+
100
+ q_embed = (q * cos) + (rotate_half(q) * sin)
101
+ k_embed = (k * cos) + (rotate_half(k) * sin)
102
+ return q_embed, k_embed
103
+
104
+
105
+ class VideoRotaryEmbeddingSplit466(nn.Module):
106
  """
107
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
 
108
  """
109
 
110
+ def __init__(self, config: OneVisionEncoderConfig):
111
  super().__init__()
112
  head_dim = config.hidden_size // config.num_attention_heads
113
  base = config.rope_theta
 
120
  self.head_dim = head_dim
121
  self.half = half
122
 
 
123
  unit = half // 16
124
  self.t_size = 4 * unit
125
  self.h_size = 6 * unit
 
141
  persistent=False,
142
  )
143
 
144
+ def forward(self, t: int, h: int, w: int, device=None):
145
+ if device is None:
146
+ device = self.inv_freq_t.device
 
 
 
147
 
 
 
 
 
148
  inv_t = self.inv_freq_t.to(device=device)
149
  inv_h = self.inv_freq_h.to(device=device)
150
  inv_w = self.inv_freq_w.to(device=device)
151
 
152
+ ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
153
+ fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
154
+ fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
 
 
 
 
 
 
 
 
 
 
155
 
156
+ t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
157
+ h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
158
+ w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
159
 
160
+ freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
161
+ return freqs
162
 
163
  def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
164
  """
165
  Compute rotary position embeddings from explicit patch positions.
166
 
167
  Args:
168
+ patch_positions: [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch
169
 
170
  Returns:
171
+ freqs: [batch_size, seq_len, half] tensor of position frequencies
172
  """
173
  device = patch_positions.device
174
  inv_t = self.inv_freq_t.to(device=device)
175
  inv_h = self.inv_freq_h.to(device=device)
176
  inv_w = self.inv_freq_w.to(device=device)
177
 
178
+ t_pos = patch_positions[..., 0].float() # [batch_size, seq_len]
179
+ h_pos = patch_positions[..., 1].float() # [batch_size, seq_len]
180
+ w_pos = patch_positions[..., 2].float() # [batch_size, seq_len]
181
 
182
+ # Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim]
183
+ ft = torch.einsum("bs,d->bsd", t_pos, inv_t)
184
+ fh = torch.einsum("bs,d->bsd", h_pos, inv_h)
185
+ fw = torch.einsum("bs,d->bsd", w_pos, inv_w)
186
 
187
  return torch.cat([ft, fh, fw], dim=-1)
188
 
 
 
 
189
 
190
+ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
191
+ """
192
+ Multi-Head Attention Pooling with a learned probe (PMA-style).
193
+ """
 
194
 
195
+ def __init__(self, config: OneVisionEncoderConfig):
196
+ super().__init__()
197
+ self.embed_dim = config.hidden_size
198
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
199
+ self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
200
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
201
+ self.mlp = SiglipMLP(config)
202
 
203
+ def forward(self, hidden_states):
204
+ batch_size = hidden_states.shape[0]
205
+ probe = self.probe.repeat(batch_size, 1, 1)
206
 
207
+ attn_output, _ = self.attention(probe, hidden_states, hidden_states)
 
 
208
 
209
+ residual = attn_output
210
+ attn_output = self.norm(attn_output)
211
+ attn_output = residual + self.mlp(attn_output)
212
 
213
+ return attn_output[:, 0]
 
214
 
215
 
216
  # ---------------------------------------------------------------------------
217
+ # Modeling Components
218
  # ---------------------------------------------------------------------------
219
 
220
 
221
+ class OneVisionEncoderEmbeddings(nn.Module):
222
+ def __init__(self, config: OneVisionEncoderConfig):
 
 
 
 
 
 
 
 
 
 
223
  super().__init__()
224
  self.config = config
225
  self.embed_dim = config.hidden_size
226
  self.image_size = config.image_size
227
  self.patch_size = config.patch_size
 
228
 
229
  self.patch_embedding = nn.Conv2d(
230
  in_channels=config.num_channels,
 
234
  bias=False,
235
  )
236
 
237
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
238
+ # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
239
+ if pixel_values.dim() == 4:
240
+ pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ batch_size, channels, t_frames, height, width = pixel_values.shape
 
 
243
 
244
+ # Merge time into batch for Conv2d
245
+ x_2d = pixel_values.permute(0, 2, 1, 3, 4).reshape(batch_size * t_frames, channels, height, width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ # Patch Embed
248
+ embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
249
+ embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
250
 
251
+ # Flatten all patches
252
+ total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
253
+ embeddings = embeddings.reshape(batch_size, total_patches, self.embed_dim)
254
 
255
+ return embeddings
 
 
256
 
 
 
 
 
 
 
 
257
 
258
+ class OneVisionEncoderAttention(nn.Module):
259
+ """Multi-headed attention with RoPE support"""
260
 
261
+ def __init__(self, config: OneVisionEncoderConfig):
262
+ super().__init__()
263
+ self.config = config
264
+ self.embed_dim = config.hidden_size
265
+ self.num_heads = config.num_attention_heads
266
+ self.head_dim = self.embed_dim // self.num_heads
267
+ if self.head_dim * self.num_heads != self.embed_dim:
268
+ raise ValueError(
269
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  )
271
 
272
+ self.scale = self.head_dim**-0.5
273
+ self.dropout = config.attention_dropout
 
 
274
 
275
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
276
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
279
 
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ rotary_pos_emb: Optional[torch.Tensor] = None,
285
+ output_attentions: bool = False,
286
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
287
+ batch_size, q_len, _ = hidden_states.size()
288
 
289
+ query_states = self.q_proj(hidden_states)
290
+ key_states = self.k_proj(hidden_states)
291
+ value_states = self.v_proj(hidden_states)
 
 
 
292
 
293
+ # (B, L, H, D) -> Transpose to (B, H, L, D)
294
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
295
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
296
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
297
 
298
+ if rotary_pos_emb is not None:
299
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, rotary_pos_emb)
300
 
301
+ # Calculate attention scores
302
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
303
 
304
+ if attention_mask is not None:
305
+ if attention_mask.size() != (batch_size, 1, q_len, q_len):
306
+ if attention_mask.dim() == 3:
307
+ attention_mask = attention_mask.unsqueeze(1)
308
+ attn_weights = attn_weights + attention_mask
 
309
 
310
+ # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
311
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
312
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
313
 
314
+ attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
315
 
316
+ attn_output = attn_output.transpose(1, 2).contiguous()
317
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
318
 
319
+ attn_output = self.out_proj(attn_output)
 
320
 
321
+ return attn_output, attn_weights if output_attentions else None
 
 
322
 
 
323
 
324
+ class OneVisionEncoderFlashAttention2(nn.Module):
325
  """
326
  Multi-headed attention with RoPE support using Flash Attention 2.
327
+ This module implements the same attention mechanism as OneVisionEncoderAttention but uses
328
+ Flash Attention for improved performance and memory efficiency.
329
  """
330
 
331
+ def __init__(self, config: OneVisionEncoderConfig):
332
  super().__init__()
333
  self.config = config
334
  self.embed_dim = config.hidden_size
 
341
 
342
  self.scale = self.head_dim**-0.5
343
  self.dropout = config.attention_dropout
344
+
345
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
346
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
347
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
348
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
349
 
350
  def forward(
351
  self,
 
353
  attention_mask: Optional[torch.Tensor] = None,
354
  rotary_pos_emb: Optional[torch.Tensor] = None,
355
  output_attentions: bool = False,
356
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
357
  """
358
  Forward pass using Flash Attention 2.
359
  """
360
  batch_size, q_len, _ = hidden_states.size()
361
+
362
+ query_states = self.q_proj(hidden_states)
363
+ key_states = self.k_proj(hidden_states)
364
+ value_states = self.v_proj(hidden_states)
 
 
365
 
366
  # Flash Attention requires (B, L, H, D) format
367
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
368
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
369
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
370
 
371
  # Apply RoPE if provided
372
  if rotary_pos_emb is not None:
 
379
  query_states = query_states.transpose(1, 2)
380
  key_states = key_states.transpose(1, 2)
381
 
 
 
 
382
  # Flash Attention forward pass
383
+ if not _flash_attn_available:
384
+ raise ImportError("flash_attn is not installed. Please install it to use OneVisionEncoderFlashAttention2.")
385
+
386
  attn_output = flash_attn_func(
387
  query_states,
388
  key_states,
 
396
  attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
397
 
398
  # No extra casting here.
399
+ attn_output = self.out_proj(attn_output)
 
400
 
401
  return attn_output, None
402
 
403
 
404
+ ONEVISION_ENCODER_ATTENTION_CLASSES = {
405
+ "eager": OneVisionEncoderAttention,
406
+ "flash_attention_2": OneVisionEncoderFlashAttention2,
407
+ }
408
 
409
+
410
+ class OneVisionEncoderEncoderLayer(nn.Module):
411
+ def __init__(self, config: OneVisionEncoderConfig):
412
  super().__init__()
413
  self.embed_dim = config.hidden_size
414
+ # Get attention implementation from config, default to "flash_attention_2"
415
+ attn_implementation = getattr(config, "_attn_implementation", "flash_attention_2")
416
+ if attn_implementation not in ONEVISION_ENCODER_ATTENTION_CLASSES:
417
+ # Fallback to eager if flash_attention_2 is not available
418
+ if not _flash_attn_available and attn_implementation == "flash_attention_2":
419
+ attn_implementation = "eager"
420
+ else:
421
+ raise ValueError(
422
+ f"Unknown attention implementation: {attn_implementation}. "
423
+ f"Available implementations: {list(ONEVISION_ENCODER_ATTENTION_CLASSES.keys())}"
424
+ )
425
+ self.self_attn = ONEVISION_ENCODER_ATTENTION_CLASSES[attn_implementation](config)
426
  self.layer_norm1 = get_norm_layer(config)
427
  self.mlp = SiglipMLP(config)
428
  self.layer_norm2 = get_norm_layer(config)
 
433
  attention_mask: Optional[torch.Tensor] = None,
434
  rotary_pos_emb: Optional[torch.Tensor] = None,
435
  output_attentions: bool = False,
436
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
437
  residual = hidden_states
438
  hidden_states = self.layer_norm1(hidden_states)
439
 
 
454
  return outputs
455
 
456
 
457
+ class OneVisionEncoderEncoder(nn.Module):
458
+ def __init__(self, config: OneVisionEncoderConfig):
459
  super().__init__()
460
  self.config = config
461
+ self.layers = nn.ModuleList([OneVisionEncoderEncoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
462
 
463
  def forward(
464
  self,
 
476
  if output_hidden_states:
477
  all_hidden_states = all_hidden_states + (hidden_states,)
478
 
479
+ layer_outputs = layer(
480
+ hidden_states,
481
+ attention_mask=attention_mask,
482
+ rotary_pos_emb=rotary_pos_emb,
483
+ output_attentions=output_attentions,
484
+ )
 
 
 
 
 
 
 
 
 
485
 
486
  hidden_states = layer_outputs[0]
487
 
 
500
  attentions=all_self_attentions,
501
  )
502
 
 
 
 
 
 
 
 
 
503
 
504
+ # ---------------------------------------------------------------------------
505
+ # Main Models
506
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
 
509
+ @add_start_docstrings(
510
+ "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
511
+ ONEVISION_ENCODER_START_DOCSTRING,
512
+ )
513
+ class OneVisionEncoderPreTrainedModel(PreTrainedModel):
514
+ config_class = OneVisionEncoderConfig
515
+ base_model_prefix = "onevision_encoder"
516
  supports_gradient_checkpointing = True
517
+ _no_split_modules = ["OneVisionEncoderEncoderLayer"]
518
  _supports_flash_attn_2 = True
 
 
 
 
519
 
520
  def _init_weights(self, module):
521
+ """Initialize the weights"""
522
+ std = self.config.initializer_range
523
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
524
+ module.weight.data.normal_(mean=0.0, std=std)
525
+ if module.bias is not None:
526
+ module.bias.data.zero_()
527
+ elif isinstance(module, nn.Embedding):
528
+ module.weight.data.normal_(mean=0.0, std=std)
529
+ if module.padding_idx is not None:
530
+ module.weight.data[module.padding_idx].zero_()
531
+ elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
532
+ # Fix: RMSNorm doesn't have bias, must check hasattr first
533
+ module.weight.data.fill_(1.0)
534
+ if hasattr(module, "bias") and module.bias is not None:
535
+ module.bias.data.zero_()
536
+
537
+
538
+ @add_start_docstrings(
539
+ "OneVision Encoder Model with a vision transformer encoder.",
540
+ ONEVISION_ENCODER_START_DOCSTRING,
541
+ )
542
+ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
543
+ def __init__(self, config: OneVisionEncoderConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  super().__init__(config)
545
  self.config = config
 
546
 
547
+ self.embeddings = OneVisionEncoderEmbeddings(config)
 
548
  self.layernorm_pre = get_norm_layer(config)
549
+ self.encoder = OneVisionEncoderEncoder(config)
550
+ self.video_rope = VideoRotaryEmbeddingSplit466(config)
551
 
552
  if config.use_head:
553
  self.layernorm_post = get_norm_layer(config)
 
556
  self.layernorm_post = None
557
  self.head = None
558
 
 
 
 
 
 
 
 
559
  self.post_init()
560
 
561
+ @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
562
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
563
  def forward(
564
  self,
565
+ pixel_values: torch.Tensor,
566
+ visible_indices: Optional[torch.Tensor] = None,
567
  patch_positions: Optional[torch.Tensor] = None,
568
  output_attentions: Optional[bool] = None,
569
  output_hidden_states: Optional[bool] = None,
570
  return_dict: Optional[bool] = None,
 
571
  ) -> Union[tuple, BaseModelOutputWithPooling]:
572
  r"""
573
+ Returns:
574
 
575
+ Examples:
 
 
576
 
577
+ ```python
578
+ >>> from transformers import AutoModel, AutoImageProcessor
579
+ >>> from PIL import Image
 
 
 
 
 
 
 
580
 
581
+ >>> model = AutoModel.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True)
582
+ >>> preprocessor = AutoImageProcessor.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True)
583
+ >>> image = Image.open("path/to/your/image.jpg") # Replace with your image path
584
+ >>> pixel_values = preprocessor(images=image, return_tensors="pt")["pixel_values"]
585
+ >>> outputs = model(pixel_values)
586
+ >>> last_hidden_states = outputs.last_hidden_state
587
+ >>> pooled_output = outputs.pooler_output
588
+ ```
589
  """
590
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
591
  output_hidden_states = (
592
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
593
  )
594
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
595
+
596
+ # Determine video dimensions for RoPE
597
+ # Note: pixel_values passed to embeddings can be 4D or 5D
598
+ if pixel_values.dim() == 5:
599
+ # Use config.rope_temporal_size if set, otherwise use actual frame count
600
+ t_frames = (
601
+ self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
602
+ )
603
+ height = pixel_values.shape[3]
604
+ width = pixel_values.shape[4]
605
+ else:
606
+ t_frames = 1
607
+ height = pixel_values.shape[2]
608
+ width = pixel_values.shape[3]
609
 
610
  # 1. Embeddings
611
+ hidden_states = self.embeddings(pixel_values)
 
 
 
612
  batch_size, total_patches, _ = hidden_states.shape
613
 
614
+ # 2. Visible Indices Handling
615
+ if visible_indices is None:
616
+ visible_indices = (
617
+ torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
618
+ )
619
+
620
+ # 3. RoPE Construction
621
+ if patch_positions is not None:
622
+ freqs_visible = self.video_rope.forward_from_positions(patch_positions)
623
  else:
624
+ freqs_full = self.video_rope(
625
+ t=t_frames,
626
+ h=height // self.config.patch_size,
627
+ w=width // self.config.patch_size,
628
+ device=pixel_values.device,
629
+ )
630
+ freqs_visible = freqs_full[visible_indices]
 
 
 
 
 
 
 
 
631
 
632
  # Concatenate D/2 + D/2 -> D for applying rope
633
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
 
 
634
 
635
+ # 4. Pre-Norm & Encoder
636
  hidden_states = self.layernorm_pre(hidden_states)
637
 
638
+ # fix: gather hidden_states to match freqs_visible when using sparse visible_indices
639
+ num_visible = visible_indices.shape[1]
640
+ if num_visible != total_patches:
641
+ # sparse mode: select only visible patches
642
+ hidden_states = hidden_states.gather(
643
+ 1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
644
+ )
645
+
646
  encoder_outputs = self.encoder(
647
  hidden_states,
648
  attention_mask=None,
649
  rotary_pos_emb=freqs_visible,
650
  output_attentions=output_attentions,
651
+ output_hidden_states=output_hidden_states,
652
+ return_dict=return_dict,
653
  )
654
 
655
+ sequence_output = encoder_outputs[0]
 
 
 
 
656
 
657
+ # Apply post-norm if configured
658
  if self.layernorm_post is not None:
659
  sequence_output = self.layernorm_post(sequence_output)
660
 
661
+ # 5. Pooling Head
662
+ pooled_output = None
663
+ if self.head is not None:
664
+ pooled_output = self.head(sequence_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  if not return_dict:
667
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
668
 
669
  return BaseModelOutputWithPooling(
670
+ last_hidden_state=sequence_output,
671
+ pooler_output=pooled_output,
672
+ hidden_states=encoder_outputs.hidden_states,
673
+ attentions=encoder_outputs.attentions,
674
+ )