Update code for transformers 5.5.4

#4
by sjzhou - opened
Files changed (1) hide show
  1. modeling_moss_vl.py +268 -1317
modeling_moss_vl.py CHANGED
@@ -14,21 +14,18 @@
14
  # limitations under the License.
15
  """PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
16
 
17
- import copy
18
  from dataclasses import dataclass
19
- import queue
20
- import threading
21
- from typing import Any, Callable, Dict, Optional, Union, Tuple, List
22
 
23
  import torch
24
  import torch.nn as nn
25
  import torch.nn.functional as F
26
 
 
 
27
  from transformers.activations import ACT2FN
28
  from transformers.cache_utils import Cache, DynamicCache
29
  from transformers.generation import GenerationMixin
30
- from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
31
- from transformers.generation.streamers import TextIteratorStreamer
32
  from transformers.integrations import use_kernel_forward_from_hub
33
  from transformers.masking_utils import create_causal_mask
34
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
@@ -39,7 +36,8 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
  from transformers.processing_utils import Unpack
40
  from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging
41
  from transformers.utils.deprecation import deprecate_kwarg
42
- from transformers.utils.generic import OutputRecorder
 
43
 
44
  from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionConfig
45
 
@@ -47,58 +45,6 @@ from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionC
47
 
48
  logger = logging.get_logger(__name__)
49
 
50
- _OFFLINE_SYSTEM_PROMPTS = {
51
- "no_thinking": {
52
- "text_image": "You are a helpful AI assistant. Respond to the user's request based on the provided text and/or images.",
53
- "video": "You are a helpful AI assistant specializing in video analysis. Respond to the user's request based on the provided video content.",
54
- },
55
- "deep_thinking": {
56
- "text_image": "A conversation between User and Assistant. The user makes a request, and the assistant responds to it based on the provided text and/or images. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
57
- "video": "A conversation between User and Assistant specializing in video analysis. The user makes a request, and the assistant responds to it based on the provided video content. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
58
- },
59
- }
60
-
61
-
62
- class _OfflineCancelStoppingCriteria(StoppingCriteria):
63
- def __init__(self, cancel_event: threading.Event):
64
- self.cancel_event = cancel_event
65
-
66
- def __call__(self, input_ids, scores, **kwargs) -> bool:
67
- return self.cancel_event.is_set()
68
-
69
-
70
- class _OfflineQueueStreamer(TextIteratorStreamer):
71
- def __init__(self, tokenizer, output_text_queue: "queue.Queue[str]"):
72
- super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
73
- self.output_text_queue = output_text_queue
74
- self.collected_chunks: List[str] = []
75
-
76
- def on_finalized_text(self, text: str, stream_end: bool = False):
77
- if text:
78
- self.collected_chunks.append(text)
79
- self.output_text_queue.put(text)
80
- super().on_finalized_text(text, stream_end=stream_end)
81
-
82
-
83
- _OFFLINE_THINKING_MODE_ALIASES = {
84
- "no_thinking": "no_thinking",
85
- "default": "no_thinking",
86
- "standard": "no_thinking",
87
- "deep_thinking": "deep_thinking",
88
- "thinking": "deep_thinking",
89
- "reasoning": "deep_thinking",
90
- }
91
-
92
- _OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES = {
93
- "text_image": "text_image",
94
- "text-image": "text_image",
95
- "image_text": "text_image",
96
- "image-text": "text_image",
97
- "text": "text_image",
98
- "image": "text_image",
99
- "video": "video",
100
- }
101
-
102
 
103
  @dataclass
104
  class MossVLModelOutputWithPast(ModelOutput):
@@ -198,13 +144,21 @@ class MossVLVisionPatchEmbed(nn.Module):
198
 
199
 
200
  class MossVLVisionRotaryEmbedding(nn.Module):
201
- inv_freq: torch.Tensor
202
 
203
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
204
  super().__init__()
205
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
 
 
 
 
 
206
  self.register_buffer("inv_freq", inv_freq, persistent=False)
207
 
 
 
 
208
  def forward(self, seqlen: int) -> torch.Tensor:
209
  seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
210
  freqs = torch.outer(seq, self.inv_freq)
@@ -233,11 +187,16 @@ class MossVLVisionPatchMerger(nn.Module):
233
  self.act_fn = nn.GELU()
234
  self.linear_fc2 = nn.Linear(self.input_hidden_size, config.out_hidden_size)
235
 
236
- def forward(self, last_hidden_state: torch.Tensor, deepstack_features: List[torch.Tensor] = []) -> torch.Tensor:
 
 
 
 
237
  # 1. Collect all features: [last_hidden_state, deepstack_1, deepstack_2, ...]
238
  # self.norms[0] corresponds to last_hidden_state
239
  # self.norms[1:] corresponds to deepstack_features
240
-
 
241
  all_inputs = [last_hidden_state] + deepstack_features
242
 
243
  # 2. Apply Norm independently
@@ -346,11 +305,11 @@ class MossVLVisionAttention(nn.Module):
346
  key_states = key_states.transpose(0, 1).unsqueeze(0)
347
  value_states = value_states.transpose(0, 1).unsqueeze(0)
348
 
349
- attention_interface: Callable = eager_attention_forward
350
- if self.config._attn_implementation != "eager":
351
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
352
 
353
- if self.config._attn_implementation == "flash_attention_2":
354
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
355
  attn_output, _ = attention_interface(
356
  self,
@@ -429,26 +388,44 @@ class MossVLTextRotaryEmbedding(nn.Module):
429
 
430
  def __init__(self, config: MossVLTextConfig, device=None):
431
  super().__init__()
432
- # BC: "rope_type" was originally "type"
433
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
434
- self.rope_type = config.rope_scaling.get("rope_type", "default")
435
- else:
436
- self.rope_type = "default"
437
  self.max_seq_len_cached = config.max_position_embeddings
438
  self.original_max_seq_len = config.max_position_embeddings
439
 
440
  self.config = config
441
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
442
 
443
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
444
  self.register_buffer("inv_freq", inv_freq, persistent=False)
445
- self.original_inv_freq = self.inv_freq
446
 
 
447
 
448
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
449
- self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
450
- else:
451
- self.mrope_section = [24, 20, 20]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  def apply_interleaved_mrope(self, freqs, mrope_section):
454
  """Apply interleaved MRoPE to 3D rotary embeddings.
@@ -470,7 +447,6 @@ class MossVLTextRotaryEmbedding(nn.Module):
470
  @torch.no_grad()
471
  @dynamic_rope_update
472
  def forward(self, x, position_ids):
473
-
474
  if position_ids.ndim == 2:
475
  position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
476
 
@@ -571,12 +547,11 @@ class MossVLTextSelfAttention(nn.Module):
571
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
572
 
573
  if past_key_values is not None:
574
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
575
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
576
 
577
- attention_interface: Callable = eager_attention_forward
578
- if self.config._attn_implementation != "eager":
579
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
580
 
581
  attn_output, attn_weights = attention_interface(
582
  self,
@@ -625,7 +600,7 @@ class MossVLTextCrossAttention(nn.Module):
625
  attention_mask: Optional[torch.Tensor] = None,
626
  past_key_values: Optional[Cache] = None,
627
  use_cache: bool = None,
628
- cache_position: Optional[torch.LongTensor] = None, # vision_cache_position
629
  query_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
630
  vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
631
  **kwargs,
@@ -659,9 +634,7 @@ class MossVLTextCrossAttention(nn.Module):
659
  if past_key_values is not None:
660
  # if we have a new image + new tokens, we only computed key_states on that new image
661
  # we still update the cross key states, past_image, new_image. And use it!
662
- key_states, value_states = past_key_values.update(
663
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
664
- )
665
 
666
  elif cache_position[0] != 0:
667
  key_states, value_states = (
@@ -673,13 +646,13 @@ class MossVLTextCrossAttention(nn.Module):
673
  "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
674
  )
675
 
676
- attention_interface: Callable = eager_attention_forward
677
- if self.config._attn_implementation != "eager":
678
- # 如果是flash attention,走sdpa_attention_forward
679
- if self.config._attn_implementation == "flash_attention_3" or self.config._attn_implementation == "flash_attention_2":
680
- attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
681
- else:
682
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
683
 
684
  attn_output, attn_weights = attention_interface(
685
  self,
@@ -740,14 +713,14 @@ class MossVLSelfAttentionDecoderLayer(GradientCheckpointingLayer):
740
  use_cache: Optional[bool] = False,
741
  cache_position: Optional[torch.LongTensor] = None,
742
  vision_position_ids: Optional[torch.LongTensor] = None,
743
- vision_cache_position: Optional[torch.LongTensor] = None,
744
  vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
 
745
  **kwargs: Unpack[TransformersKwargs],
746
- ) -> torch.Tensor:
747
  # Self Attention
748
  residual = hidden_states
749
  hidden_states = self.input_layernorm(hidden_states)
750
- hidden_states, _ = self.self_attn(
751
  hidden_states=hidden_states,
752
  attention_mask=attention_mask,
753
  past_key_values=past_key_values,
@@ -762,8 +735,11 @@ class MossVLSelfAttentionDecoderLayer(GradientCheckpointingLayer):
762
  hidden_states = self.post_attention_layernorm(hidden_states)
763
  hidden_states = self.mlp(hidden_states)
764
  hidden_states = residual + hidden_states
765
-
766
- return hidden_states
 
 
 
767
 
768
 
769
  class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
@@ -799,21 +775,21 @@ class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
799
  use_cache: Optional[bool] = False,
800
  cache_position: Optional[torch.LongTensor] = None,
801
  vision_position_ids: Optional[torch.LongTensor] = None,
802
- vision_cache_position: Optional[torch.LongTensor] = None,
803
  vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
 
804
  **kwargs: Unpack[TransformersKwargs],
805
- ) -> torch.Tensor:
806
  # Cross Attention
807
  residual = hidden_states
808
  hidden_states = self.input_layernorm(hidden_states)
809
 
810
- hidden_states, _ = self.cross_attn(
811
  hidden_states=hidden_states,
812
  cross_attention_states=cross_attention_states,
813
  attention_mask=cross_attention_mask,
814
  past_key_values=past_key_values,
815
  use_cache=use_cache,
816
- cache_position=vision_cache_position,
817
  query_position_embeddings=position_embeddings,
818
  vision_position_embeddings=vision_position_embeddings,
819
  )
@@ -830,8 +806,11 @@ class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
830
  hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
831
 
832
  hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
833
-
834
- return hidden_states
 
 
 
835
 
836
 
837
 
@@ -857,32 +836,10 @@ class MossVLPreTrainedModel(PreTrainedModel):
857
 
858
  def _init_weights(self, module):
859
  """Initialize the weights.
860
-
861
- Note: For loading pretrained weights:
862
- - Cross attention: can be initialized from the previous layer's self attention weights
863
  """
864
- std = getattr(self.config, "initializer_range", 0.02)
865
- if hasattr(self.config, "text_config") and hasattr(self.config.text_config, "initializer_range"):
866
- std = self.config.text_config.initializer_range
867
-
868
- if isinstance(module, MossVLVisionPatchMerger):
869
- # Initialize merger weights
870
- # Input: hidden_size * (1 + num_deepstack_features) -> Output: out_hidden_size
871
- # This projection handles concatenated features, so we might want specific initialization
872
- module.linear_fc1.weight.data.normal_(mean=0.0, std=std)
873
- module.linear_fc2.weight.data.normal_(mean=0.0, std=std)
874
- if module.linear_fc1.bias is not None:
875
- module.linear_fc1.bias.data.zero_()
876
- if module.linear_fc2.bias is not None:
877
- module.linear_fc2.bias.data.zero_()
878
-
879
- # Initialize separate LayerNorms
880
- if hasattr(module, "norms"):
881
- for norm in module.norms:
882
- if hasattr(norm, "weight") and norm.weight is not None:
883
- norm.weight.data.fill_(1.0)
884
- if hasattr(norm, "bias") and norm.bias is not None:
885
- norm.bias.data.zero_()
886
 
887
 
888
 
@@ -958,13 +915,15 @@ class MossVLVisionModel(MossVLPreTrainedModel):
958
 
959
  def fast_pos_embed_interpolate(self, grid_thw):
960
  grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
 
 
961
 
962
- idx_list = [[] for _ in range(4)]
963
- weight_list = [[] for _ in range(4)]
964
 
965
  for t, h, w in zip(grid_ts, grid_hs, grid_ws):
966
- h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
967
- w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
968
 
969
  h_idxs_floor = h_idxs.int()
970
  w_idxs_floor = w_idxs.int()
@@ -992,13 +951,11 @@ class MossVLVisionModel(MossVLPreTrainedModel):
992
  ]
993
 
994
  for i in range(4):
995
- idx_list[i].extend(indices[i].tolist())
996
- weight_list[i].extend(weights[i].tolist())
997
 
998
- idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
999
- weight_tensor = torch.tensor(
1000
- weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
1001
- )
1002
  pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
1003
  patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
1004
 
@@ -1127,7 +1084,9 @@ class MossVLTextModel(MossVLPreTrainedModel):
1127
  vision_position_ids: Optional[torch.LongTensor] = None,
1128
  use_cache: Optional[bool] = None,
1129
  cache_position: Optional[torch.LongTensor] = None,
1130
- vision_cache_position: Optional[torch.LongTensor] = None,
 
 
1131
  **kwargs: Unpack[FlashAttentionKwargs],
