mivolo_v2 / mivolo_image_processor.py
iitolstykh's picture
Upload 8 files
a088f53 verified
from typing import List, Optional
import numpy as np
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import is_numpy_array
from mivolo.data.misc import prepare_classification_images
class MiVOLOImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
input_size: int = 384,
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225],
**kwargs
) -> None:
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.input_size = input_size
def preprocess(
self,
images: List[Optional[np.ndarray]],
):
# Transformations expect numpy arrays or None.
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type List[numpy.ndarray]."
)
input = prepare_classification_images(
images, self.input_size, self.mean, self.std
)
data = {"pixel_values": input}
return BatchFeature(data=data, tensor_type="pt")
def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if img is None:
continue
if not is_numpy_array(img):
return False
else:
return False
return True