File size: 4,332 Bytes
11c72a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .ETC import ETC
from .UWE import UWE
from .Encoder import MLPEncoder


class CFDTM(nn.Module):
    '''
    Modeling Dynamic Topics in Chain-Free Fashion by Evolution-Tracking Contrastive Learning and Unassociated Word Exclusion. ACL 2024 Findings

    Xiaobao Wu, Xinshuai Dong, Liangming Pan, Thong Nguyen, Anh Tuan Luu.
    '''

    def __init__(self,
                 vocab_size,
                 train_time_wordfreq,
                 num_times,
                 pretrained_WE=None,
                 num_topics=50,
                 en_units=100,
                 temperature=0.1,
                 beta_temp=1.0,
                 weight_neg=1.0e+7,
                 weight_pos=1.0e+1,
                 weight_UWE=1.0e+3,
                 neg_topk=15,
                 dropout=0.,
                 embed_size=200
                ):
        super().__init__()

        self.num_topics = num_topics
        self.beta_temp = beta_temp
        self.train_time_wordfreq = train_time_wordfreq
        self.encoder = MLPEncoder(vocab_size, num_topics, en_units, dropout)

        self.a = 1 * np.ones((1, num_topics)).astype(np.float32)
        self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T))
        self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T))

        self.mu2.requires_grad = False
        self.var2.requires_grad = False

        self.decoder_bn = nn.BatchNorm1d(vocab_size, affine=False)

        if pretrained_WE is None:
            self.word_embeddings  = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size), std=0.1)
            self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))

        else:
            self.word_embeddings = nn.Parameter(torch.from_numpy(pretrained_WE).float())

        # topic_embeddings: TxKxD
        self.topic_embeddings = nn.init.xavier_normal_(torch.zeros(num_topics, self.word_embeddings.shape[1])).repeat(num_times, 1, 1)
        self.topic_embeddings = nn.Parameter(self.topic_embeddings)

        self.ETC = ETC(num_times, temperature, weight_neg, weight_pos)
        self.UWE = UWE(self.ETC, num_times, temperature, weight_UWE, neg_topk)

    def get_beta(self):
        dist = self.pairwise_euclidean_dist(F.normalize(self.topic_embeddings, dim=-1), F.normalize(self.word_embeddings, dim=-1))
        beta = F.softmax(-dist / self.beta_temp, dim=1)

        return beta

    def pairwise_euclidean_dist(self, x, y):
        cost = torch.sum(x ** 2, axis=-1, keepdim=True) + torch.sum(y ** 2, axis=-1) - 2 * torch.matmul(x, y.t())
        return cost

    def get_theta(self, x, times=None):
        theta, mu, logvar = self.encoder(x)
        if self.training:
            return theta, mu, logvar

        return theta

    def get_KL(self, mu, logvar):
        var = logvar.exp()
        var_division = var / self.var2
        diff = mu - self.mu2
        diff_term = diff * diff / self.var2
        logvar_division = self.var2.log() - logvar
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(axis=1) - self.num_topics)

        return KLD.mean()

    def get_NLL(self, theta, beta, x, recon_x=None):
        if recon_x is None:
            recon_x = self.decode(theta, beta)
        recon_loss = -(x * recon_x.log()).sum(axis=1)

        return recon_loss

    def decode(self, theta, beta):
        d1 = F.softmax(self.decoder_bn(torch.bmm(theta.unsqueeze(1), beta).squeeze(1)), dim=-1)
        return d1

    def forward(self, x, times):
        loss = 0.

        theta, mu, logvar = self.get_theta(x)
        kl_theta = self.get_KL(mu, logvar)

        loss += kl_theta

        beta = self.get_beta()
        time_index_beta = beta[times]
        recon_x = self.decode(theta, time_index_beta)
        NLL = self.get_NLL(theta, time_index_beta, x, recon_x)
        NLL = NLL.mean()
        loss += NLL

        loss_ETC = self.ETC(self.topic_embeddings)
        loss += loss_ETC

        loss_UWE = self.UWE(self.train_time_wordfreq, beta, self.topic_embeddings, self.word_embeddings)
        loss += loss_UWE

        rst_dict = {
            'loss': loss,
        }

        return rst_dict