Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from contextualized_topic_models.networks.inference_network import ( | |
| CombinedInferenceNetwork, | |
| ContextualInferenceNetwork, | |
| ) | |
| class DecoderNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| input_size, | |
| bert_size, | |
| infnet, | |
| n_components=10, | |
| model_type="prodLDA", | |
| hidden_sizes=(100, 100), | |
| activation="softplus", | |
| dropout=0.2, | |
| learn_priors=True, | |
| label_size=0, | |
| ): | |
| """ | |
| Initialize InferenceNetwork. | |
| Args | |
| input_size : int, dimension of input | |
| n_components : int, number of topic components, (default 10) | |
| model_type : string, 'prodLDA' or 'LDA' (default 'prodLDA') | |
| hidden_sizes : tuple, length = n_layers, (default (100, 100)) | |
| activation : string, 'softplus', 'relu', (default 'softplus') | |
| learn_priors : bool, make priors learnable parameter | |
| """ | |
| super(DecoderNetwork, self).__init__() | |
| assert isinstance(input_size, int), "input_size must by type int." | |
| assert ( | |
| isinstance(n_components, int) and n_components > 0 | |
| ), "n_components must be type int > 0." | |
| assert model_type in ["prodLDA", "LDA"], "model type must be 'prodLDA' or 'LDA'" | |
| assert isinstance(hidden_sizes, tuple), "hidden_sizes must be type tuple." | |
| assert activation in [ | |
| "softplus", | |
| "relu", | |
| ], "activation must be 'softplus' or 'relu'." | |
| assert dropout >= 0, "dropout must be >= 0." | |
| self.input_size = input_size | |
| self.n_components = n_components | |
| self.model_type = model_type | |
| self.hidden_sizes = hidden_sizes | |
| self.activation = activation | |
| self.dropout = dropout | |
| self.learn_priors = learn_priors | |
| self.topic_word_matrix = None | |
| if infnet == "zeroshot": | |
| self.inf_net = ContextualInferenceNetwork( | |
| input_size, | |
| bert_size, | |
| n_components, | |
| hidden_sizes, | |
| activation, | |
| label_size=label_size, | |
| ) | |
| elif infnet == "combined": | |
| self.inf_net = CombinedInferenceNetwork( | |
| input_size, | |
| bert_size, | |
| n_components, | |
| hidden_sizes, | |
| activation, | |
| label_size=label_size, | |
| ) | |
| else: | |
| raise Exception( | |
| "Missing infnet parameter, options are zeroshot and combined" | |
| ) | |
| if label_size != 0: | |
| self.label_classification = nn.Linear(n_components, label_size) | |
| # init prior parameters | |
| # \mu_1k = log \alpha_k + 1/K \sum_i log \alpha_i; | |
| # \alpha = 1 \forall \alpha | |
| topic_prior_mean = 0.0 | |
| self.prior_mean = torch.tensor([topic_prior_mean] * n_components) | |
| if torch.cuda.is_available(): | |
| self.prior_mean = self.prior_mean.cuda() | |
| if self.learn_priors: | |
| self.prior_mean = nn.Parameter(self.prior_mean) | |
| # \Sigma_1kk = 1 / \alpha_k (1 - 2/K) + 1/K^2 \sum_i 1 / \alpha_k; | |
| # \alpha = 1 \forall \alpha | |
| topic_prior_variance = 1.0 - (1.0 / self.n_components) | |
| self.prior_variance = torch.tensor([topic_prior_variance] * n_components) | |
| if torch.cuda.is_available(): | |
| self.prior_variance = self.prior_variance.cuda() | |
| if self.learn_priors: | |
| self.prior_variance = nn.Parameter(self.prior_variance) | |
| self.beta = torch.Tensor(n_components, input_size) | |
| if torch.cuda.is_available(): | |
| self.beta = self.beta.cuda() | |
| self.beta = nn.Parameter(self.beta) | |
| nn.init.xavier_uniform_(self.beta) | |
| self.beta_batchnorm = nn.BatchNorm1d(input_size, affine=False) | |
| # dropout on theta | |
| self.drop_theta = nn.Dropout(p=self.dropout) | |
| def reparameterize(mu, logvar): | |
| """Reparameterize the theta distribution.""" | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return eps.mul(std).add_(mu) | |
| def forward(self, x, x_bert, labels=None): | |
| """Forward pass.""" | |
| # batch_size x n_components | |
| posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert, labels) | |
| posterior_sigma = torch.exp(posterior_log_sigma) | |
| # generate samples from theta | |
| theta = F.softmax(self.reparameterize(posterior_mu, posterior_log_sigma), dim=1) | |
| theta = self.drop_theta(theta) | |
| # prodLDA vs LDA | |
| if self.model_type == "prodLDA": | |
| # in: batch_size x input_size x n_components | |
| word_dist = F.softmax( | |
| self.beta_batchnorm(torch.matmul(theta, self.beta)), dim=1 | |
| ) | |
| # word_dist: batch_size x input_size | |
| self.topic_word_matrix = self.beta | |
| elif self.model_type == "LDA": | |
| # simplex constrain on Beta | |
| beta = F.softmax(self.beta_batchnorm(self.beta), dim=1) | |
| self.topic_word_matrix = beta | |
| word_dist = torch.matmul(theta, beta) | |
| # word_dist: batch_size x input_size | |
| else: | |
| raise NotImplementedError("Model Type Not Implemented") | |
| # classify labels | |
| estimated_labels = None | |
| if labels is not None: | |
| estimated_labels = self.label_classification(theta) | |
| return ( | |
| self.prior_mean, | |
| self.prior_variance, | |
| posterior_mu, | |
| posterior_sigma, | |
| posterior_log_sigma, | |
| word_dist, | |
| estimated_labels, | |
| ) | |
| def get_posterior(self, x, x_bert, labels=None): | |
| """Get posterior distribution.""" | |
| # batch_size x n_components | |
| posterior_mu, posterior_log_sigma = self.inf_net(x, x_bert, labels) | |
| return posterior_mu, posterior_log_sigma | |
| def get_theta(self, x, x_bert, labels=None): | |
| with torch.no_grad(): | |
| # batch_size x n_components | |
| posterior_mu, posterior_log_sigma = self.get_posterior(x, x_bert, labels) | |
| # posterior_sigma = torch.exp(posterior_log_sigma) | |
| # generate samples from theta | |
| theta = F.softmax( | |
| self.reparameterize(posterior_mu, posterior_log_sigma), dim=1 | |
| ) | |
| return theta | |
| def sample(self, posterior_mu, posterior_log_sigma, n_samples: int = 20): | |
| with torch.no_grad(): | |
| posterior_mu = posterior_mu.unsqueeze(0).repeat(n_samples, 1, 1) | |
| posterior_log_sigma = posterior_log_sigma.unsqueeze(0).repeat(n_samples, 1, 1) | |
| # generate samples from theta | |
| theta = F.softmax( | |
| self.reparameterize(posterior_mu, posterior_log_sigma), dim=-1 | |
| ) | |
| return theta.mean(dim=0) | |