|
|
from typing import List, Optional |
|
|
import numpy as np |
|
|
|
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
|
from transformers.image_utils import is_numpy_array |
|
|
|
|
|
from mivolo.data.misc import prepare_classification_images |
|
|
|
|
|
|
|
|
class MiVOLOImageProcessor(BaseImageProcessor): |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_size: int = 384, |
|
|
mean: List[float] = [0.485, 0.456, 0.406], |
|
|
std: List[float] = [0.229, 0.224, 0.225], |
|
|
**kwargs |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.mean = mean |
|
|
self.std = std |
|
|
self.input_size = input_size |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images: List[Optional[np.ndarray]], |
|
|
): |
|
|
|
|
|
if not valid_images(images): |
|
|
raise ValueError( |
|
|
"Invalid image type. Must be of type List[numpy.ndarray]." |
|
|
) |
|
|
|
|
|
input = prepare_classification_images( |
|
|
images, self.input_size, self.mean, self.std |
|
|
) |
|
|
|
|
|
data = {"pixel_values": input} |
|
|
return BatchFeature(data=data, tensor_type="pt") |
|
|
|
|
|
|
|
|
def valid_images(imgs): |
|
|
|
|
|
if isinstance(imgs, (list, tuple)): |
|
|
for img in imgs: |
|
|
if img is None: |
|
|
continue |
|
|
if not is_numpy_array(img): |
|
|
return False |
|
|
else: |
|
|
return False |
|
|
return True |
|
|
|