szxllm commited on
Commit
6328772
·
verified ·
1 Parent(s): 6419b37

Update encoders.py

Browse files
Files changed (1) hide show
  1. encoders.py +515 -558
encoders.py CHANGED
@@ -1,559 +1,516 @@
1
- """
2
- 改进的多模态编码器 - SOTA级别(修复版)
3
- 集成最新的视觉、音频、视频编码技术
4
- """
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from typing import Tuple, Optional
9
- from components import RMSNorm, SwiGLU
10
- from transformer import OptimizedTransformerBlock
11
- import math
12
-
13
- class LayerScale(nn.Module):
14
- """LayerScale - 改进训练稳定性"""
15
- def __init__(self, dim: int, init_values: float = 1e-5):
16
- super().__init__()
17
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
18
-
19
- def forward(self, x: torch.Tensor) -> torch.Tensor:
20
- return x * self.gamma
21
-
22
- class StochasticDepth(nn.Module):
23
- """随机深度 - Drop Path"""
24
- def __init__(self, drop_prob: float = 0.0):
25
- super().__init__()
26
- self.drop_prob = drop_prob
27
-
28
- def forward(self, x: torch.Tensor) -> torch.Tensor:
29
- if not self.training or self.drop_prob == 0.0:
30
- return x
31
-
32
- keep_prob = 1 - self.drop_prob
33
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
34
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
35
- random_tensor.floor_()
36
- return x.div(keep_prob) * random_tensor
37
-
38
- class ImprovedPatchEmbedding(nn.Module):
39
- """改进的图像分块嵌入 - 支持重叠patch和多尺度"""
40
- def __init__(
41
- self,
42
- patch_size: int = 14,
43
- in_channels: int = 3,
44
- embed_dim: int = 2048,
45
- overlap: int = 0
46
- ):
47
- super().__init__()
48
- self.patch_size = patch_size
49
- stride = patch_size - overlap
50
- self.proj = nn.Conv2d(
51
- in_channels,
52
- embed_dim,
53
- kernel_size=patch_size,
54
- stride=stride,
55
- padding=overlap // 2
56
- )
57
-
58
- self.norm = RMSNorm(embed_dim)
59
-
60
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
61
- B, C, H, W = x.shape
62
- x = self.proj(x)
63
- grid_size = (x.shape[2], x.shape[3])
64
- x = x.flatten(2).transpose(1, 2)
65
- x = self.norm(x)
66
- return x, grid_size
67
-
68
- class ImprovedVisionBlock(nn.Module):
69
- """改进的Vision Transformer Block"""
70
- def __init__(
71
- self,
72
- dim: int,
73
- n_heads: int,
74
- dropout: float = 0.0,
75
- drop_path: float = 0.0,
76
- use_adapter: bool = False,
77
- adapter_dim: int = 64,
78
- use_layer_scale: bool = True,
79
- layer_scale_init: float = 1e-5
80
- ):
81
- super().__init__()
82
- self.norm1 = RMSNorm(dim)
83
- self.attn = nn.MultiheadAttention(
84
- dim, n_heads, dropout=dropout, batch_first=True
85
- )
86
-
87
- self.norm2 = RMSNorm(dim)
88
- self.mlp = nn.Sequential(
89
- nn.Linear(dim, dim * 4),
90
- nn.GELU(),
91
- nn.Dropout(dropout),
92
- nn.Linear(dim * 4, dim),
93
- nn.Dropout(dropout)
94
- )
95
-
96
- self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity()
97
-
98
- if use_layer_scale:
99
- self.ls1 = LayerScale(dim, layer_scale_init)
100
- self.ls2 = LayerScale(dim, layer_scale_init)
101
- else:
102
- self.ls1 = nn.Identity()
103
- self.ls2 = nn.Identity()
104
-
105
- # 修复:使用简单的adapter实现,避免外部依赖
106
- if use_adapter:
107
- self.adapter = nn.Sequential(
108
- nn.Linear(dim, adapter_dim),
109
- nn.GELU(),
110
- nn.Linear(adapter_dim, dim)
111
- )
112
- else:
113
- self.adapter = None
114
-
115
- def forward(self, x: torch.Tensor) -> torch.Tensor:
116
- # 注意力
117
- normx = self.norm1(x)
118
- attn_out, _ = self.attn(normx, normx, normx)
119
- x = x + self.drop_path(self.ls1(attn_out))
120
-
121
- # MLP
122
- x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
123
-
124
- # Adapter
125
- if self.adapter is not None:
126
- x = x + self.adapter(x)
127
-
128
- return x
129
-
130
- class ImprovedVisionTransformer(nn.Module):
131
- """
132
- 改进的视觉Transformer
133
- - LayerScale
134
- - Stochastic Depth
135
- - 改进的位置编码
136
- - 可选的Register tokens
137
- """
138
- def __init__(
139
- self,
140
- img_size: int = 224,
141
- patch_size: int = 14,
142
- in_channels: int = 3,
143
- embed_dim: int = 2048,
144
- depth: int = 24,
145
- n_heads: int = 16,
146
- dropout: float = 0.0,
147
- drop_path_rate: float = 0.1,
148
- use_register_tokens: bool = True,
149
- num_register_tokens: int = 4,
150
- use_adapter: bool = False,
151
- adapter_dim: int = 64,
152
- use_layer_scale: bool = True,
153
- layer_scale_init: float = 1e-5
154
- ):
155
- super().__init__()
156
- self.patch_size = patch_size
157
- self.embed_dim = embed_dim
158
- self.use_register_tokens = use_register_tokens
159
- self.num_register_tokens = num_register_tokens if use_register_tokens else 0
160
-
161
- # Patch embedding
162
- self.patch_embed = ImprovedPatchEmbedding(
163
- patch_size, in_channels, embed_dim, overlap=0
164
- )
165
-
166
- self.pretrain_img_size = img_size
167
- n_patches_pretrain = (img_size // patch_size) ** 2
168
-
169
- # CLS token
170
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
171
-
172
- # Register tokens (DINOv2启发)
173
- if use_register_tokens:
174
- self.register_tokens = nn.Parameter(
175
- torch.zeros(1, num_register_tokens, embed_dim)
176
- )
177
-
178
- # 修复:位置编码总数 = 1(CLS) + n_patches + register_tokens
179
- total_tokens = 1 + n_patches_pretrain + self.num_register_tokens
180
- self.pos_embed = nn.Parameter(
181
- torch.zeros(1, total_tokens, embed_dim)
182
- )
183
- self.pos_drop = nn.Dropout(dropout)
184
-
185
- # Stochastic depth
186
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
187
-
188
- # Transformer blocks
189
- self.blocks = nn.ModuleList([
190
- ImprovedVisionBlock(
191
- embed_dim,
192
- n_heads,
193
- dropout,
194
- drop_path=dpr[i],
195
- use_adapter=use_adapter,
196
- adapter_dim=adapter_dim,
197
- use_layer_scale=use_layer_scale,
198
- layer_scale_init=layer_scale_init
199
- )
200
- for i in range(depth)
201
- ])
202
-
203
- self.norm = RMSNorm(embed_dim)
204
- self._init_weights()
205
-
206
- def _init_weights(self):
207
- nn.init.trunc_normal_(self.cls_token, std=0.02)
208
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
209
- if self.use_register_tokens:
210
- nn.init.trunc_normal_(self.register_tokens, std=0.02)
211
-
212
- self.apply(self._init_module_weights)
213
-
214
- def _init_module_weights(self, m):
215
- if isinstance(m, nn.Linear):
216
- nn.init.trunc_normal_(m.weight, std=0.02)
217
- if m.bias is not None:
218
- nn.init.zeros_(m.bias)
219
- elif isinstance(m, nn.Conv2d):
220
- nn.init.trunc_normal_(m.weight, std=0.02)
221
- if m.bias is not None:
222
- nn.init.zeros_(m.bias)
223
- elif isinstance(m, RMSNorm):
224
- if hasattr(m, 'weight') and m.weight is not None:
225
- nn.init.ones_(m.weight)
226
-
227
- def _interpolate_pos_encoding(
228
- self,
229
- patch_tokens: torch.Tensor,
230
- grid_size: Tuple[int, int]
231
- ) -> torch.Tensor:
232
- """
233
- 修复:改进的位置编码插值
234
- 只对patch位置编码进行插值,CLS和register token位置编码保持不变
235
- """
236
- pretrain_grid_h = self.pretrain_img_size // self.patch_size
237
- pretrain_grid_w = pretrain_grid_h
238
-
239
- # 如果尺寸匹配,直接返回原始位置编码
240
- if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w:
241
- return self.pos_embed
242
-
243
- # 分离不同部分的位置编码
244
- # pos_embed结构: [CLS(1), register_tokens(n), patches(H*W)]
245
- num_extra_tokens = 1 + self.num_register_tokens
246
- cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] # [1, 1+n, dim]
247
- patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] # [1, H*W, dim]
248
-
249
- # 2D插值patch位置编码
250
- patch_pos_embed = patch_pos_embed.reshape(
251
- 1, pretrain_grid_h, pretrain_grid_w, -1
252
- ).permute(0, 3, 1, 2)
253
-
254
- patch_pos_embed = F.interpolate(
255
- patch_pos_embed,
256
- size=grid_size,
257
- mode='bicubic',
258
- align_corners=False
259
- )
260
-
261
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
262
-
263
- # 拼接回去
264
- return torch.cat([cls_register_pos, patch_pos_embed], dim=1)
265
-
266
- def forward(self, x: torch.Tensor) -> torch.Tensor:
267
- B = x.shape[0]
268
-
269
- # Patch embedding
270
- x, grid_size = self.patch_embed(x)
271
-
272
- # 添加CLS token
273
- cls_tokens = self.cls_token.expand(B, -1, -1)
274
-
275
- # 修复:正确组装tokens序列
276
- if self.use_register_tokens:
277
- register_tokens = self.register_tokens.expand(B, -1, -1)
278
- # 顺序: [CLS, register_tokens, patches]
279
- x = torch.cat([cls_tokens, register_tokens, x], dim=1)
280
- else:
281
- x = torch.cat([cls_tokens, x], dim=1)
282
-
283
- # 位置编码(插值以适应不同尺寸)
284
- pos_embed = self._interpolate_pos_encoding(x, grid_size)
285
- x = self.pos_drop(x + pos_embed)
286
-
287
- # Transformer blocks
288
- for block in self.blocks:
289
- x = block(x)
290
-
291
- x = self.norm(x)
292
-
293
- # 返回所有tokens(调用者可以选择使用CLS token或全局池化)
294
- return x
295
-
296
- class ImprovedAudioEncoder(nn.Module):
297
- """
298
- 改进的音频编码器
299
- - 时序建模
300
- - 频率建模
301
- - 双流架构
302
- """
303
- def __init__(
304
- self,
305
- n_mels: int = 128,
306
- target_length: int = 1024,
307
- embed_dim: int = 2048,
308
- depth: int = 12,
309
- n_heads: int = 16,
310
- patch_size: int = 16,
311
- dropout: float = 0.1,
312
- use_adapter: bool = False,
313
- adapter_dim: int = 64,
314
- use_dual_stream: bool = True
315
- ):
316
- super().__init__()
317
- self.use_dual_stream = use_dual_stream
318
- self.patch_size = patch_size
319
-
320
- # 主编码器
321
- self.patch_embed = nn.Conv2d(
322
- 1, embed_dim, kernel_size=patch_size, stride=patch_size
323
- )
324
-
325
- # 修复:计算实际的patch数量
326
- self.n_patches_h = n_mels // patch_size
327
- self.n_patches_w = target_length // patch_size
328
- n_patches = self.n_patches_h * self.n_patches_w
329
-
330
- self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim))
331
- self.pos_drop = nn.Dropout(dropout)
332
-
333
- # Transformer blocks
334
- self.blocks = nn.ModuleList([
335
- OptimizedTransformerBlock(
336
- embed_dim, n_heads, None, None, dropout,
337
- use_adapter=use_adapter, adapter_dim=adapter_dim
338
- )
339
- for _ in range(depth)
340
- ])
341
-
342
- # 双流:时间流和频率流
343
- if use_dual_stream:
344
- # 修复:使用正确的池化维度
345
- self.temporal_pool = nn.AdaptiveAvgPool1d(1)
346
- self.frequency_pool = nn.AdaptiveAvgPool1d(1)
347
-
348
- self.temporal_proj = nn.Linear(embed_dim, embed_dim)
349
- self.frequency_proj = nn.Linear(embed_dim, embed_dim)
350
-
351
- self.fusion = nn.Linear(embed_dim * 2, embed_dim)
352
-
353
- self.norm = RMSNorm(embed_dim)
354
- self._init_weights()
355
-
356
- def _init_weights(self):
357
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
358
- self.apply(self._init_module_weights)
359
-
360
- def _init_module_weights(self, m):
361
- if isinstance(m, nn.Linear):
362
- nn.init.trunc_normal_(m.weight, std=0.02)
363
- if m.bias is not None:
364
- nn.init.zeros_(m.bias)
365
- elif isinstance(m, nn.Conv2d):
366
- nn.init.trunc_normal_(m.weight, std=0.02)
367
- if m.bias is not None:
368
- nn.init.zeros_(m.bias)
369
-
370
- def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
371
- if mel_spec.ndim == 3:
372
- mel_spec = mel_spec.unsqueeze(1)
373
-
374
- # Patch embedding
375
- x = self.patch_embed(mel_spec) # [B, C, H, W]
376
- x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
377
- x = self.pos_drop(x + self.pos_embed)
378
-
379
- # Transformer encoding
380
- for block in self.blocks:
381
- x, _, _ = block(x)
382
-
383
- x = self.norm(x)
384
-
385
- # 修复:双流处理
386
- if self.use_dual_stream:
387
- B, N, C = x.shape
388
-
389
- # 重塑为2D网格
390
- x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w)
391
-
392
- # 时间流:沿频率维度池化(保留时间)
393
- temporal = x_2d.mean(dim=2) # [B, C, W]
394
- temporal = self.temporal_pool(temporal).squeeze(-1) # [B, C]
395
- temporal = self.temporal_proj(temporal).unsqueeze(1) # [B, 1, C]
396
-
397
- # 频率流:沿时间维度池化(保留频率)
398
- frequency = x_2d.mean(dim=3) # [B, C, H]
399
- frequency = self.frequency_pool(frequency).squeeze(-1) # [B, C]
400
- frequency = self.frequency_proj(frequency).unsqueeze(1) # [B, 1, C]
401
-
402
- # 融合
403
- x = self.fusion(torch.cat([temporal, frequency], dim=-1))
404
- else:
405
- # 简单全局平均池化
406
- x = x.mean(dim=1, keepdim=True)
407
-
408
- return x
409
-
410
- class ImprovedVideoEncoder(nn.Module):
411
- """
412
- 改进的视频编码器
413
- - 因果时序建模
414
- - 时空分离注意力
415
- - 可选的3D卷积
416
- """
417
- def __init__(
418
- self,
419
- img_size: int = 224,
420
- patch_size: int = 14,
421
- in_channels: int = 3,
422
- embed_dim: int = 2048,
423
- spatial_depth: int = 12,
424
- temporal_depth: int = 4,
425
- n_heads: int = 16,
426
- num_frames: int = 16,
427
- dropout: float = 0.1,
428
- use_adapter: bool = False,
429
- adapter_dim: int = 64,
430
- use_3d_conv: bool = False
431
- ):
432
- super().__init__()
433
- self.num_frames = num_frames
434
- self.use_3d_conv = use_3d_conv
435
- self.patch_size = patch_size
436
- self.img_size = img_size
437
-
438
- if use_3d_conv:
439
- # 3D卷积处理���空信息
440
- self.patch_embed = nn.Conv3d(
441
- in_channels,
442
- embed_dim,
443
- kernel_size=(2, patch_size, patch_size),
444
- stride=(2, patch_size, patch_size)
445
- )
446
- # 修复:计算3D卷积后的尺寸
447
- self.n_temporal_patches = num_frames // 2
448
- self.n_spatial_patches = (img_size // patch_size) ** 2
449
- else:
450
- # 2D卷积 + 时序建模
451
- self.patch_embed = ImprovedPatchEmbedding(
452
- patch_size, in_channels, embed_dim
453
- )
454
- self.n_spatial_patches = (img_size // patch_size) ** 2
455
-
456
- # 空间位置编码
457
- self.spatial_pos_embed = nn.Parameter(
458
- torch.zeros(1, self.n_spatial_patches, embed_dim)
459
- )
460
- self.spatial_pos_drop = nn.Dropout(dropout)
461
-
462
- # 空间编码器
463
- self.spatial_blocks = nn.ModuleList([
464
- OptimizedTransformerBlock(
465
- embed_dim, n_heads, None, None, dropout,
466
- use_adapter=use_adapter, adapter_dim=adapter_dim
467
- )
468
- for _ in range(spatial_depth)
469
- ])
470
-
471
- # 时间位置编码
472
- if use_3d_conv:
473
- self.temporal_pos_embed = nn.Parameter(
474
- torch.zeros(1, self.n_temporal_patches, embed_dim)
475
- )
476
- else:
477
- self.temporal_pos_embed = nn.Parameter(
478
- torch.zeros(1, num_frames, embed_dim)
479
- )
480
- self.temporal_pos_drop = nn.Dropout(dropout)
481
-
482
- # 时序编码器
483
- self.temporal_blocks = nn.ModuleList([
484
- OptimizedTransformerBlock(
485
- embed_dim, n_heads, None, None, dropout,
486
- use_adapter=use_adapter, adapter_dim=adapter_dim
487
- )
488
- for _ in range(temporal_depth)
489
- ])
490
-
491
- self.norm = RMSNorm(embed_dim)
492
- self._init_weights()
493
-
494
- def _init_weights(self):
495
- nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02)
496
- nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02)
497
- self.apply(self._init_module_weights)
498
-
499
- def _init_module_weights(self, m):
500
- if isinstance(m, nn.Linear):
501
- nn.init.trunc_normal_(m.weight, std=0.02)
502
- if m.bias is not None:
503
- nn.init.zeros_(m.bias)
504
- elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
505
- nn.init.trunc_normal_(m.weight, std=0.02)
506
- if m.bias is not None:
507
- nn.init.zeros_(m.bias)
508
-
509
- def forward(self, x: torch.Tensor) -> torch.Tensor:
510
- B, T, C, H, W = x.shape
511
-
512
- if self.use_3d_conv:
513
- # 修复:3D卷积路径
514
- x = x.transpose(1, 2) # [B, C, T, H, W]
515
- x = self.patch_embed(x) # [B, embed_dim, T', H', W']
516
-
517
- # 重塑: [B, D, T', H'*W'] -> [B, T', H'*W', D]
518
- B, D, T_new, H_new, W_new = x.shape
519
- x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) # [B, T', H'*W', D]
520
-
521
- # 空间位置编码(每帧独立)
522
- x = x + self.spatial_pos_embed.unsqueeze(1)
523
-
524
- # 逐帧空间编码
525
- x_flat = x.reshape(B * T_new, -1, D)
526
- for block in self.spatial_blocks:
527
- x_flat, _, _ = block(x_flat)
528
-
529
- # 重塑回时序维度
530
- x = x_flat.view(B, T_new, -1, D)
531
-
532
- # 修复:时序聚合 - 使用平均池化而非取第一个token
533
- x = x.mean(dim=2) # [B, T', D]
534
-
535
- else:
536
- # 2D卷积 + 分离时空建模
537
- x_flat = x.view(B * T, C, H, W)
538
- x_patched, grid_size = self.patch_embed(x_flat)
539
-
540
- # 空间位置编码
541
- x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed)
542
-
543
- # 空间编码
544
- for block in self.spatial_blocks:
545
- x_patched, _, _ = block(x_patched)
546
-
547
- # 修复:时序聚合 - 全局平均池化而非仅mean(dim=2)
548
- _, N, D = x_patched.shape
549
- x_spatial = x_patched.view(B, T, N, D)
550
- x = x_spatial.mean(dim=2) # [B, T, D] - 对每帧的所有patch取平均
551
-
552
- # 时序位置编码
553
- x = self.temporal_pos_drop(x + self.temporal_pos_embed)
554
-
555
- # 时序编码
556
- for block in self.temporal_blocks:
557
- x, _, _ = block(x)
558
-
559
  return self.norm(x)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Optional