1132
  ) -> Union[tuple, BaseModelOutputWithPast]:
1133
  """
@@ -1140,9 +1099,15 @@ class MossVLTextModel(MossVLPreTrainedModel):
1140
  Attention mask for cross-attention between text and vision. Shape: `(batch_size, 1, text_seq_len, vision_seq_len)`.
1141
  vision_position_ids (`torch.LongTensor`, *optional*):
1142
  Position IDs for vision tokens used in cross-attention. Shape: `(batch_size, vision_seq_len)`.
1143
- vision_cache_position (`torch.LongTensor`, *optional*):
1144
- Cache position for vision tokens. Shape: `(vision_seq_len,)`.
1145
  """
 
 
 
 
 
 
1146
  if (input_ids is None) ^ (inputs_embeds is not None):
1147
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1148
 
@@ -1164,7 +1129,7 @@ class MossVLTextModel(MossVLPreTrainedModel):
1164
 
1165
  attention_mask = create_causal_mask(
1166
  config=self.config,
1167
- input_embeds=inputs_embeds,
1168
  attention_mask=attention_mask,
1169
  cache_position=cache_position,
1170
  past_key_values=past_key_values,
@@ -1179,14 +1144,15 @@ class MossVLTextModel(MossVLPreTrainedModel):
1179
  # Compute vision position embeddings (for cross-attention key/value) if needed
1180
  vision_position_embeddings = None
1181
 
1182
- if vision_cache_position is None:
1183
- # TODO:use cache_position now
1184
- vision_cache_position = cache_position
1185
-
1186
  if cross_attention_states is not None:
1187
  if vision_position_ids is not None:
1188
  vision_position_embeddings = self.rotary_emb(cross_attention_states, vision_position_ids)
1189
 
 
 
 
 
 
1190
 
1191
  for idx, decoder_layer in enumerate(self.layers):
1192
  # For text-only path we should skip cross attention layers.
@@ -1211,17 +1177,35 @@ class MossVLTextModel(MossVLPreTrainedModel):
1211
  cross_attention_states=cross_attention_states,
1212
  cross_attention_mask=cross_attention_mask,
1213
  vision_position_ids=vision_position_ids,
1214
- vision_cache_position=vision_cache_position,
1215
  vision_position_embeddings=vision_position_embeddings,
 
1216
  **kwargs,
1217
  )
1218
- hidden_states = layer_outputs
 
 
 
 
 
 
1219
 
1220
  hidden_states = self.norm(hidden_states)
 
 
 
 
 
 
 
 
 
 
1221
 
1222
  return BaseModelOutputWithPast(
1223
  last_hidden_state=hidden_states,
1224
  past_key_values=past_key_values,
 
 
1225
  )
1226
 
1227
 
@@ -1240,8 +1224,6 @@ class MossVLModel(MossVLPreTrainedModel):
1240
  super().__init__(config)
1241
  self.visual = MossVLVisionModel._from_config(config.vision_config)
1242
  self.language_model = MossVLTextModel._from_config(config.text_config)
1243
- self.vision_token_info = None # cache vision_token_info here for decode stage
1244
- self.rope_deltas = None # cache position deltas for decode stage
1245
 
1246
  # Learnable Separator Token: inserted after each image/frame's vision tokens
1247
  # Initialized from LLM's separator_token_init_id embedding
@@ -1550,7 +1532,7 @@ class MossVLModel(MossVLPreTrainedModel):
1550
  continue
1551
 
1552
  # Collect repetition counts for all frames in this sample
1553
- repeats = []
1554
  for media in medias:
1555
  num_frames = media.get('num_frames', 1)
1556
  length = media['length']
@@ -1565,25 +1547,30 @@ class MossVLModel(MossVLPreTrainedModel):
1565
 
1566
  # In convert_packed_to_batch we enforce strictly regular frames
1567
  # so we can assume all frames have the same number of tokens
1568
- repeats.extend([tokens_per_frame_with_sep] * num_frames)
 
 
 
 
 
 
 
1569
 
1570
- num_valid_frames = len(repeats)
1571
  if num_valid_frames == 0:
1572
  continue
1573
 
1574
  # If cross_attention_mask has more frames (e.g. padded), slice it
1575
  # If it has fewer (shouldn't happen), slice repeats
1576
  valid_mask_frames = min(num_valid_frames, cross_attention_mask.shape[-1])
 
1577
  if valid_mask_frames < num_valid_frames:
1578
- repeats = repeats[:valid_mask_frames]
1579
 
1580
  # Extract valid columns for this sample
1581
  # (1, text_len, valid_mask_frames)
1582
  source_mask = cross_attention_mask[i, :, :, :valid_mask_frames]
1583
 
1584
- # Convert repeats to tensor
1585
- repeats_tensor = torch.tensor(repeats, device=cross_attention_mask.device)
1586
-
1587
  # Expand using repeat_interleave
1588
  # output shape: (1, text_len, sum(repeats))
1589
  expanded_mask = source_mask.repeat_interleave(repeats_tensor, dim=-1)
@@ -1602,7 +1589,8 @@ class MossVLModel(MossVLPreTrainedModel):
1602
  self,
1603
  input_ids: torch.Tensor,
1604
  attention_mask: Optional[torch.Tensor] = None,
1605
- cache_position: Optional[torch.LongTensor] = None,
 
1606
  ) -> torch.Tensor:
1607
  """
1608
  Compute 3D position IDs for text tokens with special handling for image tokens.
@@ -1617,7 +1605,7 @@ class MossVLModel(MossVLPreTrainedModel):
1617
  Args:
1618
  input_ids: (batch_size, seq_len)
1619
  attention_mask: (batch_size, seq_len), optional
1620
- cache_position: (seq_len,), position in cache
1621
 
1622
  Returns:
1623
  position_ids: (3, batch_size, seq_len)
@@ -1626,25 +1614,17 @@ class MossVLModel(MossVLPreTrainedModel):
1626
  device = input_ids.device
1627
  image_token_id = self.config.image_token_id
1628
 
1629
- # Decode stage: use cached rope_deltas for fast computation
1630
- if cache_position is not None and cache_position[0] != 0 and self.rope_deltas is not None:
1631
- # In decode, position = cache_position + rope_deltas
1632
- # rope_deltas is per-sample: (batch_size,)
1633
  position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
1634
- position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # (batch, seq_len)
1635
-
1636
- # Add cache_position offset
1637
- if cache_position is not None:
1638
- position_ids = position_ids + cache_position[0]
1639
-
1640
- # Add rope_deltas (position offset due to vision tokens)
1641
- # self.rope_deltas shape: (batch_size,), need to unsqueeze for broadcasting
1642
- position_ids = position_ids + self.rope_deltas.unsqueeze(1) # (batch, seq_len)
1643
-
1644
- # Expand to 3D: (3, batch, seq_len)
1645
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1646
-
1647
- return position_ids
1648
 
1649
  # Prefill stage: compute full position_ids with image token awareness
1650
  # Vectorized implementation
@@ -1723,7 +1703,7 @@ class MossVLModel(MossVLPreTrainedModel):
1723
  rope_deltas: (batch_size,) - position offset due to vision tokens
1724
  """
1725
  batch_size, max_vision_seq_len, _ = cross_attention_states.shape
1726
- device = position_ids.device if position_ids is not None else input_ids.device
1727
  image_token_id = self.config.image_token_id
1728
  merge_size = self.visual.spatial_merge_size
1729
 
@@ -1731,15 +1711,14 @@ class MossVLModel(MossVLPreTrainedModel):
1731
  # We need to flatten the nested vision_token_info structure to align with image tokens in input_ids
1732
 
1733
  # Find all image tokens in text: (num_occurrences, 2) -> [batch_idx, seq_idx]
1734
- image_token_indices = (input_ids == image_token_id).nonzero().to(device)
1735
 
1736
  # Flatten vision_token_info to parallel lists
1737
  # We assume the order of medias in vision_token_info matches the appearance of image tokens in input_ids
1738
- flat_eff_h = []
1739
- flat_eff_w = []
1740
- flat_vis_starts = []
1741
- flat_batch_indices = []
1742
-
1743
  # Processing metadata on CPU (fast enough for typical batch sizes)
1744
  for b_idx, info in enumerate(vision_token_info):
1745
  medias = info.get('medias', [])
@@ -1750,13 +1729,11 @@ class MossVLModel(MossVLPreTrainedModel):
1750
  start = media['start']
1751
  tok_per_frame = media['vision_tokens_per_frame']
1752
  stride = tok_per_frame + 1 # +1 for separator
1753
-
1754
- # Generate entries for all frames in this media
1755
- for f in range(num_frames):
1756
- flat_eff_h.append(eh)
1757
- flat_eff_w.append(ew)
1758
- flat_vis_starts.append(start + f * stride)
1759
- flat_batch_indices.append(b_idx)
1760
 
1761
  # Pre-allocate output
1762
  vision_pos_ids = torch.zeros(
@@ -1766,17 +1743,19 @@ class MossVLModel(MossVLPreTrainedModel):
1766
  )
1767
 
1768
  # Handle case where no image tokens or info
1769
- if len(flat_eff_h) == 0 or len(image_token_indices) == 0:
1770
  rope_deltas = position_ids.max(dim=0).values.max(dim=-1).values + 1 - input_ids.shape[1]
1771
  return vision_pos_ids, position_ids, rope_deltas
1772
 
 
 
 
 
1773
  # Align lengths (handle truncation if text has fewer tokens or vice versa)
1774
- num_matches = min(len(flat_eff_h), len(image_token_indices))
1775
-
1776
- # Convert to tensors
1777
- flat_eff_h = torch.tensor(flat_eff_h[:num_matches], device=device, dtype=torch.long)
1778
- flat_eff_w = torch.tensor(flat_eff_w[:num_matches], device=device, dtype=torch.long)
1779
- flat_vis_starts = torch.tensor(flat_vis_starts[:num_matches], device=device, dtype=torch.long)
1780
 
1781
  # Get corresponding text positions
1782
  target_indices = image_token_indices[:num_matches]
@@ -1942,53 +1921,6 @@ class MossVLModel(MossVLPreTrainedModel):
1942
  )
1943
  return vision_embeds, vision_token_info
1944
 
1945
- def get_vision_features_chunked(
1946
- self,
1947
- pixel_values: torch.FloatTensor,
1948
- grid_thw: Optional[torch.LongTensor] = None,
1949
- media_nums_per_sample: Optional[List[int]] = None,
1950
- vision_chunked_length: Optional[int] = None,
1951
- ):
1952
- """
1953
- Chunk the visual encoder forward by media items, then reuse the same
1954
- packed-to-batch conversion logic. This keeps output semantics identical
1955
- to `get_vision_features(...)` while reducing prefill memory pressure.
1956
- """
1957
- if (
1958
- vision_chunked_length is None
1959
- or vision_chunked_length <= 0
1960
- or grid_thw is None
1961
- or grid_thw.shape[0] <= vision_chunked_length
1962
- ):
1963
- return self.get_vision_features(pixel_values, grid_thw, media_nums_per_sample)
1964
-
1965
- pixel_values = pixel_values.type(self.visual.dtype)
1966
- token_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist()
1967
-
1968
- hidden_state_chunks = []
1969
- token_offset = 0
1970
- for media_start in range(0, grid_thw.shape[0], vision_chunked_length):
1971
- media_end = min(media_start + vision_chunked_length, grid_thw.shape[0])
1972
- chunk_grid_thw = grid_thw[media_start:media_end]
1973
- chunk_token_count = sum(token_counts[media_start:media_end])
1974
- chunk_pixel_values = pixel_values[token_offset:token_offset + chunk_token_count]
1975
- token_offset += chunk_token_count
1976
-
1977
- hidden_state_chunks.append(
1978
- self.visual(
1979
- chunk_pixel_values,
1980
- grid_thw=chunk_grid_thw,
1981
- )
1982
- )
1983
-
1984
- hidden_states = torch.cat(hidden_state_chunks, dim=0)
1985
- vision_embeds, vision_token_info = self.convert_packed_to_batch(
1986
- hidden_states,
1987
- grid_thw,
1988
- media_nums_per_sample,
1989
- )
1990
- return vision_embeds, vision_token_info
1991
-
1992
 
1993
 
1994
  @auto_docstring
@@ -2004,7 +1936,11 @@ class MossVLModel(MossVLPreTrainedModel):
2004
  media_nums_per_sample: Optional[List[int]] = None,
2005
  vision_position_ids: Optional[torch.LongTensor] = None,
2006
  cross_attention_mask: Optional[torch.Tensor] = None,
2007
- cache_position: Optional[torch.LongTensor] = None,
 
 
 
 
2008
  **kwargs: Unpack[TransformersKwargs],
2009
  ) -> Union[tuple, BaseModelOutputWithPast]:
2010
  """
@@ -2021,11 +1957,20 @@ class MossVLModel(MossVLPreTrainedModel):
2021
  cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
2022
  Attention mask for cross-attention between text and vision. Controls which vision tokens each text
2023
  token can attend to, enforcing causal visibility for video frames.
2024
- vision_chunked_length (`int`, *optional*):
2025
- Number of media items to process per visual-encoder chunk during prefill. This only changes
2026
- how the vision tower is executed, not the final prompt or decoding logic.
 
 
 
2027
  """
