import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple, Union from components import RMSNorm from transformer import GroupedQueryAttention import math from contrastive_learning import MultiModalContrastiveLoss class CrossModalAttention(nn.Module): def __init__( self, dim: int, n_heads: int = 16, dropout: float = 0.1, qkv_bias: bool = True ): super().__init__() self.dim = dim self.n_heads = n_heads self.head_dim = dim // n_heads self.scale = self.head_dim ** -0.5 assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}" self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) self.o_proj = nn.Linear(dim, dim) self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) self.norm_q = RMSNorm(dim) self.norm_k = RMSNorm(dim) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: B, T_q, D = query.shape T_k = key.shape[1] # 归一化 query = self.norm_q(query) key = self.norm_k(key) # 投影并重塑 q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2) # 使用Flash Attention或手动实现 if hasattr(F, 'scaled_dot_product_attention'): dropout_p = self.attn_dropout.p if self.training else 0.0 attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False ) else: attn_scores = (q @ k.transpose(-2, -1)) * self.scale if attention_mask is not None: attn_scores = attn_scores + attention_mask attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.attn_dropout(attn_weights) attn_output = attn_weights @ v # 重塑并投影输出 attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D) output = self.resid_dropout(self.o_proj(attn_output)) return output class ModalityProjector(nn.Module): """模态投影器 - 将不同模态投影到统一空间""" def __init__( self, input_dim: int, output_dim: int, hidden_dim: Optional[int] = None, num_layers: int = 2, use_layer_norm: bool = True ): super().__init__() if hidden_dim is None: hidden_dim = (input_dim + output_dim) // 2 layers = [] for i in range(num_layers): if i == 0: layers.append(nn.Linear(input_dim, hidden_dim)) elif i == num_layers - 1: layers.append(nn.Linear(hidden_dim, output_dim)) else: layers.append(nn.Linear(hidden_dim, hidden_dim)) if i < num_layers - 1: if use_layer_norm: layers.append(RMSNorm(hidden_dim)) layers.append(nn.GELU()) self.projector = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.projector(x) class ModalityAdapter(nn.Module): """模态适配器 - 为每个模态学习特定的适配参数""" def __init__( self, dim: int, bottleneck_dim: int = 64, num_modalities: int = 4 ): super().__init__() self.adapters = nn.ModuleList([ nn.Sequential( nn.Linear(dim, bottleneck_dim), nn.GELU(), nn.Linear(bottleneck_dim, dim) ) for _ in range(num_modalities) ]) for adapter in self.adapters: nn.init.zeros_(adapter[-1].weight) nn.init.zeros_(adapter[-1].bias) def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor: if modality_id >= len(self.adapters): return x return x + self.adapters[modality_id](x) class CrossModalFusionLayer(nn.Module): """跨模态融合层""" def __init__( self, dim: int, n_heads: int = 16, dropout: float = 0.1, use_adapter: bool = True, adapter_dim: int = 64 ): super().__init__() self.dim = dim self.use_adapter = use_adapter # 自注意力 self.self_attn = GroupedQueryAttention( dim=dim, n_heads=n_heads, dropout=dropout, attn_dropout=dropout ) # 跨模态注意力 self.cross_attn = CrossModalAttention( dim=dim, n_heads=n_heads, dropout=dropout ) # 前馈网络 self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim), nn.Dropout(dropout) ) # 归一化层 self.norm1 = RMSNorm(dim) self.norm2 = RMSNorm(dim) self.norm3 = RMSNorm(dim) # 模态适配器 if use_adapter: self.adapter = ModalityAdapter(dim, adapter_dim) else: self.adapter = None def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, modality_id: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: attn_out = self.self_attn( self.norm1(x), attention_mask=attention_mask )[0] # 只取输出 x = x + attn_out if context is not None: cross_attn_out = self.cross_attn( self.norm2(x), context, context, attention_mask=None ) x = x + cross_attn_out # 前馈网络 x = x + self.ffn(self.norm3(x)) # 模态适配器 if self.use_adapter and modality_id is not None and self.adapter is not None: x = self.adapter(x, modality_id) return x class PerceiverResampler(nn.Module): """Perceiver Resampler - 压缩模态特征到固定数量的tokens""" def __init__( self, dim: int, depth: int = 6, num_latents: int = 64, n_heads: int = 16, dropout: float = 0.0 ): super().__init__() self.num_latents = num_latents self.latents = nn.Parameter(torch.randn(num_latents, dim)) self.layers = nn.ModuleList([ CrossModalFusionLayer( dim=dim, n_heads=n_heads, dropout=dropout, use_adapter=False ) for _ in range(depth) ]) self.norm = RMSNorm(dim) # 初始化latents nn.init.trunc_normal_(self.latents, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: B = x.shape[0] latents = self.latents.unsqueeze(0).expand(B, -1, -1) # 通过多层交叉注意力处理 for layer in self.layers: latents = layer(latents, context=x) return self.norm(latents) class MultiModalFusionModule(nn.Module): """多模态融合模块 - 整合所有融合策略""" def __init__( self, dim: int = 2048, num_fusion_layers: int = 4, n_heads: int = 16, dropout: float = 0.1, use_perceiver: bool = True, num_latents: int = 64, use_contrastive: bool = True, contrastive_loss_type: str = 'siglip', contrastive_embed_dim: int = 512 ): super().__init__() self.dim = dim self.use_perceiver = use_perceiver self.use_contrastive = use_contrastive # 模态投影器 self.modality_projectors = nn.ModuleDict({ 'image': ModalityProjector(dim, dim), 'audio': ModalityProjector(dim, dim), 'video': ModalityProjector(dim, dim), 'text': ModalityProjector(dim, dim) }) # 跨模态融合层 self.fusion_layers = nn.ModuleList([ CrossModalFusionLayer( dim=dim, n_heads=n_heads, dropout=dropout, use_adapter=True ) for _ in range(num_fusion_layers) ]) # Perceiver Resampler if use_perceiver: self.perceiver = PerceiverResampler( dim=dim, depth=4, num_latents=num_latents, n_heads=n_heads, dropout=dropout ) # 对比学习模块 if use_contrastive: # 定义每个模态的输入维度和池化类型 modality_config = { 'text': 'cls', 'image': 'cls', 'audio': 'mean', 'video': 'mean' } input_dims = {k: dim for k in modality_config.keys()} self.contrastive_module = MultiModalContrastiveLoss( embed_dim=contrastive_embed_dim, input_dims=input_dims, temperature=0.07, loss_type=contrastive_loss_type, modality_config=modality_config ) self.final_norm = RMSNorm(dim) def _pool_features(self, features: torch.Tensor) -> torch.Tensor: """池化特征到单一向量 [B, T, D] -> [B, D]""" if features.dim() == 3: return features.mean(dim=1) return features def forward( self, segments: List[Dict], compute_contrastive: bool = False ) -> Dict: # 分离不同模态 modality_features = {} modality_ids = {} for seg in segments: mod_type = seg['type'] mod_data = seg['data'] mod_id = seg['modality_id'] # 检查数据维度 if mod_data.dim() != 3: raise ValueError( f"Expected 3D tensor [B, T, D] for modality {mod_type}, " f"got shape {mod_data.shape}" ) # 投影到统一空间 if mod_type in self.modality_projectors: projected = self.modality_projectors[mod_type](mod_data) else: projected = mod_data # 使用Perceiver压缩(可选,非text模态) if self.use_perceiver and mod_type != 'text': projected = self.perceiver(projected) modality_features[mod_type] = projected modality_ids[mod_type] = mod_id # 跨模态融合 fused_features = {} for mod_type, features in modality_features.items(): # 创建不包含当前模态的上下文 if len(modality_features) > 1: other_features = torch.cat([ f for k, f in modality_features.items() if k != mod_type ], dim=1) else: other_features = None # 通过融合层 fused = features for layer in self.fusion_layers: fused = layer( fused, context=other_features, modality_id=modality_ids[mod_type] ) fused_features[mod_type] = self.final_norm(fused) # 计算对比学习损失(如果需要) contrastive_losses = {} if compute_contrastive and self.use_contrastive: pooled_features = fused_features # 定义需要对比的模态对 modality_pairs = [] if 'text' in pooled_features: for mod in pooled_features.keys(): if mod != 'text': modality_pairs.append((mod, 'text')) # 调用对比学习模块 if modality_pairs: contrastive_losses = self.contrastive_module( pooled_features, modality_pairs=modality_pairs ) # 拼接所有融合后的特征 fused_sequence = torch.cat(list(fused_features.values()), dim=1) return { 'fused_features': fused_sequence, 'modality_features': fused_features, 'contrastive_losses': contrastive_losses } class EarlyFusionModule(nn.Module): """早期融合 - 在浅层就融合模态""" def __init__(self, dim: int = 2048): super().__init__() self.fusion_proj = nn.Linear(dim, dim) self.norm = RMSNorm(dim) def forward(self, segments: List[Dict]) -> torch.Tensor: """简单拼接所有模态""" all_features = [seg['data'] for seg in segments] fused = torch.cat(all_features, dim=1) fused = self.fusion_proj(fused) return self.norm(fused) class LateFusionModule(nn.Module): """晚期融合 - 在深层才融合模态""" def __init__( self, dim: int = 2048, num_modalities: int = 4, fusion_method: str = 'concat' # 'concat', 'attention', 'average' ): super().__init__() self.fusion_method = fusion_method if fusion_method == 'concat': self.fusion_proj = nn.Linear(dim * num_modalities, dim) elif fusion_method == 'attention': self.attention_weights = nn.Linear(dim, 1) self.norm = RMSNorm(dim) def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor: if self.fusion_method == 'concat': # 拼接并投影 pooled = [x.mean(dim=1) for x in modality_outputs] fused = torch.cat(pooled, dim=-1) fused = self.fusion_proj(fused) elif self.fusion_method == 'attention': # 注意力加权 stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1) weights = F.softmax(self.attention_weights(stacked), dim=1) fused = (stacked * weights).sum(dim=1) else: # average stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1) fused = stacked.mean(dim=1) return self.norm(fused)