| import random |
| from typing import Dict, Tuple |
|
|
| import torch |
| import torchvision.transforms.v2.functional as F |
| from einops import rearrange |
|
|
|
|
| class FlipAndRotateSpace(object): |
| """ |
| For now, lets have no parameters |
| Choose 1 of 8 transformations and apply it to space_time_x and space_x |
| """ |
|
|
| def __init__(self, enabled: bool): |
| self.enabled = enabled |
| self.transformations = [ |
| self.no_transform, |
| self.rotate_90, |
| self.rotate_180, |
| self.rotate_270, |
| self.hflip, |
| self.vflip, |
| self.hflip_rotate_90, |
| self.vflip_rotate_90, |
| ] |
|
|
| def no_transform(self, x): |
| return x |
|
|
| def rotate_90(self, x): |
| return F.rotate(x, 90) |
|
|
| def rotate_180(self, x): |
| return F.rotate(x, 180) |
|
|
| def rotate_270(self, x): |
| return F.rotate(x, 270) |
|
|
| def hflip(self, x): |
| return F.hflip(x) |
|
|
| def vflip(self, x): |
| return F.vflip(x) |
|
|
| def hflip_rotate_90(self, x): |
| return F.hflip(F.rotate(x, 90)) |
|
|
| def vflip_rotate_90(self, x): |
| return F.vflip(F.rotate(x, 90)) |
|
|
| def apply( |
| self, |
| space_time_x: torch.Tensor, |
| space_x: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if not self.enabled: |
| return space_time_x, space_x |
|
|
| space_time_x = rearrange( |
| space_time_x.float(), "b h w t c -> b t c h w" |
| ) |
| space_x = rearrange(space_x.float(), "b h w c -> b c h w") |
|
|
| transformation = random.choice(self.transformations) |
|
|
| space_time_x = rearrange( |
| transformation(space_time_x), "b t c h w -> b h w t c" |
| ) |
| space_x = rearrange(transformation(space_x), "b c h w -> b h w c") |
|
|
| return space_time_x.half(), space_x.half() |
|
|
|
|
| class Augmentation(object): |
| def __init__(self, aug_config: Dict): |
| self.flip_and_rotate = FlipAndRotateSpace(enabled=aug_config.get("flip+rotate", False)) |
|
|
| def apply( |
| self, |
| space_time_x: torch.Tensor, |
| space_x: torch.Tensor, |
| time_x: torch.Tensor, |
| static_x: torch.Tensor, |
| months: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| space_time_x, space_x = self.flip_and_rotate.apply(space_time_x, space_x) |
|
|
| return space_time_x, space_x, time_x, static_x, months |
|
|