Update code for transformers 5.5.4

#6
by sjzhou - opened
Files changed (1) hide show
  1. processing_moss_vl.py +78 -16
processing_moss_vl.py CHANGED
@@ -23,8 +23,8 @@ import torch
23
  from torchvision.transforms.v2 import functional as F
24
  from PIL import Image
25
  from transformers.feature_extraction_utils import BatchFeature
26
- from transformers.image_utils import ImageInput, SizeDict
27
- from transformers.image_processing_utils_fast import group_images_by_shape, reorder_images
28
  from transformers.utils import TensorType
29
  from transformers.processing_utils import (
30
  ImagesKwargs,
@@ -35,17 +35,16 @@ from transformers.processing_utils import (
35
  )
36
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
37
  from transformers.utils import logging
38
- from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast
39
- from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
40
 
41
 
42
  logger = logging.get_logger(__name__)
43
 
44
 
45
- class MossVLImageProcessorFast(Qwen2VLImageProcessorFast):
46
  """
47
  Custom image processor that overrides _preprocess to support multi_image_max_pixels.
48
- Inherits from Qwen2VLImageProcessorFast.
49
  """
50
  # Multi-image batch total pixels limit (read from config)
51
  multi_image_max_pixels = None
@@ -56,7 +55,7 @@ class MossVLImageProcessorFast(Qwen2VLImageProcessorFast):
56
  images: list["torch.Tensor"],
57
  do_resize: bool,
58
  size: SizeDict,
59
- interpolation: Optional["F.InterpolationMode"],
60
  do_rescale: bool,
61
  rescale_factor: float,
62
  do_normalize: bool,
@@ -75,6 +74,8 @@ class MossVLImageProcessorFast(Qwen2VLImageProcessorFast):
75
  to each image based on its original pixel count. min_pixels remains a per-image
76
  constraint. multi_image_max_pixels can be configured separately from longest_edge.
77
  """
 
 
78
  min_pixels = size["shortest_edge"]
79
  max_pixels = size["longest_edge"] # Per-image upper limit
80
  # Use multi_image_max_pixels if configured, otherwise fall back to longest_edge
@@ -115,7 +116,7 @@ class MossVLImageProcessorFast(Qwen2VLImageProcessorFast):
115
  stacked_images = self.resize(
116
  image=stacked_images,
117
  size=SizeDict(height=resized_height, width=resized_width),
118
- interpolation=interpolation,
119
  )
120
  resized_images_grouped[shape] = stacked_images
121
  resized_images = reorder_images(resized_images_grouped, grouped_images_index)
@@ -214,6 +215,58 @@ def _to_numpy(x):
214
  return np.array(x)
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  class MossVLImagesKwargs(ImagesKwargs):
218
  min_pixels: Optional[int]
219
  max_pixels: Optional[int]
@@ -272,8 +325,6 @@ class MossVLProcessor(ProcessorMixin):
272
  """
273
 
274
  attributes = ["image_processor", "tokenizer", "video_processor"]
275
- image_processor_class = "AutoImageProcessor"
276
- video_processor_class = "AutoVideoProcessor"
277
  tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
278
 
279
  def __init__(
@@ -485,7 +536,9 @@ class MossVLProcessor(ProcessorMixin):
485
  elif len(patch_counts) > 1:
486
  # Multiple images: split by cumulative counts
487
  split_indices = np.cumsum(patch_counts)[:-1]
488
- image_pixel_values_list = np.split(flat_pixel_values, split_indices)
 
 
489
 
490
  if has_videos:
491
  flat_video_values = videos_inputs["pixel_values_videos"]
@@ -497,7 +550,9 @@ class MossVLProcessor(ProcessorMixin):
497
  elif len(video_patch_counts) > 1:
498
  # Multiple videos: split by cumulative counts
499
  split_indices = np.cumsum(video_patch_counts)[:-1]
500
- video_pixel_values_list = np.split(flat_video_values, split_indices)
 
 
501
 
502
  # Step 3.1: Replace placeholders (simple replacement, no expansion yet)
503
  # In MossVL, one image placeholder = one image token
@@ -713,10 +768,14 @@ class MossVLProcessor(ProcessorMixin):
713
 
714
  # Concatenate/stack to unified format
715
  if final_pixel_values:
716
- output_data["pixel_values"] = np.concatenate(final_pixel_values, axis=0)
 
 
717
 
718
  if final_grid_thw:
719
- output_data["grid_thw"] = np.stack(final_grid_thw, axis=0)
 
 
720
 
721
  # Don't add media_nums_per_sample to output_data yet
722
  # Will add it after BatchFeature to keep it as list
@@ -773,6 +832,10 @@ class MossVLProcessor(ProcessorMixin):
773
  for _ in range(num_media):
774
  # grid_thw is (N, 3) where first dim is t (num_frames)
775
  t = grid_thw[media_idx][0]
 
 
 
 
776
  sample_frames += t
777
  media_idx += 1
778
  total_frames_per_sample.append(sample_frames)
@@ -1075,5 +1138,4 @@ class MossVLProcessor(ProcessorMixin):
1075
  **kwargs,
1076
  )
1077
 
1078
-
1079
- __all__ = ["MossVLProcessor", "MossVLImageProcessorFast"]
 
23
  from torchvision.transforms.v2 import functional as F
24
  from PIL import Image
25
  from transformers.feature_extraction_utils import BatchFeature
26
+ from transformers.image_utils import ImageInput, PILImageResampling, SizeDict
27
+ from transformers.image_transforms import group_images_by_shape, reorder_images
28
  from transformers.utils import TensorType
29
  from transformers.processing_utils import (
30
  ImagesKwargs,
 
35
  )
36
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
37
  from transformers.utils import logging
38
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor, smart_resize
 
39
 
40
 
41
  logger = logging.get_logger(__name__)
42
 
43
 
44
+ class MossVLImageProcessor(Qwen2VLImageProcessor):
45
  """
46
  Custom image processor that overrides _preprocess to support multi_image_max_pixels.
47
+ Inherits from Qwen2VLImageProcessor.
48
  """
49
  # Multi-image batch total pixels limit (read from config)
50
  multi_image_max_pixels = None
 
55
  images: list["torch.Tensor"],
56
  do_resize: bool,
57
  size: SizeDict,
58
+ resample: Optional[Union["PILImageResampling", "F.InterpolationMode", int]],
59
  do_rescale: bool,
60
  rescale_factor: float,
61
  do_normalize: bool,
 
74
  to each image based on its original pixel count. min_pixels remains a per-image
75
  constraint. multi_image_max_pixels can be configured separately from longest_edge.
76
  """
77
+ if resample is None:
78
+ resample = kwargs.pop("interpolation", None)
79
  min_pixels = size["shortest_edge"]
80
  max_pixels = size["longest_edge"] # Per-image upper limit
81
  # Use multi_image_max_pixels if configured, otherwise fall back to longest_edge
 
116
  stacked_images = self.resize(
117
  image=stacked_images,
118
  size=SizeDict(height=resized_height, width=resized_width),
119
+ interpolation=resample,
120
  )
121
  resized_images_grouped[shape] = stacked_images
122
  resized_images = reorder_images(resized_images_grouped, grouped_images_index)
 
215
  return np.array(x)
216
 
217
 
218
+ def _split_array_or_tensor(x, split_indices):
219
+ """Split along the first dimension while preserving tensor/array type."""
220
+ split_indices = [int(idx) for idx in split_indices]
221
+ if isinstance(x, torch.Tensor):
222
+ if not split_indices:
223
+ return [x]
224
+ chunks = []
225
+ start = 0
226
+ for end in split_indices:
227
+ chunks.append(x[start:end])
228
+ start = end
229
+ chunks.append(x[start:])
230
+ return chunks
231
+ return np.split(x, split_indices)
232
+
233
+
234
+ def _concat_array_or_tensor(items, axis=0):
235
+ """Concatenate while preserving tensor/array type and device."""
236
+ if not items:
237
+ return None
238
+
239
+ if any(isinstance(item, torch.Tensor) for item in items):
240
+ ref = next(item for item in items if isinstance(item, torch.Tensor))
241
+ tensor_items = [
242
+ item
243
+ if isinstance(item, torch.Tensor)
244
+ else torch.as_tensor(item, device=ref.device, dtype=ref.dtype)
245
+ for item in items
246
+ ]
247
+ return torch.cat(tensor_items, dim=axis)
248
+
249
+ return np.concatenate(items, axis=axis)
250
+
251
+
252
+ def _stack_array_or_tensor(items, axis=0):
253
+ """Stack while preserving tensor/array type and device."""
254
+ if not items:
255
+ return None
256
+
257
+ if any(isinstance(item, torch.Tensor) for item in items):
258
+ ref = next(item for item in items if isinstance(item, torch.Tensor))
259
+ tensor_items = [
260
+ item
261
+ if isinstance(item, torch.Tensor)
262
+ else torch.as_tensor(item, device=ref.device, dtype=ref.dtype)
263
+ for item in items
264
+ ]
265
+ return torch.stack(tensor_items, dim=axis)
266
+
267
+ return np.stack(items, axis=axis)
268
+
269
+
270
  class MossVLImagesKwargs(ImagesKwargs):
271
  min_pixels: Optional[int]
272
  max_pixels: Optional[int]
 
325
  """
326
 
327
  attributes = ["image_processor", "tokenizer", "video_processor"]
 
 
328
  tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
329
 
330
  def __init__(
 
536
  elif len(patch_counts) > 1:
537
  # Multiple images: split by cumulative counts
538
  split_indices = np.cumsum(patch_counts)[:-1]
539
+ image_pixel_values_list = _split_array_or_tensor(
540
+ flat_pixel_values, split_indices
541
+ )
542
 
543
  if has_videos:
544
  flat_video_values = videos_inputs["pixel_values_videos"]
 
550
  elif len(video_patch_counts) > 1:
551
  # Multiple videos: split by cumulative counts
552
  split_indices = np.cumsum(video_patch_counts)[:-1]
553
+ video_pixel_values_list = _split_array_or_tensor(
554
+ flat_video_values, split_indices
555
+ )
556
 
557
  # Step 3.1: Replace placeholders (simple replacement, no expansion yet)
558
  # In MossVL, one image placeholder = one image token
 
768
 
769
  # Concatenate/stack to unified format
770
  if final_pixel_values:
771
+ output_data["pixel_values"] = _concat_array_or_tensor(
772
+ final_pixel_values, axis=0
773
+ )
774
 
775
  if final_grid_thw:
776
+ output_data["grid_thw"] = _stack_array_or_tensor(
777
+ final_grid_thw, axis=0
778
+ )
779
 
780
  # Don't add media_nums_per_sample to output_data yet
781
  # Will add it after BatchFeature to keep it as list
 
832
  for _ in range(num_media):
833
  # grid_thw is (N, 3) where first dim is t (num_frames)
834
  t = grid_thw[media_idx][0]
835
+ if isinstance(t, torch.Tensor):
836
+ t = int(t.item())
837
+ else:
838
+ t = int(t)
839
  sample_frames += t
840
  media_idx += 1
841
  total_frames_per_sample.append(sample_frames)
 
1138
  **kwargs,
1139
  )
1140
 
1141
+ __all__ = ["MossVLProcessor", "MossVLImageProcessor"]