File size: 6,682 Bytes
dfefe0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index b1e261412203..a0e78e824a71 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -49,6 +49,7 @@
is_vision_available,
logging,
)
+from .utils.import_utils import is_rocm_platform
if is_vision_available():
@@ -279,8 +280,34 @@ def resize(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
+ # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
+ # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
+ # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
+ if torch.compiler.is_compiling() and is_rocm_platform():
+ return self.compile_friendly_resize(image, new_size, interpolation, antialias)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+ @staticmethod
+ def compile_friendly_resize(
+ image: "torch.Tensor",
+ new_size: tuple[int, int],
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ ) -> "torch.Tensor":
+ """
+ A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
+ """
+ if image.dtype == torch.uint8:
+ image = image.float() / 256
+ image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+ image = image * 256
+ image = image.masked_fill(image > 255, 255)
+ image = image.masked_fill(image < 0, 0)
+ image = image.round().to(torch.uint8)
+ else:
+ image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+ return image
+
def rescale(
self,
image: "torch.Tensor",
diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
index 68d5b2c0f899..01df06c52934 100644
--- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
+++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
@@ -164,13 +164,18 @@ def resize(
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size.shortest_edge
longer = int(1333 / 800 * shorter)
- output_size = get_resize_output_image_size(
+ output_height, output_width = get_resize_output_image_size(
image,
shorter=shorter,
longer=longer,
size_divisor=size_divisor,
)
- return F.resize(image, output_size, interpolation=interpolation, antialias=antialias)
+ return super().resize(
+ image=image,
+ size=SizeDict(height=output_height, width=output_width),
+ interpolation=interpolation,
+ antialias=antialias,
+ )
def center_crop(
self,
diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py
index ac90290cef48..2f7a398479b0 100644
--- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py
+++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py
@@ -137,7 +137,11 @@ def _resize_for_patching(
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
- resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+ resized_image = self.resize(
+ image=image,
+ size=SizeDict(height=new_height, width=new_width),
+ interpolation=interpolation,
+ )
return resized_image
diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
index a29631fcb6af..cf5483e226b9 100644
--- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
+++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
@@ -142,7 +142,11 @@ def _resize_for_patching(
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
- resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+ resized_image = self.resize(
+ image=image,
+ size=SizeDict(height=new_height, width=new_width),
+ interpolation=interpolation,
+ )
return resized_image
diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
index adaf369473b2..ee8077acb2de 100644
--- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
@@ -143,8 +143,10 @@ def _preprocess(
min_pixels=min_pixels,
max_pixels=max_pixels,
)
- stacked_videos = F.resize(
- stacked_videos, size=(resized_height, resized_width), interpolation=interpolation
+ stacked_videos = self.resize(
+ image=stacked_videos,
+ size=SizeDict(height=resized_height, width=resized_width),
+ interpolation=interpolation,
)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index ff097849f971..58d4165c24e0 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -481,7 +481,7 @@ def is_cuda_platform():
if is_torch_available():
import torch
- torch.version.cuda is not None
+ return torch.version.cuda is not None
else:
return False
@@ -490,7 +490,7 @@ def is_rocm_platform():
if is_torch_available():
import torch
- torch.version.hip is not None
+ return torch.version.hip is not None
else:
return False
|