Update code for transformers 5.5.4
#4
by sjzhou - opened
- modeling_moss_vl.py +268 -1317
modeling_moss_vl.py
CHANGED
|
@@ -14,21 +14,18 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
|
| 16 |
|
| 17 |
-
import copy
|
| 18 |
from dataclasses import dataclass
|
| 19 |
-
import
|
| 20 |
-
import threading
|
| 21 |
-
from typing import Any, Callable, Dict, Optional, Union, Tuple, List
|
| 22 |
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
import torch.nn.functional as F
|
| 26 |
|
|
|
|
|
|
|
| 27 |
from transformers.activations import ACT2FN
|
| 28 |
from transformers.cache_utils import Cache, DynamicCache
|
| 29 |
from transformers.generation import GenerationMixin
|
| 30 |
-
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
| 31 |
-
from transformers.generation.streamers import TextIteratorStreamer
|
| 32 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 33 |
from transformers.masking_utils import create_causal_mask
|
| 34 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
@@ -39,7 +36,8 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
| 39 |
from transformers.processing_utils import Unpack
|
| 40 |
from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging
|
| 41 |
from transformers.utils.deprecation import deprecate_kwarg
|
| 42 |
-
from transformers.utils.generic import
|
|
|
|
| 43 |
|
| 44 |
from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionConfig
|
| 45 |
|
|
@@ -47,58 +45,6 @@ from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionC
|
|
| 47 |
|
| 48 |
logger = logging.get_logger(__name__)
|
| 49 |
|
| 50 |
-
_OFFLINE_SYSTEM_PROMPTS = {
|
| 51 |
-
"no_thinking": {
|
| 52 |
-
"text_image": "You are a helpful AI assistant. Respond to the user's request based on the provided text and/or images.",
|
| 53 |
-
"video": "You are a helpful AI assistant specializing in video analysis. Respond to the user's request based on the provided video content.",
|
| 54 |
-
},
|
| 55 |
-
"deep_thinking": {
|
| 56 |
-
"text_image": "A conversation between User and Assistant. The user makes a request, and the assistant responds to it based on the provided text and/or images. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
|
| 57 |
-
"video": "A conversation between User and Assistant specializing in video analysis. The user makes a request, and the assistant responds to it based on the provided video content. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking></thinking> and <answer></answer> tags, respectively, i.e., <thinking>reasoning process here</thinking><answer>answer here</answer>.",
|
| 58 |
-
},
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class _OfflineCancelStoppingCriteria(StoppingCriteria):
|
| 63 |
-
def __init__(self, cancel_event: threading.Event):
|
| 64 |
-
self.cancel_event = cancel_event
|
| 65 |
-
|
| 66 |
-
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 67 |
-
return self.cancel_event.is_set()
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class _OfflineQueueStreamer(TextIteratorStreamer):
|
| 71 |
-
def __init__(self, tokenizer, output_text_queue: "queue.Queue[str]"):
|
| 72 |
-
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 73 |
-
self.output_text_queue = output_text_queue
|
| 74 |
-
self.collected_chunks: List[str] = []
|
| 75 |
-
|
| 76 |
-
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 77 |
-
if text:
|
| 78 |
-
self.collected_chunks.append(text)
|
| 79 |
-
self.output_text_queue.put(text)
|
| 80 |
-
super().on_finalized_text(text, stream_end=stream_end)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
_OFFLINE_THINKING_MODE_ALIASES = {
|
| 84 |
-
"no_thinking": "no_thinking",
|
| 85 |
-
"default": "no_thinking",
|
| 86 |
-
"standard": "no_thinking",
|
| 87 |
-
"deep_thinking": "deep_thinking",
|
| 88 |
-
"thinking": "deep_thinking",
|
| 89 |
-
"reasoning": "deep_thinking",
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
_OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES = {
|
| 93 |
-
"text_image": "text_image",
|
| 94 |
-
"text-image": "text_image",
|
| 95 |
-
"image_text": "text_image",
|
| 96 |
-
"image-text": "text_image",
|
| 97 |
-
"text": "text_image",
|
| 98 |
-
"image": "text_image",
|
| 99 |
-
"video": "video",
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
|
| 103 |
@dataclass
|
| 104 |
class MossVLModelOutputWithPast(ModelOutput):
|
|
@@ -198,13 +144,21 @@ class MossVLVisionPatchEmbed(nn.Module):
|
|
| 198 |
|
| 199 |
|
| 200 |
class MossVLVisionRotaryEmbedding(nn.Module):
|
| 201 |
-
inv_freq: torch.Tensor
|
| 202 |
|
| 203 |
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 204 |
super().__init__()
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
| 208 |
def forward(self, seqlen: int) -> torch.Tensor:
|
| 209 |
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 210 |
freqs = torch.outer(seq, self.inv_freq)
|
|
@@ -233,11 +187,16 @@ class MossVLVisionPatchMerger(nn.Module):
|
|
| 233 |
self.act_fn = nn.GELU()
|
| 234 |
self.linear_fc2 = nn.Linear(self.input_hidden_size, config.out_hidden_size)
|
| 235 |
|
| 236 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# 1. Collect all features: [last_hidden_state, deepstack_1, deepstack_2, ...]
|
| 238 |
# self.norms[0] corresponds to last_hidden_state
|
| 239 |
# self.norms[1:] corresponds to deepstack_features
|
| 240 |
-
|
|
|
|
| 241 |
all_inputs = [last_hidden_state] + deepstack_features
|
| 242 |
|
| 243 |
# 2. Apply Norm independently
|
|
@@ -346,11 +305,11 @@ class MossVLVisionAttention(nn.Module):
|
|
| 346 |
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
| 347 |
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
| 348 |
|
| 349 |
-
attention_interface: Callable =
|
| 350 |
-
|
| 351 |
-
|
| 352 |
|
| 353 |
-
if self.config
|
| 354 |
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
| 355 |
attn_output, _ = attention_interface(
|
| 356 |
self,
|
|
@@ -429,26 +388,44 @@ class MossVLTextRotaryEmbedding(nn.Module):
|
|
| 429 |
|
| 430 |
def __init__(self, config: MossVLTextConfig, device=None):
|
| 431 |
super().__init__()
|
| 432 |
-
# BC: "rope_type" was originally "type"
|
| 433 |
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 434 |
-
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
| 435 |
-
else:
|
| 436 |
-
self.rope_type = "default"
|
| 437 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 438 |
self.original_max_seq_len = config.max_position_embeddings
|
| 439 |
|
| 440 |
self.config = config
|
| 441 |
-
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 445 |
-
self.original_inv_freq
|
| 446 |
|
|
|
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
def apply_interleaved_mrope(self, freqs, mrope_section):
|
| 454 |
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
|
@@ -470,7 +447,6 @@ class MossVLTextRotaryEmbedding(nn.Module):
|
|
| 470 |
@torch.no_grad()
|
| 471 |
@dynamic_rope_update
|
| 472 |
def forward(self, x, position_ids):
|
| 473 |
-
|
| 474 |
if position_ids.ndim == 2:
|
| 475 |
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 476 |
|
|
@@ -571,12 +547,11 @@ class MossVLTextSelfAttention(nn.Module):
|
|
| 571 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 572 |
|
| 573 |
if past_key_values is not None:
|
| 574 |
-
|
| 575 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 576 |
|
| 577 |
-
attention_interface: Callable =
|
| 578 |
-
|
| 579 |
-
|
| 580 |
|
| 581 |
attn_output, attn_weights = attention_interface(
|
| 582 |
self,
|
|
@@ -625,7 +600,7 @@ class MossVLTextCrossAttention(nn.Module):
|
|
| 625 |
attention_mask: Optional[torch.Tensor] = None,
|
| 626 |
past_key_values: Optional[Cache] = None,
|
| 627 |
use_cache: bool = None,
|
| 628 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 629 |
query_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 630 |
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 631 |
**kwargs,
|
|
@@ -659,9 +634,7 @@ class MossVLTextCrossAttention(nn.Module):
|
|
| 659 |
if past_key_values is not None:
|
| 660 |
# if we have a new image + new tokens, we only computed key_states on that new image
|
| 661 |
# we still update the cross key states, past_image, new_image. And use it!
|
| 662 |
-
key_states, value_states = past_key_values.update(
|
| 663 |
-
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
| 664 |
-
)
|
| 665 |
|
| 666 |
elif cache_position[0] != 0:
|
| 667 |
key_states, value_states = (
|
|
@@ -673,13 +646,13 @@ class MossVLTextCrossAttention(nn.Module):
|
|
| 673 |
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
|
| 674 |
)
|
| 675 |
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
|
| 684 |
attn_output, attn_weights = attention_interface(
|
| 685 |
self,
|
|
@@ -740,14 +713,14 @@ class MossVLSelfAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
| 740 |
use_cache: Optional[bool] = False,
|
| 741 |
cache_position: Optional[torch.LongTensor] = None,
|
| 742 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 743 |
-
vision_cache_position: Optional[torch.LongTensor] = None,
|
| 744 |
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
| 745 |
**kwargs: Unpack[TransformersKwargs],
|
| 746 |
-
) -> torch.Tensor:
|
| 747 |
# Self Attention
|
| 748 |
residual = hidden_states
|
| 749 |
hidden_states = self.input_layernorm(hidden_states)
|
| 750 |
-
hidden_states,
|
| 751 |
hidden_states=hidden_states,
|
| 752 |
attention_mask=attention_mask,
|
| 753 |
past_key_values=past_key_values,
|
|
@@ -762,8 +735,11 @@ class MossVLSelfAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
| 762 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 763 |
hidden_states = self.mlp(hidden_states)
|
| 764 |
hidden_states = residual + hidden_states
|
| 765 |
-
|
| 766 |
-
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
|
| 769 |
class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
@@ -799,21 +775,21 @@ class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
| 799 |
use_cache: Optional[bool] = False,
|
| 800 |
cache_position: Optional[torch.LongTensor] = None,
|
| 801 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 802 |
-
vision_cache_position: Optional[torch.LongTensor] = None,
|
| 803 |
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
| 804 |
**kwargs: Unpack[TransformersKwargs],
|
| 805 |
-
) -> torch.Tensor:
|
| 806 |
# Cross Attention
|
| 807 |
residual = hidden_states
|
| 808 |
hidden_states = self.input_layernorm(hidden_states)
|
| 809 |
|
| 810 |
-
hidden_states,
|
| 811 |
hidden_states=hidden_states,
|
| 812 |
cross_attention_states=cross_attention_states,
|
| 813 |
attention_mask=cross_attention_mask,
|
| 814 |
past_key_values=past_key_values,
|
| 815 |
use_cache=use_cache,
|
| 816 |
-
cache_position=
|
| 817 |
query_position_embeddings=position_embeddings,
|
| 818 |
vision_position_embeddings=vision_position_embeddings,
|
| 819 |
)
|
|
@@ -830,8 +806,11 @@ class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
| 830 |
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
|
| 831 |
|
| 832 |
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
| 833 |
-
|
| 834 |
-
|
|
|
|
|
|
|
|
|
|
| 835 |
|
| 836 |
|
| 837 |
|
|
@@ -857,32 +836,10 @@ class MossVLPreTrainedModel(PreTrainedModel):
|
|
| 857 |
|
| 858 |
def _init_weights(self, module):
|
| 859 |
"""Initialize the weights.
|
| 860 |
-
|
| 861 |
-
Note: For loading pretrained weights:
|
| 862 |
-
- Cross attention: can be initialized from the previous layer's self attention weights
|
| 863 |
"""
|
| 864 |
-
|
| 865 |
-
if
|
| 866 |
-
|
| 867 |
-
|
| 868 |
-
if isinstance(module, MossVLVisionPatchMerger):
|
| 869 |
-
# Initialize merger weights
|
| 870 |
-
# Input: hidden_size * (1 + num_deepstack_features) -> Output: out_hidden_size
|
| 871 |
-
# This projection handles concatenated features, so we might want specific initialization
|
| 872 |
-
module.linear_fc1.weight.data.normal_(mean=0.0, std=std)
|
| 873 |
-
module.linear_fc2.weight.data.normal_(mean=0.0, std=std)
|
| 874 |
-
if module.linear_fc1.bias is not None:
|
| 875 |
-
module.linear_fc1.bias.data.zero_()
|
| 876 |
-
if module.linear_fc2.bias is not None:
|
| 877 |
-
module.linear_fc2.bias.data.zero_()
|
| 878 |
-
|
| 879 |
-
# Initialize separate LayerNorms
|
| 880 |
-
if hasattr(module, "norms"):
|
| 881 |
-
for norm in module.norms:
|
| 882 |
-
if hasattr(norm, "weight") and norm.weight is not None:
|
| 883 |
-
norm.weight.data.fill_(1.0)
|
| 884 |
-
if hasattr(norm, "bias") and norm.bias is not None:
|
| 885 |
-
norm.bias.data.zero_()
|
| 886 |
|
| 887 |
|
| 888 |
|
|
@@ -958,13 +915,15 @@ class MossVLVisionModel(MossVLPreTrainedModel):
|
|
| 958 |
|
| 959 |
def fast_pos_embed_interpolate(self, grid_thw):
|
| 960 |
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
|
|
|
|
|
|
|
| 961 |
|
| 962 |
-
|
| 963 |
-
|
| 964 |
|
| 965 |
for t, h, w in zip(grid_ts, grid_hs, grid_ws):
|
| 966 |
-
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
|
| 967 |
-
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
|
| 968 |
|
| 969 |
h_idxs_floor = h_idxs.int()
|
| 970 |
w_idxs_floor = w_idxs.int()
|
|
@@ -992,13 +951,11 @@ class MossVLVisionModel(MossVLPreTrainedModel):
|
|
| 992 |
]
|
| 993 |
|
| 994 |
for i in range(4):
|
| 995 |
-
|
| 996 |
-
|
| 997 |
|
| 998 |
-
idx_tensor = torch.
|
| 999 |
-
weight_tensor = torch.
|
| 1000 |
-
weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
|
| 1001 |
-
)
|
| 1002 |
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
|
| 1003 |
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
| 1004 |
|
|
@@ -1127,7 +1084,9 @@ class MossVLTextModel(MossVLPreTrainedModel):
|
|
| 1127 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 1128 |
use_cache: Optional[bool] = None,
|
| 1129 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1130 |
-
|
|
|
|
|
|
|
| 1131 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 1132 |
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 1133 |
"""
|
|
@@ -1140,9 +1099,15 @@ class MossVLTextModel(MossVLPreTrainedModel):
|
|
| 1140 |
Attention mask for cross-attention between text and vision. Shape: `(batch_size, 1, text_seq_len, vision_seq_len)`.
|
| 1141 |
vision_position_ids (`torch.LongTensor`, *optional*):
|
| 1142 |
Position IDs for vision tokens used in cross-attention. Shape: `(batch_size, vision_seq_len)`.
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1146 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1147 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1148 |
|
|
@@ -1164,7 +1129,7 @@ class MossVLTextModel(MossVLPreTrainedModel):
|
|
| 1164 |
|
| 1165 |
attention_mask = create_causal_mask(
|
| 1166 |
config=self.config,
|
| 1167 |
-
|
| 1168 |
attention_mask=attention_mask,
|
| 1169 |
cache_position=cache_position,
|
| 1170 |
past_key_values=past_key_values,
|
|
@@ -1179,14 +1144,15 @@ class MossVLTextModel(MossVLPreTrainedModel):
|
|
| 1179 |
# Compute vision position embeddings (for cross-attention key/value) if needed
|
| 1180 |
vision_position_embeddings = None
|
| 1181 |
|
| 1182 |
-
if vision_cache_position is None:
|
| 1183 |
-
# TODO:use cache_position now
|
| 1184 |
-
vision_cache_position = cache_position
|
| 1185 |
-
|
| 1186 |
if cross_attention_states is not None:
|
| 1187 |
if vision_position_ids is not None:
|
| 1188 |
vision_position_embeddings = self.rotary_emb(cross_attention_states, vision_position_ids)
|
| 1189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1190 |
|
| 1191 |
for idx, decoder_layer in enumerate(self.layers):
|
| 1192 |
# For text-only path we should skip cross attention layers.
|
|
@@ -1211,17 +1177,35 @@ class MossVLTextModel(MossVLPreTrainedModel):
|
|
| 1211 |
cross_attention_states=cross_attention_states,
|
| 1212 |
cross_attention_mask=cross_attention_mask,
|
| 1213 |
vision_position_ids=vision_position_ids,
|
| 1214 |
-
vision_cache_position=vision_cache_position,
|
| 1215 |
vision_position_embeddings=vision_position_embeddings,
|
|
|
|
| 1216 |
**kwargs,
|
| 1217 |
)
|
| 1218 |
-
hidden_states = layer_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1219 |
|
| 1220 |
hidden_states = self.norm(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1221 |
|
| 1222 |
return BaseModelOutputWithPast(
|
| 1223 |
last_hidden_state=hidden_states,
|
| 1224 |
past_key_values=past_key_values,
|
|
|
|
|
|
|
| 1225 |
)
|
| 1226 |
|
| 1227 |
|
|
@@ -1240,8 +1224,6 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1240 |
super().__init__(config)
|
| 1241 |
self.visual = MossVLVisionModel._from_config(config.vision_config)
|
| 1242 |
self.language_model = MossVLTextModel._from_config(config.text_config)
|
| 1243 |
-
self.vision_token_info = None # cache vision_token_info here for decode stage
|
| 1244 |
-
self.rope_deltas = None # cache position deltas for decode stage
|
| 1245 |
|
| 1246 |
# Learnable Separator Token: inserted after each image/frame's vision tokens
|
| 1247 |
# Initialized from LLM's separator_token_init_id embedding
|
|
@@ -1550,7 +1532,7 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1550 |
continue
|
| 1551 |
|
| 1552 |
# Collect repetition counts for all frames in this sample
|
| 1553 |
-
|
| 1554 |
for media in medias:
|
| 1555 |
num_frames = media.get('num_frames', 1)
|
| 1556 |
length = media['length']
|
|
@@ -1565,25 +1547,30 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1565 |
|
| 1566 |
# In convert_packed_to_batch we enforce strictly regular frames
|
| 1567 |
# so we can assume all frames have the same number of tokens
|
| 1568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1569 |
|
| 1570 |
-
num_valid_frames =
|
| 1571 |
if num_valid_frames == 0:
|
| 1572 |
continue
|
| 1573 |
|
| 1574 |
# If cross_attention_mask has more frames (e.g. padded), slice it
|
| 1575 |
# If it has fewer (shouldn't happen), slice repeats
|
| 1576 |
valid_mask_frames = min(num_valid_frames, cross_attention_mask.shape[-1])
|
|
|
|
| 1577 |
if valid_mask_frames < num_valid_frames:
|
| 1578 |
-
|
| 1579 |
|
| 1580 |
# Extract valid columns for this sample
|
| 1581 |
# (1, text_len, valid_mask_frames)
|
| 1582 |
source_mask = cross_attention_mask[i, :, :, :valid_mask_frames]
|
| 1583 |
|
| 1584 |
-
# Convert repeats to tensor
|
| 1585 |
-
repeats_tensor = torch.tensor(repeats, device=cross_attention_mask.device)
|
| 1586 |
-
|
| 1587 |
# Expand using repeat_interleave
|
| 1588 |
# output shape: (1, text_len, sum(repeats))
|
| 1589 |
expanded_mask = source_mask.repeat_interleave(repeats_tensor, dim=-1)
|
|
@@ -1602,7 +1589,8 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1602 |
self,
|
| 1603 |
input_ids: torch.Tensor,
|
| 1604 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1605 |
-
|
|
|
|
| 1606 |
) -> torch.Tensor:
|
| 1607 |
"""
|
| 1608 |
Compute 3D position IDs for text tokens with special handling for image tokens.
|
|
@@ -1617,7 +1605,7 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1617 |
Args:
|
| 1618 |
input_ids: (batch_size, seq_len)
|
| 1619 |
attention_mask: (batch_size, seq_len), optional
|
| 1620 |
-
|
| 1621 |
|
| 1622 |
Returns:
|
| 1623 |
position_ids: (3, batch_size, seq_len)
|
|
@@ -1626,25 +1614,17 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1626 |
device = input_ids.device
|
| 1627 |
image_token_id = self.config.image_token_id
|
| 1628 |
|
| 1629 |
-
# Decode stage:
|
| 1630 |
-
|
| 1631 |
-
|
| 1632 |
-
# rope_deltas is per-sample: (batch_size,)
|
| 1633 |
position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
|
| 1634 |
-
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
if
|
| 1638 |
-
position_ids = position_ids +
|
| 1639 |
-
|
| 1640 |
-
|
| 1641 |
-
# self.rope_deltas shape: (batch_size,), need to unsqueeze for broadcasting
|
| 1642 |
-
position_ids = position_ids + self.rope_deltas.unsqueeze(1) # (batch, seq_len)
|
| 1643 |
-
|
| 1644 |
-
# Expand to 3D: (3, batch, seq_len)
|
| 1645 |
-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 1646 |
-
|
| 1647 |
-
return position_ids
|
| 1648 |
|
| 1649 |
# Prefill stage: compute full position_ids with image token awareness
|
| 1650 |
# Vectorized implementation
|
|
@@ -1723,7 +1703,7 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1723 |
rope_deltas: (batch_size,) - position offset due to vision tokens
|
| 1724 |
"""
|
| 1725 |
batch_size, max_vision_seq_len, _ = cross_attention_states.shape
|
| 1726 |
-
device =
|
| 1727 |
image_token_id = self.config.image_token_id
|
| 1728 |
merge_size = self.visual.spatial_merge_size
|
| 1729 |
|
|
@@ -1731,15 +1711,14 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1731 |
# We need to flatten the nested vision_token_info structure to align with image tokens in input_ids
|
| 1732 |
|
| 1733 |
# Find all image tokens in text: (num_occurrences, 2) -> [batch_idx, seq_idx]
|
| 1734 |
-
image_token_indices = (input_ids == image_token_id).nonzero()
|
| 1735 |
|
| 1736 |
# Flatten vision_token_info to parallel lists
|
| 1737 |
# We assume the order of medias in vision_token_info matches the appearance of image tokens in input_ids
|
| 1738 |
-
|
| 1739 |
-
|
| 1740 |
-
|
| 1741 |
-
|
| 1742 |
-
|
| 1743 |
# Processing metadata on CPU (fast enough for typical batch sizes)
|
| 1744 |
for b_idx, info in enumerate(vision_token_info):
|
| 1745 |
medias = info.get('medias', [])
|
|
@@ -1750,13 +1729,11 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1750 |
start = media['start']
|
| 1751 |
tok_per_frame = media['vision_tokens_per_frame']
|
| 1752 |
stride = tok_per_frame + 1 # +1 for separator
|
| 1753 |
-
|
| 1754 |
-
|
| 1755 |
-
|
| 1756 |
-
|
| 1757 |
-
|
| 1758 |
-
flat_vis_starts.append(start + f * stride)
|
| 1759 |
-
flat_batch_indices.append(b_idx)
|
| 1760 |
|
| 1761 |
# Pre-allocate output
|
| 1762 |
vision_pos_ids = torch.zeros(
|
|
@@ -1766,17 +1743,19 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1766 |
)
|
| 1767 |
|
| 1768 |
# Handle case where no image tokens or info
|
| 1769 |
-
if len(
|
| 1770 |
rope_deltas = position_ids.max(dim=0).values.max(dim=-1).values + 1 - input_ids.shape[1]
|
| 1771 |
return vision_pos_ids, position_ids, rope_deltas
|
| 1772 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1773 |
# Align lengths (handle truncation if text has fewer tokens or vice versa)
|
| 1774 |
-
num_matches = min(
|
| 1775 |
-
|
| 1776 |
-
|
| 1777 |
-
|
| 1778 |
-
flat_eff_w = torch.tensor(flat_eff_w[:num_matches], device=device, dtype=torch.long)
|
| 1779 |
-
flat_vis_starts = torch.tensor(flat_vis_starts[:num_matches], device=device, dtype=torch.long)
|
| 1780 |
|
| 1781 |
# Get corresponding text positions
|
| 1782 |
target_indices = image_token_indices[:num_matches]
|
|
@@ -1942,53 +1921,6 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 1942 |
)
|
| 1943 |
return vision_embeds, vision_token_info
|
| 1944 |
|
| 1945 |
-
def get_vision_features_chunked(
|
| 1946 |
-
self,
|
| 1947 |
-
pixel_values: torch.FloatTensor,
|
| 1948 |
-
grid_thw: Optional[torch.LongTensor] = None,
|
| 1949 |
-
media_nums_per_sample: Optional[List[int]] = None,
|
| 1950 |
-
vision_chunked_length: Optional[int] = None,
|
| 1951 |
-
):
|
| 1952 |
-
"""
|
| 1953 |
-
Chunk the visual encoder forward by media items, then reuse the same
|
| 1954 |
-
packed-to-batch conversion logic. This keeps output semantics identical
|
| 1955 |
-
to `get_vision_features(...)` while reducing prefill memory pressure.
|
| 1956 |
-
"""
|
| 1957 |
-
if (
|
| 1958 |
-
vision_chunked_length is None
|
| 1959 |
-
or vision_chunked_length <= 0
|
| 1960 |
-
or grid_thw is None
|
| 1961 |
-
or grid_thw.shape[0] <= vision_chunked_length
|
| 1962 |
-
):
|
| 1963 |
-
return self.get_vision_features(pixel_values, grid_thw, media_nums_per_sample)
|
| 1964 |
-
|
| 1965 |
-
pixel_values = pixel_values.type(self.visual.dtype)
|
| 1966 |
-
token_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist()
|
| 1967 |
-
|
| 1968 |
-
hidden_state_chunks = []
|
| 1969 |
-
token_offset = 0
|
| 1970 |
-
for media_start in range(0, grid_thw.shape[0], vision_chunked_length):
|
| 1971 |
-
media_end = min(media_start + vision_chunked_length, grid_thw.shape[0])
|
| 1972 |
-
chunk_grid_thw = grid_thw[media_start:media_end]
|
| 1973 |
-
chunk_token_count = sum(token_counts[media_start:media_end])
|
| 1974 |
-
chunk_pixel_values = pixel_values[token_offset:token_offset + chunk_token_count]
|
| 1975 |
-
token_offset += chunk_token_count
|
| 1976 |
-
|
| 1977 |
-
hidden_state_chunks.append(
|
| 1978 |
-
self.visual(
|
| 1979 |
-
chunk_pixel_values,
|
| 1980 |
-
grid_thw=chunk_grid_thw,
|
| 1981 |
-
)
|
| 1982 |
-
)
|
| 1983 |
-
|
| 1984 |
-
hidden_states = torch.cat(hidden_state_chunks, dim=0)
|
| 1985 |
-
vision_embeds, vision_token_info = self.convert_packed_to_batch(
|
| 1986 |
-
hidden_states,
|
| 1987 |
-
grid_thw,
|
| 1988 |
-
media_nums_per_sample,
|
| 1989 |
-
)
|
| 1990 |
-
return vision_embeds, vision_token_info
|
| 1991 |
-
|
| 1992 |
|
| 1993 |
|
| 1994 |
@auto_docstring
|
|
@@ -2004,7 +1936,11 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2004 |
media_nums_per_sample: Optional[List[int]] = None,
|
| 2005 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 2006 |
cross_attention_mask: Optional[torch.Tensor] = None,
|
| 2007 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2008 |
**kwargs: Unpack[TransformersKwargs],
|
| 2009 |
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 2010 |
"""
|
|
@@ -2021,11 +1957,20 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2021 |
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
|
| 2022 |
Attention mask for cross-attention between text and vision. Controls which vision tokens each text
|
| 2023 |
token can attend to, enforcing causal visibility for video frames.
|
| 2024 |
-
|
| 2025 |
-
|
| 2026 |
-
|
|
|
|
|
|
|
|
|
|
| 2027 |
"""
|
| 2028 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2029 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 2030 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 2031 |
|
|
@@ -2034,8 +1979,7 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2034 |
|
| 2035 |
# Process vision features (images and videos are already merged by processor)
|
| 2036 |
cross_attention_states = None
|
| 2037 |
-
|
| 2038 |
-
|
| 2039 |
if pixel_values is not None:
|
| 2040 |
# Determine batch size
|
| 2041 |
batch_size = inputs_embeds.shape[0]
|
|
@@ -2050,23 +1994,12 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2050 |
|
| 2051 |
# Process all vision inputs together through VIT
|
| 2052 |
# pixel_values and grid_thw are already ordered by appearance in text
|
| 2053 |
-
vision_embeds, vision_token_info = self.
|
| 2054 |
-
pixel_values,
|
| 2055 |
-
grid_thw,
|
| 2056 |
-
media_nums_per_sample,
|
| 2057 |
-
vision_chunked_length=vision_chunked_length,
|
| 2058 |
)
|
| 2059 |
|
| 2060 |
# vision_embeds: [batch_size, max_seq_len, hidden_size]
|
| 2061 |
cross_attention_states = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 2062 |
-
num_vision_tokens = cross_attention_states.shape[1]
|
| 2063 |
-
|
| 2064 |
-
# Cache vision_token_info for decode stage (prefill only)
|
| 2065 |
-
|
| 2066 |
-
self.vision_token_info = vision_token_info
|
| 2067 |
-
else:
|
| 2068 |
-
# In decode stage, use cached vision_token_info
|
| 2069 |
-
vision_token_info = self.vision_token_info
|
| 2070 |
|
| 2071 |
# Generate 3D position IDs for text if not provided
|
| 2072 |
if position_ids is None:
|
|
@@ -2075,7 +2008,8 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2075 |
position_ids = self.compute_position_ids(
|
| 2076 |
input_ids=input_ids,
|
| 2077 |
attention_mask=attention_mask,
|
| 2078 |
-
|
|
|
|
| 2079 |
)
|
| 2080 |
|
| 2081 |
# Compute cross_attention_mask, vision_position_ids, and full_text_row_masked_out_mask
|
|
@@ -2099,8 +2033,6 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2099 |
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
|
| 2100 |
)
|
| 2101 |
cross_attention_mask = cross_attention_mask * full_text_row_masked_out_mask
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
|
| 2105 |
if vision_position_ids is None and cross_attention_states is not None and input_ids is not None:
|
| 2106 |
vision_position_ids, position_ids, rope_deltas = self.compute_vision_position_ids(
|
|
@@ -2110,14 +2042,6 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2110 |
cross_attention_states,
|
| 2111 |
attention_mask
|
| 2112 |
)
|
| 2113 |
-
|
| 2114 |
-
# Cache rope_deltas for decode stage (only in prefill)
|
| 2115 |
-
# rope_deltas = max_position - sequence_length
|
| 2116 |
-
# This allows fast position computation in decode: position = cache_position + rope_deltas
|
| 2117 |
-
if cache_position is not None and cache_position[0] == 0:
|
| 2118 |
-
self.rope_deltas = rope_deltas
|
| 2119 |
-
|
| 2120 |
-
|
| 2121 |
|
| 2122 |
outputs = self.language_model(
|
| 2123 |
input_ids=None,
|
|
@@ -2130,16 +2054,33 @@ class MossVLModel(MossVLPreTrainedModel):
|
|
| 2130 |
cross_attention_mask=cross_attention_mask,
|
| 2131 |
vision_position_ids=vision_position_ids,
|
| 2132 |
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
|
|
|
|
|
|
|
|
| 2133 |
**kwargs,
|
| 2134 |
)
|
| 2135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2136 |
return MossVLModelOutputWithPast(
|
| 2137 |
last_hidden_state=outputs.last_hidden_state,
|
| 2138 |
past_key_values=outputs.past_key_values,
|
| 2139 |
hidden_states=outputs.hidden_states,
|
| 2140 |
attentions=outputs.attentions,
|
| 2141 |
-
vision_token_info=
|
| 2142 |
-
rope_deltas=
|
| 2143 |
)
|
| 2144 |
|
| 2145 |
|
|
@@ -2161,7 +2102,6 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2161 |
super().__init__(config)
|
| 2162 |
self.model = MossVLModel(config)
|
| 2163 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 2164 |
-
self._offline_processor_lock = threading.RLock()
|
| 2165 |
|
| 2166 |
self.post_init()
|
| 2167 |
|
|
@@ -2219,9 +2159,12 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2219 |
media_nums_per_sample: Optional[List[int]] = None,
|
| 2220 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 2221 |
cross_attention_mask: Optional[torch.Tensor] = None,
|
| 2222 |
-
|
| 2223 |
-
|
| 2224 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
|
|
|
|
|
|
|
|
| 2225 |
**kwargs: Unpack[TransformersKwargs],
|
| 2226 |
) -> Union[tuple, CausalLMOutputWithPast]:
|
| 2227 |
"""
|
|
@@ -2238,10 +2181,13 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2238 |
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
|
| 2239 |
Attention mask for cross-attention between text and vision. Controls which vision tokens each text
|
| 2240 |
token can attend to, enforcing causal visibility for video frames.
|
| 2241 |
-
|
| 2242 |
-
|
| 2243 |
-
|
|
|
|
|
|
|
| 2244 |
"""
|
|
|
|
| 2245 |
outputs = self.model(
|
| 2246 |
input_ids=input_ids,
|
| 2247 |
pixel_values=pixel_values,
|
|
@@ -2253,12 +2199,17 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2253 |
cross_attention_mask=cross_attention_mask,
|
| 2254 |
past_key_values=past_key_values,
|
| 2255 |
inputs_embeds=inputs_embeds,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2256 |
cache_position=cache_position,
|
| 2257 |
-
vision_chunked_length=vision_chunked_length,
|
| 2258 |
**kwargs,
|
| 2259 |
)
|
| 2260 |
|
| 2261 |
-
|
|
|
|
| 2262 |
|
| 2263 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 2264 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
@@ -2267,6 +2218,11 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2267 |
if labels is not None:
|
| 2268 |
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
| 2269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2270 |
return MossVLCausalLMOutputWithPast(
|
| 2271 |
loss=loss,
|
| 2272 |
logits=logits,
|
|
@@ -2283,15 +2239,15 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2283 |
past_key_values=None,
|
| 2284 |
attention_mask=None,
|
| 2285 |
inputs_embeds=None,
|
| 2286 |
-
cache_position=None,
|
| 2287 |
position_ids=None,
|
| 2288 |
use_cache=True,
|
| 2289 |
pixel_values=None,
|
| 2290 |
grid_thw=None,
|
| 2291 |
media_nums_per_sample=None, # One video is one meida.
|
| 2292 |
vision_position_ids=None,
|
|
|
|
|
|
|
| 2293 |
cross_attention_mask=None,
|
| 2294 |
-
vision_chunked_length=None,
|
| 2295 |
**kwargs,
|
| 2296 |
):
|
| 2297 |
"""
|
|
@@ -2304,12 +2260,12 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2304 |
Args:
|
| 2305 |
media_nums_per_sample: One video counts as one media item (regardless of frame count)
|
| 2306 |
"""
|
|
|
|
| 2307 |
model_inputs = super().prepare_inputs_for_generation(
|
| 2308 |
input_ids,
|
| 2309 |
past_key_values=past_key_values,
|
| 2310 |
attention_mask=attention_mask,
|
| 2311 |
inputs_embeds=inputs_embeds,
|
| 2312 |
-
cache_position=cache_position,
|
| 2313 |
position_ids=position_ids,
|
| 2314 |
pixel_values=pixel_values,
|
| 2315 |
grid_thw=grid_thw,
|
|
@@ -2318,21 +2274,27 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2318 |
**kwargs,
|
| 2319 |
)
|
| 2320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2321 |
|
| 2322 |
-
#
|
| 2323 |
-
# we can set them to None to let forward recompute them from cache_position.
|
| 2324 |
model_inputs["position_ids"] = None
|
|
|
|
|
|
|
| 2325 |
|
| 2326 |
# Handle cross attention mask
|
| 2327 |
if cross_attention_mask is not None:
|
| 2328 |
-
# Slice to current
|
| 2329 |
-
# Shape: [batch, 1, text_len, vision_len] -> [batch, 1,
|
| 2330 |
-
cross_attention_mask = cross_attention_mask[:, :, -
|
| 2331 |
model_inputs["cross_attention_mask"] = cross_attention_mask
|
| 2332 |
|
| 2333 |
-
# Vision inputs are only needed in prefill stage
|
| 2334 |
# In decode stage, vision features are retrieved from cross attention cache
|
| 2335 |
-
if
|
| 2336 |
model_inputs["pixel_values"] = None
|
| 2337 |
model_inputs["grid_thw"] = None
|
| 2338 |
model_inputs["media_nums_per_sample"] = None
|
|
@@ -2341,7 +2303,6 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2341 |
else:
|
| 2342 |
# In prefill stage, include all vision-related inputs
|
| 2343 |
model_inputs["vision_position_ids"] = vision_position_ids
|
| 2344 |
-
model_inputs["vision_chunked_length"] = vision_chunked_length
|
| 2345 |
|
| 2346 |
return model_inputs
|
| 2347 |
|
|
@@ -2362,1026 +2323,16 @@ class MossVLForConditionalGeneration(MossVLPreTrainedModel, GenerationMixin):
|
|
| 2362 |
**kwargs,
|
| 2363 |
)
|
| 2364 |
|
| 2365 |
-
# Extend cross_attention_mask for the new token
|
| 2366 |
-
# Copy the last token's mask pattern for the newly generated token
|
| 2367 |
if cross_attention_mask_prev is not None:
|
| 2368 |
-
model_kwargs["cross_attention_mask"] =
|
| 2369 |
-
|
| 2370 |
-
|
| 2371 |
-
|
|
|
|
|
|
|
| 2372 |
|
| 2373 |
return model_kwargs
|
| 2374 |
|
| 2375 |
-
@staticmethod
|
| 2376 |
-
def _offline_flatten_content_with_vision_tokens(content) -> str:
|
| 2377 |
-
if isinstance(content, str):
|
| 2378 |
-
return content
|
| 2379 |
-
if not isinstance(content, list):
|
| 2380 |
-
return str(content) if content else ""
|
| 2381 |
-
|
| 2382 |
-
parts = []
|
| 2383 |
-
for item in content:
|
| 2384 |
-
if isinstance(item, dict):
|
| 2385 |
-
if item.get("type") == "image" or "image" in item:
|
| 2386 |
-
parts.append("<|image|>")
|
| 2387 |
-
elif item.get("type") == "video" or "video" in item:
|
| 2388 |
-
parts.append("<|video|>")
|
| 2389 |
-
if "text" in item:
|
| 2390 |
-
parts.append(str(item["text"]))
|
| 2391 |
-
elif isinstance(item, str):
|
| 2392 |
-
parts.append(item)
|
| 2393 |
-
return "".join(parts)
|
| 2394 |
-
|
| 2395 |
-
@staticmethod
|
| 2396 |
-
def _offline_sanitize_prompt_text(processor, text: Any) -> str:
|
| 2397 |
-
if text is None:
|
| 2398 |
-
return ""
|
| 2399 |
-
|
| 2400 |
-
sanitized = str(text)
|
| 2401 |
-
replacements = [
|
| 2402 |
-
(getattr(processor, "image_placeholder", None), ""),
|
| 2403 |
-
(getattr(processor, "video_placeholder", None), ""),
|
| 2404 |
-
(getattr(processor, "image_token", None), ""),
|
| 2405 |
-
(getattr(processor, "video_token", None), ""),
|
| 2406 |
-
]
|
| 2407 |
-
for needle, replacement in replacements:
|
| 2408 |
-
if needle:
|
| 2409 |
-
sanitized = sanitized.replace(needle, replacement)
|
| 2410 |
-
return sanitized.lstrip("\n")
|
| 2411 |
-
|
| 2412 |
-
def _offline_sanitize_message_content(self, processor, content: Any) -> Any:
|
| 2413 |
-
if isinstance(content, str):
|
| 2414 |
-
return self._offline_sanitize_prompt_text(processor, content)
|
| 2415 |
-
if not isinstance(content, list):
|
| 2416 |
-
return content
|
| 2417 |
-
|
| 2418 |
-
sanitized_items = []
|
| 2419 |
-
for item in content:
|
| 2420 |
-
if isinstance(item, dict):
|
| 2421 |
-
item_copy = dict(item)
|
| 2422 |
-
if "text" in item_copy:
|
| 2423 |
-
item_copy["text"] = self._offline_sanitize_prompt_text(processor, item_copy.get("text"))
|
| 2424 |
-
sanitized_items.append(item_copy)
|
| 2425 |
-
elif isinstance(item, str):
|
| 2426 |
-
sanitized_items.append(self._offline_sanitize_prompt_text(processor, item))
|
| 2427 |
-
else:
|
| 2428 |
-
sanitized_items.append(item)
|
| 2429 |
-
return sanitized_items
|
| 2430 |
-
|
| 2431 |
-
def _offline_prepare_messages(self, processor, query: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 2432 |
-
messages = query.get("messages")
|
| 2433 |
-
if messages:
|
| 2434 |
-
prepared_messages = []
|
| 2435 |
-
for message in messages:
|
| 2436 |
-
if not isinstance(message, dict):
|
| 2437 |
-
continue
|
| 2438 |
-
message_copy = dict(message)
|
| 2439 |
-
message_copy["content"] = self._offline_sanitize_message_content(
|
| 2440 |
-
processor,
|
| 2441 |
-
message_copy.get("content", ""),
|
| 2442 |
-
)
|
| 2443 |
-
prepared_messages.append(message_copy)
|
| 2444 |
-
if prepared_messages:
|
| 2445 |
-
return prepared_messages
|
| 2446 |
-
|
| 2447 |
-
prompt = self._offline_sanitize_prompt_text(processor, query.get("prompt", ""))
|
| 2448 |
-
images = list(query.get("images") or [])
|
| 2449 |
-
videos = list(query.get("videos") or [])
|
| 2450 |
-
|
| 2451 |
-
content = []
|
| 2452 |
-
for image in images:
|
| 2453 |
-
content.append({"type": "image", "image": image})
|
| 2454 |
-
for video in videos:
|
| 2455 |
-
content.append({"type": "video", "video": video})
|
| 2456 |
-
if prompt:
|
| 2457 |
-
content.append({"type": "text", "text": prompt.lstrip("\n")})
|
| 2458 |
-
|
| 2459 |
-
if not content:
|
| 2460 |
-
content = [{"type": "text", "text": ""}]
|
| 2461 |
-
|
| 2462 |
-
return [{"role": "user", "content": content}]
|
| 2463 |
-
|
| 2464 |
-
@staticmethod
|
| 2465 |
-
def _offline_extract_content_parts(content: Any) -> Tuple[str, List[Any], List[Any]]:
|
| 2466 |
-
if isinstance(content, str):
|
| 2467 |
-
return content, [], []
|
| 2468 |
-
if not isinstance(content, list):
|
| 2469 |
-
return (str(content) if content else ""), [], []
|
| 2470 |
-
|
| 2471 |
-
text_parts: List[str] = []
|
| 2472 |
-
images: List[Any] = []
|
| 2473 |
-
videos: List[Any] = []
|
| 2474 |
-
for item in content:
|
| 2475 |
-
if isinstance(item, dict):
|
| 2476 |
-
if item.get("type") == "image" or "image" in item or "image_url" in item:
|
| 2477 |
-
image = item.get("image") or item.get("image_url")
|
| 2478 |
-
if image is not None:
|
| 2479 |
-
images.append(image)
|
| 2480 |
-
elif item.get("type") == "video" or "video" in item or "video_path" in item:
|
| 2481 |
-
video = item.get("video") or item.get("video_path")
|
| 2482 |
-
if video is not None:
|
| 2483 |
-
videos.append(video)
|
| 2484 |
-
|
| 2485 |
-
if "text" in item and item["text"] is not None:
|
| 2486 |
-
text_parts.append(str(item["text"]))
|
| 2487 |
-
elif isinstance(item, str):
|
| 2488 |
-
text_parts.append(item)
|
| 2489 |
-
|
| 2490 |
-
return "".join(text_parts), images, videos
|
| 2491 |
-
|
| 2492 |
-
@staticmethod
|
| 2493 |
-
def _offline_resolve_use_template(query: Dict[str, Any]) -> bool:
|
| 2494 |
-
return bool(query.get("use_template", False))
|
| 2495 |
-
|
| 2496 |
-
def _offline_prepare_input_text(
|
| 2497 |
-
self,
|
| 2498 |
-
processor,
|
| 2499 |
-
messages: List[Dict[str, Any]],
|
| 2500 |
-
use_template: bool,
|
| 2501 |
-
) -> str:
|
| 2502 |
-
if not use_template:
|
| 2503 |
-
if any(isinstance(message, dict) and message.get("role") == "system" for message in messages):
|
| 2504 |
-
raise ValueError("system messages require use_template=True")
|
| 2505 |
-
|
| 2506 |
-
parts = ["<|im_start|>"]
|
| 2507 |
-
for message in messages:
|
| 2508 |
-
role = message.get("role", "user") if isinstance(message, dict) else "user"
|
| 2509 |
-
content = message.get("content", "") if isinstance(message, dict) else message
|
| 2510 |
-
text, msg_images, msg_videos = self._offline_extract_content_parts(content)
|
| 2511 |
-
|
| 2512 |
-
if role == "user":
|
| 2513 |
-
media_tokens = ""
|
| 2514 |
-
if msg_images:
|
| 2515 |
-
media_tokens += "<|image|>" * len(msg_images)
|
| 2516 |
-
if msg_videos:
|
| 2517 |
-
media_tokens += "<|video|>" * len(msg_videos)
|
| 2518 |
-
parts.append(f"{media_tokens}{text}")
|
| 2519 |
-
else:
|
| 2520 |
-
parts.append(f"{text}<|im_end|>")
|
| 2521 |
-
return "".join(parts)
|
| 2522 |
-
|
| 2523 |
-
processed_messages = []
|
| 2524 |
-
for message in messages:
|
| 2525 |
-
message_copy = dict(message)
|
| 2526 |
-
message_copy["content"] = self._offline_flatten_content_with_vision_tokens(
|
| 2527 |
-
message_copy.get("content", "")
|
| 2528 |
-
)
|
| 2529 |
-
processed_messages.append(message_copy)
|
| 2530 |
-
return processor.apply_chat_template(
|
| 2531 |
-
processed_messages,
|
| 2532 |
-
tokenize=False,
|
| 2533 |
-
add_generation_prompt=True,
|
| 2534 |
-
)
|
| 2535 |
-
|
| 2536 |
-
@staticmethod
|
| 2537 |
-
def _offline_collect_media(messages: List[Dict[str, Any]]) -> tuple[List[Any], List[Any]]:
|
| 2538 |
-
all_images: List[Any] = []
|
| 2539 |
-
all_videos: List[Any] = []
|
| 2540 |
-
|
| 2541 |
-
for message in messages:
|
| 2542 |
-
content = message.get("content")
|
| 2543 |
-
if isinstance(content, list):
|
| 2544 |
-
for item in content:
|
| 2545 |
-
if not isinstance(item, dict):
|
| 2546 |
-
continue
|
| 2547 |
-
if item.get("type") == "image" or "image" in item:
|
| 2548 |
-
image = item.get("image") or item.get("image_url")
|
| 2549 |
-
if image is not None:
|
| 2550 |
-
all_images.append(image)
|
| 2551 |
-
elif item.get("type") == "video" or "video" in item:
|
| 2552 |
-
video = item.get("video")
|
| 2553 |
-
if video is not None:
|
| 2554 |
-
all_videos.append(video)
|
| 2555 |
-
|
| 2556 |
-
return all_images, all_videos
|
| 2557 |
-
|
| 2558 |
-
def _offline_build_processor_kwargs(
|
| 2559 |
-
self,
|
| 2560 |
-
input_text: Union[str, List[str]],
|
| 2561 |
-
all_images: List[Any],
|
| 2562 |
-
all_videos: List[Any],
|
| 2563 |
-
media_kwargs: Dict[str, Any],
|
| 2564 |
-
) -> Dict[str, Any]:
|
| 2565 |
-
processor_kwargs: Dict[str, Any] = {
|
| 2566 |
-
"text": input_text,
|
| 2567 |
-
"images": all_images or None,
|
| 2568 |
-
"videos": all_videos or None,
|
| 2569 |
-
"return_tensors": "pt",
|
| 2570 |
-
"padding": False,
|
| 2571 |
-
}
|
| 2572 |
-
|
| 2573 |
-
if media_kwargs.get("min_pixels") is not None:
|
| 2574 |
-
processor_kwargs["min_pixels"] = media_kwargs["min_pixels"]
|
| 2575 |
-
if media_kwargs.get("max_pixels") is not None:
|
| 2576 |
-
processor_kwargs["max_pixels"] = media_kwargs["max_pixels"]
|
| 2577 |
-
if media_kwargs.get("video_fps") is not None:
|
| 2578 |
-
processor_kwargs["video_fps"] = media_kwargs["video_fps"]
|
| 2579 |
-
|
| 2580 |
-
min_frames = media_kwargs.get("min_frames", media_kwargs.get("video_minlen"))
|
| 2581 |
-
max_frames = media_kwargs.get("max_frames", media_kwargs.get("video_maxlen"))
|
| 2582 |
-
if min_frames is not None:
|
| 2583 |
-
processor_kwargs["min_frames"] = min_frames
|
| 2584 |
-
if max_frames is not None:
|
| 2585 |
-
processor_kwargs["max_frames"] = max_frames
|
| 2586 |
-
|
| 2587 |
-
return processor_kwargs
|
| 2588 |
-
|
| 2589 |
-
def _offline_prepare_inputs(self, processor, query: Dict[str, Any]):
|
| 2590 |
-
messages = self._offline_prepare_messages(processor, query)
|
| 2591 |
-
input_text = self._offline_prepare_input_text(
|
| 2592 |
-
processor,
|
| 2593 |
-
messages,
|
| 2594 |
-
use_template=self._offline_resolve_use_template(query),
|
| 2595 |
-
)
|
| 2596 |
-
all_images, all_videos = self._offline_collect_media(messages)
|
| 2597 |
-
media_kwargs = dict(query.get("media_kwargs") or {})
|
| 2598 |
-
processor_kwargs = self._offline_build_processor_kwargs(
|
| 2599 |
-
input_text,
|
| 2600 |
-
all_images,
|
| 2601 |
-
all_videos,
|
| 2602 |
-
media_kwargs,
|
| 2603 |
-
)
|
| 2604 |
-
|
| 2605 |
-
image_proc = getattr(processor, "image_processor", None)
|
| 2606 |
-
video_proc = getattr(processor, "video_processor", None)
|
| 2607 |
-
modified_multi_image = False
|
| 2608 |
-
modified_video = False
|
| 2609 |
-
|
| 2610 |
-
with self._offline_processor_lock:
|
| 2611 |
-
try:
|
| 2612 |
-
multi_image_max_pixels = media_kwargs.get("multi_image_max_pixels")
|
| 2613 |
-
if multi_image_max_pixels is not None and image_proc is not None:
|
| 2614 |
-
orig_multi_image_max_pixels = getattr(image_proc, "multi_image_max_pixels", None)
|
| 2615 |
-
image_proc.multi_image_max_pixels = multi_image_max_pixels
|
| 2616 |
-
modified_multi_image = True
|
| 2617 |
-
|
| 2618 |
-
video_max_pixels = media_kwargs.get("video_max_pixels")
|
| 2619 |
-
if video_max_pixels is not None and video_proc is not None:
|
| 2620 |
-
orig_video_max_pixels = getattr(video_proc, "video_max_pixels", None)
|
| 2621 |
-
video_proc.video_max_pixels = video_max_pixels
|
| 2622 |
-
modified_video = True
|
| 2623 |
-
|
| 2624 |
-
inputs = processor(**processor_kwargs)
|
| 2625 |
-
finally:
|
| 2626 |
-
if modified_multi_image and image_proc is not None:
|
| 2627 |
-
image_proc.multi_image_max_pixels = orig_multi_image_max_pixels
|
| 2628 |
-
if modified_video and video_proc is not None:
|
| 2629 |
-
video_proc.video_max_pixels = orig_video_max_pixels
|
| 2630 |
-
|
| 2631 |
-
text_device = self.get_input_embeddings().weight.device
|
| 2632 |
-
vision_device = self.visual.patch_embed.proj.weight.device
|
| 2633 |
-
vision_input_keys = {"pixel_values", "grid_thw"}
|
| 2634 |
-
|
| 2635 |
-
for key, value in list(inputs.items()):
|
| 2636 |
-
if not isinstance(value, torch.Tensor):
|
| 2637 |
-
continue
|
| 2638 |
-
|
| 2639 |
-
target_device = vision_device if key in vision_input_keys else text_device
|
| 2640 |
-
moved_value = value.to(target_device)
|
| 2641 |
-
if moved_value.dtype == torch.float32:
|
| 2642 |
-
moved_value = moved_value.to(torch.bfloat16)
|
| 2643 |
-
inputs[key] = moved_value
|
| 2644 |
-
|
| 2645 |
-
return inputs, input_text
|
| 2646 |
-
|
| 2647 |
-
def _offline_build_session_messages(
|
| 2648 |
-
self,
|
| 2649 |
-
processor,
|
| 2650 |
-
query: Dict[str, Any],
|
| 2651 |
-
session_messages: List[Dict[str, Any]],
|
| 2652 |
-
) -> List[Dict[str, Any]]:
|
| 2653 |
-
has_explicit_messages = bool(query.get("messages"))
|
| 2654 |
-
if has_explicit_messages and not query.get("append_messages_to_session", False):
|
| 2655 |
-
base_messages: List[Dict[str, Any]] = []
|
| 2656 |
-
else:
|
| 2657 |
-
base_messages = [dict(message) for message in session_messages]
|
| 2658 |
-
|
| 2659 |
-
turn_messages = self._offline_prepare_messages(processor, query)
|
| 2660 |
-
has_system_message = any(
|
| 2661 |
-
isinstance(message, dict) and message.get("role") == "system"
|
| 2662 |
-
for message in (base_messages + turn_messages)
|
| 2663 |
-
)
|
| 2664 |
-
|
| 2665 |
-
should_add_system_prompt = (
|
| 2666 |
-
query.get("use_default_system_prompt", False)
|
| 2667 |
-
or query.get("system_prompt") is not None
|
| 2668 |
-
or query.get("system_prompt_type") is not None
|
| 2669 |
-
or query.get("thinking_mode") is not None
|
| 2670 |
-
)
|
| 2671 |
-
|
| 2672 |
-
if not base_messages and not has_system_message and should_add_system_prompt:
|
| 2673 |
-
system_prompt = self._offline_resolve_system_prompt(query, turn_messages)
|
| 2674 |
-
if system_prompt is not None:
|
| 2675 |
-
base_messages.append({"role": "system", "content": system_prompt})
|
| 2676 |
-
|
| 2677 |
-
return base_messages + turn_messages
|
| 2678 |
-
|
| 2679 |
-
@staticmethod
|
| 2680 |
-
def _offline_query_contains_video(query: Dict[str, Any], messages: List[Dict[str, Any]]) -> bool:
|
| 2681 |
-
if query.get("videos"):
|
| 2682 |
-
return True
|
| 2683 |
-
|
| 2684 |
-
for message in messages:
|
| 2685 |
-
content = message.get("content") if isinstance(message, dict) else None
|
| 2686 |
-
if isinstance(content, list) and any(
|
| 2687 |
-
isinstance(item, dict) and (item.get("type") == "video" or "video" in item)
|
| 2688 |
-
for item in content
|
| 2689 |
-
):
|
| 2690 |
-
return True
|
| 2691 |
-
return False
|
| 2692 |
-
|
| 2693 |
-
@staticmethod
|
| 2694 |
-
def _offline_normalize_thinking_mode(value: Optional[str]) -> str:
|
| 2695 |
-
if value is None:
|
| 2696 |
-
return "no_thinking"
|
| 2697 |
-
|
| 2698 |
-
normalized = _OFFLINE_THINKING_MODE_ALIASES.get(str(value).strip().lower())
|
| 2699 |
-
if normalized is None:
|
| 2700 |
-
allowed = ", ".join(sorted(set(_OFFLINE_THINKING_MODE_ALIASES.values())))
|
| 2701 |
-
raise ValueError(f"Unsupported thinking_mode: {value!r}. Supported values: {allowed}")
|
| 2702 |
-
return normalized
|
| 2703 |
-
|
| 2704 |
-
@staticmethod
|
| 2705 |
-
def _offline_normalize_system_prompt_type(value: Optional[str], has_video: bool) -> str:
|
| 2706 |
-
if value is None:
|
| 2707 |
-
return "video" if has_video else "text_image"
|
| 2708 |
-
|
| 2709 |
-
normalized_key = str(value).strip().lower().replace("/", "_").replace(" ", "_")
|
| 2710 |
-
while "__" in normalized_key:
|
| 2711 |
-
normalized_key = normalized_key.replace("__", "_")
|
| 2712 |
-
|
| 2713 |
-
normalized = _OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.get(normalized_key)
|
| 2714 |
-
if normalized is None:
|
| 2715 |
-
allowed = ", ".join(sorted(set(_OFFLINE_SYSTEM_PROMPT_TYPE_ALIASES.values())))
|
| 2716 |
-
raise ValueError(f"Unsupported system_prompt_type: {value!r}. Supported values: {allowed}")
|
| 2717 |
-
return normalized
|
| 2718 |
-
|
| 2719 |
-
def _offline_resolve_system_prompt(
|
| 2720 |
-
self,
|
| 2721 |
-
query: Dict[str, Any],
|
| 2722 |
-
turn_messages: List[Dict[str, Any]],
|
| 2723 |
-
) -> Optional[str]:
|
| 2724 |
-
explicit_system_prompt = query.get("system_prompt")
|
| 2725 |
-
if explicit_system_prompt is not None:
|
| 2726 |
-
return str(explicit_system_prompt)
|
| 2727 |
-
|
| 2728 |
-
has_video = self._offline_query_contains_video(query, turn_messages)
|
| 2729 |
-
thinking_mode = self._offline_normalize_thinking_mode(query.get("thinking_mode"))
|
| 2730 |
-
system_prompt_type = self._offline_normalize_system_prompt_type(
|
| 2731 |
-
query.get("system_prompt_type"),
|
| 2732 |
-
has_video=has_video,
|
| 2733 |
-
)
|
| 2734 |
-
return _OFFLINE_SYSTEM_PROMPTS[thinking_mode][system_prompt_type]
|
| 2735 |
-
|
| 2736 |
-
@staticmethod
|
| 2737 |
-
def _offline_finalize_session_messages(
|
| 2738 |
-
working_messages: List[Dict[str, Any]],
|
| 2739 |
-
assistant_text: str,
|
| 2740 |
-
) -> List[Dict[str, Any]]:
|
| 2741 |
-
next_messages = [dict(message) for message in working_messages]
|
| 2742 |
-
next_messages.append({"role": "assistant", "content": assistant_text})
|
| 2743 |
-
return next_messages
|
| 2744 |
-
|
| 2745 |
-
def _offline_prepare_generation(self, processor, query: Dict[str, Any]):
|
| 2746 |
-
inputs, input_text = self._offline_prepare_inputs(processor, query)
|
| 2747 |
-
generate_kwargs = dict(query.get("generate_kwargs") or {})
|
| 2748 |
-
|
| 2749 |
-
max_new_tokens = generate_kwargs.pop("max_new_tokens", 1024)
|
| 2750 |
-
temperature = generate_kwargs.pop("temperature", 1.0)
|
| 2751 |
-
top_k = generate_kwargs.pop("top_k", 50)
|
| 2752 |
-
top_p = generate_kwargs.pop("top_p", 1.0)
|
| 2753 |
-
repetition_penalty = generate_kwargs.pop("repetition_penalty", 1.0)
|
| 2754 |
-
do_sample = generate_kwargs.pop("do_sample", False)
|
| 2755 |
-
vision_chunked_length = generate_kwargs.pop("vision_chunked_length", None)
|
| 2756 |
-
|
| 2757 |
-
if temperature is None:
|
| 2758 |
-
temperature = 1.0
|
| 2759 |
-
if temperature <= 0:
|
| 2760 |
-
temperature = 1.0
|
| 2761 |
-
do_sample = False
|
| 2762 |
-
|
| 2763 |
-
call_kwargs = dict(
|
| 2764 |
-
max_new_tokens=max_new_tokens,
|
| 2765 |
-
temperature=temperature,
|
| 2766 |
-
top_k=top_k,
|
| 2767 |
-
top_p=top_p,
|
| 2768 |
-
repetition_penalty=repetition_penalty,
|
| 2769 |
-
do_sample=do_sample,
|
| 2770 |
-
vision_chunked_length=vision_chunked_length,
|
| 2771 |
-
**generate_kwargs,
|
| 2772 |
-
)
|
| 2773 |
-
return inputs, input_text, call_kwargs
|
| 2774 |
-
|
| 2775 |
-
@staticmethod
|
| 2776 |
-
def _offline_normalize_shared_mapping(
|
| 2777 |
-
values: List[Dict[str, Any]],
|
| 2778 |
-
mapping_name: str,
|
| 2779 |
-
) -> Dict[str, Any]:
|
| 2780 |
-
normalized_values = [dict(value or {}) for value in values]
|
| 2781 |
-
if not normalized_values:
|
| 2782 |
-
return {}
|
| 2783 |
-
|
| 2784 |
-
all_keys = set()
|
| 2785 |
-
for value in normalized_values:
|
| 2786 |
-
all_keys.update(value.keys())
|
| 2787 |
-
|
| 2788 |
-
merged: Dict[str, Any] = {}
|
| 2789 |
-
mismatched_keys: List[str] = []
|
| 2790 |
-
for key in sorted(all_keys):
|
| 2791 |
-
unique_values = {repr(value.get(key)) for value in normalized_values}
|
| 2792 |
-
if len(unique_values) > 1:
|
| 2793 |
-
mismatched_keys.append(key)
|
| 2794 |
-
else:
|
| 2795 |
-
merged[key] = normalized_values[0].get(key)
|
| 2796 |
-
|
| 2797 |
-
if mismatched_keys:
|
| 2798 |
-
mismatch_text = ", ".join(mismatched_keys)
|
| 2799 |
-
raise ValueError(
|
| 2800 |
-
f"All batch queries must share the same {mapping_name}. "
|
| 2801 |
-
f"Mismatched keys: {mismatch_text}"
|
| 2802 |
-
)
|
| 2803 |
-
return merged
|
| 2804 |
-
|
| 2805 |
-
def _offline_prepare_batch_generation(
|
| 2806 |
-
self,
|
| 2807 |
-
processor,
|
| 2808 |
-
queries: List[Dict[str, Any]],
|
| 2809 |
-
session_states: Optional[List[List[Dict[str, Any]]]] = None,
|
| 2810 |
-
):
|
| 2811 |
-
if not queries:
|
| 2812 |
-
raise ValueError("`queries` must contain at least one query.")
|
| 2813 |
-
|
| 2814 |
-
if session_states is None:
|
| 2815 |
-
session_states = [[] for _ in queries]
|
| 2816 |
-
elif len(session_states) != len(queries):
|
| 2817 |
-
raise ValueError("`session_states` must have the same length as `queries`.")
|
| 2818 |
-
|
| 2819 |
-
working_messages_list: List[List[Dict[str, Any]]] = []
|
| 2820 |
-
input_texts: List[str] = []
|
| 2821 |
-
all_images_per_query: List[List[Any]] = []
|
| 2822 |
-
all_videos_per_query: List[List[Any]] = []
|
| 2823 |
-
|
| 2824 |
-
for query, session_state in zip(queries, session_states):
|
| 2825 |
-
if not isinstance(query, dict):
|
| 2826 |
-
raise TypeError("Each batch query must be a dict.")
|
| 2827 |
-
if query.get("stop_offline_generate"):
|
| 2828 |
-
raise ValueError("`stop_offline_generate` is not supported in offline_batch_generate.")
|
| 2829 |
-
if query.get("stream_output", query.get("stream", False)):
|
| 2830 |
-
raise ValueError("Streaming is not supported in offline_batch_generate.")
|
| 2831 |
-
if query.get("cancel_current_generate") or query.get("stop_generation"):
|
| 2832 |
-
raise ValueError("Cancel / stop controls are not supported in offline_batch_generate.")
|
| 2833 |
-
|
| 2834 |
-
current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
|
| 2835 |
-
working_messages = self._offline_build_session_messages(
|
| 2836 |
-
processor,
|
| 2837 |
-
query,
|
| 2838 |
-
current_session,
|
| 2839 |
-
)
|
| 2840 |
-
working_messages_list.append(working_messages)
|
| 2841 |
-
input_texts.append(
|
| 2842 |
-
self._offline_prepare_input_text(
|
| 2843 |
-
processor,
|
| 2844 |
-
working_messages,
|
| 2845 |
-
use_template=self._offline_resolve_use_template(query),
|
| 2846 |
-
)
|
| 2847 |
-
)
|
| 2848 |
-
|
| 2849 |
-
all_images, all_videos = self._offline_collect_media(working_messages)
|
| 2850 |
-
all_images_per_query.append(all_images)
|
| 2851 |
-
all_videos_per_query.append(all_videos)
|
| 2852 |
-
|
| 2853 |
-
media_kwargs = self._offline_normalize_shared_mapping(
|
| 2854 |
-
[query.get("media_kwargs") or {} for query in queries],
|
| 2855 |
-
mapping_name="media_kwargs",
|
| 2856 |
-
)
|
| 2857 |
-
processor_kwargs = self._offline_build_processor_kwargs(
|
| 2858 |
-
input_text=input_texts,
|
| 2859 |
-
all_images=[image for images in all_images_per_query for image in images],
|
| 2860 |
-
all_videos=[video for videos in all_videos_per_query for video in videos],
|
| 2861 |
-
media_kwargs=media_kwargs,
|
| 2862 |
-
)
|
| 2863 |
-
processor_kwargs["padding"] = True
|
| 2864 |
-
|
| 2865 |
-
image_proc = getattr(processor, "image_processor", None)
|
| 2866 |
-
video_proc = getattr(processor, "video_processor", None)
|
| 2867 |
-
tokenizer = getattr(processor, "tokenizer", None)
|
| 2868 |
-
modified_multi_image = False
|
| 2869 |
-
modified_video = False
|
| 2870 |
-
orig_padding_side = None
|
| 2871 |
-
|
| 2872 |
-
with self._offline_processor_lock:
|
| 2873 |
-
try:
|
| 2874 |
-
multi_image_max_pixels = media_kwargs.get("multi_image_max_pixels")
|
| 2875 |
-
if multi_image_max_pixels is not None and image_proc is not None:
|
| 2876 |
-
orig_multi_image_max_pixels = getattr(image_proc, "multi_image_max_pixels", None)
|
| 2877 |
-
image_proc.multi_image_max_pixels = multi_image_max_pixels
|
| 2878 |
-
modified_multi_image = True
|
| 2879 |
-
|
| 2880 |
-
video_max_pixels = media_kwargs.get("video_max_pixels")
|
| 2881 |
-
if video_max_pixels is not None and video_proc is not None:
|
| 2882 |
-
orig_video_max_pixels = getattr(video_proc, "video_max_pixels", None)
|
| 2883 |
-
video_proc.video_max_pixels = video_max_pixels
|
| 2884 |
-
modified_video = True
|
| 2885 |
-
|
| 2886 |
-
if tokenizer is not None and hasattr(tokenizer, "padding_side"):
|
| 2887 |
-
orig_padding_side = tokenizer.padding_side
|
| 2888 |
-
tokenizer.padding_side = "left"
|
| 2889 |
-
|
| 2890 |
-
inputs = processor(**processor_kwargs)
|
| 2891 |
-
finally:
|
| 2892 |
-
if modified_multi_image and image_proc is not None:
|
| 2893 |
-
image_proc.multi_image_max_pixels = orig_multi_image_max_pixels
|
| 2894 |
-
if modified_video and video_proc is not None:
|
| 2895 |
-
video_proc.video_max_pixels = orig_video_max_pixels
|
| 2896 |
-
if tokenizer is not None and orig_padding_side is not None:
|
| 2897 |
-
tokenizer.padding_side = orig_padding_side
|
| 2898 |
-
|
| 2899 |
-
text_device = self.get_input_embeddings().weight.device
|
| 2900 |
-
vision_device = self.visual.patch_embed.proj.weight.device
|
| 2901 |
-
vision_input_keys = {"pixel_values", "grid_thw"}
|
| 2902 |
-
|
| 2903 |
-
for key, value in list(inputs.items()):
|
| 2904 |
-
if not isinstance(value, torch.Tensor):
|
| 2905 |
-
continue
|
| 2906 |
-
|
| 2907 |
-
target_device = vision_device if key in vision_input_keys else text_device
|
| 2908 |
-
moved_value = value.to(target_device)
|
| 2909 |
-
if moved_value.dtype == torch.float32:
|
| 2910 |
-
moved_value = moved_value.to(torch.bfloat16)
|
| 2911 |
-
inputs[key] = moved_value
|
| 2912 |
-
|
| 2913 |
-
generate_kwargs = self._offline_normalize_shared_mapping(
|
| 2914 |
-
[query.get("generate_kwargs") or {} for query in queries],
|
| 2915 |
-
mapping_name="generate_kwargs",
|
| 2916 |
-
)
|
| 2917 |
-
max_new_tokens = generate_kwargs.pop("max_new_tokens", 1024)
|
| 2918 |
-
temperature = generate_kwargs.pop("temperature", 1.0)
|
| 2919 |
-
top_k = generate_kwargs.pop("top_k", 50)
|
| 2920 |
-
top_p = generate_kwargs.pop("top_p", 1.0)
|
| 2921 |
-
repetition_penalty = generate_kwargs.pop("repetition_penalty", 1.0)
|
| 2922 |
-
do_sample = generate_kwargs.pop("do_sample", False)
|
| 2923 |
-
vision_chunked_length = generate_kwargs.pop("vision_chunked_length", None)
|
| 2924 |
-
|
| 2925 |
-
if temperature is None:
|
| 2926 |
-
temperature = 1.0
|
| 2927 |
-
if temperature <= 0:
|
| 2928 |
-
temperature = 1.0
|
| 2929 |
-
do_sample = False
|
| 2930 |
-
|
| 2931 |
-
call_kwargs = dict(
|
| 2932 |
-
max_new_tokens=max_new_tokens,
|
| 2933 |
-
temperature=temperature,
|
| 2934 |
-
top_k=top_k,
|
| 2935 |
-
top_p=top_p,
|
| 2936 |
-
repetition_penalty=repetition_penalty,
|
| 2937 |
-
do_sample=do_sample,
|
| 2938 |
-
vision_chunked_length=vision_chunked_length,
|
| 2939 |
-
**generate_kwargs,
|
| 2940 |
-
)
|
| 2941 |
-
return inputs, input_texts, working_messages_list, call_kwargs
|
| 2942 |
-
|
| 2943 |
-
def offline_batch_generate(
|
| 2944 |
-
self,
|
| 2945 |
-
processor,
|
| 2946 |
-
queries: List[Dict[str, Any]],
|
| 2947 |
-
session_states: Optional[List[List[Dict[str, Any]]]] = None,
|
| 2948 |
-
vision_chunked_length: int = 64,
|
| 2949 |
-
) -> Dict[str, Any]:
|
| 2950 |
-
"""
|
| 2951 |
-
Batch offline generation for multiple independent samples.
|
| 2952 |
-
|
| 2953 |
-
This method supports:
|
| 2954 |
-
- batched single-turn generation
|
| 2955 |
-
- batched multi-turn continuation through `session_states`
|
| 2956 |
-
|
| 2957 |
-
It intentionally does not support queue-style controls such as:
|
| 2958 |
-
- `stream_output`
|
| 2959 |
-
- `cancel_current_generate`
|
| 2960 |
-
- `stop_generation`
|
| 2961 |
-
- `stop_offline_generate`
|
| 2962 |
-
"""
|
| 2963 |
-
if not queries:
|
| 2964 |
-
return {"results": [], "session_states": []}
|
| 2965 |
-
|
| 2966 |
-
prepared_queries = [dict(query) for query in queries]
|
| 2967 |
-
for query in prepared_queries:
|
| 2968 |
-
generate_kwargs = query.setdefault("generate_kwargs", {})
|
| 2969 |
-
generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
|
| 2970 |
-
if session_states is None:
|
| 2971 |
-
session_states = [[] for _ in prepared_queries]
|
| 2972 |
-
elif len(session_states) != len(prepared_queries):
|
| 2973 |
-
raise ValueError("`session_states` must have the same length as `queries`.")
|
| 2974 |
-
|
| 2975 |
-
tokenizer = getattr(processor, "tokenizer", None)
|
| 2976 |
-
bucketed_indices: Dict[Any, List[int]] = {}
|
| 2977 |
-
for index, (query, session_state) in enumerate(zip(prepared_queries, session_states)):
|
| 2978 |
-
current_session = [] if query.get("reset_session") or query.get("clear_history") else session_state
|
| 2979 |
-
working_messages = self._offline_build_session_messages(processor, query, current_session)
|
| 2980 |
-
input_text = self._offline_prepare_input_text(
|
| 2981 |
-
processor,
|
| 2982 |
-
working_messages,
|
| 2983 |
-
use_template=self._offline_resolve_use_template(query),
|
| 2984 |
-
)
|
| 2985 |
-
|
| 2986 |
-
if tokenizer is not None:
|
| 2987 |
-
token_ids = tokenizer(input_text, add_special_tokens=False)["input_ids"]
|
| 2988 |
-
bucket_key = len(token_ids)
|
| 2989 |
-
else:
|
| 2990 |
-
bucket_key = len(input_text)
|
| 2991 |
-
bucketed_indices.setdefault(bucket_key, []).append(index)
|
| 2992 |
-
|
| 2993 |
-
results: List[Optional[Dict[str, Any]]] = [None] * len(prepared_queries)
|
| 2994 |
-
next_session_states: List[Optional[List[Dict[str, Any]]]] = [None] * len(prepared_queries)
|
| 2995 |
-
|
| 2996 |
-
for bucket_indices in bucketed_indices.values():
|
| 2997 |
-
bucket_queries = [prepared_queries[index] for index in bucket_indices]
|
| 2998 |
-
bucket_session_states = [session_states[index] for index in bucket_indices]
|
| 2999 |
-
inputs, input_texts, working_messages_list, call_kwargs = self._offline_prepare_batch_generation(
|
| 3000 |
-
processor,
|
| 3001 |
-
bucket_queries,
|
| 3002 |
-
session_states=bucket_session_states,
|
| 3003 |
-
)
|
| 3004 |
-
|
| 3005 |
-
with torch.no_grad():
|
| 3006 |
-
outputs = self.generate(
|
| 3007 |
-
**inputs,
|
| 3008 |
-
**call_kwargs,
|
| 3009 |
-
)
|
| 3010 |
-
|
| 3011 |
-
input_seq_len = inputs["input_ids"].shape[1]
|
| 3012 |
-
generated_tokens = outputs[:, input_seq_len:]
|
| 3013 |
-
decoded_texts = processor.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 3014 |
-
|
| 3015 |
-
for local_index, (query, input_text, working_messages, text) in enumerate(
|
| 3016 |
-
zip(bucket_queries, input_texts, working_messages_list, decoded_texts)
|
| 3017 |
-
):
|
| 3018 |
-
original_index = bucket_indices[local_index]
|
| 3019 |
-
if query.get("persist_session", True):
|
| 3020 |
-
next_session_state = self._offline_finalize_session_messages(working_messages, text)
|
| 3021 |
-
else:
|
| 3022 |
-
next_session_state = working_messages
|
| 3023 |
-
next_session_states[original_index] = next_session_state
|
| 3024 |
-
results[original_index] = {
|
| 3025 |
-
"index": original_index,
|
| 3026 |
-
"text": text,
|
| 3027 |
-
"input_text": input_text,
|
| 3028 |
-
"messages": working_messages,
|
| 3029 |
-
}
|
| 3030 |
-
|
| 3031 |
-
return {
|
| 3032 |
-
"results": [item for item in results if item is not None],
|
| 3033 |
-
"session_states": [item for item in next_session_states if item is not None],
|
| 3034 |
-
}
|
| 3035 |
-
|
| 3036 |
-
def _offline_generate_one(self, processor, query: Dict[str, Any]) -> str:
|
| 3037 |
-
working_messages = self._offline_build_session_messages(processor, query, [])
|
| 3038 |
-
generation_query = dict(query)
|
| 3039 |
-
generation_query["messages"] = working_messages
|
| 3040 |
-
inputs, _, call_kwargs = self._offline_prepare_generation(processor, generation_query)
|
| 3041 |
-
|
| 3042 |
-
with torch.no_grad():
|
| 3043 |
-
outputs = self.generate(
|
| 3044 |
-
**inputs,
|
| 3045 |
-
**call_kwargs,
|
| 3046 |
-
)
|
| 3047 |
-
|
| 3048 |
-
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 3049 |
-
return processor.decode(new_tokens, skip_special_tokens=True)
|
| 3050 |
-
|
| 3051 |
-
@staticmethod
|
| 3052 |
-
def _offline_capture_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 3053 |
-
if target is None or not overrides:
|
| 3054 |
-
return None
|
| 3055 |
-
return {name: copy.deepcopy(getattr(target, name)) for name in overrides}
|
| 3056 |
-
|
| 3057 |
-
@staticmethod
|
| 3058 |
-
def _offline_apply_processor_attrs(target, overrides: Optional[Dict[str, Any]]) -> None:
|
| 3059 |
-
if target is None or not overrides:
|
| 3060 |
-
return
|
| 3061 |
-
for name, value in overrides.items():
|
| 3062 |
-
setattr(target, name, copy.deepcopy(value))
|
| 3063 |
-
|
| 3064 |
-
@staticmethod
|
| 3065 |
-
def _offline_restore_processor_attrs(target, snapshot: Optional[Dict[str, Any]]) -> None:
|
| 3066 |
-
if target is None or snapshot is None:
|
| 3067 |
-
return
|
| 3068 |
-
for name, value in snapshot.items():
|
| 3069 |
-
setattr(target, name, copy.deepcopy(value))
|
| 3070 |
-
|
| 3071 |
-
def _offline_generate_one_with_processor_overrides(
|
| 3072 |
-
self,
|
| 3073 |
-
processor,
|
| 3074 |
-
query: Dict[str, Any],
|
| 3075 |
-
image_processor_overrides: Optional[Dict[str, Any]] = None,
|
| 3076 |
-
video_processor_overrides: Optional[Dict[str, Any]] = None,
|
| 3077 |
-
) -> str:
|
| 3078 |
-
image_proc = getattr(processor, "image_processor", None)
|
| 3079 |
-
video_proc = getattr(processor, "video_processor", None)
|
| 3080 |
-
image_snapshot = self._offline_capture_processor_attrs(image_proc, image_processor_overrides)
|
| 3081 |
-
video_snapshot = self._offline_capture_processor_attrs(video_proc, video_processor_overrides)
|
| 3082 |
-
|
| 3083 |
-
with self._offline_processor_lock:
|
| 3084 |
-
try:
|
| 3085 |
-
self._offline_apply_processor_attrs(image_proc, image_processor_overrides)
|
| 3086 |
-
self._offline_apply_processor_attrs(video_proc, video_processor_overrides)
|
| 3087 |
-
return self._offline_generate_one(processor, query)
|
| 3088 |
-
finally:
|
| 3089 |
-
self._offline_restore_processor_attrs(image_proc, image_snapshot)
|
| 3090 |
-
self._offline_restore_processor_attrs(video_proc, video_snapshot)
|
| 3091 |
-
|
| 3092 |
-
def offline_image_generate(
|
| 3093 |
-
self,
|
| 3094 |
-
processor,
|
| 3095 |
-
prompt: str = "",
|
| 3096 |
-
image: Any = None,
|
| 3097 |
-
*,
|
| 3098 |
-
shortest_edge: int = 4096,
|
| 3099 |
-
longest_edge: int = 16777216,
|
| 3100 |
-
multi_image_max_pixels: int = 201326592,
|
| 3101 |
-
patch_size: int = 16,
|
| 3102 |
-
temporal_patch_size: int = 1,
|
| 3103 |
-
merge_size: int = 2,
|
| 3104 |
-
image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3105 |
-
image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3106 |
-
max_new_tokens: int = 1024,
|
| 3107 |
-
temperature: float = 1.0,
|
| 3108 |
-
top_k: int = 50,
|
| 3109 |
-
top_p: float = 1.0,
|
| 3110 |
-
repetition_penalty: float = 1.0,
|
| 3111 |
-
do_sample: bool = False,
|
| 3112 |
-
vision_chunked_length: int = 64,
|
| 3113 |
-
use_template: bool = False,
|
| 3114 |
-
thinking_mode: Optional[str] = None,
|
| 3115 |
-
system_prompt_type: Optional[str] = None,
|
| 3116 |
-
system_prompt: Optional[str] = None,
|
| 3117 |
-
) -> str:
|
| 3118 |
-
"""
|
| 3119 |
-
Single-image offline generation with explicit image preprocessor defaults.
|
| 3120 |
-
|
| 3121 |
-
The default values mirror `preprocessor_config.json` so README examples can
|
| 3122 |
-
surface the full image preprocessing setup without requiring a batch wrapper.
|
| 3123 |
-
"""
|
| 3124 |
-
if image is None:
|
| 3125 |
-
raise ValueError("`image` is required.")
|
| 3126 |
-
query: Dict[str, Any] = {
|
| 3127 |
-
"prompt": prompt,
|
| 3128 |
-
"images": [image],
|
| 3129 |
-
"videos": [],
|
| 3130 |
-
"media_kwargs": {
|
| 3131 |
-
"min_pixels": shortest_edge,
|
| 3132 |
-
"max_pixels": longest_edge,
|
| 3133 |
-
"multi_image_max_pixels": multi_image_max_pixels,
|
| 3134 |
-
},
|
| 3135 |
-
"generate_kwargs": {
|
| 3136 |
-
"max_new_tokens": max_new_tokens,
|
| 3137 |
-
"temperature": temperature,
|
| 3138 |
-
"top_k": top_k,
|
| 3139 |
-
"top_p": top_p,
|
| 3140 |
-
"repetition_penalty": repetition_penalty,
|
| 3141 |
-
"do_sample": do_sample,
|
| 3142 |
-
"vision_chunked_length": vision_chunked_length,
|
| 3143 |
-
},
|
| 3144 |
-
"use_template": use_template,
|
| 3145 |
-
}
|
| 3146 |
-
if thinking_mode is not None:
|
| 3147 |
-
query["thinking_mode"] = thinking_mode
|
| 3148 |
-
if system_prompt_type is not None:
|
| 3149 |
-
query["system_prompt_type"] = system_prompt_type
|
| 3150 |
-
if system_prompt is not None:
|
| 3151 |
-
query["system_prompt"] = system_prompt
|
| 3152 |
-
|
| 3153 |
-
image_processor_overrides = {
|
| 3154 |
-
"size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
|
| 3155 |
-
"multi_image_max_pixels": multi_image_max_pixels,
|
| 3156 |
-
"patch_size": patch_size,
|
| 3157 |
-
"temporal_patch_size": temporal_patch_size,
|
| 3158 |
-
"merge_size": merge_size,
|
| 3159 |
-
"image_mean": list(image_mean) if image_mean is not None else None,
|
| 3160 |
-
"image_std": list(image_std) if image_std is not None else None,
|
| 3161 |
-
}
|
| 3162 |
-
return self._offline_generate_one_with_processor_overrides(
|
| 3163 |
-
processor,
|
| 3164 |
-
query,
|
| 3165 |
-
image_processor_overrides=image_processor_overrides,
|
| 3166 |
-
)
|
| 3167 |
-
|
| 3168 |
-
def offline_video_generate(
|
| 3169 |
-
self,
|
| 3170 |
-
processor,
|
| 3171 |
-
prompt: str = "",
|
| 3172 |
-
video: Any = None,
|
| 3173 |
-
*,
|
| 3174 |
-
shortest_edge: int = 4096,
|
| 3175 |
-
longest_edge: int = 16777216,
|
| 3176 |
-
video_max_pixels: int = 201326592,
|
| 3177 |
-
patch_size: int = 16,
|
| 3178 |
-
temporal_patch_size: int = 1,
|
| 3179 |
-
merge_size: int = 2,
|
| 3180 |
-
video_fps: float = 1.0,
|
| 3181 |
-
min_frames: int = 1,
|
| 3182 |
-
max_frames: int = 256,
|
| 3183 |
-
num_extract_threads: int = 4,
|
| 3184 |
-
image_mean: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3185 |
-
image_std: Optional[Union[List[float], Tuple[float, ...]]] = (0.5, 0.5, 0.5),
|
| 3186 |
-
max_new_tokens: int = 1024,
|
| 3187 |
-
temperature: float = 1.0,
|
| 3188 |
-
top_k: int = 50,
|
| 3189 |
-
top_p: float = 1.0,
|
| 3190 |
-
repetition_penalty: float = 1.0,
|
| 3191 |
-
do_sample: bool = False,
|
| 3192 |
-
vision_chunked_length: int = 64,
|
| 3193 |
-
use_template: bool = False,
|
| 3194 |
-
thinking_mode: Optional[str] = None,
|
| 3195 |
-
system_prompt_type: Optional[str] = None,
|
| 3196 |
-
system_prompt: Optional[str] = None,
|
| 3197 |
-
) -> str:
|
| 3198 |
-
"""
|
| 3199 |
-
Single-video offline generation with explicit video preprocessor defaults.
|
| 3200 |
-
|
| 3201 |
-
The default values mirror `video_preprocessor_config.json` so README examples
|
| 3202 |
-
can show a standalone video entry point with the effective preprocessing knobs.
|
| 3203 |
-
"""
|
| 3204 |
-
if video is None:
|
| 3205 |
-
raise ValueError("`video` is required.")
|
| 3206 |
-
query: Dict[str, Any] = {
|
| 3207 |
-
"prompt": prompt,
|
| 3208 |
-
"images": [],
|
| 3209 |
-
"videos": [video],
|
| 3210 |
-
"media_kwargs": {
|
| 3211 |
-
"min_pixels": shortest_edge,
|
| 3212 |
-
"max_pixels": longest_edge,
|
| 3213 |
-
"video_max_pixels": video_max_pixels,
|
| 3214 |
-
"video_fps": video_fps,
|
| 3215 |
-
"min_frames": min_frames,
|
| 3216 |
-
"max_frames": max_frames,
|
| 3217 |
-
},
|
| 3218 |
-
"generate_kwargs": {
|
| 3219 |
-
"max_new_tokens": max_new_tokens,
|
| 3220 |
-
"temperature": temperature,
|
| 3221 |
-
"top_k": top_k,
|
| 3222 |
-
"top_p": top_p,
|
| 3223 |
-
"repetition_penalty": repetition_penalty,
|
| 3224 |
-
"do_sample": do_sample,
|
| 3225 |
-
"vision_chunked_length": vision_chunked_length,
|
| 3226 |
-
},
|
| 3227 |
-
"use_template": use_template,
|
| 3228 |
-
}
|
| 3229 |
-
if thinking_mode is not None:
|
| 3230 |
-
query["thinking_mode"] = thinking_mode
|
| 3231 |
-
if system_prompt_type is not None:
|
| 3232 |
-
query["system_prompt_type"] = system_prompt_type
|
| 3233 |
-
if system_prompt is not None:
|
| 3234 |
-
query["system_prompt"] = system_prompt
|
| 3235 |
-
|
| 3236 |
-
video_processor_overrides = {
|
| 3237 |
-
"size": {"shortest_edge": shortest_edge, "longest_edge": longest_edge},
|
| 3238 |
-
"video_max_pixels": video_max_pixels,
|
| 3239 |
-
"patch_size": patch_size,
|
| 3240 |
-
"temporal_patch_size": temporal_patch_size,
|
| 3241 |
-
"merge_size": merge_size,
|
| 3242 |
-
"video_fps": video_fps,
|
| 3243 |
-
"min_frames": min_frames,
|
| 3244 |
-
"max_frames": max_frames,
|
| 3245 |
-
"num_extract_threads": num_extract_threads,
|
| 3246 |
-
"image_mean": list(image_mean) if image_mean is not None else None,
|
| 3247 |
-
"image_std": list(image_std) if image_std is not None else None,
|
| 3248 |
-
}
|
| 3249 |
-
return self._offline_generate_one_with_processor_overrides(
|
| 3250 |
-
processor,
|
| 3251 |
-
query,
|
| 3252 |
-
video_processor_overrides=video_processor_overrides,
|
| 3253 |
-
)
|
| 3254 |
-
|
| 3255 |
-
def offline_generate(
|
| 3256 |
-
self,
|
| 3257 |
-
processor,
|
| 3258 |
-
new_queries: "queue.Queue[dict]",
|
| 3259 |
-
output_text_queue: "queue.Queue[str]",
|
| 3260 |
-
vision_chunked_length: int = 64,
|
| 3261 |
-
) -> None:
|
| 3262 |
-
"""
|
| 3263 |
-
HF-style offline inference wrapper aligned with the previous backend output path.
|
| 3264 |
-
|
| 3265 |
-
This method intentionally reuses the checkpoint's existing processor and
|
| 3266 |
-
`generate()` flow so that outputs stay consistent with the old external
|
| 3267 |
-
backend inference implementation.
|
| 3268 |
-
|
| 3269 |
-
Supported query keys include:
|
| 3270 |
-
- `prompt` / `messages`
|
| 3271 |
-
- `images` / `videos`
|
| 3272 |
-
- `media_kwargs` / `generate_kwargs`
|
| 3273 |
-
- `use_template` to switch between backend-style pretrain prompting
|
| 3274 |
-
(`False`, default for base) and tokenizer chat template prompting (`True`)
|
| 3275 |
-
- `thinking_mode` (`no_thinking` or `deep_thinking`, plus compatible aliases)
|
| 3276 |
-
- `system_prompt_type` (`text_image` or `video`, plus compatible aliases)
|
| 3277 |
-
- `system_prompt` for an explicit override
|
| 3278 |
-
- `stream_output` / `stream`
|
| 3279 |
-
- `reset_session` / `clear_history`
|
| 3280 |
-
- `cancel_current_generate` / `stop_generation` / `stop_offline_generate`
|
| 3281 |
-
"""
|
| 3282 |
-
buffered_queries: List[Dict[str, Any]] = []
|
| 3283 |
-
session_messages: List[Dict[str, Any]] = []
|
| 3284 |
-
|
| 3285 |
-
while True:
|
| 3286 |
-
if buffered_queries:
|
| 3287 |
-
query = buffered_queries.pop(0)
|
| 3288 |
-
else:
|
| 3289 |
-
query = new_queries.get()
|
| 3290 |
-
if not isinstance(query, dict):
|
| 3291 |
-
continue
|
| 3292 |
-
|
| 3293 |
-
if query.get("stop_offline_generate"):
|
| 3294 |
-
break
|
| 3295 |
-
|
| 3296 |
-
if query.get("reset_session") or query.get("clear_history"):
|
| 3297 |
-
session_messages = []
|
| 3298 |
-
|
| 3299 |
-
try:
|
| 3300 |
-
generate_kwargs = query.setdefault("generate_kwargs", {})
|
| 3301 |
-
generate_kwargs.setdefault("vision_chunked_length", vision_chunked_length)
|
| 3302 |
-
working_messages = self._offline_build_session_messages(
|
| 3303 |
-
processor,
|
| 3304 |
-
query,
|
| 3305 |
-
session_messages,
|
| 3306 |
-
)
|
| 3307 |
-
|
| 3308 |
-
generation_query = dict(query)
|
| 3309 |
-
generation_query["messages"] = working_messages
|
| 3310 |
-
inputs, input_text, call_kwargs = self._offline_prepare_generation(processor, generation_query)
|
| 3311 |
-
|
| 3312 |
-
stream_output = bool(query.get("stream_output", query.get("stream", False)))
|
| 3313 |
-
cancel_event = threading.Event()
|
| 3314 |
-
stopping_criteria = StoppingCriteriaList([_OfflineCancelStoppingCriteria(cancel_event)])
|
| 3315 |
-
generation_state: Dict[str, Any] = {}
|
| 3316 |
-
|
| 3317 |
-
if stream_output:
|
| 3318 |
-
output_text_queue.put("<|round_start|>")
|
| 3319 |
-
streamer = _OfflineQueueStreamer(getattr(processor, "tokenizer", processor), output_text_queue)
|
| 3320 |
-
else:
|
| 3321 |
-
streamer = None
|
| 3322 |
-
|
| 3323 |
-
def _run_generation():
|
| 3324 |
-
try:
|
| 3325 |
-
with torch.no_grad():
|
| 3326 |
-
generation_state["outputs"] = self.generate(
|
| 3327 |
-
**inputs,
|
| 3328 |
-
stopping_criteria=stopping_criteria,
|
| 3329 |
-
streamer=streamer,
|
| 3330 |
-
**call_kwargs,
|
| 3331 |
-
)
|
| 3332 |
-
except Exception as exc:
|
| 3333 |
-
generation_state["exception"] = exc
|
| 3334 |
-
|
| 3335 |
-
worker = threading.Thread(target=_run_generation, daemon=True)
|
| 3336 |
-
worker.start()
|
| 3337 |
-
|
| 3338 |
-
stop_conversation_after_turn = False
|
| 3339 |
-
while worker.is_alive():
|
| 3340 |
-
try:
|
| 3341 |
-
control_query = new_queries.get(timeout=0.1)
|
| 3342 |
-
except queue.Empty:
|
| 3343 |
-
continue
|
| 3344 |
-
|
| 3345 |
-
if not isinstance(control_query, dict):
|
| 3346 |
-
continue
|
| 3347 |
-
|
| 3348 |
-
if control_query.get("cancel_current_generate") or control_query.get("stop_generation"):
|
| 3349 |
-
cancel_event.set()
|
| 3350 |
-
stop_conversation_after_turn = stop_conversation_after_turn or control_query.get("stop_offline_generate", False)
|
| 3351 |
-
continue
|
| 3352 |
-
|
| 3353 |
-
if control_query.get("stop_offline_generate"):
|
| 3354 |
-
cancel_event.set()
|
| 3355 |
-
stop_conversation_after_turn = True
|
| 3356 |
-
continue
|
| 3357 |
-
|
| 3358 |
-
buffered_queries.append(control_query)
|
| 3359 |
-
|
| 3360 |
-
worker.join()
|
| 3361 |
-
was_cancelled = cancel_event.is_set()
|
| 3362 |
-
|
| 3363 |
-
if "exception" in generation_state:
|
| 3364 |
-
raise generation_state["exception"]
|
| 3365 |
-
|
| 3366 |
-
if stream_output and streamer is not None:
|
| 3367 |
-
text = "".join(streamer.collected_chunks)
|
| 3368 |
-
else:
|
| 3369 |
-
outputs = generation_state["outputs"]
|
| 3370 |
-
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 3371 |
-
text = processor.decode(new_tokens, skip_special_tokens=True)
|
| 3372 |
-
output_text_queue.put(text)
|
| 3373 |
-
|
| 3374 |
-
if query.get("persist_session", True) and (not was_cancelled or query.get("persist_cancelled_turn", False)):
|
| 3375 |
-
session_messages = self._offline_finalize_session_messages(working_messages, text)
|
| 3376 |
-
|
| 3377 |
-
output_text_queue.put("<|round_end|>")
|
| 3378 |
-
|
| 3379 |
-
if stop_conversation_after_turn:
|
| 3380 |
-
break
|
| 3381 |
-
except Exception as exc:
|
| 3382 |
-
output_text_queue.put(f"[ERROR] {exc}")
|
| 3383 |
-
output_text_queue.put("<|round_end|>")
|
| 3384 |
-
|
| 3385 |
|
| 3386 |
__all__ = [
|
| 3387 |
"MossVLVisionModel",
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
"""PyTorch MossVL model - Qwen3VL Vision + Text with Cross Attention"""
|
| 16 |
|
|
|
|
| 17 |
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Optional, Union, Tuple, List
|
|
|
|
|
|
|
| 19 |
|
| 20 |
import torch
|
| 21 |
import torch.nn as nn
|
| 22 |
import torch.nn.functional as F
|
| 23 |
|
| 24 |
+
from transformers import initialization as init
|
| 25 |
+
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.generation import GenerationMixin
|
|
|
|
|
|
|
| 29 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 30 |
from transformers.masking_utils import create_causal_mask
|
| 31 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
|
|
| 36 |
from transformers.processing_utils import Unpack
|
| 37 |
from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging
|
| 38 |
from transformers.utils.deprecation import deprecate_kwarg
|
| 39 |
+
from transformers.utils.generic import is_flash_attention_requested
|
| 40 |
+
from transformers.utils.output_capturing import OutputRecorder
|
| 41 |
|
| 42 |
from .configuration_moss_vl import MossVLConfig, MossVLTextConfig, MossVLVisionConfig
|
| 43 |
|
|
|
|
| 45 |
|
| 46 |
logger = logging.get_logger(__name__)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
@dataclass
|
| 50 |
class MossVLModelOutputWithPast(ModelOutput):
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
class MossVLVisionRotaryEmbedding(nn.Module):
|
| 147 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 148 |
|
| 149 |
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 150 |
super().__init__()
|
| 151 |
+
# Keep dim / theta so that `_init_weights` can rebuild `inv_freq` after
|
| 152 |
+
# from_pretrained materializes the module (it is a non-persistent buffer
|
| 153 |
+
# and therefore never populated by the checkpoint).
|
| 154 |
+
self.dim = dim
|
| 155 |
+
self.theta = theta
|
| 156 |
+
inv_freq = self.compute_inv_freq()
|
| 157 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 158 |
|
| 159 |
+
def compute_inv_freq(self) -> torch.Tensor:
|
| 160 |
+
return 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim))
|
| 161 |
+
|
| 162 |
def forward(self, seqlen: int) -> torch.Tensor:
|
| 163 |
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 164 |
freqs = torch.outer(seq, self.inv_freq)
|
|
|
|
| 187 |
self.act_fn = nn.GELU()
|
| 188 |
self.linear_fc2 = nn.Linear(self.input_hidden_size, config.out_hidden_size)
|
| 189 |
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
last_hidden_state: torch.Tensor,
|
| 193 |
+
deepstack_features: Optional[List[torch.Tensor]] = None,
|
| 194 |
+
) -> torch.Tensor:
|
| 195 |
# 1. Collect all features: [last_hidden_state, deepstack_1, deepstack_2, ...]
|
| 196 |
# self.norms[0] corresponds to last_hidden_state
|
| 197 |
# self.norms[1:] corresponds to deepstack_features
|
| 198 |
+
if deepstack_features is None:
|
| 199 |
+
deepstack_features = []
|
| 200 |
all_inputs = [last_hidden_state] + deepstack_features
|
| 201 |
|
| 202 |
# 2. Apply Norm independently
|
|
|
|
| 305 |
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
| 306 |
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
| 307 |
|
| 308 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 309 |
+
self.config._attn_implementation, eager_attention_forward
|
| 310 |
+
)
|
| 311 |
|
| 312 |
+
if is_flash_attention_requested(self.config):
|
| 313 |
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
| 314 |
attn_output, _ = attention_interface(
|
| 315 |
self,
|
|
|
|
| 388 |
|
| 389 |
def __init__(self, config: MossVLTextConfig, device=None):
|
| 390 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
self.max_seq_len_cached = config.max_position_embeddings
|
| 392 |
self.original_max_seq_len = config.max_position_embeddings
|
| 393 |
|
| 394 |
self.config = config
|
| 395 |
+
rope_parameters = getattr(config, "rope_parameters", None)
|
| 396 |
+
if rope_parameters is None:
|
| 397 |
+
rope_parameters = getattr(config, "rope_scaling", None) or {"rope_type": "default"}
|
| 398 |
|
| 399 |
+
self.rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default"))
|
| 400 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 401 |
+
if self.rope_type != "default":
|
| 402 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 403 |
+
|
| 404 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 405 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 406 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 407 |
|
| 408 |
+
self.mrope_section = rope_parameters.get("mrope_section", [24, 20, 20])
|
| 409 |
|
| 410 |
+
@staticmethod
|
| 411 |
+
def compute_default_rope_parameters(
|
| 412 |
+
config: Optional[MossVLTextConfig] = None,
|
| 413 |
+
device: Optional[torch.device] = None,
|
| 414 |
+
seq_len: Optional[int] = None,
|
| 415 |
+
) -> tuple[torch.Tensor, float]:
|
| 416 |
+
rope_parameters = getattr(config, "rope_parameters", None) or {}
|
| 417 |
+
base = rope_parameters.get("rope_theta", getattr(config, "rope_theta", 10000.0))
|
| 418 |
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 419 |
+
partial_rotary_factor = rope_parameters.get(
|
| 420 |
+
"partial_rotary_factor", getattr(config, "partial_rotary_factor", 1.0)
|
| 421 |
+
)
|
| 422 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 423 |
+
|
| 424 |
+
attention_factor = 1.0
|
| 425 |
+
inv_freq = 1.0 / (
|
| 426 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 427 |
+
)
|
| 428 |
+
return inv_freq, attention_factor
|
| 429 |
|
| 430 |
def apply_interleaved_mrope(self, freqs, mrope_section):
|
| 431 |
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
|
|
|
| 447 |
@torch.no_grad()
|
| 448 |
@dynamic_rope_update
|
| 449 |
def forward(self, x, position_ids):
|
|
|
|
| 450 |
if position_ids.ndim == 2:
|
| 451 |
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 452 |
|
|
|
|
| 547 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 548 |
|
| 549 |
if past_key_values is not None:
|
| 550 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
|
|
|
| 551 |
|
| 552 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 553 |
+
self.config._attn_implementation, eager_attention_forward
|
| 554 |
+
)
|
| 555 |
|
| 556 |
attn_output, attn_weights = attention_interface(
|
| 557 |
self,
|
|
|
|
| 600 |
attention_mask: Optional[torch.Tensor] = None,
|
| 601 |
past_key_values: Optional[Cache] = None,
|
| 602 |
use_cache: bool = None,
|
| 603 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 604 |
query_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 605 |
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 606 |
**kwargs,
|
|
|
|
| 634 |
if past_key_values is not None:
|
| 635 |
# if we have a new image + new tokens, we only computed key_states on that new image
|
| 636 |
# we still update the cross key states, past_image, new_image. And use it!
|
| 637 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
|
|
|
|
|
|
| 638 |
|
| 639 |
elif cache_position[0] != 0:
|
| 640 |
key_states, value_states = (
|
|
|
|
| 646 |
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
|
| 647 |
)
|
| 648 |
|
| 649 |
+
if is_flash_attention_requested(self.config):
|
| 650 |
+
# Cross attention still relies on an explicit dense mask.
|
| 651 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
|
| 652 |
+
else:
|
| 653 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 654 |
+
self.config._attn_implementation, eager_attention_forward
|
| 655 |
+
)
|
| 656 |
|
| 657 |
attn_output, attn_weights = attention_interface(
|
| 658 |
self,
|
|
|
|
| 713 |
use_cache: Optional[bool] = False,
|
| 714 |
cache_position: Optional[torch.LongTensor] = None,
|
| 715 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 716 |
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 717 |
+
output_attentions: bool = False,
|
| 718 |
**kwargs: Unpack[TransformersKwargs],
|
| 719 |
+
) -> tuple[torch.Tensor, ...]:
|
| 720 |
# Self Attention
|
| 721 |
residual = hidden_states
|
| 722 |
hidden_states = self.input_layernorm(hidden_states)
|
| 723 |
+
hidden_states, attn_weights = self.self_attn(
|
| 724 |
hidden_states=hidden_states,
|
| 725 |
attention_mask=attention_mask,
|
| 726 |
past_key_values=past_key_values,
|
|
|
|
| 735 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 736 |
hidden_states = self.mlp(hidden_states)
|
| 737 |
hidden_states = residual + hidden_states
|
| 738 |
+
|
| 739 |
+
outputs = (hidden_states,)
|
| 740 |
+
if output_attentions:
|
| 741 |
+
outputs += (attn_weights,)
|
| 742 |
+
return outputs
|
| 743 |
|
| 744 |
|
| 745 |
class MossVLCrossAttentionDecoderLayer(GradientCheckpointingLayer):
|
|
|
|
| 775 |
use_cache: Optional[bool] = False,
|
| 776 |
cache_position: Optional[torch.LongTensor] = None,
|
| 777 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 778 |
vision_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 779 |
+
output_attentions: bool = False,
|
| 780 |
**kwargs: Unpack[TransformersKwargs],
|
| 781 |
+
) -> tuple[torch.Tensor, ...]:
|
| 782 |
# Cross Attention
|
| 783 |
residual = hidden_states
|
| 784 |
hidden_states = self.input_layernorm(hidden_states)
|
| 785 |
|
| 786 |
+
hidden_states, attn_weights = self.cross_attn(
|
| 787 |
hidden_states=hidden_states,
|
| 788 |
cross_attention_states=cross_attention_states,
|
| 789 |
attention_mask=cross_attention_mask,
|
| 790 |
past_key_values=past_key_values,
|
| 791 |
use_cache=use_cache,
|
| 792 |
+
cache_position=cache_position,
|
| 793 |
query_position_embeddings=position_embeddings,
|
| 794 |
vision_position_embeddings=vision_position_embeddings,
|
| 795 |
)
|
|
|
|
| 806 |
hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
|
| 807 |
|
| 808 |
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
| 809 |
+
|
| 810 |
+
outputs = (hidden_states,)
|
| 811 |
+
if output_attentions:
|
| 812 |
+
outputs += (attn_weights,)
|
| 813 |
+
return outputs
|
| 814 |
|
| 815 |
|
| 816 |
|
|
|
|
| 836 |
|
| 837 |
def _init_weights(self, module):
|
| 838 |
"""Initialize the weights.
|
|
|
|
|
|
|
|
|
|
| 839 |
"""
|
| 840 |
+
super()._init_weights(module)
|
| 841 |
+
if isinstance(module, MossVLVisionRotaryEmbedding):
|
| 842 |
+
init.copy_(module.inv_freq, module.compute_inv_freq())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
|
| 844 |
|
| 845 |
|
|
|
|
| 915 |
|
| 916 |
def fast_pos_embed_interpolate(self, grid_thw):
|
| 917 |
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
|
| 918 |
+
device = self.pos_embed.weight.device
|
| 919 |
+
dtype = self.pos_embed.weight.dtype
|
| 920 |
|
| 921 |
+
idx_parts = [[] for _ in range(4)]
|
| 922 |
+
weight_parts = [[] for _ in range(4)]
|
| 923 |
|
| 924 |
for t, h, w in zip(grid_ts, grid_hs, grid_ws):
|
| 925 |
+
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
| 926 |
+
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
| 927 |
|
| 928 |
h_idxs_floor = h_idxs.int()
|
| 929 |
w_idxs_floor = w_idxs.int()
|
|
|
|
| 951 |
]
|
| 952 |
|
| 953 |
for i in range(4):
|
| 954 |
+
idx_parts[i].append(indices[i])
|
| 955 |
+
weight_parts[i].append(weights[i])
|
| 956 |
|
| 957 |
+
idx_tensor = torch.stack([torch.cat(parts) for parts in idx_parts]).to(dtype=torch.long)
|
| 958 |
+
weight_tensor = torch.stack([torch.cat(parts) for parts in weight_parts]).to(dtype=dtype)
|
|
|
|
|
|
|
| 959 |
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
|
| 960 |
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
| 961 |
|
|
|
|
| 1084 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 1085 |
use_cache: Optional[bool] = None,
|
| 1086 |
cache_position: Optional[torch.LongTensor] = None,
|
| 1087 |
+
output_attentions: Optional[bool] = None,
|
| 1088 |
+
output_hidden_states: Optional[bool] = None,
|
| 1089 |
+
return_dict: Optional[bool] = None,
|
| 1090 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 1091 |
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 1092 |
"""
|
|
|
|
| 1099 |
Attention mask for cross-attention between text and vision. Shape: `(batch_size, 1, text_seq_len, vision_seq_len)`.
|
| 1100 |
vision_position_ids (`torch.LongTensor`, *optional*):
|
| 1101 |
Position IDs for vision tokens used in cross-attention. Shape: `(batch_size, vision_seq_len)`.
|
| 1102 |
+
cache_position (`torch.LongTensor`, *optional*):
|
| 1103 |
+
Absolute cache positions for the current text tokens during incremental decoding.
|
| 1104 |
"""
|
| 1105 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1106 |
+
output_hidden_states = (
|
| 1107 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1108 |
+
)
|
| 1109 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1110 |
+
|
| 1111 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1112 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1113 |
|
|
|
|
| 1129 |
|
| 1130 |
attention_mask = create_causal_mask(
|
| 1131 |
config=self.config,
|
| 1132 |
+
inputs_embeds=inputs_embeds,
|
| 1133 |
attention_mask=attention_mask,
|
| 1134 |
cache_position=cache_position,
|
| 1135 |
past_key_values=past_key_values,
|
|
|
|
| 1144 |
# Compute vision position embeddings (for cross-attention key/value) if needed
|
| 1145 |
vision_position_embeddings = None
|
| 1146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1147 |
if cross_attention_states is not None:
|
| 1148 |
if vision_position_ids is not None:
|
| 1149 |
vision_position_embeddings = self.rotary_emb(cross_attention_states, vision_position_ids)
|
| 1150 |
|
| 1151 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1152 |
+
all_attentions = () if output_attentions else None
|
| 1153 |
+
|
| 1154 |
+
if output_hidden_states:
|
| 1155 |
+
all_hidden_states += (hidden_states,)
|
| 1156 |
|
| 1157 |
for idx, decoder_layer in enumerate(self.layers):
|
| 1158 |
# For text-only path we should skip cross attention layers.
|
|
|
|
| 1177 |
cross_attention_states=cross_attention_states,
|
| 1178 |
cross_attention_mask=cross_attention_mask,
|
| 1179 |
vision_position_ids=vision_position_ids,
|
|
|
|
| 1180 |
vision_position_embeddings=vision_position_embeddings,
|
| 1181 |
+
output_attentions=output_attentions,
|
| 1182 |
**kwargs,
|
| 1183 |
)
|
| 1184 |
+
hidden_states = layer_outputs[0]
|
| 1185 |
+
|
| 1186 |
+
if output_attentions:
|
| 1187 |
+
all_attentions += (layer_outputs[1],)
|
| 1188 |
+
|
| 1189 |
+
if output_hidden_states:
|
| 1190 |
+
all_hidden_states += (hidden_states,)
|
| 1191 |
|
| 1192 |
hidden_states = self.norm(hidden_states)
|
| 1193 |
+
if output_hidden_states:
|
| 1194 |
+
all_hidden_states = all_hidden_states[:-1] + (hidden_states,)
|
| 1195 |
+
|
| 1196 |
+
if not return_dict:
|
| 1197 |
+
outputs = (hidden_states, past_key_values)
|
| 1198 |
+
if output_hidden_states:
|
| 1199 |
+
outputs += (all_hidden_states,)
|
| 1200 |
+
if output_attentions:
|
| 1201 |
+
outputs += (all_attentions,)
|
| 1202 |
+
return outputs
|
| 1203 |
|
| 1204 |
return BaseModelOutputWithPast(
|
| 1205 |
last_hidden_state=hidden_states,
|
| 1206 |
past_key_values=past_key_values,
|
| 1207 |
+
hidden_states=all_hidden_states,
|
| 1208 |
+
attentions=all_attentions,
|
| 1209 |
)
|
| 1210 |
|
| 1211 |
|
|
|
|
| 1224 |
super().__init__(config)
|
| 1225 |
self.visual = MossVLVisionModel._from_config(config.vision_config)
|
| 1226 |
self.language_model = MossVLTextModel._from_config(config.text_config)
|
|
|
|
|
|
|
| 1227 |
|
| 1228 |
# Learnable Separator Token: inserted after each image/frame's vision tokens
|
| 1229 |
# Initialized from LLM's separator_token_init_id embedding
|
|
|
|
| 1532 |
continue
|
| 1533 |
|
| 1534 |
# Collect repetition counts for all frames in this sample
|
| 1535 |
+
repeats_parts = []
|
| 1536 |
for media in medias:
|
| 1537 |
num_frames = media.get('num_frames', 1)
|
| 1538 |
length = media['length']
|
|
|
|
| 1547 |
|
| 1548 |
# In convert_packed_to_batch we enforce strictly regular frames
|
| 1549 |
# so we can assume all frames have the same number of tokens
|
| 1550 |
+
repeats_parts.append(
|
| 1551 |
+
torch.full(
|
| 1552 |
+
(num_frames,),
|
| 1553 |
+
tokens_per_frame_with_sep,
|
| 1554 |
+
dtype=torch.long,
|
| 1555 |
+
device=cross_attention_mask.device,
|
| 1556 |
+
)
|
| 1557 |
+
)
|
| 1558 |
|
| 1559 |
+
num_valid_frames = sum(part.numel() for part in repeats_parts)
|
| 1560 |
if num_valid_frames == 0:
|
| 1561 |
continue
|
| 1562 |
|
| 1563 |
# If cross_attention_mask has more frames (e.g. padded), slice it
|
| 1564 |
# If it has fewer (shouldn't happen), slice repeats
|
| 1565 |
valid_mask_frames = min(num_valid_frames, cross_attention_mask.shape[-1])
|
| 1566 |
+
repeats_tensor = torch.cat(repeats_parts)
|
| 1567 |
if valid_mask_frames < num_valid_frames:
|
| 1568 |
+
repeats_tensor = repeats_tensor[:valid_mask_frames]
|
| 1569 |
|
| 1570 |
# Extract valid columns for this sample
|
| 1571 |
# (1, text_len, valid_mask_frames)
|
| 1572 |
source_mask = cross_attention_mask[i, :, :, :valid_mask_frames]
|
| 1573 |
|
|
|
|
|
|
|
|
|
|
| 1574 |
# Expand using repeat_interleave
|
| 1575 |
# output shape: (1, text_len, sum(repeats))
|
| 1576 |
expanded_mask = source_mask.repeat_interleave(repeats_tensor, dim=-1)
|
|
|
|
| 1589 |
self,
|
| 1590 |
input_ids: torch.Tensor,
|
| 1591 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1592 |
+
past_key_values: Optional[Cache] = None,
|
| 1593 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 1594 |
) -> torch.Tensor:
|
| 1595 |
"""
|
| 1596 |
Compute 3D position IDs for text tokens with special handling for image tokens.
|
|
|
|
| 1605 |
Args:
|
| 1606 |
input_ids: (batch_size, seq_len)
|
| 1607 |
attention_mask: (batch_size, seq_len), optional
|
| 1608 |
+
past_key_values: cache object used to infer decode offset from the current text cache length
|
| 1609 |
|
| 1610 |
Returns:
|
| 1611 |
position_ids: (3, batch_size, seq_len)
|
|
|
|
| 1614 |
device = input_ids.device
|
| 1615 |
image_token_id = self.config.image_token_id
|
| 1616 |
|
| 1617 |
+
# Decode stage: always advance positions from the current text cache length.
|
| 1618 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1619 |
+
if past_seen_tokens > 0:
|
|
|
|
| 1620 |
position_ids = torch.arange(seq_len, device=device, dtype=torch.long)
|
| 1621 |
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
| 1622 |
+
position_ids = position_ids + past_seen_tokens
|
| 1623 |
+
|
| 1624 |
+
if rope_deltas is not None:
|
| 1625 |
+
position_ids = position_ids + rope_deltas.unsqueeze(1)
|
| 1626 |
+
|
| 1627 |
+
return position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1628 |
|
| 1629 |
# Prefill stage: compute full position_ids with image token awareness
|
| 1630 |
# Vectorized implementation
|
|
|
|
| 1703 |
rope_deltas: (batch_size,) - position offset due to vision tokens
|
| 1704 |
"""
|
| 1705 |
batch_size, max_vision_seq_len, _ = cross_attention_states.shape
|
| 1706 |
+
device = cross_attention_states.device
|
| 1707 |
image_token_id = self.config.image_token_id
|
| 1708 |
merge_size = self.visual.spatial_merge_size
|
| 1709 |
|
|
|
|
| 1711 |
# We need to flatten the nested vision_token_info structure to align with image tokens in input_ids
|
| 1712 |
|
| 1713 |
# Find all image tokens in text: (num_occurrences, 2) -> [batch_idx, seq_idx]
|
| 1714 |
+
image_token_indices = (input_ids == image_token_id).nonzero()
|
| 1715 |
|
| 1716 |
# Flatten vision_token_info to parallel lists
|
| 1717 |
# We assume the order of medias in vision_token_info matches the appearance of image tokens in input_ids
|
| 1718 |
+
flat_eff_h_parts = []
|
| 1719 |
+
flat_eff_w_parts = []
|
| 1720 |
+
flat_vis_start_parts = []
|
| 1721 |
+
|
|
|
|
| 1722 |
# Processing metadata on CPU (fast enough for typical batch sizes)
|
| 1723 |
for b_idx, info in enumerate(vision_token_info):
|
| 1724 |
medias = info.get('medias', [])
|
|
|
|
| 1729 |
start = media['start']
|
| 1730 |
tok_per_frame = media['vision_tokens_per_frame']
|
| 1731 |
stride = tok_per_frame + 1 # +1 for separator
|
| 1732 |
+
|
| 1733 |
+
frame_offsets = start + torch.arange(num_frames, device=device, dtype=torch.long) * stride
|
| 1734 |
+
flat_vis_start_parts.append(frame_offsets)
|
| 1735 |
+
flat_eff_h_parts.append(torch.full((num_frames,), eh, device=device, dtype=torch.long))
|
| 1736 |
+
flat_eff_w_parts.append(torch.full((num_frames,), ew, device=device, dtype=torch.long))
|
|
|
|
|
|
|
| 1737 |
|
| 1738 |
# Pre-allocate output
|
| 1739 |
vision_pos_ids = torch.zeros(
|
|
|
|
| 1743 |
)
|
| 1744 |
|
| 1745 |
# Handle case where no image tokens or info
|
| 1746 |
+
if len(flat_eff_h_parts) == 0 or len(image_token_indices) == 0:
|
| 1747 |
rope_deltas = position_ids.max(dim=0).values.max(dim=-1).values + 1 - input_ids.shape[1]
|
| 1748 |
return vision_pos_ids, position_ids, rope_deltas
|
| 1749 |
|
| 1750 |
+
flat_eff_h = torch.cat(flat_eff_h_parts)
|
| 1751 |
+
flat_eff_w = torch.cat(flat_eff_w_parts)
|
| 1752 |
+
flat_vis_starts = torch.cat(flat_vis_start_parts)
|
| 1753 |
+
|
| 1754 |
# Align lengths (handle truncation if text has fewer tokens or vice versa)
|
| 1755 |
+
num_matches = min(flat_eff_h.shape[0], image_token_indices.shape[0])
|
| 1756 |
+
flat_eff_h = flat_eff_h[:num_matches]
|
| 1757 |
+
flat_eff_w = flat_eff_w[:num_matches]
|
| 1758 |
+
flat_vis_starts = flat_vis_starts[:num_matches]
|
|
|
|
|
|
|
| 1759 |
|
| 1760 |
# Get corresponding text positions
|
| 1761 |
target_indices = image_token_indices[:num_matches]
|
|
|
|
| 1921 |
)
|
| 1922 |
return vision_embeds, vision_token_info
|
| 1923 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1924 |
|
| 1925 |
|
| 1926 |
@auto_docstring
|
|
|
|
| 1936 |
media_nums_per_sample: Optional[List[int]] = None,
|
| 1937 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 1938 |
cross_attention_mask: Optional[torch.Tensor] = None,
|
| 1939 |
+
vision_token_info: Optional[List[dict]] = None,
|
| 1940 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 1941 |
+
output_attentions: Optional[bool] = None,
|
| 1942 |
+
output_hidden_states: Optional[bool] = None,
|
| 1943 |
+
return_dict: Optional[bool] = None,
|
| 1944 |
**kwargs: Unpack[TransformersKwargs],
|
| 1945 |
) -> Union[tuple, BaseModelOutputWithPast]:
|
| 1946 |
"""
|
|
|
|
| 1957 |
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
|
| 1958 |
Attention mask for cross-attention between text and vision. Controls which vision tokens each text
|
| 1959 |
token can attend to, enforcing causal visibility for video frames.
|
| 1960 |
+
vision_token_info (`List[dict]`, *optional*):
|
| 1961 |
+
Cached metadata describing how packed vision tokens were regrouped per sample. Reused in decode
|
| 1962 |
+
to expand frame-level cross-attention masks to token-level masks without recomputing vision features.
|
| 1963 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1964 |
+
Cached offsets between text sequence length and multimodal RoPE positions. Reused in decode to
|
| 1965 |
+
reconstruct text position ids from the current cache length.
|
| 1966 |
"""
|
| 1967 |
+
cache_position = kwargs.pop("cache_position", None)
|
| 1968 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1969 |
+
output_hidden_states = (
|
| 1970 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1971 |
+
)
|
| 1972 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1973 |
+
|
| 1974 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1975 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1976 |
|
|
|
|
| 1979 |
|
| 1980 |
# Process vision features (images and videos are already merged by processor)
|
| 1981 |
cross_attention_states = None
|
| 1982 |
+
|
|
|
|
| 1983 |
if pixel_values is not None:
|
| 1984 |
# Determine batch size
|
| 1985 |
batch_size = inputs_embeds.shape[0]
|
|
|
|
| 1994 |
|
| 1995 |
# Process all vision inputs together through VIT
|
| 1996 |
# pixel_values and grid_thw are already ordered by appearance in text
|
| 1997 |
+
vision_embeds, vision_token_info = self.get_vision_features(
|
| 1998 |
+
pixel_values, grid_thw, media_nums_per_sample
|
|
|
|
|
|
|
|
|
|
| 1999 |
)
|
| 2000 |
|
| 2001 |
# vision_embeds: [batch_size, max_seq_len, hidden_size]
|
| 2002 |
cross_attention_states = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2003 |
|
| 2004 |
# Generate 3D position IDs for text if not provided
|
| 2005 |
if position_ids is None:
|
|
|
|
| 2008 |
position_ids = self.compute_position_ids(
|
| 2009 |
input_ids=input_ids,
|
| 2010 |
attention_mask=attention_mask,
|
| 2011 |
+
past_key_values=past_key_values,
|
| 2012 |
+
rope_deltas=rope_deltas,
|
| 2013 |
)
|
| 2014 |
|
| 2015 |
# Compute cross_attention_mask, vision_position_ids, and full_text_row_masked_out_mask
|
|
|
|
| 2033 |
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
|
| 2034 |
)
|
| 2035 |
cross_attention_mask = cross_attention_mask * full_text_row_masked_out_mask
|
|
|
|
|
|
|
| 2036 |
|
| 2037 |
if vision_position_ids is None and cross_attention_states is not None and input_ids is not None:
|
| 2038 |
vision_position_ids, position_ids, rope_deltas = self.compute_vision_position_ids(
|
|
|
|
| 2042 |
cross_attention_states,
|
| 2043 |
attention_mask
|
| 2044 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2045 |
|
| 2046 |
outputs = self.language_model(
|
| 2047 |
input_ids=None,
|
|
|
|
| 2054 |
cross_attention_mask=cross_attention_mask,
|
| 2055 |
vision_position_ids=vision_position_ids,
|
| 2056 |
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
| 2057 |
+
output_attentions=output_attentions,
|
| 2058 |
+
output_hidden_states=output_hidden_states,
|
| 2059 |
+
return_dict=return_dict,
|
| 2060 |
**kwargs,
|
| 2061 |
)
|
| 2062 |
|
| 2063 |
+
if not return_dict:
|
| 2064 |
+
last_hidden_state = outputs[0]
|
| 2065 |
+
model_outputs = (
|
| 2066 |
+
last_hidden_state,
|
| 2067 |
+
outputs[1] if len(outputs) > 1 else past_key_values,
|
| 2068 |
+
)
|
| 2069 |
+
if output_hidden_states:
|
| 2070 |
+
model_outputs += (outputs[2],)
|
| 2071 |
+
if output_attentions:
|
| 2072 |
+
attn_idx = 3 if output_hidden_states else 2
|
| 2073 |
+
model_outputs += (outputs[attn_idx],)
|
| 2074 |
+
model_outputs += (vision_token_info, rope_deltas)
|
| 2075 |
+
return model_outputs
|
| 2076 |
+
|
| 2077 |
return MossVLModelOutputWithPast(
|
| 2078 |
last_hidden_state=outputs.last_hidden_state,
|
| 2079 |
past_key_values=outputs.past_key_values,
|
| 2080 |
hidden_states=outputs.hidden_states,
|
| 2081 |
attentions=outputs.attentions,
|
| 2082 |
+
vision_token_info=vision_token_info,
|
| 2083 |
+
rope_deltas=rope_deltas,
|
| 2084 |
)
|
| 2085 |
|
| 2086 |
|
|
|
|
| 2102 |
super().__init__(config)
|
| 2103 |
self.model = MossVLModel(config)
|
| 2104 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
|
|
|
| 2105 |
|
| 2106 |
self.post_init()
|
| 2107 |
|
|
|
|
| 2159 |
media_nums_per_sample: Optional[List[int]] = None,
|
| 2160 |
vision_position_ids: Optional[torch.LongTensor] = None,
|
| 2161 |
cross_attention_mask: Optional[torch.Tensor] = None,
|
| 2162 |
+
vision_token_info: Optional[List[dict]] = None,
|
| 2163 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 2164 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 2165 |
+
output_attentions: Optional[bool] = None,
|
| 2166 |
+
output_hidden_states: Optional[bool] = None,
|
| 2167 |
+
return_dict: Optional[bool] = None,
|
| 2168 |
**kwargs: Unpack[TransformersKwargs],
|
| 2169 |
) -> Union[tuple, CausalLMOutputWithPast]:
|
| 2170 |
"""
|
|
|
|
| 2181 |
cross_attention_mask (`torch.Tensor` of shape `(batch_size, 1, text_seq_len, vision_seq_len)`, *optional*):
|
| 2182 |
Attention mask for cross-attention between text and vision. Controls which vision tokens each text
|
| 2183 |
token can attend to, enforcing causal visibility for video frames.
|
| 2184 |
+
vision_token_info (`List[dict]`, *optional*):
|
| 2185 |
+
Cached metadata describing how packed vision tokens were regrouped per sample. Reused across decode
|
| 2186 |
+
steps to expand cross-attention masks without re-running the vision encoder.
|
| 2187 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 2188 |
+
Cached multimodal RoPE offsets returned by the base model during prefill and reused during decode.
|
| 2189 |
"""
|
| 2190 |
+
cache_position = kwargs.pop("cache_position", None)
|
| 2191 |
outputs = self.model(
|
| 2192 |
input_ids=input_ids,
|
| 2193 |
pixel_values=pixel_values,
|
|
|
|
| 2199 |
cross_attention_mask=cross_attention_mask,
|
| 2200 |
past_key_values=past_key_values,
|
| 2201 |
inputs_embeds=inputs_embeds,
|
| 2202 |
+
vision_token_info=vision_token_info,
|
| 2203 |
+
rope_deltas=rope_deltas,
|
| 2204 |
+
output_attentions=output_attentions,
|
| 2205 |
+
output_hidden_states=output_hidden_states,
|
| 2206 |
+
return_dict=return_dict,
|
| 2207 |
cache_position=cache_position,
|
|
|
|
| 2208 |
**kwargs,
|
| 2209 |
)
|
| 2210 |
|
| 2211 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2212 |
+
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
| 2213 |
|
| 2214 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 2215 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
| 2218 |
if labels is not None:
|
| 2219 |
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
| 2220 |
|
| 2221 |
+
if not return_dict:
|
| 2222 |
+
output = (logits,)
|
| 2223 |
+
output += outputs[1:]
|
| 2224 |
+
return ((loss,) + output) if loss is not None else output
|
| 2225 |
+
|
| 2226 |
return MossVLCausalLMOutputWithPast(
|
| 2227 |
loss=loss,
|
| 2228 |
logits=logits,
|
|
|
|
| 2239 |
past_key_values=None,
|
| 2240 |
attention_mask=None,
|
| 2241 |
inputs_embeds=None,
|
|
|
|
| 2242 |
position_ids=None,
|
| 2243 |
use_cache=True,
|
| 2244 |
pixel_values=None,
|
| 2245 |
grid_thw=None,
|
| 2246 |
media_nums_per_sample=None, # One video is one meida.
|
| 2247 |
vision_position_ids=None,
|
| 2248 |
+
vision_token_info=None,
|
| 2249 |
+
rope_deltas=None,
|
| 2250 |
cross_attention_mask=None,
|
|
|
|
| 2251 |
**kwargs,
|
| 2252 |
):
|
| 2253 |
"""
|
|
|
|
| 2260 |
Args:
|
| 2261 |
media_nums_per_sample: One video counts as one media item (regardless of frame count)
|
| 2262 |
"""
|
| 2263 |
+
kwargs.pop("cache_position", None)
|
| 2264 |
model_inputs = super().prepare_inputs_for_generation(
|
| 2265 |
input_ids,
|
| 2266 |
past_key_values=past_key_values,
|
| 2267 |
attention_mask=attention_mask,
|
| 2268 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 2269 |
position_ids=position_ids,
|
| 2270 |
pixel_values=pixel_values,
|
| 2271 |
grid_thw=grid_thw,
|
|
|
|
| 2274 |
**kwargs,
|
| 2275 |
)
|
| 2276 |
|
| 2277 |
+
model_input = model_inputs.get("input_ids")
|
| 2278 |
+
if model_input is None:
|
| 2279 |
+
model_input = model_inputs.get("inputs_embeds")
|
| 2280 |
+
current_length = model_input.shape[1]
|
| 2281 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 2282 |
|
| 2283 |
+
# Let the model recompute multimodal position ids from the current cache length.
|
|
|
|
| 2284 |
model_inputs["position_ids"] = None
|
| 2285 |
+
model_inputs["vision_token_info"] = vision_token_info
|
| 2286 |
+
model_inputs["rope_deltas"] = rope_deltas
|
| 2287 |
|
| 2288 |
# Handle cross attention mask
|
| 2289 |
if cross_attention_mask is not None:
|
| 2290 |
+
# Slice to the current text slice on text dimension (dim=2).
|
| 2291 |
+
# Shape: [batch, 1, text_len, vision_len] -> [batch, 1, current_len, vision_len]
|
| 2292 |
+
cross_attention_mask = cross_attention_mask[:, :, -current_length:, :]
|
| 2293 |
model_inputs["cross_attention_mask"] = cross_attention_mask
|
| 2294 |
|
| 2295 |
+
# Vision inputs are only needed in prefill stage.
|
| 2296 |
# In decode stage, vision features are retrieved from cross attention cache
|
| 2297 |
+
if past_seen_tokens > 0:
|
| 2298 |
model_inputs["pixel_values"] = None
|
| 2299 |
model_inputs["grid_thw"] = None
|
| 2300 |
model_inputs["media_nums_per_sample"] = None
|
|
|
|
| 2303 |
else:
|
| 2304 |
# In prefill stage, include all vision-related inputs
|
| 2305 |
model_inputs["vision_position_ids"] = vision_position_ids
|
|
|
|
| 2306 |
|
| 2307 |
return model_inputs
|
| 2308 |
|
|
|
|
| 2323 |
**kwargs,
|
| 2324 |
)
|
| 2325 |
|
|
|
|
|
|
|
| 2326 |
if cross_attention_mask_prev is not None:
|
| 2327 |
+
model_kwargs["cross_attention_mask"] = cross_attention_mask_prev
|
| 2328 |
+
|
| 2329 |
+
if getattr(outputs, "vision_token_info", None) is not None:
|
| 2330 |
+
model_kwargs["vision_token_info"] = outputs.vision_token_info
|
| 2331 |
+
if getattr(outputs, "rope_deltas", None) is not None:
|
| 2332 |
+
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
| 2333 |
|
| 2334 |
return model_kwargs
|
| 2335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2336 |
|
| 2337 |
__all__ = [
|
| 2338 |
"MossVLVisionModel",
|