Thastp commited on
Commit
3f02cf8
·
verified ·
1 Parent(s): 17b99e3

Upload processor

Browse files
Files changed (1) hide show
  1. image_processing_efficientnet.py +48 -8
image_processing_efficientnet.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
2
  from timm import create_model
3
  from timm.data import resolve_data_config
@@ -6,19 +11,54 @@ from timm.data.transforms_factory import create_transform
6
  class EfficientNetImageProcessor(BaseImageProcessor):
7
  model_input_names = ["pixel_values"]
8
 
9
- def __init__(self,
10
- model_name: str,
11
- **kwargs
12
- ):
13
- super().__init__(**kwargs)
14
-
15
  self.model_name = model_name
16
  self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
18
 
19
- def preprocess(self, image):
20
  transforms = create_transform(**self.config)
21
- data = {'pixel_values': transforms(image).unsqueeze(0)}
 
 
 
 
 
22
  return BatchFeature(data=data)
23
 
24
  __all__ = [
 
1
+ from PIL import Image
2
+ from torch import Tensor, stack
3
+ from numpy import ndarray
4
+ from typing import Union, List
5
+
6
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
7
  from timm import create_model
8
  from timm.data import resolve_data_config
 
11
  class EfficientNetImageProcessor(BaseImageProcessor):
12
  model_input_names = ["pixel_values"]
13
 
14
+ def __init__(
15
+ self,
16
+ model_name: str,
17
+ **kwargs,
18
+ ):
 
19
  self.model_name = model_name
20
  self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
21
+ super().__init__(**kwargs)
22
+
23
+ def preprocess(
24
+ self,
25
+ images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
26
+ ) -> BatchFeature:
27
+ """
28
+ Preprocesses input images by applying transformations and returning them as a BatchFeature.
29
+
30
+ Parameters
31
+ ----------
32
+ images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
33
+ A single image or a list of images in one of the accepted formats.
34
+
35
+ Returns
36
+ -------
37
+ BatchFeature
38
+ A batch of transformed images
39
+ """
40
+ images = [images] if not isinstance(images, list) else images
41
+
42
+ # TEST: empty list
43
+ if len(images) == 0:
44
+ raise ValueError("Received an empty list of images")
45
 
46
+ # TEST: validate input type
47
+ test_image = images[0]
48
+ if not isinstance(images[0], (Image.Image, Tensor)):
49
+ raise TypeError(
50
+ f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
51
+ f"but got {type(test_image).__name__} instead."
52
+ )
53
 
54
+ # Apply transformations
55
  transforms = create_transform(**self.config)
56
+ transformed_images = [transforms(image) for image in images]
57
+
58
+ # Convert to batch tensor
59
+ transformed_image_tensors = stack(transformed_images)
60
+
61
+ data = {'pixel_values': transformed_image_tensors}
62
  return BatchFeature(data=data)
63
 
64
  __all__ = [