File size: 1,635 Bytes
11c72a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn


class UWE(nn.Module):
    def __init__(self, ETC, num_times, temperature, weight_UWE, neg_topk):
        super().__init__()

        self.ETC = ETC
        self.weight_UWE = weight_UWE
        self.num_times = num_times
        self.temperature = temperature
        self.neg_topk = neg_topk

    def forward(self, time_wordcount, beta, topic_embeddings, word_embeddings):
        assert(self.num_times == time_wordcount.shape[0])

        topk_indices = self.get_topk_indices(beta)

        loss_UWE = 0.
        cnt_valid_times = 0.
        for t in range(self.num_times):
            neg_idx = torch.where(time_wordcount[t] == 0)[0]

            time_topk_indices = topk_indices[t]
            neg_idx = list(set(neg_idx.cpu().tolist()).intersection(set(time_topk_indices.cpu().tolist())))
            neg_idx = torch.tensor(neg_idx).long().to(time_wordcount.device)

            if len(neg_idx) == 0:
                continue

            time_neg_WE = word_embeddings[neg_idx]

            # topic_embeddings[t]: K x D
            # word_embeddings[neg_idx]: |V_{neg}| x D
            loss_UWE += self.ETC.compute_loss(topic_embeddings[t], time_neg_WE, temperature=self.temperature, all_neg=True)
            cnt_valid_times += 1

        if cnt_valid_times > 0:
            loss_UWE *= (self.weight_UWE / cnt_valid_times)

        return loss_UWE

    def get_topk_indices(self, beta):
        # topk_indices: T x K x neg_topk
        topk_indices = torch.topk(beta, k=self.neg_topk, dim=-1).indices
        topk_indices = torch.flatten(topk_indices, start_dim=1)
        return topk_indices