TYTTYTTYT commited on
Commit
bbaf57f
·
verified ·
1 Parent(s): b6d4848

ignore type errors in the processing codes

Browse files
image_processing_qwen2_vl.py CHANGED
@@ -3,15 +3,15 @@ from typing import Optional, Union
3
 
4
  import numpy as np
5
 
6
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
 
7
  from transformers.image_transforms import (
8
  convert_to_rgb,
9
  resize,
10
  to_channel_dimension_format,
11
  )
 
12
  from transformers.image_utils import (
13
- OPENAI_CLIP_MEAN,
14
- OPENAI_CLIP_STD,
15
  ChannelDimension,
16
  ImageInput,
17
  PILImageResampling,
@@ -23,7 +23,8 @@ from transformers.image_utils import (
23
  valid_images,
24
  validate_preprocess_arguments,
25
  )
26
- from transformers.utils import TensorType, logging
 
27
  from transformers.video_utils import VideoInput, make_batched_videos
28
 
29
 
@@ -205,16 +206,16 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
205
  - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
206
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
207
  """
208
- images = make_flat_list_of_images(images)
209
 
210
  if do_convert_rgb:
211
- images = [convert_to_rgb(image) for image in images]
212
 
213
  # All transformations expect numpy arrays.
214
- images = [to_numpy_array(image) for image in images]
215
 
216
  if do_rescale and is_scaled_image(images[0]):
217
- logger.warning_once(
218
  "It looks like you are trying to rescale already rescaled images. If the input"
219
  " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
220
  )
@@ -222,7 +223,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
222
  # We assume that all images have the same channel dimension format.
223
  input_data_format = infer_channel_dimension_format(images[0])
224
 
225
- height, width = get_image_size(images[0], channel_dim=input_data_format)
226
  resized_height, resized_width = height, width
227
  processed_images = []
228
  for image in images:
@@ -230,55 +231,55 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
230
  resized_height, resized_width = smart_resize(
231
  height,
232
  width,
233
- factor=patch_size * merge_size * focus_size,
234
- min_pixels=size["shortest_edge"],
235
- max_pixels=size["longest_edge"],
236
  )
237
  image = resize(
238
  image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
239
  )
240
 
241
  if do_rescale:
242
- image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
243
 
244
  if do_normalize:
245
  image = self.normalize(
246
- image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
247
  )
248
 
249
- image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
250
  processed_images.append(image)
251
 
252
  patches = np.array(processed_images)
253
  if data_format == ChannelDimension.LAST:
254
  patches = patches.transpose(0, 3, 1, 2)
255
- if patches.shape[0] % temporal_patch_size != 0:
256
  repeats = np.repeat(
257
- patches[-1][np.newaxis], temporal_patch_size - (patches.shape[0] % temporal_patch_size), axis=0
258
  )
259
  patches = np.concatenate([patches, repeats], axis=0)
260
  channel = patches.shape[1]
261
- grid_t = patches.shape[0] // temporal_patch_size
262
- grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
263
  patches = patches.reshape(
264
  grid_t,
265
- temporal_patch_size,
266
  channel,
267
- grid_h // merge_size,
268
- merge_size,
269
- patch_size,
270
- grid_w // merge_size,
271
- merge_size,
272
- patch_size,
273
  )
274
  patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
275
  flatten_patches = patches.reshape(
276
- grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
277
  )
278
 
279
  return flatten_patches, (grid_t, grid_h, grid_w)
280
 
281
- def preprocess(
282
  self,
283
  images: ImageInput,
284
  videos: Optional[VideoInput] = None,
@@ -386,7 +387,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
386
  do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
387
 
388
  if images is not None:
389
- images = self.fetch_images(images)
390
  images = make_flat_list_of_images(images)
391
 
392
  if images is not None and not valid_images(images):
@@ -408,7 +409,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
408
  data = {}
409
  if images is not None:
410
  pixel_values, vision_grid_thws = [], []
411
- for image in images:
412
  patches, image_grid_thw = self._preprocess(
413
  image,
414
  do_resize=do_resize,
@@ -439,9 +440,9 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
439
  "This is a deprecated behavior and will be removed in v5.0. "
440
  "Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
441
  )
442
- videos = make_batched_videos(videos)
443
  pixel_values_videos, vision_grid_thws_videos = [], []
444
- for images in videos:
445
  patches, video_grid_thw = self._preprocess(
446
  images,
447
  do_resize=do_resize,
@@ -484,11 +485,11 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
484
  Returns:
485
  `int`: Number of image patches per image.
486
  """
487
- min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
488
- max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
489
- patch_size = images_kwargs.get("patch_size", self.patch_size)
490
- merge_size = images_kwargs.get("merge_size", self.merge_size)
491
- focus_size = images_kwargs.get("focus_size", self.focus_size)
492
 
493
  factor = patch_size * merge_size * focus_size
494
  resized_height, resized_width = smart_resize(
 
3
 
4
  import numpy as np
5
 
6
+ from transformers.image_processing_utils import BaseImageProcessor
7
+ from transformers.image_processing_base import BatchFeature
8
  from transformers.image_transforms import (
9
  convert_to_rgb,
10
  resize,
11
  to_channel_dimension_format,
12
  )
13
+ from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
14
  from transformers.image_utils import (
 
 
15
  ChannelDimension,
16
  ImageInput,
17
  PILImageResampling,
 
23
  valid_images,
24
  validate_preprocess_arguments,
25
  )
26
+ from transformers.utils.generic import TensorType
27
+ from transformers.utils import logging
28
  from transformers.video_utils import VideoInput, make_batched_videos
29
 
30
 
 
206
  - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
207
  - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
208
  """
209
+ images = make_flat_list_of_images(images) # type: ignore
210
 
211
  if do_convert_rgb:
212
+ images = [convert_to_rgb(image) for image in images] # type: ignore
213
 
214
  # All transformations expect numpy arrays.
215
+ images = [to_numpy_array(image) for image in images] # type: ignore
216
 
217
  if do_rescale and is_scaled_image(images[0]):
218
+ logger.warning_once( # type: ignore
219
  "It looks like you are trying to rescale already rescaled images. If the input"
220
  " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
221
  )
 
223
  # We assume that all images have the same channel dimension format.
224
  input_data_format = infer_channel_dimension_format(images[0])
225
 
226
+ height, width = get_image_size(images[0], channel_dim=input_data_format) # type: ignore
227
  resized_height, resized_width = height, width
228
  processed_images = []
229
  for image in images:
 
231
  resized_height, resized_width = smart_resize(
232
  height,
233
  width,
234
+ factor=patch_size * merge_size * focus_size, # type: ignore
235
+ min_pixels=size["shortest_edge"], # type: ignore
236
+ max_pixels=size["longest_edge"], # type: ignore
237
  )
238
  image = resize(
239
  image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
240
  )
241
 
242
  if do_rescale:
243
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) # type: ignore
244
 
245
  if do_normalize:
246
  image = self.normalize(
247
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format # type: ignore
248
  )
249
 
250
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) # type: ignore
251
  processed_images.append(image)
252
 
253
  patches = np.array(processed_images)
254
  if data_format == ChannelDimension.LAST:
255
  patches = patches.transpose(0, 3, 1, 2)
256
+ if patches.shape[0] % temporal_patch_size != 0: # type: ignore
257
  repeats = np.repeat(
258
+ patches[-1][np.newaxis], temporal_patch_size - (patches.shape[0] % temporal_patch_size), axis=0 # type: ignore
259
  )
260
  patches = np.concatenate([patches, repeats], axis=0)
261
  channel = patches.shape[1]
262
+ grid_t = patches.shape[0] // temporal_patch_size # type: ignore
263
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # type: ignore
264
  patches = patches.reshape(
265
  grid_t,
266
+ temporal_patch_size, # type: ignore
267
  channel,
268
+ grid_h // merge_size, # type: ignore
269
+ merge_size, # type: ignore
270
+ patch_size, # type: ignore
271
+ grid_w // merge_size, # type: ignore
272
+ merge_size, # type: ignore
273
+ patch_size, # type: ignore
274
  )
275
  patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
276
  flatten_patches = patches.reshape(
277
+ grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size # type: ignore
278
  )
279
 
280
  return flatten_patches, (grid_t, grid_h, grid_w)
281
 
282
+ def preprocess( # type: ignore
283
  self,
284
  images: ImageInput,
285
  videos: Optional[VideoInput] = None,
 
387
  do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
388
 
389
  if images is not None:
390
+ images = self.fetch_images(images) # type: ignore
391
  images = make_flat_list_of_images(images)
392
 
393
  if images is not None and not valid_images(images):
 
409
  data = {}
410
  if images is not None:
411
  pixel_values, vision_grid_thws = [], []
412
+ for image in images: # type: ignore
413
  patches, image_grid_thw = self._preprocess(
414
  image,
415
  do_resize=do_resize,
 
440
  "This is a deprecated behavior and will be removed in v5.0. "
441
  "Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
442
  )
443
+ videos = make_batched_videos(videos) # type: ignore
444
  pixel_values_videos, vision_grid_thws_videos = [], []
445
+ for images in videos: # type: ignore
446
  patches, video_grid_thw = self._preprocess(
447
  images,
448
  do_resize=do_resize,
 
485
  Returns:
486
  `int`: Number of image patches per image.
487
  """
488
+ min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"] # type: ignore
489
+ max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"] # type: ignore
490
+ patch_size = images_kwargs.get("patch_size", self.patch_size) # type: ignore
491
+ merge_size = images_kwargs.get("merge_size", self.merge_size) # type: ignore
492
+ focus_size = images_kwargs.get("focus_size", self.focus_size) # type: ignore
493
 
494
  factor = patch_size * merge_size * focus_size
495
  resized_height, resized_width = smart_resize(
image_processing_qwen2_vl_fast.py CHANGED
@@ -3,27 +3,23 @@ from typing import Optional, Union
3
  import torch
4
  from torchvision.transforms.v2 import functional as F
5
 
6
- from transformers.image_processing_utils import BatchFeature
7
  from transformers.image_processing_utils_fast import (
8
  BaseImageProcessorFast,
9
  DefaultFastImageProcessorKwargs,
10
- group_images_by_shape,
11
- reorder_images,
12
  )
 
 
13
  from transformers.image_utils import (
14
- OPENAI_CLIP_MEAN,
15
- OPENAI_CLIP_STD,
16
  ChannelDimension,
17
  ImageInput,
18
  PILImageResampling,
19
  SizeDict,
20
  )
21
  from transformers.processing_utils import Unpack
22
- from transformers.utils import (
23
- TensorType,
24
- auto_docstring,
25
- logging,
26
- )
27
  from transformers.video_utils import VideoInput, make_batched_videos
28
  from .image_processing_qwen2_vl import smart_resize
29
 
@@ -81,17 +77,17 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
81
  # backward compatibility: override size with min_pixels and max_pixels if they are provided
82
  size = self.size if size is None else size
83
  if min_pixels is not None:
84
- size["shortest_edge"] = min_pixels
85
  size.pop("min_pixels", None)
86
  if max_pixels is not None:
87
- size["longest_edge"] = max_pixels
88
  size.pop("max_pixels", None)
89
  if "shortest_edge" not in size or "longest_edge" not in size:
90
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
91
 
92
- super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
93
 
94
- def _further_process_kwargs(
95
  self,
96
  size: Optional[SizeDict] = None,
97
  min_pixels: Optional[int] = None,
@@ -103,19 +99,19 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
103
  Can be overridden by subclasses to customize the processing of kwargs.
104
  """
105
  if min_pixels is not None and max_pixels is not None:
106
- size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
107
  elif size is not None:
108
  if "shortest_edge" not in size or "longest_edge" not in size:
109
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
110
  min_pixels = size["shortest_edge"]
111
  max_pixels = size["longest_edge"]
112
  else:
113
- size = {**self.size}
114
 
115
  return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
116
 
117
  @auto_docstring
118
- def preprocess(
119
  self,
120
  images: ImageInput,
121
  videos: Optional[VideoInput] = None,
@@ -123,14 +119,14 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
123
  ) -> BatchFeature:
124
  return super().preprocess(images, videos, **kwargs)
125
 
126
- def _preprocess_image_like_inputs(
127
  self,
128
  images: ImageInput,
129
  videos: VideoInput,
130
  do_convert_rgb: bool,
131
  input_data_format: ChannelDimension,
132
  device: Optional[Union[str, "torch.device"]] = None,
133
- **kwargs: Unpack[DefaultFastImageProcessorKwargs],
134
  ) -> BatchFeature:
135
  """
136
  Preprocess image-like inputs.
@@ -141,9 +137,9 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
141
  batch_feature = BatchFeature()
142
  if images is not None:
143
  images = self._prepare_image_like_inputs(
144
- images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
145
  )
146
- batch_feature = self._preprocess(images, **kwargs)
147
  if videos is not None:
148
  logger.warning(
149
  "`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
@@ -151,18 +147,18 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
151
  "Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
152
  )
153
  # Can't change _prepare_images_structure to work with videos because it also needs to work with images.
154
- videos = make_batched_videos(videos)
155
  videos = [
156
- torch.stack(self._prepare_image_like_inputs(video, do_convert_rgb, input_data_format, device))
157
  for video in videos
158
  ]
159
- video_outputs = self._preprocess(videos, **kwargs)
160
  batch_feature.update(
161
  {"pixel_values_videos": video_outputs.pixel_values, "video_grid_thw": video_outputs.image_grid_thw}
162
  )
163
  return batch_feature
164
 
165
- def _preprocess(
166
  self,
167
  images: list["torch.Tensor"],
168
  do_resize: bool,
@@ -182,10 +178,10 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
182
  **kwargs,
183
  ):
184
  # Group images by size for batched resizing
185
- grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
186
  resized_images_grouped = {}
187
  for shape, stacked_images in grouped_images.items():
188
- height, width = stacked_images.shape[-2:]
189
  if do_resize:
190
  resized_height, resized_width = smart_resize(
191
  height,
@@ -195,7 +191,7 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
195
  max_pixels=size["longest_edge"],
196
  )
197
  stacked_images = self.resize(
198
- image=stacked_images,
199
  size=SizeDict(height=resized_height, width=resized_width),
200
  interpolation=interpolation,
201
  )
@@ -204,14 +200,14 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
204
 
205
  # Group images by size for further processing
206
  # Needed in case do_resize is False, or resize returns images with different sizes
207
- grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
208
  processed_images_grouped = {}
209
  processed_grids = {}
210
  for shape, stacked_images in grouped_images.items():
211
- resized_height, resized_width = stacked_images.shape[-2:]
212
  # Fused rescale and normalize
213
  patches = self.rescale_and_normalize(
214
- stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
215
  )
216
  if patches.ndim == 4:
217
  # add a temporal dimension if we have images
@@ -249,7 +245,7 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
249
 
250
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
251
  processed_grids = reorder_images(processed_grids, grouped_images_index)
252
- pixel_values = torch.cat(processed_images, dim=0)
253
  image_grid_thw = torch.tensor(processed_grids)
254
 
255
  return BatchFeature(
@@ -273,11 +269,11 @@ class ZFQwen2VLImageProcessorFast(BaseImageProcessorFast):
273
  Returns:
274
  `int`: Number of image patches per image.
275
  """
276
- min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
277
- max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
278
- patch_size = images_kwargs.get("patch_size", self.patch_size)
279
- merge_size = images_kwargs.get("merge_size", self.merge_size)
280
- focus_size = images_kwargs.get("focus_size", self.focus_size)
281
 
282
  factor = patch_size * merge_size * focus_size
283
  resized_height, resized_width = smart_resize(
 
3
  import torch
4
  from torchvision.transforms.v2 import functional as F
5
 
6
+ from transformers.image_processing_base import BatchFeature
7
  from transformers.image_processing_utils_fast import (
8
  BaseImageProcessorFast,
9
  DefaultFastImageProcessorKwargs,
 
 
10
  )
11
+ from transformers.image_transforms import group_images_by_shape, reorder_images
12
+ from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
13
  from transformers.image_utils import (
 
 
14
  ChannelDimension,
15
  ImageInput,
16
  PILImageResampling,
17
  SizeDict,
18
  )
19
  from transformers.processing_utils import Unpack
20
+ from transformers.utils.generic import TensorType
21
+ from transformers.utils.auto_docstring import auto_docstring
22
+ from transformers.utils import logging
 
 
23
  from transformers.video_utils import VideoInput, make_batched_videos
24
  from .image_processing_qwen2_vl import smart_resize
25
 
 
77
  # backward compatibility: override size with min_pixels and max_pixels if they are provided
78
  size = self.size if size is None else size
79
  if min_pixels is not None:
80
+ size["shortest_edge"] = min_pixels # type: ignore
81
  size.pop("min_pixels", None)
82
  if max_pixels is not None:
83
+ size["longest_edge"] = max_pixels # type: ignore
84
  size.pop("max_pixels", None)
85
  if "shortest_edge" not in size or "longest_edge" not in size:
86
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
87
 
88
+ super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) # type: ignore
89
 
90
+ def _further_process_kwargs( # type: ignore
91
  self,
92
  size: Optional[SizeDict] = None,
93
  min_pixels: Optional[int] = None,
 
99
  Can be overridden by subclasses to customize the processing of kwargs.
100
  """
101
  if min_pixels is not None and max_pixels is not None:
102
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} # type: ignore
103
  elif size is not None:
104
  if "shortest_edge" not in size or "longest_edge" not in size:
105
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
106
  min_pixels = size["shortest_edge"]
107
  max_pixels = size["longest_edge"]
108
  else:
109
+ size = {**self.size} # type: ignore
110
 
111
  return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
112
 
113
  @auto_docstring
114
+ def preprocess( # type: ignore
115
  self,
116
  images: ImageInput,
117
  videos: Optional[VideoInput] = None,
 
119
  ) -> BatchFeature:
120
  return super().preprocess(images, videos, **kwargs)
121
 
122
+ def _preprocess_image_like_inputs( # type: ignore
123
  self,
124
  images: ImageInput,
125
  videos: VideoInput,
126
  do_convert_rgb: bool,
127
  input_data_format: ChannelDimension,
128
  device: Optional[Union[str, "torch.device"]] = None,
129
+ **kwargs: Unpack[DefaultFastImageProcessorKwargs], # type: ignore
130
  ) -> BatchFeature:
131
  """
132
  Preprocess image-like inputs.
 
137
  batch_feature = BatchFeature()
138
  if images is not None:
139
  images = self._prepare_image_like_inputs(
140
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device # type: ignore
141
  )
142
+ batch_feature = self._preprocess(images, **kwargs) # type: ignore
143
  if videos is not None:
144
  logger.warning(
145
  "`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
 
147
  "Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
148
  )
149
  # Can't change _prepare_images_structure to work with videos because it also needs to work with images.
150
+ videos = make_batched_videos(videos) # type: ignore
151
  videos = [
152
+ torch.stack(self._prepare_image_like_inputs(video, do_convert_rgb, input_data_format, device)) # type: ignore
153
  for video in videos
154
  ]
155
+ video_outputs = self._preprocess(videos, **kwargs) # type: ignore
156
  batch_feature.update(
157
  {"pixel_values_videos": video_outputs.pixel_values, "video_grid_thw": video_outputs.image_grid_thw}
158
  )
159
  return batch_feature
160
 
161
+ def _preprocess( # type: ignore
162
  self,
163
  images: list["torch.Tensor"],
164
  do_resize: bool,
 
178
  **kwargs,
179
  ):
180
  # Group images by size for batched resizing
181
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) # type: ignore
182
  resized_images_grouped = {}
183
  for shape, stacked_images in grouped_images.items():
184
+ height, width = stacked_images.shape[-2:] # type: ignore
185
  if do_resize:
186
  resized_height, resized_width = smart_resize(
187
  height,
 
191
  max_pixels=size["longest_edge"],
192
  )
193
  stacked_images = self.resize(
194
+ image=stacked_images, # type: ignore
195
  size=SizeDict(height=resized_height, width=resized_width),
196
  interpolation=interpolation,
197
  )
 
200
 
201
  # Group images by size for further processing
202
  # Needed in case do_resize is False, or resize returns images with different sizes
203
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) # type: ignore
204
  processed_images_grouped = {}
205
  processed_grids = {}
206
  for shape, stacked_images in grouped_images.items():
207
+ resized_height, resized_width = stacked_images.shape[-2:] # type: ignore
208
  # Fused rescale and normalize
209
  patches = self.rescale_and_normalize(
210
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std # type: ignore
211
  )
212
  if patches.ndim == 4:
213
  # add a temporal dimension if we have images
 
245
 
246
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
247
  processed_grids = reorder_images(processed_grids, grouped_images_index)
248
+ pixel_values = torch.cat(processed_images, dim=0) # type: ignore
249
  image_grid_thw = torch.tensor(processed_grids)
250
 
251
  return BatchFeature(
 
269
  Returns:
270
  `int`: Number of image patches per image.
271
  """
272
+ min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"] # type: ignore
273
+ max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"] # type: ignore
274
+ patch_size = images_kwargs.get("patch_size", self.patch_size) # type: ignore
275
+ merge_size = images_kwargs.get("merge_size", self.merge_size) # type: ignore
276
+ focus_size = images_kwargs.get("focus_size", self.focus_size) # type: ignore
277
 
278
  factor = patch_size * merge_size * focus_size
279
  resized_height, resized_width = smart_resize(
processing_qwen3_vl.py CHANGED
@@ -27,9 +27,9 @@ class Qwen3VLImagesKwargs(ImagesKwargs):
27
 
28
 
29
  class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
30
- images_kwargs: Qwen3VLImagesKwargs
31
- videos_kwargs: Qwen3VLVideosProcessorKwargs
32
- _defaults = {
33
  "text_kwargs": {
34
  "padding": False,
35
  "return_token_type_ids": False,
@@ -62,40 +62,40 @@ class ZFQwen3VLProcessor(ProcessorMixin):
62
 
63
  def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
64
  super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
65
- self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
66
- self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
67
  self.image_token_id = (
68
- tokenizer.image_token_id
69
  if getattr(tokenizer, "image_token_id", None)
70
- else tokenizer.convert_tokens_to_ids(self.image_token)
71
  )
72
  self.video_token_id = (
73
- tokenizer.video_token_id
74
  if getattr(tokenizer, "video_token_id", None)
75
- else tokenizer.convert_tokens_to_ids(self.video_token)
76
  )
77
  self.vision_start_token = (
78
- "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
79
  )
80
  self.vision_end_token = (
81
- "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
82
  )
83
  self.vision_start_token_id = (
84
- tokenizer.vision_start_token_id
85
  if getattr(tokenizer, "vision_start_token_id", None)
86
- else tokenizer.convert_tokens_to_ids(self.vision_start_token)
87
  )
88
  self.vision_end_token_id = (
89
- tokenizer.vision_end_token_id
90
  if getattr(tokenizer, "vision_end_token_id", None)
91
- else tokenizer.convert_tokens_to_ids(self.vision_end_token)
92
  )
93
 
94
- def __call__(
95
  self,
96
- images: ImageInput = None,
97
- text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
98
- videos: VideoInput = None,
99
  **kwargs: Unpack[Qwen3VLProcessorKwargs],
100
  ) -> BatchFeature:
101
  """
@@ -135,19 +135,19 @@ class ZFQwen3VLProcessor(ProcessorMixin):
135
  - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
136
  """
137
  output_kwargs = self._merge_kwargs(
138
- Qwen3VLProcessorKwargs,
139
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
140
  **kwargs,
141
  )
142
  if images is not None:
143
- image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
144
  image_grid_thw = image_inputs["image_grid_thw"]
145
  else:
146
  image_inputs = {}
147
  image_grid_thw = None
148
 
149
  if videos is not None:
150
- videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
151
  video_grid_thw = videos_inputs["video_grid_thw"]
152
  # If user has not requested video metadata, pop it
153
  if "return_metadata" not in kwargs:
@@ -164,23 +164,23 @@ class ZFQwen3VLProcessor(ProcessorMixin):
164
 
165
  text = text.copy() # below lines change text in-place
166
  if image_grid_thw is not None:
167
- merge_length = self.image_processor.merge_size**2
168
  index = 0
169
  for i in range(len(text)):
170
  while self.image_token in text[i]:
171
  num_image_tokens = image_grid_thw[index].prod() // merge_length
172
- text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
173
  index += 1
174
- text[i] = text[i].replace("<|placeholder|>", self.image_token)
175
 
176
  if video_grid_thw is not None:
177
- merge_length = self.video_processor.merge_size**2
178
  index = 0
179
  for i in range(len(text)):
180
  while self.video_token in text[i]:
181
- metadata = video_metadata[index]
182
  if metadata.fps is None:
183
- logger.warning_once(
184
  "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
185
  "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
186
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
@@ -191,8 +191,8 @@ class ZFQwen3VLProcessor(ProcessorMixin):
191
  curr_timestamp = self._calculate_timestamps(
192
  metadata.frames_indices,
193
  metadata.fps,
194
- self.video_processor.merge_size,
195
- self.video_processor.focus_size,
196
  )
197
 
198
  print(len(curr_timestamp), curr_timestamp)
@@ -206,20 +206,20 @@ class ZFQwen3VLProcessor(ProcessorMixin):
206
  self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
207
  )
208
  if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
209
- text[i] = text[i].replace(
210
  f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
211
  )
212
  else:
213
  # vllm may input video token directly
214
- text[i] = text[i].replace(self.video_token, video_placeholder, 1)
215
  index += 1
216
 
217
- text[i] = text[i].replace("<|placeholder|>", self.video_token)
218
 
219
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
220
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
221
- text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
222
- self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
223
 
224
  if return_mm_token_type_ids:
225
  array_ids = np.array(text_inputs["input_ids"])
@@ -246,10 +246,10 @@ class ZFQwen3VLProcessor(ProcessorMixin):
246
  if image_sizes is not None:
247
  images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
248
  images_kwargs.update(kwargs)
249
- merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
250
 
251
  num_image_patches = [
252
- self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
253
  for image_size in image_sizes
254
  ]
255
  num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
@@ -259,10 +259,10 @@ class ZFQwen3VLProcessor(ProcessorMixin):
259
  videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
260
  videos_kwargs.update(kwargs)
261
  num_video_patches = [
262
- self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
263
  for video_size in video_sizes
264
  ]
265
- num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
266
  vision_data["num_video_tokens"] = num_video_tokens
267
 
268
  return MultiModalData(**vision_data)
@@ -287,7 +287,7 @@ class ZFQwen3VLProcessor(ProcessorMixin):
287
  Returns:
288
  `list[str]`: The decoded text.
289
  """
290
- return self.tokenizer.batch_decode(
291
  generated_outputs,
292
  skip_special_tokens=skip_special_tokens,
293
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
@@ -306,7 +306,7 @@ class ZFQwen3VLProcessor(ProcessorMixin):
306
  print(len(indices), indices)
307
  b_size = merge_size * focus_size
308
  if len(indices) % b_size != 0:
309
- indices.extend(indices[-1] for _ in range(b_size - len(indices) % b_size))
310
  print(len(indices), indices)
311
  timestamps = [idx / video_fps for idx in indices]
312
  # @JJJYmmm frames are merged by self.merge_size, \
 
27
 
28
 
29
  class Qwen3VLProcessorKwargs(ProcessingKwargs, total=False):
30
+ images_kwargs: Qwen3VLImagesKwargs # type: ignore
31
+ videos_kwargs: Qwen3VLVideosProcessorKwargs # type: ignore
32
+ _defaults = { # type: ignore
33
  "text_kwargs": {
34
  "padding": False,
35
  "return_token_type_ids": False,
 
62
 
63
  def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
64
  super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
65
+ self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token # type: ignore
66
+ self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token # type: ignore
67
  self.image_token_id = (
68
+ tokenizer.image_token_id # type: ignore
69
  if getattr(tokenizer, "image_token_id", None)
70
+ else tokenizer.convert_tokens_to_ids(self.image_token) # type: ignore
71
  )
72
  self.video_token_id = (
73
+ tokenizer.video_token_id # type: ignore
74
  if getattr(tokenizer, "video_token_id", None)
75
+ else tokenizer.convert_tokens_to_ids(self.video_token) # type: ignore
76
  )
77
  self.vision_start_token = (
78
+ "<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token # type: ignore
79
  )
80
  self.vision_end_token = (
81
+ "<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token # type: ignore
82
  )
83
  self.vision_start_token_id = (
84
+ tokenizer.vision_start_token_id # type: ignore
85
  if getattr(tokenizer, "vision_start_token_id", None)
86
+ else tokenizer.convert_tokens_to_ids(self.vision_start_token) # type: ignore
87
  )
88
  self.vision_end_token_id = (
89
+ tokenizer.vision_end_token_id # type: ignore
90
  if getattr(tokenizer, "vision_end_token_id", None)
91
+ else tokenizer.convert_tokens_to_ids(self.vision_end_token) # type: ignore
92
  )
93
 
94
+ def __call__( # type: ignore
95
  self,
96
+ images: ImageInput = None, # type: ignore
97
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, # type: ignore
98
+ videos: VideoInput = None, # type: ignore
99
  **kwargs: Unpack[Qwen3VLProcessorKwargs],
100
  ) -> BatchFeature:
101
  """
 
135
  - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
136
  """
137
  output_kwargs = self._merge_kwargs(
138
+ Qwen3VLProcessorKwargs, # type: ignore
139
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs, # type: ignore
140
  **kwargs,
141
  )
142
  if images is not None:
143
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) # type: ignore
144
  image_grid_thw = image_inputs["image_grid_thw"]
145
  else:
146
  image_inputs = {}
147
  image_grid_thw = None
148
 
149
  if videos is not None:
150
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) # type: ignore
151
  video_grid_thw = videos_inputs["video_grid_thw"]
152
  # If user has not requested video metadata, pop it
153
  if "return_metadata" not in kwargs:
 
164
 
165
  text = text.copy() # below lines change text in-place
166
  if image_grid_thw is not None:
167
+ merge_length = self.image_processor.merge_size**2 # type: ignore
168
  index = 0
169
  for i in range(len(text)):
170
  while self.image_token in text[i]:
171
  num_image_tokens = image_grid_thw[index].prod() // merge_length
172
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) # type: ignore
173
  index += 1
174
+ text[i] = text[i].replace("<|placeholder|>", self.image_token) # type: ignore
175
 
176
  if video_grid_thw is not None:
177
+ merge_length = self.video_processor.merge_size**2 # type: ignore
178
  index = 0
179
  for i in range(len(text)):
180
  while self.video_token in text[i]:
181
+ metadata = video_metadata[index] # type: ignore
182
  if metadata.fps is None:
183
+ logger.warning_once( # type: ignore
184
  "Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
185
  "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
186
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
 
191
  curr_timestamp = self._calculate_timestamps(
192
  metadata.frames_indices,
193
  metadata.fps,
194
+ self.video_processor.merge_size, # type: ignore
195
+ self.video_processor.focus_size, # type: ignore
196
  )
197
 
198
  print(len(curr_timestamp), curr_timestamp)
 
206
  self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
207
  )
208
  if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
209
+ text[i] = text[i].replace( # type: ignore
210
  f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
211
  )
212
  else:
213
  # vllm may input video token directly
214
+ text[i] = text[i].replace(self.video_token, video_placeholder, 1) # type: ignore
215
  index += 1
216
 
217
+ text[i] = text[i].replace("<|placeholder|>", self.video_token) # type: ignore
218
 
219
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
220
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
221
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) # type: ignore
222
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) # type: ignore
223
 
224
  if return_mm_token_type_ids:
225
  array_ids = np.array(text_inputs["input_ids"])
 
246
  if image_sizes is not None:
247
  images_kwargs = Qwen3VLProcessorKwargs._defaults.get("images_kwargs", {})
248
  images_kwargs.update(kwargs)
249
+ merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size # type: ignore
250
 
251
  num_image_patches = [
252
+ self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) # type: ignore
253
  for image_size in image_sizes
254
  ]
255
  num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
 
259
  videos_kwargs = Qwen3VLProcessorKwargs._defaults.get("videos_kwargs", {})
260
  videos_kwargs.update(kwargs)
261
  num_video_patches = [
262
+ self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) # type: ignore
263
  for video_size in video_sizes
264
  ]
265
+ num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] # type: ignore
266
  vision_data["num_video_tokens"] = num_video_tokens
267
 
268
  return MultiModalData(**vision_data)
 
287
  Returns:
288
  `list[str]`: The decoded text.
289
  """
290
+ return self.tokenizer.batch_decode( # type: ignore
291
  generated_outputs,
292
  skip_special_tokens=skip_special_tokens,
293
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
 
306
  print(len(indices), indices)
307
  b_size = merge_size * focus_size
308
  if len(indices) % b_size != 0:
309
+ indices.extend(indices[-1] for _ in range(b_size - len(indices) % b_size)) # type: ignore
310
  print(len(indices), indices)
311
  timestamps = [idx / video_fps for idx in indices]
312
  # @JJJYmmm frames are merged by self.merge_size, \
video_processing_qwen3_vl.py CHANGED
@@ -7,7 +7,9 @@ import torch
7
  from transformers.feature_extraction_utils import BatchFeature
8
  from transformers.image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size
9
  from transformers.processing_utils import Unpack, VideosKwargs
10
- from transformers.utils import TensorType, add_start_docstrings, logging
 
 
11
  from transformers.video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
12
  from transformers.video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
13
 
@@ -96,7 +98,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
96
  ):
97
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
98
 
99
- def _further_process_kwargs(
100
  self,
101
  size: Optional[SizeDict] = None,
102
  **kwargs,
@@ -110,7 +112,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
110
 
111
  return super()._further_process_kwargs(size=size, **kwargs)
112
 
113
- def sample_frames(
114
  self,
115
  metadata: VideoMetadata,
116
  num_frames: Optional[int] = None,
@@ -145,7 +147,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
145
  if num_frames is None and fps is not None:
146
  if metadata.fps is None:
147
  metadata.fps = 24
148
- logger.warning_once(
149
  "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
150
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
151
  )
@@ -159,7 +161,7 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
159
 
160
  return indices
161
 
162
- def _preprocess(
163
  self,
164
  videos: list[torch.Tensor],
165
  do_convert_rgb: bool = True,
@@ -189,16 +191,16 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
189
  num_frames=num_frames,
190
  height=height,
191
  width=width,
192
- temporal_factor=temporal_patch_size,
193
- factor=patch_size * merge_size * focus_size,
194
- min_pixels=size.shortest_edge,
195
- max_pixels=size.longest_edge,
196
  )
197
  stacked_videos = stacked_videos.view(B * T, C, H, W)
198
  stacked_videos = self.resize(
199
  stacked_videos,
200
  size=SizeDict(height=resized_height, width=resized_width),
201
- interpolation=interpolation,
202
  )
203
  stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
204
  resized_videos_grouped[shape] = stacked_videos
@@ -210,40 +212,40 @@ class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
210
  processed_videos_grouped = {}
211
  processed_grids = {}
212
  for shape, stacked_videos in grouped_videos.items():
213
- resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
214
 
215
  # Fused rescale and normalize
216
  stacked_videos = self.rescale_and_normalize(
217
- stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
218
  )
219
  patches = stacked_videos
220
 
221
- temporal_focus_size = temporal_patch_size * focus_size
222
  # Check that videos have `num_frames` divisible by `temporal_patch_size`
223
  if res := patches.shape[1] % temporal_focus_size:
224
  repeats = patches[:, -1:].repeat(1, temporal_focus_size - res, 1, 1, 1)
225
  patches = torch.cat([patches, repeats], dim=1)
226
  batch_size, grid_t, channel = patches.shape[:3]
227
- grid_t = grid_t // temporal_patch_size
228
- grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
229
 
230
  patches = patches.view(
231
  batch_size,
232
  grid_t,
233
- temporal_patch_size,
234
  channel,
235
- grid_h // merge_size,
236
- merge_size,
237
- patch_size,
238
- grid_w // merge_size,
239
- merge_size,
240
- patch_size,
241
  )
242
  patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
243
  flatten_patches = patches.reshape(
244
  batch_size,
245
  grid_t * grid_h * grid_w,
246
- channel * temporal_patch_size * patch_size * patch_size,
247
  )
248
 
249
  processed_videos_grouped[shape] = flatten_patches
 
7
  from transformers.feature_extraction_utils import BatchFeature
8
  from transformers.image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size
9
  from transformers.processing_utils import Unpack, VideosKwargs
10
+ from transformers.utils.generic import TensorType
11
+ from transformers.utils.doc import add_start_docstrings
12
+ from transformers.utils import logging
13
  from transformers.video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
14
  from transformers.video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
15
 
 
98
  ):
99
  raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
100
 
101
+ def _further_process_kwargs( # type: ignore
102
  self,
103
  size: Optional[SizeDict] = None,
104
  **kwargs,
 
112
 
113
  return super()._further_process_kwargs(size=size, **kwargs)
114
 
115
+ def sample_frames( # type: ignore
116
  self,
117
  metadata: VideoMetadata,
118
  num_frames: Optional[int] = None,
 
147
  if num_frames is None and fps is not None:
148
  if metadata.fps is None:
149
  metadata.fps = 24
150
+ logger.warning_once( # type: ignore
151
  "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
152
  "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
153
  )
 
161
 
162
  return indices
163
 
164
+ def _preprocess( # type: ignore
165
  self,
166
  videos: list[torch.Tensor],
167
  do_convert_rgb: bool = True,
 
191
  num_frames=num_frames,
192
  height=height,
193
  width=width,
194
+ temporal_factor=temporal_patch_size, # type: ignore
195
+ factor=patch_size * merge_size * focus_size, # type: ignore
196
+ min_pixels=size.shortest_edge, # type: ignore
197
+ max_pixels=size.longest_edge, # type: ignore
198
  )
199
  stacked_videos = stacked_videos.view(B * T, C, H, W)
200
  stacked_videos = self.resize(
201
  stacked_videos,
202
  size=SizeDict(height=resized_height, width=resized_width),
203
+ interpolation=interpolation, # type: ignore
204
  )
205
  stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
206
  resized_videos_grouped[shape] = stacked_videos
 
212
  processed_videos_grouped = {}
213
  processed_grids = {}
214
  for shape, stacked_videos in grouped_videos.items():
215
+ resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) # type: ignore
216
 
217
  # Fused rescale and normalize
218
  stacked_videos = self.rescale_and_normalize(
219
+ stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std # type: ignore
220
  )
221
  patches = stacked_videos
222
 
223
+ temporal_focus_size = temporal_patch_size * focus_size # type: ignore
224
  # Check that videos have `num_frames` divisible by `temporal_patch_size`
225
  if res := patches.shape[1] % temporal_focus_size:
226
  repeats = patches[:, -1:].repeat(1, temporal_focus_size - res, 1, 1, 1)
227
  patches = torch.cat([patches, repeats], dim=1)
228
  batch_size, grid_t, channel = patches.shape[:3]
229
+ grid_t = grid_t // temporal_patch_size # type: ignore
230
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # type: ignore
231
 
232
  patches = patches.view(
233
  batch_size,
234
  grid_t,
235
+ temporal_patch_size, # type: ignore
236
  channel,
237
+ grid_h // merge_size, # type: ignore
238
+ merge_size, # type: ignore
239
+ patch_size, # type: ignore
240
+ grid_w // merge_size, # type: ignore
241
+ merge_size, # type: ignore
242
+ patch_size, # type: ignore
243
  )
244
  patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
245
  flatten_patches = patches.reshape(
246
  batch_size,
247
  grid_t * grid_h * grid_w,
248
+ channel * temporal_patch_size * patch_size * patch_size, # type: ignore
249
  )
250
 
251
  processed_videos_grouped[shape] = flatten_patches