File size: 5,474 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
diff --git a/src/transformers/models/pixtral/image_processing_pixtral_fast.py b/src/transformers/models/pixtral/image_processing_pixtral_fast.py
index 82fbf3b2c094..5fa23923fe74 100644
--- a/src/transformers/models/pixtral/image_processing_pixtral_fast.py
+++ b/src/transformers/models/pixtral/image_processing_pixtral_fast.py
@@ -346,4 +346,7 @@ def preprocess(
             batch_images.append(images)
             batch_image_sizes.append(image_sizes)
 
-        return BatchMixFeature(data={"pixel_values": batch_images, "image_sizes": batch_image_sizes}, tensor_type=None)
+        return BatchMixFeature(
+            data={"pixel_values": batch_images, "image_sizes": batch_image_sizes},
+            tensor_type=None,
+        )
diff --git a/tests/models/pixtral/test_image_processing_pixtral.py b/tests/models/pixtral/test_image_processing_pixtral.py
index a45ead506129..1377b676917f 100644
--- a/tests/models/pixtral/test_image_processing_pixtral.py
+++ b/tests/models/pixtral/test_image_processing_pixtral.py
@@ -19,8 +19,15 @@
 
 import numpy as np
 import requests
-
-from transformers.testing_utils import require_torch, require_vision
+from packaging import version
+
+from transformers.testing_utils import (
+    require_torch,
+    require_torch_gpu,
+    require_vision,
+    slow,
+    torch_device,
+)
 from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
 
 from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@@ -157,6 +164,9 @@ def test_image_processor_properties(self):
             self.assertTrue(hasattr(image_processing, "image_std"))
             self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
 
+    # The following tests are overriden as PixtralImageProcessor can return images of different sizes
+    # and thus doesn't support returning batched tensors
+
     def test_call_pil(self):
         for image_processing_class in self.image_processor_list:
             # Initialize image_processing
@@ -273,6 +283,25 @@ def test_slow_fast_equivalence(self):
 
         self.assertTrue(torch.allclose(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], atol=1e-2))
 
+    @slow
+    @require_torch_gpu
+    @require_vision
+    def test_can_compile_fast_image_processor(self):
+        if self.fast_image_processing_class is None:
+            self.skipTest("Skipping compilation test as fast image processor is not defined")
+        if version.parse(torch.__version__) < version.parse("2.3"):
+            self.skipTest(reason="This test requires torch >= 2.3 to run.")
+
+        torch.compiler.reset()
+        input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
+        image_processor = self.fast_image_processing_class(**self.image_processor_dict)
+        output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
+
+        image_processor = torch.compile(image_processor, mode="reduce-overhead")
+        output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
+
+        self.assertTrue(torch.allclose(output_eager.pixel_values[0][0], output_compiled.pixel_values[0][0], atol=1e-4))
+
     @unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet")  # FIXME Amy
     def test_call_numpy_4_channels(self):
         pass
diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py
index 221552175a93..1cb92174df1d 100644
--- a/tests/test_image_processing_common.py
+++ b/tests/test_image_processing_common.py
@@ -23,10 +23,18 @@
 
 import numpy as np
 import requests
+from packaging import version
 
 from transformers import AutoImageProcessor, BatchFeature
 from transformers.image_utils import AnnotationFormat, AnnotionFormat
-from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
+from transformers.testing_utils import (
+    check_json_file_has_correct_format,
+    require_torch,
+    require_torch_gpu,
+    require_vision,
+    slow,
+    torch_device,
+)
 from transformers.utils import is_torch_available, is_vision_available
 
 
@@ -463,6 +471,25 @@ def test_image_processor_preprocess_arguments(self):
         if not is_tested:
             self.skipTest(reason="No validation found for `preprocess` method")
 
+    @slow
+    @require_torch_gpu
+    @require_vision
+    def test_can_compile_fast_image_processor(self):
+        if self.fast_image_processing_class is None:
+            self.skipTest("Skipping compilation test as fast image processor is not defined")
+        if version.parse(torch.__version__) < version.parse("2.3"):
+            self.skipTest(reason="This test requires torch >= 2.3 to run.")
+
+        torch.compiler.reset()
+        input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
+        image_processor = self.fast_image_processing_class(**self.image_processor_dict)
+        output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
+
+        image_processor = torch.compile(image_processor, mode="reduce-overhead")
+        output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
+
+        self.assertTrue(torch.allclose(output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4))
+
 
 class AnnotationFormatTestMixin:
     # this mixin adds a test to assert that usages of the