Spaces:
Running
Running
| from torch import Tensor | |
| import torch.distributed as dist | |
| import torch | |
| import torch.nn.functional as F | |
| class SimpleContrastiveLoss: | |
| def __init__(self, temperature: float = 0.02): | |
| self.temperature = temperature | |
| def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: | |
| if target is None: | |
| target_per_qry = y.size(0) // x.size(0) | |
| target = torch.arange( | |
| 0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long) | |
| logits = torch.matmul(x, y.transpose(0, 1)) | |
| loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) | |
| return loss | |
| class DistributedContrastiveLoss(SimpleContrastiveLoss): | |
| def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02): | |
| assert dist.is_initialized(), "Distributed training has not been properly initialized." | |
| super().__init__() | |
| self.word_size = dist.get_world_size() | |
| self.rank = dist.get_rank() | |
| self.scale_loss = scale_loss | |
| self.temperature = temperature | |
| def __call__(self, x: Tensor, y: Tensor, **kwargs): | |
| dist_x = self.gather_tensor(x) | |
| dist_y = self.gather_tensor(y) | |
| loss = super().__call__(dist_x, dist_y, **kwargs) | |
| if self.scale_loss: | |
| loss = loss * self.word_size | |
| return loss | |
| def gather_tensor(self, t): | |
| gathered = [torch.empty_like(t) for _ in range(self.word_size)] | |
| dist.all_gather(gathered, t) | |
| gathered[self.rank] = t | |
| return torch.cat(gathered, dim=0) | |
| class HardNegativeContrastiveLoss: | |
| def __init__(self, temperature: float = 0.02): | |
| self.temperature = temperature | |
| def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, reduction: str = 'mean') -> Tensor: | |
| # x: query embeddings | |
| # y: positive embeddings | |
| # z: negative embeddings (optional) | |
| if z is None: | |
| target_per_qry = y.size(0) // x.size(0) | |
| target = torch.arange( | |
| 0, x.size(0) * target_per_qry, target_per_qry, | |
| device=x.device, dtype=torch.long) | |
| logits = torch.matmul(x, y.transpose(0, 1)) | |
| loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) | |
| return loss | |
| pos_logits = torch.matmul(x, y.transpose(0, 1)) | |
| neg_logits = torch.matmul(x, z.transpose(0, 1)) | |
| logits = torch.cat([pos_logits, neg_logits], dim=1) | |
| target = torch.arange(x.size(0), device=x.device) | |
| loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction) | |
| return loss | |
| class DistributedHardNegativeContrastiveLoss(HardNegativeContrastiveLoss): | |
| def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02): | |
| assert dist.is_initialized(), "Distributed training has not been properly initialized." | |
| super().__init__(temperature=temperature) | |
| self.world_size = dist.get_world_size() | |
| self.rank = dist.get_rank() | |
| self.scale_loss = scale_loss | |
| def __call__(self, x: Tensor, y: Tensor, z: Tensor = None, **kwargs): | |
| dist_x = self.gather_tensor(x) | |
| dist_y = self.gather_tensor(y) | |
| dist_z = self.gather_tensor(z) if z is not None else None | |
| loss = super().__call__(dist_x, dist_y, dist_z, **kwargs) | |
| if self.scale_loss: | |
| loss = loss * self.world_size | |
| return loss | |
| def gather_tensor(self, t): | |
| if t is None: | |
| return None | |
| gathered = [torch.empty_like(t) for _ in range(self.world_size)] | |
| dist.all_gather(gathered, t) | |
| gathered[self.rank] = t | |
| return torch.cat(gathered, dim=0) | |