|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Tuple, List |
|
|
import random |
|
|
import math |
|
|
|
|
|
class RandAugment(nn.Module): |
|
|
"""RandAugment for images""" |
|
|
def __init__(self, n: int = 2, m: int = 10): |
|
|
super().__init__() |
|
|
self.n = n |
|
|
self.m = m |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""随机应用n个增强操作""" |
|
|
|
|
|
is_batched = x.ndim == 4 |
|
|
if not is_batched: |
|
|
x = x.unsqueeze(0) |
|
|
|
|
|
augmentations = [ |
|
|
self._auto_contrast, |
|
|
self._equalize, |
|
|
self._solarize, |
|
|
self._color, |
|
|
self._contrast, |
|
|
self._brightness, |
|
|
self._sharpness, |
|
|
] |
|
|
|
|
|
for _ in range(self.n): |
|
|
aug = random.choice(augmentations) |
|
|
x = aug(x) |
|
|
|
|
|
if not is_batched: |
|
|
x = x.squeeze(0) |
|
|
|
|
|
return x |
|
|
|
|
|
def _auto_contrast(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
B, C, H, W = x.shape |
|
|
x_flat = x.view(B, C, -1) |
|
|
min_val = x_flat.min(dim=2, keepdim=True)[0].view(B, C, 1, 1) |
|
|
max_val = x_flat.max(dim=2, keepdim=True)[0].view(B, C, 1, 1) |
|
|
return (x - min_val) / (max_val - min_val + 1e-8) |
|
|
|
|
|
def _equalize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, C, H, W = x.shape |
|
|
x_int = (x * 255).long().clamp(0, 255) |
|
|
|
|
|
out = torch.zeros_like(x) |
|
|
|
|
|
for b in range(B): |
|
|
for c in range(C): |
|
|
hist = torch.histc(x[b, c].float(), bins=256, min=0, max=1) |
|
|
cdf = hist.cumsum(0) |
|
|
cdf = cdf / cdf[-1] |
|
|
|
|
|
out[b, c] = cdf[x_int[b, c]] |
|
|
|
|
|
return out |
|
|
|
|
|
def _solarize(self, x: torch.Tensor) -> torch.Tensor: |
|
|
threshold = random.uniform(0.3, 0.7) |
|
|
return torch.where(x < threshold, x, 1.0 - x) |
|
|
|
|
|
def _color(self, x: torch.Tensor) -> torch.Tensor: |
|
|
factor = 1.0 + (random.random() - 0.5) * 0.4 |
|
|
mean = x.mean(dim=1, keepdim=True) |
|
|
return torch.clamp(mean + factor * (x - mean), 0, 1) |
|
|
|
|
|
def _contrast(self, x: torch.Tensor) -> torch.Tensor: |
|
|
factor = 1.0 + (random.random() - 0.5) * 0.4 |
|
|
|
|
|
|
|
|
mean = x.view(x.size(0), -1).mean(dim=1).view(-1, 1, 1, 1) |
|
|
return torch.clamp(mean + factor * (x - mean), 0, 1) |
|
|
|
|
|
def _brightness(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""亮度""" |
|
|
factor = 1.0 + (random.random() - 0.5) * 0.4 |
|
|
return torch.clamp(x * factor, 0, 1) |
|
|
|
|
|
def _sharpness(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""锐化: 通过混合原图和高斯模糊图实现""" |
|
|
factor = 1.0 + (random.random() - 0.5) * 0.4 |
|
|
|
|
|
kernel_size = 3 |
|
|
pad = kernel_size // 2 |
|
|
blurred = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad) |
|
|
return torch.clamp(x + (factor - 1.0) * (x - blurred), 0, 1) |
|
|
|
|
|
class MixUp(nn.Module): |
|
|
def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None): |
|
|
super().__init__() |
|
|
self.alpha = alpha |
|
|
self.num_classes = num_classes |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
y: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]: |
|
|
|
|
|
if self.alpha > 0: |
|
|
lambda_ = random.betavariate(self.alpha, self.alpha) |
|
|
else: |
|
|
lambda_ = 1.0 |
|
|
|
|
|
batch_size = x.shape[0] |
|
|
index = torch.randperm(batch_size, device=x.device) |
|
|
|
|
|
mixed_x = lambda_ * x + (1 - lambda_) * x[index] |
|
|
|
|
|
mixed_y = None |
|
|
if y is not None: |
|
|
|
|
|
y_a = y |
|
|
y_b = y[index] |
|
|
|
|
|
if y.dtype == torch.long or y.ndim == 1: |
|
|
if self.num_classes is None: |
|
|
self.num_classes = int(y.max().item()) + 1 |
|
|
|
|
|
y_a = F.one_hot(y_a, num_classes=self.num_classes).float() |
|
|
y_b = F.one_hot(y_b, num_classes=self.num_classes).float() |
|
|
|
|
|
mixed_y = lambda_ * y_a + (1 - lambda_) * y_b |
|
|
|
|
|
return mixed_x, mixed_y, lambda_ |
|
|
|
|
|
class CutMix(nn.Module): |
|
|
def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None): |
|
|
super().__init__() |
|
|
self.alpha = alpha |
|
|
self.num_classes = num_classes |
|
|
|
|
|
def _rand_bbox( |
|
|
self, |
|
|
size: Tuple[int, ...], |
|
|
lambda_: float |
|
|
) -> Tuple[int, int, int, int]: |
|
|
W = size[-1] |
|
|
H = size[-2] |
|
|
cut_rat = math.sqrt(1.0 - lambda_) |
|
|
cut_w = int(W * cut_rat) |
|
|
cut_h = int(H * cut_rat) |
|
|
|
|
|
cx = random.randint(0, W) |
|
|
cy = random.randint(0, H) |
|
|
|
|
|
bbx1 = torch.tensor(cx - cut_w // 2, device='cpu').clamp(0, W).item() |
|
|
bby1 = torch.tensor(cy - cut_h // 2, device='cpu').clamp(0, H).item() |
|
|
bbx2 = torch.tensor(cx + cut_w // 2, device='cpu').clamp(0, W).item() |
|
|
bby2 = torch.tensor(cy + cut_h // 2, device='cpu').clamp(0, H).item() |
|
|
|
|
|
return int(bbx1), int(bby1), int(bbx2), int(bby2) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
y: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]: |
|
|
|
|
|
if self.alpha > 0: |
|
|
lambda_ = random.betavariate(self.alpha, self.alpha) |
|
|
else: |
|
|
lambda_ = 1.0 |
|
|
|
|
|
batch_size = x.shape[0] |
|
|
index = torch.randperm(batch_size, device=x.device) |
|
|
|
|
|
bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lambda_) |
|
|
|
|
|
x = x.clone() |
|
|
x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2] |
|
|
|
|
|
H, W = x.size()[-2], x.size()[-1] |
|
|
lambda_ = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W)) |
|
|
|
|
|
mixed_y = None |
|
|
if y is not None: |
|
|
y_a = y |
|
|
y_b = y[index] |
|
|
|
|
|
if y.dtype == torch.long or y.ndim == 1: |
|
|
if self.num_classes is None: |
|
|
self.num_classes = int(y.max().item()) + 1 |
|
|
y_a = F.one_hot(y_a, num_classes=self.num_classes).float() |
|
|
y_b = F.one_hot(y_b, num_classes=self.num_classes).float() |
|
|
|
|
|
mixed_y = lambda_ * y_a + (1 - lambda_) * y_b |
|
|
|
|
|
return x, mixed_y, lambda_ |
|
|
|
|
|
class SpecAugment(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
freq_mask_param: int = 27, |
|
|
time_mask_param: int = 100, |
|
|
num_freq_masks: int = 2, |
|
|
num_time_masks: int = 2 |
|
|
): |
|
|
super().__init__() |
|
|
self.freq_mask_param = freq_mask_param |
|
|
self.time_mask_param = time_mask_param |
|
|
self.num_freq_masks = num_freq_masks |
|
|
self.num_time_masks = num_time_masks |
|
|
|
|
|
def forward(self, spec: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
spec: [B, F, T] or [B, C, F, T] |
|
|
""" |
|
|
input_ndim = spec.ndim |
|
|
if input_ndim == 3: |
|
|
spec = spec.unsqueeze(1) |
|
|
|
|
|
B, C, F, T = spec.shape |
|
|
spec = spec.clone() |
|
|
|
|
|
|
|
|
for _ in range(self.num_freq_masks): |
|
|
|
|
|
f_param = min(self.freq_mask_param, F) |
|
|
f = random.randint(0, f_param) |
|
|
f0 = random.randint(0, max(0, F - f)) |
|
|
spec[:, :, f0:f0+f, :] = 0 |
|
|
|
|
|
|
|
|
for _ in range(self.num_time_masks): |
|
|
|
|
|
t_param = min(self.time_mask_param, T) |
|
|
t = random.randint(0, t_param) |
|
|
t0 = random.randint(0, max(0, T - t)) |
|
|
spec[:, :, :, t0:t0+t] = 0 |
|
|
|
|
|
if input_ndim == 3: |
|
|
return spec.squeeze(1) |
|
|
return spec |
|
|
|
|
|
class TemporalMasking(nn.Module): |
|
|
"""视频的时序遮罩""" |
|
|
def __init__(self, mask_ratio: float = 0.15): |
|
|
super().__init__() |
|
|
self.mask_ratio = mask_ratio |
|
|
|
|
|
def forward(self, video: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
video: [B, T, C, H, W] |
|
|
""" |
|
|
B, T, C, H, W = video.shape |
|
|
num_mask = int(T * self.mask_ratio) |
|
|
if num_mask == 0: |
|
|
return video |
|
|
|
|
|
video = video.clone() |
|
|
|
|
|
for b in range(B): |
|
|
|
|
|
mask_indices = torch.randperm(T)[:num_mask] |
|
|
video[b, mask_indices] = 0 |
|
|
|
|
|
return video |
|
|
|
|
|
class MultiModalAugmentation(nn.Module): |
|
|
"""统一的多模态数据增强""" |
|
|
def __init__( |
|
|
self, |
|
|
image_aug: bool = True, |
|
|
audio_aug: bool = True, |
|
|
video_aug: bool = True, |
|
|
use_mixup: bool = True, |
|
|
use_cutmix: bool = True, |
|
|
num_classes: Optional[int] = None |
|
|
): |
|
|
super().__init__() |
|
|
self.image_aug = RandAugment() if image_aug else None |
|
|
self.audio_aug = SpecAugment() if audio_aug else None |
|
|
self.video_aug = TemporalMasking() if video_aug else None |
|
|
|
|
|
self.mixup = MixUp(num_classes=num_classes) if use_mixup else None |
|
|
self.cutmix = CutMix(num_classes=num_classes) if use_cutmix else None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
data: torch.Tensor, |
|
|
modality: str, |
|
|
labels: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
""" |
|
|
Args: |
|
|
data: 输入数据 |
|
|
modality: 模态类型 ('image', 'audio', 'video') |
|
|
labels: 标签(可选) |
|
|
""" |
|
|
if modality == 'image' and self.image_aug is not None: |
|
|
data = self.image_aug(data) |
|
|
elif modality == 'audio' and self.audio_aug is not None: |
|
|
data = self.audio_aug(data) |
|
|
elif modality == 'video' and self.video_aug is not None: |
|
|
data = self.video_aug(data) |
|
|
|
|
|
if self.training and labels is not None: |
|
|
|
|
|
apply_mixup = False |
|
|
apply_cutmix = False |
|
|
|
|
|
p = random.random() |
|
|
|
|
|
if self.cutmix is not None and modality == 'image': |
|
|
if p < 0.5: |
|
|
apply_cutmix = True |
|
|
elif self.mixup is not None: |
|
|
apply_mixup = True |
|
|
elif self.mixup is not None: |
|
|
if p < 0.5: |
|
|
apply_mixup = True |
|
|
|
|
|
if apply_cutmix: |
|
|
data, labels, _ = self.cutmix(data, labels) |
|
|
elif apply_mixup: |
|
|
data, labels, _ = self.mixup(data, labels) |
|
|
|
|
|
return data, labels |