szxllm commited on
Commit
9223e06
·
verified ·
1 Parent(s): 958b4f3

Update multimodel_fusion.py

Browse files
Files changed (1) hide show
  1. multimodel_fusion.py +474 -521
multimodel_fusion.py CHANGED
@@ -1,522 +1,475 @@
1
- """
2
- 跨模态融合模块 - SOTA级别
3
- 支持深度跨模态交互、对比学习、模态对齐
4
- 修复版本:解决了所有接口不匹配和潜在bug
5
- """
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from typing import Dict, List, Optional, Tuple, Union
10
- from components import RMSNorm
11
- from transformer import GroupedQueryAttention
12
- import math
13
- from contrastive_learning import MultiModalContrastiveLoss
14
-
15
-
16
- class CrossModalAttention(nn.Module):
17
- """跨模态注意力 - 允许不同模态之间的信息交互"""
18
- def __init__(
19
- self,
20
- dim: int,
21
- n_heads: int = 16,
22
- dropout: float = 0.1,
23
- qkv_bias: bool = True
24
- ):
25
- super().__init__()
26
- self.dim = dim
27
- self.n_heads = n_heads
28
- self.head_dim = dim // n_heads
29
- self.scale = self.head_dim ** -0.5
30
-
31
- assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
32
-
33
- self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
34
- self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
35
- self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
36
- self.o_proj = nn.Linear(dim, dim)
37
-
38
- self.attn_dropout = nn.Dropout(dropout)
39
- self.resid_dropout = nn.Dropout(dropout)
40
-
41
- self.norm_q = RMSNorm(dim)
42
- self.norm_k = RMSNorm(dim)
43
-
44
- def forward(
45
- self,
46
- query: torch.Tensor,
47
- key: torch.Tensor,
48
- value: torch.Tensor,
49
- attention_mask: Optional[torch.Tensor] = None
50
- ) -> torch.Tensor:
51
- """
52
- Args:
53
- query: [B, T_q, D] - 查询模态
54
- key: [B, T_k, D] - 键模态
55
- value: [B, T_v, D] - 值模态 (通常与key相同)
56
- """
57
- B, T_q, D = query.shape
58
- T_k = key.shape[1]
59
-
60
- # 归一化
61
- query = self.norm_q(query)
62
- key = self.norm_k(key)
63
-
64
- # 投影并重塑
65
- q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
66
- k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
67
- v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
68
-
69
- # 使用Flash Attention或手动实现
70
- if hasattr(F, 'scaled_dot_product_attention'):
71
- dropout_p = self.attn_dropout.p if self.training else 0.0
72
- attn_output = F.scaled_dot_product_attention(
73
- q, k, v,
74
- attn_mask=attention_mask,
75
- dropout_p=dropout_p,
76
- is_causal=False
77
- )
78
- else:
79
- attn_scores = (q @ k.transpose(-2, -1)) * self.scale
80
- if attention_mask is not None:
81
- attn_scores = attn_scores + attention_mask
82
- attn_weights = F.softmax(attn_scores, dim=-1)
83
- attn_weights = self.attn_dropout(attn_weights)
84
- attn_output = attn_weights @ v
85
-
86
- # 重塑并投影输出
87
- attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D)
88
- output = self.resid_dropout(self.o_proj(attn_output))
89
-
90
- return output
91
-
92
-
93
- class ModalityProjector(nn.Module):
94
- """模态投影器 - 将不同模态投影到统一空间"""
95
- def __init__(
96
- self,
97
- input_dim: int,
98
- output_dim: int,
99
- hidden_dim: Optional[int] = None,
100
- num_layers: int = 2,
101
- use_layer_norm: bool = True
102
- ):
103
- super().__init__()
104
- if hidden_dim is None:
105
- hidden_dim = (input_dim + output_dim) // 2
106
-
107
- layers = []
108
- for i in range(num_layers):
109
- if i == 0:
110
- layers.append(nn.Linear(input_dim, hidden_dim))
111
- elif i == num_layers - 1:
112
- layers.append(nn.Linear(hidden_dim, output_dim))
113
- else:
114
- layers.append(nn.Linear(hidden_dim, hidden_dim))
115
-
116
- if i < num_layers - 1:
117
- if use_layer_norm:
118
- layers.append(RMSNorm(hidden_dim))
119
- layers.append(nn.GELU())
120
-
121
- self.projector = nn.Sequential(*layers)
122
-
123
- def forward(self, x: torch.Tensor) -> torch.Tensor:
124
- return self.projector(x)
125
-
126
-
127
- class ModalityAdapter(nn.Module):
128
- """模态适配器 - 为每个模态学习特定的适配参数"""
129
- def __init__(
130
- self,
131
- dim: int,
132
- bottleneck_dim: int = 64,
133
- num_modalities: int = 4
134
- ):
135
- super().__init__()
136
- self.adapters = nn.ModuleList([
137
- nn.Sequential(
138
- nn.Linear(dim, bottleneck_dim),
139
- nn.GELU(),
140
- nn.Linear(bottleneck_dim, dim)
141
- )
142
- for _ in range(num_modalities)
143
- ])
144
- # 初始化为零,确保开始时是恒等映射
145
- for adapter in self.adapters:
146
- nn.init.zeros_(adapter[-1].weight)
147
- nn.init.zeros_(adapter[-1].bias)
148
-
149
- def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor:
150
- if modality_id >= len(self.adapters):
151
- return x
152
- return x + self.adapters[modality_id](x)
153
-
154
-
155
- class CrossModalFusionLayer(nn.Module):
156
- """跨模态融合层"""
157
- def __init__(
158
- self,
159
- dim: int,
160
- n_heads: int = 16,
161
- dropout: float = 0.1,
162
- use_adapter: bool = True,
163
- adapter_dim: int = 64
164
- ):
165
- super().__init__()
166
- self.dim = dim
167
- self.use_adapter = use_adapter
168
-
169
- # 自注意力
170
- self.self_attn = GroupedQueryAttention(
171
- dim=dim,
172
- n_heads=n_heads,
173
- dropout=dropout,
174
- attn_dropout=dropout
175
- )
176
-
177
- # 跨模态注意力
178
- self.cross_attn = CrossModalAttention(
179
- dim=dim,
180
- n_heads=n_heads,
181
- dropout=dropout
182
- )
183
-
184
- # 前馈网络
185
- self.ffn = nn.Sequential(
186
- nn.Linear(dim, dim * 4),
187
- nn.GELU(),
188
- nn.Dropout(dropout),
189
- nn.Linear(dim * 4, dim),
190
- nn.Dropout(dropout)
191
- )
192
-
193
- # 归一化层
194
- self.norm1 = RMSNorm(dim)
195
- self.norm2 = RMSNorm(dim)
196
- self.norm3 = RMSNorm(dim)
197
-
198
- # 模态适配器
199
- if use_adapter:
200
- self.adapter = ModalityAdapter(dim, adapter_dim)
201
- else:
202
- self.adapter = None
203
-
204
- def forward(
205
- self,
206
- x: torch.Tensor,
207
- context: Optional[torch.Tensor] = None,
208
- modality_id: Optional[int] = None,
209
- attention_mask: Optional[torch.Tensor] = None
210
- ) -> torch.Tensor:
211
- """
212
- Args:
213
- x: 当前模态特征 [B, T, D]
214
- context: 其他模态的上下文 [B, T_ctx, D]
215
- modality_id: 模态ID(用于adapter)
216
- attention_mask: 注意力掩码
217
- """
218
- # 自注意力 - 返回 (output, present_kv, attention_weights)
219
- attn_out = self.self_attn(
220
- self.norm1(x),
221
- attention_mask=attention_mask
222
- )[0] # 只取输出
223
- x = x + attn_out
224
-
225
- # 跨模态注意力(如果有上下文)
226
- if context is not None:
227
- cross_attn_out = self.cross_attn(
228
- self.norm2(x),
229
- context,
230
- context,
231
- attention_mask=None
232
- )
233
- x = x + cross_attn_out
234
-
235
- # 前馈网络
236
- x = x + self.ffn(self.norm3(x))
237
-
238
- # 模态适配器
239
- if self.use_adapter and modality_id is not None and self.adapter is not None:
240
- x = self.adapter(x, modality_id)
241
-
242
- return x
243
-
244
-
245
- class PerceiverResampler(nn.Module):
246
- """Perceiver Resampler - 压缩模态特征到固定数量的tokens"""
247
- def __init__(
248
- self,
249
- dim: int,
250
- depth: int = 6,
251
- num_latents: int = 64,
252
- n_heads: int = 16,
253
- dropout: float = 0.0
254
- ):
255
- super().__init__()
256
- self.num_latents = num_latents
257
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
258
-
259
- self.layers = nn.ModuleList([
260
- CrossModalFusionLayer(
261
- dim=dim,
262
- n_heads=n_heads,
263
- dropout=dropout,
264
- use_adapter=False
265
- )
266
- for _ in range(depth)
267
- ])
268
-
269
- self.norm = RMSNorm(dim)
270
-
271
- # 初始化latents
272
- nn.init.trunc_normal_(self.latents, std=0.02)
273
-
274
- def forward(self, x: torch.Tensor) -> torch.Tensor:
275
- """
276
- Args:
277
- x: [B, T, D] - 输入特征
278
- Returns:
279
- [B, num_latents, D] - 压缩后的特征
280
- """
281
- B = x.shape[0]
282
- latents = self.latents.unsqueeze(0).expand(B, -1, -1)
283
-
284
- # 通过多层交叉注意力处理
285
- for layer in self.layers:
286
- latents = layer(latents, context=x)
287
-
288
- return self.norm(latents)
289
-
290
-
291
- class MultiModalFusionModule(nn.Module):
292
- """多模态融合模块 - 整合所有融合策略"""
293
- def __init__(
294
- self,
295
- dim: int = 2048,
296
- num_fusion_layers: int = 4,
297
- n_heads: int = 16,
298
- dropout: float = 0.1,
299
- use_perceiver: bool = True,
300
- num_latents: int = 64,
301
- use_contrastive: bool = True,
302
- contrastive_loss_type: str = 'siglip',
303
- contrastive_embed_dim: int = 512
304
- ):
305
- super().__init__()
306
- self.dim = dim
307
- self.use_perceiver = use_perceiver
308
- self.use_contrastive = use_contrastive
309
-
310
- # 模态投影器
311
- self.modality_projectors = nn.ModuleDict({
312
- 'image': ModalityProjector(dim, dim),
313
- 'audio': ModalityProjector(dim, dim),
314
- 'video': ModalityProjector(dim, dim),
315
- 'text': ModalityProjector(dim, dim)
316
- })
317
-
318
- # 跨模态融合层
319
- self.fusion_layers = nn.ModuleList([
320
- CrossModalFusionLayer(
321
- dim=dim,
322
- n_heads=n_heads,
323
- dropout=dropout,
324
- use_adapter=True
325
- )
326
- for _ in range(num_fusion_layers)
327
- ])
328
-
329
- # Perceiver Resampler
330
- if use_perceiver:
331
- self.perceiver = PerceiverResampler(
332
- dim=dim,
333
- depth=4,
334
- num_latents=num_latents,
335
- n_heads=n_heads,
336
- dropout=dropout
337
- )
338
-
339
- # 对比学习模块
340
- if use_contrastive:
341
- # 定义每个模态的输入维度和池化类型
342
- modality_config = {
343
- 'text': 'cls',
344
- 'image': 'cls',
345
- 'audio': 'mean',
346
- 'video': 'mean'
347
- }
348
-
349
- input_dims = {k: dim for k in modality_config.keys()}
350
-
351
- self.contrastive_module = MultiModalContrastiveLoss(
352
- embed_dim=contrastive_embed_dim,
353
- input_dims=input_dims,
354
- temperature=0.07,
355
- loss_type=contrastive_loss_type,
356
- modality_config=modality_config
357
- )
358
-
359
- self.final_norm = RMSNorm(dim)
360
-
361
- def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
362
- """池化特征到单一向量 [B, T, D] -> [B, D]"""
363
- if features.dim() == 3:
364
- return features.mean(dim=1)
365
- return features
366
-
367
- def forward(
368
- self,
369
- segments: List[Dict],
370
- compute_contrastive: bool = False
371
- ) -> Dict:
372
- """
373
- Args:
374
- segments: 列表,每个元素包含 {'type', 'data', 'modality_id'}
375
- - type: str, 模态类型 ('image', 'audio', 'video', 'text')
376
- - data: Tensor [B, T, D], 模态数据
377
- - modality_id: int, 模态ID (0-3)
378
- compute_contrastive: 是否计算对比学习损失
379
-
380
- Returns:
381
- Dict containing:
382
- - fused_features: 融合后的特征序列
383
- - modality_features: 各模态的特征字典
384
- - contrastive_losses: 对比学习损失字典
385
- """
386
- # 分离不同模态
387
- modality_features = {}
388
- modality_ids = {}
389
-
390
- for seg in segments:
391
- mod_type = seg['type']
392
- mod_data = seg['data']
393
- mod_id = seg['modality_id']
394
-
395
- # 检查数据维度
396
- if mod_data.dim() != 3:
397
- raise ValueError(
398
- f"Expected 3D tensor [B, T, D] for modality {mod_type}, "
399
- f"got shape {mod_data.shape}"
400
- )
401
-
402
- # 投影到统一空间
403
- if mod_type in self.modality_projectors:
404
- projected = self.modality_projectors[mod_type](mod_data)
405
- else:
406
- projected = mod_data
407
-
408
- # 使用Perceiver压缩(可选,非text模态)
409
- if self.use_perceiver and mod_type != 'text':
410
- projected = self.perceiver(projected)
411
-
412
- modality_features[mod_type] = projected
413
- modality_ids[mod_type] = mod_id
414
-
415
- # 跨模态融合
416
- fused_features = {}
417
-
418
- for mod_type, features in modality_features.items():
419
- # 创建不包含当前模态的上下文
420
- if len(modality_features) > 1:
421
- other_features = torch.cat([
422
- f for k, f in modality_features.items() if k != mod_type
423
- ], dim=1)
424
- else:
425
- other_features = None
426
-
427
- # 通过融合层
428
- fused = features
429
- for layer in self.fusion_layers:
430
- fused = layer(
431
- fused,
432
- context=other_features,
433
- modality_id=modality_ids[mod_type]
434
- )
435
-
436
- fused_features[mod_type] = self.final_norm(fused)
437
-
438
- # 计算对比学习损失(如果需要)
439
- contrastive_losses = {}
440
- if compute_contrastive and self.use_contrastive:
441
- # 准备特征字典 - 保持3D格式供投影头处理
442
- pooled_features = fused_features # 不池化,让ProjectionHead处理
443
-
444
- # 定义需要对比的模态对
445
- modality_pairs = []
446
- if 'text' in pooled_features:
447
- for mod in pooled_features.keys():
448
- if mod != 'text':
449
- modality_pairs.append((mod, 'text'))
450
-
451
- # 调用对比学习模块
452
- if modality_pairs:
453
- contrastive_losses = self.contrastive_module(
454
- pooled_features,
455
- modality_pairs=modality_pairs
456
- )
457
-
458
- # 拼接所有融合后的特征
459
- fused_sequence = torch.cat(list(fused_features.values()), dim=1)
460
-
461
- return {
462
- 'fused_features': fused_sequence,
463
- 'modality_features': fused_features,
464
- 'contrastive_losses': contrastive_losses
465
- }
466
-
467
-
468
- class EarlyFusionModule(nn.Module):
469
- """早期融合 - 在浅层就融合模态"""
470
- def __init__(self, dim: int = 2048):
471
- super().__init__()
472
- self.fusion_proj = nn.Linear(dim, dim)
473
- self.norm = RMSNorm(dim)
474
-
475
- def forward(self, segments: List[Dict]) -> torch.Tensor:
476
- """简单拼接所有模态"""
477
- all_features = [seg['data'] for seg in segments]
478
- fused = torch.cat(all_features, dim=1)
479
- fused = self.fusion_proj(fused)
480
- return self.norm(fused)
481
-
482
-
483
- class LateFusionModule(nn.Module):
484
- """晚期融合 - 在深层才融合模态"""
485
- def __init__(
486
- self,
487
- dim: int = 2048,
488
- num_modalities: int = 4,
489
- fusion_method: str = 'concat' # 'concat', 'attention', 'average'
490
- ):
491
- super().__init__()
492
- self.fusion_method = fusion_method
493
-
494
- if fusion_method == 'concat':
495
- self.fusion_proj = nn.Linear(dim * num_modalities, dim)
496
- elif fusion_method == 'attention':
497
- self.attention_weights = nn.Linear(dim, 1)
498
-
499
- self.norm = RMSNorm(dim)
500
-
501
- def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor:
502
- """
503
- Args:
504
- modality_outputs: 每个模态独立处理后的输出列表 [B, T, D]
505
- """
506
- if self.fusion_method == 'concat':
507
- # 拼接并投影
508
- pooled = [x.mean(dim=1) for x in modality_outputs]
509
- fused = torch.cat(pooled, dim=-1)
510
- fused = self.fusion_proj(fused)
511
-
512
- elif self.fusion_method == 'attention':
513
- # 注意力加权
514
- stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
515
- weights = F.softmax(self.attention_weights(stacked), dim=1)
516
- fused = (stacked * weights).sum(dim=1)
517
-
518
- else: # average
519
- stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
520
- fused = stacked.mean(dim=1)
521
-
522
  return self.norm(fused)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+ from components import RMSNorm
