| import logging |
| import os |
| from functools import partial |
| from multiprocessing.pool import ThreadPool |
| from typing import Dict, List, Optional, Tuple |
|
|
| import cv2 |
| import numpy as np |
| from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file |
| from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts |
| from timm.data.readers.reader import Reader |
| from tqdm import tqdm |
|
|
| CROP_ROUND_TOL = 0.3 |
| MIN_PERSON_SIZE = 100 |
| MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4 |
|
|
| _logger = logging.getLogger("ReaderAgeGender") |
|
|
|
|
| class ReaderAgeGender(Reader): |
| """ |
| Reader for almost original imdb-wiki cleaned dataset. |
| Two changes: |
| 1. Your annotation must be in ./annotation subdir of dataset root |
| 2. Images must be in images subdir |
| |
| """ |
|
|
| def __init__( |
| self, |
| images_path, |
| annotations_path, |
| split="validation", |
| target_size=224, |
| min_size=5, |
| seed=1234, |
| with_persons=False, |
| min_person_size=MIN_PERSON_SIZE, |
| disable_faces=False, |
| only_age=False, |
| min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO, |
| crop_round_tol=CROP_ROUND_TOL, |
| ): |
| super().__init__() |
|
|
| self.with_persons = with_persons |
| self.disable_faces = disable_faces |
| self.only_age = only_age |
|
|
| |
| self.crop_out_color = (0, 0, 0) |
|
|
| self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color |
| self.empty_crop = self.empty_crop.astype(np.uint8) |
|
|
| self.min_person_size = min_person_size |
| self.min_person_aftercut_ratio = min_person_aftercut_ratio |
| self.crop_round_tol = crop_round_tol |
|
|
| self.split = split |
| self.min_size = min_size |
| self.seed = seed |
| self.target_size = target_size |
|
|
| |
| self._ann: Dict[str, List[PictureInfo]] = {} |
| self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} |
| self._faces_list: List[Tuple[str, int]] = [] |
|
|
| self._read_annotations(images_path, annotations_path) |
| _logger.info(f"Dataset length: {len(self._faces_list)} crops") |
|
|
| def __getitem__(self, index): |
| return self._read_img_and_label(index) |
|
|
| def __len__(self): |
| return len(self._faces_list) |
|
|
| def _filename(self, index, basename=False, absolute=False): |
| img_p = self._faces_list[index][0] |
| return os.path.basename(img_p) if basename else img_p |
|
|
| def _read_annotations(self, images_path, csvs_path): |
| self._ann = {} |
| self._faces_list = [] |
| self._associated_objects = {} |
|
|
| csvs = get_all_files(csvs_path, [".csv"]) |
| csvs = [c for c in csvs if self.split in os.path.basename(c)] |
|
|
| |
| for csv in csvs: |
| db, ann_type = read_csv_annotation_file(csv, images_path) |
| if self.with_persons and ann_type != AnnotType.PERSONS: |
| raise ValueError( |
| f"Annotation type in file {csv} contains no persons, " |
| f"but annotations with persons are requested." |
| ) |
| self._ann.update(db) |
|
|
| if len(self._ann) == 0: |
| raise ValueError("Annotations are empty!") |
|
|
| self._ann, self._associated_objects = self.prepare_annotations() |
| images_list = list(self._ann.keys()) |
|
|
| for img_path in images_list: |
| for index, image_sample_info in enumerate(self._ann[img_path]): |
| assert image_sample_info.has_gt( |
| self.only_age |
| ), "Annotations must be checked with self.prepare_annotations() func" |
| self._faces_list.append((img_path, index)) |
|
|
| def _read_img_and_label(self, index): |
| if not isinstance(index, int): |
| raise TypeError("ReaderAgeGender expected index to be integer") |
|
|
| img_p, face_index = self._faces_list[index] |
| ann: PictureInfo = self._ann[img_p][face_index] |
| img = cv2.imread(img_p) |
|
|
| face_empty = True |
| if ann.has_face_bbox and not (self.with_persons and self.disable_faces): |
| face_crop, face_empty = self._get_crop(ann.bbox, img) |
|
|
| if not self.with_persons and face_empty: |
| |
| raise ValueError("Annotations must be checked with self.prepare_annotations() func") |
|
|
| if face_empty: |
| face_crop = self.empty_crop |
|
|
| person_empty = True |
| if self.with_persons or self.disable_faces: |
| if ann.has_person_bbox: |
| |
| objects = self._associated_objects[img_p][face_index] |
| person_crop, person_empty = self._get_crop( |
| ann.person_bbox, |
| img, |
| crop_out_color=self.crop_out_color, |
| asced_objects=objects, |
| ) |
|
|
| if face_empty and person_empty: |
| raise ValueError("Annotations must be checked with self.prepare_annotations() func") |
|
|
| if person_empty: |
| person_crop = self.empty_crop |
|
|
| return (face_crop, person_crop), [ann.age, ann.gender] |
|
|
| def _get_crop( |
| self, |
| bbox, |
| img, |
| asced_objects=None, |
| crop_out_color=(0, 0, 0), |
| ) -> Tuple[np.ndarray, bool]: |
|
|
| empty_bbox = False |
|
|
| xmin, ymin, xmax, ymax = bbox |
| assert not ( |
| ymax - ymin < self.min_size or xmax - xmin < self.min_size |
| ), "Annotations must be checked with self.prepare_annotations() func" |
|
|
| crop = img[ymin:ymax, xmin:xmax] |
|
|
| if asced_objects: |
| |
| crop, empty_bbox = _cropout_asced_objs( |
| asced_objects, |
| bbox, |
| crop.copy(), |
| crop_out_color=crop_out_color, |
| min_person_size=self.min_person_size, |
| crop_round_tol=self.crop_round_tol, |
| min_person_aftercut_ratio=self.min_person_aftercut_ratio, |
| ) |
| if empty_bbox: |
| crop = self.empty_crop |
|
|
| crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color) |
| return crop, empty_bbox |
|
|
| def prepare_annotations(self): |
|
|
| good_anns: Dict[str, List[PictureInfo]] = {} |
| all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {} |
|
|
| if not self.with_persons: |
| |
| for img_path, bboxes in self._ann.items(): |
| for sample in bboxes: |
| sample.clear_person_bbox() |
|
|
| |
| verify_images_func = partial( |
| verify_images, |
| min_size=self.min_size, |
| min_person_size=self.min_person_size, |
| with_persons=self.with_persons, |
| disable_faces=self.disable_faces, |
| crop_round_tol=self.crop_round_tol, |
| min_person_aftercut_ratio=self.min_person_aftercut_ratio, |
| only_age=self.only_age, |
| ) |
| num_threads = min(8, os.cpu_count()) |
|
|
| all_msgs = [] |
| broken = 0 |
| skipped = 0 |
| all_skipped_crops = 0 |
| desc = "Check annotations..." |
| with ThreadPool(num_threads) as pool: |
| pbar = tqdm( |
| pool.imap_unordered(verify_images_func, list(self._ann.items())), |
| desc=desc, |
| total=len(self._ann), |
| ) |
|
|
| for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar: |
| broken += 1 if is_corrupted else 0 |
| all_msgs.extend(msgs) |
| all_skipped_crops += skipped_crops |
| skipped += 1 if is_empty_annotations else 0 |
| if img_info is not None: |
| img_path, img_samples = img_info |
| good_anns[img_path] = img_samples |
| all_associated_objects.update({img_path: associated_objects}) |
|
|
| pbar.desc = ( |
| f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); " |
| f"{broken} images corrupted" |
| ) |
|
|
| pbar.close() |
|
|
| for msg in all_msgs: |
| print(msg) |
| print(f"\nLeft images: {len(good_anns)}") |
|
|
| return good_anns, all_associated_objects |
|
|
|
|
| def verify_images( |
| img_info, |
| min_size: int, |
| min_person_size: int, |
| with_persons: bool, |
| disable_faces: bool, |
| crop_round_tol: float, |
| min_person_aftercut_ratio: float, |
| only_age: bool, |
| ): |
| |
| |
|
|
| disable_faces = disable_faces and with_persons |
| kwargs = dict( |
| min_person_size=min_person_size, |
| disable_faces=disable_faces, |
| with_persons=with_persons, |
| crop_round_tol=crop_round_tol, |
| min_person_aftercut_ratio=min_person_aftercut_ratio, |
| only_age=only_age, |
| ) |
|
|
| def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]: |
| ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w) |
| crop_h, crop_w = ymax - ymin, xmax - xmin |
| if crop_h < min_size or crop_w < min_size: |
| return False, [-1, -1, -1, -1] |
| bbox = [xmin, ymin, xmax, ymax] |
| return True, bbox |
|
|
| msgs = [] |
| skipped_crops = 0 |
| is_corrupted = False |
| is_empty_annotations = False |
|
|
| img_path: str = img_info[0] |
| img_samples: List[PictureInfo] = img_info[1] |
| try: |
| im_cv = cv2.imread(img_path) |
| im_h, im_w = im_cv.shape[:2] |
| except Exception: |
| msgs.append(f"Can not load image {img_path}") |
| is_corrupted = True |
| return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops |
|
|
| out_samples: List[PictureInfo] = [] |
| for sample in img_samples: |
| |
| if sample.has_face_bbox: |
| is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w) |
| if not is_correct and sample.has_gt(only_age): |
| msgs.append("Small face. Passing..") |
| skipped_crops += 1 |
|
|
| |
| if sample.has_person_bbox: |
| is_correct, sample.person_bbox = bbox_correct( |
| sample.person_bbox, max(min_person_size, min_size), im_h, im_w |
| ) |
| if not is_correct and sample.has_gt(only_age): |
| msgs.append(f"Small person {img_path}. Passing..") |
| skipped_crops += 1 |
|
|
| if sample.has_face_bbox or sample.has_person_bbox: |
| out_samples.append(sample) |
| elif sample.has_gt(only_age): |
| msgs.append("Sample hs no face and no body. Passing..") |
| skipped_crops += 1 |
|
|
| |
| out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0) |
|
|
| |
| associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age) |
|
|
| out_samples, associated_objects, skipped_crops = filter_bad_samples( |
| out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs |
| ) |
|
|
| out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples) |
| if len(out_samples) == 0: |
| out_img_info = None |
| is_empty_annotations = True |
|
|
| return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops |
|
|
|
|
| def filter_bad_samples( |
| out_samples: List[PictureInfo], |
| associated_objects: dict, |
| im_cv: np.ndarray, |
| msgs: List[str], |
| skipped_crops: int, |
| **kwargs, |
| ): |
| with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = ( |
| kwargs["with_persons"], |
| kwargs["disable_faces"], |
| kwargs["min_person_size"], |
| kwargs["crop_round_tol"], |
| kwargs["min_person_aftercut_ratio"], |
| kwargs["only_age"], |
| ) |
|
|
| |
| inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)] |
| out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) |
|
|
| if kwargs["disable_faces"]: |
| |
| for ind, sample in enumerate(out_samples): |
| sample.clear_face_bbox() |
|
|
| |
| inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox] |
| out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) |
|
|
| if with_persons or disable_faces: |
| |
| |
|
|
| inds = [] |
| for ind, sample in enumerate(out_samples): |
| person_empty = True |
| if sample.has_person_bbox: |
| xmin, ymin, xmax, ymax = sample.person_bbox |
| crop = im_cv[ymin:ymax, xmin:xmax] |
| |
| _, person_empty = _cropout_asced_objs( |
| associated_objects[ind], |
| sample.person_bbox, |
| crop.copy(), |
| min_person_size=min_person_size, |
| crop_round_tol=crop_round_tol, |
| min_person_aftercut_ratio=min_person_aftercut_ratio, |
| ) |
|
|
| if person_empty and not sample.has_face_bbox: |
| msgs.append("Small person after preprocessing. Passing..") |
| skipped_crops += 1 |
| else: |
| inds.append(ind) |
| out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds) |
|
|
| assert len(associated_objects) == len(out_samples) |
| return out_samples, associated_objects, skipped_crops |
|
|
|
|
| def _filter_by_ind(out_samples, associated_objects, inds): |
| _associated_objects = {} |
| _out_samples = [] |
| for ind, sample in enumerate(out_samples): |
| if ind in inds: |
| _associated_objects[len(_out_samples)] = associated_objects[ind] |
| _out_samples.append(sample) |
|
|
| return _out_samples, _associated_objects |
|
|
|
|
| def find_associated_objects( |
| image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False |
| ) -> Dict[int, List[List[int]]]: |
| """ |
| For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it |
| """ |
| associated_objects: Dict[int, List[List[int]]] = {} |
|
|
| for iindex, image_sample_info in enumerate(image_samples): |
| |
| associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else [] |
|
|
| if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age): |
| |
| continue |
|
|
| iperson_box = image_sample_info.person_bbox |
| for jindex, other_image_sample in enumerate(image_samples): |
| if iindex == jindex: |
| continue |
| if other_image_sample.has_face_bbox: |
| jface_bbox = other_image_sample.bbox |
| iou = _get_iou(jface_bbox, iperson_box) |
| if iou >= iou_thresh: |
| associated_objects[iindex].append(jface_bbox) |
| if other_image_sample.has_person_bbox: |
| jperson_bbox = other_image_sample.person_bbox |
| iou = _get_iou(jperson_bbox, iperson_box) |
| if iou >= iou_thresh: |
| associated_objects[iindex].append(jperson_bbox) |
|
|
| return associated_objects |
|
|
|
|
| def _cropout_asced_objs( |
| asced_objects, |
| person_bbox, |
| crop, |
| min_person_size, |
| crop_round_tol, |
| min_person_aftercut_ratio, |
| crop_out_color=(0, 0, 0), |
| ): |
| empty = False |
| xmin, ymin, xmax, ymax = person_bbox |
|
|
| for a_obj in asced_objects: |
| aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj |
|
|
| aobj_ymin = int(max(aobj_ymin - ymin, 0)) |
| aobj_xmin = int(max(aobj_xmin - xmin, 0)) |
| aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin)) |
| aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin)) |
|
|
| crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color |
|
|
| crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol) |
| if ( |
| crop.shape[0] < min_person_size or crop.shape[1] < min_person_size |
| ) or cropped_ratio < min_person_aftercut_ratio: |
| crop = None |
| empty = True |
|
|
| return crop, empty |
|
|
|
|
| def _correct_bbox(bbox, h, w): |
| xmin, ymin, xmax, ymax = bbox |
| ymin = min(max(ymin, 0), h) |
| ymax = min(max(ymax, 0), h) |
| xmin = min(max(xmin, 0), w) |
| xmax = min(max(xmax, 0), w) |
| return ymin, ymax, xmin, xmax |
|
|
|
|
| def _get_iou(bbox1, bbox2): |
| xmin1, ymin1, xmax1, ymax1 = bbox1 |
| xmin2, ymin2, xmax2, ymax2 = bbox2 |
| iou = IOU( |
| [ymin1, xmin1, ymax1, xmax1], |
| [ymin2, xmin2, ymax2, xmax2], |
| ) |
| return iou |
|
|