2028
- vision_chunked_length = kwargs.pop("vision_chunked_length", None)
 
 
 
 
 
 
2029
  if (input_ids is None) ^ (inputs_embeds is not None):
2030
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
2031
 
@@ -2034,8 +1979,7 @@ class MossVLModel(MossVLPreTrainedModel):
2034
 
2035
  # Process vision features (images and videos are already merged by processor)
2036
  cross_attention_states = None
2037
- num_vision_tokens = 0
2038
-
2039
  if pixel_values is not None:
2040
  # Determine batch size
2041
  batch_size = inputs_embeds.shape[0]
@@ -2050,23 +1994,12 @@ class MossVLModel(MossVLPreTrainedModel):
2050
 
2051
  # Process all vision inputs together through VIT
2052
  # pixel_values and grid_thw are already ordered by appearance in text
2053
- vision_embeds, vision_token_info = self.get_vision_features_chunked(
2054
- pixel_values,
2055
- grid_thw,
2056
- media_nums_per_sample,
2057
- vision_chunked_length=vision_chunked_length,
2058
  )
2059
 
2060
  # vision_embeds: [batch_size, max_seq_len, hidden_size]
2061
  cross_attention_states = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2062
- num_vision_tokens = cross_attention_states.shape[1]
2063
-
2064
- # Cache vision_token_info for decode stage (prefill only)
2065
-
2066
- self.vision_token_info = vision_token_info
2067
- else:
2068
- # In decode stage, use cached vision_token_info
2069
- vision_token_info = self.vision_token_info
2070
 
2071
  # Generate 3D position IDs for text if not provided
2072
  if position_ids is None:
@@ -2075,7 +2008,8 @@ class MossVLModel(MossVLPreTrainedModel):
2075
  position_ids = self.compute_position_ids(
2076
  input_ids=input_ids,
2077
  attention_mask=attention_mask,
2078
- cache_position=cache_position,
 
2079
  )
2080
 
2081
  # Compute cross_attention_mask, vision_position_ids, and full_text_row_masked_out_mask
@@ -2099,8 +2033,6 @@ class MossVLModel(MossVLPreTrainedModel):
2099
  (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
2100
  )
2101
  cross_attention_mask = cross_attention_mask * full_text_row_masked_out_mask
2102
-
2103
-
2104
 
2105
  if vision_position_ids is None and cross_attention_states is not None and input_ids is not None:
2106
  vision_position_ids, position_ids, rope_deltas = self.compute_vision_position_ids(
@@ -2110,14 +2042,6 @@ class MossVLModel(MossVLPreTrainedModel):
2110
  cross_attention_states,
2111
  attention_mask
2112
  )
2113
-
2114
- # Cache rope_deltas for decode stage (only in prefill)
2115
- # rope_deltas = max_position - sequence_length
2116
- # This allows fast position computation in decode: position = cache_position + rope_deltas
2117
- if cache_position is not None and cache_position[0] == 0:
2118
- self.rope_deltas = rope_deltas
2119
-
2120
-
2121
 
2122
  outputs = self.language_model(
2123
  input_ids=None,
@@ -2130,16 +2054,33 @@ class MossVLModel(MossVLPreTrainedModel):
2130
  cross_attention_mask=cross_attention_mask,
2131
  vision_position_ids=vision_position_ids,
2132
  full_text_row_masked_out_mask=full_text_row_masked_out_mask,
 
 
 
2133
  **kwargs,
2134
  )
2135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2136
  return MossVLModelOutputWithPast(
2137
  last_hidden_state=outputs.last_hidden_state,
2138
  past_key_values=outputs.past_key_values,
2139
  hidden_states=outputs.hidden_states,
2140
  attentions=outputs.attentions,
2141
- vision_token_info=self.vision_token_info,
2142
- rope_deltas=self.rope_deltas,
2143
  )
2144
 
2145
 
@@ -2161,7 +2102,6 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2161
  super().__init__(config)
2162
  self.model = MossVLModel(config)
2163
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
2164
- self._offline_processor_lock = threading.RLock()
2165
 
2166
  self.post_init()
2167
 
@@ -2219,9 +2159,12 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2219
  media_nums_per_sample: Optional[List[int]] = None,
2220
  vision_position_ids: Optional[torch.LongTensor] = None,
2221
  cross_attention_mask: Optional[torch.Tensor] = None,
2222
- cache_position: Optional[torch.LongTensor] = None,
2223
- vision_chunked_length: Optional[int] = None,
2224
  logits_to_keep: Union[int, torch.Tensor] = 0,
 
 
 
2225
  **kwargs: Unpack[TransformersKwargs],
2226
  ) -> Union[tuple, CausalLMOutputWithPast]:
2227
  """
@@ -2238,10 +2181,13 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2238
  cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
2239
  Attention mask for cross-attention between text and vision. Controls which vision tokens each text
2240
  token can attend to, enforcing causal visibility for video frames.
2241
- vision_chunked_length (`int`, *optional*):
2242
- Number of media items to process per visual-encoder chunk during prefill. This only changes
2243
- how the vision tower is executed, not the final prompt or decoding logic.
 
 
2244
  """
 
2245
  outputs = self.model(
2246
  input_ids=input_ids,
2247
  pixel_values=pixel_values,
@@ -2253,12 +2199,17 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2253
  cross_attention_mask=cross_attention_mask,
2254
  past_key_values=past_key_values,
2255
  inputs_embeds=inputs_embeds,
 
 
 
 
 
2256
  cache_position=cache_position,
2257
- vision_chunked_length=vision_chunked_length,
2258
  **kwargs,
2259
  )
2260
 
2261
- hidden_states = outputs[0]
 
2262
 
2263
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
2264
  logits = self.lm_head(hidden_states[:, slice_indices, :])
@@ -2267,6 +2218,11 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2267
  if labels is not None:
2268
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
2269
 
 
 
 
 
 
2270
  return MossVLCausalLMOutputWithPast(
2271
  loss=loss,
2272
  logits=logits,
@@ -2283,15 +2239,15 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2283
  past_key_values=None,
2284
  attention_mask=None,
2285
  inputs_embeds=None,
2286
- cache_position=None,
2287
  position_ids=None,
2288
  use_cache=True,
2289
  pixel_values=None,
2290
  grid_thw=None,
2291
  media_nums_per_sample=None, # One video is one meida.
2292
  vision_position_ids=None,
 
 
2293
  cross_attention_mask=None,
2294
- vision_chunked_length=None,
2295
  **kwargs,
2296
  ):
2297
  """
@@ -2304,12 +2260,12 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2304
  Args:
2305
  media_nums_per_sample: One video counts as one media item (regardless of frame count)
2306
  """
 
2307
  model_inputs = super().prepare_inputs_for_generation(
2308
  input_ids,
2309
  past_key_values=past_key_values,
2310
  attention_mask=attention_mask,
2311
  inputs_embeds=inputs_embeds,
2312
- cache_position=cache_position,
2313
  position_ids=position_ids,
2314
  pixel_values=pixel_values,
2315
  grid_thw=grid_thw,
@@ -2318,21 +2274,27 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2318
  **kwargs,
2319
  )
2320
 
 
 
 
 
 
2321
 
2322
- # For decoding stage, if position_ids are generated by GenerationMixin (2D),
2323
- # we can set them to None to let forward recompute them from cache_position.
2324
  model_inputs["position_ids"] = None
 
 
2325
 
2326
  # Handle cross attention mask
2327
  if cross_attention_mask is not None:
2328
- # Slice to current sequence length on text dimension (dim=2)
2329
- # Shape: [batch, 1, text_len, vision_len] -> [batch, 1, cache_len, vision_len]
2330
- cross_attention_mask = cross_attention_mask[:, :, -cache_position.shape[0]:, :]
2331
  model_inputs["cross_attention_mask"] = cross_attention_mask
2332
 
2333
- # Vision inputs are only needed in prefill stage (cache_position[0] == 0)
2334
  # In decode stage, vision features are retrieved from cross attention cache
2335
- if cache_position[0] != 0:
2336
  model_inputs["pixel_values"] = None
2337
  model_inputs["grid_thw"] = None
2338
  model_inputs["media_nums_per_sample"] = None
@@ -2341,7 +2303,6 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2341
  else:
2342
  # In prefill stage, include all vision-related inputs
2343
  model_inputs["vision_position_ids"] = vision_position_ids
2344
- model_inputs["vision_chunked_length"] = vision_chunked_length
2345
 
2346
  return model_inputs
2347
 
@@ -2362,1026 +2323,16 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
2362
  **kwargs,
2363
  )
2364
 
2365
- # Extend cross_attention_mask for the new token
2366
- # Copy the last token's mask pattern for the newly generated token
2367
  if cross_attention_mask_prev is not None:
2368
- model_kwargs["cross_attention_mask"] = torch.cat(
2369
- [cross_attention_mask_prev, cross_attention_mask_prev[:, :, -1:, :]],
2370
- dim=2 # Concatenate along text sequence dimension
2371
- )
 
 
2372
 
2373
  return model_kwargs
2374
 
