File size: 4,389 Bytes
372980e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import List, Optional, Union, Tuple
import PIL
import torch
from torchvision.transforms.v2 import (
    Compose,
    Lambda,
    Resize,
    Normalize,
    InterpolationMode,
)
import numpy as np

from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import ChannelDimension, to_numpy_array
from transformers.utils import TensorType, logging


logger = logging.get_logger(__name__)


class VAEImageProcessor(BaseImageProcessor):

    model_input_names = ["pixel_values"]

    def __init__(
            self,
            do_resize:bool = True,
            image_size: Tuple[int, int]=[64, 64],
            do_rescale: bool = True,
            rescale_factor: Union[int, float] = 1 / 255,
            do_normalize: bool = True,
            image_mean: Optional[Union[List[float]]] = [0.5, 0.5, 0.5],
            image_std: Optional[Union[List[float]]] = [0.5, 0.5, 0.5],
            *args,
            **kwargs
        ):
        super().__init__(*args, **kwargs)
        self.do_resize = do_resize
        self.image_size = image_size
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = image_mean
        self.image_std = image_std

    def preprocess(
            self,
            images: Union["PIL.Image.Image", np.ndarray, List["PIL.Image.Image"], List[np.ndarray]],
            is_video: bool = False,
            return_tensors: Optional[Union[str, TensorType]] = "pt",
            input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.LAST,
            **kwargs
        ):
        if isinstance(images, list):
            images = [to_numpy_array(image) for image in images]
            images = torch.from_numpy(np.stack(images, axis=0)).float()
        else:
            images = to_numpy_array(images)
            images = torch.from_numpy(images).float()

        if is_video:
            if images.dim() == 4:
                images = images.unsqueeze(0)
            if input_data_format == ChannelDimension.LAST:
                images = images.permute(0, 1, 4, 2, 3)
        else:
            if images.dim() == 3:
                images = images.unsqueeze(0)
            if input_data_format == ChannelDimension.LAST:
                images = images.permute(0, 3, 1, 2)
        compose_tf = Compose(
                    [
                        Resize(self.image_size, interpolation=InterpolationMode.BICUBIC) if self.do_resize else Lambda(lambda x: x),
                        Lambda(lambda x: x / 255.0) if self.do_rescale else Lambda(lambda x: x),
                        Normalize(self.image_mean, self.image_std) if self.do_normalize else Lambda(lambda x: x),
                    ]
                )
        images = compose_tf(images)

        return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)

    def postprocess(
            self,
            images:  "torch.Tensor",
            is_video: bool = False,
            return_tensors:  Optional[Union[str, TensorType]] = "np",
            input_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
            **kwargs
            ):
        if isinstance(images, np.ndarray):
            images = torch.from_numpy(images).float()
        if isinstance(images, list):
            images = torch.stack(images, dim=0)
        if not isinstance(images, torch.Tensor):
            raise ValueError("images must be a torch.Tensor")

        if is_video:
            if images.dim() == 4:
                images = images.unsqueeze(0)
            if input_data_format == ChannelDimension.FIRST:
                images = images.permute(0, 1, 3, 4, 2)
        else:
            if images.dim() == 3:
                images = images.unsqueeze(0)
            if input_data_format == ChannelDimension.FIRST:
                images = images.permute(0, 2, 3, 1)

        if self.do_normalize:
            images = (images * torch.tensor(self.image_std)) + torch.tensor(self.image_mean)
        if self.do_rescale:
            images = torch.clamp(images, 0, 1)
            images = (images * 255).type(torch.uint8)

        if return_tensors == TensorType.NUMPY:
            images = images.numpy()

        return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)