TYTTYTTYT commited on
Commit
3054cf2
·
verified ·
1 Parent(s): 9118e49

Fixed bug in resize logic

Browse files
image_processing_qwen2_vl.py CHANGED
@@ -2,15 +2,15 @@ import math
2
 
3
  import numpy as np
4
 
5
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
 
6
  from transformers.image_transforms import (
7
  convert_to_rgb,
8
  resize,
9
  to_channel_dimension_format,
10
  )
 
11
  from transformers.image_utils import (
12
- OPENAI_CLIP_MEAN,
13
- OPENAI_CLIP_STD,
14
  ChannelDimension,
15
  ImageInput,
16
  PILImageResampling,
@@ -23,7 +23,8 @@ from transformers.image_utils import (
23
  validate_preprocess_arguments,
24
  )
25
  from transformers.processing_utils import ImagesKwargs
26
- from transformers.utils import TensorType, logging
 
27
  from transformers.video_utils import VideoInput
28
 
29
 
@@ -137,6 +138,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
137
  patch_size: int = 14,
138
  temporal_patch_size: int = 2,
139
  merge_size: int = 2,
 
140
  **kwargs,
141
  ) -> None:
142
  super().__init__(**kwargs)
@@ -165,6 +167,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
165
  self.patch_size = patch_size
166
  self.temporal_patch_size = temporal_patch_size
167
  self.merge_size = merge_size
 
168
  self.do_convert_rgb = do_convert_rgb
169
 
170
  def _preprocess(
@@ -181,6 +184,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
181
  patch_size: int | None = None,
182
  temporal_patch_size: int | None = None,
183
  merge_size: int | None = None,
 
184
  do_convert_rgb: bool | None = None,
185
  data_format: ChannelDimension | None = ChannelDimension.FIRST,
186
  input_data_format: str | ChannelDimension | None = None,
@@ -228,16 +232,16 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
228
  - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
229
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
230
  """
231
- images = make_flat_list_of_images(images)
232
 
233
  if do_convert_rgb:
234
- images = [convert_to_rgb(image) for image in images]
235
 
236
  # All transformations expect numpy arrays.
237
- images = [to_numpy_array(image) for image in images]
238
 
239
  if do_rescale and is_scaled_image(images[0]):
240
- logger.warning_once(
241
  "It looks like you are trying to rescale already rescaled images. If the input"
242
  " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
243
  )
@@ -245,7 +249,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
245
  # We assume that all images have the same channel dimension format.
246
  input_data_format = infer_channel_dimension_format(images[0])
247
 
248
- height, width = get_image_size(images[0], channel_dim=input_data_format)
249
  resized_height, resized_width = height, width
250
  processed_images = []
251
  for image in images:
@@ -253,23 +257,23 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
253
  resized_height, resized_width = smart_resize(
254
  height,
255
  width,
256
- factor=patch_size * merge_size,
257
- min_pixels=size["shortest_edge"],
258
- max_pixels=size["longest_edge"],
259
  )
260
  image = resize(
261
  image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
262
  )
263
 
264
  if do_rescale:
265
- image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
266
 
267
  if do_normalize:
268
  image = self.normalize(
269
- image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
270
  )
271
 
272
- image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
273
  processed_images.append(image)
274
 
275
  patches = np.array(processed_images)
@@ -282,17 +286,17 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
282
  patches = np.concatenate([patches, repeats], axis=0)
283
  channel = patches.shape[1]
284
  grid_t = patches.shape[0] // temporal_patch_size
285
- grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
286
  patches = patches.reshape(
287
  grid_t,
288
- temporal_patch_size,
289
  channel,
290
- grid_h // merge_size,
291
- merge_size,
292
- patch_size,
293
- grid_w // merge_size,
294
- merge_size,
295
- patch_size,
296
  )
297
  patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
298
  flatten_patches = patches.reshape(
@@ -301,7 +305,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
301
 
302
  return flatten_patches, (grid_t, grid_h, grid_w)
303
 
304
- def preprocess(
305
  self,
306
  images: ImageInput,
307
  do_resize: bool | None = None,
@@ -403,7 +407,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
403
  do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
404
 
405
  if images is not None:
406
- images = self.fetch_images(images)
407
  images = make_flat_list_of_images(images)
408
 
409
  if images is not None and not valid_images(images):
@@ -421,7 +425,7 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
421
 
422
  data = {}
423
  pixel_values, vision_grid_thws = [], []
424
- for image in images:
425
  patches, image_grid_thw = self._preprocess(
426
  image,
427
  do_resize=do_resize,
@@ -461,12 +465,13 @@ class ZFQwen2VLImageProcessor(BaseImageProcessor):
461
  Returns:
462
  `int`: Number of image patches per image.
463
  """
464
- min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
465
- max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
466
- patch_size = images_kwargs.get("patch_size", self.patch_size)
467
- merge_size = images_kwargs.get("merge_size", self.merge_size)
 
468
 
469
- factor = patch_size * merge_size
470
  resized_height, resized_width = smart_resize(
471
  height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
472
  )
 
2
 
3
  import numpy as np
4
 
5
+ from transformers.image_processing_utils import BaseImageProcessor
6
+ from transformers.image_processing_base import BatchFeature
7
  from transformers.image_transforms import (
8
  convert_to_rgb,
9
  resize,
10
  to_channel_dimension_format,
11
  )
12
+ from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
13
  from transformers.image_utils import (
 
 
14
  ChannelDimension,
15
  ImageInput,
16
  PILImageResampling,
 
23
  validate_preprocess_arguments,
24
  )
25
  from transformers.processing_utils import ImagesKwargs
26
+ from transformers.utils.generic import TensorType
27
+ from transformers.utils import logging
28
  from transformers.video_utils import VideoInput
29
 
30
 
 
138
  patch_size: int = 14,
139
  temporal_patch_size: int = 2,
140
  merge_size: int = 2,
141
+ focus_size: int = 2,
142
  **kwargs,
143
  ) -> None:
