File size: 4,123 Bytes
fe365dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
from PIL import Image
from typing import List, Optional, Union, Dict
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import logging

# Local import of your existing logic
# (Assuming smart_resize and convert_image_to_patches are in the same folder or copied here)
from .image_processor import smart_resize, convert_image_to_patches, pad_along_first_dim

logger = logging.get_logger(__name__)

class AMOEImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values", "padding_mask", "spatial_shapes"]

    def __init__(
        self,
        patch_size: int = 16,
        min_pixels: int = 128 * 128,
        max_pixels: int = 256 * 256,
        image_mean: Optional[List[float]] = None,
        image_std: Optional[List[float]] = None,
        do_resize: bool = True,
        do_rescale: bool = True,
        do_normalize: bool = True,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
        self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
        self.do_resize = do_resize
        self.do_rescale = do_rescale
        self.do_normalize = do_normalize

    def preprocess_single(self, image: Image.Image) -> Dict:
        """Standard preprocessing for a single PIL image."""
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        image = image.convert("RGB")
        width, height = image.size # PIL uses (W, H)

        # 1. Smart Resize
        if self.do_resize:
            resized_height, resized_width = smart_resize(
                height, width,
                factor=self.patch_size,
                min_pixels=self.min_pixels,
                max_pixels=self.max_pixels,
            )
            image = image.resize((resized_width, resized_height), Image.BICUBIC)
        else:
            resized_height, resized_width = height, width

        image_np = np.array(image).astype(np.float32)

        # 2. Rescale
        if self.do_rescale:
            image_np = image_np / 255.0

        # 3. Normalize
        if self.do_normalize:
            mean = np.array(self.image_mean, dtype=np.float32)
            std = np.array(self.image_std, dtype=np.float32)
            image_np = (image_np - mean) / std

        spatial_shape = (resized_height // self.patch_size, resized_width // self.patch_size)
        
        # Convert to tensor and patchify
        img_tensor = torch.from_numpy(image_np)
        patches = convert_image_to_patches(img_tensor, self.patch_size)
        
        return {
            "patches": patches,
            "spatial_shape": spatial_shape
        }

    def preprocess(
        self,
        images: Union[Image.Image, List[Image.Image]],
        max_num_patches: int = 256,
        return_tensors: Optional[str] = "pt",
        **kwargs
    ) -> BatchFeature:
        """Main entry point for transformers image processor."""
        if not isinstance(images, (list, tuple)):
            images = [images]

        results = [self.preprocess_single(img) for img in images]
        
        batched_pixels = []
        batched_masks = []
        batched_shapes = []

        for res in results:
            patches = res["patches"]
            shape = res["spatial_shape"]
            
            # Padding logic
            patches_padded, mask = pad_along_first_dim(
                patches,
                max_num_patches,
                pad_value=0.0
            )
            
            batched_pixels.append(patches_padded)
            batched_masks.append(mask)
            batched_shapes.append(list(shape))

        data = {
            "pixel_values": torch.stack(batched_pixels),
            "padding_mask": torch.stack(batched_masks),
            "spatial_shapes": torch.tensor(batched_shapes)
        }

        return BatchFeature(data=data, tensor_type=return_tensors)