diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index b1e261412203..a0e78e824a71 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -49,6 +49,7 @@ is_vision_available, logging, ) +from .utils.import_utils import is_rocm_platform if is_vision_available(): @@ -279,8 +280,34 @@ def resize( "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" f" {size}." ) + # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs + # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209 + # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd) + if torch.compiler.is_compiling() and is_rocm_platform(): + return self.compile_friendly_resize(image, new_size, interpolation, antialias) return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + @staticmethod + def compile_friendly_resize( + image: "torch.Tensor", + new_size: tuple[int, int], + interpolation: Optional["F.InterpolationMode"] = None, + antialias: bool = True, + ) -> "torch.Tensor": + """ + A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor. + """ + if image.dtype == torch.uint8: + image = image.float() / 256 + image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + image = image * 256 + image = image.masked_fill(image > 255, 255) + image = image.masked_fill(image < 0, 0) + image = image.round().to(torch.uint8) + else: + image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + return image + def rescale( self, image: "torch.Tensor", diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py index 68d5b2c0f899..01df06c52934 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py @@ -164,13 +164,18 @@ def resize( raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") shorter = size.shortest_edge longer = int(1333 / 800 * shorter) - output_size = get_resize_output_image_size( + output_height, output_width = get_resize_output_image_size( image, shorter=shorter, longer=longer, size_divisor=size_divisor, ) - return F.resize(image, output_size, interpolation=interpolation, antialias=antialias) + return super().resize( + image=image, + size=SizeDict(height=output_height, width=output_width), + interpolation=interpolation, + antialias=antialias, + ) def center_crop( self, diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py index ac90290cef48..2f7a398479b0 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -137,7 +137,11 @@ def _resize_for_patching( new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) # Resize the image - resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + resized_image = self.resize( + image=image, + size=SizeDict(height=new_height, width=new_width), + interpolation=interpolation, + ) return resized_image diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index a29631fcb6af..cf5483e226b9 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -142,7 +142,11 @@ def _resize_for_patching( new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) # Resize the image - resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + resized_image = self.resize( + image=image, + size=SizeDict(height=new_height, width=new_width), + interpolation=interpolation, + ) return resized_image diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py index adaf369473b2..ee8077acb2de 100644 --- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py @@ -143,8 +143,10 @@ def _preprocess( min_pixels=min_pixels, max_pixels=max_pixels, ) - stacked_videos = F.resize( - stacked_videos, size=(resized_height, resized_width), interpolation=interpolation + stacked_videos = self.resize( + image=stacked_videos, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, ) resized_videos_grouped[shape] = stacked_videos resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ff097849f971..58d4165c24e0 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -481,7 +481,7 @@ def is_cuda_platform(): if is_torch_available(): import torch - torch.version.cuda is not None + return torch.version.cuda is not None else: return False @@ -490,7 +490,7 @@ def is_rocm_platform(): if is_torch_available(): import torch - torch.version.hip is not None + return torch.version.hip is not None else: return False