MultiModal / multimodel_fusion.py
szxllm's picture
Update multimodel_fusion.py
9223e06 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Union
from components import RMSNorm
from transformer import GroupedQueryAttention
import math
from contrastive_learning import MultiModalContrastiveLoss
class CrossModalAttention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int = 16,
dropout: float = 0.1,
qkv_bias: bool = True
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.scale = self.head_dim ** -0.5
assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}"
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.o_proj = nn.Linear(dim, dim)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.norm_q = RMSNorm(dim)
self.norm_k = RMSNorm(dim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, T_q, D = query.shape
T_k = key.shape[1]
# 归一化
query = self.norm_q(query)
key = self.norm_k(key)
# 投影并重塑
q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
# 使用Flash Attention或手动实现
if hasattr(F, 'scaled_dot_product_attention'):
dropout_p = self.attn_dropout.p if self.training else 0.0
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=dropout_p,
is_causal=False
)
else:
attn_scores = (q @ k.transpose(-2, -1)) * self.scale
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
attn_output = attn_weights @ v
# 重塑并投影输出
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D)
output = self.resid_dropout(self.o_proj(attn_output))
return output
class ModalityProjector(nn.Module):
"""模态投影器 - 将不同模态投影到统一空间"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: Optional[int] = None,
num_layers: int = 2,
use_layer_norm: bool = True
):
super().__init__()
if hidden_dim is None:
hidden_dim = (input_dim + output_dim) // 2
layers = []
for i in range(num_layers):
if i == 0:
layers.append(nn.Linear(input_dim, hidden_dim))
elif i == num_layers - 1:
layers.append(nn.Linear(hidden_dim, output_dim))
else:
layers.append(nn.Linear(hidden_dim, hidden_dim))
if i < num_layers - 1:
if use_layer_norm:
layers.append(RMSNorm(hidden_dim))
layers.append(nn.GELU())
self.projector = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class ModalityAdapter(nn.Module):
"""模态适配器 - 为每个模态学习特定的适配参数"""
def __init__(
self,
dim: int,
bottleneck_dim: int = 64,
num_modalities: int = 4
):
super().__init__()
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Linear(dim, bottleneck_dim),
nn.GELU(),
nn.Linear(bottleneck_dim, dim)
)
for _ in range(num_modalities)
])
for adapter in self.adapters:
nn.init.zeros_(adapter[-1].weight)
nn.init.zeros_(adapter[-1].bias)
def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor:
if modality_id >= len(self.adapters):
return x
return x + self.adapters[modality_id](x)
class CrossModalFusionLayer(nn.Module):
"""跨模态融合层"""
def __init__(
self,
dim: int,
n_heads: int = 16,
dropout: float = 0.1,
use_adapter: bool = True,
adapter_dim: int = 64
):
super().__init__()
self.dim = dim
self.use_adapter = use_adapter
# 自注意力
self.self_attn = GroupedQueryAttention(
dim=dim,
n_heads=n_heads,
dropout=dropout,
attn_dropout=dropout
)
# 跨模态注意力
self.cross_attn = CrossModalAttention(
dim=dim,
n_heads=n_heads,
dropout=dropout
)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout)
)
# 归一化层
self.norm1 = RMSNorm(dim)
self.norm2 = RMSNorm(dim)
self.norm3 = RMSNorm(dim)
# 模态适配器
if use_adapter:
self.adapter = ModalityAdapter(dim, adapter_dim)
else:
self.adapter = None
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
modality_id: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
attn_out = self.self_attn(
self.norm1(x),
attention_mask=attention_mask
)[0] # 只取输出
x = x + attn_out
if context is not None:
cross_attn_out = self.cross_attn(
self.norm2(x),
context,
context,
attention_mask=None
)
x = x + cross_attn_out
# 前馈网络
x = x + self.ffn(self.norm3(x))
# 模态适配器
if self.use_adapter and modality_id is not None and self.adapter is not None:
x = self.adapter(x, modality_id)
return x
class PerceiverResampler(nn.Module):
"""Perceiver Resampler - 压缩模态特征到固定数量的tokens"""
def __init__(
self,
dim: int,
depth: int = 6,
num_latents: int = 64,
n_heads: int = 16,
dropout: float = 0.0
):
super().__init__()
self.num_latents = num_latents
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.layers = nn.ModuleList([
CrossModalFusionLayer(
dim=dim,
n_heads=n_heads,
dropout=dropout,
use_adapter=False
)
for _ in range(depth)
])
self.norm = RMSNorm(dim)
# 初始化latents
nn.init.trunc_normal_(self.latents, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B = x.shape[0]
latents = self.latents.unsqueeze(0).expand(B, -1, -1)
# 通过多层交叉注意力处理
for layer in self.layers:
latents = layer(latents, context=x)
return self.norm(latents)
class MultiModalFusionModule(nn.Module):
"""多模态融合模块 - 整合所有融合策略"""
def __init__(
self,
dim: int = 2048,
num_fusion_layers: int = 4,
n_heads: int = 16,
dropout: float = 0.1,
use_perceiver: bool = True,
num_latents: int = 64,
use_contrastive: bool = True,
contrastive_loss_type: str = 'siglip',
contrastive_embed_dim: int = 512
):
super().__init__()
self.dim = dim
self.use_perceiver = use_perceiver
self.use_contrastive = use_contrastive
# 模态投影器
self.modality_projectors = nn.ModuleDict({
'image': ModalityProjector(dim, dim),
'audio': ModalityProjector(dim, dim),
'video': ModalityProjector(dim, dim),
'text': ModalityProjector(dim, dim)
})
# 跨模态融合层
self.fusion_layers = nn.ModuleList([
CrossModalFusionLayer(
dim=dim,
n_heads=n_heads,
dropout=dropout,
use_adapter=True
)
for _ in range(num_fusion_layers)
])
# Perceiver Resampler
if use_perceiver:
self.perceiver = PerceiverResampler(
dim=dim,
depth=4,
num_latents=num_latents,
n_heads=n_heads,
dropout=dropout
)
# 对比学习模块
if use_contrastive:
# 定义每个模态的输入维度和池化类型
modality_config = {
'text': 'cls',
'image': 'cls',
'audio': 'mean',
'video': 'mean'
}
input_dims = {k: dim for k in modality_config.keys()}
self.contrastive_module = MultiModalContrastiveLoss(
embed_dim=contrastive_embed_dim,
input_dims=input_dims,
temperature=0.07,
loss_type=contrastive_loss_type,
modality_config=modality_config
)
self.final_norm = RMSNorm(dim)
def _pool_features(self, features: torch.Tensor) -> torch.Tensor:
"""池化特征到单一向量 [B, T, D] -> [B, D]"""
if features.dim() == 3:
return features.mean(dim=1)
return features
def forward(
self,
segments: List[Dict],
compute_contrastive: bool = False
) -> Dict:
# 分离不同模态
modality_features = {}
modality_ids = {}
for seg in segments:
mod_type = seg['type']
mod_data = seg['data']
mod_id = seg['modality_id']
# 检查数据维度
if mod_data.dim() != 3:
raise ValueError(
f"Expected 3D tensor [B, T, D] for modality {mod_type}, "
f"got shape {mod_data.shape}"
)
# 投影到统一空间
if mod_type in self.modality_projectors:
projected = self.modality_projectors[mod_type](mod_data)
else:
projected = mod_data
# 使用Perceiver压缩(可选,非text模态)
if self.use_perceiver and mod_type != 'text':
projected = self.perceiver(projected)
modality_features[mod_type] = projected
modality_ids[mod_type] = mod_id
# 跨模态融合
fused_features = {}
for mod_type, features in modality_features.items():
# 创建不包含当前模态的上下文
if len(modality_features) > 1:
other_features = torch.cat([
f for k, f in modality_features.items() if k != mod_type
], dim=1)
else:
other_features = None
# 通过融合层
fused = features
for layer in self.fusion_layers:
fused = layer(
fused,
context=other_features,
modality_id=modality_ids[mod_type]
)
fused_features[mod_type] = self.final_norm(fused)
# 计算对比学习损失(如果需要)
contrastive_losses = {}
if compute_contrastive and self.use_contrastive:
pooled_features = fused_features
# 定义需要对比的模态对
modality_pairs = []
if 'text' in pooled_features:
for mod in pooled_features.keys():
if mod != 'text':
modality_pairs.append((mod, 'text'))
# 调用对比学习模块
if modality_pairs:
contrastive_losses = self.contrastive_module(
pooled_features,
modality_pairs=modality_pairs
)
# 拼接所有融合后的特征
fused_sequence = torch.cat(list(fused_features.values()), dim=1)
return {
'fused_features': fused_sequence,
'modality_features': fused_features,
'contrastive_losses': contrastive_losses
}
class EarlyFusionModule(nn.Module):
"""早期融合 - 在浅层就融合模态"""
def __init__(self, dim: int = 2048):
super().__init__()
self.fusion_proj = nn.Linear(dim, dim)
self.norm = RMSNorm(dim)
def forward(self, segments: List[Dict]) -> torch.Tensor:
"""简单拼接所有模态"""
all_features = [seg['data'] for seg in segments]
fused = torch.cat(all_features, dim=1)
fused = self.fusion_proj(fused)
return self.norm(fused)
class LateFusionModule(nn.Module):
"""晚期融合 - 在深层才融合模态"""
def __init__(
self,
dim: int = 2048,
num_modalities: int = 4,
fusion_method: str = 'concat' # 'concat', 'attention', 'average'
):
super().__init__()
self.fusion_method = fusion_method
if fusion_method == 'concat':
self.fusion_proj = nn.Linear(dim * num_modalities, dim)
elif fusion_method == 'attention':
self.attention_weights = nn.Linear(dim, 1)
self.norm = RMSNorm(dim)
def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor:
if self.fusion_method == 'concat':
# 拼接并投影
pooled = [x.mean(dim=1) for x in modality_outputs]
fused = torch.cat(pooled, dim=-1)
fused = self.fusion_proj(fused)
elif self.fusion_method == 'attention':
# 注意力加权
stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
weights = F.softmax(self.attention_weights(stacked), dim=1)
fused = (stacked * weights).sum(dim=1)
else: # average
stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1)
fused = stacked.mean(dim=1)
return self.norm(fused)