harness / diffs /38540.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py
index b1e261412203..a0e78e824a71 100644
--- a/src/transformers/image_processing_utils_fast.py
+++ b/src/transformers/image_processing_utils_fast.py
@@ -49,6 +49,7 @@
is_vision_available,
logging,
)
+from .utils.import_utils import is_rocm_platform
if is_vision_available():
@@ -279,8 +280,34 @@ def resize(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
+ # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs
+ # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209
+ # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd)
+ if torch.compiler.is_compiling() and is_rocm_platform():
+ return self.compile_friendly_resize(image, new_size, interpolation, antialias)
return F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+ @staticmethod
+ def compile_friendly_resize(
+ image: "torch.Tensor",
+ new_size: tuple[int, int],
+ interpolation: Optional["F.InterpolationMode"] = None,
+ antialias: bool = True,
+ ) -> "torch.Tensor":
+ """
+ A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor.
+ """
+ if image.dtype == torch.uint8:
+ image = image.float() / 256
+ image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+ image = image * 256
+ image = image.masked_fill(image > 255, 255)
+ image = image.masked_fill(image < 0, 0)
+ image = image.round().to(torch.uint8)
+ else:
+ image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias)
+ return image
+
def rescale(
self,
image: "torch.Tensor",
diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
index 68d5b2c0f899..01df06c52934 100644
--- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
+++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py
@@ -164,13 +164,18 @@ def resize(
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size.shortest_edge
longer = int(1333 / 800 * shorter)
- output_size = get_resize_output_image_size(
+ output_height, output_width = get_resize_output_image_size(
image,
shorter=shorter,
longer=longer,
size_divisor=size_divisor,
)
- return F.resize(image, output_size, interpolation=interpolation, antialias=antialias)
+ return super().resize(
+ image=image,
+ size=SizeDict(height=output_height, width=output_width),
+ interpolation=interpolation,
+ antialias=antialias,
+ )
def center_crop(
self,
diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py
index ac90290cef48..2f7a398479b0 100644
--- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py
+++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py
@@ -137,7 +137,11 @@ def _resize_for_patching(
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
- resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+ resized_image = self.resize(
+ image=image,
+ size=SizeDict(height=new_height, width=new_width),
+ interpolation=interpolation,
+ )
return resized_image
diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
index a29631fcb6af..cf5483e226b9 100644
--- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
+++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
@@ -142,7 +142,11 @@ def _resize_for_patching(
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
- resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
+ resized_image = self.resize(
+ image=image,
+ size=SizeDict(height=new_height, width=new_width),
+ interpolation=interpolation,
+ )
return resized_image
diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
index adaf369473b2..ee8077acb2de 100644
--- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
@@ -143,8 +143,10 @@ def _preprocess(
min_pixels=min_pixels,
max_pixels=max_pixels,
)
- stacked_videos = F.resize(
- stacked_videos, size=(resized_height, resized_width), interpolation=interpolation
+ stacked_videos = self.resize(
+ image=stacked_videos,
+ size=SizeDict(height=resized_height, width=resized_width),
+ interpolation=interpolation,
)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index ff097849f971..58d4165c24e0 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -481,7 +481,7 @@ def is_cuda_platform():
if is_torch_available():
import torch
- torch.version.cuda is not None
+ return torch.version.cuda is not None
else:
return False
@@ -490,7 +490,7 @@ def is_rocm_platform():
if is_torch_available():
import torch
- torch.version.hip is not None
+ return torch.version.hip is not None
else:
return False