nlomov's picture
Added all source code
ab33b80
Raw
History Blame Contribute Delete
6.9 kB
import torch
from torch import nn
from torch.nn import functional as F
from contextualized_topic_models.networks.inference_network import (
CombinedInferenceNetwork,
ContextualInferenceNetwork,
)
class DecoderNetwork(nn.Module):
def __init__(
self,
input_size,
bert_size,
infnet,
n_components=10,
model_type="prodLDA",
hidden_sizes=(100, 100),
activation="softplus",
dropout=0.2,
learn_priors=True,
label_size=0,
):
"""
Initialize InferenceNetwork.
Args
input_size : int, dimension of input
n_components : int, number of topic components, (default 10)
model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA')
hidden_sizes : tuple, length = n_layers, (default (100, 100))
activation : string, 'softplus', 'relu', (default 'softplus')
learn_priors : bool, make priors learnable parameter
"""
super(DecoderNetwork, self).__init__()
assert isinstance(input_size, int), "input_size must by type int."
assert (
isinstance(n_components, int) and n_components > 0
), "n_components must be type int > 0."
assert model_type in ["prodLDA", "LDA"], "model type must be 'prodLDA' or 'LDA'"
assert isinstance(hidden_sizes, tuple), "hidden_sizes must be type tuple."
assert activation in [
"softplus",
"relu",
], "activation must be 'softplus' or 'relu'."
assert dropout >= 0, "dropout must be >= 0."
self.input_size = input_size
self.n_components = n_components
self.model_type = model_type
self.hidden_sizes = hidden_sizes
self.activation = activation
self.dropout = dropout
self.learn_priors = learn_priors
self.topic_word_matrix = None
if infnet == "zeroshot":
self.inf_net = ContextualInferenceNetwork(
input_size,
bert_size,
n_components,
hidden_sizes,
activation,
label_size=label_size,
)
elif infnet == "combined":
self.inf_net = CombinedInferenceNetwork(
input_size,
bert_size,
n_components,
hidden_sizes,
activation,
label_size=label_size,
)
else:
raise Exception(
"Missing infnet parameter, options are zeroshot and combined"
)
if label_size != 0:
self.label_classification = nn.Linear(n_components, label_size)
# init prior parameters
# \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i;
# \alpha = 1 \forall \alpha
topic_prior_mean = 0.0
self.prior_mean = torch.tensor([topic_prior_mean] * n_components)
if torch.cuda.is_available():
self.prior_mean = self.prior_mean.cuda()
if self.learn_priors:
self.prior_mean = nn.Parameter(self.prior_mean)
# \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k;
# \alpha = 1 \forall \alpha
topic_prior_variance = 1.0 - (1.0 / self.n_components)
self.prior_variance = torch.tensor([topic_prior_variance] * n_components)
if torch.cuda.is_available():
self.prior_variance = self.prior_variance.cuda()
if self.learn_priors:
self.prior_variance = nn.Parameter(self.prior_variance)
self.beta = torch.Tensor(n_components, input_size)
if torch.cuda.is_available():
self.beta = self.beta.cuda()
self.beta = nn.Parameter(self.beta)
nn.init.xavier_uniform_(self.beta)
self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False)
# dropout on theta
self.drop_theta = nn.Dropout(p=self.dropout)
@staticmethod
def reparameterize(mu, logvar):
"""Reparameterize the theta distribution."""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
def forward(self, x, x_bert, labels=None):
"""Forward pass."""
# batch_size x n_components
posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert, labels)
posterior_sigma = torch.exp(posterior_log_sigma)
# generate samples from theta
theta = F.softmax(self.reparameterize(posterior_mu, posterior_log_sigma), dim=1)
theta = self.drop_theta(theta)
# prodLDA vs LDA
if self.model_type == "prodLDA":
# in: batch_size x input_size x n_components
word_dist = F.softmax(
self.beta_batchnorm(torch.matmul(theta, self.beta)), dim=1
)
# word_dist: batch_size x input_size
self.topic_word_matrix = self.beta
elif self.model_type == "LDA":
# simplex constrain on Beta
beta = F.softmax(self.beta_batchnorm(self.beta), dim=1)
self.topic_word_matrix = beta
word_dist = torch.matmul(theta, beta)
# word_dist: batch_size x input_size
else:
raise NotImplementedError("Model Type Not Implemented")
# classify labels
estimated_labels = None
if labels is not None:
estimated_labels = self.label_classification(theta)
return (
self.prior_mean,
self.prior_variance,
posterior_mu,
posterior_sigma,
posterior_log_sigma,
word_dist,
estimated_labels,
)
def get_posterior(self, x, x_bert, labels=None):
"""Get posterior distribution."""
# batch_size x n_components
posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert, labels)
return posterior_mu, posterior_log_sigma
def get_theta(self, x, x_bert, labels=None):
with torch.no_grad():
# batch_size x n_components
posterior_mu, posterior_log_sigma = self.get_posterior(x, x_bert, labels)
# posterior_sigma = torch.exp(posterior_log_sigma)
# generate samples from theta
theta = F.softmax(
self.reparameterize(posterior_mu, posterior_log_sigma), dim=1
)
return theta
def sample(self, posterior_mu, posterior_log_sigma, n_samples: int = 20):
with torch.no_grad():
posterior_mu = posterior_mu.unsqueeze(0).repeat(n_samples, 1, 1)
posterior_log_sigma = posterior_log_sigma.unsqueeze(0).repeat(n_samples, 1, 1)
# generate samples from theta
theta = F.softmax(
self.reparameterize(posterior_mu, posterior_log_sigma), dim=-1
)
return theta.mean(dim=0)