|
|
import os |
|
|
|
|
|
import numpy as np |
|
|
import cv2 as cv |
|
|
|
|
|
from .base_dataloader import _BaseImageLoader |
|
|
from ..factory import DATALOADERS |
|
|
|
|
|
@DATALOADERS.register |
|
|
class ClassificationImageLoader(_BaseImageLoader): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self._to_rgb = kwargs.pop('toRGB', False) |
|
|
self._center_crop = kwargs.pop('centerCrop', None) |
|
|
|
|
|
def _toRGB(self, image): |
|
|
return cv.cvtColor(image, cv.COLOR_BGR2RGB) |
|
|
|
|
|
def _centerCrop(self, image): |
|
|
h, w, _ = image.shape |
|
|
w_crop = int((w - self._center_crop) / 2.) |
|
|
assert w_crop >= 0 |
|
|
h_crop = int((h - self._center_crop) / 2.) |
|
|
assert h_crop >= 0 |
|
|
return image[w_crop:w-w_crop, h_crop:h-h_crop, :] |
|
|
|
|
|
def __iter__(self): |
|
|
for filename in self._files: |
|
|
image = cv.imread(os.path.join(self._path, filename)) |
|
|
|
|
|
if self._to_rgb: |
|
|
image = self._toRGB(image) |
|
|
|
|
|
if [0, 0] in self._sizes: |
|
|
yield filename, image |
|
|
else: |
|
|
for size in self._sizes: |
|
|
image = cv.resize(image, size) |
|
|
if self._center_crop: |
|
|
image = self._centerCrop(image) |
|
|
yield filename, image |