File size: 1,513 Bytes
a088f53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from typing import List, Optional
import numpy as np
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import is_numpy_array
from mivolo.data.misc import prepare_classification_images
class MiVOLOImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
input_size: int = 384,
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225],
**kwargs
) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.input_size = input_size
def preprocess(
self,
images: List[Optional[np.ndarray]],
):
# Transformations expect numpy arrays or None.
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type List[numpy.ndarray]."
)
input = prepare_classification_images(
images, self.input_size, self.mean, self.std
)
data = {"pixel_values": input}
return BatchFeature(data=data, tensor_type="pt")
def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if img is None:
continue
if not is_numpy_array(img):
return False
else:
return False
return True
|