2375
- @staticmethod
2376
- def _offline_flatten_content_with_vision_tokens(content) -> str:
2377
- if isinstance(content, str):
2378
- return content
2379
- if not isinstance(content, list):
2380
- return str(content) if content else ""
2381
-
2382
- parts = []
2383
- for item in content:
2384
- if isinstance(item, dict):
2385
- if item.get("type") == "image" or "image" in item:
2386
- parts.append("<|image|>")
2387
- elif item.get("type") == "video" or "video" in item:
2388
- parts.append("<|video|>")
2389
- if "text" in item:
2390
- parts.append(str(item["text"]))
2391
- elif isinstance(item, str):
2392
- parts.append(item)
2393
- return "".join(parts)
2394
-
2395
- @staticmethod
2396
- def _offline_sanitize_prompt_text(processor, text: Any) -> str:
2397
- if text is None:
2398
- return ""
2399
-
2400
- sanitized = str(text)
2401
- replacements = [
2402
- (getattr(processor, "image_placeholder", None), ""),
2403
- (getattr(processor, "video_placeholder", None), ""),
2404
- (getattr(processor, "image_token", None), ""),
2405
- (getattr(processor, "video_token", None), ""),
2406
- ]
2407
- for needle, replacement in replacements:
2408
- if needle:
2409
- sanitized = sanitized.replace(needle, replacement)
2410
- return sanitized.lstrip("\n")
2411
-
2412
- def _offline_sanitize_message_content(self, processor, content: Any) -> Any:
2413
- if isinstance(content, str):
2414
- return self._offline_sanitize_prompt_text(processor, content)
2415
- if not isinstance(content, list):
2416
- return content
2417
-
2418
- sanitized_items = []
2419
- for item in content:
2420
- if isinstance(item, dict):
2421
- item_copy = dict(item)
2422
- if "text" in item_copy:
2423
- item_copy["text"] = self._offline_sanitize_prompt_text(processor, item_copy.get("text"))
2424
- sanitized_items.append(item_copy)
2425
- elif isinstance(item, str):
2426
- sanitized_items.append(self._offline_sanitize_prompt_text(processor, item))
2427
- else:
2428
- sanitized_items.append(item)
2429
- return sanitized_items
2430
-
2431
- def _offline_prepare_messages(self, processor, query: Dict[str, Any]) -> List[Dict[str, Any]]:
2432
- messages = query.get("messages")
2433
- if messages:
2434
- prepared_messages = []
2435
- for message in messages:
2436
- if not isinstance(message, dict):
2437
- continue
2438
- message_copy = dict(message)
2439
- message_copy["content"] = self._offline_sanitize_message_content(
2440
- processor,
2441
- message_copy.get("content", ""),
2442
- )
2443
- prepared_messages.append(message_copy)
2444
- if prepared_messages:
2445
- return prepared_messages
2446
-
2447
- prompt = self._offline_sanitize_prompt_text(processor, query.get("prompt", ""))
2448
- images = list(query.get("images") or [])
2449
- videos = list(query.get("videos") or [])
2450
-
2451
- content = []
2452
- for image in images:
2453
- content.append({"type": "image", "image": image})
2454
- for video in videos:
2455
- content.append({"type": "video", "video": video})
2456
- if prompt:
2457
- content.append({"type": "text", "text": prompt.lstrip("\n")})
2458
-
2459
- if not content:
2460
- content = [{"type": "text", "text": ""}]
2461
-
2462
- return [{"role": "user", "content": content}]
2463
-
2464
- @staticmethod
2465
- def _offline_extract_content_parts(content: Any) -> Tuple[str, List[Any], List[Any]]:
2466
- if isinstance(content, str):
2467
- return content, [], []
2468
- if not isinstance(content, list):
2469
- return (str(content) if content else ""), [], []
2470
-
2471
- text_parts: List[str] = []
2472
- images: List[Any] = []
2473
- videos: List[Any] = []
2474
- for item in content:
2475
- if isinstance(item, dict):
2476
- if item.get("type") == "image" or "image" in item or "image_url" in item:
2477
- image = item.get("image") or item.get("image_url")
2478
- if image is not None:
2479
- images.append(image)
2480
- elif item.get("type") == "video" or "video" in item or "video_path" in item:
2481
- video = item.get("video") or item.get("video_path")
2482
- if video is not None:
2483
- videos.append(video)
2484
-
2485
- if "text" in item and item["text"] is not None:
2486
- text_parts.append(str(item["text"]))
2487
- elif isinstance(item, str):
2488
- text_parts.append(item)
2489
-
2490
- return "".join(text_parts), images, videos
2491
-
2492
- @staticmethod
2493
- def _offline_resolve_use_template(query: Dict[str, Any]) -> bool:
2494
- return bool(query.get("use_template", False))
2495
-
2496
- def _offline_prepare_input_text(
2497
- self,
2498
- processor,
2499
- messages: List[Dict[str, Any]],
2500
- use_template: bool,
2501
- ) -> str:
2502
- if not use_template:
2503
- if any(isinstance(message, dict) and message.get("role") == "system" for message in messages):
2504
- raise ValueError("system messages require use_template=True")
2505
-
2506
- parts = ["<|im_start|>"]
2507
- for message in messages:
2508
- role = message.get("role", "user") if isinstance(message, dict) else "user"
2509
- content = message.get("content", "") if isinstance(message, dict) else message
2510
- text, msg_images, msg_videos = self._offline_extract_content_parts(content)
2511
-
2512
- if role == "user":
2513
- media_tokens = ""
2514
- if msg_images:
2515
- media_tokens += "<|image|>" * len(msg_images)
2516
- if msg_videos:
2517
- media_tokens += "<|video|>" * len(msg_videos)
2518
- parts.append(f"{media_tokens}{text}")
2519
- else:
2520
- parts.append(f"{text}<|im_end|>")
2521
- return "".join(parts)
2522
-
2523
- processed_messages = []
2524
- for message in messages:
2525
- message_copy = dict(message)
2526
- message_copy["content"] = self._offline_flatten_content_with_vision_tokens(
2527
- message_copy.get("content", "")
2528
- )
2529
- processed_messages.append(message_copy)
2530
- return processor.apply_chat_template(
2531
- processed_messages,
2532
- tokenize=False,
2533
- add_generation_prompt=True,
2534
- )
2535
-
2536
- @staticmethod
2537
- def _offline_collect_media(messages: List[Dict[str, Any]]) -> tuple[List[Any], List[Any]]:
2538
- all_images: List[Any] = []
2539
- all_videos: List[Any] = []
2540
-
2541
- for message in messages:
2542
- content = message.get("content")
2543
- if isinstance(content, list):
2544
- for item in content:
2545
- if not isinstance(item, dict):
2546
- continue
2547
- if item.get("type") == "image" or "image" in item:
2548
- image = item.get("image") or item.get("image_url")
2549
- if image is not None:
2550
- all_images.append(image)
2551
- elif item.get("type") == "video" or "video" in item:
2552
- video = item.get("video")
2553
- if video is not None:
2554
- all_videos.append(video)
2555
-
2556
- return all_images, all_videos
2557
-
2558
- def _offline_build_processor_kwargs(
2559
- self,
2560
- input_text: Union[str, List[str]],
2561
- all_images: List[Any],
2562
- all_videos: List[Any],
2563
- media_kwargs: Dict[str, Any],
2564
- ) -> Dict[str, Any]:
2565
- processor_kwargs: Dict[str, Any] = {
2566
- "text": input_text,
2567
- "images": all_images or None,
2568
- "videos": all_videos or None,
2569
- "return_tensors": "pt",
2570
- "padding": False,
2571
- }
2572
-
2573
- if media_kwargs.get("min_pixels") is not None:
2574
- processor_kwargs["min_pixels"] = media_kwargs["min_pixels"]
2575
- if media_kwargs.get("max_pixels") is not None:
2576
- processor_kwargs["max_pixels"] = media_kwargs["max_pixels"]
2577
- if media_kwargs.get("video_fps") is not None:
2578
- processor_kwargs["video_fps"] = media_kwargs["video_fps"]
2579
-
2580
- min_frames = media_kwargs.get("min_frames", media_kwargs.get("video_minlen"))
2581
- max_frames = media_kwargs.get("max_frames", media_kwargs.get("video_maxlen"))
2582
- if min_frames is not None:
2583
- processor_kwargs["min_frames"] = min_frames
2584
- if max_frames is not None:
2585
- processor_kwargs["max_frames"] = max_frames
2586
-
2587
- return processor_kwargs
2588
-
2589
- def _offline_prepare_inputs(self, processor, query: Dict[str, Any]):
2590
- messages = self._offline_prepare_messages(processor, query)
2591
- input_text = self._offline_prepare_input_text(
2592
- processor,
2593
- messages,
2594
- use_template=self._offline_resolve_use_template(query),
2595
- )
2596
- all_images, all_videos = self._offline_collect_media(messages)
2597
- media_kwargs = dict(query.get("media_kwargs") or {})
2598
- processor_kwargs = self._offline_build_processor_kwargs(
2599
- input_text,
2600
- all_images,
2601
- all_videos,
2602
- media_kwargs,
2603
- )
2604
-
2605
- image_proc = getattr(processor, "image_processor", None)
2606
- video_proc = getattr(processor, "video_processor", None)
2607
- modified_multi_image = False
2608
- modified_video = False
2609
-
2610
- with self._offline_processor_lock:
2611
- try:
2612
- multi_image_max_pixels = media_kwargs.get("multi_image_max_pixels")
2613
- if multi_image_max_pixels is not None and image_proc is not None:
2614
- orig_multi_image_max_pixels = getattr(image_proc, "multi_image_max_pixels", None)
2615
- image_proc.multi_image_max_pixels = multi_image_max_pixels
2616
- modified_multi_image = True
2617
-
2618
- video_max_pixels = media_kwargs.get("video_max_pixels")
2619
- if video_max_pixels is not None and video_proc is not None:
2620
- orig_video_max_pixels = getattr(video_proc, "video_max_pixels", None)
2621
- video_proc.video_max_pixels = video_max_pixels
2622
- modified_video = True
2623
-
2624
- inputs = processor(**processor_kwargs)
2625
- finally:
2626
- if modified_multi_image and image_proc is not None:
2627
- image_proc.multi_image_max_pixels = orig_multi_image_max_pixels
2628
- if modified_video and video_proc is not None:
2629
- video_proc.video_max_pixels = orig_video_max_pixels
2630
-
2631
- text_device = self.get_input_embeddings().weight.device
2632
- vision_device = self.visual.patch_embed.proj.weight.device
2633
- vision_input_keys = {"pixel_values", "grid_thw"}
2634
-
2635
- for key, value in list(inputs.items()):
2636
- if not isinstance(value, torch.Tensor):
2637
- continue
2638
-
2639
- target_device = vision_device if key in vision_input_keys else text_device
2640
- moved_value = value.to(target_device)
2641
- if moved_value.dtype == torch.float32:
2642
- moved_value = moved_value.to(torch.bfloat16)
2643
- inputs[key] = moved_value
2644
-
2645
- return inputs, input_text
2646
-
2647
- def _offline_build_session_messages(
2648
- self,
2649
- processor,
2650
- query: Dict[str, Any],
2651
- session_messages: List[Dict[str, Any]],
2652
- ) -> List[Dict[str, Any]]:
2653
- has_explicit_messages = bool(query.get("messages"))
2654
- if has_explicit_messages and not query.get("append_messages_to_session", False):
2655
- base_messages: List[Dict[str, Any]] = []
2656
- else:
2657
- base_messages = [dict(message) for message in session_messages]
2658
-
2659
- turn_messages = self._offline_prepare_messages(processor, query)
2660
- has_system_message = any(
2661
- isinstance(message, dict) and message.get("role") == "system"
2662
- for message in (base_messages + turn_messages)
2663
- )
2664
-
2665
- should_add_system_prompt = (
2666
- query.get("use_default_system_prompt", False)
2667
- or query.get("system_prompt") is not None
2668
- or query.get("system_prompt_type") is not None
2669
- or query.get("thinking_mode") is not None
2670
- )
2671
-
2672
- if not base_messages and not has_system_message and should_add_system_prompt:
2673
- system_prompt = self._offline_resolve_system_prompt(query, turn_messages)
2674
- if system_prompt is not None:
2675
- base_messages.append({"role": "system", "content": system_prompt})
2676
-
2677
- return base_messages + turn_messages
2678
-
2679
- @staticmethod
2680
- def _offline_query_contains_video(query: Dict[str, Any], messages: List[Dict[str, Any]]) -> bool:
2681
- if query.get("videos"):
2682
- return True
2683
-
2684
- for message in messages:
2685
- content = message.get("content") if isinstance(message, dict) else None
2686
- if isinstance(content, list) and any(
2687
- isinstance(item, dict) and (item.get("type") == "video" or "video" in item)
2688
- for item in content
2689
- ):
2690
- return True
2691
- return False
2692
-
2693
- @staticmethod
2694
- def _offline_normalize_thinking_mode(value: Optional[str]) -> str:
2695
- if value is None:
2696
- return "no_thinking"
2697
-
2698
- normalized = _OFFLINE_THINKING_MODE_ALIASES.get(str(value).strip().lower())
2699
- if normalized is None:
2700
- allowed = ", ".join(sorted(set(_OFFLINE_THINKING_MODE_ALIASES.values())))
2701
- raise ValueError(f"Unsupported thinking_mode: {value!r}. Supported values: {allowed}")
2702
- return normalized
2703
-
2704
- @staticmethod
2705
- def _offline_normalize_system_prompt_type(value: Optional[str], has_video: bool) -> str:
2706
- if value is None:
2707
- return "video" if has_video else "text_image"
2708
-
2709
- normalized_key = str(value).strip().lower().replace("/", "_").replace(" ", "_")
2710
- while "__" in normalized_key:
2711
- normalized_key = normalized_key.replace("__", "_")
2712
-
2713
- normalized = _OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.get(normalized_key)
2714
- if normalized is None:
2715
- allowed = ", ".join(sorted(set(_OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.values())))
2716
- raise ValueError(f"Unsupported system_prompt_type: {value!r}. Supported values: {allowed}")
2717
- return normalized
2718
-
2719
- def _offline_resolve_system_prompt(
2720
- self,
2721
- query: Dict[str, Any],
2722
- turn_messages: List[Dict[str, Any]],
2723
- ) -> Optional[str]:
2724
- explicit_system_prompt = query.get("system_prompt")
2725
- if explicit_system_prompt is not None:
2726
- return str(explicit_system_prompt)
2727
-
2728
- has_video = self._offline_query_contains_video(query, turn_messages)
2729
- thinking_mode = self._offline_normalize_thinking_mode(query.get("thinking_mode"))
2730
- system_prompt_type = self._offline_normalize_system_prompt_type(
2731
- query.get("system_prompt_type"),
2732
- has_video=has_video,
2733
- )
2734
- return _OFFLINE_SYSTEM_PROMPTS[thinking_mode][system_prompt_type]
2735
-
2736
- @staticmethod
2737
- def _offline_finalize_session_messages(
2738
- working_messages: List[Dict[str, Any]],
2739
- assistant_text: str,
2740
- ) -> List[Dict[str, Any]]:
2741
- next_messages = [dict(message) for message in working_messages]
2742
- next_messages.append({"role": "assistant", "content": assistant_text})
2743
- return next_messages
2744
-
2745
- def _offline_prepare_generation(self, processor, query: Dict[str, Any]):
2746
- inputs, input_text = self._offline_prepare_inputs(processor, query)
2747
- generate_kwargs = dict(query.get("generate_kwargs") or {})
2748
-
2749
- max_new_tokens = generate_kwargs.pop("max_new_tokens", 1024)
2750
- temperature = generate_kwargs.pop("temperature", 1.0)
2751
- top_k = generate_kwargs.pop("top_k", 50)
2752
- top_p = generate_kwargs.pop("top_p", 1.0)
2753
- repetition_penalty = generate_kwargs.pop("repetition_penalty", 1.0)
2754
- do_sample = generate_kwargs.pop("do_sample", False)
2755
- vision_chunked_length = generate_kwargs.pop("vision_chunked_length", None)
2756
-
2757
- if temperature is None:
2758
- temperature = 1.0
2759
- if temperature <= 0:
2760
- temperature = 1.0
2761
- do_sample = False
2762
-
2763
- call_kwargs = dict(
2764
- max_new_tokens=max_new_tokens,
2765
- temperature=temperature,
2766
- top_k=top_k,
2767
- top_p=top_p,
2768
- repetition_penalty=repetition_penalty,
2769
- do_sample=do_sample,
2770
- vision_chunked_length=vision_chunked_length,
2771
- **generate_kwargs,
2772
- )
2773
- return inputs, input_text, call_kwargs
2774
-
2775
- @staticmethod
2776
- def _offline_normalize_shared_mapping(
2777
- values: List[Dict[str, Any]],
2778
- mapping_name: str,
2779
- ) -> Dict[str, Any]:
2780
- normalized_values = [dict(value or {}) for value in values]
2781
- if not normalized_values:
2782
- return {}
2783
-
2784
- all_keys = set()
2785
- for value in normalized_values:
2786
- all_keys.update(value.keys())
2787
-
2788
- merged: Dict[str, Any] = {}
2789
- mismatched_keys: List[str] = []
2790
- for key in sorted(all_keys):
2791
- unique_values = {repr(value.get(key)) for value in normalized_values}
2792
- if len(unique_values) > 1:
2793
- mismatched_keys.append(key)
2794
- else:
2795
- merged[key] = normalized_values[0].get(key)
2796
-
2797
- if mismatched_keys:
2798
- mismatch_text = ", ".join(mismatched_keys)
2799
- raise ValueError(
2800
- f"All batch queries must share the same {mapping_name}. "
2801
- f"Mismatched keys: {mismatch_text}"
2802
- )
2803
- return merged
2804
-
2805
- def _offline_prepare_batch_generation(
2806
- self,
2807
- processor,
2808
- queries: List[Dict[str, Any]],
2809
- session_states: Optional[List[List[Dict[str, Any]]]] = None,
2810
- ):
2811
- if not queries:
2812
- raise ValueError("`queries` must contain at least one query.")
2813
-
2814
- if session_states is None:
2815
- session_states = [[] for _ in queries]
2816
- elif len(session_states) != len(queries):
2817
- raise ValueError("`session_states` must have the same length as `queries`.")
2818
-
2819
- working_messages_list: List[List[Dict[str, Any]]] = []
2820
- input_texts: List[str] = []
2821
- all_images_per_query: List[List[Any]] = []
2822
- all_videos_per_query: List[List[Any]] = []
2823
-
2824
- for query, session_state in zip(queries, session_states):
2825
- if not isinstance(query, dict):
2826
- raise TypeError("Each batch query must be a dict.")
2827
- if query.get("stop_offline_generate"):
2828
- raise ValueError("`stop_offline_generate` is not supported in offline_batch_generate.")
2829
- if query.get("stream_output", query.get("stream", False)):
2830
- raise ValueError("Streaming is not supported in offline_batch_generate.")
2831
- if query.get("cancel_current_generate") or query.get("stop_generation"):
2832
- raise ValueError("Cancel / stop controls are not supported in offline_batch_generate.")
2833
-
2834
- current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
2835
- working_messages = self._offline_build_session_messages(
2836
- processor,
2837
- query,
2838
- current_session,
2839
- )
2840
- working_messages_list.append(working_messages)
2841
- input_texts.append(
2842
- self._offline_prepare_input_text(
2843
- processor,
2844
- working_messages,
2845
- use_template=self._offline_resolve_use_template(query),
2846
- )
2847
- )
2848
-
2849
- all_images, all_videos = self._offline_collect_media(working_messages)
2850
- all_images_per_query.append(all_images)
2851
- all_videos_per_query.append(all_videos)
2852
-
2853
- media_kwargs = self._offline_normalize_shared_mapping(
2854
- [query.get("media_kwargs") or {} for query in queries],
2855
- mapping_name="media_kwargs",
2856
- )
2857
- processor_kwargs = self._offline_build_processor_kwargs(
2858
- input_text=input_texts,
2859
- all_images=[image for images in all_images_per_query for image in images],
2860
- all_videos=[video for videos in all_videos_per_query for video in videos],
2861
- media_kwargs=media_kwargs,
2862
- )
2863
- processor_kwargs["padding"] = True
2864
-
2865
- image_proc = getattr(processor, "image_processor", None)
2866
- video_proc = getattr(processor, "video_processor", None)
2867
- tokenizer = getattr(processor, "tokenizer", None)
2868
- modified_multi_image = False
2869
- modified_video = False
2870
- orig_padding_side = None
2871
-
2872
- with self._offline_processor_lock:
2873
- try:
2874
- multi_image_max_pixels = media_kwargs.get("multi_image_max_pixels")
2875
- if multi_image_max_pixels is not None and image_proc is not None:
2876
- orig_multi_image_max_pixels = getattr(image_proc, "multi_image_max_pixels", None)
2877
- image_proc.multi_image_max_pixels = multi_image_max_pixels
2878
- modified_multi_image = True
2879
-
2880
- video_max_pixels = media_kwargs.get("video_max_pixels")
2881
- if video_max_pixels is not None and video_proc is not None:
2882
- orig_video_max_pixels = getattr(video_proc, "video_max_pixels", None)
2883
- video_proc.video_max_pixels = video_max_pixels
2884
- modified_video = True
2885
-
2886
- if tokenizer is not None and hasattr(tokenizer, "padding_side"):
2887
- orig_padding_side = tokenizer.padding_side
2888
- tokenizer.padding_side = "left"
2889
-
2890
- inputs = processor(**processor_kwargs)
2891
- finally:
2892
- if modified_multi_image and image_proc is not None:
2893
- image_proc.multi_image_max_pixels = orig_multi_image_max_pixels
2894
- if modified_video and video_proc is not None:
2895
- video_proc.video_max_pixels = orig_video_max_pixels
2896
- if tokenizer is not None and orig_padding_side is not None:
2897
- tokenizer.padding_side = orig_padding_side
2898
-
2899
- text_device = self.get_input_embeddings().weight.device
2900
- vision_device = self.visual.patch_embed.proj.weight.device
2901
- vision_input_keys = {"pixel_values", "grid_thw"}
2902
-
2903
- for key, value in list(inputs.items()):
2904
- if not isinstance(value, torch.Tensor):
2905
- continue
2906
-
2907
- target_device = vision_device if key in vision_input_keys else text_device
2908
- moved_value = value.to(target_device)
2909
- if moved_value.dtype == torch.float32:
2910
- moved_value = moved_value.to(torch.bfloat16)
2911
- inputs[key] = moved_value
2912
-
2913
- generate_kwargs = self._offline_normalize_shared_mapping(
2914
- [query.get("generate_kwargs") or {} for query in queries],
2915
- mapping_name="generate_kwargs",
2916
- )
2917
- max_new_tokens = generate_kwargs.pop("max_new_tokens", 1024)
2918
- temperature = generate_kwargs.pop("temperature", 1.0)
2919
- top_k = generate_kwargs.pop("top_k", 50)
2920
- top_p = generate_kwargs.pop("top_p", 1.0)
2921
- repetition_penalty = generate_kwargs.pop("repetition_penalty", 1.0)
2922
- do_sample = generate_kwargs.pop("do_sample", False)
2923
- vision_chunked_length = generate_kwargs.pop("vision_chunked_length", None)
2924
-
2925
- if temperature is None:
2926
- temperature = 1.0
2927
- if temperature <= 0:
2928
- temperature = 1.0
2929
- do_sample = False
2930
-
2931
- call_kwargs = dict(
2932
- max_new_tokens=max_new_tokens,
2933
- temperature=temperature,
2934
- top_k=top_k,
2935
- top_p=top_p,
2936
- repetition_penalty=repetition_penalty,
2937
- do_sample=do_sample,
2938
- vision_chunked_length=vision_chunked_length,
2939
- **generate_kwargs,
2940
- )
2941
- return inputs, input_texts, working_messages_list, call_kwargs
2942
-
2943
- def offline_batch_generate(
2944
- self,
2945
- processor,
2946
- queries: List[Dict[str, Any]],
2947
- session_states: Optional[List[List[Dict[str, Any]]]] = None,
2948
- vision_chunked_length: int = 64,
2949
- ) -> Dict[str, Any]:
2950
- """
2951
- Batch offline generation for multiple independent samples.
2952
-
2953
- This method supports:
2954
- - batched single-turn generation
2955
- - batched multi-turn continuation through `session_states`
2956
-
2957
- It intentionally does not support queue-style controls such as:
2958
- - `stream_output`
2959
- - `cancel_current_generate`
2960
- - `stop_generation`
2961
- - `stop_offline_generate`
2962
- """
2963
- if not queries:
2964
- return {"results": [], "session_states": []}
2965
-
2966
- prepared_queries = [dict(query) for query in queries]
2967
- for query in prepared_queries:
2968
- generate_kwargs = query.setdefault("generate_kwargs", {})
2969
- generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
2970
- if session_states is None:
2971
- session_states = [[] for _ in prepared_queries]
2972
- elif len(session_states) != len(prepared_queries):
2973
- raise ValueError("`session_states` must have the same length as `queries`.")
2974
-
2975
- tokenizer = getattr(processor, "tokenizer", None)
2976
- bucketed_indices: Dict[Any, List[int]] = {}
2977
- for index, (query, session_state) in enumerate(zip(prepared_queries, session_states)):
2978
- current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
2979
- working_messages = self._offline_build_session_messages(processor, query, current_session)
2980
- input_text = self._offline_prepare_input_text(
2981
- processor,
2982
- working_messages,
2983
- use_template=self._offline_resolve_use_template(query),
2984
- )
2985
-
2986
- if tokenizer is not None:
2987
- token_ids = tokenizer(input_text, add_special_tokens=False)["input_ids"]
2988
- bucket_key = len(token_ids)
2989
- else:
2990
- bucket_key = len(input_text)
2991
- bucketed_indices.setdefault(bucket_key, []).append(index)
2992
-
2993
- results: List[Optional[Dict[str, Any]]] = [None] * len(prepared_queries)
2994
- next_session_states: List[Optional[List[Dict[str, Any]]]] = [None] * len(prepared_queries)
2995
-
2996
- for bucket_indices in bucketed_indices.values():
2997
- bucket_queries = [prepared_queries[index] for index in bucket_indices]
2998
- bucket_session_states = [session_states[index] for index in bucket_indices]
2999
- inputs, input_texts, working_messages_list, call_kwargs = self._offline_prepare_batch_generation(
3000
- processor,
3001
- bucket_queries,
3002
- session_states=bucket_session_states,
3003
- )
3004
-
3005
- with torch.no_grad():
3006
- outputs = self.generate(
3007
- **inputs,
3008
- **call_kwargs,
3009
- )
3010
-
3011
- input_seq_len = inputs["input_ids"].shape[1]
3012
- generated_tokens = outputs[:, input_seq_len:]
3013
- decoded_texts = processor.batch_decode(generated_tokens, skip_special_tokens=True)
3014
-
3015
- for local_index, (query, input_text, working_messages, text) in enumerate(
3016
- zip(bucket_queries, input_texts, working_messages_list, decoded_texts)
3017
- ):
3018
- original_index = bucket_indices[local_index]
3019
- if query.get("persist_session", True):
3020
- next_session_state = self._offline_finalize_session_messages(working_messages, text)
3021
- else:
3022
- next_session_state = working_messages
3023
- next_session_states[original_index] = next_session_state
3024
- results[original_index] = {
3025
- "index": original_index,
3026
- "text": text,
3027
- "input_text": input_text,
3028
- "messages": working_messages,
3029
- }
3030
-
3031
- return {
3032
- "results": [item for item in results if item is not None],
3033
- "session_states": [item for item in next_session_states if item is not None],
3034
- }
3035
-
3036
- def _offline_generate_one(self, processor, query: Dict[str, Any]) -> str:
3037
- working_messages = self._offline_build_session_messages(processor, query, [])
3038
- generation_query = dict(query)
3039
- generation_query["messages"] = working_messages
3040
- inputs, _, call_kwargs = self._offline_prepare_generation(processor, generation_query)
3041
-
3042
- with torch.no_grad():
3043
- outputs = self.generate(
3044
- **inputs,
3045
- **call_kwargs,
3046
- )
3047
-
3048
- new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
3049
- return processor.decode(new_tokens, skip_special_tokens=True)
3050
-
3051
- @staticmethod
3052
- def _offline_capture_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
3053
- if target is None or not overrides:
3054
- return None
3055
- return {name: copy.deepcopy(getattr(target, name)) for name in overrides}
3056
-
3057
- @staticmethod
3058
- def _offline_apply_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> None:
3059
- if target is None or not overrides:
3060
- return
3061
- for name, value in overrides.items():
3062
- setattr(target, name, copy.deepcopy(value))
3063
-
3064
- @staticmethod
3065
- def _offline_restore_processor_attrs(target, snapshot: Optional[Dict[str, Any]]) -> None:
3066
- if target is None or snapshot is None:
3067
- return
3068
- for name, value in snapshot.items():
3069
- setattr(target, name, copy.deepcopy(value))
3070
-
3071
- def _offline_generate_one_with_processor_overrides(
3072
- self,
3073
- processor,
3074
- query: Dict[str, Any],
3075
- image_processor_overrides: Optional[Dict[str, Any]] = None,
3076
- video_processor_overrides: Optional[Dict[str, Any]] = None,
3077
- ) -> str:
3078
- image_proc = getattr(processor, "image_processor", None)
3079
- video_proc = getattr(processor, "video_processor", None)
3080
- image_snapshot = self._offline_capture_processor_attrs(image_proc, image_processor_overrides)
3081
- video_snapshot = self._offline_capture_processor_attrs(video_proc, video_processor_overrides)
3082
-
3083
- with self._offline_processor_lock:
3084
- try:
3085
- self._offline_apply_processor_attrs(image_proc, image_processor_overrides)
3086
- self._offline_apply_processor_attrs(video_proc, video_processor_overrides)
3087
- return self._offline_generate_one(processor, query)
3088
- finally:
3089
- self._offline_restore_processor_attrs(image_proc, image_snapshot)
3090
- self._offline_restore_processor_attrs(video_proc, video_snapshot)
3091
-
3092
- def offline_image_generate(
3093
- self,
3094
- processor,
3095
- prompt: str = "",
3096
- image: Any = None,
3097
- *,
3098
- shortest_edge: int = 4096,
3099
- longest_edge: int = 16777216,
3100
- multi_image_max_pixels: int = 201326592,
3101
- patch_size: int = 16,
3102
- temporal_patch_size: int = 1,
3103
- merge_size: int = 2,
3104
- image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3105
- image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3106
- max_new_tokens: int = 1024,
3107
- temperature: float = 1.0,
3108
- top_k: int = 50,
3109
- top_p: float = 1.0,
3110
- repetition_penalty: float = 1.0,
3111
- do_sample: bool = False,
3112
- vision_chunked_length: int = 64,
3113
- use_template: bool = False,
3114
- thinking_mode: Optional[str] = None,
3115
- system_prompt_type: Optional[str] = None,
3116
- system_prompt: Optional[str] = None,
3117
- ) -> str:
3118
- """
3119
- Single-image offline generation with explicit image preprocessor defaults.
3120
-
3121
- The default values mirror `preprocessor_config.json` so README examples can
3122
- surface the full image preprocessing setup without requiring a batch wrapper.
3123
- """
3124
- if image is None:
3125
- raise ValueError("`image` is required.")
3126
- query: Dict[str, Any] = {
3127
- "prompt": prompt,
3128
- "images": [image],
3129
- "videos": [],
3130
- "media_kwargs": {
3131
- "min_pixels": shortest_edge,
3132
- "max_pixels": longest_edge,
3133
- "multi_image_max_pixels": multi_image_max_pixels,
3134
- },
3135
- "generate_kwargs": {
3136
- "max_new_tokens": max_new_tokens,
3137
- "temperature": temperature,
3138
- "top_k": top_k,
3139
- "top_p": top_p,
3140
- "repetition_penalty": repetition_penalty,
3141
- "do_sample": do_sample,
3142
- "vision_chunked_length": vision_chunked_length,
3143
- },
3144
- "use_template": use_template,
3145
- }
3146
- if thinking_mode is not None:
3147
- query["thinking_mode"] = thinking_mode
3148
- if system_prompt_type is not None:
3149
- query["system_prompt_type"] = system_prompt_type
3150
- if system_prompt is not None:
3151
- query["system_prompt"] = system_prompt
3152
-
3153
- image_processor_overrides = {
3154
- "size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
3155
- "multi_image_max_pixels": multi_image_max_pixels,
3156
- "patch_size": patch_size,
3157
- "temporal_patch_size": temporal_patch_size,
3158
- "merge_size": merge_size,
3159
- "image_mean": list(image_mean) if image_mean is not None else None,
3160
- "image_std": list(image_std) if image_std is not None else None,
3161
- }
3162
- return self._offline_generate_one_with_processor_overrides(
3163
- processor,
3164
- query,
3165
- image_processor_overrides=image_processor_overrides,
3166
- )
3167
-
3168
- def offline_video_generate(
3169
- self,
3170
- processor,
3171
- prompt: str = "",
3172
- video: Any = None,
3173
- *,
3174
- shortest_edge: int = 4096,
3175
- longest_edge: int = 16777216,
3176
- video_max_pixels: int = 201326592,
3177
- patch_size: int = 16,
3178
- temporal_patch_size: int = 1,
3179
- merge_size: int = 2,
3180
- video_fps: float = 1.0,
3181
- min_frames: int = 1,
3182
- max_frames: int = 256,
3183
- num_extract_threads: int = 4,
3184
- image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3185
- image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
3186
- max_new_tokens: int = 1024,
3187
- temperature: float = 1.0,
3188
- top_k: int = 50,
3189
- top_p: float = 1.0,
3190
- repetition_penalty: float = 1.0,
3191
- do_sample: bool = False,
3192
- vision_chunked_length: int = 64,
3193
- use_template: bool = False,
3194
- thinking_mode: Optional[str] = None,
3195
- system_prompt_type: Optional[str] = None,
3196
- system_prompt: Optional[str] = None,
3197
- ) -> str:
3198
- """
3199
- Single-video offline generation with explicit video preprocessor defaults.
3200
-
3201
- The default values mirror `video_preprocessor_config.json` so README examples
3202
- can show a standalone video entry point with the effective preprocessing knobs.
3203
- """
3204
- if video is None:
3205
- raise ValueError("`video` is required.")
3206
- query: Dict[str, Any] = {
3207
- "prompt": prompt,
3208
- "images": [],
3209
- "videos": [video],
3210
- "media_kwargs": {
3211
- "min_pixels": shortest_edge,
3212
- "max_pixels": longest_edge,
3213
- "video_max_pixels": video_max_pixels,
3214
- "video_fps": video_fps,
3215
- "min_frames": min_frames,
3216
- "max_frames": max_frames,
3217
- },
3218
- "generate_kwargs": {
3219
- "max_new_tokens": max_new_tokens,
3220
- "temperature": temperature,
3221
- "top_k": top_k,
3222
- "top_p": top_p,
3223
- "repetition_penalty": repetition_penalty,
3224
- "do_sample": do_sample,
3225
- "vision_chunked_length": vision_chunked_length,
3226
- },
3227
- "use_template": use_template,
3228
- }
3229
- if thinking_mode is not None:
3230
- query["thinking_mode"] = thinking_mode
3231
- if system_prompt_type is not None:
3232
- query["system_prompt_type"] = system_prompt_type
3233
- if system_prompt is not None:
3234
- query["system_prompt"] = system_prompt
3235
-
3236
- video_processor_overrides = {
3237
- "size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
3238
- "video_max_pixels": video_max_pixels,
3239
- "patch_size": patch_size,
3240
- "temporal_patch_size": temporal_patch_size,
3241
- "merge_size": merge_size,
3242
- "video_fps": video_fps,
3243
- "min_frames": min_frames,
3244
- "max_frames": max_frames,
3245
- "num_extract_threads": num_extract_threads,
3246
- "image_mean": list(image_mean) if image_mean is not None else None,
3247
- "image_std": list(image_std) if image_std is not None else None,
3248
- }
3249
- return self._offline_generate_one_with_processor_overrides(
3250
- processor,
3251
- query,
3252
- video_processor_overrides=video_processor_overrides,
3253
- )
3254
-
3255
- def offline_generate(
3256
- self,
3257
- processor,
3258
- new_queries: "queue.Queue[dict]",
3259
- output_text_queue: "queue.Queue[str]",
3260
- vision_chunked_length: int = 64,
3261
- ) -> None:
3262
- """
3263
- HF-style offline inference wrapper aligned with the previous backend output path.
3264
-
3265
- This method intentionally reuses the checkpoint's existing processor and
3266
- `generate()` flow so that outputs stay consistent with the old external
3267
- backend inference implementation.
3268
-
3269
- Supported query keys include:
3270
- - `prompt` / `messages`
3271
- - `images` / `videos`
3272
- - `media_kwargs` / `generate_kwargs`
3273
- - `use_template` to switch between backend-style pretrain prompting
3274
- (`False`, default for base) and tokenizer chat template prompting (`True`)
3275
- - `thinking_mode` (`no_thinking` or `deep_thinking`, plus compatible aliases)
3276
- - `system_prompt_type` (`text_image` or `video`, plus compatible aliases)
3277
- - `system_prompt` for an explicit override
3278
- - `stream_output` / `stream`
3279
- - `reset_session` / `clear_history`
3280
- - `cancel_current_generate` / `stop_generation` / `stop_offline_generate`
3281
- """
3282
- buffered_queries: List[Dict[str, Any]] = []
3283
- session_messages: List[Dict[str, Any]] = []
3284
-
3285
- while True:
3286
- if buffered_queries:
3287
- query = buffered_queries.pop(0)
3288
- else:
3289
- query = new_queries.get()
3290
- if not isinstance(query, dict):
3291
- continue
3292
-
3293
- if query.get("stop_offline_generate"):
3294
- break
3295
-
3296
- if query.get("reset_session") or query.get("clear_history"):
3297
- session_messages = []
3298
-
3299
- try:
3300
- generate_kwargs = query.setdefault("generate_kwargs", {})
3301
- generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
3302
- working_messages = self._offline_build_session_messages(
3303
- processor,
3304
- query,
3305
- session_messages,
3306
- )
3307
-
3308
- generation_query = dict(query)
3309
- generation_query["messages"] = working_messages
3310
- inputs, input_text, call_kwargs = self._offline_prepare_generation(processor, generation_query)
3311
-
3312
- stream_output = bool(query.get("stream_output", query.get("stream", False)))
3313
- cancel_event = threading.Event()
3314
- stopping_criteria = StoppingCriteriaList([_OfflineCancelStoppingCriteria(cancel_event)])
3315
- generation_state: Dict[str, Any] = {}
3316
-
3317
- if stream_output:
3318
- output_text_queue.put("<|round_start|>")
3319
- streamer = _OfflineQueueStreamer(getattr(processor, "tokenizer", processor), output_text_queue)
3320
- else:
3321
- streamer = None
3322
-
3323
- def _run_generation():
3324
- try:
3325
- with torch.no_grad():
3326
- generation_state["outputs"] = self.generate(
3327
- **inputs,
3328
- stopping_criteria=stopping_criteria,
3329
- streamer=streamer,
3330
- **call_kwargs,
3331
- )
3332
- except Exception as exc:
3333
- generation_state["exception"] = exc
3334
-
3335
- worker = threading.Thread(target=_run_generation, daemon=True)
3336
- worker.start()
3337
-
3338
- stop_conversation_after_turn = False
3339
- while worker.is_alive():
3340
- try:
3341
- control_query = new_queries.get(timeout=0.1)
3342
- except queue.Empty:
3343
- continue
3344
-
3345
- if not isinstance(control_query, dict):
3346
- continue
3347
-
3348
- if control_query.get("cancel_current_generate") or control_query.get("stop_generation"):
3349
- cancel_event.set()
3350
- stop_conversation_after_turn = stop_conversation_after_turn or control_query.get("stop_offline_generate", False)
3351
- continue
3352
-
3353
- if control_query.get("stop_offline_generate"):
3354
- cancel_event.set()
3355
- stop_conversation_after_turn = True
3356
- continue
3357
-
3358
- buffered_queries.append(control_query)
3359
-
3360
- worker.join()
3361
- was_cancelled = cancel_event.is_set()
3362
-
3363
- if "exception" in generation_state:
3364
- raise generation_state["exception"]
3365
-
3366
- if stream_output and streamer is not None:
3367
- text = "".join(streamer.collected_chunks)
3368
- else:
3369
- outputs = generation_state["outputs"]
3370
- new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
3371
- text = processor.decode(new_tokens, skip_special_tokens=True)
3372
- output_text_queue.put(text)
3373
-
3374
- if query.get("persist_session", True) and (not was_cancelled or query.get("persist_cancelled_turn", False)):
3375
- session_messages = self._offline_finalize_session_messages(working_messages, text)
3376
-
3377
- output_text_queue.put("<|round_end|>")
3378
-
3379
- if stop_conversation_after_turn:
3380
- break
3381
- except Exception as exc:
3382
- output_text_queue.put(f"[ERROR] {exc}")
3383
- output_text_queue.put("<|round_end|>")
3384
-
3385
 
