|
|
from dataclasses import dataclass |
|
|
from typing import Tuple, Union, List, Dict |
|
|
from numpy import ndarray |
|
|
import numpy as np |
|
|
from abc import ABC, abstractmethod |
|
|
from scipy.spatial.transform import Rotation as R |
|
|
|
|
|
from .spec import ConfigSpec |
|
|
from .asset import Asset |
|
|
from .utils import axis_angle_to_matrix |
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class AugmentAffineConfig(ConfigSpec): |
|
|
|
|
|
normalize_into: Tuple[float, float] |
|
|
|
|
|
|
|
|
random_scale_p: float |
|
|
|
|
|
|
|
|
random_scale: Tuple[float, float] |
|
|
|
|
|
|
|
|
random_shift_p: float |
|
|
|
|
|
|
|
|
random_shift: Tuple[float, float] |
|
|
|
|
|
@classmethod |
|
|
def parse(cls, config) -> Union['AugmentAffineConfig', None]: |
|
|
if config is None: |
|
|
return None |
|
|
cls.check_keys(config) |
|
|
return AugmentAffineConfig( |
|
|
normalize_into=config.normalize_into, |
|
|
random_scale_p=config.get('random_scale_p', 0.), |
|
|
random_scale=config.get('random_scale', [1., 1.]), |
|
|
random_shift_p=config.get('random_shift_p', 0.), |
|
|
random_shift=config.get('random_shift', [0., 0.]), |
|
|
) |
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class AugmentConfig(ConfigSpec): |
|
|
''' |
|
|
Config to handle final easy augmentation of vertices, normals and bones before sampling. |
|
|
''' |
|
|
augment_affine_config: Union[AugmentAffineConfig, None] |
|
|
|
|
|
@classmethod |
|
|
def parse(cls, config) -> 'AugmentConfig': |
|
|
cls.check_keys(config) |
|
|
return AugmentConfig( |
|
|
augment_affine_config=AugmentAffineConfig.parse(config.get('augment_affine_config', None)), |
|
|
) |
|
|
|
|
|
class Augment(ABC): |
|
|
''' |
|
|
Abstract class for augmentation |
|
|
''' |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def transform(self, asset: Asset, **kwargs): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def inverse(self, asset: Asset): |
|
|
pass |
|
|
|
|
|
class AugmentAffine(Augment): |
|
|
|
|
|
def __init__(self, config: AugmentAffineConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
def _apply(self, v: ndarray, trans: ndarray) -> ndarray: |
|
|
return np.matmul(v, trans[:3, :3].transpose()) + trans[:3, 3] |
|
|
|
|
|
def transform(self, asset: Asset, **kwargs): |
|
|
bound_min = asset.vertices.min(axis=0) |
|
|
bound_max = asset.vertices.max(axis=0) |
|
|
if asset.joints is not None: |
|
|
joints_bound_min = asset.joints.min(axis=0) |
|
|
joints_bound_max = asset.joints.max(axis=0) |
|
|
bound_min = np.minimum(bound_min, joints_bound_min) |
|
|
bound_max = np.maximum(bound_max, joints_bound_max) |
|
|
|
|
|
trans_vertex = np.eye(4, dtype=np.float32) |
|
|
|
|
|
trans_vertex = _trans_to_m(-(bound_max + bound_min)/2) @ trans_vertex |
|
|
|
|
|
|
|
|
normalize_into = self.config.normalize_into |
|
|
scale = np.max((bound_max - bound_min) / (normalize_into[1] - normalize_into[0])) |
|
|
trans_vertex = _scale_to_m(1. / scale) @ trans_vertex |
|
|
|
|
|
bias = (normalize_into[0] + normalize_into[1]) / 2 |
|
|
trans_vertex = _trans_to_m(np.array([bias, bias, bias], dtype=np.float32)) @ trans_vertex |
|
|
|
|
|
if np.random.rand() < self.config.random_scale_p: |
|
|
scale = _scale_to_m(np.random.uniform(self.config.random_scale[0], self.config.random_scale[1])) |
|
|
trans_vertex = scale @ trans_vertex |
|
|
|
|
|
if np.random.rand() < self.config.random_shift_p: |
|
|
l, r = self.config.random_shift |
|
|
shift = _trans_to_m(np.array([np.random.uniform(l, r), np.random.uniform(l, r), np.random.uniform(l, r)]), dtype=np.float32) |
|
|
trans_vertex = shift @ trans_vertex |
|
|
|
|
|
asset.vertices = self._apply(asset.vertices, trans_vertex) |
|
|
|
|
|
if asset.matrix_local is not None: |
|
|
asset.matrix_local[:, :, 3:4] = trans_vertex @ asset.matrix_local[:, :, 3:4] |
|
|
if asset.pose_matrix is not None: |
|
|
asset.pose_matrix[:, :, 3:4] = trans_vertex @ asset.pose_matrix[:, :, 3:4] |
|
|
|
|
|
if asset.joints is not None: |
|
|
asset.joints = self._apply(asset.joints, trans_vertex) |
|
|
if asset.tails is not None: |
|
|
asset.tails = self._apply(asset.tails, trans_vertex) |
|
|
|
|
|
self.trans_vertex = trans_vertex |
|
|
|
|
|
def inverse(self, asset: Asset): |
|
|
m = np.linalg.inv(self.trans_vertex) |
|
|
asset.vertices = self._apply(asset.vertices, m) |
|
|
if asset.joints is not None: |
|
|
asset.joints = self._apply(asset.joints, m) |
|
|
if asset.tails is not None: |
|
|
asset.tails = self._apply(asset.tails, m) |
|
|
|
|
|
def _trans_to_m(v: ndarray): |
|
|
m = np.eye(4, dtype=np.float32) |
|
|
m[0:3, 3] = v |
|
|
return m |
|
|
|
|
|
def _scale_to_m(r: ndarray): |
|
|
m = np.zeros((4, 4), dtype=np.float32) |
|
|
m[0, 0] = r |
|
|
m[1, 1] = r |
|
|
m[2, 2] = r |
|
|
m[3, 3] = 1. |
|
|
return m |
|
|
|
|
|
def get_augments(config: AugmentConfig) -> Tuple[List[Augment], List[Augment]]: |
|
|
first_augments = [] |
|
|
second_augments = [] |
|
|
augment_affine_config = config.augment_affine_config |
|
|
|
|
|
if augment_affine_config is not None: |
|
|
second_augments.append(AugmentAffine(config=augment_affine_config)) |
|
|
return first_augments, second_augments |