|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
from components import RMSNorm |
|
|
from transformer import GroupedQueryAttention |
|
|
import math |
|
|
from contrastive_learning import MultiModalContrastiveLoss |
|
|
|
|
|
|
|
|
class CrossModalAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
n_heads: int = 16, |
|
|
dropout: float = 0.1, |
|
|
qkv_bias: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = dim // n_heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
|
|
|
assert dim % n_heads == 0, f"dim {dim} must be divisible by n_heads {n_heads}" |
|
|
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
self.o_proj = nn.Linear(dim, dim) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
self.resid_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.norm_q = RMSNorm(dim) |
|
|
self.norm_k = RMSNorm(dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
B, T_q, D = query.shape |
|
|
T_k = key.shape[1] |
|
|
|
|
|
|
|
|
query = self.norm_q(query) |
|
|
key = self.norm_k(key) |
|
|
|
|
|
|
|
|
q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if hasattr(F, 'scaled_dot_product_attention'): |
|
|
dropout_p = self.attn_dropout.p if self.training else 0.0 |
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=dropout_p, |
|
|
is_causal=False |
|
|
) |
|
|
else: |
|
|
attn_scores = (q @ k.transpose(-2, -1)) * self.scale |
|
|
if attention_mask is not None: |
|
|
attn_scores = attn_scores + attention_mask |
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
attn_output = attn_weights @ v |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T_q, D) |
|
|
output = self.resid_dropout(self.o_proj(attn_output)) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class ModalityProjector(nn.Module): |
|
|
"""模态投影器 - 将不同模态投影到统一空间""" |
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
output_dim: int, |
|
|
hidden_dim: Optional[int] = None, |
|
|
num_layers: int = 2, |
|
|
use_layer_norm: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
if hidden_dim is None: |
|
|
hidden_dim = (input_dim + output_dim) // 2 |
|
|
|
|
|
layers = [] |
|
|
for i in range(num_layers): |
|
|
if i == 0: |
|
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
elif i == num_layers - 1: |
|
|
layers.append(nn.Linear(hidden_dim, output_dim)) |
|
|
else: |
|
|
layers.append(nn.Linear(hidden_dim, hidden_dim)) |
|
|
|
|
|
if i < num_layers - 1: |
|
|
if use_layer_norm: |
|
|
layers.append(RMSNorm(hidden_dim)) |
|
|
layers.append(nn.GELU()) |
|
|
|
|
|
self.projector = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.projector(x) |
|
|
|
|
|
|
|
|
class ModalityAdapter(nn.Module): |
|
|
"""模态适配器 - 为每个模态学习特定的适配参数""" |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
bottleneck_dim: int = 64, |
|
|
num_modalities: int = 4 |
|
|
): |
|
|
super().__init__() |
|
|
self.adapters = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Linear(dim, bottleneck_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(bottleneck_dim, dim) |
|
|
) |
|
|
for _ in range(num_modalities) |
|
|
]) |
|
|
for adapter in self.adapters: |
|
|
nn.init.zeros_(adapter[-1].weight) |
|
|
nn.init.zeros_(adapter[-1].bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor, modality_id: int) -> torch.Tensor: |
|
|
if modality_id >= len(self.adapters): |
|
|
return x |
|
|
return x + self.adapters[modality_id](x) |
|
|
|
|
|
|
|
|
class CrossModalFusionLayer(nn.Module): |
|
|
"""跨模态融合层""" |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
n_heads: int = 16, |
|
|
dropout: float = 0.1, |
|
|
use_adapter: bool = True, |
|
|
adapter_dim: int = 64 |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.use_adapter = use_adapter |
|
|
|
|
|
|
|
|
self.self_attn = GroupedQueryAttention( |
|
|
dim=dim, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
attn_dropout=dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.cross_attn = CrossModalAttention( |
|
|
dim=dim, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(dim, dim * 4), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(dim * 4, dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
|
|
|
self.norm1 = RMSNorm(dim) |
|
|
self.norm2 = RMSNorm(dim) |
|
|
self.norm3 = RMSNorm(dim) |
|
|
|
|
|
|
|
|
if use_adapter: |
|
|
self.adapter = ModalityAdapter(dim, adapter_dim) |
|
|
else: |
|
|
self.adapter = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
context: Optional[torch.Tensor] = None, |
|
|
modality_id: Optional[int] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
attn_out = self.self_attn( |
|
|
self.norm1(x), |
|
|
attention_mask=attention_mask |
|
|
)[0] |
|
|
x = x + attn_out |
|
|
|
|
|
if context is not None: |
|
|
cross_attn_out = self.cross_attn( |
|
|
self.norm2(x), |
|
|
context, |
|
|
context, |
|
|
attention_mask=None |
|
|
) |
|
|
x = x + cross_attn_out |
|
|
|
|
|
|
|
|
x = x + self.ffn(self.norm3(x)) |
|
|
|
|
|
|
|
|
if self.use_adapter and modality_id is not None and self.adapter is not None: |
|
|
x = self.adapter(x, modality_id) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
|
"""Perceiver Resampler - 压缩模态特征到固定数量的tokens""" |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
depth: int = 6, |
|
|
num_latents: int = 64, |
|
|
n_heads: int = 16, |
|
|
dropout: float = 0.0 |
|
|
): |
|
|
super().__init__() |
|
|
self.num_latents = num_latents |
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
CrossModalFusionLayer( |
|
|
dim=dim, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
use_adapter=False |
|
|
) |
|
|
for _ in range(depth) |
|
|
]) |
|
|
|
|
|
self.norm = RMSNorm(dim) |
|
|
|
|
|
|
|
|
nn.init.trunc_normal_(self.latents, std=0.02) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B = x.shape[0] |
|
|
latents = self.latents.unsqueeze(0).expand(B, -1, -1) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
latents = layer(latents, context=x) |
|
|
|
|
|
return self.norm(latents) |
|
|
|
|
|
|
|
|
class MultiModalFusionModule(nn.Module): |
|
|
"""多模态融合模块 - 整合所有融合策略""" |
|
|
def __init__( |
|
|
self, |
|
|
dim: int = 2048, |
|
|
num_fusion_layers: int = 4, |
|
|
n_heads: int = 16, |
|
|
dropout: float = 0.1, |
|
|
use_perceiver: bool = True, |
|
|
num_latents: int = 64, |
|
|
use_contrastive: bool = True, |
|
|
contrastive_loss_type: str = 'siglip', |
|
|
contrastive_embed_dim: int = 512 |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.use_perceiver = use_perceiver |
|
|
self.use_contrastive = use_contrastive |
|
|
|
|
|
|
|
|
self.modality_projectors = nn.ModuleDict({ |
|
|
'image': ModalityProjector(dim, dim), |
|
|
'audio': ModalityProjector(dim, dim), |
|
|
'video': ModalityProjector(dim, dim), |
|
|
'text': ModalityProjector(dim, dim) |
|
|
}) |
|
|
|
|
|
|
|
|
self.fusion_layers = nn.ModuleList([ |
|
|
CrossModalFusionLayer( |
|
|
dim=dim, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout, |
|
|
use_adapter=True |
|
|
) |
|
|
for _ in range(num_fusion_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
if use_perceiver: |
|
|
self.perceiver = PerceiverResampler( |
|
|
dim=dim, |
|
|
depth=4, |
|
|
num_latents=num_latents, |
|
|
n_heads=n_heads, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
|
|
|
if use_contrastive: |
|
|
|
|
|
modality_config = { |
|
|
'text': 'cls', |
|
|
'image': 'cls', |
|
|
'audio': 'mean', |
|
|
'video': 'mean' |
|
|
} |
|
|
|
|
|
input_dims = {k: dim for k in modality_config.keys()} |
|
|
|
|
|
self.contrastive_module = MultiModalContrastiveLoss( |
|
|
embed_dim=contrastive_embed_dim, |
|
|
input_dims=input_dims, |
|
|
temperature=0.07, |
|
|
loss_type=contrastive_loss_type, |
|
|
modality_config=modality_config |
|
|
) |
|
|
|
|
|
self.final_norm = RMSNorm(dim) |
|
|
|
|
|
def _pool_features(self, features: torch.Tensor) -> torch.Tensor: |
|
|
"""池化特征到单一向量 [B, T, D] -> [B, D]""" |
|
|
if features.dim() == 3: |
|
|
return features.mean(dim=1) |
|
|
return features |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
segments: List[Dict], |
|
|
compute_contrastive: bool = False |
|
|
) -> Dict: |
|
|
|
|
|
modality_features = {} |
|
|
modality_ids = {} |
|
|
|
|
|
for seg in segments: |
|
|
mod_type = seg['type'] |
|
|
mod_data = seg['data'] |
|
|
mod_id = seg['modality_id'] |
|
|
|
|
|
|
|
|
if mod_data.dim() != 3: |
|
|
raise ValueError( |
|
|
f"Expected 3D tensor [B, T, D] for modality {mod_type}, " |
|
|
f"got shape {mod_data.shape}" |
|
|
) |
|
|
|
|
|
|
|
|
if mod_type in self.modality_projectors: |
|
|
projected = self.modality_projectors[mod_type](mod_data) |
|
|
else: |
|
|
projected = mod_data |
|
|
|
|
|
|
|
|
if self.use_perceiver and mod_type != 'text': |
|
|
projected = self.perceiver(projected) |
|
|
|
|
|
modality_features[mod_type] = projected |
|
|
modality_ids[mod_type] = mod_id |
|
|
|
|
|
|
|
|
fused_features = {} |
|
|
|
|
|
for mod_type, features in modality_features.items(): |
|
|
|
|
|
if len(modality_features) > 1: |
|
|
other_features = torch.cat([ |
|
|
f for k, f in modality_features.items() if k != mod_type |
|
|
], dim=1) |
|
|
else: |
|
|
other_features = None |
|
|
|
|
|
|
|
|
fused = features |
|
|
for layer in self.fusion_layers: |
|
|
fused = layer( |
|
|
fused, |
|
|
context=other_features, |
|
|
modality_id=modality_ids[mod_type] |
|
|
) |
|
|
|
|
|
fused_features[mod_type] = self.final_norm(fused) |
|
|
|
|
|
|
|
|
contrastive_losses = {} |
|
|
if compute_contrastive and self.use_contrastive: |
|
|
pooled_features = fused_features |
|
|
|
|
|
|
|
|
modality_pairs = [] |
|
|
if 'text' in pooled_features: |
|
|
for mod in pooled_features.keys(): |
|
|
if mod != 'text': |
|
|
modality_pairs.append((mod, 'text')) |
|
|
|
|
|
|
|
|
if modality_pairs: |
|
|
contrastive_losses = self.contrastive_module( |
|
|
pooled_features, |
|
|
modality_pairs=modality_pairs |
|
|
) |
|
|
|
|
|
|
|
|
fused_sequence = torch.cat(list(fused_features.values()), dim=1) |
|
|
|
|
|
return { |
|
|
'fused_features': fused_sequence, |
|
|
'modality_features': fused_features, |
|
|
'contrastive_losses': contrastive_losses |
|
|
} |
|
|
|
|
|
|
|
|
class EarlyFusionModule(nn.Module): |
|
|
"""早期融合 - 在浅层就融合模态""" |
|
|
def __init__(self, dim: int = 2048): |
|
|
super().__init__() |
|
|
self.fusion_proj = nn.Linear(dim, dim) |
|
|
self.norm = RMSNorm(dim) |
|
|
|
|
|
def forward(self, segments: List[Dict]) -> torch.Tensor: |
|
|
"""简单拼接所有模态""" |
|
|
all_features = [seg['data'] for seg in segments] |
|
|
fused = torch.cat(all_features, dim=1) |
|
|
fused = self.fusion_proj(fused) |
|
|
return self.norm(fused) |
|
|
|
|
|
|
|
|
class LateFusionModule(nn.Module): |
|
|
"""晚期融合 - 在深层才融合模态""" |
|
|
def __init__( |
|
|
self, |
|
|
dim: int = 2048, |
|
|
num_modalities: int = 4, |
|
|
fusion_method: str = 'concat' |
|
|
): |
|
|
super().__init__() |
|
|
self.fusion_method = fusion_method |
|
|
|
|
|
if fusion_method == 'concat': |
|
|
self.fusion_proj = nn.Linear(dim * num_modalities, dim) |
|
|
elif fusion_method == 'attention': |
|
|
self.attention_weights = nn.Linear(dim, 1) |
|
|
|
|
|
self.norm = RMSNorm(dim) |
|
|
|
|
|
def forward(self, modality_outputs: List[torch.Tensor]) -> torch.Tensor: |
|
|
if self.fusion_method == 'concat': |
|
|
|
|
|
pooled = [x.mean(dim=1) for x in modality_outputs] |
|
|
fused = torch.cat(pooled, dim=-1) |
|
|
fused = self.fusion_proj(fused) |
|
|
|
|
|
elif self.fusion_method == 'attention': |
|
|
|
|
|
stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1) |
|
|
weights = F.softmax(self.attention_weights(stacked), dim=1) |
|
|
fused = (stacked * weights).sum(dim=1) |
|
|
|
|
|
else: |
|
|
stacked = torch.stack([x.mean(dim=1) for x in modality_outputs], dim=1) |
|
|
fused = stacked.mean(dim=1) |
|
|
|
|
|
return self.norm(fused) |