6
+ from transformer import GroupedQueryAttention
7
+ import math
8
+ from contrastive_learning import MultiModalContrastiveLoss
9
+
10
+
11
+ class CrossModalAttention(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ n_heads: int = 16,
16
+ dropout: float = 0.1,
17
+ qkv_bias: bool = True
18
+ ):
19
+ super().__init__()
20
+ self.dim = dim
21
+ self.n_heads = n_heads
22
+ self.head_dim = dim // n_heads
23
+ self.scale = self.head_dim ** -0.5
24
+
25
+ assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
26
+
27
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
28
+ self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
29
+ self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
30
+ self.o_proj = nn.Linear(dim, dim)
31
+
32
+ self.attn_dropout = nn.Dropout(dropout)
33
+ self.resid_dropout = nn.Dropout(dropout)
34
+
35
+ self.norm_q = RMSNorm(dim)
36
+ self.norm_k = RMSNorm(dim)
37
+
38
+ def forward(
39
+ self,
40
+ query: torch.Tensor,
41
+ key: torch.Tensor,
42
+ value: torch.Tensor,
43
+ attention_mask: Optional[torch.Tensor] = None
44
+ ) -> torch.Tensor:
45
+ B, T_q, D = query.shape
46
+ T_k = key.shape[1]
47
+
48
+ # 归一化
49
+ query = self.norm_q(query)
50
+ key = self.norm_k(key)
51
+
52
+ # 投影并重塑
53
+ q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
54
+ k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
55
+ v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
56
+
57
+ # 使用Flash Attention或手动实现
58
+ if hasattr(F, 'scaled_dot_product_attention'):
59
+ dropout_p = self.attn_dropout.p if self.training else 0.0
60
+ attn_output = F.scaled_dot_product_attention(
61
+ q, k, v,
62
+ attn_mask=attention_mask,
63
+ dropout_p=dropout_p,
64
+ is_causal=False
65
+ )
66
+ else:
67
+ attn_scores = (q @ k.transpose(-2, -1)) * self.scale
68
+ if attention_mask is not None:
69
+ attn_scores = attn_scores + attention_mask
70
+ attn_weights = F.softmax(attn_scores, dim=-1)
71
+ attn_weights = self.attn_dropout(attn_weights)
72
+ attn_output = attn_weights @ v
73
+
74
+ # 重塑并投影输出
75
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D)
76
+ output = self.resid_dropout(self.o_proj(attn_output))
77
+
78
+ return output
79
+
80
+
81
+ class ModalityProjector(nn.Module):
82
+ """模态投影器 - 将不同模态投影到统一空间"""
83
+ def __init__(
84
+ self,
85
+ input_dim: int,
86
+ output_dim: int,
87
+ hidden_dim: Optional[int] = None,
88
+ num_layers: int = 2,
89
+ use_layer_norm: bool = True
90
+ ):
91
+ super().__init__()
92
+ if hidden_dim is None:
93
+ hidden_dim = (input_dim + output_dim) // 2
94
+
95
+ layers = []
96
+ for i in range(num_layers):
97
+ if i == 0:
98
+ layers.append(nn.Linear(input_dim, hidden_dim))
99
+ elif i == num_layers - 1:
100
+ layers.append(nn.Linear(hidden_dim, output_dim))
101
+ else:
102
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
103
+
104
+ if i < num_layers - 1:
105
+ if use_layer_norm:
106
+ layers.append(RMSNorm(hidden_dim))
107
+ layers.append(nn.GELU())
108
+
109
+ self.projector = nn.Sequential(*layers)
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ return self.projector(x)
113
+
114
+
115
+ class ModalityAdapter(nn.Module):
116
+ """模态适配器 - 为每个模态学习特定的适配参数"""
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ bottleneck_dim: int = 64,
121
+ num_modalities: int = 4
122
+ ):
123
+ super().__init__()
124
+ self.adapters = nn.ModuleList([
125
+ nn.Sequential(
126
+ nn.Linear(dim, bottleneck_dim),
127
+ nn.GELU(),
128
+ nn.Linear(bottleneck_dim, dim)
129
+ )
130
+ for _ in range(num_modalities)
131
+ ])
132
+ for adapter in self.adapters:
133
+ nn.init.zeros_(adapter[-1].weight)
134
+ nn.init.zeros_(adapter[-1].bias)
135
+
136
+ def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor:
137
+ if modality_id >= len(self.adapters):
138
+ return x
139
+ return x + self.adapters[modality_id](x)
140
+
141
+
142
+ class CrossModalFusionLayer(nn.Module):
143
+ """跨模态融合层"""
144
+ def __init__(
145
+ self,
146
+ dim: int,
147
+ n_heads: int = 16,
148
+ dropout: float = 0.1,
149
+ use_adapter: bool = True,
150
+ adapter_dim: int = 64
151
+ ):
152
+ super().__init__()
153
+ self.dim = dim
154
+ self.use_adapter = use_adapter
155
+
156
+ # 自注意力
157
+ self.self_attn = GroupedQueryAttention(
158
+ dim=dim,
159
+ n_heads=n_heads,
160
+ dropout=dropout,
161
+ attn_dropout=dropout
162
+ )
163
+
164
+ # 跨模态注意力
165
+ self.cross_attn = CrossModalAttention(
166
+ dim=dim,
167
+ n_heads=n_heads,
168
+ dropout=dropout
169
+ )
170
+
171
+ # 前馈网络
172
+ self.ffn = nn.Sequential(
173
+ nn.Linear(dim, dim * 4),
174
+ nn.GELU(),
175
+ nn.Dropout(dropout),
176
+ nn.Linear(dim * 4, dim),
177
+ nn.Dropout(dropout)
178
+ )
179
+
180
+ # 归一化层
181
+ self.norm1 = RMSNorm(dim)
182
+ self.norm2 = RMSNorm(dim)
183
+ self.norm3 = RMSNorm(dim)
184
+
185
+ # 模态适配器
186
+ if use_adapter:
187
+ self.adapter = ModalityAdapter(dim, adapter_dim)
188
+ else:
189
+ self.adapter = None
190
+
191
+ def forward(
192
+ self,
193
+ x: torch.Tensor,
194
+ context: Optional[torch.Tensor] = None,
195
+ modality_id: Optional[int] = None,
196
+ attention_mask: Optional[torch.Tensor] = None
197
+ ) -> torch.Tensor:
198
+ attn_out = self.self_attn(
199
+ self.norm1(x),
200
+ attention_mask=attention_mask
201
+ )[0] # 只取输出
202
+ x = x + attn_out
203
+
204
+ if context is not None:
205
+ cross_attn_out = self.cross_attn(
206
+ self.norm2(x),
207
+ context,
208
+ context,
209
+ attention_mask=None
210
+ )
211
+ x = x + cross_attn_out
212
+
213
+ # 前馈网络
214
+ x = x + self.ffn(self.norm3(x))
215
+
216
+ # 模态适配器
217
+ if self.use_adapter and modality_id is not None and self.adapter is not None:
218
+ x = self.adapter(x, modality_id)
219
+
220
+ return x
221
+
222
+
223
+ class PerceiverResampler(nn.Module):
224
+ """Perceiver Resampler - 压缩模态特征到固定数量的tokens"""
225
+ def __init__(
226
+ self,
227
+ dim: int,
228
+ depth: int = 6,
229
+ num_latents: int = 64,
230
+ n_heads: int = 16,
231
+ dropout: float = 0.0
232
+ ):
233
+ super().__init__()
234
+ self.num_latents = num_latents
235
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
236
+
237
+ self.layers = nn.ModuleList([
238
+ CrossModalFusionLayer(
239
+ dim=dim,
240
+ n_heads=n_heads,
241
+ dropout=dropout,
242
+ use_adapter=False
243
+ )
244
+ for _ in range(depth)
245
+ ])
246
+
247
+ self.norm = RMSNorm(dim)
248
+
249
+ # 初始化latents
250
+ nn.init.trunc_normal_(self.latents, std=0.02)
251
+
252
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
253
+ B = x.shape[0]
254
+ latents = self.latents.unsqueeze(0).expand(B, -1, -1)
255
+
256
+ # 通过多层交叉注意力处理
257
+ for layer in self.layers:
258
+ latents = layer(latents, context=x)
259
+
260
+ return self.norm(latents)
261
+
262
+
263
+ class MultiModalFusionModule(nn.Module):
264
+ """多模态融合模块 - 整合所有融合策略"""
265
+ def __init__(
266
+ self,
267
+ dim: int = 2048,
268
+ num_fusion_layers: int = 4,
269
+ n_heads: int = 16,
270
+ dropout: float = 0.1,
271
+ use_perceiver: bool = True,
272
+ num_latents: int = 64,
273
+ use_contrastive: bool = True,
274
+ contrastive_loss_type: str = 'siglip',
275
+ contrastive_embed_dim: int = 512
276
+ ):
277
+ super().__init__()
278
+ self.dim = dim
279
+ self.use_perceiver = use_perceiver
280
+ self.use_contrastive = use_contrastive
281
+
282
+ # 模态投影器
283
+ self.modality_projectors = nn.ModuleDict({
284
+ 'image': ModalityProjector(dim, dim),
285
+ 'audio': ModalityProjector(dim, dim),
286
+ 'video': ModalityProjector(dim, dim),
287
+ 'text': ModalityProjector(dim, dim)
288
+ })
289
+
290
+ # 跨模态融合层
291
+ self.fusion_layers = nn.ModuleList([
292
+ CrossModalFusionLayer(
293
+ dim=dim,
294
+ n_heads=n_heads,
295
+ dropout=dropout,
296
+ use_adapter=True
297
+ )
298
+ for _ in range(num_fusion_layers)
299
+ ])
300
+
301
+ # Perceiver Resampler
302
+ if use_perceiver:
303
+ self.perceiver = PerceiverResampler(
304
+ dim=dim,
305
+ depth=4,
306
+ num_latents=num_latents,
307
+ n_heads=n_heads,
308
+ dropout=dropout
309
+ )
310
+
311
+ # 对比学习模块
312
+ if use_contrastive:
313
+ # 定义每个模态的输入维度和池化类型
314
+ modality_config = {
315
+ 'text': 'cls',
316
+ 'image': 'cls',
317
+ 'audio': 'mean',
318
+ 'video': 'mean'
319
+ }
320
+
321
+ input_dims = {k: dim for k in modality_config.keys()}
322
+
323
+ self.contrastive_module = MultiModalContrastiveLoss(
324
+ embed_dim=contrastive_embed_dim,
325
+ input_dims=input_dims,
326
+ temperature=0.07,
327
+ loss_type=contrastive_loss_type,
328
+ modality_config=modality_config
329
+ )
330
+
331
+ self.final_norm = RMSNorm(dim)
332
+
333
+ def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
334
+ """池化特征到单一向量 [B, T, D] -> [B, D]"""
335
+ if features.dim() == 3:
336
+ return features.mean(dim=1)
337
+ return features
338
+
339
+ def forward(
340
+ self,
341
+ segments: List[Dict],
342
+ compute_contrastive: bool = False
343
+ ) -> Dict:
344
+ # 分离不同模态
345
+ modality_features = {}
346
+ modality_ids = {}
347
+
348
+ for seg in segments:
349
+ mod_type = seg['type']
350
+ mod_data = seg['data']
351
+ mod_id = seg['modality_id']
352
+
353
+ # 检查数据维度
354
+ if mod_data.dim() != 3:
355
+ raise ValueError(
356
+ f"Expected 3D tensor [B, T, D] for modality {mod_type}, "
357
+ f"got shape {mod_data.shape}"
358
+ )
359
+
360
+ # 投影到统一空间
361
+ if mod_type in self.modality_projectors:
362
+ projected = self.modality_projectors[mod_type](mod_data)
363
+ else:
364
+ projected = mod_data
365
+
366
+ # 使用Perceiver压缩(可选,非text模态)
367
+ if self.use_perceiver and mod_type != 'text':
368
+ projected = self.perceiver(projected)
369
+
370
+ modality_features[mod_type] = projected
371
+ modality_ids[mod_type] = mod_id
372
+
373
+ # 跨模态融合
374
+ fused_features = {}
375
+
376
+ for mod_type, features in modality_features.items():
377
+ # 创建不包含当前模态的上下文
378
+ if len(modality_features) > 1:
379
+ other_features = torch.cat([
380
+ f for k, f in modality_features.items() if k != mod_type
381
+ ], dim=1)
382
+ else:
383
+ other_features = None
384
+
385
+ # 通过融合层
386
+ fused = features
387
+ for layer in self.fusion_layers:
388
+ fused = layer(
389
+ fused,
390
+ context=other_features,
391
+ modality_id=modality_ids[mod_type]
392
+ )
393
+
394
+ fused_features[mod_type] = self.final_norm(fused)
395
+
396
+ # 计算对比学习损失(如果需要)
397
+ contrastive_losses = {}
398
+ if compute_contrastive and self.use_contrastive:
399
+ pooled_features = fused_features
400
+
401
+ # 定义需要对比的模态对
402
+ modality_pairs = []
403
+ if 'text' in pooled_features:
404
+ for mod in pooled_features.keys():
405
+ if mod != 'text':
406
+ modality_pairs.append((mod, 'text'))
407
+
408
+ # 调用对比学习模块
409
+ if modality_pairs:
410
+ contrastive_losses = self.contrastive_module(
411
+ pooled_features,
412
+ modality_pairs=modality_pairs
413
+ )
414
+
415
+ # 拼接所有融合后的特征
416
+ fused_sequence = torch.cat(list(fused_features.values()), dim=1)
417
+
418
+ return {
419
+ 'fused_features': fused_sequence,
420
+ 'modality_features': fused_features,
421
+ 'contrastive_losses': contrastive_losses
422
+ }
423
+
424
+
425
+ class EarlyFusionModule(nn.Module):
426
+ """早期融合 - 在浅层就融合模态"""
427
+ def __init__(self, dim: int = 2048):
428
+ super().__init__()
429
+ self.fusion_proj = nn.Linear(dim, dim)
430
+ self.norm = RMSNorm(dim)
431
+
432
+ def forward(self, segments: List[Dict]) -> torch.Tensor:
433
+ """简单拼接所有模态"""
434
+ all_features = [seg['data'] for seg in segments]
435
+ fused = torch.cat(all_features, dim=1)
436
+ fused = self.fusion_proj(fused)
437
+ return self.norm(fused)
438
+
439
+
440
+ class LateFusionModule(nn.Module):
441
+ """晚期融合 - 在深层才融合模态"""
442
+ def __init__(
443
+ self,
444
+ dim: int = 2048,
445
+ num_modalities: int = 4,
446
+ fusion_method: str = 'concat' # 'concat', 'attention', 'average'
447
+ ):
448
+ super().__init__()
449
+ self.fusion_method = fusion_method
450
+
451
+ if fusion_method == 'concat':
452
+ self.fusion_proj = nn.Linear(dim * num_modalities, dim)
453
+ elif fusion_method == 'attention':
454
+ self.attention_weights = nn.Linear(dim, 1)
455
+
456
+ self.norm = RMSNorm(dim)
457
+
458
+ def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor:
459
+ if self.fusion_method == 'concat':
460
+ # 拼接并投影
461
+ pooled = [x.mean(dim=1) for x in modality_outputs]
462
+ fused = torch.cat(pooled, dim=-1)
463
+ fused = self.fusion_proj(fused)
464
+
465
+ elif self.fusion_method == 'attention':
466
+ # 注意力加权
467
+ stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
468
+ weights = F.softmax(self.attention_weights(stacked), dim=1)
469
+ fused = (stacked * weights).sum(dim=1)
470
+
471
+ else: # average
472
+ stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
473
+ fused = stacked.mean(dim=1)
474
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  return self.norm(fused)