Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| class UWE(nn.Module): | |
| def __init__(self, ETC, num_times, temperature, weight_UWE, neg_topk): | |
| super().__init__() | |
| self.ETC = ETC | |
| self.weight_UWE = weight_UWE | |
| self.num_times = num_times | |
| self.temperature = temperature | |
| self.neg_topk = neg_topk | |
| def forward(self, time_wordcount, beta, topic_embeddings, word_embeddings): | |
| assert(self.num_times == time_wordcount.shape[0]) | |
| topk_indices = self.get_topk_indices(beta) | |
| loss_UWE = 0. | |
| cnt_valid_times = 0. | |
| for t in range(self.num_times): | |
| neg_idx = torch.where(time_wordcount[t] == 0)[0] | |
| time_topk_indices = topk_indices[t] | |
| neg_idx = list(set(neg_idx.cpu().tolist()).intersection(set(time_topk_indices.cpu().tolist()))) | |
| neg_idx = torch.tensor(neg_idx).long().to(time_wordcount.device) | |
| if len(neg_idx) == 0: | |
| continue | |
| time_neg_WE = word_embeddings[neg_idx] | |
| # topic_embeddings[t]: K x D | |
| # word_embeddings[neg_idx]: |V_{neg}| x D | |
| loss_UWE += self.ETC.compute_loss(topic_embeddings[t], time_neg_WE, temperature=self.temperature, all_neg=True) | |
| cnt_valid_times += 1 | |
| if cnt_valid_times > 0: | |
| loss_UWE *= (self.weight_UWE / cnt_valid_times) | |
| return loss_UWE | |
| def get_topk_indices(self, beta): | |
| # topk_indices: T x K x neg_topk | |
| topk_indices = torch.topk(beta, k=self.neg_topk, dim=-1).indices | |
| topk_indices = torch.flatten(topk_indices, start_dim=1) | |
| return topk_indices | |