File size: 1,745 Bytes
1abfecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e9fa0a
 
 
 
1abfecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import random

import torch
import torchvision.transforms.functional as TF


def augment_triplet(img0: torch.Tensor, img1: torch.Tensor, gt: torch.Tensor, crop_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Applies identical spatial augmentations to all three frames in a triplet.
    
    Ensures that the cloud motions remain physically consistent across the temporal sequence.
    
    Args:
        img0: Past frame tensor [C, H, W]
        img1: Future frame tensor [C, H, W]
        gt: Ground truth intermediate frame tensor [C, H, W]
        crop_size: Final output dimension for height and width.
        
    Returns:
        Tuple of augmented tensors (img0, img1, gt).
    """
    _, h, w = img0.shape
    if h < crop_size or w < crop_size:
        raise ValueError(
            f"Input smaller than crop size: {h}x{w}"
        )
    
    # Random Spatial Cropping
    top = random.randint(0, h - crop_size)
    left = random.randint(0, w - crop_size)
    
    img0 = TF.crop(img0, top, left, crop_size, crop_size)
    img1 = TF.crop(img1, top, left, crop_size, crop_size)
    gt = TF.crop(gt, top, left, crop_size, crop_size)
    
    # Random Horizontal Flip
    if random.random() > 0.5:
        img0 = TF.hflip(img0)
        img1 = TF.hflip(img1)
        gt = TF.hflip(gt)
        
    # Random Vertical Flip
    if random.random() > 0.5:
        img0 = TF.vflip(img0)
        img1 = TF.vflip(img1)
        gt = TF.vflip(gt)
        
    # Random Dihedral Transformations
    angles = [0, 90, 180, 270]
    angle = random.choice(angles)
    if angle > 0:
        img0 = TF.rotate(img0, angle)
        img1 = TF.rotate(img1, angle)
        gt = TF.rotate(gt, angle)
        
    return img0, img1, gt