144
  super().__init__(**kwargs)
 
167
  self.patch_size = patch_size
168
  self.temporal_patch_size = temporal_patch_size
169
  self.merge_size = merge_size
170
+ self.focus_size = focus_size
171
  self.do_convert_rgb = do_convert_rgb
172
 
173
  def _preprocess(
 
184
  patch_size: int | None = None,
185
  temporal_patch_size: int | None = None,
186
  merge_size: int | None = None,
187
+ focus_size: int | None = None,
188
  do_convert_rgb: bool | None = None,
189
  data_format: ChannelDimension | None = ChannelDimension.FIRST,
190
  input_data_format: str | ChannelDimension | None = None,
 
232
  - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
233
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
234
  """
235
+ images = make_flat_list_of_images(images) # type: ignore
236
 
237
  if do_convert_rgb:
238
+ images = [convert_to_rgb(image) for image in images] # type: ignore
239
 
240
  # All transformations expect numpy arrays.
241
+ images = [to_numpy_array(image) for image in images] # type: ignore
242
 
243
  if do_rescale and is_scaled_image(images[0]):
244
+ logger.warning_once( # type: ignore
245
  "It looks like you are trying to rescale already rescaled images. If the input"
246
  " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
247
  )
 
249
  # We assume that all images have the same channel dimension format.
250
  input_data_format = infer_channel_dimension_format(images[0])
251
 
252
+ height, width = get_image_size(images[0], channel_dim=input_data_format) # type: ignore
253
  resized_height, resized_width = height, width
254
  processed_images = []
255
  for image in images:
 
257
  resized_height, resized_width = smart_resize(
258
  height,
259
  width,
260
+ factor=patch_size * merge_size * focus_size, # type: ignore
261
+ min_pixels=size["shortest_edge"], # type: ignore
262
+ max_pixels=size["longest_edge"], # type: ignore
263
  )
264
  image = resize(
265
  image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
266
  )
267
 
268
  if do_rescale:
269
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) # type: ignore
270
 
271
  if do_normalize:
272
  image = self.normalize(
273
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format # type: ignore
274
  )
275
 
276
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) # type: ignore
277
  processed_images.append(image)
278
 
279
  patches = np.array(processed_images)
 
286
  patches = np.concatenate([patches, repeats], axis=0)
287
  channel = patches.shape[1]
288
  grid_t = patches.shape[0] // temporal_patch_size
289
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # type: ignore
290
  patches = patches.reshape(
291
  grid_t,
292
+ temporal_patch_size, # type: ignore
293
  channel,
294
+ grid_h // merge_size, # type: ignore
295
+ merge_size, # type: ignore
296
+ patch_size, # type: ignore
297
+ grid_w // merge_size, # type: ignore
298
+ merge_size, # type: ignore
299
+ patch_size, # type: ignore
300
  )
301
  patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
302
  flatten_patches = patches.reshape(
 
305
 
306
  return flatten_patches, (grid_t, grid_h, grid_w)
307
 
308
+ def preprocess( # type: ignore
309
  self,
310
  images: ImageInput,
311
  do_resize: bool | None = None,
 
407
  do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
408
 
409
  if images is not None:
410
+ images = self.fetch_images(images) # type: ignore
411
  images = make_flat_list_of_images(images)
412
 
413
  if images is not None and not valid_images(images):
 
425
 
426
  data = {}
427
  pixel_values, vision_grid_thws = [], []
428
+ for image in images: # type: ignore
429
  patches, image_grid_thw = self._preprocess(
430
  image,
431
  do_resize=do_resize,
 
465
  Returns:
466
  `int`: Number of image patches per image.
467
  """
468
+ min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"] # type: ignore
469
+ max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"] # type: ignore
470
+ patch_size = images_kwargs.get("patch_size", self.patch_size) # type: ignore
471
+ merge_size = images_kwargs.get("merge_size", self.merge_size) # type: ignore
472
+ focus_size = images_kwargs.get("focus_size", self.focus_size) # type: ignore
473
 
474
+ factor = patch_size * merge_size * focus_size
475
  resized_height, resized_width = smart_resize(
476
  height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
477
  )
processing_qwen3_vl.py CHANGED
@@ -25,42 +25,42 @@ class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
25
  @auto_docstring
26
  class ZFQwen3VLProcessor(ProcessorMixin):
27
  def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
28
- self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
29
- self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
30
  self.image_token_id = (
31
- tokenizer.image_token_id
32
  if getattr(tokenizer, "image_token_id", None)
33
- else tokenizer.convert_tokens_to_ids(self.image_token)
34
  )
35
  self.video_token_id = (
36
- tokenizer.video_token_id
37
  if getattr(tokenizer, "video_token_id", None)
38
- else tokenizer.convert_tokens_to_ids(self.video_token)
39
  )
40
  super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
41
  self.vision_start_token = (
42
- "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
43
  )
44
  self.vision_end_token = (
45
- "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
46
  )
47
  self.vision_start_token_id = (
48
- tokenizer.vision_start_token_id
49
  if getattr(tokenizer, "vision_start_token_id", None)
50
- else tokenizer.convert_tokens_to_ids(self.vision_start_token)
51
  )
52
  self.vision_end_token_id = (
53
- tokenizer.vision_end_token_id
54
  if getattr(tokenizer, "vision_end_token_id", None)
55
- else tokenizer.convert_tokens_to_ids(self.vision_end_token)
56
  )
57
 
58
  @auto_docstring
59
- def __call__(
60
  self,
61
- images: ImageInput = None,
62
- text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
63
- videos: VideoInput = None,
64
  **kwargs: Unpack[Qwen3VLProcessorKwargs],
65
  ) -> BatchFeature:
66
  r"""
@@ -77,19 +77,19 @@ class ZFQwen3VLProcessor(ProcessorMixin):
77
  - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
78
  """
79
  output_kwargs = self._merge_kwargs(
80
- Qwen3VLProcessorKwargs,
81
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
82
  **kwargs,
83
  )
84
  if images is not None:
85
- image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
86
  image_grid_thw = image_inputs["image_grid_thw"]
87
  else:
88
  image_inputs = {}
89
  image_grid_thw = None
90
 
91
  if videos is not None:
92
- videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
93
  video_grid_thw = videos_inputs["video_grid_thw"]
94
  # If user has not requested video metadata, pop it
95
  if not kwargs.get("return_metadata"):
@@ -105,23 +105,23 @@ class ZFQwen3VLProcessor(ProcessorMixin):
105
 
106
  text = text.copy() # below lines change text in-place
107
  if image_grid_thw is not None:
108
- merge_length = self.image_processor.merge_size**2
109
  index = 0
110
  for i in range(len(text)):
111
  while self.image_token in text[i]:
112
  num_image_tokens = image_grid_thw[index].prod() // merge_length
113
- text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
114
  index += 1
115
- text[i] = text[i].replace("<|placeholder|>", self.image_token)
116
 
117
  if video_grid_thw is not None:
118
- merge_length = self.video_processor.merge_size**2
119
  index = 0
120
  for i in range(len(text)):
121
  while self.video_token in text[i]:
122
- metadata = video_metadata[index]
123
  if metadata.fps is None:
124
- logger.warning_once(
125
  "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
126
  "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
127
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
@@ -132,9 +132,9 @@ class ZFQwen3VLProcessor(ProcessorMixin):
132
  curr_timestamp = self._calculate_timestamps(
133
  metadata.frames_indices,
134
  metadata.fps,
135
- self.video_processor.merge_size,
 
136
  )
137
-
138
  video_placeholder = ""
139
  frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
140
  for frame_idx in range(video_grid_thw[index][0]):
@@ -144,20 +144,20 @@ class ZFQwen3VLProcessor(ProcessorMixin):
144
  self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
145
  )
146
  if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
147
- text[i] = text[i].replace(
148
  f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
149
  )
150
  else:
151
  # vllm may input video token directly
152
- text[i] = text[i].replace(self.video_token, video_placeholder, 1)
153
  index += 1
154
 
155
- text[i] = text[i].replace("<|placeholder|>", self.video_token)
156
 
157
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
158
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
159
- text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
160
- self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
161
 
162
  if return_mm_token_type_ids:
163
  array_ids = np.array(text_inputs["input_ids"])
@@ -184,10 +184,10 @@ class ZFQwen3VLProcessor(ProcessorMixin):
184
  if image_sizes is not None:
185
  images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
186
  images_kwargs.update(kwargs)
187
- merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
188
 
189
  num_image_patches = [
190
- self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
191
  for image_size in image_sizes
192
  ]
193
  num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
@@ -197,10 +197,10 @@ class ZFQwen3VLProcessor(ProcessorMixin):
197
  videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
198
  videos_kwargs.update(kwargs)
199
  num_video_patches = [
200
- self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
201
  for video_size in video_sizes
202
  ]
203
- num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
204
  vision_data["num_video_tokens"] = num_video_tokens
205
 
206
  return MultiModalData(**vision_data)
@@ -225,18 +225,26 @@ class ZFQwen3VLProcessor(ProcessorMixin):
225
  Returns:
226
  `list[str]`: The decoded text.
227
  """
228
- return self.tokenizer.batch_decode(
229
  generated_outputs,
230
  skip_special_tokens=skip_special_tokens,
231
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
232
  **kwargs,
233
  )
234
 
235
- def _calculate_timestamps(self, indices: list[int] | np.ndarray, video_fps: float, merge_size: int = 2):
 
 
 
 
 
 
236
  if not isinstance(indices, list):
237
  indices = indices.tolist()
238
- if len(indices) % merge_size != 0:
239
- indices.extend(indices[-1] for _ in range(merge_size - len(indices) % merge_size))
 
 
240
  timestamps = [idx / video_fps for idx in indices]
241
  # @JJJYmmm frames are merged by self.merge_size, \
242
  # so we need to average the timestamps between the first/last frame within the temporal patch
 
25
  @auto_docstring
26
  class ZFQwen3VLProcessor(ProcessorMixin):
27
  def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
28
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token # type: ignore
29
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token # type: ignore
30
  self.image_token_id = (
31
+ tokenizer.image_token_id # type: ignore
32
  if getattr(tokenizer, "image_token_id", None)
33
+ else tokenizer.convert_tokens_to_ids(self.image_token) # type: ignore
34
  )
35
  self.video_token_id = (
36
+ tokenizer.video_token_id # type: ignore
37
  if getattr(tokenizer, "video_token_id", None)
38
+ else tokenizer.convert_tokens_to_ids(self.video_token) # type: ignore
39
  )
40
  super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
41
  self.vision_start_token = (
42
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token # type: ignore
43
  )
44
  self.vision_end_token = (
45
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token # type: ignore
46
  )
47
  self.vision_start_token_id = (
48
+ tokenizer.vision_start_token_id # type: ignore
49
  if getattr(tokenizer, "vision_start_token_id", None)
50
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token) # type: ignore
51
  )
52
  self.vision_end_token_id = (
53
+ tokenizer.vision_end_token_id # type: ignore
54
  if getattr(tokenizer, "vision_end_token_id", None)
55
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token) # type: ignore
56
  )
57
 
58
  @auto_docstring
59
+ def __call__( # type: ignore
60
  self,
61
+ images: ImageInput = None, # type: ignore
62
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, # type: ignore
63
+ videos: VideoInput = None, # type: ignore
64
  **kwargs: Unpack[Qwen3VLProcessorKwargs],
65
  ) -> BatchFeature:
66
  r"""
 
77
  - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
78
  """
79
  output_kwargs = self._merge_kwargs(
80
+ Qwen3VLProcessorKwargs, # type: ignore
81
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs, # type: ignore
82
  **kwargs,
83
  )
84
  if images is not None:
85
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) # type: ignore
86
  image_grid_thw = image_inputs["image_grid_thw"]
87
  else:
88
  image_inputs = {}
89
  image_grid_thw = None
90
 
91
  if videos is not None:
92
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) # type: ignore
93
  video_grid_thw = videos_inputs["video_grid_thw"]
94
  # If user has not requested video metadata, pop it
95
  if not kwargs.get("return_metadata"):
 
105
 
106
  text = text.copy() # below lines change text in-place
107
  if image_grid_thw is not None:
108
+ merge_length = self.image_processor.merge_size**2 # type: ignore
109
  index = 0
110
  for i in range(len(text)):
111
  while self.image_token in text[i]:
112
  num_image_tokens = image_grid_thw[index].prod() // merge_length
113
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) # type: ignore
114
  index += 1
115
+ text[i] = text[i].replace("<|placeholder|>", self.image_token) # type: ignore
116
 
117
  if video_grid_thw is not None:
118
+ merge_length = self.video_processor.merge_size**2 # type: ignore
119
  index = 0
120
  for i in range(len(text)):
121
  while self.video_token in text[i]:
122
+ metadata = video_metadata[index] # type: ignore
123
  if metadata.fps is None:
124
+ logger.warning_once( # type: ignore
125
  "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
126
  "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
127
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
 
132
  curr_timestamp = self._calculate_timestamps(
133
  metadata.frames_indices,
134
  metadata.fps,
135
+ self.video_processor.merge_size, # type: ignore
136
+ self.video_processor.focus_size, # type: ignore
137
  )
 
138
  video_placeholder = ""
139
  frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
140
  for frame_idx in range(video_grid_thw[index][0]):
 
144
  self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
145
  )
146
  if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
147
+ text[i] = text[i].replace( # type: ignore
148
  f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
149
  )
150
  else:
151
  # vllm may input video token directly
152
+ text[i] = text[i].replace(self.video_token, video_placeholder, 1) # type: ignore
153
  index += 1
154
 
155
+ text[i] = text[i].replace("<|placeholder|>", self.video_token) # type: ignore
156
 
157
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
158
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
159
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) # type: ignore
160
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) # type: ignore
161
 
162
  if return_mm_token_type_ids:
163
  array_ids = np.array(text_inputs["input_ids"])
 
184
  if image_sizes is not None:
185
  images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
186
  images_kwargs.update(kwargs)
187
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size # type: ignore
188
 
189
  num_image_patches = [
190
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) # type: ignore
191
  for image_size in image_sizes
192
  ]
193
  num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
 
197
  videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
198
  videos_kwargs.update(kwargs)
199
  num_video_patches = [
200
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) # type: ignore
201
  for video_size in video_sizes
202
  ]
203
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] # type: ignore
204
  vision_data["num_video_tokens"] = num_video_tokens
205
 
206
  return MultiModalData(**vision_data)
 
225
  Returns:
226
  `list[str]`: The decoded text.
227
  """
228
+ return self.tokenizer.batch_decode( # type: ignore
229
  generated_outputs,
230
  skip_special_tokens=skip_special_tokens,
231
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
232
  **kwargs,
233
  )
234
 
235
+ def _calculate_timestamps(
236
+ self,
237
+ indices: list[int] | np.ndarray,
238
+ video_fps: float,
239
+ merge_size: int = 2,
240
+ focus_size: int = 2
241
+ ):
242
  if not isinstance(indices, list):
243
  indices = indices.tolist()
244
+ if len(indices) % (merge_size * focus_size) != 0:
245
+ indices.extend( # type: ignore
246
+ indices[-1] for _ in range((merge_size * focus_size) - len(indices) % (merge_size * focus_size))
247
+ )
248
  timestamps = [idx / video_fps for idx in indices]
249
  # @JJJYmmm frames are merged by self.merge_size, \
250
  # so we need to average the timestamps between the first/last frame within the temporal patch
processor_config.json CHANGED
@@ -25,7 +25,7 @@
25
  0.5
26
  ],
