Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .ETC import ETC | |
| from .UWE import UWE | |
| from .Encoder import MLPEncoder | |
| class CFDTM(nn.Module): | |
| ''' | |
| Modeling Dynamic Topics in Chain-Free Fashion by Evolution-Tracking Contrastive Learning and Unassociated Word Exclusion. ACL 2024 Findings | |
| Xiaobao Wu, Xinshuai Dong, Liangming Pan, Thong Nguyen, Anh Tuan Luu. | |
| ''' | |
| def __init__(self, | |
| vocab_size, | |
| train_time_wordfreq, | |
| num_times, | |
| pretrained_WE=None, | |
| num_topics=50, | |
| en_units=100, | |
| temperature=0.1, | |
| beta_temp=1.0, | |
| weight_neg=1.0e+7, | |
| weight_pos=1.0e+1, | |
| weight_UWE=1.0e+3, | |
| neg_topk=15, | |
| dropout=0., | |
| embed_size=200 | |
| ): | |
| super().__init__() | |
| self.num_topics = num_topics | |
| self.beta_temp = beta_temp | |
| self.train_time_wordfreq = train_time_wordfreq | |
| self.encoder = MLPEncoder(vocab_size, num_topics, en_units, dropout) | |
| self.a = 1 * np.ones((1, num_topics)).astype(np.float32) | |
| self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T)) | |
| self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T)) | |
| self.mu2.requires_grad = False | |
| self.var2.requires_grad = False | |
| self.decoder_bn = nn.BatchNorm1d(vocab_size, affine=False) | |
| if pretrained_WE is None: | |
| self.word_embeddings = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size), std=0.1) | |
| self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings)) | |
| else: | |
| self.word_embeddings = nn.Parameter(torch.from_numpy(pretrained_WE).float()) | |
| # topic_embeddings: TxKxD | |
| self.topic_embeddings = nn.init.xavier_normal_(torch.zeros(num_topics, self.word_embeddings.shape[1])).repeat(num_times, 1, 1) | |
| self.topic_embeddings = nn.Parameter(self.topic_embeddings) | |
| self.ETC = ETC(num_times, temperature, weight_neg, weight_pos) | |
| self.UWE = UWE(self.ETC, num_times, temperature, weight_UWE, neg_topk) | |
| def get_beta(self): | |
| dist = self.pairwise_euclidean_dist(F.normalize(self.topic_embeddings, dim=-1), F.normalize(self.word_embeddings, dim=-1)) | |
| beta = F.softmax(-dist / self.beta_temp, dim=1) | |
| return beta | |
| def pairwise_euclidean_dist(self, x, y): | |
| cost = torch.sum(x ** 2, axis=-1, keepdim=True) + torch.sum(y ** 2, axis=-1) - 2 * torch.matmul(x, y.t()) | |
| return cost | |
| def get_theta(self, x, times=None): | |
| theta, mu, logvar = self.encoder(x) | |
| if self.training: | |
| return theta, mu, logvar | |
| return theta | |
| def get_KL(self, mu, logvar): | |
| var = logvar.exp() | |
| var_division = var / self.var2 | |
| diff = mu - self.mu2 | |
| diff_term = diff * diff / self.var2 | |
| logvar_division = self.var2.log() - logvar | |
| KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(axis=1) - self.num_topics) | |
| return KLD.mean() | |
| def get_NLL(self, theta, beta, x, recon_x=None): | |
| if recon_x is None: | |
| recon_x = self.decode(theta, beta) | |
| recon_loss = -(x * recon_x.log()).sum(axis=1) | |
| return recon_loss | |
| def decode(self, theta, beta): | |
| d1 = F.softmax(self.decoder_bn(torch.bmm(theta.unsqueeze(1), beta).squeeze(1)), dim=-1) | |
| return d1 | |
| def forward(self, x, times): | |
| loss = 0. | |
| theta, mu, logvar = self.get_theta(x) | |
| kl_theta = self.get_KL(mu, logvar) | |
| loss += kl_theta | |
| beta = self.get_beta() | |
| time_index_beta = beta[times] | |
| recon_x = self.decode(theta, time_index_beta) | |
| NLL = self.get_NLL(theta, time_index_beta, x, recon_x) | |
| NLL = NLL.mean() | |
| loss += NLL | |
| loss_ETC = self.ETC(self.topic_embeddings) | |
| loss += loss_ETC | |
| loss_UWE = self.UWE(self.train_time_wordfreq, beta, self.topic_embeddings, self.word_embeddings) | |
| loss += loss_UWE | |
| rst_dict = { | |
| 'loss': loss, | |
| } | |
| return rst_dict | |