import torchvision.transforms as T import torch import albumentations as A from albumentations.core.transforms_interface import ImageOnlyTransform from albumentations.pytorch import ToTensorV2 import numpy as np def get_train_transform(resize_crop_size = 256, mean = [0.4139, 0.4341, 0.3482, 0.5263], std = [0.0010, 0.0010, 0.0013, 0.0013] ): augmentation = A.Compose( [ A.RandomResizedCrop(height=resize_crop_size, width=resize_crop_size), A.RandomBrightnessContrast(), A.HorizontalFlip(), A.VerticalFlip(), A.GaussianBlur(), A.Normalize(mean=mean, std=std), ToTensorV2(), ] ) def transform(sample): image = sample["image"].numpy().transpose(1,2,0) point = sample["point"] image = augmentation(image=image)["image"] point = coordinate_jitter(point) return dict(image=image, point=point) return transform def get_s2_train_transform(resize_crop_size = 256): augmentation = T.Compose([ T.RandomCrop(resize_crop_size), T.RandomHorizontalFlip(), T.RandomVerticalFlip(), T.GaussianBlur(3), ]) def transform(sample): image = sample["image"] / 10000.0 point = sample["point"] image = torch.tensor(image) image = augmentation(image) point = coordinate_jitter(point) return dict(image=image, point=point) return transform def get_pretrained_s2_train_transform(resize_crop_size = 256): augmentation = T.Compose([ T.RandomCrop(resize_crop_size), T.RandomHorizontalFlip(), T.RandomVerticalFlip(), T.GaussianBlur(3), ]) def transform(sample): image = sample["image"] / 10000.0 point = sample["point"] B10 = np.zeros((1, *image.shape[1:]), dtype=image.dtype) image = np.concatenate([image[:10], B10, image[10:]], axis=0) image = torch.tensor(image) image = augmentation(image) point = coordinate_jitter(point) return dict(image=image, point=point) return transform def coordinate_jitter( point, radius=0.01 # approximately 1 km ): return point + torch.rand(point.shape) * radius