3386
  __all__ = [
3387
  "MossVLVisionModel",
 
14
  # limitations under the License.
15
  """PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
16
 
 
17
  from dataclasses import dataclass
18
+ from typing import Any, Callable, Optional, Union, Tuple, List
 
 
19
 
20
  import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
 
24
+ from transformers import initialization as init
25
+
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.generation import GenerationMixin
 
 
29
  from transformers.integrations import use_kernel_forward_from_hub
30
  from transformers.masking_utils import create_causal_mask
31
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 
36
  from transformers.processing_utils import Unpack
37
  from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging
38
  from transformers.utils.deprecation import deprecate_kwarg
39
+ from transformers.utils.generic import is_flash_attention_requested
40
+ from transformers.utils.output_capturing import OutputRecorder
41
 
42
  from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionConfig
43
 
 
45
 
46
  logger = logging.get_logger(__name__)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @dataclass
50
  class MossVLModelOutputWithPast(ModelOutput):
 
144
 
145
 
146
  class MossVLVisionRotaryEmbedding(nn.Module):
147
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
148
 
149
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
150
  super().__init__()
151
+ # Keep dim / theta so that `_init_weights` can rebuild `inv_freq` after
152
+ # from_pretrained materializes the module (it is a non-persistent buffer
153
+ # and therefore never populated by the checkpoint).
154
+ self.dim = dim
155
+ self.theta = theta
156
+ inv_freq = self.compute_inv_freq()
157
  self.register_buffer("inv_freq", inv_freq, persistent=False)
158
 
159
+ def compute_inv_freq(self) -> torch.Tensor:
160
+ return 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
161
+
162
  def forward(self, seqlen: int) -> torch.Tensor:
163
  seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
164
  freqs = torch.outer(seq, self.inv_freq)
 
187
  self.act_fn = nn.GELU()
188
  self.linear_fc2 = nn.Linear(self.input_hidden_size, config.out_hidden_size)
189
 
190
+ def forward(
191
+ self,
192
+ last_hidden_state: torch.Tensor,
193
+ deepstack_features: Optional[List[torch.Tensor]] = None,
194
+ ) -> torch.Tensor:
195
  # 1. Collect all features: [last_hidden_state, deepstack_1, deepstack_2, ...]
196
  # self.norms[0] corresponds to last_hidden_state
197
  # self.norms[1:] corresponds to deepstack_features
198
+ if deepstack_features is None:
199
+ deepstack_features = []
200
  all_inputs = [last_hidden_state] + deepstack_features
201
 
202
  # 2. Apply Norm independently
 
305
  key_states = key_states.transpose(0, 1).unsqueeze(0)
306
  value_states = value_states.transpose(0, 1).unsqueeze(0)
307
 
308
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
309
+ self.config._attn_implementation, eager_attention_forward
310
+ )
311
 
312
+ if is_flash_attention_requested(self.config):
313
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
314
  attn_output, _ = attention_interface(
315
  self,
 
388
 
389
  def __init__(self, config: MossVLTextConfig, device=None):
390
  super().__init__()
 
 
 
 
 
391
  self.max_seq_len_cached = config.max_position_embeddings
392
  self.original_max_seq_len = config.max_position_embeddings
393
 
394
  self.config = config
395
+ rope_parameters = getattr(config, "rope_parameters", None)
396
+ if rope_parameters is None:
397
+ rope_parameters = getattr(config, "rope_scaling", None) or {"rope_type": "default"}
398
 
399
+ self.rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default"))
400
+ rope_init_fn: Callable = self.compute_default_rope_parameters
401
+ if self.rope_type != "default":
402
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
403
+
404
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
405
  self.register_buffer("inv_freq", inv_freq, persistent=False)
406
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
407
 
408
+ self.mrope_section = rope_parameters.get("mrope_section", [24, 20, 20])
409
 
410
+ @staticmethod
411
+ def compute_default_rope_parameters(
412
+ config: Optional[MossVLTextConfig] = None,
413
+ device: Optional[torch.device] = None,
414
+ seq_len: Optional[int] = None,
415
+ ) -> tuple[torch.Tensor, float]:
416
+ rope_parameters = getattr(config, "rope_parameters", None) or {}
417
+ base = rope_parameters.get("rope_theta", getattr(config, "rope_theta", 10000.0))
418
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
419
+ partial_rotary_factor = rope_parameters.get(
420
+ "partial_rotary_factor", getattr(config, "partial_rotary_factor", 1.0)
421
+ )
422
+ dim = int(head_dim * partial_rotary_factor)
423
+
424
+ attention_factor = 1.0
425
+ inv_freq = 1.0 / (
426
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
427
+ )
428
+ return inv_freq, attention_factor
429
 
430
  def apply_interleaved_mrope(self, freqs, mrope_section):
431
  """Apply interleaved MRoPE to 3D rotary embeddings.
 
447
  @torch.no_grad()
448
  @dynamic_rope_update
449
  def forward(self, x, position_ids):
 
450
  if position_ids.ndim == 2:
451
  position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
452
 
 
547
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
548
 
549
  if past_key_values is not None:
550
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
 
551
 
552
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
553
+ self.config._attn_implementation, eager_attention_forward
554
+ )
555
 
