|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, Optional, Tuple, Union, Literal, List |
|
|
import math |
|
|
import copy |
|
|
|
|
|
class CLIPLoss(nn.Module): |
|
|
"""CLIP风格的对比学习损失""" |
|
|
def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0): |
|
|
super().__init__() |
|
|
self.temperature = temperature |
|
|
self.max_temperature = max_temperature |
|
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image_features: torch.Tensor, |
|
|
text_features: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
image_features: [B, D] |
|
|
text_features: [B, D] |
|
|
""" |
|
|
|
|
|
image_features = F.normalize(image_features, dim=-1) |
|
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
|
|
|
|
|
logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature) |
|
|
logits_per_image = logit_scale * image_features @ text_features.T |
|
|
logits_per_text = logits_per_image.T |
|
|
|
|
|
|
|
|
batch_size = image_features.shape[0] |
|
|
labels = torch.arange(batch_size, device=image_features.device) |
|
|
|
|
|
|
|
|
loss_i2t = F.cross_entropy(logits_per_image, labels) |
|
|
loss_t2i = F.cross_entropy(logits_per_text, labels) |
|
|
|
|
|
total_loss = (loss_i2t + loss_t2i) / 2 |
|
|
|
|
|
return total_loss, loss_i2t, loss_t2i |
|
|
|
|
|
class SigLIPLoss(nn.Module): |
|
|
def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0): |
|
|
super().__init__() |
|
|
self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature))) |
|
|
self.b = nn.Parameter(torch.tensor(init_bias)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image_features: torch.Tensor, |
|
|
text_features: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
|
|
|
image_features = F.normalize(image_features, dim=-1) |
|
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
|
|
batch_size = image_features.shape[0] |
|
|
|
|
|
|
|
|
logits = image_features @ text_features.T * self.t_prime.exp() + self.b |
|
|
|
|
|
|
|
|
labels = -torch.ones(batch_size, batch_size, device=image_features.device) |
|
|
labels += 2 * torch.eye(batch_size, device=image_features.device) |
|
|
|
|
|
loss = -F.logsigmoid(labels * logits).sum() / batch_size |
|
|
|
|
|
return loss |
|
|
|
|
|
class InfoNCELoss(nn.Module): |
|
|
def __init__(self, temperature: float = 0.07): |
|
|
super().__init__() |
|
|
self.temperature = temperature |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query: torch.Tensor, |
|
|
positive_key: torch.Tensor, |
|
|
negative_keys: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
query: [B, D] |
|
|
positive_key: [B, D] |
|
|
negative_keys: [B, N, D] or None. |
|
|
""" |
|
|
query = F.normalize(query, dim=-1) |
|
|
positive_key = F.normalize(positive_key, dim=-1) |
|
|
|
|
|
if negative_keys is not None: |
|
|
|
|
|
pos_sim = (query * positive_key).sum(dim=-1) / self.temperature |
|
|
|
|
|
negative_keys = F.normalize(negative_keys, dim=-1) |
|
|
|
|
|
neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature |
|
|
|
|
|
|
|
|
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1) |
|
|
|
|
|
labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device) |
|
|
else: |
|
|
logits = query @ positive_key.T / self.temperature |
|
|
labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device) |
|
|
|
|
|
loss = F.cross_entropy(logits, labels) |
|
|
return loss |
|
|
|
|
|
class ProjectionHead(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
embed_dim: int, |
|
|
pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean', |
|
|
exclude_first_token: bool = False |
|
|
): |
|
|
super().__init__() |
|
|
self.pooling_type = pooling_type |
|
|
self.exclude_first_token = exclude_first_token |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(input_dim, embed_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(embed_dim, embed_dim) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
if x.dim() == 3: |
|
|
if self.pooling_type == 'cls': |
|
|
x = x[:, 0, :] |
|
|
|
|
|
elif self.pooling_type == 'mean': |
|
|
if self.exclude_first_token and x.shape[1] > 1: |
|
|
x = x[:, 1:, :].mean(dim=1) |
|
|
else: |
|
|
x = x.mean(dim=1) |
|
|
|
|
|
elif self.pooling_type == 'max': |
|
|
if self.exclude_first_token and x.shape[1] > 1: |
|
|
x = x[:, 1:, :].max(dim=1)[0] |
|
|
else: |
|
|
x = x.max(dim=1)[0] |
|
|
|
|
|
elif self.pooling_type == 'none': |
|
|
pass |
|
|
|
|
|
return self.net(x) |
|
|
|
|
|
class MultiModalContrastiveLoss(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int = 512, |
|
|
input_dims: Union[int, Dict[str, int]] = 2048, |
|
|
temperature: float = 0.07, |
|
|
loss_type: str = 'clip', |
|
|
modality_config: Optional[Dict[str, str]] = None |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.loss_type = loss_type |
|
|
|
|
|
if loss_type == 'clip': |
|
|
self.loss_fn = CLIPLoss(temperature) |
|
|
elif loss_type == 'siglip': |
|
|
self.loss_fn = SigLIPLoss() |
|
|
else: |
|
|
self.loss_fn = InfoNCELoss(temperature) |
|
|
|
|
|
self.projectors = nn.ModuleDict() |
|
|
|
|
|
if modality_config is None: |
|
|
modality_config = { |
|
|
'text': 'cls', |
|
|
'image': 'cls', |
|
|
'audio': 'mean', |
|
|
'video': 'mean' |
|
|
} |
|
|
|
|
|
self.modality_config = modality_config |
|
|
|
|
|
|
|
|
for mod_name, pool_type in modality_config.items(): |
|
|
dim = 0 |
|
|
if isinstance(input_dims, dict): |
|
|
dim = input_dims.get(mod_name) |
|
|
|
|
|
if dim is None: |
|
|
continue |
|
|
else: |
|
|
dim = input_dims |
|
|
|
|
|
exclude_first = False |
|
|
if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']: |
|
|
exclude_first = True |
|
|
|
|
|
self.projectors[mod_name] = ProjectionHead( |
|
|
input_dim=dim, |
|
|
embed_dim=embed_dim, |
|
|
pooling_type=pool_type, |
|
|
exclude_first_token=exclude_first |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: Dict[str, torch.Tensor], |
|
|
modality_pairs: Optional[List[Tuple[str, str]]] = None |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
|
|
|
if modality_pairs is None: |
|
|
if 'text' in features: |
|
|
modality_pairs = [ |
|
|
(mod, 'text') for mod in features.keys() if mod != 'text' |
|
|
] |
|
|
else: |
|
|
return {} |
|
|
|
|
|
losses = {} |
|
|
|
|
|
for mod_a, mod_b in modality_pairs: |
|
|
if mod_a not in features or mod_b not in features: |
|
|
continue |
|
|
|
|
|
if mod_a not in self.projectors or mod_b not in self.projectors: |
|
|
|
|
|
continue |
|
|
|
|
|
feat_a = self.projectors[mod_a](features[mod_a]) |
|
|
feat_b = self.projectors[mod_b](features[mod_b]) |
|
|
|
|
|
|
|
|
loss_key = f'{mod_a}_{mod_b}_loss' |
|
|
|
|
|
if self.loss_type == 'clip': |
|
|
loss, _, _ = self.loss_fn(feat_a, feat_b) |
|
|
else: |
|
|
loss = self.loss_fn(feat_a, feat_b) |
|
|
|
|
|
losses[loss_key] = loss |
|
|
|
|
|
return losses |
|
|
|
|
|
class MomentumEncoder(nn.Module): |
|
|
def __init__(self, encoder: nn.Module, momentum: float = 0.999): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.momentum_encoder = self._build_momentum_encoder(encoder) |
|
|
self.momentum = momentum |
|
|
|
|
|
def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module: |
|
|
"""构建动量编码器""" |
|
|
momentum_encoder = copy.deepcopy(encoder) |
|
|
|
|
|
|
|
|
for param in momentum_encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
return momentum_encoder |
|
|
|
|
|
@torch.no_grad() |
|
|
def _update_momentum_encoder(self): |
|
|
for param_q, param_k in zip( |
|
|
self.encoder.parameters(), |
|
|
self.momentum_encoder.parameters() |
|
|
): |
|
|
|
|
|
param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum) |
|
|
|
|
|
for buffer_q, buffer_k in zip( |
|
|
self.encoder.buffers(), |
|
|
self.momentum_encoder.buffers() |
|
|
): |
|
|
buffer_k.data.copy_(buffer_q.data) |
|
|
|
|
|
def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor: |
|
|
if use_momentum: |
|
|
with torch.no_grad(): |
|
|
self._update_momentum_encoder() |
|
|
|
|
|
self.momentum_encoder.eval() |
|
|
return self.momentum_encoder(x) |
|
|
else: |
|
|
return self.encoder(x) |