import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List import random import math class RandAugment(nn.Module): """RandAugment for images""" def __init__(self, n: int = 2, m: int = 10): super().__init__() self.n = n self.m = m def forward(self, x: torch.Tensor) -> torch.Tensor: """随机应用n个增强操作""" # 确保输入是 [B, C, H, W],如果是 [C, H, W] 则增加维度 is_batched = x.ndim == 4 if not is_batched: x = x.unsqueeze(0) augmentations = [ self._auto_contrast, self._equalize, self._solarize, self._color, self._contrast, self._brightness, self._sharpness, ] for _ in range(self.n): aug = random.choice(augmentations) x = aug(x) if not is_batched: x = x.squeeze(0) return x def _auto_contrast(self, x: torch.Tensor) -> torch.Tensor: # 针对每个样本分别计算 min/max # x: [B, C, H, W] B, C, H, W = x.shape x_flat = x.view(B, C, -1) min_val = x_flat.min(dim=2, keepdim=True)[0].view(B, C, 1, 1) max_val = x_flat.max(dim=2, keepdim=True)[0].view(B, C, 1, 1) return (x - min_val) / (max_val - min_val + 1e-8) def _equalize(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape x_int = (x * 255).long().clamp(0, 255) out = torch.zeros_like(x) for b in range(B): for c in range(C): hist = torch.histc(x[b, c].float(), bins=256, min=0, max=1) cdf = hist.cumsum(0) cdf = cdf / cdf[-1] # 归一化 # 使用cdf作为查找表 out[b, c] = cdf[x_int[b, c]] return out def _solarize(self, x: torch.Tensor) -> torch.Tensor: threshold = random.uniform(0.3, 0.7) return torch.where(x < threshold, x, 1.0 - x) def _color(self, x: torch.Tensor) -> torch.Tensor: factor = 1.0 + (random.random() - 0.5) * 0.4 mean = x.mean(dim=1, keepdim=True) return torch.clamp(mean + factor * (x - mean), 0, 1) def _contrast(self, x: torch.Tensor) -> torch.Tensor: factor = 1.0 + (random.random() - 0.5) * 0.4 # 计算整张图的均值,保留 Batch 维度 # view(B, -1) -> mean(1) -> view(B, 1, 1, 1) mean = x.view(x.size(0), -1).mean(dim=1).view(-1, 1, 1, 1) return torch.clamp(mean + factor * (x - mean), 0, 1) def _brightness(self, x: torch.Tensor) -> torch.Tensor: """亮度""" factor = 1.0 + (random.random() - 0.5) * 0.4 return torch.clamp(x * factor, 0, 1) def _sharpness(self, x: torch.Tensor) -> torch.Tensor: """锐化: 通过混合原图和高斯模糊图实现""" factor = 1.0 + (random.random() - 0.5) * 0.4 # 使用 AvgPool 模拟模糊 kernel_size = 3 pad = kernel_size // 2 blurred = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad) return torch.clamp(x + (factor - 1.0) * (x - blurred), 0, 1) class MixUp(nn.Module): def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None): super().__init__() self.alpha = alpha self.num_classes = num_classes def forward( self, x: torch.Tensor, y: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]: if self.alpha > 0: lambda_ = random.betavariate(self.alpha, self.alpha) else: lambda_ = 1.0 batch_size = x.shape[0] index = torch.randperm(batch_size, device=x.device) mixed_x = lambda_ * x + (1 - lambda_) * x[index] mixed_y = None if y is not None: # 处理标签混合 y_a = y y_b = y[index] if y.dtype == torch.long or y.ndim == 1: if self.num_classes is None: self.num_classes = int(y.max().item()) + 1 y_a = F.one_hot(y_a, num_classes=self.num_classes).float() y_b = F.one_hot(y_b, num_classes=self.num_classes).float() mixed_y = lambda_ * y_a + (1 - lambda_) * y_b return mixed_x, mixed_y, lambda_ class CutMix(nn.Module): def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None): super().__init__() self.alpha = alpha self.num_classes = num_classes def _rand_bbox( self, size: Tuple[int, ...], lambda_: float ) -> Tuple[int, int, int, int]: W = size[-1] # 兼容 [B, C, H, W] H = size[-2] cut_rat = math.sqrt(1.0 - lambda_) cut_w = int(W * cut_rat) cut_h = int(H * cut_rat) cx = random.randint(0, W) cy = random.randint(0, H) bbx1 = torch.tensor(cx - cut_w // 2, device='cpu').clamp(0, W).item() bby1 = torch.tensor(cy - cut_h // 2, device='cpu').clamp(0, H).item() bbx2 = torch.tensor(cx + cut_w // 2, device='cpu').clamp(0, W).item() bby2 = torch.tensor(cy + cut_h // 2, device='cpu').clamp(0, H).item() return int(bbx1), int(bby1), int(bbx2), int(bby2) def forward( self, x: torch.Tensor, y: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]: if self.alpha > 0: lambda_ = random.betavariate(self.alpha, self.alpha) else: lambda_ = 1.0 batch_size = x.shape[0] index = torch.randperm(batch_size, device=x.device) bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lambda_) x = x.clone() x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2] H, W = x.size()[-2], x.size()[-1] lambda_ = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W)) mixed_y = None if y is not None: y_a = y y_b = y[index] if y.dtype == torch.long or y.ndim == 1: if self.num_classes is None: self.num_classes = int(y.max().item()) + 1 y_a = F.one_hot(y_a, num_classes=self.num_classes).float() y_b = F.one_hot(y_b, num_classes=self.num_classes).float() mixed_y = lambda_ * y_a + (1 - lambda_) * y_b return x, mixed_y, lambda_ class SpecAugment(nn.Module): def __init__( self, freq_mask_param: int = 27, time_mask_param: int = 100, num_freq_masks: int = 2, num_time_masks: int = 2 ): super().__init__() self.freq_mask_param = freq_mask_param self.time_mask_param = time_mask_param self.num_freq_masks = num_freq_masks self.num_time_masks = num_time_masks def forward(self, spec: torch.Tensor) -> torch.Tensor: """ Args: spec: [B, F, T] or [B, C, F, T] """ input_ndim = spec.ndim if input_ndim == 3: spec = spec.unsqueeze(1) # [B, 1, F, T] B, C, F, T = spec.shape spec = spec.clone() # 频率遮罩 for _ in range(self.num_freq_masks): # 确保 mask 不超过 F f_param = min(self.freq_mask_param, F) f = random.randint(0, f_param) f0 = random.randint(0, max(0, F - f)) spec[:, :, f0:f0+f, :] = 0 # 时间遮罩 for _ in range(self.num_time_masks): # 确保 mask 不超过 T t_param = min(self.time_mask_param, T) t = random.randint(0, t_param) t0 = random.randint(0, max(0, T - t)) spec[:, :, :, t0:t0+t] = 0 if input_ndim == 3: return spec.squeeze(1) return spec class TemporalMasking(nn.Module): """视频的时序遮罩""" def __init__(self, mask_ratio: float = 0.15): super().__init__() self.mask_ratio = mask_ratio def forward(self, video: torch.Tensor) -> torch.Tensor: """ Args: video: [B, T, C, H, W] """ B, T, C, H, W = video.shape num_mask = int(T * self.mask_ratio) if num_mask == 0: return video video = video.clone() for b in range(B): # 随机采样要遮罩的帧索引 mask_indices = torch.randperm(T)[:num_mask] video[b, mask_indices] = 0 return video class MultiModalAugmentation(nn.Module): """统一的多模态数据增强""" def __init__( self, image_aug: bool = True, audio_aug: bool = True, video_aug: bool = True, use_mixup: bool = True, use_cutmix: bool = True, num_classes: Optional[int] = None ): super().__init__() self.image_aug = RandAugment() if image_aug else None self.audio_aug = SpecAugment() if audio_aug else None self.video_aug = TemporalMasking() if video_aug else None self.mixup = MixUp(num_classes=num_classes) if use_mixup else None self.cutmix = CutMix(num_classes=num_classes) if use_cutmix else None def forward( self, data: torch.Tensor, modality: str, labels: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: data: 输入数据 modality: 模态类型 ('image', 'audio', 'video') labels: 标签(可选) """ if modality == 'image' and self.image_aug is not None: data = self.image_aug(data) elif modality == 'audio' and self.audio_aug is not None: data = self.audio_aug(data) elif modality == 'video' and self.video_aug is not None: data = self.video_aug(data) if self.training and labels is not None: apply_mixup = False apply_cutmix = False p = random.random() if self.cutmix is not None and modality == 'image': if p < 0.5: apply_cutmix = True elif self.mixup is not None: apply_mixup = True elif self.mixup is not None: if p < 0.5: apply_mixup = True if apply_cutmix: data, labels, _ = self.cutmix(data, labels) elif apply_mixup: data, labels, _ = self.mixup(data, labels) return data, labels