import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Optional, Tuple, Union, Literal, List import math import copy class CLIPLoss(nn.Module): """CLIP风格的对比学习损失""" def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0): super().__init__() self.temperature = temperature self.max_temperature = max_temperature # 初始化 logit_scale self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature)) def forward( self, image_features: torch.Tensor, text_features: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: image_features: [B, D] text_features: [B, D] """ # 归一化 image_features = F.normalize(image_features, dim=-1) text_features = F.normalize(text_features, dim=-1) # 限制 logit_scale 防止数值不稳定 logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature) logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logits_per_image.T # 标签: 对角线为正样本 batch_size = image_features.shape[0] labels = torch.arange(batch_size, device=image_features.device) # 双向交叉熵 loss_i2t = F.cross_entropy(logits_per_image, labels) loss_t2i = F.cross_entropy(logits_per_text, labels) total_loss = (loss_i2t + loss_t2i) / 2 return total_loss, loss_i2t, loss_t2i class SigLIPLoss(nn.Module): def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0): super().__init__() self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature))) self.b = nn.Parameter(torch.tensor(init_bias)) def forward( self, image_features: torch.Tensor, text_features: torch.Tensor ) -> torch.Tensor: # 归一化 image_features = F.normalize(image_features, dim=-1) text_features = F.normalize(text_features, dim=-1) batch_size = image_features.shape[0] # Logits = exp(t) * (x @ yT) + b logits = image_features @ text_features.T * self.t_prime.exp() + self.b # 构造标签: 对角线为1,其余为-1 labels = -torch.ones(batch_size, batch_size, device=image_features.device) labels += 2 * torch.eye(batch_size, device=image_features.device) loss = -F.logsigmoid(labels * logits).sum() / batch_size return loss class InfoNCELoss(nn.Module): def __init__(self, temperature: float = 0.07): super().__init__() self.temperature = temperature def forward( self, query: torch.Tensor, positive_key: torch.Tensor, negative_keys: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: query: [B, D] positive_key: [B, D] negative_keys: [B, N, D] or None. """ query = F.normalize(query, dim=-1) positive_key = F.normalize(positive_key, dim=-1) if negative_keys is not None: pos_sim = (query * positive_key).sum(dim=-1) / self.temperature negative_keys = F.normalize(negative_keys, dim=-1) # neg_sim: [B, N] neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature # [B, 1 + N] logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) # 正样本在索引0 labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device) else: logits = query @ positive_key.T / self.temperature labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device) loss = F.cross_entropy(logits, labels) return loss class ProjectionHead(nn.Module): def __init__( self, input_dim: int, embed_dim: int, pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean', exclude_first_token: bool = False ): super().__init__() self.pooling_type = pooling_type self.exclude_first_token = exclude_first_token self.net = nn.Sequential( nn.Linear(input_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # 适配 3D 张量 [B, Seq, D] -> [B, D] if x.dim() == 3: if self.pooling_type == 'cls': x = x[:, 0, :] elif self.pooling_type == 'mean': if self.exclude_first_token and x.shape[1] > 1: x = x[:, 1:, :].mean(dim=1) else: x = x.mean(dim=1) elif self.pooling_type == 'max': if self.exclude_first_token and x.shape[1] > 1: x = x[:, 1:, :].max(dim=1)[0] else: x = x.max(dim=1)[0] elif self.pooling_type == 'none': pass return self.net(x) class MultiModalContrastiveLoss(nn.Module): def __init__( self, embed_dim: int = 512, input_dims: Union[int, Dict[str, int]] = 2048, temperature: float = 0.07, loss_type: str = 'clip', modality_config: Optional[Dict[str, str]] = None ): super().__init__() self.embed_dim = embed_dim self.loss_type = loss_type if loss_type == 'clip': self.loss_fn = CLIPLoss(temperature) elif loss_type == 'siglip': self.loss_fn = SigLIPLoss() else: self.loss_fn = InfoNCELoss(temperature) self.projectors = nn.ModuleDict() if modality_config is None: modality_config = { 'text': 'cls', 'image': 'cls', 'audio': 'mean', 'video': 'mean' } self.modality_config = modality_config # 初始化投影头 for mod_name, pool_type in modality_config.items(): dim = 0 if isinstance(input_dims, dict): dim = input_dims.get(mod_name) # 如果字典里没给这个模态的维度,跳过初始化,避免 crash if dim is None: continue else: dim = input_dims exclude_first = False if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']: exclude_first = True self.projectors[mod_name] = ProjectionHead( input_dim=dim, embed_dim=embed_dim, pooling_type=pool_type, exclude_first_token=exclude_first ) def forward( self, features: Dict[str, torch.Tensor], modality_pairs: Optional[List[Tuple[str, str]]] = None ) -> Dict[str, torch.Tensor]: # 自动生成对比对:将所有非Text模态与Text对比 if modality_pairs is None: if 'text' in features: modality_pairs = [ (mod, 'text') for mod in features.keys() if mod != 'text' ] else: return {} losses = {} for mod_a, mod_b in modality_pairs: if mod_a not in features or mod_b not in features: continue if mod_a not in self.projectors or mod_b not in self.projectors: # 记录警告或跳过 continue feat_a = self.projectors[mod_a](features[mod_a]) feat_b = self.projectors[mod_b](features[mod_b]) # 计算损失 loss_key = f'{mod_a}_{mod_b}_loss' if self.loss_type == 'clip': loss, _, _ = self.loss_fn(feat_a, feat_b) else: loss = self.loss_fn(feat_a, feat_b) losses[loss_key] = loss return losses class MomentumEncoder(nn.Module): def __init__(self, encoder: nn.Module, momentum: float = 0.999): super().__init__() self.encoder = encoder self.momentum_encoder = self._build_momentum_encoder(encoder) self.momentum = momentum def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module: """构建动量编码器""" momentum_encoder = copy.deepcopy(encoder) # 冻结动量编码器参数 for param in momentum_encoder.parameters(): param.requires_grad = False return momentum_encoder @torch.no_grad() def _update_momentum_encoder(self): for param_q, param_k in zip( self.encoder.parameters(), self.momentum_encoder.parameters() ): # EMA Update: k = m * k + (1 - m) * q param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum) for buffer_q, buffer_k in zip( self.encoder.buffers(), self.momentum_encoder.buffers() ): buffer_k.data.copy_(buffer_q.data) def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor: if use_momentum: with torch.no_grad(): self._update_momentum_encoder() # 动量编码器始终处于 eval 模式 self.momentum_encoder.eval() return self.momentum_encoder(x) else: return self.encoder(x)