szxllm commited on
Commit
b17ba29
·
verified ·
1 Parent(s): 28693e2

Update data_augmentation.py

Browse files
Files changed (1) hide show
  1. data_augmentation.py +328 -365
data_augmentation.py CHANGED
@@ -1,366 +1,329 @@
1
- """
2
- 数据增强模块
3
- 针对不同模态的高级数据增强策略
4
- """
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from typing import Optional, Tuple, List
9
- import random
10
- import math
11
-
12
- class RandAugment(nn.Module):
13
- """RandAugment for images"""
14
- def __init__(self, n: int = 2, m: int = 10):
15
- super().__init__()
16
- self.n = n
17
- self.m = m
18
-
19
- def forward(self, x: torch.Tensor) -> torch.Tensor:
20
- """随机应用n个增强操作"""
21
- # 确保输入是 [B, C, H, W],如果是 [C, H, W] 则增加维度
22
- is_batched = x.ndim == 4
23
- if not is_batched:
24
- x = x.unsqueeze(0)
25
-
26
- augmentations = [
27
- self._auto_contrast,
28
- self._equalize,
29
- self._solarize,
30
- self._color,
31
- self._contrast,
32
- self._brightness,
33
- self._sharpness,
34
- ]
35
-
36
- # 这里的ops应该是每一轮随机选择,而不是固定
37
- for _ in range(self.n):
38
- aug = random.choice(augmentations)
39
- x = aug(x)
40
-
41
- if not is_batched:
42
- x = x.squeeze(0)
43
-
44
- return x
45
-
46
- def _auto_contrast(self, x: torch.Tensor) -> torch.Tensor:
47
- """自动对比度: 线性拉伸到 [0, 1]"""
48
- # 针对每个样本分别计算 min/max
49
- # x: [B, C, H, W]
50
- B, C, H, W = x.shape
51
- x_flat = x.view(B, C, -1)
52
- min_val = x_flat.min(dim=2, keepdim=True)[0].view(B, C, 1, 1)
53
- max_val = x_flat.max(dim=2, keepdim=True)[0].view(B, C, 1, 1)
54
- return (x - min_val) / (max_val - min_val + 1e-8)
55
-
56
- def _equalize(self, x: torch.Tensor) -> torch.Tensor:
57
- """直方图均衡化 (简化版:基于每个通道的CDF)"""
58
- # 这是一个计算密集型操作,PyTorch原生实现较复杂。
59
- # 这里实现一个基于排序的简化版本,模拟均衡化效果
60
- B, C, H, W = x.shape
61
- # 将像素值缩放到 [0, 255] 离散化以便计算直方图
62
- x_int = (x * 255).long().clamp(0, 255)
63
-
64
- out = torch.zeros_like(x)
65
-
66
- for b in range(B):
67
- for c in range(C):
68
- hist = torch.histc(x[b, c].float(), bins=256, min=0, max=1)
69
- cdf = hist.cumsum(0)
70
- cdf = cdf / cdf[-1] # 归一化
71
- # 使用cdf作为查找表
72
- out[b, c] = cdf[x_int[b, c]]
73
-
74
- return out
75
-
76
- def _solarize(self, x: torch.Tensor) -> torch.Tensor:
77
- """曝光"""
78
- threshold = random.uniform(0.3, 0.7)
79
- return torch.where(x < threshold, x, 1.0 - x)
80
-
81
- def _color(self, x: torch.Tensor) -> torch.Tensor:
82
- """颜色增强 (饱和度)"""
83
- factor = 1.0 + (random.random() - 0.5) * 0.4
84
- # RGB转灰度简单近似: mean over channels
85
- # x is [B, C, H, W], dim=1 is channels
86
- mean = x.mean(dim=1, keepdim=True)
87
- return torch.clamp(mean + factor * (x - mean), 0, 1)
88
-
89
- def _contrast(self, x: torch.Tensor) -> torch.Tensor:
90
- """对比度"""
91
- factor = 1.0 + (random.random() - 0.5) * 0.4
92
- # 计算整张图的均值,保留 Batch 维度
93
- # view(B, -1) -> mean(1) -> view(B, 1, 1, 1)
94
- mean = x.view(x.size(0), -1).mean(dim=1).view(-1, 1, 1, 1)
95
- return torch.clamp(mean + factor * (x - mean), 0, 1)
96
-
97
- def _brightness(self, x: torch.Tensor) -> torch.Tensor:
98
- """亮度"""
99
- factor = 1.0 + (random.random() - 0.5) * 0.4
100
- return torch.clamp(x * factor, 0, 1)
101
-
102
- def _sharpness(self, x: torch.Tensor) -> torch.Tensor:
103
- """锐化: 通过混合原图和高斯模糊图实现"""
104
- factor = 1.0 + (random.random() - 0.5) * 0.4
105
- # 使用 AvgPool 模拟模糊
106
- kernel_size = 3
107
- pad = kernel_size // 2
108
- blurred = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad)
109
- # 锐化公式: Original + alpha * (Original - Blurred)
110
- # 或者简单的混合: Blend(Original, Blurred, factor)
111
- # 这里使用 PIL 风格的锐化:
112
- # result = original * factor + blurred * (1 - factor)
113
- # 但要注意 factor>1 时是锐化,factor<1 是模糊
114
- # 更标准的锐化掩模: x + factor * (x - blurred)
115
- return torch.clamp(x + (factor - 1.0) * (x - blurred), 0, 1)
116
-
117
- class MixUp(nn.Module):
118
- """MixUp数据增强"""
119
- def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
120
- super().__init__()
121
- self.alpha = alpha
122
- self.num_classes = num_classes
123
-
124
- def forward(
125
- self,
126
- x: torch.Tensor,
127
- y: Optional[torch.Tensor] = None
128
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
129
-
130
- if self.alpha > 0:
131
- lambda_ = random.betavariate(self.alpha, self.alpha)
132
- else:
133
- lambda_ = 1.0
134
-
135
- batch_size = x.shape[0]
136
- index = torch.randperm(batch_size, device=x.device)
137
-
138
- mixed_x = lambda_ * x + (1 - lambda_) * x[index]
139
-
140
- mixed_y = None
141
- if y is not None:
142
- # 处理标签混合
143
- y_a = y
144
- y_b = y[index]
145
-
146
- # 检查标签是否需要 One-Hot 编码 (如果 y long 类型且维度不对)
147
- if y.dtype == torch.long or y.ndim == 1:
148
- if self.num_classes is None:
149
- # 如果未提供 num_classes,尝试推断 (可能有风险)
150
- self.num_classes = int(y.max().item()) + 1
151
-
152
- y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
153
- y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
154
-
155
- mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
156
-
157
- return mixed_x, mixed_y, lambda_
158
-
159
- class CutMix(nn.Module):
160
- """CutMix数据增强"""
161
- def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
162
- super().__init__()
163
- self.alpha = alpha
164
- self.num_classes = num_classes
165
-
166
- def _rand_bbox(
167
- self,
168
- size: Tuple[int, ...],
169
- lambda_: float
170
- ) -> Tuple[int, int, int, int]:
171
- """生成随机bbox"""
172
- W = size[-1] # 兼容 [B, C, H, W]
173
- H = size[-2]
174
- cut_rat = math.sqrt(1.0 - lambda_)
175
- cut_w = int(W * cut_rat)
176
- cut_h = int(H * cut_rat)
177
-
178
- cx = random.randint(0, W)
179
- cy = random.randint(0, H)
180
-
181
- bbx1 = torch.tensor(cx - cut_w // 2, device='cpu').clamp(0, W).item()
182
- bby1 = torch.tensor(cy - cut_h // 2, device='cpu').clamp(0, H).item()
183
- bbx2 = torch.tensor(cx + cut_w // 2, device='cpu').clamp(0, W).item()
184
- bby2 = torch.tensor(cy + cut_h // 2, device='cpu').clamp(0, H).item()
185
-
186
- return int(bbx1), int(bby1), int(bbx2), int(bby2)
187
-
188
- def forward(
189
- self,
190
- x: torch.Tensor,
191
- y: Optional[torch.Tensor] = None
192
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
193
-
194
- if self.alpha > 0:
195
- lambda_ = random.betavariate(self.alpha, self.alpha)
196
- else:
197
- lambda_ = 1.0
198
-
199
- batch_size = x.shape[0]
200
- index = torch.randperm(batch_size, device=x.device)
201
-
202
- bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lambda_)
203
-
204
- # 克隆防止就地修改影响后续梯度计算 (虽然这里是输入数据处理,通常还好)
205
- x = x.clone()
206
- x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
207
-
208
- # 调整lambda为精确的像素比例
209
- # 注意: 原始代码中宽高的计算顺序可能有歧义,这里统一 H=size[-2], W=size[-1]
210
- H, W = x.size()[-2], x.size()[-1]
211
- lambda_ = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
212
-
213
- mixed_y = None
214
- if y is not None:
215
- y_a = y
216
- y_b = y[index]
217
-
218
- if y.dtype == torch.long or y.ndim == 1:
219
- if self.num_classes is None:
220
- # 最好在初始化时传入 num_classes
221
- self.num_classes = int(y.max().item()) + 1
222
- y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
223
- y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
224
-
225
- mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
226
-
227
- return x, mixed_y, lambda_
228
-
229
- class SpecAugment(nn.Module):
230
- """SpecAugment for audio spectrograms"""
231
- def __init__(
232
- self,
233
- freq_mask_param: int = 27,
234
- time_mask_param: int = 100,
235
- num_freq_masks: int = 2,
236
- num_time_masks: int = 2
237
- ):
238
- super().__init__()
239
- self.freq_mask_param = freq_mask_param
240
- self.time_mask_param = time_mask_param
241
- self.num_freq_masks = num_freq_masks
242
- self.num_time_masks = num_time_masks
243
-
244
- def forward(self, spec: torch.Tensor) -> torch.Tensor:
245
- """
246
- Args:
247
- spec: [B, F, T] or [B, C, F, T]
248
- """
249
- input_ndim = spec.ndim
250
- if input_ndim == 3:
251
- spec = spec.unsqueeze(1) # [B, 1, F, T]
252
-
253
- B, C, F, T = spec.shape
254
- spec = spec.clone()
255
-
256
- # 频率遮罩
257
- for _ in range(self.num_freq_masks):
258
- # 确保 mask 不超过 F
259
- f_param = min(self.freq_mask_param, F)
260
- f = random.randint(0, f_param)
261
- f0 = random.randint(0, max(0, F - f))
262
- spec[:, :, f0:f0+f, :] = 0
263
-
264
- # 时间遮罩
265
- for _ in range(self.num_time_masks):
266
- # 确保 mask 不超过 T
267
- t_param = min(self.time_mask_param, T)
268
- t = random.randint(0, t_param)
269
- t0 = random.randint(0, max(0, T - t))
270
- spec[:, :, :, t0:t0+t] = 0
271
-
272
- if input_ndim == 3:
273
- return spec.squeeze(1)
274
- return spec
275
-
276
- class TemporalMasking(nn.Module):
277
- """视频的时序遮罩"""
278
- def __init__(self, mask_ratio: float = 0.15):
279
- super().__init__()
280
- self.mask_ratio = mask_ratio
281
-
282
- def forward(self, video: torch.Tensor) -> torch.Tensor:
283
- """
284
- Args:
285
- video: [B, T, C, H, W]
286
- """
287
- B, T, C, H, W = video.shape
288
- num_mask = int(T * self.mask_ratio)
289
- if num_mask == 0:
290
- return video
291
-
292
- video = video.clone()
293
-
294
- for b in range(B):
295
- # 随机采样要遮罩的帧索引
296
- mask_indices = torch.randperm(T)[:num_mask]
297
- video[b, mask_indices] = 0
298
-
299
- return video
300
-
301
- class MultiModalAugmentation(nn.Module):
302
- """统一的多模态数据增强"""
303
- def __init__(
304
- self,
305
- image_aug: bool = True,
306
- audio_aug: bool = True,
307
- video_aug: bool = True,
308
- use_mixup: bool = True,
309
- use_cutmix: bool = True,
310
- num_classes: Optional[int] = None
311
- ):
312
- super().__init__()
313
- self.image_aug = RandAugment() if image_aug else None
314
- self.audio_aug = SpecAugment() if audio_aug else None
315
- self.video_aug = TemporalMasking() if video_aug else None
316
-
317
- self.mixup = MixUp(num_classes=num_classes) if use_mixup else None
318
- self.cutmix = CutMix(num_classes=num_classes) if use_cutmix else None
319
-
320
- def forward(
321
- self,
322
- data: torch.Tensor,
323
- modality: str,
324
- labels: Optional[torch.Tensor] = None
325
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
326
- """
327
- Args:
328
- data: 输入数据
329
- modality: 模态类型 ('image', 'audio', 'video')
330
- labels: 标签(可选)
331
- """
332
- # 1. 模态特定的增强 (Intra-sample augmentation)
333
- if modality == 'image' and self.image_aug is not None:
334
- data = self.image_aug(data)
335
- elif modality == 'audio' and self.audio_aug is not None:
336
- data = self.audio_aug(data)
337
- elif modality == 'video' and self.video_aug is not None:
338
- data = self.video_aug(data)
339
-
340
- # 2. 混合增强 (Inter-sample augmentation)
341
- if self.training and labels is not None:
342
- # 随机选择 MixUp 或 CutMix,或者都不选
343
- # 策略:如果有 CutMix 且是图片,50%概率 CutMix;否则看有没有 MixUp
344
-
345
- apply_mixup = False
346
- apply_cutmix = False
347
-
348
- p = random.random()
349
-
350
- # 简单的互斥逻辑:如果有CutMix且是图像,一半概率CutMix,一半概率MixUp(如果有)
351
- if self.cutmix is not None and modality == 'image':
352
- if p < 0.5:
353
- apply_cutmix = True
354
- elif self.mixup is not None:
355
- apply_mixup = True
356
- elif self.mixup is not None:
357
- # 非图像或无CutMix,则只考虑MixUp
358
- if p < 0.5: # 假设 50% 概率应用 MixUp
359
- apply_mixup = True
360
-
361
- if apply_cutmix:
362
- data, labels, _ = self.cutmix(data, labels)
363
- elif apply_mixup:
364
- data, labels, _ = self.mixup(data, labels)
365
-
366
  return data, labels
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, List
5
+ import random
6
+ import math
7
+
8
+ class RandAugment(nn.Module):
9
+ """RandAugment for images"""
10
+ def __init__(self, n: int = 2, m: int = 10):
11
+ super().__init__()
12
+ self.n = n
13
+ self.m = m
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ """随机应用n个增强操作"""
17
+ # 确保输入是 [B, C, H, W],如果是 [C, H, W] 则增加维度
18
+ is_batched = x.ndim == 4
19
+ if not is_batched:
20
+ x = x.unsqueeze(0)
21
+
22
+ augmentations = [
23
+ self._auto_contrast,
24
+ self._equalize,
25
+ self._solarize,
26
+ self._color,
27
+ self._contrast,
28
+ self._brightness,
29
+ self._sharpness,
30
+ ]
31
+
32
+ for _ in range(self.n):
33
+ aug = random.choice(augmentations)
34
+ x = aug(x)
35
+
36
+ if not is_batched:
37
+ x = x.squeeze(0)
38
+
39
+ return x
40
+
41
+ def _auto_contrast(self, x: torch.Tensor) -> torch.Tensor:
42
+ # 针对每个样本分别计算 min/max
43
+ # x: [B, C, H, W]
44
+ B, C, H, W = x.shape
45
+ x_flat = x.view(B, C, -1)
46
+ min_val = x_flat.min(dim=2, keepdim=True)[0].view(B, C, 1, 1)
47
+ max_val = x_flat.max(dim=2, keepdim=True)[0].view(B, C, 1, 1)
48
+ return (x - min_val) / (max_val - min_val + 1e-8)
49
+
50
+ def _equalize(self, x: torch.Tensor) -> torch.Tensor:
51
+ B, C, H, W = x.shape
52
+ x_int = (x * 255).long().clamp(0, 255)
53
+
54
+ out = torch.zeros_like(x)
55
+
56
+ for b in range(B):
57
+ for c in range(C):
58
+ hist = torch.histc(x[b, c].float(), bins=256, min=0, max=1)
59
+ cdf = hist.cumsum(0)
60
+ cdf = cdf / cdf[-1] # 归一化
61
+ # 使用cdf作为查找表
62
+ out[b, c] = cdf[x_int[b, c]]
63
+
64
+ return out
65
+
66
+ def _solarize(self, x: torch.Tensor) -> torch.Tensor:
67
+ threshold = random.uniform(0.3, 0.7)
68
+ return torch.where(x < threshold, x, 1.0 - x)
69
+
70
+ def _color(self, x: torch.Tensor) -> torch.Tensor:
71
+ factor = 1.0 + (random.random() - 0.5) * 0.4
72
+ mean = x.mean(dim=1, keepdim=True)
73
+ return torch.clamp(mean + factor * (x - mean), 0, 1)
74
+
75
+ def _contrast(self, x: torch.Tensor) -> torch.Tensor:
76
+ factor = 1.0 + (random.random() - 0.5) * 0.4
77
+ # 计算整张图的均值,保留 Batch 维度
78
+ # view(B, -1) -> mean(1) -> view(B, 1, 1, 1)
79
+ mean = x.view(x.size(0), -1).mean(dim=1).view(-1, 1, 1, 1)
80
+ return torch.clamp(mean + factor * (x - mean), 0, 1)
81
+
82
+ def _brightness(self, x: torch.Tensor) -> torch.Tensor:
83
+ """亮度"""
84
+ factor = 1.0 + (random.random() - 0.5) * 0.4
85
+ return torch.clamp(x * factor, 0, 1)
86
+
87
+ def _sharpness(self, x: torch.Tensor) -> torch.Tensor:
88
+ """锐化: 通过混合原图和高斯模糊图实现"""
89
+ factor = 1.0 + (random.random() - 0.5) * 0.4
90
+ # 使用 AvgPool 模拟模糊
91
+ kernel_size = 3
92
+ pad = kernel_size // 2
93
+ blurred = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad)
94
+ return torch.clamp(x + (factor - 1.0) * (x - blurred), 0, 1)
95
+
96
+ class MixUp(nn.Module):
97
+ def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
98
+ super().__init__()
99
+ self.alpha = alpha
100
+ self.num_classes = num_classes
101
+
102
+ def forward(
103
+ self,
104
+ x: torch.Tensor,
105
+ y: Optional[torch.Tensor] = None
106
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
107
+
108
+ if self.alpha > 0:
109
+ lambda_ = random.betavariate(self.alpha, self.alpha)
110
+ else:
111
+ lambda_ = 1.0
112
+
113
+ batch_size = x.shape[0]
114
+ index = torch.randperm(batch_size, device=x.device)
115
+
116
+ mixed_x = lambda_ * x + (1 - lambda_) * x[index]
117
+
118
+ mixed_y = None
119
+ if y is not None:
120
+ # 处理标签混合
121
+ y_a = y
122
+ y_b = y[index]
123
+
124
+ if y.dtype == torch.long or y.ndim == 1:
125
+ if self.num_classes is None:
126
+ self.num_classes = int(y.max().item()) + 1
127
+
128
+ y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
129
+ y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
130
+
131
+ mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
132
+
133
+ return mixed_x, mixed_y, lambda_
134
+
135
+ class CutMix(nn.Module):
136
+ def __init__(self, alpha: float = 1.0, num_classes: Optional[int] = None):
137
+ super().__init__()
138
+ self.alpha = alpha
139
+ self.num_classes = num_classes
140
+
141
+ def _rand_bbox(
142
+ self,
143
+ size: Tuple[int, ...],
144
+ lambda_: float
145
+ ) -> Tuple[int, int, int, int]:
146
+ W = size[-1] # 兼容 [B, C, H, W]
147
+ H = size[-2]
148
+ cut_rat = math.sqrt(1.0 - lambda_)
149
+ cut_w = int(W * cut_rat)
150
+ cut_h = int(H * cut_rat)
151
+
152
+ cx = random.randint(0, W)
153
+ cy = random.randint(0, H)
154
+
155
+ bbx1 = torch.tensor(cx - cut_w // 2, device='cpu').clamp(0, W).item()
156
+ bby1 = torch.tensor(cy - cut_h // 2, device='cpu').clamp(0, H).item()
157
+ bbx2 = torch.tensor(cx + cut_w // 2, device='cpu').clamp(0, W).item()
158
+ bby2 = torch.tensor(cy + cut_h // 2, device='cpu').clamp(0, H).item()
159
+
160
+ return int(bbx1), int(bby1), int(bbx2), int(bby2)
161
+
162
+ def forward(
163
+ self,
164
+ x: torch.Tensor,
165
+ y: Optional[torch.Tensor] = None
166
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
167
+
168
+ if self.alpha > 0:
169
+ lambda_ = random.betavariate(self.alpha, self.alpha)
170
+ else:
171
+ lambda_ = 1.0
172
+
173
+ batch_size = x.shape[0]
174
+ index = torch.randperm(batch_size, device=x.device)
175
+
176
+ bbx1, bby1, bbx2, bby2 = self._rand_bbox(x.size(), lambda_)
177
+
178
+ x = x.clone()
179
+ x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
180
+
181
+ H, W = x.size()[-2], x.size()[-1]
182
+ lambda_ = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
183
+
184
+ mixed_y = None
185
+ if y is not None:
186
+ y_a = y
187
+ y_b = y[index]
188
+
189
+ if y.dtype == torch.long or y.ndim == 1:
190
+ if self.num_classes is None:
191
+ self.num_classes = int(y.max().item()) + 1
192
+ y_a = F.one_hot(y_a, num_classes=self.num_classes).float()
193
+ y_b = F.one_hot(y_b, num_classes=self.num_classes).float()
194
+
195
+ mixed_y = lambda_ * y_a + (1 - lambda_) * y_b
196
+
197
+ return x, mixed_y, lambda_
198
+
199
+ class SpecAugment(nn.Module):
200
+ def __init__(
201
+ self,
202
+ freq_mask_param: int = 27,
203
+ time_mask_param: int = 100,
204
+ num_freq_masks: int = 2,
205
+ num_time_masks: int = 2
206
+ ):
207
+ super().__init__()
208
+ self.freq_mask_param = freq_mask_param
209
+ self.time_mask_param = time_mask_param
210
+ self.num_freq_masks = num_freq_masks
211
+ self.num_time_masks = num_time_masks
212
+
213
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ Args:
216
+ spec: [B, F, T] or [B, C, F, T]
217
+ """
218
+ input_ndim = spec.ndim
219
+ if input_ndim == 3:
220
+ spec = spec.unsqueeze(1) # [B, 1, F, T]
221
+
222
+ B, C, F, T = spec.shape
223
+ spec = spec.clone()
224
+
225
+ # 频率遮罩
226
+ for _ in range(self.num_freq_masks):
227
+ # 确保 mask 不超过 F
228
+ f_param = min(self.freq_mask_param, F)
229
+ f = random.randint(0, f_param)
230
+ f0 = random.randint(0, max(0, F - f))
231
+ spec[:, :, f0:f0+f, :] = 0
232
+
233
+ # 时间遮罩
234
+ for _ in range(self.num_time_masks):
235
+ # 确保 mask 不超过 T
236
+ t_param = min(self.time_mask_param, T)
237
+ t = random.randint(0, t_param)
238
+ t0 = random.randint(0, max(0, T - t))
239
+ spec[:, :, :, t0:t0+t] = 0
240
+
241
+ if input_ndim == 3:
242
+ return spec.squeeze(1)
243
+ return spec
244
+
245
+ class TemporalMasking(nn.Module):
246
+ """视频的时序遮罩"""
247
+ def __init__(self, mask_ratio: float = 0.15):
248
+ super().__init__()
249
+ self.mask_ratio = mask_ratio
250
+
251
+ def forward(self, video: torch.Tensor) -> torch.Tensor:
252
+ """
253
+ Args:
254
+ video: [B, T, C, H, W]
255
+ """
256
+ B, T, C, H, W = video.shape
257
+ num_mask = int(T * self.mask_ratio)
258
+ if num_mask == 0:
259
+ return video
260
+
261
+ video = video.clone()
262
+
263
+ for b in range(B):
264
+ # 随机采样要遮罩的帧索引
265
+ mask_indices = torch.randperm(T)[:num_mask]
266
+ video[b, mask_indices] = 0
267
+
268
+ return video
269
+
270
+ class MultiModalAugmentation(nn.Module):
271
+ """统一的多模态数据增强"""
272
+ def __init__(
273
+ self,
274
+ image_aug: bool = True,
275
+ audio_aug: bool = True,
276
+ video_aug: bool = True,
277
+ use_mixup: bool = True,
278
+ use_cutmix: bool = True,
279
+ num_classes: Optional[int] = None
280
+ ):
281
+ super().__init__()
282
+ self.image_aug = RandAugment() if image_aug else None
283
+ self.audio_aug = SpecAugment() if audio_aug else None
284
+ self.video_aug = TemporalMasking() if video_aug else None
285
+
286
+ self.mixup = MixUp(num_classes=num_classes) if use_mixup else None
287
+ self.cutmix = CutMix(num_classes=num_classes) if use_cutmix else None
288
+
289
+ def forward(
290
+ self,
291
+ data: torch.Tensor,
292
+ modality: str,
293
+ labels: Optional[torch.Tensor] = None
294
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
295
+ """
296
+ Args:
297
+ data: 输入数据
298
+ modality: 模态类型 ('image', 'audio', 'video')
299
+ labels: 标签(可选)
300
+ """
301
+ if modality == 'image' and self.image_aug is not None:
302
+ data = self.image_aug(data)
303
+ elif modality == 'audio' and self.audio_aug is not None:
304
+ data = self.audio_aug(data)
305
+ elif modality == 'video' and self.video_aug is not None:
306
+ data = self.video_aug(data)
307
+
308
+ if self.training and labels is not None:
309
+
310
+ apply_mixup = False
311
+ apply_cutmix = False
312
+
313
+ p = random.random()
314
+
315
+ if self.cutmix is not None and modality == 'image':
316
+ if p < 0.5:
317
+ apply_cutmix = True
318
+ elif self.mixup is not None:
319
+ apply_mixup = True
320
+ elif self.mixup is not None:
321
+ if p < 0.5:
322
+ apply_mixup = True
323
+
324
+ if apply_cutmix:
325
+ data, labels, _ = self.cutmix(data, labels)
326
+ elif apply_mixup:
327
+ data, labels, _ = self.mixup(data, labels)
328
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  return data, labels