File size: 1,513 Bytes
31c8080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import List, Optional
import numpy as np

from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import is_numpy_array

from mivolo.data.misc import prepare_classification_images


class MiVOLOImageProcessor(BaseImageProcessor):
    
    model_input_names = ["pixel_values"]

    def __init__(
        self, 
        input_size: int = 384, 
        mean: List[float] = [0.485, 0.456, 0.406], 
        std: List[float] = [0.229, 0.224, 0.225], 
        **kwargs
    ) -> None:
        super().__init__(**kwargs)

        self.mean = mean
        self.std = std
        self.input_size = input_size

    def preprocess(
        self,
        images: List[Optional[np.ndarray]],
    ):
        # Transformations expect numpy arrays or None.
        if not valid_images(images):
            raise ValueError(
                "Invalid image type. Must be of type List[numpy.ndarray]."
            )
        
        input = prepare_classification_images(
            images, self.input_size, self.mean, self.std
        )

        data = {"pixel_values": input}
        return BatchFeature(data=data, tensor_type="pt")


def valid_images(imgs):
    # If we have an list of images, make sure every image is valid
    if isinstance(imgs, (list, tuple)):
        for img in imgs:
            if img is None:
                continue
            if not is_numpy_array(img):
                return False
    else:
        return False
    return True