AdhyaSuman's picture
Initial commit with Git LFS for large files
11c72a2
import torch
from torch import nn
import torch.nn.functional as F
class DETM(nn.Module):
"""
The Dynamic Embedded Topic Model. 2019
Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei
"""
def __init__(self, vocab_size, num_times, train_size, train_time_wordfreq,
num_topics=50, train_WE=True, pretrained_WE=None, en_units=800,
eta_hidden_size=200, rho_size=300, enc_drop=0.0, eta_nlayers=3,
eta_dropout=0.0, delta=0.005, theta_act='relu', device='cpu'):
super().__init__()
## define hyperparameters
self.num_topics = num_topics
self.num_times = num_times
self.vocab_size = vocab_size
self.eta_hidden_size = eta_hidden_size
self.rho_size = rho_size
self.enc_drop = enc_drop
self.eta_nlayers = eta_nlayers
self.t_drop = nn.Dropout(enc_drop)
self.eta_dropout = eta_dropout
self.delta = delta
self.train_WE = train_WE
self.train_size = train_size
self.rnn_inp = train_time_wordfreq
self.device = device
self.theta_act = self.get_activation(theta_act)
## define the word embedding matrix \rho
if self.train_WE:
self.rho = nn.Linear(self.rho_size, self.vocab_size, bias=False)
else:
rho = nn.Embedding(pretrained_WE.size())
rho.weight.data = torch.from_numpy(pretrained_WE)
self.rho = rho.weight.data.clone().float().to(self.device)
## define the variational parameters for the topic embeddings over time (alpha) ... alpha is K x T x L
self.mu_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
self.logsigma_q_alpha = nn.Parameter(torch.randn(self.num_topics, self.num_times, self.rho_size))
## define variational distribution for \theta_{1:D} via amortizartion... theta is K x D
self.q_theta = nn.Sequential(
nn.Linear(self.vocab_size + self.num_topics, en_units),
self.theta_act,
nn.Linear(en_units, en_units),
self.theta_act,
)
self.mu_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
self.logsigma_q_theta = nn.Linear(en_units, self.num_topics, bias=True)
## define variational distribution for \eta via amortizartion... eta is K x T
self.q_eta_map = nn.Linear(self.vocab_size, self.eta_hidden_size)
self.q_eta = nn.LSTM(self.eta_hidden_size, self.eta_hidden_size, self.eta_nlayers, dropout=self.eta_dropout)
self.mu_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
self.logsigma_q_eta = nn.Linear(self.eta_hidden_size + self.num_topics, self.num_topics, bias=True)
self.decoder_bn = nn.BatchNorm1d(vocab_size)
self.decoder_bn.weight.requires_grad = False
def get_activation(self, act):
activations = {
'tanh': nn.Tanh(),
'relu': nn.ReLU(),
'softplus': nn.Softplus(),
'rrelu': nn.RReLU(),
'leakyrelu': nn.LeakyReLU(),
'elu': nn.ELU(),
'selu': nn.SELU(),
'glu': nn.GLU(),
}
if act in activations:
act = activations[act]
else:
print('Defaulting to tanh activations...')
act = nn.Tanh()
return act
def reparameterize(self, mu, logvar):
"""Returns a sample from a Gaussian distribution via reparameterization.
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul_(std).add_(mu)
else:
return mu
def get_kl(self, q_mu, q_logsigma, p_mu=None, p_logsigma=None):
"""Returns KL( N(q_mu, q_logsigma) || N(p_mu, p_logsigma) ).
"""
if p_mu is not None and p_logsigma is not None:
sigma_q_sq = torch.exp(q_logsigma)
sigma_p_sq = torch.exp(p_logsigma)
kl = ( sigma_q_sq + (q_mu - p_mu)**2 ) / ( sigma_p_sq + 1e-6 )
kl = kl - 1 + p_logsigma - q_logsigma
kl = 0.5 * torch.sum(kl, dim=-1)
else:
kl = -0.5 * torch.sum(1 + q_logsigma - q_mu.pow(2) - q_logsigma.exp(), dim=-1)
return kl
def get_alpha(self): ## mean field
alphas = torch.zeros(self.num_times, self.num_topics, self.rho_size).to(self.device)
kl_alpha = []
alphas[0] = self.reparameterize(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :])
# TODO: why logsigma_p_0 is zero?
p_mu_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
logsigma_p_0 = torch.zeros(self.num_topics, self.rho_size).to(self.device)
kl_0 = self.get_kl(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :], p_mu_0, logsigma_p_0)
kl_alpha.append(kl_0)
for t in range(1, self.num_times):
alphas[t] = self.reparameterize(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :])
p_mu_t = alphas[t - 1]
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics, self.rho_size).to(self.device))
kl_t = self.get_kl(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :], p_mu_t, logsigma_p_t)
kl_alpha.append(kl_t)
kl_alpha = torch.stack(kl_alpha).sum()
return alphas, kl_alpha.sum()
def get_eta(self, rnn_inp): ## structured amortized inference
inp = self.q_eta_map(rnn_inp).unsqueeze(1)
hidden = self.init_hidden()
output, _ = self.q_eta(inp, hidden)
output = output.squeeze()
etas = torch.zeros(self.num_times, self.num_topics).to(self.device)
kl_eta = []
inp_0 = torch.cat([output[0], torch.zeros(self.num_topics,).to(self.device)], dim=0)
mu_0 = self.mu_q_eta(inp_0)
logsigma_0 = self.logsigma_q_eta(inp_0)
etas[0] = self.reparameterize(mu_0, logsigma_0)
p_mu_0 = torch.zeros(self.num_topics,).to(self.device)
logsigma_p_0 = torch.zeros(self.num_topics,).to(self.device)
kl_0 = self.get_kl(mu_0, logsigma_0, p_mu_0, logsigma_p_0)
kl_eta.append(kl_0)
for t in range(1, self.num_times):
inp_t = torch.cat([output[t], etas[t-1]], dim=0)
mu_t = self.mu_q_eta(inp_t)
logsigma_t = self.logsigma_q_eta(inp_t)
etas[t] = self.reparameterize(mu_t, logsigma_t)
p_mu_t = etas[t-1]
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics,).to(self.device))
kl_t = self.get_kl(mu_t, logsigma_t, p_mu_t, logsigma_p_t)
kl_eta.append(kl_t)
kl_eta = torch.stack(kl_eta).sum()
return etas, kl_eta
def get_theta(self, bows, times, eta=None): ## amortized inference
"""Returns the topic proportions.
"""
normalized_bows = bows / bows.sum(1, keepdims=True)
if eta is None and self.training is False:
eta, kl_eta = self.get_eta(self.rnn_inp)
eta_td = eta[times]
inp = torch.cat([normalized_bows, eta_td], dim=1)
q_theta = self.q_theta(inp)
if self.enc_drop > 0:
q_theta = self.t_drop(q_theta)
mu_theta = self.mu_q_theta(q_theta)
logsigma_theta = self.logsigma_q_theta(q_theta)
z = self.reparameterize(mu_theta, logsigma_theta)
theta = F.softmax(z, dim=-1)
kl_theta = self.get_kl(mu_theta, logsigma_theta, eta_td, torch.zeros(self.num_topics).to(self.device))
if self.training:
return theta, kl_theta
else:
return theta
@property
def word_embeddings(self):
return self.rho.weight
@property
def topic_embeddings(self):
alpha, _ = self.get_alpha()
return alpha
def get_beta(self, alpha=None):
"""Returns the topic matrix \beta of shape T x K x V
"""
if alpha is None and self.training is False:
alpha, kl_alpha = self.get_alpha()
if self.train_WE:
logit = self.rho(alpha.view(alpha.size(0) * alpha.size(1), self.rho_size))
else:
tmp = alpha.view(alpha.size(0) * alpha.size(1), self.rho_size)
logit = torch.mm(tmp, self.rho.permute(1, 0))
logit = logit.view(alpha.size(0), alpha.size(1), -1)
beta = F.softmax(logit, dim=-1)
return beta
def get_NLL(self, theta, beta, bows):
theta = theta.unsqueeze(1)
loglik = torch.bmm(theta, beta).squeeze(1)
loglik = torch.log(loglik + 1e-12)
nll = -loglik * bows
nll = nll.sum(-1)
return nll
def forward(self, bows, times):
bsz = bows.size(0)
coeff = self.train_size / bsz
eta, kl_eta = self.get_eta(self.rnn_inp)
theta, kl_theta = self.get_theta(bows, times, eta)
kl_theta = kl_theta.sum() * coeff
alpha, kl_alpha = self.get_alpha()
beta = self.get_beta(alpha)
beta = beta[times]
# beta = beta[times.type('torch.LongTensor')]
nll = self.get_NLL(theta, beta, bows)
nll = nll.sum() * coeff
loss = nll + kl_eta + kl_theta
rst_dict = {
'loss': loss,
'nll': nll,
'kl_eta': kl_eta,
'kl_theta': kl_theta
}
loss += kl_alpha
rst_dict['kl_alpha'] = kl_alpha
return rst_dict
def init_hidden(self):
"""Initializes the first hidden state of the RNN used as inference network for \\eta.
"""
weight = next(self.parameters())
nlayers = self.eta_nlayers
nhid = self.eta_hidden_size
return (weight.new_zeros(nlayers, 1, nhid), weight.new_zeros(nlayers, 1, nhid))