Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import imgaug | |
| import imgaug.augmenters as iaa | |
| import mmcv | |
| import numpy as np | |
| from mmdet.core.mask import PolygonMasks | |
| from mmdet.datasets.builder import PIPELINES | |
| class AugmenterBuilder: | |
| """Build imgaug object according ImgAug argmentations.""" | |
| def __init__(self): | |
| pass | |
| def build(self, args, root=True): | |
| if args is None: | |
| return None | |
| if isinstance(args, (int, float, str)): | |
| return args | |
| if isinstance(args, list): | |
| if root: | |
| sequence = [self.build(value, root=False) for value in args] | |
| return iaa.Sequential(sequence) | |
| arg_list = [self.to_tuple_if_list(a) for a in args[1:]] | |
| return getattr(iaa, args[0])(*arg_list) | |
| if isinstance(args, dict): | |
| if 'cls' in args: | |
| cls = getattr(iaa, args['cls']) | |
| return cls( | |
| **{ | |
| k: self.to_tuple_if_list(v) | |
| for k, v in args.items() if not k == 'cls' | |
| }) | |
| else: | |
| return { | |
| key: self.build(value, root=False) | |
| for key, value in args.items() | |
| } | |
| raise RuntimeError('unknown augmenter arg: ' + str(args)) | |
| def to_tuple_if_list(self, obj): | |
| if isinstance(obj, list): | |
| return tuple(obj) | |
| return obj | |
| class ImgAug: | |
| """A wrapper to use imgaug https://github.com/aleju/imgaug. | |
| Args: | |
| args ([list[list|dict]]): The argumentation list. For details, please | |
| refer to imgaug document. Take args=[['Fliplr', 0.5], | |
| dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an | |
| example. The args horizontally flip images with probability 0.5, | |
| followed by random rotation with angles in range [-10, 10], and | |
| resize with an independent scale in range [0.5, 3.0] for each | |
| side of images. | |
| """ | |
| def __init__(self, args=None): | |
| self.augmenter_args = args | |
| self.augmenter = AugmenterBuilder().build(self.augmenter_args) | |
| def __call__(self, results): | |
| # img is bgr | |
| image = results['img'] | |
| aug = None | |
| shape = image.shape | |
| if self.augmenter: | |
| aug = self.augmenter.to_deterministic() | |
| results['img'] = aug.augment_image(image) | |
| results['img_shape'] = results['img'].shape | |
| results['flip'] = 'unknown' # it's unknown | |
| results['flip_direction'] = 'unknown' # it's unknown | |
| target_shape = results['img_shape'] | |
| self.may_augment_annotation(aug, shape, target_shape, results) | |
| return results | |
| def may_augment_annotation(self, aug, shape, target_shape, results): | |
| if aug is None: | |
| return results | |
| # augment polygon mask | |
| for key in results['mask_fields']: | |
| masks = self.may_augment_poly(aug, shape, results[key]) | |
| if len(masks) > 0: | |
| results[key] = PolygonMasks(masks, *target_shape[:2]) | |
| # augment bbox | |
| for key in results['bbox_fields']: | |
| bboxes = self.may_augment_poly( | |
| aug, shape, results[key], mask_flag=False) | |
| results[key] = np.zeros(0) | |
| if len(bboxes) > 0: | |
| results[key] = np.stack(bboxes) | |
| return results | |
| def may_augment_poly(self, aug, img_shape, polys, mask_flag=True): | |
| key_points, poly_point_nums = [], [] | |
| for poly in polys: | |
| if mask_flag: | |
| poly = poly[0] | |
| poly = poly.reshape(-1, 2) | |
| key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly]) | |
| poly_point_nums.append(poly.shape[0]) | |
| key_points = aug.augment_keypoints( | |
| [imgaug.KeypointsOnImage(keypoints=key_points, | |
| shape=img_shape)])[0].keypoints | |
| new_polys = [] | |
| start_idx = 0 | |
| for poly_point_num in poly_point_nums: | |
| new_poly = [] | |
| for key_point in key_points[start_idx:(start_idx + | |
| poly_point_num)]: | |
| new_poly.append([key_point.x, key_point.y]) | |
| start_idx += poly_point_num | |
| new_poly = np.array(new_poly).flatten() | |
| new_polys.append([new_poly] if mask_flag else new_poly) | |
| return new_polys | |
| def __repr__(self): | |
| repr_str = self.__class__.__name__ | |
| return repr_str | |
| class EastRandomCrop: | |
| def __init__(self, | |
| target_size=(640, 640), | |
| max_tries=10, | |
| min_crop_side_ratio=0.1): | |
| self.target_size = target_size | |
| self.max_tries = max_tries | |
| self.min_crop_side_ratio = min_crop_side_ratio | |
| def __call__(self, results): | |
| # sampling crop | |
| # crop image, boxes, masks | |
| img = results['img'] | |
| crop_x, crop_y, crop_w, crop_h = self.crop_area( | |
| img, results['gt_masks']) | |
| scale_w = self.target_size[0] / crop_w | |
| scale_h = self.target_size[1] / crop_h | |
| scale = min(scale_w, scale_h) | |
| h = int(crop_h * scale) | |
| w = int(crop_w * scale) | |
| padded_img = np.zeros( | |
| (self.target_size[1], self.target_size[0], img.shape[2]), | |
| img.dtype) | |
| padded_img[:h, :w] = mmcv.imresize( | |
| img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) | |
| # for bboxes | |
| for key in results['bbox_fields']: | |
| lines = [] | |
| for box in results[key]: | |
| box = box.reshape(2, 2) | |
| poly = ((box - (crop_x, crop_y)) * scale) | |
| if not self.is_poly_outside_rect(poly, 0, 0, w, h): | |
| lines.append(poly.flatten()) | |
| results[key] = np.array(lines) | |
| # for masks | |
| for key in results['mask_fields']: | |
| polys = [] | |
| polys_label = [] | |
| for poly in results[key]: | |
| poly = np.array(poly).reshape(-1, 2) | |
| poly = ((poly - (crop_x, crop_y)) * scale) | |
| if not self.is_poly_outside_rect(poly, 0, 0, w, h): | |
| polys.append([poly]) | |
| polys_label.append(0) | |
| results[key] = PolygonMasks(polys, *self.target_size) | |
| if key == 'gt_masks': | |
| results['gt_labels'] = polys_label | |
| results['img'] = padded_img | |
| results['img_shape'] = padded_img.shape | |
| return results | |
| def is_poly_in_rect(self, poly, x, y, w, h): | |
| poly = np.array(poly) | |
| if poly[:, 0].min() < x or poly[:, 0].max() > x + w: | |
| return False | |
| if poly[:, 1].min() < y or poly[:, 1].max() > y + h: | |
| return False | |
| return True | |
| def is_poly_outside_rect(self, poly, x, y, w, h): | |
| poly = np.array(poly).reshape(-1, 2) | |
| if poly[:, 0].max() < x or poly[:, 0].min() > x + w: | |
| return True | |
| if poly[:, 1].max() < y or poly[:, 1].min() > y + h: | |
| return True | |
| return False | |
| def split_regions(self, axis): | |
| regions = [] | |
| min_axis = 0 | |
| for i in range(1, axis.shape[0]): | |
| if axis[i] != axis[i - 1] + 1: | |
| region = axis[min_axis:i] | |
| min_axis = i | |
| regions.append(region) | |
| return regions | |
| def random_select(self, axis, max_size): | |
| xx = np.random.choice(axis, size=2) | |
| xmin = np.min(xx) | |
| xmax = np.max(xx) | |
| xmin = np.clip(xmin, 0, max_size - 1) | |
| xmax = np.clip(xmax, 0, max_size - 1) | |
| return xmin, xmax | |
| def region_wise_random_select(self, regions): | |
| selected_index = list(np.random.choice(len(regions), 2)) | |
| selected_values = [] | |
| for index in selected_index: | |
| axis = regions[index] | |
| xx = int(np.random.choice(axis, size=1)) | |
| selected_values.append(xx) | |
| xmin = min(selected_values) | |
| xmax = max(selected_values) | |
| return xmin, xmax | |
| def crop_area(self, img, polys): | |
| h, w, _ = img.shape | |
| h_array = np.zeros(h, dtype=np.int32) | |
| w_array = np.zeros(w, dtype=np.int32) | |
| for points in polys: | |
| points = np.round( | |
| points, decimals=0).astype(np.int32).reshape(-1, 2) | |
| min_x = np.min(points[:, 0]) | |
| max_x = np.max(points[:, 0]) | |
| w_array[min_x:max_x] = 1 | |
| min_y = np.min(points[:, 1]) | |
| max_y = np.max(points[:, 1]) | |
| h_array[min_y:max_y] = 1 | |
| # ensure the cropped area not across a text | |
| h_axis = np.where(h_array == 0)[0] | |
| w_axis = np.where(w_array == 0)[0] | |
| if len(h_axis) == 0 or len(w_axis) == 0: | |
| return 0, 0, w, h | |
| h_regions = self.split_regions(h_axis) | |
| w_regions = self.split_regions(w_axis) | |
| for i in range(self.max_tries): | |
| if len(w_regions) > 1: | |
| xmin, xmax = self.region_wise_random_select(w_regions) | |
| else: | |
| xmin, xmax = self.random_select(w_axis, w) | |
| if len(h_regions) > 1: | |
| ymin, ymax = self.region_wise_random_select(h_regions) | |
| else: | |
| ymin, ymax = self.random_select(h_axis, h) | |
| if (xmax - xmin < self.min_crop_side_ratio * w | |
| or ymax - ymin < self.min_crop_side_ratio * h): | |
| # area too small | |
| continue | |
| num_poly_in_rect = 0 | |
| for poly in polys: | |
| if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, | |
| ymax - ymin): | |
| num_poly_in_rect += 1 | |
| break | |
| if num_poly_in_rect > 0: | |
| return xmin, ymin, xmax - xmin, ymax - ymin | |
| return 0, 0, w, h | |