| |
| |
| |
| |
| @@ -49,6 +49,7 @@ |
| is_vision_available, |
| logging, |
| ) |
| +from .utils.import_utils import is_rocm_platform |
| |
| |
| if is_vision_available(): |
| @@ -279,8 +280,34 @@ def resize( |
| "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" |
| f" {size}." |
| ) |
| + # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs |
| + # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209 |
| + # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd) |
| + if torch.compiler.is_compiling() and is_rocm_platform(): |
| + return self.compile_friendly_resize(image, new_size, interpolation, antialias) |
| return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) |
| |
| + @staticmethod |
| + def compile_friendly_resize( |
| + image: "torch.Tensor", |
| + new_size: tuple[int, int], |
| + interpolation: Optional["F.InterpolationMode"] = None, |
| + antialias: bool = True, |
| + ) -> "torch.Tensor": |
| + """ |
| + A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor. |
| + """ |
| + if image.dtype == torch.uint8: |
| + image = image.float() / 256 |
| + image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) |
| + image = image * 256 |
| + image = image.masked_fill(image > 255, 255) |
| + image = image.masked_fill(image < 0, 0) |
| + image = image.round().to(torch.uint8) |
| + else: |
| + image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) |
| + return image |
| + |
| def rescale( |
| self, |
| image: "torch.Tensor", |
| |
| |
| |
| |
| @@ -164,13 +164,18 @@ def resize( |
| raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") |
| shorter = size.shortest_edge |
| longer = int(1333 / 800 * shorter) |
| - output_size = get_resize_output_image_size( |
| + output_height, output_width = get_resize_output_image_size( |
| image, |
| shorter=shorter, |
| longer=longer, |
| size_divisor=size_divisor, |
| ) |
| - return F.resize(image, output_size, interpolation=interpolation, antialias=antialias) |
| + return super().resize( |
| + image=image, |
| + size=SizeDict(height=output_height, width=output_width), |
| + interpolation=interpolation, |
| + antialias=antialias, |
| + ) |
| |
| def center_crop( |
| self, |
| |
| |
| |
| |
| @@ -137,7 +137,11 @@ def _resize_for_patching( |
| new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) |
| |
| # Resize the image |
| - resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) |
| + resized_image = self.resize( |
| + image=image, |
| + size=SizeDict(height=new_height, width=new_width), |
| + interpolation=interpolation, |
| + ) |
| |
| return resized_image |
| |
| |
| |
| |
| |
| @@ -142,7 +142,11 @@ def _resize_for_patching( |
| new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) |
| |
| # Resize the image |
| - resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) |
| + resized_image = self.resize( |
| + image=image, |
| + size=SizeDict(height=new_height, width=new_width), |
| + interpolation=interpolation, |
| + ) |
| |
| return resized_image |
| |
| |
| |
| |
| |
| @@ -143,8 +143,10 @@ def _preprocess( |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| ) |
| - stacked_videos = F.resize( |
| - stacked_videos, size=(resized_height, resized_width), interpolation=interpolation |
| + stacked_videos = self.resize( |
| + image=stacked_videos, |
| + size=SizeDict(height=resized_height, width=resized_width), |
| + interpolation=interpolation, |
| ) |
| resized_videos_grouped[shape] = stacked_videos |
| resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) |
| |
| |
| |
| |
| @@ -481,7 +481,7 @@ def is_cuda_platform(): |
| if is_torch_available(): |
| import torch |
| |
| - torch.version.cuda is not None |
| + return torch.version.cuda is not None |
| else: |
| return False |
| |
| @@ -490,7 +490,7 @@ def is_rocm_platform(): |
| if is_torch_available(): |
| import torch |
| |
| - torch.version.hip is not None |
| + return torch.version.hip is not None |
| else: |
| return False |
| |
|
|