556
  attn_output, attn_weights = attention_interface(
557
  self,
 
600
  attention_mask: Optional[torch.Tensor] = None,
601
  past_key_values: Optional[Cache] = None,
602
  use_cache: bool = None,
603
+ cache_position: Optional[torch.LongTensor] = None,
604
  query_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
605
  vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
606
  **kwargs,
 
634
  if past_key_values is not None:
635
  # if we have a new image + new tokens, we only computed key_states on that new image
636
  # we still update the cross key states, past_image, new_image. And use it!
637
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
 
 
638
 
639
  elif cache_position[0] != 0:
640
  key_states, value_states = (
 
646
  "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
647
  )
648
 
649
+ if is_flash_attention_requested(self.config):
650
+ # Cross attention still relies on an explicit dense mask.
651
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
652
+ else:
653
+ attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
654
+ self.config._attn_implementation, eager_attention_forward
655
+ )
656
 
657
  attn_output, attn_weights = attention_interface(
658
  self,
 
713
  use_cache: Optional[bool] = False,
714
  cache_position: Optional[torch.LongTensor] = None,
715
  vision_position_ids: Optional[torch.LongTensor] = None,
 
716
  vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
717
+ output_attentions: bool = False,
718
  **kwargs: Unpack[TransformersKwargs],
719
+ ) -> tuple[torch.Tensor, ...]:
720
  # Self Attention
