Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| try: | |
| import torch.distributed.nn | |
| from torch import distributed as dist | |
| has_distributed = True | |
| except ImportError: | |
| has_distributed = False | |
| try: | |
| import horovod.torch as hvd | |
| except ImportError: | |
| hvd = None | |
| def gather_features( | |
| image_features, | |
| text_features, | |
| local_loss=False, | |
| gather_with_grad=False, | |
| rank=0, | |
| world_size=1, | |
| use_horovod=False | |
| ): | |
| assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' | |
| if use_horovod: | |
| assert hvd is not None, 'Please install horovod' | |
| if gather_with_grad: | |
| all_image_features = hvd.allgather(image_features) | |
| all_text_features = hvd.allgather(text_features) | |
| else: | |
| with torch.no_grad(): | |
| all_image_features = hvd.allgather(image_features) | |
| all_text_features = hvd.allgather(text_features) | |
| if not local_loss: | |
| # ensure grads for local rank when all_* features don't have a gradient | |
| gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) | |
| gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) | |
| gathered_image_features[rank] = image_features | |
| gathered_text_features[rank] = text_features | |
| all_image_features = torch.cat(gathered_image_features, dim=0) | |
| all_text_features = torch.cat(gathered_text_features, dim=0) | |
| else: | |
| # We gather tensors from all gpus | |
| if gather_with_grad: | |
| all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) | |
| all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) | |
| else: | |
| gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] | |
| gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] | |
| dist.all_gather(gathered_image_features, image_features) | |
| dist.all_gather(gathered_text_features, text_features) | |
| if not local_loss: | |
| # ensure grads for local rank when all_* features don't have a gradient | |
| gathered_image_features[rank] = image_features | |
| gathered_text_features[rank] = text_features | |
| all_image_features = torch.cat(gathered_image_features, dim=0) | |
| all_text_features = torch.cat(gathered_text_features, dim=0) | |
| return all_image_features, all_text_features | |
| class ClipLoss(nn.Module): | |
| def __init__( | |
| self, | |
| local_loss=False, | |
| gather_with_grad=False, | |
| cache_labels=False, | |
| rank=0, | |
| world_size=1, | |
| use_horovod=False, | |
| ): | |
| super().__init__() | |
| self.local_loss = local_loss | |
| self.gather_with_grad = gather_with_grad | |
| self.cache_labels = cache_labels | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.use_horovod = use_horovod | |
| # cache state | |
| self.prev_num_logits = 0 | |
| self.labels = {} | |
| def get_ground_truth(self, device, num_logits) -> torch.Tensor: | |
| # calculated ground-truth and cache if enabled | |
| if self.prev_num_logits != num_logits or device not in self.labels: | |
| labels = torch.arange(num_logits, device=device, dtype=torch.long) | |
| if self.world_size > 1 and self.local_loss: | |
| labels = labels + num_logits * self.rank | |
| if self.cache_labels: | |
| self.labels[device] = labels | |
| self.prev_num_logits = num_logits | |
| else: | |
| labels = self.labels[device] | |
| return labels | |
| def get_logits(self, image_features, text_features, logit_scale): | |
| if self.world_size > 1: | |
| all_image_features, all_text_features = gather_features( | |
| image_features, | |
| text_features, | |
| local_loss=self.local_loss, | |
| gather_with_grad=self.gather_with_grad, | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| use_horovod=self.use_horovod, | |
| ) | |
| if self.local_loss: | |
| logits_per_image = logit_scale * image_features @ all_text_features.T | |
| logits_per_text = logit_scale * text_features @ all_image_features.T | |
| else: | |
| logits_per_image = logit_scale * all_image_features @ all_text_features.T | |
| logits_per_text = logits_per_image.T | |
| else: | |
| logits_per_image = logit_scale * image_features @ text_features.T | |
| logits_per_text = logit_scale * text_features @ image_features.T | |
| return logits_per_image, logits_per_text | |
| def forward(self, image_features, text_features, logit_scale, output_dict=False, padding_mask=None): | |
| device = image_features.device | |
| logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) | |
| labels = self.get_ground_truth(device, logits_per_image.shape[0]) | |
| if padding_mask is not None: | |
| if self.world_size > 1: | |
| local_mask_list = [torch.empty_like(padding_mask) for _ in range(self.world_size)] | |
| dist.all_gather(local_mask_list, padding_mask) # [world_size * B] | |
| global_padding_mask = torch.cat(local_mask_list, dim=0) | |
| else: | |
| global_padding_mask = padding_mask | |
| ignore_index = -100 | |
| # labels[~global_padding_mask] = ignore_index | |
| # total_loss = ( | |
| # F.cross_entropy(logits_per_image, labels, ignore_index=ignore_index) + | |
| # F.cross_entropy(logits_per_text, labels, ignore_index=ignore_index) | |
| # ) / 2 | |
| labels_mod = labels.clone() | |
| labels_mod[~global_padding_mask] = ignore_index | |
| total_loss = ( | |
| F.cross_entropy(logits_per_image, labels_mod, ignore_index=ignore_index) + | |
| F.cross_entropy(logits_per_text, labels_mod, ignore_index=ignore_index) | |
| ) / 2 | |
| else: | |
| total_loss = ( | |
| F.cross_entropy(logits_per_image, labels) + | |
| F.cross_entropy(logits_per_text, labels) | |
| ) / 2 | |
| return {"contrastive_loss": total_loss} if output_dict else total_loss | |
| class MultiPosConLossMM(nn.Module): | |
| """Multi-positive contrastive loss, when multiple images corresponds to the same texts. Refer to https://arxiv.org/abs/2306.00984""" | |
| def __init__(self, rank, world_size, temperature=0.1, w1=1.0, w2=1.0): | |
| """ | |
| Args: | |
| temperature: temperature for contrastive loss | |
| w1: weight for the image contrastive part | |
| w2: weight for the image-text part | |
| """ | |
| super(MultiPosConLossMM, self).__init__() | |
| self.temperature = temperature | |
| self.w1 = w1 | |
| self.w2 = w2 | |
| self.last_local_batch_size = None | |
| self.rank = rank | |
| self.world_size = world_size | |
| def compute_cross_entropy(self, p, q): | |
| q = F.log_softmax(q, dim=-1) | |
| loss = torch.sum(p * q, dim=-1) | |
| return - loss.mean() | |
| def stablize_logits(self, logits): | |
| logits_max, _ = torch.max(logits, dim=-1, keepdim=True) | |
| logits = logits - logits_max.detach() | |
| return logits | |
| def concat_all_gather(self, tensor): | |
| """ | |
| Performs all_gather operation on the provided tensors. | |
| *** Warning ***: torch.distributed.all_gather has no gradient. | |
| """ | |
| tensors_gather = [torch.ones_like(tensor) | |
| for _ in range(self.world_size)] | |
| dist.all_gather(tensors_gather, tensor, async_op=False) | |
| output = torch.cat(tensors_gather, dim=0) | |
| return output | |
| def forward(self, v_feats, t_feats, logit_scale, labels, output_dict=False): | |
| device = v_feats.device | |
| v_feats = F.normalize(v_feats, dim=-1) | |
| t_feats = F.normalize(t_feats, dim=-1) | |
| v_local_batch_size = v_feats.size(0) | |
| t_local_batch_size = t_feats.size(0) | |
| # ====== get labels ====== | |
| if self.world_size > 1: | |
| all_v_feats = torch.cat(dist.nn.all_gather(v_feats), dim=0) | |
| all_t_feats = torch.cat(dist.nn.all_gather(t_feats), dim=0) | |
| all_labels = self.concat_all_gather(labels) | |
| # all_labels = all_labels.contiguous().view(1, -1) | |
| else: | |
| all_v_feats = v_feats | |
| all_t_feats = t_feats | |
| all_labels = labels | |
| # all_labels = all_labels.view(1, -1) # shape: (1, B) | |
| # ====== get valid samples ====== | |
| valid_mask = (all_labels != -1) | |
| all_v_feats = all_v_feats[valid_mask] # [N_valid, D] | |
| all_t_feats = all_t_feats[valid_mask] # [N_valid, D] | |
| all_labels = all_labels[valid_mask] # [N_valid] | |
| # ====== get_logits ====== | |
| # compute the logits for image-text contrasting | |
| logits_v = logit_scale * all_v_feats @ all_t_feats.T | |
| logits_t = logit_scale * all_t_feats @ all_v_feats.T | |
| # # compute the logits for image-only contrasting | |
| # feats = outputs['image_feats'] | |
| # feats = F.normalize(feats, dim=-1, p=2) | |
| # all_feats = torch.cat(torch.distributed.nn.all_gather(feats), dim=0) | |
| # logits = torch.matmul(feats, all_feats.T) / self.temperature | |
| # ====== Create label matrix ====== | |
| # mask matrix for image-text contrastive loss | |
| label_matrix = torch.eq(all_labels.view(-1, 1), | |
| all_labels).float().to(device) | |
| # # mask matrix for image supervised contrastive loss | |
| # self.mask = torch.eq(v_labels.view(-1, 1), all_v_labels).float().to(device) | |
| # self.logits_mask = torch.scatter( | |
| # torch.ones_like(self.mask), | |
| # 1, | |
| # torch.arange(self.mask.shape[0]).view(-1, 1).to(device) + | |
| # v_local_batch_size * misc.get_rank(), | |
| # 0 | |
| # ) | |
| # self.mask = self.mask * self.logits_mask | |
| # | |
| # self.last_local_batch_size = v_local_batch_size | |
| # image only loss | |
| # mask = self.mask | |
| # p = mask / mask.sum(1, keepdim=True).clamp(min=1.0) | |
| # logits = logits - (1 - self.logits_mask) * 1e9 | |
| # logits = stablize_logits(logits) | |
| # img_loss = compute_cross_entropy(p, logits) | |
| # image text loss | |
| label_matrix = label_matrix / label_matrix.sum(1, keepdim=True).clamp(min=1.0) | |
| img_txt_loss = (self.compute_cross_entropy(label_matrix, logits_v) + self.compute_cross_entropy(label_matrix, logits_t)) / 2 | |
| # total loss | |
| img_loss = 0 | |
| loss = self.w1 * img_loss + self.w2 * img_txt_loss | |
| return {'loss': loss, | |
| 'image_loss': img_loss, | |
| 'img_txt_loss': img_txt_loss} if output_dict else loss | |
| class CoCaLoss(ClipLoss): | |
| def __init__( | |
| self, | |
| caption_loss_weight, | |
| clip_loss_weight, | |
| pad_id=0, # pad_token for open_clip custom tokenizer | |
| local_loss=False, | |
| gather_with_grad=False, | |
| cache_labels=False, | |
| rank=0, | |
| world_size=1, | |
| use_horovod=False, | |
| ): | |
| super().__init__( | |
| local_loss=local_loss, | |
| gather_with_grad=gather_with_grad, | |
| cache_labels=cache_labels, | |
| rank=rank, | |
| world_size=world_size, | |
| use_horovod=use_horovod | |
| ) | |
| self.clip_loss_weight = clip_loss_weight | |
| self.caption_loss_weight = caption_loss_weight | |
| self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) | |
| def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): | |
| if self.clip_loss_weight: | |
| clip_loss = super().forward(image_features, text_features, logit_scale) | |
| clip_loss = self.clip_loss_weight * clip_loss | |
| else: | |
| clip_loss = torch.tensor(0, device=logits.device) | |
| caption_loss = self.caption_loss( | |
| logits.permute(0, 2, 1), | |
| labels, | |
| ) | |
| caption_loss = caption_loss * self.caption_loss_weight | |
| if output_dict: | |
| return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} | |
| return clip_loss, caption_loss | |
| class DistillClipLoss(ClipLoss): | |
| def dist_loss(self, teacher_logits, student_logits): | |
| return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) | |
| def forward( | |
| self, | |
| image_features, | |
| text_features, | |
| logit_scale, | |
| dist_image_features, | |
| dist_text_features, | |
| dist_logit_scale, | |
| output_dict=False, | |
| ): | |
| logits_per_image, logits_per_text = \ | |
| self.get_logits(image_features, text_features, logit_scale) | |
| dist_logits_per_image, dist_logits_per_text = \ | |
| self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) | |
| labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) | |
| contrastive_loss = ( | |
| F.cross_entropy(logits_per_image, labels) + | |
| F.cross_entropy(logits_per_text, labels) | |
| ) / 2 | |
| distill_loss = ( | |
| self.dist_loss(dist_logits_per_image, logits_per_image) + | |
| self.dist_loss(dist_logits_per_text, logits_per_text) | |
| ) / 2 | |
| if output_dict: | |
| return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} | |
| return contrastive_loss, distill_loss | |
| def neighbour_exchange(from_rank, to_rank, tensor, group=None): | |
| tensor_recv = torch.zeros_like(tensor) | |
| send_op = torch.distributed.P2POp( | |
| torch.distributed.isend, | |
| tensor, | |
| to_rank, | |
| group=group, | |
| ) | |
| recv_op = torch.distributed.P2POp( | |
| torch.distributed.irecv, | |
| tensor_recv, | |
| from_rank, | |
| group=group, | |
| ) | |
| reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) | |
| for req in reqs: | |
| req.wait() | |
| return tensor_recv | |
| def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): | |
| tensor_from_left = torch.zeros_like(tensor_to_right) | |
| tensor_from_right = torch.zeros_like(tensor_to_left) | |
| send_op_left = torch.distributed.P2POp( | |
| torch.distributed.isend, | |
| tensor_to_left, | |
| left_rank, | |
| group=group, | |
| ) | |
| send_op_right = torch.distributed.P2POp( | |
| torch.distributed.isend, | |
| tensor_to_right, | |
| right_rank, | |
| group=group, | |
| ) | |
| recv_op_left = torch.distributed.P2POp( | |
| torch.distributed.irecv, | |
| tensor_from_left, | |
| left_rank, | |
| group=group, | |
| ) | |
| recv_op_right = torch.distributed.P2POp( | |
| torch.distributed.irecv, | |
| tensor_from_right, | |
| right_rank, | |
| group=group, | |
| ) | |
| reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) | |
| for req in reqs: | |
| req.wait() | |
| return tensor_from_right, tensor_from_left | |
| class NeighbourExchange(torch.autograd.Function): | |
| def forward(ctx, from_rank, to_rank, group, tensor): | |
| ctx.group = group | |
| ctx.from_rank = from_rank | |
| ctx.to_rank = to_rank | |
| return neighbour_exchange(from_rank, to_rank, tensor, group=group) | |
| def backward(ctx, grad_output): | |
| return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) | |
| def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): | |
| return NeighbourExchange.apply(from_rank, to_rank, group, tensor) | |
| class NeighbourExchangeBidir(torch.autograd.Function): | |
| def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): | |
| ctx.group = group | |
| ctx.left_rank = left_rank | |
| ctx.right_rank = right_rank | |
| return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) | |
| def backward(ctx, *grad_outputs): | |
| return (None, None, None) + \ | |
| NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) | |
| def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): | |
| return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) | |
| class SigLipLoss(nn.Module): | |
| """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 | |
| @article{zhai2023sigmoid, | |
| title={Sigmoid loss for language image pre-training}, | |
| author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, | |
| journal={arXiv preprint arXiv:2303.15343}, | |
| year={2023} | |
| } | |
| """ | |
| def __init__( | |
| self, | |
| cache_labels: bool = False, | |
| rank: int = 0, | |
| world_size: int = 1, | |
| dist_impl: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.cache_labels = cache_labels | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change | |
| assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather') | |
| # cache state FIXME cache not currently used, worthwhile? | |
| self.prev_num_logits = 0 | |
| self.labels = {} | |
| def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: | |
| labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) | |
| if not negative_only: | |
| labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels | |
| return labels | |
| def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): | |
| logits = logit_scale * image_features @ text_features.T | |
| if logit_bias is not None: | |
| logits += logit_bias | |
| return logits | |
| def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): | |
| logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) | |
| labels = self.get_ground_truth( | |
| image_features.device, | |
| image_features.dtype, | |
| image_features.shape[0], | |
| negative_only=negative_only, | |
| ) | |
| loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] | |
| return loss | |
| def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): | |
| loss = self._loss(image_features, text_features, logit_scale, logit_bias) | |
| if self.world_size > 1: | |
| if self.dist_impl == 'bidir': | |
| right_rank = (self.rank + 1) % self.world_size | |
| left_rank = (self.rank - 1 + self.world_size) % self.world_size | |
| text_features_to_right = text_features_to_left = text_features | |
| num_bidir, remainder = divmod(self.world_size - 1, 2) | |
| for i in range(num_bidir): | |
| text_features_recv = neighbour_exchange_bidir_with_grad( | |
| left_rank, | |
| right_rank, | |
| text_features_to_left, | |
| text_features_to_right, | |
| ) | |
| for f in text_features_recv: | |
| loss += self._loss( | |
| image_features, | |
| f, | |
| logit_scale, | |
| logit_bias, | |
| negative_only=True, | |
| ) | |
| text_features_to_left, text_features_to_right = text_features_recv | |
| if remainder: | |
| text_features_recv = neighbour_exchange_with_grad( | |
| left_rank, | |
| right_rank, | |
| text_features_to_right | |
| ) | |
| loss += self._loss( | |
| image_features, | |
| text_features_recv, | |
| logit_scale, | |
| logit_bias, | |
| negative_only=True, | |
| ) | |
| elif self.dist_impl == "shift": | |
| right_rank = (self.rank + 1) % self.world_size | |
| left_rank = (self.rank - 1 + self.world_size) % self.world_size | |
| text_features_to_right = text_features | |
| for i in range(self.world_size - 1): | |
| text_features_from_left = neighbour_exchange_with_grad( | |
| left_rank, | |
| right_rank, | |
| text_features_to_right, | |
| ) | |
| loss += self._loss( | |
| image_features, | |
| text_features_from_left, | |
| logit_scale, | |
| logit_bias, | |
| negative_only=True, | |
| ) | |
| text_features_to_right = text_features_from_left | |
| elif self.dist_impl == "reduce": | |
| for i in range(self.world_size): | |
| text_from_other = torch.distributed.nn.all_reduce( | |
| text_features * (self.rank == i), | |
| torch.distributed.ReduceOp.SUM, | |
| ) | |
| loss += float(i != self.rank) * self._loss( | |
| image_features, | |
| text_from_other, | |
| logit_scale, | |
| logit_bias, | |
| negative_only=True, | |
| ) | |
| elif self.dist_impl == "gather": | |
| all_text = torch.distributed.nn.all_gather(text_features) | |
| for i in range(self.world_size): | |
| loss += float(i != self.rank) * self._loss( | |
| image_features, | |
| all_text[i], | |
| logit_scale, | |
| logit_bias, | |
| negative_only=True, | |
| ) | |
| else: | |
| assert False | |
| return {"contrastive_loss": loss} if output_dict else loss | |