MultiModal / contrastive_learning.py
szxllm's picture
Update contrastive_learning.py
28693e2 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Union, Literal, List
import math
import copy
class CLIPLoss(nn.Module):
"""CLIP风格的对比学习损失"""
def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0):
super().__init__()
self.temperature = temperature
self.max_temperature = max_temperature
# 初始化 logit_scale
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature))
def forward(
self,
image_features: torch.Tensor,
text_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
image_features: [B, D]
text_features: [B, D]
"""
# 归一化
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# 限制 logit_scale 防止数值不稳定
logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature)
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logits_per_image.T
# 标签: 对角线为正样本
batch_size = image_features.shape[0]
labels = torch.arange(batch_size, device=image_features.device)
# 双向交叉熵
loss_i2t = F.cross_entropy(logits_per_image, labels)
loss_t2i = F.cross_entropy(logits_per_text, labels)
total_loss = (loss_i2t + loss_t2i) / 2
return total_loss, loss_i2t, loss_t2i
class SigLIPLoss(nn.Module):
def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0):
super().__init__()
self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature)))
self.b = nn.Parameter(torch.tensor(init_bias))
def forward(
self,
image_features: torch.Tensor,
text_features: torch.Tensor
) -> torch.Tensor:
# 归一化
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
batch_size = image_features.shape[0]
# Logits = exp(t) * (x @ yT) + b
logits = image_features @ text_features.T * self.t_prime.exp() + self.b
# 构造标签: 对角线为1,其余为-1
labels = -torch.ones(batch_size, batch_size, device=image_features.device)
labels += 2 * torch.eye(batch_size, device=image_features.device)
loss = -F.logsigmoid(labels * logits).sum() / batch_size
return loss
class InfoNCELoss(nn.Module):
def __init__(self, temperature: float = 0.07):
super().__init__()
self.temperature = temperature
def forward(
self,
query: torch.Tensor,
positive_key: torch.Tensor,
negative_keys: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
query: [B, D]
positive_key: [B, D]
negative_keys: [B, N, D] or None.
"""
query = F.normalize(query, dim=-1)
positive_key = F.normalize(positive_key, dim=-1)
if negative_keys is not None:
pos_sim = (query * positive_key).sum(dim=-1) / self.temperature
negative_keys = F.normalize(negative_keys, dim=-1)
# neg_sim: [B, N]
neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature
# [B, 1 + N]
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
# 正样本在索引0
labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
else:
logits = query @ positive_key.T / self.temperature
labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device)
loss = F.cross_entropy(logits, labels)
return loss
class ProjectionHead(nn.Module):
def __init__(
self,
input_dim: int,
embed_dim: int,
pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean',
exclude_first_token: bool = False
):
super().__init__()
self.pooling_type = pooling_type
self.exclude_first_token = exclude_first_token
self.net = nn.Sequential(
nn.Linear(input_dim, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, embed_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 适配 3D 张量 [B, Seq, D] -> [B, D]
if x.dim() == 3:
if self.pooling_type == 'cls':
x = x[:, 0, :]
elif self.pooling_type == 'mean':
if self.exclude_first_token and x.shape[1] > 1:
x = x[:, 1:, :].mean(dim=1)
else:
x = x.mean(dim=1)
elif self.pooling_type == 'max':
if self.exclude_first_token and x.shape[1] > 1:
x = x[:, 1:, :].max(dim=1)[0]
else:
x = x.max(dim=1)[0]
elif self.pooling_type == 'none':
pass
return self.net(x)
class MultiModalContrastiveLoss(nn.Module):
def __init__(
self,
embed_dim: int = 512,
input_dims: Union[int, Dict[str, int]] = 2048,
temperature: float = 0.07,
loss_type: str = 'clip',
modality_config: Optional[Dict[str, str]] = None
):
super().__init__()
self.embed_dim = embed_dim
self.loss_type = loss_type
if loss_type == 'clip':
self.loss_fn = CLIPLoss(temperature)
elif loss_type == 'siglip':
self.loss_fn = SigLIPLoss()
else:
self.loss_fn = InfoNCELoss(temperature)
self.projectors = nn.ModuleDict()
if modality_config is None:
modality_config = {
'text': 'cls',
'image': 'cls',
'audio': 'mean',
'video': 'mean'
}
self.modality_config = modality_config
# 初始化投影头
for mod_name, pool_type in modality_config.items():
dim = 0
if isinstance(input_dims, dict):
dim = input_dims.get(mod_name)
# 如果字典里没给这个模态的维度,跳过初始化,避免 crash
if dim is None:
continue
else:
dim = input_dims
exclude_first = False
if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']:
exclude_first = True
self.projectors[mod_name] = ProjectionHead(
input_dim=dim,
embed_dim=embed_dim,
pooling_type=pool_type,
exclude_first_token=exclude_first
)
def forward(
self,
features: Dict[str, torch.Tensor],
modality_pairs: Optional[List[Tuple[str, str]]] = None
) -> Dict[str, torch.Tensor]:
# 自动生成对比对:将所有非Text模态与Text对比
if modality_pairs is None:
if 'text' in features:
modality_pairs = [
(mod, 'text') for mod in features.keys() if mod != 'text'
]
else:
return {}
losses = {}
for mod_a, mod_b in modality_pairs:
if mod_a not in features or mod_b not in features:
continue
if mod_a not in self.projectors or mod_b not in self.projectors:
# 记录警告或跳过
continue
feat_a = self.projectors[mod_a](features[mod_a])
feat_b = self.projectors[mod_b](features[mod_b])
# 计算损失
loss_key = f'{mod_a}_{mod_b}_loss'
if self.loss_type == 'clip':
loss, _, _ = self.loss_fn(feat_a, feat_b)
else:
loss = self.loss_fn(feat_a, feat_b)
losses[loss_key] = loss
return losses
class MomentumEncoder(nn.Module):
def __init__(self, encoder: nn.Module, momentum: float = 0.999):
super().__init__()
self.encoder = encoder
self.momentum_encoder = self._build_momentum_encoder(encoder)
self.momentum = momentum
def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module:
"""构建动量编码器"""
momentum_encoder = copy.deepcopy(encoder)
# 冻结动量编码器参数
for param in momentum_encoder.parameters():
param.requires_grad = False
return momentum_encoder
@torch.no_grad()
def _update_momentum_encoder(self):
for param_q, param_k in zip(
self.encoder.parameters(),
self.momentum_encoder.parameters()
):
# EMA Update: k = m * k + (1 - m) * q
param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum)
for buffer_q, buffer_k in zip(
self.encoder.buffers(),
self.momentum_encoder.buffers()
):
buffer_k.data.copy_(buffer_q.data)
def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor:
if use_momentum:
with torch.no_grad():
self._update_momentum_encoder()
# 动量编码器始终处于 eval 模式
self.momentum_encoder.eval()
return self.momentum_encoder(x)
else:
return self.encoder(x)