File size: 6,682 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index b1e261412203..a0e78e824a71 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -49,6 +49,7 @@
     is_vision_available,
     logging,
 )
+from .utils.import_utils import is_rocm_platform
 
 
 if is_vision_available():
@@ -279,8 +280,34 @@ def resize(
                 "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
                 f" {size}."
             )
+        # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
+        # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
+        # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
+        if torch.compiler.is_compiling() and is_rocm_platform():
+            return self.compile_friendly_resize(image, new_size, interpolation, antialias)
         return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
 
+    @staticmethod
+    def compile_friendly_resize(
+        image: "torch.Tensor",
+        new_size: tuple[int, int],
+        interpolation: Optional["F.InterpolationMode"] = None,
+        antialias: bool = True,
+    ) -> "torch.Tensor":
+        """
+        A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
+        """
+        if image.dtype == torch.uint8:
+            image = image.float() / 256
+            image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+            image = image * 256
+            image = image.masked_fill(image > 255, 255)
+            image = image.masked_fill(image < 0, 0)
+            image = image.round().to(torch.uint8)
+        else:
+            image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+        return image
+
     def rescale(
         self,
         image: "torch.Tensor",
diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
index 68d5b2c0f899..01df06c52934 100644
--- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
+++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
@@ -164,13 +164,18 @@ def resize(
             raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
         shorter = size.shortest_edge
         longer = int(1333 / 800 * shorter)
-        output_size = get_resize_output_image_size(
+        output_height, output_width = get_resize_output_image_size(
             image,
             shorter=shorter,
             longer=longer,
             size_divisor=size_divisor,
         )
-        return F.resize(image, output_size, interpolation=interpolation, antialias=antialias)
+        return super().resize(
+            image=image,
+            size=SizeDict(height=output_height, width=output_width),
+            interpolation=interpolation,
+            antialias=antialias,
+        )
 
     def center_crop(
         self,
diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py
index ac90290cef48..2f7a398479b0 100644
--- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py
+++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py
@@ -137,7 +137,11 @@ def _resize_for_patching(
         new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
 
         # Resize the image
-        resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+        resized_image = self.resize(
+            image=image,
+            size=SizeDict(height=new_height, width=new_width),
+            interpolation=interpolation,
+        )
 
         return resized_image
 
diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
index a29631fcb6af..cf5483e226b9 100644
--- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
+++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
@@ -142,7 +142,11 @@ def _resize_for_patching(
         new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
 
         # Resize the image
-        resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+        resized_image = self.resize(
+            image=image,
+            size=SizeDict(height=new_height, width=new_width),
+            interpolation=interpolation,
+        )
 
         return resized_image
 
diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
index adaf369473b2..ee8077acb2de 100644
--- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
@@ -143,8 +143,10 @@ def _preprocess(
                     min_pixels=min_pixels,
                     max_pixels=max_pixels,
                 )
-                stacked_videos = F.resize(
-                    stacked_videos, size=(resized_height, resized_width), interpolation=interpolation
+                stacked_videos = self.resize(
+                    image=stacked_videos,
+                    size=SizeDict(height=resized_height, width=resized_width),
+                    interpolation=interpolation,
                 )
             resized_videos_grouped[shape] = stacked_videos
         resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index ff097849f971..58d4165c24e0 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -481,7 +481,7 @@ def is_cuda_platform():
     if is_torch_available():
         import torch
 
-        torch.version.cuda is not None
+        return torch.version.cuda is not None
     else:
         return False
 
@@ -490,7 +490,7 @@ def is_rocm_platform():
     if is_torch_available():
         import torch
 
-        torch.version.hip is not None
+        return torch.version.hip is not None
     else:
         return False