5
+ from components import RMSNorm, SwiGLU
6
+ from transformer import OptimizedTransformerBlock
7
+ import math
8
+
9
+ class LayerScale(nn.Module):
10
+ def __init__(self, dim: int, init_values: float = 1e-5):
11
+ super().__init__()
12
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return x * self.gamma
16
+
17
+ class StochasticDepth(nn.Module):
18
+ def __init__(self, drop_prob: float = 0.0):
19
+ super().__init__()
20
+ self.drop_prob = drop_prob
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ if not self.training or self.drop_prob == 0.0:
24
+ return x
25
+
26
+ keep_prob = 1 - self.drop_prob
27
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
28
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
29
+ random_tensor.floor_()
30
+ return x.div(keep_prob) * random_tensor
31
+
32
+ class ImprovedPatchEmbedding(nn.Module):
33
+ def __init__(
34
+ self,
35
+ patch_size: int = 14,
36
+ in_channels: int = 3,
37
+ embed_dim: int = 2048,
38
+ overlap: int = 0
39
+ ):
40
+ super().__init__()
41
+ self.patch_size = patch_size
42
+ stride = patch_size - overlap
43
+ self.proj = nn.Conv2d(
44
+ in_channels,
45
+ embed_dim,
46
+ kernel_size=patch_size,
47
+ stride=stride,
48
+ padding=overlap // 2
49
+ )
50
+
51
+ self.norm = RMSNorm(embed_dim)
52
+
53
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
54
+ B, C, H, W = x.shape
55
+ x = self.proj(x)
56
+ grid_size = (x.shape[2], x.shape[3])
57
+ x = x.flatten(2).transpose(1, 2)
58
+ x = self.norm(x)
59
+ return x, grid_size
60
+
61
+ class ImprovedVisionBlock(nn.Module):
62
+ def __init__(
63
+ self,
64
+ dim: int,
65
+ n_heads: int,
66
+ dropout: float = 0.0,
67
+ drop_path: float = 0.0,
68
+ use_adapter: bool = False,
69
+ adapter_dim: int = 64,
70
+ use_layer_scale: bool = True,
71
+ layer_scale_init: float = 1e-5
72
+ ):
73
+ super().__init__()
74
+ self.norm1 = RMSNorm(dim)
75
+ self.attn = nn.MultiheadAttention(
76
+ dim, n_heads, dropout=dropout, batch_first=True
77
+ )
78
+
79
+ self.norm2 = RMSNorm(dim)
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(dim, dim * 4),
82
+ nn.GELU(),
83
+ nn.Dropout(dropout),
84
+ nn.Linear(dim * 4, dim),
85
+ nn.Dropout(dropout)
86
+ )
87
+
88
+ self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity()
89
+
90
+ if use_layer_scale:
91
+ self.ls1 = LayerScale(dim, layer_scale_init)
92
+ self.ls2 = LayerScale(dim, layer_scale_init)
93
+ else:
94
+ self.ls1 = nn.Identity()
95
+ self.ls2 = nn.Identity()
96
+
97
+ if use_adapter:
98
+ self.adapter = nn.Sequential(
99
+ nn.Linear(dim, adapter_dim),
100
+ nn.GELU(),
101
+ nn.Linear(adapter_dim, dim)
102
+ )
103
+ else:
104
+ self.adapter = None
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ # 注意力
108
+ normx = self.norm1(x)
109
+ attn_out, _ = self.attn(normx, normx, normx)
110
+ x = x + self.drop_path(self.ls1(attn_out))
111
+
112
+ # MLP
113
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
114
+
115
+ # Adapter
116
+ if self.adapter is not None:
117
+ x = x + self.adapter(x)
118
+
119
+ return x
120
+
121
+ class ImprovedVisionTransformer(nn.Module):
122
+ def __init__(
123
+ self,
124
+ img_size: int = 224,
125
+ patch_size: int = 14,
126
+ in_channels: int = 3,
127
+ embed_dim: int = 2048,
128
+ depth: int = 24,
129
+ n_heads: int = 16,
130
+ dropout: float = 0.0,
131
+ drop_path_rate: float = 0.1,
132
+ use_register_tokens: bool = True,
133
+ num_register_tokens: int = 4,
134
+ use_adapter: bool = False,
135
+ adapter_dim: int = 64,
136
+ use_layer_scale: bool = True,
137
+ layer_scale_init: float = 1e-5
138
+ ):
139
+ super().__init__()
140
+ self.patch_size = patch_size
141
+ self.embed_dim = embed_dim
142
+ self.use_register_tokens = use_register_tokens
143
+ self.num_register_tokens = num_register_tokens if use_register_tokens else 0
144
+
145
+ # Patch embedding
146
+ self.patch_embed = ImprovedPatchEmbedding(
147
+ patch_size, in_channels, embed_dim, overlap=0
148
+ )
149
+
150
+ self.pretrain_img_size = img_size
151
+ n_patches_pretrain = (img_size // patch_size) ** 2
152
+
153
+ # CLS token
154
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
155
+
156
+ # Register tokens
157
+ if use_register_tokens:
158
+ self.register_tokens = nn.Parameter(
159
+ torch.zeros(1, num_register_tokens, embed_dim)
160
+ )
161
+
162
+ total_tokens = 1 + n_patches_pretrain + self.num_register_tokens
163
+ self.pos_embed = nn.Parameter(
164
+ torch.zeros(1, total_tokens, embed_dim)
165
+ )
166
+ self.pos_drop = nn.Dropout(dropout)
167
+
168
+ # Stochastic depth
169
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
170
+
171
+ # Transformer blocks
172
+ self.blocks = nn.ModuleList([
173
+ ImprovedVisionBlock(
174
+ embed_dim,
175
+ n_heads,
176
+ dropout,
177
+ drop_path=dpr[i],
178
+ use_adapter=use_adapter,
179
+ adapter_dim=adapter_dim,
180
+ use_layer_scale=use_layer_scale,
181
+ layer_scale_init=layer_scale_init
182
+ )
183
+ for i in range(depth)
184
+ ])
185
+
186
+ self.norm = RMSNorm(embed_dim)
187
+ self._init_weights()
188
+
189
+ def _init_weights(self):
190
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
191
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
192
+ if self.use_register_tokens:
193
+ nn.init.trunc_normal_(self.register_tokens, std=0.02)
194
+
195
+ self.apply(self._init_module_weights)
196
+
197
+ def _init_module_weights(self, m):
198
+ if isinstance(m, nn.Linear):
199
+ nn.init.trunc_normal_(m.weight, std=0.02)
200
+ if m.bias is not None:
201
+ nn.init.zeros_(m.bias)
202
+ elif isinstance(m, nn.Conv2d):
203
+ nn.init.trunc_normal_(m.weight, std=0.02)
204
+ if m.bias is not None:
205
+ nn.init.zeros_(m.bias)
206
+ elif isinstance(m, RMSNorm):
207
+ if hasattr(m, 'weight') and m.weight is not None:
208
+ nn.init.ones_(m.weight)
209
+
210
+ def _interpolate_pos_encoding(
211
+ self,
212
+ patch_tokens: torch.Tensor,
213
+ grid_size: Tuple[int, int]
214
+ ) -> torch.Tensor:
215
+ pretrain_grid_h = self.pretrain_img_size // self.patch_size
216
+ pretrain_grid_w = pretrain_grid_h
217
+
218
+ # 如果尺寸匹配,直接返回原始位置编码
219
+ if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w:
220
+ return self.pos_embed
221
+
222
+ # 分离不同部分的位置编码
223
+ # pos_embed结构: [CLS(1), register_tokens(n), patches(H*W)]
224
+ num_extra_tokens = 1 + self.num_register_tokens
225
+ cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] # [1, 1+n, dim]
226
+ patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] # [1, H*W, dim]
227
+
228
+ # 2D插值patch位置编码
229
+ patch_pos_embed = patch_pos_embed.reshape(
230
+ 1, pretrain_grid_h, pretrain_grid_w, -1
231
+ ).permute(0, 3, 1, 2)
232
+
233
+ patch_pos_embed = F.interpolate(
234
+ patch_pos_embed,
235
+ size=grid_size,
236
+ mode='bicubic',
237
+ align_corners=False
238
+ )
239
+
240
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
241
+
242
+ # 拼接回去
243
+ return torch.cat([cls_register_pos, patch_pos_embed], dim=1)
244
+
245
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
246
+ B = x.shape[0]
247
+
248
+ # Patch embedding
249
+ x, grid_size = self.patch_embed(x)
250
+
251
+ # 添加CLS token
252
+ cls_tokens = self.cls_token.expand(B, -1, -1)
253
+
254
+ if self.use_register_tokens:
255
+ register_tokens = self.register_tokens.expand(B, -1, -1)
256
+ # 顺序: [CLS, register_tokens, patches]
257
+ x = torch.cat([cls_tokens, register_tokens, x], dim=1)
258
+ else:
259
+ x = torch.cat([cls_tokens, x], dim=1)
260
+
261
+ # 位置编码
262
+ pos_embed = self._interpolate_pos_encoding(x, grid_size)
263
+ x = self.pos_drop(x + pos_embed)
264
+
265
+ # Transformer blocks
266
+ for block in self.blocks:
267
+ x = block(x)
268
+
269
+ x = self.norm(x)
270
+
271
+ return x
272
+
273
+ class ImprovedAudioEncoder(nn.Module):
274
+ def __init__(
275
+ self,
276
+ n_mels: int = 128,
277
+ target_length: int = 1024,
278
+ embed_dim: int = 2048,
279
+ depth: int = 12,
280
+ n_heads: int = 16,
281
+ patch_size: int = 16,
282
+ dropout: float = 0.1,
283
+ use_adapter: bool = False,
284
+ adapter_dim: int = 64,
285
+ use_dual_stream: bool = True
286
+ ):
287
+ super().__init__()
288
+ self.use_dual_stream = use_dual_stream
289
+ self.patch_size = patch_size
290
+
291
+ # 主编码器
292
+ self.patch_embed = nn.Conv2d(
293
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size
294
+ )
295
+
296
+ self.n_patches_h = n_mels // patch_size
297
+ self.n_patches_w = target_length // patch_size
298
+ n_patches = self.n_patches_h * self.n_patches_w
299
+
300
+ self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim))
301
+ self.pos_drop = nn.Dropout(dropout)
302
+
303
+ # Transformer blocks
304
+ self.blocks = nn.ModuleList([
305
+ OptimizedTransformerBlock(
306
+ embed_dim, n_heads, None, None, dropout,
307
+ use_adapter=use_adapter, adapter_dim=adapter_dim
308
+ )
309
+ for _ in range(depth)
310
+ ])
311
+
312
+ # 双流:时间流和频率流
313
+ if use_dual_stream:
314
+ self.temporal_pool = nn.AdaptiveAvgPool1d(1)
315
+ self.frequency_pool = nn.AdaptiveAvgPool1d(1)
316
+
317
+ self.temporal_proj = nn.Linear(embed_dim, embed_dim)
318
+ self.frequency_proj = nn.Linear(embed_dim, embed_dim)
319
+
320
+ self.fusion = nn.Linear(embed_dim * 2, embed_dim)
321
+
322
+ self.norm = RMSNorm(embed_dim)
323
+ self._init_weights()
324
+
325
+ def _init_weights(self):
326
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
327
+ self.apply(self._init_module_weights)
328
+
329
+ def _init_module_weights(self, m):
330
+ if isinstance(m, nn.Linear):
331
+ nn.init.trunc_normal_(m.weight, std=0.02)
332
+ if m.bias is not None:
333
+ nn.init.zeros_(m.bias)
334
+ elif isinstance(m, nn.Conv2d):
335
+ nn.init.trunc_normal_(m.weight, std=0.02)
336
+ if m.bias is not None:
337
+ nn.init.zeros_(m.bias)
338
+
339
+ def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
340
+ if mel_spec.ndim == 3:
341
+ mel_spec = mel_spec.unsqueeze(1)
342
+
343
+ # Patch embedding
344
+ x = self.patch_embed(mel_spec) # [B, C, H, W]
345
+ x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
346
+ x = self.pos_drop(x + self.pos_embed)
347
+
348
+ # Transformer encoding
349
+ for block in self.blocks:
350
+ x, _, _ = block(x)
351
+
352
+ x = self.norm(x)
353
+
354
+ if self.use_dual_stream:
355
+ B, N, C = x.shape
356
+
357
+ # 重塑为2D网格
358
+ x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w)
359
+
360
+ # 时间流:沿频率维度池化(保留时间)
361
+ temporal = x_2d.mean(dim=2) # [B, C, W]
362
+ temporal = self.temporal_pool(temporal).squeeze(-1) # [B, C]
363
+ temporal = self.temporal_proj(temporal).unsqueeze(1) # [B, 1, C]
364
+
365
+ # 频率流:沿时间维度池化(保留频率)
366
+ frequency = x_2d.mean(dim=3) # [B, C, H]
367
+ frequency = self.frequency_pool(frequency).squeeze(-1) # [B, C]
368
+ frequency = self.frequency_proj(frequency).unsqueeze(1) # [B, 1, C]
369
+
370
+ # 融合
371
+ x = self.fusion(torch.cat([temporal, frequency], dim=-1))
372
+ else:
373
+ # 简单全局平均池化
374
+ x = x.mean(dim=1, keepdim=True)
375
+
376
+ return x
377
+
378
+ class ImprovedVideoEncoder(nn.Module):
379
+ def __init__(
380
+ self,
381
+ img_size: int = 224,
382
+ patch_size: int = 14,
383
+ in_channels: int = 3,
384
+ embed_dim: int = 2048,
385
+ spatial_depth: int = 12,
386
+ temporal_depth: int = 4,
387
+ n_heads: int = 16,
388
+ num_frames: int = 16,
389
+ dropout: float = 0.1,
390
+ use_adapter: bool = False,
391
+ adapter_dim: int = 64,
392
+ use_3d_conv: bool = False
393
+ ):
394
+ super().__init__()
395
+ self.num_frames = num_frames
396
+ self.use_3d_conv = use_3d_conv
397
+ self.patch_size = patch_size
398
+ self.img_size = img_size
399
+
400
+ if use_3d_conv:
401
+ # 3D卷积处理时空信息
402
+ self.patch_embed = nn.Conv3d(
403
+ in_channels,
404
+ embed_dim,
405
+ kernel_size=(2, patch_size, patch_size),
406
+ stride=(2, patch_size, patch_size)
407
+ )
408
+ self.n_temporal_patches = num_frames // 2
409
+ self.n_spatial_patches = (img_size // patch_size) ** 2
410
+ else:
411
+ # 2D卷积 + 时序建模
412
+ self.patch_embed = ImprovedPatchEmbedding(
413
+ patch_size, in_channels, embed_dim
414
+ )
415
+ self.n_spatial_patches = (img_size // patch_size) ** 2
416
+
417
+ # 空间位置编码
418
+ self.spatial_pos_embed = nn.Parameter(
419
+ torch.zeros(1, self.n_spatial_patches, embed_dim)
420
+ )
421
+ self.spatial_pos_drop = nn.Dropout(dropout)
422
+
423
+ # 空间编码器
424
+ self.spatial_blocks = nn.ModuleList([
425
+ OptimizedTransformerBlock(
426
+ embed_dim, n_heads, None, None, dropout,
427
+ use_adapter=use_adapter, adapter_dim=adapter_dim
428
+ )
429
+ for _ in range(spatial_depth)
430
+ ])
431
+
432
+ # 时间位置编码
433
+ if use_3d_conv:
434
+ self.temporal_pos_embed = nn.Parameter(
435
+ torch.zeros(1, self.n_temporal_patches, embed_dim)
436
+ )
437
+ else:
438
+ self.temporal_pos_embed = nn.Parameter(
439
+ torch.zeros(1, num_frames, embed_dim)
440
+ )
441
+ self.temporal_pos_drop = nn.Dropout(dropout)
442
+
443
+ # 时序编码器
444
+ self.temporal_blocks = nn.ModuleList([
445
+ OptimizedTransformerBlock(
446
+ embed_dim, n_heads, None, None, dropout,
447
+ use_adapter=use_adapter, adapter_dim=adapter_dim
448
+ )
449
+ for _ in range(temporal_depth)
450
+ ])
451
+
452
+ self.norm = RMSNorm(embed_dim)
453
+ self._init_weights()
454
+
455
+ def _init_weights(self):
456
+ nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02)
457
+ nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02)
458
+ self.apply(self._init_module_weights)
459
+
460
+ def _init_module_weights(self, m):
461
+ if isinstance(m, nn.Linear):
462
+ nn.init.trunc_normal_(m.weight, std=0.02)
463
+ if m.bias is not None:
464
+ nn.init.zeros_(m.bias)
465
+ elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
466
+ nn.init.trunc_normal_(m.weight, std=0.02)
467
+ if m.bias is not None:
468
+ nn.init.zeros_(m.bias)
469
+
470
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
471
+ B, T, C, H, W = x.shape
472
+
473
+ if self.use_3d_conv:
474
+ x = x.transpose(1, 2) # [B, C, T, H, W]
475
+ x = self.patch_embed(x) # [B, embed_dim, T', H', W']
476
+
477
+ # 重塑: [B, D, T', H'*W'] -> [B, T', H'*W', D]
478
+ B, D, T_new, H_new, W_new = x.shape
479
+ x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) # [B, T', H'*W', D]
480
+
481
+ # 空间位置编码(每帧独立)
482
+ x = x + self.spatial_pos_embed.unsqueeze(1)
483
+
484
+ # 逐帧空间编码
485
+ x_flat = x.reshape(B * T_new, -1, D)
486
+ for block in self.spatial_blocks:
487
+ x_flat, _, _ = block(x_flat)
488
+
489
+ # 重塑回时序维度
490
+ x = x_flat.view(B, T_new, -1, D)
491
+ x = x.mean(dim=2) # [B, T', D]
492
+
493
+ else:
494
+ # 2D卷积 + 分离时空建模
495
+ x_flat = x.view(B * T, C, H, W)
496
+ x_patched, grid_size = self.patch_embed(x_flat)
497
+
498
+ # 空间位置编码
499
+ x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed)
500
+
501
+ # 空间编码
502
+ for block in self.spatial_blocks:
503
+ x_patched, _, _ = block(x_patched)
504
+
505
+ _, N, D = x_patched.shape
506
+ x_spatial = x_patched.view(B, T, N, D)
507
+ x = x_spatial.mean(dim=2)
508
+
509
+ # 时序位置编码
510
+ x = self.temporal_pos_drop(x + self.temporal_pos_embed)
511
+
512
+ # 时序编码
513
+ for block in self.temporal_blocks:
514
+ x, _, _ = block(x)
515
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  return self.norm(x)