Spaces:
Configuration error
Configuration error
| from __future__ import division | |
| import math | |
| import typing | |
| import warnings | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| from .utils import DataProcessor, Params | |
| __all__ = [ | |
| "angle_to_2pi_range", | |
| "check_keypoints", | |
| "convert_keypoints_from_albumentations", | |
| "convert_keypoints_to_albumentations", | |
| "filter_keypoints", | |
| "KeypointsProcessor", | |
| "KeypointParams", | |
| ] | |
| keypoint_formats = {"xy", "yx", "xya", "xys", "xyas", "xysa"} | |
| def angle_to_2pi_range(angle: float) -> float: | |
| two_pi = 2 * math.pi | |
| return angle % two_pi | |
| class KeypointParams(Params): | |
| """ | |
| Parameters of keypoints | |
| Args: | |
| format (str): format of keypoints. Should be 'xy', 'yx', 'xya', 'xys', 'xyas', 'xysa'. | |
| x - X coordinate, | |
| y - Y coordinate | |
| s - Keypoint scale | |
| a - Keypoint orientation in radians or degrees (depending on KeypointParams.angle_in_degrees) | |
| label_fields (list): list of fields that are joined with keypoints, e.g labels. | |
| Should be same type as keypoints. | |
| remove_invisible (bool): to remove invisible points after transform or not | |
| angle_in_degrees (bool): angle in degrees or radians in 'xya', 'xyas', 'xysa' keypoints | |
| check_each_transform (bool): if `True`, then keypoints will be checked after each dual transform. | |
| Default: `True` | |
| """ | |
| def __init__( | |
| self, | |
| format: str, # skipcq: PYL-W0622 | |
| label_fields: Optional[Sequence[str]] = None, | |
| remove_invisible: bool = True, | |
| angle_in_degrees: bool = True, | |
| check_each_transform: bool = True, | |
| ): | |
| super(KeypointParams, self).__init__(format, label_fields) | |
| self.remove_invisible = remove_invisible | |
| self.angle_in_degrees = angle_in_degrees | |
| self.check_each_transform = check_each_transform | |
| def _to_dict(self) -> Dict[str, Any]: | |
| data = super(KeypointParams, self)._to_dict() | |
| data.update( | |
| { | |
| "remove_invisible": self.remove_invisible, | |
| "angle_in_degrees": self.angle_in_degrees, | |
| "check_each_transform": self.check_each_transform, | |
| } | |
| ) | |
| return data | |
| def is_serializable(cls) -> bool: | |
| return True | |
| def get_class_fullname(cls) -> str: | |
| return "KeypointParams" | |
| class KeypointsProcessor(DataProcessor): | |
| def __init__(self, params: KeypointParams, additional_targets: Optional[Dict[str, str]] = None): | |
| super().__init__(params, additional_targets) | |
| def default_data_name(self) -> str: | |
| return "keypoints" | |
| def ensure_data_valid(self, data: Dict[str, Any]) -> None: | |
| if self.params.label_fields: | |
| if not all(i in data.keys() for i in self.params.label_fields): | |
| raise ValueError( | |
| "Your 'label_fields' are not valid - them must have same names as params in " | |
| "'keypoint_params' dict" | |
| ) | |
| def ensure_transforms_valid(self, transforms: Sequence[object]) -> None: | |
| # IAA-based augmentations supports only transformation of xy keypoints. | |
| # If your keypoints formats is other than 'xy' we emit warning to let user | |
| # be aware that angle and size will not be modified. | |
| try: | |
| from custom_albumentations.imgaug.transforms import DualIAATransform | |
| except ImportError: | |
| # imgaug is not installed so we skip imgaug checks. | |
| return | |
| if self.params.format is not None and self.params.format != "xy": | |
| for transform in transforms: | |
| if isinstance(transform, DualIAATransform): | |
| warnings.warn( | |
| "{} transformation supports only 'xy' keypoints " | |
| "augmentation. You have '{}' keypoints format. Scale " | |
| "and angle WILL NOT BE transformed.".format(transform.__class__.__name__, self.params.format) | |
| ) | |
| break | |
| def filter(self, data: Sequence[Sequence], rows: int, cols: int) -> Sequence[Sequence]: | |
| self.params: KeypointParams | |
| return filter_keypoints(data, rows, cols, remove_invisible=self.params.remove_invisible) | |
| def check(self, data: Sequence[Sequence], rows: int, cols: int) -> None: | |
| check_keypoints(data, rows, cols) | |
| def convert_from_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]: | |
| params = self.params | |
| return convert_keypoints_from_albumentations( | |
| data, | |
| params.format, | |
| rows, | |
| cols, | |
| check_validity=params.remove_invisible, | |
| angle_in_degrees=params.angle_in_degrees, | |
| ) | |
| def convert_to_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]: | |
| params = self.params | |
| return convert_keypoints_to_albumentations( | |
| data, | |
| params.format, | |
| rows, | |
| cols, | |
| check_validity=params.remove_invisible, | |
| angle_in_degrees=params.angle_in_degrees, | |
| ) | |
| def check_keypoint(kp: Sequence, rows: int, cols: int) -> None: | |
| """Check if keypoint coordinates are less than image shapes""" | |
| for name, value, size in zip(["x", "y"], kp[:2], [cols, rows]): | |
| if not 0 <= value < size: | |
| raise ValueError( | |
| "Expected {name} for keypoint {kp} " | |
| "to be in the range [0.0, {size}], got {value}.".format(kp=kp, name=name, value=value, size=size) | |
| ) | |
| angle = kp[2] | |
| if not (0 <= angle < 2 * math.pi): | |
| raise ValueError("Keypoint angle must be in range [0, 2 * PI). Got: {angle}".format(angle=angle)) | |
| def check_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int) -> None: | |
| """Check if keypoints boundaries are less than image shapes""" | |
| for kp in keypoints: | |
| check_keypoint(kp, rows, cols) | |
| def filter_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int, remove_invisible: bool) -> Sequence[Sequence]: | |
| if not remove_invisible: | |
| return keypoints | |
| resulting_keypoints = [] | |
| for kp in keypoints: | |
| x, y = kp[:2] | |
| if x < 0 or x >= cols: | |
| continue | |
| if y < 0 or y >= rows: | |
| continue | |
| resulting_keypoints.append(kp) | |
| return resulting_keypoints | |
| def convert_keypoint_to_albumentations( | |
| keypoint: Sequence, | |
| source_format: str, | |
| rows: int, | |
| cols: int, | |
| check_validity: bool = False, | |
| angle_in_degrees: bool = True, | |
| ) -> Tuple: | |
| if source_format not in keypoint_formats: | |
| raise ValueError("Unknown target_format {}. Supported formats are: {}".format(source_format, keypoint_formats)) | |
| if source_format == "xy": | |
| (x, y), tail = keypoint[:2], tuple(keypoint[2:]) | |
| a, s = 0.0, 0.0 | |
| elif source_format == "yx": | |
| (y, x), tail = keypoint[:2], tuple(keypoint[2:]) | |
| a, s = 0.0, 0.0 | |
| elif source_format == "xya": | |
| (x, y, a), tail = keypoint[:3], tuple(keypoint[3:]) | |
| s = 0.0 | |
| elif source_format == "xys": | |
| (x, y, s), tail = keypoint[:3], tuple(keypoint[3:]) | |
| a = 0.0 | |
| elif source_format == "xyas": | |
| (x, y, a, s), tail = keypoint[:4], tuple(keypoint[4:]) | |
| elif source_format == "xysa": | |
| (x, y, s, a), tail = keypoint[:4], tuple(keypoint[4:]) | |
| else: | |
| raise ValueError(f"Unsupported source format. Got {source_format}") | |
| if angle_in_degrees: | |
| a = math.radians(a) | |
| keypoint = (x, y, angle_to_2pi_range(a), s) + tail | |
| if check_validity: | |
| check_keypoint(keypoint, rows, cols) | |
| return keypoint | |
| def convert_keypoint_from_albumentations( | |
| keypoint: Sequence, | |
| target_format: str, | |
| rows: int, | |
| cols: int, | |
| check_validity: bool = False, | |
| angle_in_degrees: bool = True, | |
| ) -> Tuple: | |
| if target_format not in keypoint_formats: | |
| raise ValueError("Unknown target_format {}. Supported formats are: {}".format(target_format, keypoint_formats)) | |
| (x, y, angle, scale), tail = keypoint[:4], tuple(keypoint[4:]) | |
| angle = angle_to_2pi_range(angle) | |
| if check_validity: | |
| check_keypoint((x, y, angle, scale), rows, cols) | |
| if angle_in_degrees: | |
| angle = math.degrees(angle) | |
| kp: Tuple | |
| if target_format == "xy": | |
| kp = (x, y) | |
| elif target_format == "yx": | |
| kp = (y, x) | |
| elif target_format == "xya": | |
| kp = (x, y, angle) | |
| elif target_format == "xys": | |
| kp = (x, y, scale) | |
| elif target_format == "xyas": | |
| kp = (x, y, angle, scale) | |
| elif target_format == "xysa": | |
| kp = (x, y, scale, angle) | |
| else: | |
| raise ValueError(f"Invalid target format. Got: {target_format}") | |
| return kp + tail | |
| def convert_keypoints_to_albumentations( | |
| keypoints: Sequence[Sequence], | |
| source_format: str, | |
| rows: int, | |
| cols: int, | |
| check_validity: bool = False, | |
| angle_in_degrees: bool = True, | |
| ) -> List[Tuple]: | |
| return [ | |
| convert_keypoint_to_albumentations(kp, source_format, rows, cols, check_validity, angle_in_degrees) | |
| for kp in keypoints | |
| ] | |
| def convert_keypoints_from_albumentations( | |
| keypoints: Sequence[Sequence], | |
| target_format: str, | |
| rows: int, | |
| cols: int, | |
| check_validity: bool = False, | |
| angle_in_degrees: bool = True, | |
| ) -> List[Tuple]: | |
| return [ | |
| convert_keypoint_from_albumentations(kp, target_format, rows, cols, check_validity, angle_in_degrees) | |
| for kp in keypoints | |
| ] | |