721
  residual = hidden_states
722
  hidden_states = self.input_layernorm(hidden_states)
723
+ hidden_states, attn_weights = self.self_attn(
724
  hidden_states=hidden_states,
725
  attention_mask=attention_mask,
726
  past_key_values=past_key_values,
 
735
  hidden_states = self.post_attention_layernorm(hidden_states)
736
  hidden_states = self.mlp(hidden_states)
737
  hidden_states = residual + hidden_states
738
+
739
+ outputs = (hidden_states,)
740
+ if output_attentions:
741
+ outputs += (attn_weights,)
742
+ return outputs
743
 
744
 
745
  class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
 
775
  use_cache: Optional[bool] = False,
776
  cache_position: Optional[torch.LongTensor] = None,
777
  vision_position_ids: Optional[torch.LongTensor] = None,
 
778
  vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
779
+ output_attentions: bool = False,
780
  **kwargs: Unpack[TransformersKwargs],
781
+ ) -> tuple[torch.Tensor, ...]:
782
  # Cross Attention
783
  residual = hidden_states
784
  hidden_states = self.input_layernorm(hidden_states)
785
 
786
+ hidden_states, attn_weights = self.cross_attn(
787
  hidden_states=hidden_states,
788
  cross_attention_states=cross_attention_states,
789
  attention_mask=cross_attention_mask,
790
  past_key_values=past_key_values,
791
  use_cache=use_cache,
792
+ cache_position=cache_position,
793
  query_position_embeddings=position_embeddings,
794
  vision_position_embeddings=vision_position_embeddings,
795
  )
 
806
  hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
807
 
808
  hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
809
+
810
+ outputs = (hidden_states,)
811
+ if output_attentions:
812
+ outputs += (attn_weights,)
813
+ return outputs
814
 
815
 
816
 
 
836
 
837
  def _init_weights(self, module):
838
  """Initialize the weights.
 
 
 
839
  """
840
+ super()._init_weights(module)
841
+ if isinstance(module, MossVLVisionRotaryEmbedding):
842
+ init.copy_(module.inv_freq, module.compute_inv_freq())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
 
844
 
845
 
 
915
 
916
  def fast_pos_embed_interpolate(self, grid_thw):
917
  grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
918
+ device = self.pos_embed.weight.device
919
+ dtype = self.pos_embed.weight.dtype
920
 
921
+ idx_parts = [[] for _ in range(4)]
922
+ weight_parts = [[] for _ in range(4)]
923
 
924
  for t, h, w in zip(grid_ts, grid_hs, grid_ws):
925
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
926
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
927
 
928
  h_idxs_floor = h_idxs.int()
929
  w_idxs_floor = w_idxs.int()
 
951
  ]
952
 
953
  for i in range(4):
954
+ idx_parts[i].append(indices[i])
955
+ weight_parts[i].append(weights[i])
956
 
957
+ idx_tensor = torch.stack([torch.cat(parts) for parts in idx_parts]).to(dtype=torch.long)
958
+ weight_tensor = torch.stack([torch.cat(parts) for parts in weight_parts]).to(dtype=dtype)
 
 
959
  pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
960
  patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
961
 
 
1084
  vision_position_ids: Optional[torch.LongTensor] = None,
1085
  use_cache: Optional[bool] = None,
1086
  cache_position: Optional[torch.LongTensor] = None,
1087
+ output_attentions: Optional[bool] = None,
1088
+ output_hidden_states: Optional[bool] = None,
1089
+ return_dict: Optional[bool] = None,
1090
  **kwargs: Unpack[FlashAttentionKwargs],
1091
  ) -> Union[tuple, BaseModelOutputWithPast]:
1092
  """
 
1099
  Attention mask for cross-attention between text and vision. Shape: `(batch_size, 1, text_seq_len, vision_seq_len)`.
1100
  vision_position_ids (`torch.LongTensor`, *optional*):
1101
  Position IDs for vision tokens used in cross-attention. Shape: `(batch_size, vision_seq_len)`.
1102
+ cache_position (`torch.LongTensor`, *optional*):
1103
+ Absolute cache positions for the current text tokens during incremental decoding.
1104
  """
1105
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1106
+ output_hidden_states = (
1107
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1108
+ )
1109
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1110
+
1111
  if (input_ids is None) ^ (inputs_embeds is not None):
1112
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1113
 
 
1129
 
1130
  attention_mask = create_causal_mask(
1131
  config=self.config,
1132
+ inputs_embeds=inputs_embeds,
1133
  attention_mask=attention_mask,
1134
  cache_position=cache_position,
1135
  past_key_values=past_key_values,
 
1144
  # Compute vision position embeddings (for cross-attention key/value) if needed
1145
  vision_position_embeddings = None
1146
 
 
 
 
 
1147
  if cross_attention_states is not None:
1148
  if vision_position_ids is not None:
1149
  vision_position_embeddings = self.rotary_emb(cross_attention_states, vision_position_ids)
1150
 
1151
+ all_hidden_states = () if output_hidden_states else None
1152
+ all_attentions = () if output_attentions else None
1153
+
1154
+ if output_hidden_states:
1155
+ all_hidden_states += (hidden_states,)
1156
 
1157
  for idx, decoder_layer in enumerate(self.layers):
1158
  # For text-only path we should skip cross attention layers.
 
1177
  cross_attention_states=cross_attention_states,
1178
  cross_attention_mask=cross_attention_mask,
1179
  vision_position_ids=vision_position_ids,
 
1180
  vision_position_embeddings=vision_position_embeddings,
1181
+ output_attentions=output_attentions,
1182
  **kwargs,
1183
  )
1184
+ hidden_states = layer_outputs[0]
1185
+
1186
+ if output_attentions:
1187
+ all_attentions += (layer_outputs[1],)
1188
+
1189
+ if output_hidden_states:
1190
+ all_hidden_states += (hidden_states,)
1191
 
1192
  hidden_states = self.norm(hidden_states)
1193
+ if output_hidden_states:
1194
+ all_hidden_states = all_hidden_states[:-1] + (hidden_states,)
1195
+
1196
+ if not return_dict:
1197
+ outputs = (hidden_states, past_key_values)
1198
+ if output_hidden_states:
1199
+ outputs += (all_hidden_states,)
1200
+ if output_attentions:
1201
+ outputs += (all_attentions,)
1202
+ return outputs
1203
 
1204
  return BaseModelOutputWithPast(
1205
  last_hidden_state=hidden_states,
1206
  past_key_values=past_key_values,
1207
+ hidden_states=all_hidden_states,
1208
+ attentions=all_attentions,
1209
  )
1210
 
1211
 
 
1224
  super().__init__(config)
1225
  self.visual = MossVLVisionModel._from_config(config.vision_config)
1226
  self.language_model = MossVLTextModel._from_config(config.text_config)
 
 
1227
 
1228
  # Learnable Separator Token: inserted after each image/frame's vision tokens
1229
  # Initialized from LLM's separator_token_init_id embedding
 
1532
  continue
1533
 
1534
  # Collect repetition counts for all frames in this sample
1535
+ repeats_parts = []
1536
  for media in medias:
1537
  num_frames = media.get('num_frames', 1)
1538
  length = media['length']
 
1547
 
1548
  # In convert_packed_to_batch we enforce strictly regular frames
1549
  # so we can assume all frames have the same number of tokens
1550
+ repeats_parts.append(
1551
+ torch.full(
1552
+ (num_frames,),
1553
+ tokens_per_frame_with_sep,
1554
+ dtype=torch.long,
1555
+ device=cross_attention_mask.device,
1556
+ )
1557
+ )
1558
 
1559
+ num_valid_frames = sum(part.numel() for part in repeats_parts)
1560
  if num_valid_frames == 0:
1561
  continue
1562
 
1563
  # If cross_attention_mask has more frames (e.g. padded), slice it
1564
  # If it has fewer (shouldn't happen), slice repeats
1565
  valid_mask_frames = min(num_valid_frames, cross_attention_mask.shape[-1])
1566
+ repeats_tensor = torch.cat(repeats_parts)
1567
  if valid_mask_frames < num_valid_frames:
1568
+ repeats_tensor = repeats_tensor[:valid_mask_frames]
1569
 
1570
  # Extract valid columns for this sample
1571
  # (1, text_len, valid_mask_frames)
1572
  source_mask = cross_attention_mask[i, :, :, :valid_mask_frames]
1573
 
 
 
 
1574
  # Expand using repeat_interleave
1575
  # output shape: (1, text_len, sum(repeats))
1576
  expanded_mask = source_mask.repeat_interleave(repeats_tensor, dim=-1)
 
1589
  self,
1590
  input_ids: torch.Tensor,
1591
  attention_mask: Optional[torch.Tensor] = None,
1592
+ past_key_values: Optional[Cache] = None,
1593
+ rope_deltas: Optional[torch.LongTensor] = None,
1594
  ) -> torch.Tensor:
1595
  """
1596
  Compute 3D position IDs for text tokens with special handling for image tokens.
 
1605
  Args:
1606
  input_ids: (batch_size, seq_len)
1607
  attention_mask: (batch_size, seq_len), optional
1608
+ past_key_values: cache object used to infer decode offset from the current text cache length
1609
 
1610
  Returns:
1611
  position_ids: (3, batch_size, seq_len)
 
1614
  device = input_ids.device
1615
  image_token_id = self.config.image_token_id
1616
 
1617
+ # Decode stage: always advance positions from the current text cache length.
1618
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1619
+ if past_seen_tokens > 0:
 
1620
  position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
1621
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
1622
+ position_ids = position_ids + past_seen_tokens
1623
+
1624
+ if rope_deltas is not None:
1625
+ position_ids = position_ids + rope_deltas.unsqueeze(1)
1626
+
1627
+ return position_ids.unsqueeze(0).expand(3, -1, -1)
 
 
 
 
 
 
 
1628
 
1629
  # Prefill stage: compute full position_ids with image token awareness
1630
  # Vectorized implementation
 
1703
  rope_deltas: (batch_size,) - position offset due to vision tokens
1704
  """