27
  "merge_size": 2,
28
- "patch_size": 14,
29
  "resample": 3,
30
  "rescale_factor": 0.00392156862745098,
31
  "size": {
@@ -67,7 +67,7 @@
67
  "rescale_factor": 0.00392156862745098,
68
  "return_metadata": false,
69
  "size": {
70
- "longest_edge": 25165824,
71
  "shortest_edge": 4096
72
  },
73
  "temporal_patch_size": 2,
 
25
  0.5
26
  ],
27
  "merge_size": 2,
28
+ "patch_size": 16,
29
  "resample": 3,
30
  "rescale_factor": 0.00392156862745098,
31
  "size": {
 
67
  "rescale_factor": 0.00392156862745098,
68
  "return_metadata": false,
69
  "size": {
70
+ "longest_edge": 251658240,
71
  "shortest_edge": 4096
72
  },
73
  "temporal_patch_size": 2,
video_processing_qwen3_vl.py CHANGED
@@ -47,10 +47,11 @@ def smart_resize(
47
  return h_bar, w_bar
48
 
49
 
50
- class Qwen3VLVideoProcessorInitKwargs(VideosKwargs, total=False):
51
  patch_size: int
52
  temporal_patch_size: int
53
  merge_size: int
 
54
  min_frames: int
55
  max_frames: int
56
 
@@ -79,21 +80,22 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
79
  patch_size = 16
80
  temporal_patch_size = 2
81
  merge_size = 2
 
82
  fps = 2
83
  min_frames = 4
84
  max_frames = 768
85
  do_sample_frames = True
86
- valid_kwargs = Qwen3VLVideoProcessorInitKwargs
87
  model_input_names = ["pixel_values_videos", "video_grid_thw"]
88
 
89
- def __init__(self, **kwargs: Unpack[Qwen3VLVideoProcessorInitKwargs]):
90
  super().__init__(**kwargs)
91
  if self.size is not None and (
92
  self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None
93
  ):
94
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
95
 
96
- def _further_process_kwargs(
97
  self,
98
  size: SizeDict | None = None,
99
  **kwargs,
@@ -107,7 +109,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
107
 
108
  return super()._further_process_kwargs(size=size, **kwargs)
109
 
110
- def sample_frames(
111
  self,
112
  metadata: VideoMetadata,
113
  num_frames: int | None = None,
@@ -142,7 +144,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
142
  if num_frames is None and fps is not None:
143
  if metadata.fps is None:
144
  metadata.fps = 24
145
- logger.warning_once(
146
  "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
147
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
148
  )
@@ -156,7 +158,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
156
 
157
  return indices
158
 
159
- def _preprocess(
160
  self,
161
  videos: list[torch.Tensor],
162
  do_convert_rgb: bool = True,
@@ -171,6 +173,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
171
  patch_size: int | None = None,
172
  temporal_patch_size: int | None = None,
173
  merge_size: int | None = None,
 
174
  return_tensors: str | TensorType | None = None,
175
  **kwargs,
176
  ):
@@ -185,16 +188,16 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
185
  num_frames=num_frames,
186
  height=height,
187
  width=width,
188
- temporal_factor=temporal_patch_size,
189
- factor=patch_size * merge_size,
190
- min_pixels=size.shortest_edge,
191
- max_pixels=size.longest_edge,
192
  )
193
  stacked_videos = stacked_videos.view(B * T, C, H, W)
194
  stacked_videos = self.resize(
195
  stacked_videos,
196
  size=SizeDict(height=resized_height, width=resized_width),
197
- interpolation=interpolation,
198
  )
199
  stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
200
  resized_videos_grouped[shape] = stacked_videos
@@ -206,40 +209,40 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
206
  processed_videos_grouped = {}
207
  processed_grids = {}
208
  for shape, stacked_videos in grouped_videos.items():
209
- resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
210
 
211
  # Fused rescale and normalize
212
  stacked_videos = self.rescale_and_normalize(
213
- stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
214
  )
215
  patches = stacked_videos
216
 
217
  # Check that videos have `num_frames` divisible by `temporal_patch_size`
218
  T = patches.shape[1]
219
- if pad := -T % temporal_patch_size:
220
  repeats = patches[:, -1:].expand(-1, pad, -1, -1, -1)
221
  patches = torch.cat((patches, repeats), dim=1)
222
  batch_size, grid_t, channel = patches.shape[:3]
223
- grid_t = grid_t // temporal_patch_size
224
- grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
225
 
226
  patches = patches.view(
227
  batch_size,
228
  grid_t,
229
- temporal_patch_size,
230
  channel,
231
- grid_h // merge_size,
232
- merge_size,
233
- patch_size,
234
- grid_w // merge_size,
235
- merge_size,
236
- patch_size,
237
  )
238
  patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
239
  flatten_patches = patches.reshape(
240
  batch_size,
241
  grid_t * grid_h * grid_w,
242
- channel * temporal_patch_size * patch_size * patch_size,
243
  )
244
 
245
  processed_videos_grouped[shape] = flatten_patches
 
47
  return h_bar, w_bar
48
 
49
 
50
+ class ZFQwen3VLVideoProcessorInitKwargs(VideosKwargs, total=False):
51
  patch_size: int
52
  temporal_patch_size: int
53
  merge_size: int
54
+ focus_size: int
55
  min_frames: int
56
  max_frames: int
57
 
 
80
  patch_size = 16
81
  temporal_patch_size = 2
82
  merge_size = 2
83
+ focus_size = 2
84
  fps = 2
85
  min_frames = 4
86
  max_frames = 768
87
  do_sample_frames = True
88
+ valid_kwargs = ZFQwen3VLVideoProcessorInitKwargs
89
  model_input_names = ["pixel_values_videos", "video_grid_thw"]
90
 
91
+ def __init__(self, **kwargs: Unpack[ZFQwen3VLVideoProcessorInitKwargs]):
92
  super().__init__(**kwargs)
93
  if self.size is not None and (
94
  self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None
95
  ):
96
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
97
 
98
+ def _further_process_kwargs( # type: ignore
99
  self,
100
  size: SizeDict | None = None,
101
  **kwargs,
 
109
 
110
  return super()._further_process_kwargs(size=size, **kwargs)
111
 
112
+ def sample_frames( # type: ignore
113
  self,
114
  metadata: VideoMetadata,
115
  num_frames: int | None = None,
 
144
  if num_frames is None and fps is not None:
145
  if metadata.fps is None:
146
  metadata.fps = 24
147
+ logger.warning_once( # type: ignore
148
  "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
149
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
150
  )
 
158
 
159
  return indices
160
 
161
+ def _preprocess( # type: ignore
162
  self,
163
  videos: list[torch.Tensor],
164
  do_convert_rgb: bool = True,
 
173
  patch_size: int | None = None,
174
  temporal_patch_size: int | None = None,
175
  merge_size: int | None = None,
176
+ focus_size: int | None = None,
177
  return_tensors: str | TensorType | None = None,
178
  **kwargs,
179
  ):
 
188
  num_frames=num_frames,
189
  height=height,
190
  width=width,
191
+ temporal_factor=temporal_patch_size, # type: ignore
192
+ factor=patch_size * merge_size * focus_size, # type: ignore
193
+ min_pixels=size.shortest_edge, # type: ignore
194
+ max_pixels=size.longest_edge, # type: ignore
195
  )
196
  stacked_videos = stacked_videos.view(B * T, C, H, W)
197
  stacked_videos = self.resize(
198
  stacked_videos,
199
  size=SizeDict(height=resized_height, width=resized_width),
200
+ interpolation=interpolation, # type: ignore
201
  )
202
  stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
203
  resized_videos_grouped[shape] = stacked_videos
 
209
  processed_videos_grouped = {}
210
  processed_grids = {}
211
  for shape, stacked_videos in grouped_videos.items():
212
+ resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) # type: ignore
213
 
214
  # Fused rescale and normalize
215
  stacked_videos = self.rescale_and_normalize(
216
+ stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std # type: ignore
217
  )
218
  patches = stacked_videos
219
 
220
  # Check that videos have `num_frames` divisible by `temporal_patch_size`
221
  T = patches.shape[1]
222
+ if pad := -T % (temporal_patch_size * focus_size): # type: ignore
223
  repeats = patches[:, -1:].expand(-1, pad, -1, -1, -1)
224
  patches = torch.cat((patches, repeats), dim=1)
225
  batch_size, grid_t, channel = patches.shape[:3]
226
+ grid_t = grid_t // temporal_patch_size # type: ignore
227
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # type: ignore
228
 
229
  patches = patches.view(
230
  batch_size,
231
  grid_t,
232
+ temporal_patch_size, # type: ignore
233
  channel,
234
+ grid_h // merge_size, # type: ignore
235
+ merge_size, # type: ignore
236
+ patch_size, # type: ignore
237
+ grid_w // merge_size, # type: ignore
238
+ merge_size, # type: ignore
239
+ patch_size, # type: ignore
240
  )
241
  patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
242
  flatten_patches = patches.reshape(
243
  batch_size,
244
  grid_t * grid_h * grid_w,
245
+ channel * temporal_patch_size * patch_size * patch_size, # type: ignore
246
  )
247
 
248
  processed_videos_grouped[shape] = flatten_patches