Spaces:
Running
Running
File size: 1,745 Bytes
1abfecb 4e9fa0a 1abfecb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import random
import torch
import torchvision.transforms.functional as TF
def augment_triplet(img0: torch.Tensor, img1: torch.Tensor, gt: torch.Tensor, crop_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Applies identical spatial augmentations to all three frames in a triplet.
Ensures that the cloud motions remain physically consistent across the temporal sequence.
Args:
img0: Past frame tensor [C, H, W]
img1: Future frame tensor [C, H, W]
gt: Ground truth intermediate frame tensor [C, H, W]
crop_size: Final output dimension for height and width.
Returns:
Tuple of augmented tensors (img0, img1, gt).
"""
_, h, w = img0.shape
if h < crop_size or w < crop_size:
raise ValueError(
f"Input smaller than crop size: {h}x{w}"
)
# Random Spatial Cropping
top = random.randint(0, h - crop_size)
left = random.randint(0, w - crop_size)
img0 = TF.crop(img0, top, left, crop_size, crop_size)
img1 = TF.crop(img1, top, left, crop_size, crop_size)
gt = TF.crop(gt, top, left, crop_size, crop_size)
# Random Horizontal Flip
if random.random() > 0.5:
img0 = TF.hflip(img0)
img1 = TF.hflip(img1)
gt = TF.hflip(gt)
# Random Vertical Flip
if random.random() > 0.5:
img0 = TF.vflip(img0)
img1 = TF.vflip(img1)
gt = TF.vflip(gt)
# Random Dihedral Transformations
angles = [0, 90, 180, 270]
angle = random.choice(angles)
if angle > 0:
img0 = TF.rotate(img0, angle)
img1 = TF.rotate(img1, angle)
gt = TF.rotate(gt, angle)
return img0, img1, gt
|