| import os |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| class MVAEConfig(PretrainedConfig): |
| model_type = "mvae" |
|
|
| def __init__( |
| self, |
| prosoro_type="cylinder", |
| x_dim_dict=None, |
| h1_dim_dict=None, |
| h2_dim_dict=None, |
| z_dim=32, |
| layer_norm=False, |
| use_activation="relu", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.prosoro_type = prosoro_type |
| self.x_dim_dict = x_dim_dict |
| self.h1_dim_dict = h1_dim_dict |
| self.h2_dim_dict = h2_dim_dict |
| self.z_dim = z_dim |
| self.layer_norm = layer_norm |
| self.use_activation = use_activation |
|
|
|
|
| class MVAE(PreTrainedModel): |
| config_class = MVAEConfig |
|
|
| def __init__(self, config: MVAEConfig): |
| super().__init__(config) |
|
|
| self.prosoro_type = getattr(config, "prosoro_type", "cylinder") |
| self.x_dim_list = config.x_dim_dict |
| self.h1_dim_list = config.h1_dim_dict |
| self.h2_dim_list = config.h2_dim_dict |
| self.z_dim = config.z_dim |
|
|
| self.model = nn.ModuleDict() |
| self.model = nn.ModuleDict() |
| for i in range(len(self.x_dim_list)): |
| self.model[f"encoder_{i}"] = nn.Sequential( |
| nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]), |
| nn.ReLU(), |
| nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]), |
| ) |
| self.model[f"encoder_{i}"] = nn.Sequential( |
| nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]), |
| nn.ReLU(), |
| nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]), |
| ) |
| self.model[f"decoder_{i}"] = nn.Sequential( |
| nn.Linear(self.z_dim, self.h2_dim_list[i]), |
| nn.ReLU(), |
| nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]), |
| nn.ReLU(), |
| nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]), |
| ) |
| self.model[f"decoder_{i}"] = nn.Sequential( |
| nn.Linear(self.z_dim, self.h2_dim_list[i]), |
| nn.ReLU(), |
| nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]), |
| nn.ReLU(), |
| nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]), |
| ) |
| self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
| self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
| self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
| self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
|
|
| def sample(self, mu, var): |
| std = torch.exp(var / 2) |
| p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) |
| q = torch.distributions.Normal(mu, std) |
| z = q.rsample() |
| return p, q, z |
|
|
| def x_to_z_encoder(self, x, input_index): |
| h = self.model[f"encoder_{input_index}"](x) |
| mu = self.model[f"fc_mu_{input_index}"](h) |
| var = self.model[f"fc_var_{input_index}"](h) |
| h = self.model[f"encoder_{input_index}"](x) |
| mu = self.model[f"fc_mu_{input_index}"](h) |
| var = self.model[f"fc_var_{input_index}"](h) |
| _, _, z = self.sample(mu, var) |
| return z |
|
|
| def z_to_x_decoder(self, z, output_index): |
| x_hat = self.model[f"decoder_{output_index}"](z) |
| x_hat = self.model[f"decoder_{output_index}"](z) |
| return x_hat |
|
|
| def forward(self, x): |
| x_hat_list = [] |
| for i in range(len(self.x_dim_list)): |
| z = self.x_to_z_encoder(x, 0) |
| x_hat = self.z_to_x_decoder(z, i) |
| x_hat_list.append(x_hat) |
| return x_hat_list[1:] |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| config = kwargs.get("config", None) |
|
|
| if config is None: |
| from transformers import AutoConfig |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| |
| prosoro_type = getattr(config, "prosoro_type", None) |
|
|
| pretrained_model_name_or_path = pretrained_model_name_or_path + f"/{prosoro_type}" |
|
|
| return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|