| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| def get_similarity_matrix( | |
| image_features: torch.Tensor, text_features: torch.Tensor | |
| ) -> torch.Tensor: | |
| return image_features @ text_features.T | |
| def contrastive_loss(logits, dim): | |
| neg_ce = torch.diag(F.log_softmax(logits, dim=dim)) | |
| return -neg_ce.mean() | |
| def contrastive_sigmoid_loss(logits): | |
| return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean") | |
| class CLIPLoss(nn.Module): | |
| def __init__(self, logit_temperature: float = -1.0): | |
| super().__init__() | |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) | |
| def forward(self, similarity_matrix: torch.Tensor, *args): | |
| temperature = self.logit_temperature.sigmoid() | |
| caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0) | |
| image_loss = contrastive_loss(similarity_matrix / temperature, dim=1) | |
| return 0.5 * (caption_loss + image_loss) | |
| class CyCLIPLoss(nn.Module): | |
| def __init__(self, logit_temperature: float = -1.0): | |
| super().__init__() | |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) | |
| self.lambda_1: float = 1.0 | |
| self.lambda_2: float = 1.0 | |
| def forward( | |
| self, | |
| similarity_matrix: torch.Tensor, | |
| image_features: torch.Tensor, | |
| text_features: torch.Tensor, | |
| ): | |
| temperature = self.logit_temperature.sigmoid() | |
| caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0) | |
| image_loss = contrastive_loss(similarity_matrix / temperature, dim=1) | |
| symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T) | |
| modality_difference_loss = F.mse_loss( | |
| image_features @ image_features.T, text_features @ text_features.T | |
| ) | |
| return ( | |
| 0.5 * (caption_loss + image_loss) | |
| + self.lambda_1 * symmetry_loss | |
| + self.lambda_2 * modality_difference_loss | |
| ) | |
| class SigLIPLoss(nn.Module): | |
| def __init__(self, logit_temperature: float = -1.0): | |
| super().__init__() | |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) | |
| def forward(self, similarity_matrix: torch.Tensor, *args): | |
| temperature = self.logit_temperature.sigmoid() | |
| return contrastive_sigmoid_loss(similarity_matrix / temperature) | |
| class CySigLIPLoss(nn.Module): | |
| def __init__(self, logit_temperature: float = -1.0): | |
| super().__init__() | |
| self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature)) | |
| self.lambda_1: float = 1.0 | |
| self.lambda_2: float = 1.0 | |
| def forward( | |
| self, | |
| similarity_matrix: torch.Tensor, | |
| image_features: torch.Tensor, | |
| text_features: torch.Tensor, | |
| ): | |
| temperature = self.logit_temperature.sigmoid() | |
| loss = contrastive_sigmoid_loss(similarity_matrix / temperature) | |
| symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T) | |
| modality_difference_loss = F.mse_loss( | |
| image_features @ image_features.T, text_features @ text_features.T | |
| ) | |
| return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss | |
| def get_loss(loss_type: str): | |
| loss_functions = { | |
| "clip": CLIPLoss(), | |
| "cyclip": CyCLIPLoss(), | |
| "sigmoid": SigLIPLoss(), | |
| "cyclic_sigmoid": CySigLIPLoss(), | |
| } | |
| if loss_type in loss_functions: | |
| return loss_functions[loss_type] | |
| else: | |
| raise ValueError("Invalid loss type") | |