File size: 2,602 Bytes
8f8716a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Data augmentations for SSL pretraining (SimCLR, DINO).
"""
import torch
import torch.nn.functional as F
from typing import Tuple


@torch.no_grad()
def random_time_crop(
    x: torch.Tensor, 
    ratio: Tuple[float, float] | float = (0.6, 0.9),
    *, 
    resize_back: bool = True,
    align_to: int | None = 40
) -> torch.Tensor:
    """
    Randomly crop a contiguous sub-sequence per sample, optionally resize back to original T.
    
    Args:
        x: (B, C, T)
        ratio: crop length ratio in [low, high] or a float
        resize_back: if True, linearly interpolate the cropped view back to length T
        align_to: if not None, crop length is rounded to a multiple of align_to (>= align_to)
    """
    assert x.dim() == 3, f"expected (B,C,T), got {tuple(x.shape)}"
    B, C, T = x.shape
    dev = x.device

    def _sample_L() -> int:
        if isinstance(ratio, (tuple, list)):
            a, b = float(ratio[0]), float(ratio[1])
            r = torch.empty((), device=dev).uniform_(a, b).item()
        else:
            r = float(ratio)
        L = max(2, int(round(T * r)))
        if align_to and align_to > 1:
            L = max(align_to, int(round(L / align_to)) * align_to)
        return min(L, T)

    Ls = [_sample_L() for _ in range(B)]
    outs = []
    for b in range(B):
        L = Ls[b]
        max_start = max(0, T - L)
        s = int(torch.randint(0, max_start + 1, (1,), device=dev).item())
        v = x[b, :, s:s+L]  # (C, L)
        if resize_back and v.shape[-1] != T:
            v = F.interpolate(v[None], size=T, mode="linear", align_corners=False)[0]
        outs.append(v)
    return torch.stack(outs, dim=0)


@torch.no_grad()
def channel_dropout(
    x: torch.Tensor, 
    drop_prob: float = 0.2, 
    min_keep: int = 1
) -> torch.Tensor:
    """
    Drop entire channels to zero with probability drop_prob (per sample, per channel).
    Ensures at least `min_keep` channels remain active in each sample.
    
    Args:
        x: (B, C, T)
        drop_prob: probability to drop each channel
        min_keep: minimum number of channels to keep per sample
    """
    assert x.dim() == 3
    B, C, T = x.shape
    mask = (torch.rand(B, C, 1, device=x.device, dtype=x.dtype) > drop_prob).to(x.dtype)
    
    # Ensure at least min_keep channels kept
    keep = mask.sum(dim=1, keepdim=True)  # (B, 1, 1)
    need = (keep < min_keep).squeeze(-1).squeeze(-1)  # (B,)
    if need.any():
        for b in torch.where(need)[0]:
            idx = torch.randperm(C, device=x.device)[:min_keep]
            mask[b, idx, 0] = 1.0
    
    return x * mask