AdhyaSuman's picture
Initial commit with Git LFS for large files
11c72a2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ETC import ETC
from .UWE import UWE
from .Encoder import MLPEncoder
class CFDTM(nn.Module):
'''
Modeling Dynamic Topics in Chain-Free Fashion by Evolution-Tracking Contrastive Learning and Unassociated Word Exclusion. ACL 2024 Findings
Xiaobao Wu, Xinshuai Dong, Liangming Pan, Thong Nguyen, Anh Tuan Luu.
'''
def __init__(self,
vocab_size,
train_time_wordfreq,
num_times,
pretrained_WE=None,
num_topics=50,
en_units=100,
temperature=0.1,
beta_temp=1.0,
weight_neg=1.0e+7,
weight_pos=1.0e+1,
weight_UWE=1.0e+3,
neg_topk=15,
dropout=0.,
embed_size=200
):
super().__init__()
self.num_topics = num_topics
self.beta_temp = beta_temp
self.train_time_wordfreq = train_time_wordfreq
self.encoder = MLPEncoder(vocab_size, num_topics, en_units, dropout)
self.a = 1 * np.ones((1, num_topics)).astype(np.float32)
self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T))
self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T))
self.mu2.requires_grad = False
self.var2.requires_grad = False
self.decoder_bn = nn.BatchNorm1d(vocab_size, affine=False)
if pretrained_WE is None:
self.word_embeddings = nn.init.trunc_normal_(torch.empty(vocab_size, embed_size), std=0.1)
self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))
else:
self.word_embeddings = nn.Parameter(torch.from_numpy(pretrained_WE).float())
# topic_embeddings: TxKxD
self.topic_embeddings = nn.init.xavier_normal_(torch.zeros(num_topics, self.word_embeddings.shape[1])).repeat(num_times, 1, 1)
self.topic_embeddings = nn.Parameter(self.topic_embeddings)
self.ETC = ETC(num_times, temperature, weight_neg, weight_pos)
self.UWE = UWE(self.ETC, num_times, temperature, weight_UWE, neg_topk)
def get_beta(self):
dist = self.pairwise_euclidean_dist(F.normalize(self.topic_embeddings, dim=-1), F.normalize(self.word_embeddings, dim=-1))
beta = F.softmax(-dist / self.beta_temp, dim=1)
return beta
def pairwise_euclidean_dist(self, x, y):
cost = torch.sum(x ** 2, axis=-1, keepdim=True) + torch.sum(y ** 2, axis=-1) - 2 * torch.matmul(x, y.t())
return cost
def get_theta(self, x, times=None):
theta, mu, logvar = self.encoder(x)
if self.training:
return theta, mu, logvar
return theta
def get_KL(self, mu, logvar):
var = logvar.exp()
var_division = var / self.var2
diff = mu - self.mu2
diff_term = diff * diff / self.var2
logvar_division = self.var2.log() - logvar
KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(axis=1) - self.num_topics)
return KLD.mean()
def get_NLL(self, theta, beta, x, recon_x=None):
if recon_x is None:
recon_x = self.decode(theta, beta)
recon_loss = -(x * recon_x.log()).sum(axis=1)
return recon_loss
def decode(self, theta, beta):
d1 = F.softmax(self.decoder_bn(torch.bmm(theta.unsqueeze(1), beta).squeeze(1)), dim=-1)
return d1
def forward(self, x, times):
loss = 0.
theta, mu, logvar = self.get_theta(x)
kl_theta = self.get_KL(mu, logvar)
loss += kl_theta
beta = self.get_beta()
time_index_beta = beta[times]
recon_x = self.decode(theta, time_index_beta)
NLL = self.get_NLL(theta, time_index_beta, x, recon_x)
NLL = NLL.mean()
loss += NLL
loss_ETC = self.ETC(self.topic_embeddings)
loss += loss_ETC
loss_UWE = self.UWE(self.train_time_wordfreq, beta, self.topic_embeddings, self.word_embeddings)
loss += loss_UWE
rst_dict = {
'loss': loss,
}
return rst_dict