1705
  batch_size, max_vision_seq_len, _ = cross_attention_states.shape
1706
+ device = cross_attention_states.device
1707
  image_token_id = self.config.image_token_id
1708
  merge_size = self.visual.spatial_merge_size
1709
 
 
1711
  # We need to flatten the nested vision_token_info structure to align with image tokens in input_ids
1712
 
1713
  # Find all image tokens in text: (num_occurrences, 2) -> [batch_idx, seq_idx]
1714
+ image_token_indices = (input_ids == image_token_id).nonzero()
1715
 
1716
  # Flatten vision_token_info to parallel lists
1717
  # We assume the order of medias in vision_token_info matches the appearance of image tokens in input_ids
1718
+ flat_eff_h_parts = []
1719
+ flat_eff_w_parts = []
1720
+ flat_vis_start_parts = []
1721
+
 
1722
  # Processing metadata on CPU (fast enough for typical batch sizes)
1723
  for b_idx, info in enumerate(vision_token_info):
1724
  medias = info.get('medias', [])
 
1729
  start = media['start']
1730
  tok_per_frame = media['vision_tokens_per_frame']
1731
  stride = tok_per_frame + 1 # +1 for separator
1732
+
1733
+ frame_offsets = start + torch.arange(num_frames, device=device, dtype=torch.long) * stride
1734
+ flat_vis_start_parts.append(frame_offsets)
1735
+ flat_eff_h_parts.append(torch.full((num_frames,), eh, device=device, dtype=torch.long))
1736
+ flat_eff_w_parts.append(torch.full((num_frames,), ew, device=device, dtype=torch.long))
 
 
1737
 
1738
  # Pre-allocate output
1739
  vision_pos_ids = torch.zeros(
 
1743
  )
1744
 
1745
  # Handle case where no image tokens or info
1746
+ if len(flat_eff_h_parts) == 0 or len(image_token_indices) == 0:
1747
  rope_deltas = position_ids.max(dim=0).values.max(dim=-1).values + 1 - input_ids.shape[1]
1748
  return vision_pos_ids, position_ids, rope_deltas
1749
 
1750
+ flat_eff_h = torch.cat(flat_eff_h_parts)
1751
+ flat_eff_w = torch.cat(flat_eff_w_parts)
1752
+ flat_vis_starts = torch.cat(flat_vis_start_parts)
1753
+
1754
  # Align lengths (handle truncation if text has fewer tokens or vice versa)
1755
+ num_matches = min(flat_eff_h.shape[0], image_token_indices.shape[0])
1756
+ flat_eff_h = flat_eff_h[:num_matches]
1757
+ flat_eff_w = flat_eff_w[:num_matches]
1758
+ flat_vis_starts = flat_vis_starts[:num_matches]
 
 
1759
 
1760
  # Get corresponding text positions
1761
  target_indices = image_token_indices[:num_matches]
 
1921
  )
1922
  return vision_embeds, vision_token_info
1923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1924
 
1925
 
1926
  @auto_docstring
 
1936
  media_nums_per_sample: Optional[List[int]] = None,
1937
  vision_position_ids: Optional[torch.LongTensor] = None,
1938
  cross_attention_mask: Optional[torch.Tensor] = None,
1939
+ vision_token_info: Optional[List[dict]] = None,
1940
+ rope_deltas: Optional[torch.LongTensor] = None,
1941
+ output_attentions: Optional[bool] = None,
1942
+ output_hidden_states: Optional[bool] = None,
1943
+ return_dict: Optional[bool] = None,
1944
  **kwargs: Unpack[TransformersKwargs],
1945
  ) -> Union[tuple, BaseModelOutputWithPast]:
1946
  """
 
1957
  cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
1958
  Attention mask for cross-attention between text and vision. Controls which vision tokens each text
1959
  token can attend to, enforcing causal visibility for video frames.
1960
+ vision_token_info (`List[dict]`, *optional*):
1961
+ Cached metadata describing how packed vision tokens were regrouped per sample. Reused in decode
1962
+ to expand frame-level cross-attention masks to token-level masks without recomputing vision features.
1963
+ rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1964
+ Cached offsets between text sequence length and multimodal RoPE positions. Reused in decode to
1965
+ reconstruct text position ids from the current cache length.
1966
  """
1967
+ cache_position = kwargs.pop("cache_position", None)
1968
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1969
+ output_hidden_states = (
1970
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1971
+ )
1972
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1973
+
1974
  if (input_ids is None) ^ (inputs_embeds is not None):
1975
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1976
 
 
1979
 
1980
  # Process vision features (images and videos are already merged by processor)
1981
  cross_attention_states = None
1982
+
 
1983
  if pixel_values is not None:
1984
  # Determine batch size
1985
  batch_size = inputs_embeds.shape[0]
 
1994
 
1995
  # Process all vision inputs together through VIT
1996
  # pixel_values and grid_thw are already ordered by appearance in text
1997
+ vision_embeds, vision_token_info = self.get_vision_features(
1998
+ pixel_values, grid_thw, media_nums_per_sample
 
 
 
1999
  )
2000
 
2001
  # vision_embeds: [batch_size, max_seq_len, hidden_size]
2002
  cross_attention_states = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
 
 
 
 
 
 
 
 
2003
 
2004
  # Generate 3D position IDs for text if not provided
2005
  if position_ids is None:
 
2008
  position_ids = self.compute_position_ids(
2009
  input_ids=input_ids,
2010
  attention_mask=attention_mask,
2011
+ past_key_values=past_key_values,
2012
+ rope_deltas=rope_deltas,
2013
  )
2014
 
2015
  # Compute cross_attention_mask, vision_position_ids, and full_text_row_masked_out_mask
 
2033
  (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
2034
  )
2035
  cross_attention_mask = cross_attention_mask * full_text_row_masked_out_mask
 
 
2036
 
2037
  if vision_position_ids is None and cross_attention_states is not None and input_ids is not None:
2038
  vision_position_ids, position_ids, rope_deltas = self.compute_vision_position_ids(
 
2042
  cross_attention_states,
2043
  attention_mask
2044
  )
 
 
 
 
 
 
 
 
2045
 
2046
  outputs = self.language_model(
2047
  input_ids=None,
 
2054
  cross_attention_mask=cross_attention_mask,
2055
  vision_position_ids=vision_position_ids,
2056
  full_text_row_masked_out_mask=full_text_row_masked_out_mask,
2057
+ output_attentions=output_attentions,
2058
+ output_hidden_states=output_hidden_states,
2059
+ return_dict=return_dict,
2060
  **kwargs,
2061
  )
2062
 
2063
+ if not return_dict:
2064
+ last_hidden_state = outputs[0]
2065
+ model_outputs = (
2066
+ last_hidden_state,
2067
+ outputs[1] if len(outputs) > 1 else past_key_values,
2068
+ )
2069
+ if output_hidden_states:
2070
+ model_outputs += (outputs[2],)
2071
+ if output_attentions:
2072
+ attn_idx = 3 if output_hidden_states else 2
2073
+ model_outputs += (outputs[attn_idx],)
2074
+ model_outputs += (vision_token_info, rope_deltas)
2075
+ return model_outputs
2076
+
2077
  return MossVLModelOutputWithPast(
2078
  last_hidden_state=outputs.last_hidden_state,
2079
  past_key_values=outputs.past_key_values,
2080
  hidden_states=outputs.hidden_states,
2081
  attentions=outputs.attentions,
2082
+ vision_token_info=vision_token_info,
2083
+ rope_deltas=rope_deltas,
2084
  )
2085
 
2086
 
 
2102
  super().__init__(config)
2103
  self.model = MossVLModel(config)
2104
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
 
2105
 
2106
  self.post_init()
2107
 
 
2159
  media_nums_per_sample: Optional[List[int]] = None,
2160
  vision_position_ids: Optional[torch.LongTensor] = None,
2161
  cross_attention_mask: Optional[torch.Tensor] = None,
2162
+ vision_token_info: Optional[List[dict]] = None,
2163
+ rope_deltas: Optional[torch.LongTensor] = None,
2164
  logits_to_keep: Union[int, torch.Tensor] = 0,
2165
+ output_attentions: Optional[bool] = None,
2166
+ output_hidden_states: Optional[bool] = None,
2167
+ return_dict: Optional[bool] = None,
2168
  **kwargs: Unpack[TransformersKwargs],
2169
  ) -> Union[tuple, CausalLMOutputWithPast]:
2170
  """
 
2181
  cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
2182
  Attention mask for cross-attention between text and vision. Controls which vision tokens each text
2183
  token can attend to, enforcing causal visibility for video frames.
2184
+ vision_token_info (`List[dict]`, *optional*):
2185
+ Cached metadata describing how packed vision tokens were regrouped per sample. Reused across decode
2186
+ steps to expand cross-attention masks without re-running the vision encoder.
2187
+ rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2188
+ Cached multimodal RoPE offsets returned by the base model during prefill and reused during decode.
2189
  """
2190
+ cache_position = kwargs.pop("cache_position", None)
2191
  outputs = self.model(
2192
  input_ids=input_ids,
2193
  pixel_values=pixel_values,
 
2199
  cross_attention_mask=cross_attention_mask,
2200
  past_key_values=past_key_values,
2201
  inputs_embeds=inputs_embeds,
2202
+ vision_token_info=vision_token_info,
2203
+ rope_deltas=rope_deltas,
2204
+ output_attentions=output_attentions,
2205
+ output_hidden_states=output_hidden_states,
2206
+ return_dict=return_dict,
2207
  cache_position=cache_position,
 
2208
  **kwargs,
2209
  )
2210
 
2211
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2212
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
2213
 
2214
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
2215
  logits = self.lm_head(hidden_states[:, slice_indices, :])
 
2218
  if labels is not None:
2219
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
2220
 
2221
+ if not return_dict:
2222
+ output = (logits,)
2223
+ output += outputs[1:]
2224
+ return ((loss,) + output) if loss is not None else output
2225
+
2226
  return MossVLCausalLMOutputWithPast(
2227
  loss=loss,
2228
  logits=logits,
 
2239
  past_key_values=None,
2240
  attention_mask=None,
2241
  inputs_embeds=None,
 
2242
  position_ids=None,
2243
  use_cache=True,
2244
  pixel_values=None,
2245
  grid_thw=None,
2246
  media_nums_per_sample=None, # One video is one meida.
2247
  vision_position_ids=None,
2248
+ vision_token_info=None,
2249
+ rope_deltas=None,
2250
  cross_attention_mask=None,
 
2251
  **kwargs,
2252
  ):
2253
  """
 
2260
  Args:
2261
  media_nums_per_sample: One video counts as one media item (regardless of frame count)
2262
  """
2263
+ kwargs.pop("cache_position", None)
2264
  model_inputs = super().prepare_inputs_for_generation(
2265
  input_ids,
2266
  past_key_values=past_key_values,
2267
  attention_mask=attention_mask,
2268
  inputs_embeds=inputs_embeds,
 
2269
  position_ids=position_ids,
2270
  pixel_values=pixel_values,
2271
  grid_thw=grid_thw,
 
2274
  **kwargs,
2275
  )
2276
 
2277
+ model_input = model_inputs.get("input_ids")
2278
+ if model_input is None:
2279
+ model_input = model_inputs.get("inputs_embeds")
2280
+ current_length = model_input.shape[1]
2281
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2282
 
2283
+ # Let the model recompute multimodal position ids from the current cache length.
 
2284
  model_inputs["position_ids"] = None
2285
+ model_inputs["vision_token_info"] = vision_token_info
2286
+ model_inputs["rope_deltas"] = rope_deltas
2287
 
2288
  # Handle cross attention mask
2289
  if cross_attention_mask is not None:
2290
+ # Slice to the current text slice on text dimension (dim=2).
2291
+ # Shape: [batch, 1, text_len, vision_len] -> [batch, 1, current_len, vision_len]
2292
+ cross_attention_mask = cross_attention_mask[:, :, -current_length:, :]
2293
  model_inputs["cross_attention_mask"] = cross_attention_mask
2294
 
2295
+ # Vision inputs are only needed in prefill stage.
2296
  # In decode stage, vision features are retrieved from cross attention cache
2297
+ if past_seen_tokens > 0:
2298
  model_inputs["pixel_values"] = None
2299
  model_inputs["grid_thw"] = None
2300
  model_inputs["media_nums_per_sample"] = None
 
2303
  else:
2304
  # In prefill stage, include all vision-related inputs
2305
  model_inputs["vision_position_ids"] = vision_position_ids
 
2306
 
2307
  return model_inputs
2308
 
 
2323
  **kwargs,
2324
  )
2325
 
 
 
2326
  if cross_attention_mask_prev is not None:
2327
+ model_kwargs["cross_attention_mask"] = cross_attention_mask_prev
2328
+
2329
+ if getattr(outputs, "vision_token_info", None) is not None:
2330
+ model_kwargs["vision_token_info"] = outputs.vision_token_info
2331
+ if getattr(outputs, "rope_deltas", None) is not None:
2332
+ model_kwargs["rope_deltas"] = outputs.rope_deltas
2333
 
2334
  return model_kwargs
2335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2336
 
2337
  __all__ = [
2338
  "MossVLVisionModel",