Spaces:
Configuration error
Configuration error
| from __future__ import division | |
| import random | |
| import typing | |
| import warnings | |
| from collections import defaultdict | |
| import numpy as np | |
| from .. import random_utils | |
| from .bbox_utils import BboxParams, BboxProcessor | |
| from .keypoints_utils import KeypointParams, KeypointsProcessor | |
| from .serialization import ( | |
| SERIALIZABLE_REGISTRY, | |
| Serializable, | |
| get_shortest_class_fullname, | |
| instantiate_nonserializable, | |
| ) | |
| from .transforms_interface import BasicTransform | |
| from .utils import format_args, get_shape | |
| __all__ = [ | |
| "BaseCompose", | |
| "Compose", | |
| "SomeOf", | |
| "OneOf", | |
| "OneOrOther", | |
| "BboxParams", | |
| "KeypointParams", | |
| "ReplayCompose", | |
| "Sequential", | |
| ] | |
| REPR_INDENT_STEP = 2 | |
| TransformType = typing.Union[BasicTransform, "BaseCompose"] | |
| TransformsSeqType = typing.Sequence[TransformType] | |
| def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType]) -> TransformsSeqType: | |
| new_transforms: typing.List[TransformType] = [] | |
| for transform in transforms: # type: ignore | |
| if isinstance(transform, BaseCompose): | |
| new_transforms.extend(get_always_apply(transform)) | |
| elif transform.always_apply: | |
| new_transforms.append(transform) | |
| return new_transforms | |
| class BaseCompose(Serializable): | |
| def __init__(self, transforms: TransformsSeqType, p: float): | |
| if isinstance(transforms, (BaseCompose, BasicTransform)): | |
| warnings.warn( | |
| "transforms is single transform, but a sequence is expected! Transform will be wrapped into list." | |
| ) | |
| transforms = [transforms] | |
| self.transforms = transforms | |
| self.p = p | |
| self.replay_mode = False | |
| self.applied_in_replay = False | |
| def __len__(self) -> int: | |
| return len(self.transforms) | |
| def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]: | |
| raise NotImplementedError | |
| def __getitem__(self, item: int) -> TransformType: # type: ignore | |
| return self.transforms[item] | |
| def __repr__(self) -> str: | |
| return self.indented_repr() | |
| def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str: | |
| args = {k: v for k, v in self._to_dict().items() if not (k.startswith("__") or k == "transforms")} | |
| repr_string = self.__class__.__name__ + "([" | |
| for t in self.transforms: | |
| repr_string += "\n" | |
| if hasattr(t, "indented_repr"): | |
| t_repr = t.indented_repr(indent + REPR_INDENT_STEP) # type: ignore | |
| else: | |
| t_repr = repr(t) | |
| repr_string += " " * indent + t_repr + "," | |
| repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + "], {args})".format(args=format_args(args)) | |
| return repr_string | |
| def get_class_fullname(cls) -> str: | |
| return get_shortest_class_fullname(cls) | |
| def is_serializable(cls) -> bool: | |
| return True | |
| def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
| return { | |
| "__class_fullname__": self.get_class_fullname(), | |
| "p": self.p, | |
| "transforms": [t._to_dict() for t in self.transforms], # skipcq: PYL-W0212 | |
| } | |
| def get_dict_with_id(self) -> typing.Dict[str, typing.Any]: | |
| return { | |
| "__class_fullname__": self.get_class_fullname(), | |
| "id": id(self), | |
| "params": None, | |
| "transforms": [t.get_dict_with_id() for t in self.transforms], | |
| } | |
| def add_targets(self, additional_targets: typing.Optional[typing.Dict[str, str]]) -> None: | |
| if additional_targets: | |
| for t in self.transforms: | |
| t.add_targets(additional_targets) | |
| def set_deterministic(self, flag: bool, save_key: str = "replay") -> None: | |
| for t in self.transforms: | |
| t.set_deterministic(flag, save_key) | |
| class Compose(BaseCompose): | |
| """Compose transforms and handle all transformations regarding bounding boxes | |
| Args: | |
| transforms (list): list of transformations to compose. | |
| bbox_params (BboxParams): Parameters for bounding boxes transforms | |
| keypoint_params (KeypointParams): Parameters for keypoints transforms | |
| additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'} | |
| p (float): probability of applying all list of transforms. Default: 1.0. | |
| is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you | |
| would like to disable this check - pass False (do it only if you are sure in your data consistency). | |
| """ | |
| def __init__( | |
| self, | |
| transforms: TransformsSeqType, | |
| bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None, | |
| keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None, | |
| additional_targets: typing.Optional[typing.Dict[str, str]] = None, | |
| p: float = 1.0, | |
| is_check_shapes: bool = True, | |
| ): | |
| super(Compose, self).__init__(transforms, p) | |
| self.processors: typing.Dict[str, typing.Union[BboxProcessor, KeypointsProcessor]] = {} | |
| if bbox_params: | |
| if isinstance(bbox_params, dict): | |
| b_params = BboxParams(**bbox_params) | |
| elif isinstance(bbox_params, BboxParams): | |
| b_params = bbox_params | |
| else: | |
| raise ValueError("unknown format of bbox_params, please use `dict` or `BboxParams`") | |
| self.processors["bboxes"] = BboxProcessor(b_params, additional_targets) | |
| if keypoint_params: | |
| if isinstance(keypoint_params, dict): | |
| k_params = KeypointParams(**keypoint_params) | |
| elif isinstance(keypoint_params, KeypointParams): | |
| k_params = keypoint_params | |
| else: | |
| raise ValueError("unknown format of keypoint_params, please use `dict` or `KeypointParams`") | |
| self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets) | |
| if additional_targets is None: | |
| additional_targets = {} | |
| self.additional_targets = additional_targets | |
| for proc in self.processors.values(): | |
| proc.ensure_transforms_valid(self.transforms) | |
| self.add_targets(additional_targets) | |
| self.is_check_args = True | |
| self._disable_check_args_for_transforms(self.transforms) | |
| self.is_check_shapes = is_check_shapes | |
| def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None: | |
| for transform in transforms: | |
| if isinstance(transform, BaseCompose): | |
| Compose._disable_check_args_for_transforms(transform.transforms) | |
| if isinstance(transform, Compose): | |
| transform._disable_check_args() | |
| def _disable_check_args(self) -> None: | |
| self.is_check_args = False | |
| def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
| if args: | |
| raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") | |
| if self.is_check_args: | |
| self._check_args(**data) | |
| assert isinstance(force_apply, (bool, int)), "force_apply must have bool or int type" | |
| need_to_run = force_apply or random.random() < self.p | |
| for p in self.processors.values(): | |
| p.ensure_data_valid(data) | |
| transforms = self.transforms if need_to_run else get_always_apply(self.transforms) | |
| check_each_transform = any( | |
| getattr(item.params, "check_each_transform", False) for item in self.processors.values() | |
| ) | |
| for p in self.processors.values(): | |
| p.preprocess(data) | |
| for idx, t in enumerate(transforms): | |
| data = t(**data) | |
| if check_each_transform: | |
| data = self._check_data_post_transform(data) | |
| data = Compose._make_targets_contiguous(data) # ensure output targets are contiguous | |
| for p in self.processors.values(): | |
| p.postprocess(data) | |
| return data | |
| def _check_data_post_transform(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: | |
| rows, cols = get_shape(data["image"]) | |
| for p in self.processors.values(): | |
| if not getattr(p.params, "check_each_transform", False): | |
| continue | |
| for data_name in p.data_fields: | |
| data[data_name] = p.filter(data[data_name], rows, cols) | |
| return data | |
| def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
| dictionary = super(Compose, self)._to_dict() | |
| bbox_processor = self.processors.get("bboxes") | |
| keypoints_processor = self.processors.get("keypoints") | |
| dictionary.update( | |
| { | |
| "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212 | |
| "keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212 | |
| if keypoints_processor | |
| else None, | |
| "additional_targets": self.additional_targets, | |
| "is_check_shapes": self.is_check_shapes, | |
| } | |
| ) | |
| return dictionary | |
| def get_dict_with_id(self) -> typing.Dict[str, typing.Any]: | |
| dictionary = super().get_dict_with_id() | |
| bbox_processor = self.processors.get("bboxes") | |
| keypoints_processor = self.processors.get("keypoints") | |
| dictionary.update( | |
| { | |
| "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212 | |
| "keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212 | |
| if keypoints_processor | |
| else None, | |
| "additional_targets": self.additional_targets, | |
| "params": None, | |
| "is_check_shapes": self.is_check_shapes, | |
| } | |
| ) | |
| return dictionary | |
| def _check_args(self, **kwargs) -> None: | |
| checked_single = ["image", "mask"] | |
| checked_multi = ["masks"] | |
| check_bbox_param = ["bboxes"] | |
| # ["bboxes", "keypoints"] could be almost any type, no need to check them | |
| shapes = [] | |
| for data_name, data in kwargs.items(): | |
| internal_data_name = self.additional_targets.get(data_name, data_name) | |
| if internal_data_name in checked_single: | |
| if not isinstance(data, np.ndarray): | |
| raise TypeError("{} must be numpy array type".format(data_name)) | |
| shapes.append(data.shape[:2]) | |
| if internal_data_name in checked_multi: | |
| if data is not None and len(data): | |
| if not isinstance(data[0], np.ndarray): | |
| raise TypeError("{} must be list of numpy arrays".format(data_name)) | |
| shapes.append(data[0].shape[:2]) | |
| if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None: | |
| raise ValueError("bbox_params must be specified for bbox transformations") | |
| if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes): | |
| raise ValueError( | |
| "Height and Width of image, mask or masks should be equal. You can disable shapes check " | |
| "by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure " | |
| "about your data consistency)." | |
| ) | |
| def _make_targets_contiguous(data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: | |
| result = {} | |
| for key, value in data.items(): | |
| if isinstance(value, np.ndarray): | |
| value = np.ascontiguousarray(value) | |
| result[key] = value | |
| return result | |
| class OneOf(BaseCompose): | |
| """Select one of transforms to apply. Selected transform will be called with `force_apply=True`. | |
| Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. | |
| Args: | |
| transforms (list): list of transformations to compose. | |
| p (float): probability of applying selected transform. Default: 0.5. | |
| """ | |
| def __init__(self, transforms: TransformsSeqType, p: float = 0.5): | |
| super(OneOf, self).__init__(transforms, p) | |
| transforms_ps = [t.p for t in self.transforms] | |
| s = sum(transforms_ps) | |
| self.transforms_ps = [t / s for t in transforms_ps] | |
| def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
| if self.replay_mode: | |
| for t in self.transforms: | |
| data = t(**data) | |
| return data | |
| if self.transforms_ps and (force_apply or random.random() < self.p): | |
| idx: int = random_utils.choice(len(self.transforms), p=self.transforms_ps) | |
| t = self.transforms[idx] | |
| data = t(force_apply=True, **data) | |
| return data | |
| class SomeOf(BaseCompose): | |
| """Select N transforms to apply. Selected transforms will be called with `force_apply=True`. | |
| Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. | |
| Args: | |
| transforms (list): list of transformations to compose. | |
| n (int): number of transforms to apply. | |
| replace (bool): Whether the sampled transforms are with or without replacement. Default: True. | |
| p (float): probability of applying selected transform. Default: 1. | |
| """ | |
| def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1): | |
| super(SomeOf, self).__init__(transforms, p) | |
| self.n = n | |
| self.replace = replace | |
| transforms_ps = [t.p for t in self.transforms] | |
| s = sum(transforms_ps) | |
| self.transforms_ps = [t / s for t in transforms_ps] | |
| def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
| if self.replay_mode: | |
| for t in self.transforms: | |
| data = t(**data) | |
| return data | |
| if self.transforms_ps and (force_apply or random.random() < self.p): | |
| idx = random_utils.choice(len(self.transforms), size=self.n, replace=self.replace, p=self.transforms_ps) | |
| for i in idx: # type: ignore | |
| t = self.transforms[i] | |
| data = t(force_apply=True, **data) | |
| return data | |
| def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
| dictionary = super(SomeOf, self)._to_dict() | |
| dictionary.update({"n": self.n, "replace": self.replace}) | |
| return dictionary | |
| class OneOrOther(BaseCompose): | |
| """Select one or another transform to apply. Selected transform will be called with `force_apply=True`.""" | |
| def __init__( | |
| self, | |
| first: typing.Optional[TransformType] = None, | |
| second: typing.Optional[TransformType] = None, | |
| transforms: typing.Optional[TransformsSeqType] = None, | |
| p: float = 0.5, | |
| ): | |
| if transforms is None: | |
| if first is None or second is None: | |
| raise ValueError("You must set both first and second or set transforms argument.") | |
| transforms = [first, second] | |
| super(OneOrOther, self).__init__(transforms, p) | |
| if len(self.transforms) != 2: | |
| warnings.warn("Length of transforms is not equal to 2.") | |
| def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
| if self.replay_mode: | |
| for t in self.transforms: | |
| data = t(**data) | |
| return data | |
| if random.random() < self.p: | |
| return self.transforms[0](force_apply=True, **data) | |
| return self.transforms[-1](force_apply=True, **data) | |
| class PerChannel(BaseCompose): | |
| """Apply transformations per-channel | |
| Args: | |
| transforms (list): list of transformations to compose. | |
| channels (sequence): channels to apply the transform to. Pass None to apply to all. | |
| Default: None (apply to all) | |
| p (float): probability of applying the transform. Default: 0.5. | |
| """ | |
| def __init__( | |
| self, transforms: TransformsSeqType, channels: typing.Optional[typing.Sequence[int]] = None, p: float = 0.5 | |
| ): | |
| super(PerChannel, self).__init__(transforms, p) | |
| self.channels = channels | |
| def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: | |
| if force_apply or random.random() < self.p: | |
| image = data["image"] | |
| # Expand mono images to have a single channel | |
| if len(image.shape) == 2: | |
| image = np.expand_dims(image, -1) | |
| if self.channels is None: | |
| self.channels = range(image.shape[2]) | |
| for c in self.channels: | |
| for t in self.transforms: | |
| image[:, :, c] = t(image=image[:, :, c])["image"] | |
| data["image"] = image | |
| return data | |
| class ReplayCompose(Compose): | |
| def __init__( | |
| self, | |
| transforms: TransformsSeqType, | |
| bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None, | |
| keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None, | |
| additional_targets: typing.Optional[typing.Dict[str, str]] = None, | |
| p: float = 1.0, | |
| is_check_shapes: bool = True, | |
| save_key: str = "replay", | |
| ): | |
| super(ReplayCompose, self).__init__( | |
| transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes | |
| ) | |
| self.set_deterministic(True, save_key=save_key) | |
| self.save_key = save_key | |
| def __call__(self, *args, force_apply: bool = False, **kwargs) -> typing.Dict[str, typing.Any]: | |
| kwargs[self.save_key] = defaultdict(dict) | |
| result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs) | |
| serialized = self.get_dict_with_id() | |
| self.fill_with_params(serialized, result[self.save_key]) | |
| self.fill_applied(serialized) | |
| result[self.save_key] = serialized | |
| return result | |
| def replay(saved_augmentations: typing.Dict[str, typing.Any], **kwargs) -> typing.Dict[str, typing.Any]: | |
| augs = ReplayCompose._restore_for_replay(saved_augmentations) | |
| return augs(force_apply=True, **kwargs) | |
| def _restore_for_replay( | |
| transform_dict: typing.Dict[str, typing.Any], lambda_transforms: typing.Optional[dict] = None | |
| ) -> TransformType: | |
| """ | |
| Args: | |
| lambda_transforms (dict): A dictionary that contains lambda transforms, that | |
| is instances of the Lambda class. | |
| This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys | |
| in that dictionary should be named same as `name` arguments in respective lambda transforms from | |
| a serialized pipeline. | |
| """ | |
| applied = transform_dict["applied"] | |
| params = transform_dict["params"] | |
| lmbd = instantiate_nonserializable(transform_dict, lambda_transforms) | |
| if lmbd: | |
| transform = lmbd | |
| else: | |
| name = transform_dict["__class_fullname__"] | |
| args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]} | |
| cls = SERIALIZABLE_REGISTRY[name] | |
| if "transforms" in args: | |
| args["transforms"] = [ | |
| ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms) | |
| for t in args["transforms"] | |
| ] | |
| transform = cls(**args) | |
| transform = typing.cast(BasicTransform, transform) | |
| if isinstance(transform, BasicTransform): | |
| transform.params = params | |
| transform.replay_mode = True | |
| transform.applied_in_replay = applied | |
| return transform | |
| def fill_with_params(self, serialized: dict, all_params: dict) -> None: | |
| params = all_params.get(serialized.get("id")) | |
| serialized["params"] = params | |
| del serialized["id"] | |
| for transform in serialized.get("transforms", []): | |
| self.fill_with_params(transform, all_params) | |
| def fill_applied(self, serialized: typing.Dict[str, typing.Any]) -> bool: | |
| if "transforms" in serialized: | |
| applied = [self.fill_applied(t) for t in serialized["transforms"]] | |
| serialized["applied"] = any(applied) | |
| else: | |
| serialized["applied"] = serialized.get("params") is not None | |
| return serialized["applied"] | |
| def _to_dict(self) -> typing.Dict[str, typing.Any]: | |
| dictionary = super(ReplayCompose, self)._to_dict() | |
| dictionary.update({"save_key": self.save_key}) | |
| return dictionary | |
| class Sequential(BaseCompose): | |
| """Sequentially applies all transforms to targets. | |
| Note: | |
| This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose` | |
| the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to | |
| create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly | |
| chose sequence to input data (see the `Example` section for an example definition of such pipeline). | |
| Example: | |
| >>> import custom_albumentations as albumentations as A | |
| >>> transform = A.Compose([ | |
| >>> A.OneOf([ | |
| >>> A.Sequential([ | |
| >>> A.HorizontalFlip(p=0.5), | |
| >>> A.ShiftScaleRotate(p=0.5), | |
| >>> ]), | |
| >>> A.Sequential([ | |
| >>> A.VerticalFlip(p=0.5), | |
| >>> A.RandomBrightnessContrast(p=0.5), | |
| >>> ]), | |
| >>> ], p=1) | |
| >>> ]) | |
| """ | |
| def __init__(self, transforms: TransformsSeqType, p: float = 0.5): | |
| super().__init__(transforms, p) | |
| def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]: | |
| for t in self.transforms: | |
| data = t(**data) | |
| return data | |