Spaces:
Configuration error
Configuration error
| from __future__ import absolute_import | |
| import random | |
| from copy import deepcopy | |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast | |
| from warnings import warn | |
| import cv2 | |
| import numpy as np | |
| from .serialization import Serializable, get_shortest_class_fullname | |
| from .utils import format_args | |
| __all__ = [ | |
| "to_tuple", | |
| "BasicTransform", | |
| "DualTransform", | |
| "ImageOnlyTransform", | |
| "NoOp", | |
| "BoxType", | |
| "KeypointType", | |
| "ImageColorType", | |
| "ScaleFloatType", | |
| "ScaleIntType", | |
| "ImageColorType", | |
| ] | |
| NumType = Union[int, float, np.ndarray] | |
| BoxInternalType = Tuple[float, float, float, float] | |
| BoxType = Union[BoxInternalType, Tuple[float, float, float, float, Any]] | |
| KeypointInternalType = Tuple[float, float, float, float] | |
| KeypointType = Union[KeypointInternalType, Tuple[float, float, float, float, Any]] | |
| ImageColorType = Union[float, Sequence[float]] | |
| ScaleFloatType = Union[float, Tuple[float, float]] | |
| ScaleIntType = Union[int, Tuple[int, int]] | |
| FillValueType = Optional[Union[int, float, Sequence[int], Sequence[float]]] | |
| def to_tuple(param, low=None, bias=None): | |
| """Convert input argument to min-max tuple | |
| Args: | |
| param (scalar, tuple or list of 2+ elements): Input value. | |
| If value is scalar, return value would be (offset - value, offset + value). | |
| If value is tuple, return value would be value + offset (broadcasted). | |
| low: Second element of tuple can be passed as optional argument | |
| bias: An offset factor added to each element | |
| """ | |
| if low is not None and bias is not None: | |
| raise ValueError("Arguments low and bias are mutually exclusive") | |
| if param is None: | |
| return param | |
| if isinstance(param, (int, float)): | |
| if low is None: | |
| param = -param, +param | |
| else: | |
| param = (low, param) if low < param else (param, low) | |
| elif isinstance(param, Sequence): | |
| if len(param) != 2: | |
| raise ValueError("to_tuple expects 1 or 2 values") | |
| param = tuple(param) | |
| else: | |
| raise ValueError("Argument param must be either scalar (int, float) or tuple") | |
| if bias is not None: | |
| return tuple(bias + x for x in param) | |
| return tuple(param) | |
| class BasicTransform(Serializable): | |
| call_backup = None | |
| interpolation: Any | |
| fill_value: Any | |
| mask_fill_value: Any | |
| def __init__(self, always_apply: bool = False, p: float = 0.5): | |
| self.p = p | |
| self.always_apply = always_apply | |
| self._additional_targets: Dict[str, str] = {} | |
| # replay mode params | |
| self.deterministic = False | |
| self.save_key = "replay" | |
| self.params: Dict[Any, Any] = {} | |
| self.replay_mode = False | |
| self.applied_in_replay = False | |
| def __call__(self, *args, force_apply: bool = False, **kwargs) -> Dict[str, Any]: | |
| if args: | |
| raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") | |
| if self.replay_mode: | |
| if self.applied_in_replay: | |
| return self.apply_with_params(self.params, **kwargs) | |
| return kwargs | |
| if (random.random() < self.p) or self.always_apply or force_apply: | |
| params = self.get_params() | |
| if self.targets_as_params: | |
| assert all(key in kwargs for key in self.targets_as_params), "{} requires {}".format( | |
| self.__class__.__name__, self.targets_as_params | |
| ) | |
| targets_as_params = {k: kwargs[k] for k in self.targets_as_params} | |
| params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params) | |
| params.update(params_dependent_on_targets) | |
| if self.deterministic: | |
| if self.targets_as_params: | |
| warn( | |
| self.get_class_fullname() + " could work incorrectly in ReplayMode for other input data" | |
| " because its' params depend on targets." | |
| ) | |
| kwargs[self.save_key][id(self)] = deepcopy(params) | |
| return self.apply_with_params(params, **kwargs) | |
| return kwargs | |
| def apply_with_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]: # skipcq: PYL-W0613 | |
| if params is None: | |
| return kwargs | |
| params = self.update_params(params, **kwargs) | |
| res = {} | |
| for key, arg in kwargs.items(): | |
| if arg is not None: | |
| target_function = self._get_target_function(key) | |
| target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])} | |
| res[key] = target_function(arg, **dict(params, **target_dependencies)) | |
| else: | |
| res[key] = None | |
| return res | |
| def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform": | |
| assert save_key != "params", "params save_key is reserved" | |
| self.deterministic = flag | |
| self.save_key = save_key | |
| return self | |
| def __repr__(self) -> str: | |
| state = self.get_base_init_args() | |
| state.update(self.get_transform_init_args()) | |
| return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state)) | |
| def _get_target_function(self, key: str) -> Callable: | |
| transform_key = key | |
| if key in self._additional_targets: | |
| transform_key = self._additional_targets.get(key, key) | |
| target_function = self.targets.get(transform_key, lambda x, **p: x) | |
| return target_function | |
| def apply(self, img: np.ndarray, **params) -> np.ndarray: | |
| raise NotImplementedError | |
| def get_params(self) -> Dict: | |
| return {} | |
| def targets(self) -> Dict[str, Callable]: | |
| # you must specify targets in subclass | |
| # for example: ('image', 'mask') | |
| # ('image', 'boxes') | |
| raise NotImplementedError | |
| def update_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |
| if hasattr(self, "interpolation"): | |
| params["interpolation"] = self.interpolation | |
| if hasattr(self, "fill_value"): | |
| params["fill_value"] = self.fill_value | |
| if hasattr(self, "mask_fill_value"): | |
| params["mask_fill_value"] = self.mask_fill_value | |
| params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]}) | |
| return params | |
| def target_dependence(self) -> Dict: | |
| return {} | |
| def add_targets(self, additional_targets: Dict[str, str]): | |
| """Add targets to transform them the same way as one of existing targets | |
| ex: {'target_image': 'image'} | |
| ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'} | |
| by the way you must have at least one object with key 'image' | |
| Args: | |
| additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'} | |
| """ | |
| self._additional_targets = additional_targets | |
| def targets_as_params(self) -> List[str]: | |
| return [] | |
| def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: | |
| raise NotImplementedError( | |
| "Method get_params_dependent_on_targets is not implemented in class " + self.__class__.__name__ | |
| ) | |
| def get_class_fullname(cls) -> str: | |
| return get_shortest_class_fullname(cls) | |
| def is_serializable(cls): | |
| return True | |
| def get_transform_init_args_names(self) -> Tuple[str, ...]: | |
| raise NotImplementedError( | |
| "Class {name} is not serializable because the `get_transform_init_args_names` method is not " | |
| "implemented".format(name=self.get_class_fullname()) | |
| ) | |
| def get_base_init_args(self) -> Dict[str, Any]: | |
| return {"always_apply": self.always_apply, "p": self.p} | |
| def get_transform_init_args(self) -> Dict[str, Any]: | |
| return {k: getattr(self, k) for k in self.get_transform_init_args_names()} | |
| def _to_dict(self) -> Dict[str, Any]: | |
| state = {"__class_fullname__": self.get_class_fullname()} | |
| state.update(self.get_base_init_args()) | |
| state.update(self.get_transform_init_args()) | |
| return state | |
| def get_dict_with_id(self) -> Dict[str, Any]: | |
| d = self._to_dict() | |
| d["id"] = id(self) | |
| return d | |
| class DualTransform(BasicTransform): | |
| """Transform for segmentation task.""" | |
| def targets(self) -> Dict[str, Callable]: | |
| return { | |
| "image": self.apply, | |
| "mask": self.apply_to_mask, | |
| "masks": self.apply_to_masks, | |
| "bboxes": self.apply_to_bboxes, | |
| "keypoints": self.apply_to_keypoints, | |
| } | |
| def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: | |
| raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__) | |
| def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: | |
| raise NotImplementedError("Method apply_to_keypoint is not implemented in class " + self.__class__.__name__) | |
| def apply_to_bboxes(self, bboxes: Sequence[BoxType], **params) -> List[BoxType]: | |
| return [self.apply_to_bbox(tuple(bbox[:4]), **params) + tuple(bbox[4:]) for bbox in bboxes] # type: ignore | |
| def apply_to_keypoints(self, keypoints: Sequence[KeypointType], **params) -> List[KeypointType]: | |
| return [ # type: ignore | |
| self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) # type: ignore | |
| for keypoint in keypoints | |
| ] | |
| def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray: | |
| return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()}) | |
| def apply_to_masks(self, masks: Sequence[np.ndarray], **params) -> List[np.ndarray]: | |
| return [self.apply_to_mask(mask, **params) for mask in masks] | |
| class ImageOnlyTransform(BasicTransform): | |
| """Transform applied to image only.""" | |
| def targets(self) -> Dict[str, Callable]: | |
| return {"image": self.apply} | |
| class NoOp(DualTransform): | |
| """Does nothing""" | |
| def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: | |
| return keypoint | |
| def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: | |
| return bbox | |
| def apply(self, img: np.ndarray, **params) -> np.ndarray: | |
| return img | |
| def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray: | |
| return img | |
| def get_transform_init_args_names(self) -> Tuple: | |
| return () | |