Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ETC(nn.Module): | |
| def __init__(self, num_times, temperature, weight_neg, weight_pos): | |
| super().__init__() | |
| self.num_times = num_times | |
| self.weight_neg = weight_neg | |
| self.weight_pos = weight_pos | |
| self.temperature = temperature | |
| def forward(self, topic_embeddings): | |
| loss = 0. | |
| loss_neg = 0. | |
| loss_pos = 0. | |
| for t in range(self.num_times): | |
| loss_neg += self.compute_loss(topic_embeddings[t], topic_embeddings[t], self.temperature, self_contrast=True) | |
| for t in range(1, self.num_times): | |
| loss_pos += self.compute_loss(topic_embeddings[t], topic_embeddings[t - 1].detach(), self.temperature, self_contrast=False, only_pos=True) | |
| loss_neg *= (self.weight_neg / self.num_times) | |
| loss_pos *= (self.weight_pos / (self.num_times - 1)) | |
| loss = loss_neg + loss_pos | |
| return loss | |
| def compute_loss(self, anchor_feature, contrast_feature, temperature, self_contrast=False, only_pos=False, all_neg=False): | |
| # KxK | |
| anchor_dot_contrast = torch.div( | |
| torch.matmul(F.normalize(anchor_feature, dim=1), F.normalize(contrast_feature, dim=1).T), | |
| temperature | |
| ) | |
| logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) | |
| logits = anchor_dot_contrast - logits_max.detach() | |
| pos_mask = torch.eye(anchor_dot_contrast.shape[0]).to(anchor_dot_contrast.device) | |
| if self_contrast is False: | |
| if only_pos is False: | |
| if all_neg is True: | |
| exp_logits = torch.exp(logits) | |
| sum_exp_logits = exp_logits.sum(1) | |
| log_prob = -torch.log(sum_exp_logits + 1e-12) | |
| mean_log_prob = -log_prob.sum() / (logits.shape[0] * logits.shape[1]) | |
| else: | |
| # only pos | |
| mean_log_prob = -(logits * pos_mask).sum() / pos_mask.sum() | |
| else: | |
| # self contrast: push away from each other in the same time slice. | |
| exp_logits = torch.exp(logits) * (1 - pos_mask) | |
| sum_exp_logits = exp_logits.sum(1) | |
| log_prob = -torch.log(sum_exp_logits + 1e-12) | |
| mean_log_prob = -log_prob.sum() / (1 - pos_mask).sum() | |
| return mean_log_prob | |