huiwon commited on
Commit
faaf53b
·
verified ·
1 Parent(s): c6c73d4

Add modeling_qwen3_vl.py

Browse files
Files changed (1) hide show
  1. modeling_qwen3_vl.py +1575 -0
modeling_qwen3_vl.py ADDED
@@ -0,0 +1,1575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3_vl.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Any, Callable, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.masking_utils import create_causal_mask
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
41
+ from transformers.utils.deprecation import deprecate_kwarg
42
+ from transformers.utils.generic import check_model_inputs
43
+ from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
44
+
45
+
46
+ class Qwen3VLVisionMLP(nn.Module):
47
+ def __init__(self, config):
48
+ super().__init__()
49
+ self.hidden_size = config.hidden_size
50
+ self.intermediate_size = config.intermediate_size
51
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
52
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
53
+ self.act_fn = ACT2FN[config.hidden_act]
54
+
55
+ def forward(self, hidden_state):
56
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
57
+
58
+
59
+ class Qwen3VLVisionPatchEmbed(nn.Module):
60
+ def __init__(self, config) -> None:
61
+ super().__init__()
62
+ self.patch_size = config.patch_size
63
+ self.temporal_patch_size = config.temporal_patch_size
64
+ self.in_channels = config.in_channels
65
+ self.embed_dim = config.hidden_size
66
+
67
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
68
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
69
+
70
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
71
+ target_dtype = self.proj.weight.dtype
72
+ hidden_states = hidden_states.view(
73
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
74
+ )
75
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
76
+ return hidden_states
77
+
78
+
79
+ class Qwen3VLVisionRotaryEmbedding(nn.Module):
80
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
81
+
82
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
83
+ super().__init__()
84
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
85
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
86
+
87
+ def forward(self, seqlen: int) -> torch.Tensor:
88
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
89
+ freqs = torch.outer(seq, self.inv_freq)
90
+ return freqs
91
+
92
+
93
+ class Qwen3VLVisionPatchMerger(nn.Module):
94
+ def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
95
+ super().__init__()
96
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
97
+ self.use_postshuffle_norm = use_postshuffle_norm
98
+ self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
99
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
100
+ self.act_fn = nn.GELU()
101
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
105
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
106
+ return x
107
+
108
+
109
+ def rotate_half(x):
110
+ """Rotates half the hidden dims of the input."""
111
+ x1 = x[..., : x.shape[-1] // 2]
112
+ x2 = x[..., x.shape[-1] // 2 :]
113
+ return torch.cat((-x2, x1), dim=-1)
114
+
115
+
116
+ def apply_rotary_pos_emb_vision(
117
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
118
+ ) -> tuple[torch.Tensor, torch.Tensor]:
119
+ orig_q_dtype = q.dtype
120
+ orig_k_dtype = k.dtype
121
+ q, k = q.float(), k.float()
122
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
123
+ q_embed = (q * cos) + (rotate_half(q) * sin)
124
+ k_embed = (k * cos) + (rotate_half(k) * sin)
125
+ q_embed = q_embed.to(orig_q_dtype)
126
+ k_embed = k_embed.to(orig_k_dtype)
127
+ return q_embed, k_embed
128
+
129
+
130
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
131
+ """
132
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
133
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
134
+ """
135
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
136
+ if n_rep == 1:
137
+ return hidden_states
138
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
139
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
140
+
141
+
142
+ def eager_attention_forward(
143
+ module: nn.Module,
144
+ query: torch.Tensor,
145
+ key: torch.Tensor,
146
+ value: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor],
148
+ scaling: float,
149
+ dropout: float = 0.0,
150
+ **kwargs: Unpack[TransformersKwargs],
151
+ ):
152
+ key_states = repeat_kv(key, module.num_key_value_groups)
153
+ value_states = repeat_kv(value, module.num_key_value_groups)
154
+
155
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
156
+ if attention_mask is not None:
157
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
158
+ attn_weights = attn_weights + causal_mask
159
+
160
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
161
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
162
+ attn_output = torch.matmul(attn_weights, value_states)
163
+ attn_output = attn_output.transpose(1, 2).contiguous()
164
+
165
+ return attn_output, attn_weights
166
+
167
+
168
+ class Qwen3VLVisionAttention(nn.Module):
169
+ def __init__(self, config: Qwen3VLVisionConfig) -> None:
170
+ super().__init__()
171
+ self.dim = config.hidden_size
172
+ self.num_heads = config.num_heads
173
+ self.head_dim = self.dim // self.num_heads
174
+ self.num_key_value_groups = 1 # needed for eager attention
175
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
176
+ self.proj = nn.Linear(self.dim, self.dim)
177
+ self.scaling = self.head_dim**-0.5
178
+ self.config = config
179
+ self.attention_dropout = 0.0
180
+ self.is_causal = False
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ cu_seqlens: torch.Tensor,
186
+ rotary_pos_emb: Optional[torch.Tensor] = None,
187
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
188
+ **kwargs,
189
+ ) -> torch.Tensor:
190
+ seq_length = hidden_states.shape[0]
191
+ query_states, key_states, value_states = (
192
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
193
+ )
194
+ cos, sin = position_embeddings
195
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
196
+
197
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
198
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
199
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
200
+
201
+ attention_interface: Callable = eager_attention_forward
202
+ if self.config._attn_implementation != "eager":
203
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
204
+
205
+ if self.config._attn_implementation == "flash_attention_2":
206
+ # Flash Attention 2: Use cu_seqlens for variable length attention
207
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
208
+ attn_output, _ = attention_interface(
209
+ self,
210
+ query_states,
211
+ key_states,
212
+ value_states,
213
+ attention_mask=None,
214
+ scaling=self.scaling,
215
+ dropout=0.0 if not self.training else self.attention_dropout,
216
+ cu_seq_lens_q=cu_seqlens,
217
+ cu_seq_lens_k=cu_seqlens,
218
+ max_length_q=max_seqlen,
219
+ max_length_k=max_seqlen,
220
+ is_causal=False,
221
+ **kwargs,
222
+ )
223
+ else:
224
+ # Other implementations: Process each chunk separately
225
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
226
+ splits = [
227
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
228
+ ]
229
+
230
+ attn_outputs = [
231
+ attention_interface(
232
+ self,
233
+ q,
234
+ k,
235
+ v,
236
+ attention_mask=None,
237
+ scaling=self.scaling,
238
+ dropout=0.0 if not self.training else self.attention_dropout,
239
+ is_causal=False,
240
+ **kwargs,
241
+ )[0]
242
+ for q, k, v in zip(*splits)
243
+ ]
244
+ attn_output = torch.cat(attn_outputs, dim=1)
245
+
246
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
247
+ attn_output = self.proj(attn_output)
248
+ return attn_output
249
+
250
+
251
+ class Qwen3VLVisionBlock(GradientCheckpointingLayer):
252
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
253
+ super().__init__()
254
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
255
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
256
+ self.attn = Qwen3VLVisionAttention(config=config)
257
+ self.mlp = Qwen3VLVisionMLP(config=config)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states: torch.Tensor,
262
+ cu_seqlens: torch.Tensor,
263
+ rotary_pos_emb: Optional[torch.Tensor] = None,
264
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
265
+ **kwargs,
266
+ ) -> torch.Tensor:
267
+ hidden_states = hidden_states + self.attn(
268
+ self.norm1(hidden_states),
269
+ cu_seqlens=cu_seqlens,
270
+ rotary_pos_emb=rotary_pos_emb,
271
+ position_embeddings=position_embeddings,
272
+ **kwargs,
273
+ )
274
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
275
+ return hidden_states
276
+
277
+
278
+ class Qwen3VLTextRotaryEmbedding(nn.Module):
279
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
280
+
281
+ def __init__(self, config: Qwen3VLTextConfig, device=None):
282
+ super().__init__()
283
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
284
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
285
+ else:
286
+ self.rope_type = "default"
287
+ self.max_seq_len_cached = config.max_position_embeddings
288
+ self.original_max_seq_len = config.max_position_embeddings
289
+
290
+ self.config = config
291
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
292
+
293
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
294
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
295
+ self.original_inv_freq = self.inv_freq
296
+
297
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
298
+
299
+ def apply_interleaved_mrope(self, freqs, mrope_section):
300
+ """Apply interleaved MRoPE to 3D rotary embeddings.
301
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
302
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
303
+ args:
304
+ x: (3, bs, seq_len, head_dim // 2)
305
+ mrope_section: (3,)
306
+ returns:
307
+ x_t: (bs, seq_len, head_dim // 2)
308
+ """
309
+ freqs_t = freqs[0] # just overwrite the first dimension T
310
+ for dim, offset in enumerate((1, 2), start=1): # H, W
311
+ length = mrope_section[dim] * 3
312
+ idx = slice(offset, length, 3)
313
+ freqs_t[..., idx] = freqs[dim, ..., idx]
314
+ return freqs_t
315
+
316
+ @torch.no_grad()
317
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
318
+ def forward(self, x, position_ids):
319
+ # In contrast to other models, Qwen3VL has different position ids for the grids
320
+ # So we expand the inv_freq to shape (3, ...)
321
+ if position_ids.ndim == 2:
322
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
323
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
324
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
325
+
326
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
327
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
328
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
329
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
330
+ emb = torch.cat((freqs, freqs), dim=-1)
331
+ cos = emb.cos() * self.attention_scaling
332
+ sin = emb.sin() * self.attention_scaling
333
+
334
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
335
+
336
+
337
+ @use_kernel_forward_from_hub("RMSNorm")
338
+ class Qwen3VLTextRMSNorm(nn.Module):
339
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
340
+ """
341
+ Qwen3VLTextRMSNorm is equivalent to T5LayerNorm
342
+ """
343
+ super().__init__()
344
+ self.weight = nn.Parameter(torch.ones(hidden_size))
345
+ self.variance_epsilon = eps
346
+
347
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
348
+ input_dtype = hidden_states.dtype
349
+ hidden_states = hidden_states.to(torch.float32)
350
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
351
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
352
+ return self.weight * hidden_states.to(input_dtype)
353
+
354
+ def extra_repr(self):
355
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
356
+
357
+
358
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
359
+ """Applies Rotary Position Embedding to the query and key tensors.
360
+
361
+ Args:
362
+ q (`torch.Tensor`): The query tensor.
363
+ k (`torch.Tensor`): The key tensor.
364
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
365
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
366
+ position_ids (`torch.Tensor`, *optional*):
367
+ Deprecated and unused.
368
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
369
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
370
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
371
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
372
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
373
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
374
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
375
+ Returns:
376
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
377
+ """
378
+ cos = cos.unsqueeze(unsqueeze_dim)
379
+ sin = sin.unsqueeze(unsqueeze_dim)
380
+ q_embed = (q * cos) + (rotate_half(q) * sin)
381
+ k_embed = (k * cos) + (rotate_half(k) * sin)
382
+ return q_embed, k_embed
383
+
384
+
385
+ class Qwen3VLTextAttention(nn.Module):
386
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
387
+
388
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
389
+ super().__init__()
390
+ self.config = config
391
+ self.layer_idx = layer_idx
392
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
393
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
394
+ self.scaling = self.head_dim**-0.5
395
+ self.attention_dropout = config.attention_dropout
396
+ self.is_causal = True
397
+
398
+ self.q_proj = nn.Linear(
399
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
400
+ )
401
+ self.k_proj = nn.Linear(
402
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
403
+ )
404
+ self.v_proj = nn.Linear(
405
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
406
+ )
407
+ self.o_proj = nn.Linear(
408
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
409
+ )
410
+ self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
411
+ self.k_norm = Qwen3VLTextRMSNorm(
412
+ self.head_dim, eps=config.rms_norm_eps
413
+ ) # thus post q_norm does not need reshape
414
+
415
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
420
+ attention_mask: Optional[torch.Tensor],
421
+ past_key_values: Optional[Cache] = None,
422
+ cache_position: Optional[torch.LongTensor] = None,
423
+ **kwargs: Unpack[FlashAttentionKwargs],
424
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
425
+ input_shape = hidden_states.shape[:-1]
426
+ hidden_shape = (*input_shape, -1, self.head_dim)
427
+
428
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
429
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
430
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
431
+
432
+ cos, sin = position_embeddings
433
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
434
+
435
+ if past_key_values is not None:
436
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
437
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
438
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
439
+
440
+ attention_interface: Callable = eager_attention_forward
441
+ if self.config._attn_implementation != "eager":
442
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
443
+
444
+ attn_output, attn_weights = attention_interface(
445
+ self,
446
+ query_states,
447
+ key_states,
448
+ value_states,
449
+ attention_mask,
450
+ dropout=0.0 if not self.training else self.attention_dropout,
451
+ scaling=self.scaling,
452
+ **kwargs,
453
+ )
454
+
455
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
456
+ attn_output = self.o_proj(attn_output)
457
+ return attn_output, attn_weights
458
+
459
+
460
+ class Qwen3VLTextMLP(nn.Module):
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.config = config
464
+ self.hidden_size = config.hidden_size
465
+ self.intermediate_size = config.intermediate_size
466
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
467
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
468
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
469
+ self.act_fn = ACT2FN[config.hidden_act]
470
+
471
+ def forward(self, x):
472
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
473
+ return down_proj
474
+
475
+
476
+ class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer):
477
+ def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
478
+ super().__init__()
479
+ self.hidden_size = config.hidden_size
480
+
481
+ self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx)
482
+
483
+ self.mlp = Qwen3VLTextMLP(config)
484
+ self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
485
+ self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
486
+
487
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
488
+ def forward(
489
+ self,
490
+ hidden_states: torch.Tensor,
491
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
492
+ attention_mask: Optional[torch.Tensor] = None,
493
+ position_ids: Optional[torch.LongTensor] = None,
494
+ past_key_values: Optional[Cache] = None,
495
+ use_cache: Optional[bool] = False,
496
+ cache_position: Optional[torch.LongTensor] = None,
497
+ **kwargs: Unpack[TransformersKwargs],
498
+ ) -> torch.Tensor:
499
+ residual = hidden_states
500
+ hidden_states = self.input_layernorm(hidden_states)
501
+ # Self Attention
502
+ hidden_states, _ = self.self_attn(
503
+ hidden_states=hidden_states,
504
+ attention_mask=attention_mask,
505
+ position_ids=position_ids,
506
+ past_key_values=past_key_values,
507
+ use_cache=use_cache,
508
+ cache_position=cache_position,
509
+ position_embeddings=position_embeddings,
510
+ **kwargs,
511
+ )
512
+ hidden_states = residual + hidden_states
513
+
514
+ # Fully Connected
515
+ residual = hidden_states
516
+ hidden_states = self.post_attention_layernorm(hidden_states)
517
+ hidden_states = self.mlp(hidden_states)
518
+ hidden_states = residual + hidden_states
519
+ return hidden_states
520
+
521
+
522
+ @dataclass
523
+ @auto_docstring(
524
+ custom_intro="""
525
+ Base class for Llava outputs, with hidden states and attentions.
526
+ """
527
+ )
528
+ class Qwen3VLModelOutputWithPast(ModelOutput):
529
+ r"""
530
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
531
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
532
+
533
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
534
+ `past_key_values` input) to speed up sequential decoding.
535
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
536
+ The rope index difference between sequence length and multimodal rope.
537
+ """
538
+
539
+ last_hidden_state: Optional[torch.FloatTensor] = None
540
+ past_key_values: Optional[Cache] = None
541
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
542
+ attentions: Optional[tuple[torch.FloatTensor]] = None
543
+ rope_deltas: Optional[torch.LongTensor] = None
544
+
545
+
546
+ @auto_docstring
547
+ class Qwen3VLPreTrainedModel(PreTrainedModel):
548
+ config: Qwen3VLConfig
549
+ base_model_prefix = "model"
550
+ supports_gradient_checkpointing = True
551
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
552
+ _skip_keys_device_placement = "past_key_values"
553
+ _supports_flash_attn = True
554
+ _supports_sdpa = True
555
+
556
+ _can_compile_fullgraph = True
557
+ _supports_attention_backend = True
558
+ _can_record_outputs = {
559
+ "hidden_states": Qwen3VLTextDecoderLayer,
560
+ "attentions": Qwen3VLTextAttention,
561
+ }
562
+
563
+
564
+ class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
565
+ config: Qwen3VLVisionConfig
566
+ _no_split_modules = ["Qwen3VLVisionBlock"]
567
+
568
+ def __init__(self, config, *inputs, **kwargs) -> None:
569
+ super().__init__(config, *inputs, **kwargs)
570
+ self.spatial_merge_size = config.spatial_merge_size
571
+ self.patch_size = config.patch_size
572
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
573
+
574
+ self.patch_embed = Qwen3VLVisionPatchEmbed(
575
+ config=config,
576
+ )
577
+
578
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
579
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
580
+
581
+ head_dim = config.hidden_size // config.num_heads
582
+ self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)
583
+
584
+ self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
585
+ self.merger = Qwen3VLVisionPatchMerger(
586
+ config=config,
587
+ use_postshuffle_norm=False,
588
+ )
589
+
590
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
591
+ self.deepstack_merger_list = nn.ModuleList(
592
+ [
593
+ Qwen3VLVisionPatchMerger(
594
+ config=config,
595
+ use_postshuffle_norm=True,
596
+ )
597
+ for _ in range(len(config.deepstack_visual_indexes))
598
+ ]
599
+ )
600
+
601
+ self.gradient_checkpointing = False
602
+
603
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
604
+ merge_size = self.spatial_merge_size
605
+
606
+ max_hw = int(grid_thw[:, 1:].max().item())
607
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
608
+ device = freq_table.device
609
+
610
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
611
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
612
+
613
+ offset = 0
614
+ for num_frames, height, width in grid_thw:
615
+ merged_h, merged_w = height // merge_size, width // merge_size
616
+
617
+ block_rows = torch.arange(merged_h, device=device) # block row indices
618
+ block_cols = torch.arange(merged_w, device=device) # block col indices
619
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
620
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
621
+
622
+ # Compute full-resolution positions
623
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
624
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
625
+
626
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
627
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
628
+
629
+ coords = torch.stack((row_idx, col_idx), dim=-1)
630
+
631
+ if num_frames > 1:
632
+ coords = coords.repeat(num_frames, 1)
633
+
634
+ num_tokens = coords.shape[0]
635
+ pos_ids[offset : offset + num_tokens] = coords
636
+ offset += num_tokens
637
+
638
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
639
+ embeddings = embeddings.flatten(1)
640
+ return embeddings
641
+
642
+ def fast_pos_embed_interpolate(self, grid_thw):
643
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
644
+
645
+ idx_list = [[] for _ in range(4)]
646
+ weight_list = [[] for _ in range(4)]
647
+
648
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
649
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
650
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
651
+
652
+ h_idxs_floor = h_idxs.int()
653
+ w_idxs_floor = w_idxs.int()
654
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
655
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
656
+
657
+ dh = h_idxs - h_idxs_floor
658
+ dw = w_idxs - w_idxs_floor
659
+
660
+ base_h = h_idxs_floor * self.num_grid_per_side
661
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
662
+
663
+ indices = [
664
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
665
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
666
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
667
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
668
+ ]
669
+
670
+ weights = [
671
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
672
+ ((1 - dh)[None].T * dw[None]).flatten(),
673
+ (dh[None].T * (1 - dw)[None]).flatten(),
674
+ (dh[None].T * dw[None]).flatten(),
675
+ ]
676
+
677
+ for i in range(4):
678
+ idx_list[i].extend(indices[i].tolist())
679
+ weight_list[i].extend(weights[i].tolist())
680
+
681
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
682
+ weight_tensor = torch.tensor(
683
+ weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
684
+ )
685
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
686
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
687
+
688
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
689
+
690
+ patch_pos_embeds_permute = []
691
+ merge_size = self.config.spatial_merge_size
692
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
693
+ pos_embed = pos_embed.repeat(t, 1)
694
+ pos_embed = (
695
+ pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
696
+ .permute(0, 1, 3, 2, 4, 5)
697
+ .flatten(0, 4)
698
+ )
699
+ patch_pos_embeds_permute.append(pos_embed)
700
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
701
+ return patch_pos_embeds
702
+
703
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
704
+ """
705
+ Args:
706
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
707
+ The final hidden states of the model.
708
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
709
+ The temporal, height and width of feature shape of each image in LLM.
710
+
711
+ Returns:
712
+ `torch.Tensor`: hidden_states.
713
+ """
714
+ hidden_states = self.patch_embed(hidden_states)
715
+
716
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
717
+ hidden_states = hidden_states + pos_embeds
718
+
719
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
720
+
721
+ seq_len, _ = hidden_states.size()
722
+ hidden_states = hidden_states.reshape(seq_len, -1)
723
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
724
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
725
+ position_embeddings = (emb.cos(), emb.sin())
726
+
727
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
728
+ dim=0,
729
+ # Select dtype based on the following factors:
730
+ # - FA2 requires that cu_seqlens_q must have dtype int32
731
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
732
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
733
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
734
+ )
735
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
736
+
737
+ deepstack_feature_lists = []
738
+ for layer_num, blk in enumerate(self.blocks):
739
+ hidden_states = blk(
740
+ hidden_states,
741
+ cu_seqlens=cu_seqlens,
742
+ position_embeddings=position_embeddings,
743
+ **kwargs,
744
+ )
745
+ if layer_num in self.deepstack_visual_indexes:
746
+ deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
747
+ hidden_states
748
+ )
749
+ deepstack_feature_lists.append(deepstack_feature)
750
+
751
+ hidden_states = self.merger(hidden_states)
752
+
753
+ return hidden_states, deepstack_feature_lists
754
+
755
+
756
+ @auto_docstring(
757
+ custom_intro=(
758
+ "Text part of Qwen3VL, "
759
+ "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
760
+ )
761
+ )
762
+ class Qwen3VLTextModel(Qwen3VLPreTrainedModel):
763
+ config: Qwen3VLTextConfig
764
+ _no_split_modules = ["Qwen3VLTextDecoderLayer"]
765
+
766
+ def __init__(self, config: Qwen3VLTextConfig):
767
+ super().__init__(config)
768
+ self.padding_idx = config.pad_token_id
769
+ self.vocab_size = config.vocab_size
770
+
771
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
772
+ self.layers = nn.ModuleList(
773
+ [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
774
+ )
775
+ self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
776
+ self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)
777
+ self.gradient_checkpointing = False
778
+
779
+ # Initialize weights and apply final processing
780
+ self.post_init()
781
+
782
+ @check_model_inputs
783
+ @auto_docstring
784
+ def forward(
785
+ self,
786
+ input_ids: Optional[torch.LongTensor] = None,
787
+ attention_mask: Optional[torch.Tensor] = None,
788
+ position_ids: Optional[torch.LongTensor] = None,
789
+ past_key_values: Optional[Cache] = None,
790
+ inputs_embeds: Optional[torch.FloatTensor] = None,
791
+ use_cache: Optional[bool] = None,
792
+ cache_position: Optional[torch.LongTensor] = None,
793
+ # args for deepstack
794
+ visual_pos_masks: Optional[torch.Tensor] = None,
795
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
796
+ **kwargs: Unpack[FlashAttentionKwargs],
797
+ ) -> Union[tuple, BaseModelOutputWithPast]:
798
+ r"""
799
+ visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
800
+ The mask of the visual positions.
801
+ deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
802
+ The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
803
+ The feature is extracted from the different visual encoder layers, and fed to the decoder
804
+ hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
805
+ """
806
+ #if (input_ids is None) ^ (inputs_embeds is not None):
807
+ # raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
808
+
809
+ # torch.jit.trace() doesn't support cache objects in the output
810
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
811
+ past_key_values = DynamicCache(config=self.config)
812
+
813
+ if inputs_embeds is None:
814
+ inputs_embeds = self.embed_tokens(input_ids)
815
+
816
+ if cache_position is None:
817
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
818
+ cache_position = torch.arange(
819
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
820
+ )
821
+
822
+ # the hard coded `3` is for temporal, height and width.
823
+ if position_ids is None:
824
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
825
+ elif position_ids.ndim == 2:
826
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
827
+
828
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
829
+ text_position_ids = position_ids[0]
830
+ position_ids = position_ids[1:]
831
+ else:
832
+ text_position_ids = position_ids[0]
833
+
834
+ attention_mask = create_causal_mask(
835
+ config=self.config,
836
+ input_embeds=inputs_embeds,
837
+ attention_mask=attention_mask,
838
+ cache_position=cache_position,
839
+ past_key_values=past_key_values,
840
+ position_ids=text_position_ids,
841
+ )
842
+
843
+ hidden_states = inputs_embeds
844
+
845
+ # create position embeddings to be shared across the decoder layers
846
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
847
+
848
+ # decoder layers
849
+ for layer_idx, decoder_layer in enumerate(self.layers):
850
+ layer_outputs = decoder_layer(
851
+ hidden_states,
852
+ input_ids,
853
+ attention_mask=attention_mask,
854
+ position_ids=text_position_ids,
855
+ past_key_values=past_key_values,
856
+ cache_position=cache_position,
857
+ position_embeddings=position_embeddings,
858
+ **kwargs,
859
+ )
860
+ ## FIXME: HARD CODING
861
+ hidden_states = layer_outputs[0]
862
+ if 'attention_mask' in layer_outputs[1]:
863
+ attention_mask = layer_outputs[1]['attention_mask']
864
+ if 'position_ids' in layer_outputs[1]:
865
+ text_position_ids = layer_outputs[1]['position_ids']
866
+ if 'past_key_values' in layer_outputs[1]:
867
+ past_key_values = layer_outputs[1]['past_key_values']
868
+ if 'cache_position' in layer_outputs[1]:
869
+ cache_position = layer_outputs[1]['cache_position']
870
+ if 'position_embeddings' in layer_outputs[1]:
871
+ position_embeddings = layer_outputs[1]['position_embeddings']
872
+
873
+ # add visual features to the hidden states of first several layers
874
+ if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
875
+ hidden_states = self._deepstack_process(
876
+ hidden_states,
877
+ visual_pos_masks,
878
+ deepstack_visual_embeds[layer_idx],
879
+ )
880
+
881
+ hidden_states = self.norm(hidden_states)
882
+
883
+ return BaseModelOutputWithPast(
884
+ last_hidden_state=hidden_states,
885
+ past_key_values=past_key_values,
886
+ )
887
+
888
+ def _deepstack_process(
889
+ self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
890
+ ):
891
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
892
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
893
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
894
+ hidden_states[visual_pos_masks, :] = local_this
895
+ return hidden_states
896
+
897
+
898
+ @auto_docstring
899
+ class Qwen3VLModel(Qwen3VLPreTrainedModel):
900
+ base_model_prefix = ""
901
+ _checkpoint_conversion_mapping = {}
902
+ # Reference: fix gemma3 grad acc #37208
903
+ accepts_loss_kwargs = False
904
+ config: Qwen3VLConfig
905
+ _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
906
+
907
+ def __init__(self, config):
908
+ super().__init__(config)
909
+ self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
910
+ self.language_model = Qwen3VLTextModel._from_config(config.text_config)
911
+ self.rope_deltas = None # cache rope_deltas here
912
+
913
+ # Initialize weights and apply final processing
914
+ self.post_init()
915
+
916
+ def get_input_embeddings(self):
917
+ return self.language_model.get_input_embeddings()
918
+
919
+ def set_input_embeddings(self, value):
920
+ self.language_model.set_input_embeddings(value)
921
+
922
+ def set_decoder(self, decoder):
923
+ self.language_model = decoder
924
+
925
+ def get_decoder(self):
926
+ return self.language_model
927
+
928
+ def get_rope_index(
929
+ self,
930
+ input_ids: Optional[torch.LongTensor] = None,
931
+ image_grid_thw: Optional[torch.LongTensor] = None,
932
+ video_grid_thw: Optional[torch.LongTensor] = None,
933
+ attention_mask: Optional[torch.Tensor] = None,
934
+ ) -> tuple[torch.Tensor, torch.Tensor]:
935
+ """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
936
+
937
+ # Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
938
+ if video_grid_thw is not None:
939
+ video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
940
+ video_grid_thw[:, 0] = 1
941
+
942
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
943
+ image_token_id = self.config.image_token_id
944
+ video_token_id = self.config.video_token_id
945
+ vision_start_token_id = self.config.vision_start_token_id
946
+ mrope_position_deltas = []
947
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
948
+ total_input_ids = input_ids
949
+ if attention_mask is None:
950
+ attention_mask = torch.ones_like(total_input_ids)
951
+ position_ids = torch.ones(
952
+ 3,
953
+ input_ids.shape[0],
954
+ input_ids.shape[1],
955
+ dtype=input_ids.dtype,
956
+ device=input_ids.device,
957
+ )
958
+ image_index, video_index = 0, 0
959
+ attention_mask = attention_mask.to(total_input_ids.device)
960
+ for i, input_ids in enumerate(total_input_ids):
961
+ input_ids = input_ids[attention_mask[i] == 1]
962
+ image_nums, video_nums = 0, 0
963
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
964
+ vision_tokens = input_ids[vision_start_indices + 1]
965
+ image_nums = (vision_tokens == image_token_id).sum()
966
+ video_nums = (vision_tokens == video_token_id).sum()
967
+ input_tokens = input_ids.tolist()
968
+ llm_pos_ids_list: list = []
969
+ st = 0
970
+ remain_images, remain_videos = image_nums, video_nums
971
+ for _ in range(image_nums + video_nums):
972
+ if image_token_id in input_tokens and remain_images > 0:
973
+ ed_image = input_tokens.index(image_token_id, st)
974
+ else:
975
+ ed_image = len(input_tokens) + 1
976
+ if video_token_id in input_tokens and remain_videos > 0:
977
+ ed_video = input_tokens.index(video_token_id, st)
978
+ else:
979
+ ed_video = len(input_tokens) + 1
980
+ if ed_image < ed_video:
981
+ t, h, w = (
982
+ image_grid_thw[image_index][0],
983
+ image_grid_thw[image_index][1],
984
+ image_grid_thw[image_index][2],
985
+ )
986
+ image_index += 1
987
+ remain_images -= 1
988
+ ed = ed_image
989
+
990
+ else:
991
+ t, h, w = (
992
+ video_grid_thw[video_index][0],
993
+ video_grid_thw[video_index][1],
994
+ video_grid_thw[video_index][2],
995
+ )
996
+ video_index += 1
997
+ remain_videos -= 1
998
+ ed = ed_video
999
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1000
+ t.item(),
1001
+ h.item() // spatial_merge_size,
1002
+ w.item() // spatial_merge_size,
1003
+ )
1004
+ text_len = ed - st
1005
+
1006
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1007
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1008
+
1009
+ # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
1010
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1011
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1012
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1013
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1014
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1015
+
1016
+ if st < len(input_tokens):
1017
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1018
+ text_len = len(input_tokens) - st
1019
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1020
+
1021
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1022
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1023
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1024
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1025
+ return position_ids, mrope_position_deltas
1026
+ else:
1027
+ if attention_mask is not None:
1028
+ position_ids = attention_mask.long().cumsum(-1) - 1
1029
+ position_ids.masked_fill_(attention_mask == 0, 1)
1030
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1031
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1032
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1033
+ else:
1034
+ position_ids = (
1035
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1036
+ .view(1, 1, -1)
1037
+ .expand(3, input_ids.shape[0], -1)
1038
+ )
1039
+ mrope_position_deltas = torch.zeros(
1040
+ [input_ids.shape[0], 1],
1041
+ device=input_ids.device,
1042
+ dtype=input_ids.dtype,
1043
+ )
1044
+
1045
+ return position_ids, mrope_position_deltas
1046
+
1047
+ def get_video_features(
1048
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1049
+ ):
1050
+ """
1051
+ Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
1052
+
1053
+ Args:
1054
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1055
+ The tensors corresponding to the input videos.
1056
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1057
+ The temporal, height and width of feature shape of each video in LLM.
1058
+ """
1059
+ # Same implementation as for images
1060
+ return self.get_image_features(pixel_values_videos, video_grid_thw)
1061
+
1062
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1063
+ """
1064
+ Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
1065
+
1066
+ Args:
1067
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1068
+ The tensors corresponding to the input images.
1069
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1070
+ The temporal, height and width of feature shape of each image in LLM.
1071
+ """
1072
+ pixel_values = pixel_values.type(self.visual.dtype)
1073
+ image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
1074
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1075
+ image_embeds = torch.split(image_embeds, split_sizes)
1076
+ return image_embeds, deepstack_image_embeds
1077
+
1078
+ def get_placeholder_mask(
1079
+ self,
1080
+ input_ids: torch.LongTensor,
1081
+ inputs_embeds: torch.FloatTensor,
1082
+ image_features: Optional[torch.FloatTensor] = None,
1083
+ video_features: Optional[torch.FloatTensor] = None,
1084
+ ):
1085
+ """
1086
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1087
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1088
+ """
1089
+ if input_ids is None:
1090
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1091
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1092
+ )
1093
+ special_image_mask = special_image_mask.all(-1)
1094
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
1095
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
1096
+ )
1097
+ special_video_mask = special_video_mask.all(-1)
1098
+ else:
1099
+ special_image_mask = input_ids == self.config.image_token_id
1100
+ special_video_mask = input_ids == self.config.video_token_id
1101
+
1102
+ n_image_tokens = special_image_mask.sum()
1103
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1104
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
1105
+ raise ValueError(
1106
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
1107
+ )
1108
+
1109
+ n_video_tokens = special_video_mask.sum()
1110
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1111
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
1112
+ raise ValueError(
1113
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
1114
+ )
1115
+
1116
+ return special_image_mask, special_video_mask
1117
+
1118
+ @auto_docstring
1119
+ @check_model_inputs
1120
+ def forward(
1121
+ self,
1122
+ input_ids: torch.LongTensor = None,
1123
+ attention_mask: Optional[torch.Tensor] = None,
1124
+ position_ids: Optional[torch.LongTensor] = None,
1125
+ past_key_values: Optional[Cache] = None,
1126
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1127
+ pixel_values: Optional[torch.Tensor] = None,
1128
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1129
+ image_grid_thw: Optional[torch.LongTensor] = None,
1130
+ video_grid_thw: Optional[torch.LongTensor] = None,
1131
+ cache_position: Optional[torch.LongTensor] = None,
1132
+ **kwargs: Unpack[TransformersKwargs],
1133
+ ) -> Union[tuple, Qwen3VLModelOutputWithPast]:
1134
+ r"""
1135
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1136
+ The temporal, height and width of feature shape of each image in LLM.
1137
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1138
+ The temporal, height and width of feature shape of each video in LLM.
1139
+ """
1140
+ if (input_ids is None) ^ (inputs_embeds is not None):
1141
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1142
+
1143
+ if inputs_embeds is None:
1144
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1145
+
1146
+ image_mask = None
1147
+ video_mask = None
1148
+
1149
+ if pixel_values is not None:
1150
+ image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
1151
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1152
+ image_mask, _ = self.get_placeholder_mask(
1153
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
1154
+ )
1155
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1156
+
1157
+ if pixel_values_videos is not None:
1158
+ video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
1159
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1160
+ _, video_mask = self.get_placeholder_mask(
1161
+ input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
1162
+ )
1163
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1164
+
1165
+ visual_pos_masks = None
1166
+ deepstack_visual_embeds = None
1167
+ if image_mask is not None and video_mask is not None:
1168
+ # aggregate visual_pos_masks and deepstack_visual_embeds
1169
+ image_mask = image_mask[..., 0]
1170
+ video_mask = video_mask[..., 0]
1171
+ visual_pos_masks = image_mask | video_mask
1172
+ deepstack_visual_embeds = []
1173
+ image_mask_joint = image_mask[visual_pos_masks]
1174
+ video_mask_joint = video_mask[visual_pos_masks]
1175
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
1176
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
1177
+ embed_joint[image_mask_joint, :] = img_embed
1178
+ embed_joint[video_mask_joint, :] = vid_embed
1179
+ deepstack_visual_embeds.append(embed_joint)
1180
+ elif image_mask is not None:
1181
+ image_mask = image_mask[..., 0]
1182
+ visual_pos_masks = image_mask
1183
+ deepstack_visual_embeds = deepstack_image_embeds
1184
+ elif video_mask is not None:
1185
+ video_mask = video_mask[..., 0]
1186
+ visual_pos_masks = video_mask
1187
+ deepstack_visual_embeds = deepstack_video_embeds
1188
+
1189
+ if position_ids is None:
1190
+ attention_mask_tensor = (
1191
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1192
+ )
1193
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1194
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1195
+ # Only apply conversion for floating point tensors (inverted masks)
1196
+ if attention_mask_tensor.dtype.is_floating_point:
1197
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1198
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1199
+
1200
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1201
+ # When compiling, we can't check tensor values thus we check only input length
1202
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1203
+ # models currently cannot do asssisted decoding
1204
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
1205
+ (input_ids is not None and input_ids.shape[1] != 1)
1206
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1207
+ )
1208
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1209
+ (cache_position is not None and cache_position[0] == 0)
1210
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1211
+ )
1212
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1213
+ position_ids, rope_deltas = self.get_rope_index(
1214
+ input_ids,
1215
+ image_grid_thw,
1216
+ video_grid_thw,
1217
+ attention_mask=attention_mask_tensor,
1218
+ )
1219
+ self.rope_deltas = rope_deltas
1220
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1221
+ else:
1222
+ batch_size, seq_length, _ = inputs_embeds.shape
1223
+ delta = (
1224
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1225
+ if cache_position is not None
1226
+ else 0
1227
+ )
1228
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1229
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1230
+ if cache_position is not None: # otherwise `deltas` is an int `0`
1231
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1232
+ position_ids = position_ids.add(delta)
1233
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1234
+
1235
+ ##FIXED: HARD CODING
1236
+ outputs = self.language_model(
1237
+ input_ids=input_ids,
1238
+ position_ids=position_ids,
1239
+ attention_mask=attention_mask,
1240
+ past_key_values=past_key_values,
1241
+ inputs_embeds=inputs_embeds,
1242
+ cache_position=cache_position,
1243
+ visual_pos_masks=visual_pos_masks,
1244
+ deepstack_visual_embeds=deepstack_visual_embeds,
1245
+ **kwargs,
1246
+ )
1247
+
1248
+ return Qwen3VLModelOutputWithPast(
1249
+ last_hidden_state=outputs.last_hidden_state,
1250
+ past_key_values=outputs.past_key_values,
1251
+ rope_deltas=self.rope_deltas,
1252
+ )
1253
+
1254
+
1255
+ @dataclass
1256
+ @auto_docstring(
1257
+ custom_intro="""
1258
+ Base class for Qwen3VL causal language model (or autoregressive) outputs.
1259
+ """
1260
+ )
1261
+ class Qwen3VLCausalLMOutputWithPast(ModelOutput):
1262
+ r"""
1263
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1264
+ Language modeling loss (for next-token prediction).
1265
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1266
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1267
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1268
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1269
+
1270
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1271
+ `past_key_values` input) to speed up sequential decoding.
1272
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1273
+ The rope index difference between sequence length and multimodal rope.
1274
+ """
1275
+
1276
+ loss: Optional[torch.FloatTensor] = None
1277
+ logits: Optional[torch.FloatTensor] = None
1278
+ past_key_values: Optional[Cache] = None
1279
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1280
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1281
+ rope_deltas: Optional[torch.LongTensor] = None
1282
+
1283
+
1284
+ class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
1285
+ _checkpoint_conversion_mapping = {}
1286
+ _tied_weights_keys = ["lm_head.weight"]
1287
+ # Reference: fix gemma3 grad acc #37208
1288
+ accepts_loss_kwargs = False
1289
+ config: Qwen3VLConfig
1290
+
1291
+ def __init__(self, config):
1292
+ super().__init__(config)
1293
+ self.model = Qwen3VLModel(config)
1294
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1295
+
1296
+ self.post_init()
1297
+
1298
+ def get_input_embeddings(self):
1299
+ return self.model.get_input_embeddings()
1300
+
1301
+ def set_input_embeddings(self, value):
1302
+ self.model.set_input_embeddings(value)
1303
+
1304
+ def set_decoder(self, decoder):
1305
+ self.model.set_decoder(decoder)
1306
+
1307
+ def get_decoder(self):
1308
+ return self.model.get_decoder()
1309
+
1310
+ def get_video_features(
1311
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1312
+ ):
1313
+ return self.model.get_video_features(pixel_values_videos, video_grid_thw)
1314
+
1315
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1316
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1317
+
1318
+ # Make modules available through conditional class for BC
1319
+ @property
1320
+ def language_model(self):
1321
+ return self.model.language_model
1322
+
1323
+ @property
1324
+ def visual(self):
1325
+ return self.model.visual
1326
+
1327
+ @check_model_inputs
1328
+ def forward(
1329
+ self,
1330
+ input_ids: torch.LongTensor = None,
1331
+ attention_mask: Optional[torch.Tensor] = None,
1332
+ position_ids: Optional[torch.LongTensor] = None,
1333
+ past_key_values: Optional[Cache] = None,
1334
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1335
+ labels: Optional[torch.LongTensor] = None,
1336
+ pixel_values: Optional[torch.Tensor] = None,
1337
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1338
+ image_grid_thw: Optional[torch.LongTensor] = None,
1339
+ video_grid_thw: Optional[torch.LongTensor] = None,
1340
+ cache_position: Optional[torch.LongTensor] = None,
1341
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1342
+ **kwargs: Unpack[TransformersKwargs],
1343
+ ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:
1344
+ r"""
1345
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1346
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1347
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1348
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1349
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1350
+ The temporal, height and width of feature shape of each image in LLM.
1351
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1352
+ The temporal, height and width of feature shape of each video in LLM.
1353
+
1354
+ Example:
1355
+ TODO: Add example
1356
+ """
1357
+ outputs = self.model(
1358
+ input_ids=input_ids,
1359
+ pixel_values=pixel_values,
1360
+ pixel_values_videos=pixel_values_videos,
1361
+ image_grid_thw=image_grid_thw,
1362
+ video_grid_thw=video_grid_thw,
1363
+ position_ids=position_ids,
1364
+ attention_mask=attention_mask,
1365
+ past_key_values=past_key_values,
1366
+ inputs_embeds=inputs_embeds,
1367
+ cache_position=cache_position,
1368
+ **kwargs,
1369
+ )
1370
+
1371
+ hidden_states = outputs[0]
1372
+
1373
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1374
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1375
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1376
+
1377
+ loss = None
1378
+ if labels is not None:
1379
+ ## FIXED: HARD CODING
1380
+ loss = self.loss_function(logits=logits, labels=labels[..., -1*logits.shape[1]:], vocab_size=self.config.text_config.vocab_size)
1381
+
1382
+ return Qwen3VLCausalLMOutputWithPast(
1383
+ loss=loss,
1384
+ logits=logits,
1385
+ past_key_values=outputs.past_key_values,
1386
+ rope_deltas=outputs.rope_deltas,
1387
+ )
1388
+
1389
+ def prepare_inputs_for_generation(
1390
+ self,
1391
+ input_ids,
1392
+ past_key_values=None,
1393
+ attention_mask=None,
1394
+ inputs_embeds=None,
1395
+ cache_position=None,
1396
+ position_ids=None,
1397
+ use_cache=True,
1398
+ pixel_values=None,
1399
+ pixel_values_videos=None,
1400
+ image_grid_thw=None,
1401
+ video_grid_thw=None,
1402
+ **kwargs,
1403
+ ):
1404
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1405
+
1406
+ model_inputs = super().prepare_inputs_for_generation(
1407
+ input_ids,
1408
+ past_key_values=past_key_values,
1409
+ attention_mask=attention_mask,
1410
+ inputs_embeds=inputs_embeds,
1411
+ cache_position=cache_position,
1412
+ position_ids=position_ids,
1413
+ pixel_values=pixel_values,
1414
+ pixel_values_videos=pixel_values_videos,
1415
+ image_grid_thw=image_grid_thw,
1416
+ video_grid_thw=video_grid_thw,
1417
+ use_cache=use_cache,
1418
+ **kwargs,
1419
+ )
1420
+
1421
+ # Qwen3VL position_ids are prepareed with rope_deltas in forward
1422
+ model_inputs["position_ids"] = None
1423
+
1424
+ if cache_position[0] != 0:
1425
+ model_inputs["pixel_values"] = None
1426
+ model_inputs["pixel_values_videos"] = None
1427
+
1428
+ return model_inputs
1429
+
1430
+ def _get_image_nums_and_video_nums(
1431
+ self,
1432
+ input_ids: Optional[torch.LongTensor],
1433
+ inputs_embeds: Optional[torch.Tensor] = None,
1434
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1435
+ """
1436
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1437
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1438
+
1439
+ Args:
1440
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1441
+ Indices of input sequence tokens in the vocabulary.
1442
+
1443
+ Returns:
1444
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1445
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1446
+ """
1447
+ image_token_id = self.config.image_token_id
1448
+ video_token_id = self.config.video_token_id
1449
+ vision_start_token_id = self.config.vision_start_token_id
1450
+
1451
+ if inputs_embeds is not None:
1452
+ vision_start_mask = (
1453
+ inputs_embeds
1454
+ == self.get_input_embeddings()(
1455
+ torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1456
+ )
1457
+ )[..., 0]
1458
+ image_mask = (
1459
+ inputs_embeds
1460
+ == self.get_input_embeddings()(
1461
+ torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
1462
+ )
1463
+ )[..., 0]
1464
+ video_mask = (
1465
+ inputs_embeds
1466
+ == self.get_input_embeddings()(
1467
+ torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
1468
+ )
1469
+ )[..., 0]
1470
+ else:
1471
+ vision_start_mask = input_ids == vision_start_token_id
1472
+ image_mask = input_ids == image_token_id
1473
+ video_mask = input_ids == video_token_id
1474
+
1475
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1476
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1477
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1478
+
1479
+ return image_nums, video_nums
1480
+
1481
+ def _expand_inputs_for_generation(
1482
+ self,
1483
+ expand_size: int = 1,
1484
+ is_encoder_decoder: bool = False,
1485
+ input_ids: Optional[torch.LongTensor] = None,
1486
+ **model_kwargs,
1487
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1488
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1489
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1490
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1491
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1492
+
1493
+ if expand_size == 1:
1494
+ return input_ids, model_kwargs
1495
+
1496
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1497
+
1498
+ def _expand_dict_for_generation_visual(dict_to_expand):
1499
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1500
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1501
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1502
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1503
+ )
1504
+
1505
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1506
+ samples = torch.split(x, lengths)
1507
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1508
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1509
+ return result
1510
+
1511
+ for key in dict_to_expand:
1512
+ if key == "pixel_values":
1513
+ # split images into samples
1514
+ samples = torch.split(image_grid_thw, list(image_nums))
1515
+ # compute the sequence length of images for each sample
1516
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1517
+ dict_to_expand[key] = _repeat_interleave_samples(
1518
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1519
+ )
1520
+ elif key == "image_grid_thw":
1521
+ # get the num of images for each sample
1522
+ lengths = list(image_nums)
1523
+ dict_to_expand[key] = _repeat_interleave_samples(
1524
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1525
+ )
1526
+ elif key == "pixel_values_videos":
1527
+ samples = torch.split(video_grid_thw, list(video_nums))
1528
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1529
+ dict_to_expand[key] = _repeat_interleave_samples(
1530
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1531
+ )
1532
+ elif key == "video_grid_thw":
1533
+ lengths = list(video_nums)
1534
+ dict_to_expand[key] = _repeat_interleave_samples(
1535
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1536
+ )
1537
+ elif key == "second_per_grid_ts":
1538
+ dict_to_expand[key] = _repeat_interleave_samples(
1539
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1540
+ )
1541
+ return dict_to_expand
1542
+
1543
+ def _expand_dict_for_generation(dict_to_expand):
1544
+ for key in dict_to_expand:
1545
+ if (
1546
+ key != "cache_position"
1547
+ and dict_to_expand[key] is not None
1548
+ and isinstance(dict_to_expand[key], torch.Tensor)
1549
+ and key not in visual_keys
1550
+ ):
1551
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1552
+ return dict_to_expand
1553
+
1554
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1555
+
1556
+ if input_ids is not None:
1557
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1558
+
1559
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1560
+
1561
+ if is_encoder_decoder:
1562
+ if model_kwargs.get("encoder_outputs") is None:
1563
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1564
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1565
+
1566
+ return input_ids, model_kwargs
1567
+
1568
+
1569
+ __all__ = [
1570
+ "Qwen3VLVisionModel",
1571
+ "Qwen3VLForConditionalGeneration",
1572
+ "Qwen3VLModel",
1573
+ "Qwen3VLPreTrainedModel",
1574
+ "Qwen3VLTextModel",
1575
+ ]