| """
|
| 跨模态融合模块 - SOTA级别
|
| 支持深度跨模态交互、对比学习、模态对齐
|
| 修复版本:解决了所有接口不匹配和潜在bug
|
| """
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Dict, List, Optional, Tuple, Union
|
| from components import RMSNorm
|
| from transformer import GroupedQueryAttention
|
| import math
|
| from contrastive_learning import MultiModalContrastiveLoss
|
|
|
|
|
| class CrossModalAttention(nn.Module):
|
| """跨模态注意力 - 允许不同模态之间的信息交互"""
|
| def __init__(
|
| self,
|
| dim: int,
|
| n_heads: int = 16,
|
| dropout: float = 0.1,
|
| qkv_bias: bool = True
|
| ):
|
| super().__init__()
|
| self.dim = dim
|
| self.n_heads = n_heads
|
| self.head_dim = dim // n_heads
|
| self.scale = self.head_dim ** -0.5
|
|
|
| assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
|
|
|
| self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
|
| self.o_proj = nn.Linear(dim, dim)
|
|
|
| self.attn_dropout = nn.Dropout(dropout)
|
| self.resid_dropout = nn.Dropout(dropout)
|
|
|
| self.norm_q = RMSNorm(dim)
|
| self.norm_k = RMSNorm(dim)
|
|
|
| def forward(
|
| self,
|
| query: torch.Tensor,
|
| key: torch.Tensor,
|
| value: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None
|
| ) -> torch.Tensor:
|
| """
|
| Args:
|
| query: [B, T_q, D] - 查询模态
|
| key: [B, T_k, D] - 键模态
|
| value: [B, T_v, D] - 值模态 (通常与key相同)
|
| """
|
| B, T_q, D = query.shape
|
| T_k = key.shape[1]
|
|
|
|
|
| query = self.norm_q(query)
|
| key = self.norm_k(key)
|
|
|
|
|
| q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
|
| k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
|
| v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
| if hasattr(F, 'scaled_dot_product_attention'):
|
| dropout_p = self.attn_dropout.p if self.training else 0.0
|
| attn_output = F.scaled_dot_product_attention(
|
| q, k, v,
|
| attn_mask=attention_mask,
|
| dropout_p=dropout_p,
|
| is_causal=False
|
| )
|
| else:
|
| attn_scores = (q @ k.transpose(-2, -1)) * self.scale
|
| if attention_mask is not None:
|
| attn_scores = attn_scores + attention_mask
|
| attn_weights = F.softmax(attn_scores, dim=-1)
|
| attn_weights = self.attn_dropout(attn_weights)
|
| attn_output = attn_weights @ v
|
|
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D)
|
| output = self.resid_dropout(self.o_proj(attn_output))
|
|
|
| return output
|
|
|
|
|
| class ModalityProjector(nn.Module):
|
| """模态投影器 - 将不同模态投影到统一空间"""
|
| def __init__(
|
| self,
|
| input_dim: int,
|
| output_dim: int,
|
| hidden_dim: Optional[int] = None,
|
| num_layers: int = 2,
|
| use_layer_norm: bool = True
|
| ):
|
| super().__init__()
|
| if hidden_dim is None:
|
| hidden_dim = (input_dim + output_dim) // 2
|
|
|
| layers = []
|
| for i in range(num_layers):
|
| if i == 0:
|
| layers.append(nn.Linear(input_dim, hidden_dim))
|
| elif i == num_layers - 1:
|
| layers.append(nn.Linear(hidden_dim, output_dim))
|
| else:
|
| layers.append(nn.Linear(hidden_dim, hidden_dim))
|
|
|
| if i < num_layers - 1:
|
| if use_layer_norm:
|
| layers.append(RMSNorm(hidden_dim))
|
| layers.append(nn.GELU())
|
|
|
| self.projector = nn.Sequential(*layers)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.projector(x)
|
|
|
|
|
| class ModalityAdapter(nn.Module):
|
| """模态适配器 - 为每个模态学习特定的适配参数"""
|
| def __init__(
|
| self,
|
| dim: int,
|
| bottleneck_dim: int = 64,
|
| num_modalities: int = 4
|
| ):
|
| super().__init__()
|
| self.adapters = nn.ModuleList([
|
| nn.Sequential(
|
| nn.Linear(dim, bottleneck_dim),
|
| nn.GELU(),
|
| nn.Linear(bottleneck_dim, dim)
|
| )
|
| for _ in range(num_modalities)
|
| ])
|
|
|
| for adapter in self.adapters:
|
| nn.init.zeros_(adapter[-1].weight)
|
| nn.init.zeros_(adapter[-1].bias)
|
|
|
| def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor:
|
| if modality_id >= len(self.adapters):
|
| return x
|
| return x + self.adapters[modality_id](x)
|
|
|
|
|
| class CrossModalFusionLayer(nn.Module):
|
| """跨模态融合层"""
|
| def __init__(
|
| self,
|
| dim: int,
|
| n_heads: int = 16,
|
| dropout: float = 0.1,
|
| use_adapter: bool = True,
|
| adapter_dim: int = 64
|
| ):
|
| super().__init__()
|
| self.dim = dim
|
| self.use_adapter = use_adapter
|
|
|
|
|
| self.self_attn = GroupedQueryAttention(
|
| dim=dim,
|
| n_heads=n_heads,
|
| dropout=dropout,
|
| attn_dropout=dropout
|
| )
|
|
|
|
|
| self.cross_attn = CrossModalAttention(
|
| dim=dim,
|
| n_heads=n_heads,
|
| dropout=dropout
|
| )
|
|
|
|
|
| self.ffn = nn.Sequential(
|
| nn.Linear(dim, dim * 4),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(dim * 4, dim),
|
| nn.Dropout(dropout)
|
| )
|
|
|
|
|
| self.norm1 = RMSNorm(dim)
|
| self.norm2 = RMSNorm(dim)
|
| self.norm3 = RMSNorm(dim)
|
|
|
|
|
| if use_adapter:
|
| self.adapter = ModalityAdapter(dim, adapter_dim)
|
| else:
|
| self.adapter = None
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| context: Optional[torch.Tensor] = None,
|
| modality_id: Optional[int] = None,
|
| attention_mask: Optional[torch.Tensor] = None
|
| ) -> torch.Tensor:
|
| """
|
| Args:
|
| x: 当前模态特征 [B, T, D]
|
| context: 其他模态的上下文 [B, T_ctx, D]
|
| modality_id: 模态ID(用于adapter)
|
| attention_mask: 注意力掩码
|
| """
|
|
|
| attn_out = self.self_attn(
|
| self.norm1(x),
|
| attention_mask=attention_mask
|
| )[0]
|
| x = x + attn_out
|
|
|
|
|
| if context is not None:
|
| cross_attn_out = self.cross_attn(
|
| self.norm2(x),
|
| context,
|
| context,
|
| attention_mask=None
|
| )
|
| x = x + cross_attn_out
|
|
|
|
|
| x = x + self.ffn(self.norm3(x))
|
|
|
|
|
| if self.use_adapter and modality_id is not None and self.adapter is not None:
|
| x = self.adapter(x, modality_id)
|
|
|
| return x
|
|
|
|
|
| class PerceiverResampler(nn.Module):
|
| """Perceiver Resampler - 压缩模态特征到固定数量的tokens"""
|
| def __init__(
|
| self,
|
| dim: int,
|
| depth: int = 6,
|
| num_latents: int = 64,
|
| n_heads: int = 16,
|
| dropout: float = 0.0
|
| ):
|
| super().__init__()
|
| self.num_latents = num_latents
|
| self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
|
|
| self.layers = nn.ModuleList([
|
| CrossModalFusionLayer(
|
| dim=dim,
|
| n_heads=n_heads,
|
| dropout=dropout,
|
| use_adapter=False
|
| )
|
| for _ in range(depth)
|
| ])
|
|
|
| self.norm = RMSNorm(dim)
|
|
|
|
|
| nn.init.trunc_normal_(self.latents, std=0.02)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: [B, T, D] - 输入特征
|
| Returns:
|
| [B, num_latents, D] - 压缩后的特征
|
| """
|
| B = x.shape[0]
|
| latents = self.latents.unsqueeze(0).expand(B, -1, -1)
|
|
|
|
|
| for layer in self.layers:
|
| latents = layer(latents, context=x)
|
|
|
| return self.norm(latents)
|
|
|
|
|
| class MultiModalFusionModule(nn.Module):
|
| """多模态融合模块 - 整合所有融合策略"""
|
| def __init__(
|
| self,
|
| dim: int = 2048,
|
| num_fusion_layers: int = 4,
|
| n_heads: int = 16,
|
| dropout: float = 0.1,
|
| use_perceiver: bool = True,
|
| num_latents: int = 64,
|
| use_contrastive: bool = True,
|
| contrastive_loss_type: str = 'siglip',
|
| contrastive_embed_dim: int = 512
|
| ):
|
| super().__init__()
|
| self.dim = dim
|
| self.use_perceiver = use_perceiver
|
| self.use_contrastive = use_contrastive
|
|
|
|
|
| self.modality_projectors = nn.ModuleDict({
|
| 'image': ModalityProjector(dim, dim),
|
| 'audio': ModalityProjector(dim, dim),
|
| 'video': ModalityProjector(dim, dim),
|
| 'text': ModalityProjector(dim, dim)
|
| })
|
|
|
|
|
| self.fusion_layers = nn.ModuleList([
|
| CrossModalFusionLayer(
|
| dim=dim,
|
| n_heads=n_heads,
|
| dropout=dropout,
|
| use_adapter=True
|
| )
|
| for _ in range(num_fusion_layers)
|
| ])
|
|
|
|
|
| if use_perceiver:
|
| self.perceiver = PerceiverResampler(
|
| dim=dim,
|
| depth=4,
|
| num_latents=num_latents,
|
| n_heads=n_heads,
|
| dropout=dropout
|
| )
|
|
|
|
|
| if use_contrastive:
|
|
|
| modality_config = {
|
| 'text': 'cls',
|
| 'image': 'cls',
|
| 'audio': 'mean',
|
| 'video': 'mean'
|
| }
|
|
|
| input_dims = {k: dim for k in modality_config.keys()}
|
|
|
| self.contrastive_module = MultiModalContrastiveLoss(
|
| embed_dim=contrastive_embed_dim,
|
| input_dims=input_dims,
|
| temperature=0.07,
|
| loss_type=contrastive_loss_type,
|
| modality_config=modality_config
|
| )
|
|
|
| self.final_norm = RMSNorm(dim)
|
|
|
| def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
|
| """池化特征到单一向量 [B, T, D] -> [B, D]"""
|
| if features.dim() == 3:
|
| return features.mean(dim=1)
|
| return features
|
|
|
| def forward(
|
| self,
|
| segments: List[Dict],
|
| compute_contrastive: bool = False
|
| ) -> Dict:
|
| """
|
| Args:
|
| segments: 列表,每个元素包含 {'type', 'data', 'modality_id'}
|
| - type: str, 模态类型 ('image', 'audio', 'video', 'text')
|
| - data: Tensor [B, T, D], 模态数据
|
| - modality_id: int, 模态ID (0-3)
|
| compute_contrastive: 是否计算对比学习损失
|
|
|
| Returns:
|
| Dict containing:
|
| - fused_features: 融合后的特征序列
|
| - modality_features: 各模态的特征字典
|
| - contrastive_losses: 对比学习损失字典
|
| """
|
|
|
| modality_features = {}
|
| modality_ids = {}
|
|
|
| for seg in segments:
|
| mod_type = seg['type']
|
| mod_data = seg['data']
|
| mod_id = seg['modality_id']
|
|
|
|
|
| if mod_data.dim() != 3:
|
| raise ValueError(
|
| f"Expected 3D tensor [B, T, D] for modality {mod_type}, "
|
| f"got shape {mod_data.shape}"
|
| )
|
|
|
|
|
| if mod_type in self.modality_projectors:
|
| projected = self.modality_projectors[mod_type](mod_data)
|
| else:
|
| projected = mod_data
|
|
|
|
|
| if self.use_perceiver and mod_type != 'text':
|
| projected = self.perceiver(projected)
|
|
|
| modality_features[mod_type] = projected
|
| modality_ids[mod_type] = mod_id
|
|
|
|
|
| fused_features = {}
|
|
|
| for mod_type, features in modality_features.items():
|
|
|
| if len(modality_features) > 1:
|
| other_features = torch.cat([
|
| f for k, f in modality_features.items() if k != mod_type
|
| ], dim=1)
|
| else:
|
| other_features = None
|
|
|
|
|
| fused = features
|
| for layer in self.fusion_layers:
|
| fused = layer(
|
| fused,
|
| context=other_features,
|
| modality_id=modality_ids[mod_type]
|
| )
|
|
|
| fused_features[mod_type] = self.final_norm(fused)
|
|
|
|
|
| contrastive_losses = {}
|
| if compute_contrastive and self.use_contrastive:
|
|
|
| pooled_features = fused_features
|
|
|
|
|
| modality_pairs = []
|
| if 'text' in pooled_features:
|
| for mod in pooled_features.keys():
|
| if mod != 'text':
|
| modality_pairs.append((mod, 'text'))
|
|
|
|
|
| if modality_pairs:
|
| contrastive_losses = self.contrastive_module(
|
| pooled_features,
|
| modality_pairs=modality_pairs
|
| )
|
|
|
|
|
| fused_sequence = torch.cat(list(fused_features.values()), dim=1)
|
|
|
| return {
|
| 'fused_features': fused_sequence,
|
| 'modality_features': fused_features,
|
| 'contrastive_losses': contrastive_losses
|
| }
|
|
|
|
|
| class EarlyFusionModule(nn.Module):
|
| """早期融合 - 在浅层就融合模态"""
|
| def __init__(self, dim: int = 2048):
|
| super().__init__()
|
| self.fusion_proj = nn.Linear(dim, dim)
|
| self.norm = RMSNorm(dim)
|
|
|
| def forward(self, segments: List[Dict]) -> torch.Tensor:
|
| """简单拼接所有模态"""
|
| all_features = [seg['data'] for seg in segments]
|
| fused = torch.cat(all_features, dim=1)
|
| fused = self.fusion_proj(fused)
|
| return self.norm(fused)
|
|
|
|
|
| class LateFusionModule(nn.Module):
|
| """晚期融合 - 在深层才融合模态"""
|
| def __init__(
|
| self,
|
| dim: int = 2048,
|
| num_modalities: int = 4,
|
| fusion_method: str = 'concat'
|
| ):
|
| super().__init__()
|
| self.fusion_method = fusion_method
|
|
|
| if fusion_method == 'concat':
|
| self.fusion_proj = nn.Linear(dim * num_modalities, dim)
|
| elif fusion_method == 'attention':
|
| self.attention_weights = nn.Linear(dim, 1)
|
|
|
| self.norm = RMSNorm(dim)
|
|
|
| def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor:
|
| """
|
| Args:
|
| modality_outputs: 每个模态独立处理后的输出列表 [B, T, D]
|
| """
|
| if self.fusion_method == 'concat':
|
|
|
| pooled = [x.mean(dim=1) for x in modality_outputs]
|
| fused = torch.cat(pooled, dim=-1)
|
| fused = self.fusion_proj(fused)
|
|
|
| elif self.fusion_method == 'attention':
|
|
|
| stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
|
| weights = F.softmax(self.attention_weights(stacked), dim=1)
|
| fused = (stacked * weights).sum(dim=1)
|
|
|
| else:
|
| stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
|
| fused = stacked.mean(dim=1)
|
|
|
| return self.norm(fused) |