l45k commited on
Commit
1f2e2c7
·
verified ·
1 Parent(s): 44cadd4

Upload processor

Browse files
Files changed (2) hide show
  1. preprocessor_config.json +2 -2
  2. preprocessor_lenet.py +24 -83
preprocessor_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "auto_map": {
3
- "AutoImageProcessor": "preprocessor_resnet.ResNetProcessor"
4
  },
5
- "image_processor_type": "ResNetProcessor"
6
  }
 
1
  {
2
  "auto_map": {
3
+ "AutoImageProcessor": "preprocessor_lenet.LeNetProcessor"
4
  },
5
+ "image_processor_type": "LeNetProcessor"
6
  }
preprocessor_lenet.py CHANGED
@@ -1,99 +1,40 @@
1
  import numpy as np
2
  from PIL import Image
3
  from transformers import BaseImageProcessor, BatchFeature
4
- from transformers.image_transforms import (
5
- normalize,
6
- to_channel_dimension_format
7
- )
8
- from transformers.image_utils import (
9
- ImageInput,
10
- ChannelDimension
11
- )
12
 
13
  class LeNetProcessor(BaseImageProcessor):
14
  """
15
  A custom processor that only normalizes a grayscale image
16
  and prepares it for a model.
17
  """
 
18
  model_input_names = ["pixel_values"]
19
 
20
- def __init__(
21
- self,
22
- mean: float = 0.1307,
23
- std: float = 0.3081,
24
- **kwargs
25
- ):
26
  """
27
  Args:
28
- mean (float): The mean to use for normalization.
29
- std (float): The std dev to use for normalization.
30
  """
31
  super().__init__(**kwargs)
32
- self.mean = mean
33
- self.std = std
34
 
35
- def preprocess(
36
- self,
37
- images: ImageInput,
38
- return_tensors=None,
39
- **kwargs
40
- ) -> BatchFeature:
41
- class GrayscaleNormalizeProcessor(BaseImageProcessor):
42
- """
43
- A custom processor that only normalizes a grayscale image
44
- and prepares it for a model.
45
- """
46
- model_input_names = ["pixel_values"]
47
-
48
- def __init__(
49
- self,
50
- mean: float = 0.5,
51
- std: float = 0.5,
52
- **kwargs
53
- ):
54
- super().__init__(**kwargs)
55
- self.mean = mean
56
- self.std = std
57
-
58
- def preprocess(
59
- self,
60
- images: ImageInput,
61
- return_tensors=None,
62
- **kwargs
63
- ) -> BatchFeature:
64
- """
65
- Preprocess a batch of grayscale images.
66
- """
67
- if not isinstance(images, list):
68
- images = [images]
69
-
70
- # --- THIS IS THE FIX ---
71
- # Call the built-in self.to_numpy_array method.
72
- # It handles all validation (PIL, numpy, torch, tf)
73
- # and conversion, raising an error if the type is invalid.
74
- # No more manual validation or imports needed.
75
- try:
76
- images = [self.to_numpy_array(img) for img in images]
77
- except ValueError as e:
78
- raise ValueError(
79
- "Input must be a list of PIL Images, NumPy arrays, "
80
- f"PyTorch tensors, or TensorFlow tensors. Error: {e}"
81
- )
82
- # --- END FIX ---
83
-
84
- processed_images = []
85
- for img in images:
86
- if img.ndim == 3 and img.shape[2] == 1:
87
- img = img.squeeze(-1)
88
- elif img.ndim == 3:
89
- raise ValueError(
90
- "Image is not grayscale. "
91
- f"Expected 2D array, but got shape {img.shape}"
92
- )
93
-
94
- img = normalize(img, mean=self.mean, std=self.std)
95
- img = to_channel_dimension_format(img, ChannelDimension.FIRST)
96
- processed_images.append(img)
97
-
98
- data = {"pixel_values": processed_images}
99
- return BatchFeature(data=data, tensor_type=return_tensors)
 
1
  import numpy as np
2
  from PIL import Image
3
  from transformers import BaseImageProcessor, BatchFeature
4
+ from transformers.image_utils import ImageInput
5
+ import torch
6
+ from torchvision.transforms import v2
7
+
 
 
 
 
8
 
9
  class LeNetProcessor(BaseImageProcessor):
10
  """
11
  A custom processor that only normalizes a grayscale image
12
  and prepares it for a model.
13
  """
14
+
15
  model_input_names = ["pixel_values"]
16
 
17
+ def __init__(self, **kwargs):
 
 
 
 
 
18
  """
19
  Args:
 
 
20
  """
21
  super().__init__(**kwargs)
 
 
22
 
23
+ def preprocess(self, images: ImageInput, return_tensors=None, **kwargs) -> BatchFeature:
24
+ """
25
+ Preprocess a batch of grayscale images.
26
+ """
27
+ if not isinstance(images, list):
28
+ images = [images]
29
+
30
+ transform = v2.Compose([
31
+ v2.RandomResizedCrop(size=(28, 28), antialias=True),
32
+ v2.ToDtype(torch.float32, scale=True),
33
+ v2.Normalize(
34
+ mean=[0.1307],
35
+ std=[0.3081]
36
+ ),
37
+ ])
38
+
39
+ data = {"pixel_values": transform(images)}
40
+ return BatchFeature(data=data, tensor_type=return_tensors)