File size: 4,123 Bytes
fe365dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import numpy as np
import torch
from PIL import Image
from typing import List, Optional, Union, Dict
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import logging
# Local import of your existing logic
# (Assuming smart_resize and convert_image_to_patches are in the same folder or copied here)
from .image_processor import smart_resize, convert_image_to_patches, pad_along_first_dim
logger = logging.get_logger(__name__)
class AMOEImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values", "padding_mask", "spatial_shapes"]
def __init__(
self,
patch_size: int = 16,
min_pixels: int = 128 * 128,
max_pixels: int = 256 * 256,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
do_resize: bool = True,
do_rescale: bool = True,
do_normalize: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.patch_size = patch_size
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
def preprocess_single(self, image: Image.Image) -> Dict:
"""Standard preprocessing for a single PIL image."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
width, height = image.size # PIL uses (W, H)
# 1. Smart Resize
if self.do_resize:
resized_height, resized_width = smart_resize(
height, width,
factor=self.patch_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
image = image.resize((resized_width, resized_height), Image.BICUBIC)
else:
resized_height, resized_width = height, width
image_np = np.array(image).astype(np.float32)
# 2. Rescale
if self.do_rescale:
image_np = image_np / 255.0
# 3. Normalize
if self.do_normalize:
mean = np.array(self.image_mean, dtype=np.float32)
std = np.array(self.image_std, dtype=np.float32)
image_np = (image_np - mean) / std
spatial_shape = (resized_height // self.patch_size, resized_width // self.patch_size)
# Convert to tensor and patchify
img_tensor = torch.from_numpy(image_np)
patches = convert_image_to_patches(img_tensor, self.patch_size)
return {
"patches": patches,
"spatial_shape": spatial_shape
}
def preprocess(
self,
images: Union[Image.Image, List[Image.Image]],
max_num_patches: int = 256,
return_tensors: Optional[str] = "pt",
**kwargs
) -> BatchFeature:
"""Main entry point for transformers image processor."""
if not isinstance(images, (list, tuple)):
images = [images]
results = [self.preprocess_single(img) for img in images]
batched_pixels = []
batched_masks = []
batched_shapes = []
for res in results:
patches = res["patches"]
shape = res["spatial_shape"]
# Padding logic
patches_padded, mask = pad_along_first_dim(
patches,
max_num_patches,
pad_value=0.0
)
batched_pixels.append(patches_padded)
batched_masks.append(mask)
batched_shapes.append(list(shape))
data = {
"pixel_values": torch.stack(batched_pixels),
"padding_mask": torch.stack(batched_masks),
"spatial_shapes": torch.tensor(batched_shapes)
}
return BatchFeature(data=data, tensor_type=return_tensors)
|