| import torch | |
| class Augment_RGB_torch: | |
| def __init__(self): | |
| pass | |
| def transform0(self, torch_tensor): | |
| return torch_tensor | |
| def transform1(self, torch_tensor): | |
| torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2]) | |
| return torch_tensor | |
| def transform2(self, torch_tensor): | |
| torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2]) | |
| return torch_tensor | |
| def transform3(self, torch_tensor): | |
| torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2]) | |
| return torch_tensor | |
| def transform4(self, torch_tensor): | |
| torch_tensor = torch_tensor.flip(-2) | |
| return torch_tensor | |
| def transform5(self, torch_tensor): | |
| torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2) | |
| return torch_tensor | |
| def transform6(self, torch_tensor): | |
| torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2) | |
| return torch_tensor | |
| def transform7(self, torch_tensor): | |
| torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2) | |
| return torch_tensor | |
| class MixUp_AUG: | |
| def __init__(self): | |
| self.dist = torch.distributions.beta.Beta(torch.tensor([1.2]), torch.tensor([1.2])) | |
| def aug(self, rgb_gt, rgb_noisy): | |
| bs = rgb_gt.size(0) | |
| indices = torch.randperm(bs) | |
| rgb_gt2 = rgb_gt[indices] | |
| rgb_noisy2 = rgb_noisy[indices] | |
| lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() | |
| rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 | |
| rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 | |
| return rgb_gt, rgb_noisy | |