|
|
| from transformers import PreTrainedModel
|
| import torch.nn as nn
|
| import torch
|
| import json
|
| import os
|
| from typing import Tuple, Optional, Literal
|
| from .configuration_vae import VAEConfig
|
| from huggingface_hub import hf_hub_download
|
|
|
|
|
| class VAEModel(PreTrainedModel):
|
| config_class = VAEConfig
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
|
|
| self.img_shape = (config.img_shape[0],config.img_shape[1],config.img_shape[2])
|
| self.latent_dim = config.latent_dim
|
| self.hidden_channels = config.hidden_channels
|
| self.encoder_layers = config.encoder_layers
|
| self.data_type=config.data_type
|
|
|
| C, H, W = self.img_shape
|
|
|
| self.encoder = []
|
| for i in range(self.encoder_layers):
|
| inputDim=C if i==0 else self.hidden_channels * (2 ** (i-1))
|
| outputDim=self.hidden_channels * (2 ** i)
|
| self.encoder.append(nn.Conv2d(inputDim, outputDim, kernel_size=4, stride=2, padding=1))
|
| self.encoder.append(nn.BatchNorm2d(outputDim))
|
| self.encoder.append(nn.ReLU())
|
| self.encoderD = nn.Sequential(*self.encoder)
|
|
|
| with torch.no_grad():
|
| dummy = torch.zeros(1, C, H, W)
|
| dummy = self.encoderD(dummy)
|
| _, self.enc_channels, self.enc_H, self.enc_W = dummy.shape
|
| self.flatEncoderDim = self.enc_channels * self.enc_H * self.enc_W
|
|
|
| self.encoder = nn.Sequential(*self.encoder,nn.Flatten())
|
| self.fc_mu = nn.Linear(self.flatEncoderDim, self.latent_dim)
|
| self.fc_logvar = nn.Linear(self.flatEncoderDim, self.latent_dim)
|
|
|
| self.decoder_input = nn.Linear(self.latent_dim, self.flatEncoderDim)
|
| self.decoder = []
|
| inputDim=self.enc_channels
|
| for i in range(self.encoder_layers-1,0,-1):
|
| outputDim=self.hidden_channels * i
|
| self.decoder.append(nn.ConvTranspose2d(inputDim, outputDim, kernel_size=4, stride=2, padding=1))
|
| self.decoder.append(nn.BatchNorm2d(outputDim))
|
| self.decoder.append(nn.ReLU())
|
| inputDim=outputDim
|
| H_before_last = self.enc_H * (2 ** (self.encoder_layers - 1))
|
| W_before_last = self.enc_W * (2 ** (self.encoder_layers - 1))
|
| output_padding_h = H - 2 * H_before_last
|
| output_padding_w = W - 2 * W_before_last
|
| self.decoder.append(nn.ConvTranspose2d(inputDim, C, kernel_size=4, stride=2, padding=1,output_padding=(output_padding_h, output_padding_w)))
|
| self.decoder = nn.Sequential(*self.decoder)
|
| self.decoder_bernoulli = nn.Sigmoid()
|
| self.decoder_gaussian_mean = nn.Conv2d(C, C, kernel_size=3, padding=1)
|
| self.decoder_gaussian_logvar = nn.Conv2d(C, C, kernel_size=3, padding=1)
|
|
|
| self.prior_mean = torch.zeros(self.latent_dim)
|
| self.prior_std = torch.ones(self.latent_dim)
|
|
|
| def detect_data_type(self, x: torch.Tensor) -> str:
|
| unique_vals =torch.unique(x[0:100].flatten())
|
| if len(unique_vals) <= 2:
|
| print(f"Auto-detected: Binary data (unique values: {unique_vals.tolist()})")
|
| return 'binary'
|
| else:
|
| print(f"Auto-detected: Continuous data ({len(unique_vals)} unique values)")
|
| return 'continuous'
|
|
|
| def encode(self, x: torch.Tensor) -> tuple:
|
| h = self.encoder(x)
|
| mu = self.fc_mu(h)
|
| logvar = self.fc_logvar(h)
|
| return mu, logvar
|
|
|
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
| std = torch.exp(0.5 * logvar)
|
| eps = torch.randn_like(std)
|
| return mu + eps * std
|
|
|
| def decode(self, z: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| h = self.decoder_input(z)
|
| h = h.view(h.size(0), self.enc_channels, self.enc_H, self.enc_W)
|
| h = self.decoder(h)
|
| if self.data_type == 'binary':
|
| return self.decoder_bernoulli(h), None
|
| else:
|
| return self.decoder_gaussian_mean(h), self.decoder_gaussian_logvar(h)
|
|
|
| def sample_prior(self, num_samples: int) -> torch.Tensor:
|
| return torch.randn(num_samples, self.latent_dim).to(self.parameters().__next__().device)
|
|
|
| def forward(self,x: torch.Tensor,data_type: Optional[str] = None) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
|
| if (data_type is None):
|
| data_type = self.data_type
|
| if (data_type == 'auto'):
|
| self.data_type = self.detect_data_type(x)
|
| mu, logvar = self.encode(x)
|
| z = self.reparameterize(mu, logvar)
|
| recon_x = self.decode(z)
|
| return recon_x, mu, logvar
|
|
|
| def reconstruction_loss(self, x: torch.Tensor, recon_output: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
| mu: torch.Tensor, logvar: torch.Tensor, data_type: Optional[str] = None
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| if data_type is None:
|
| data_type = self.data_type
|
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
| if data_type == 'binary':
|
| if isinstance(recon_output, tuple):
|
| recon_output=recon_output[0]
|
| recon_loss = nn.functional.binary_cross_entropy(recon_output, x, reduction='sum')
|
| else:
|
| mean, logvar_x = recon_output
|
| var_x = torch.exp(logvar_x)
|
| recon_loss = 0.5 * torch.sum(torch.log(2 * torch.pi * var_x) + (x - mean).pow(2) / var_x)
|
| if recon_loss<0:
|
| recon_loss = torch.sum((x - mean).pow(2))
|
|
|
| return recon_loss + kl_loss,recon_loss,kl_loss
|
|
|
| def generate(self, num_samples: int = 1, z: Optional[torch.Tensor] = None):
|
| if z is None:
|
| z = self.sample_prior(num_samples)
|
| recon_x = self.decode(z)
|
| if isinstance(recon_x, tuple):
|
| return recon_x[0]
|
| return recon_x
|
|
|
|
|
| @classmethod
|
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
|
| if os.path.isdir(pretrained_model_name_or_path):
|
|
|
| config_path = os.path.join(pretrained_model_name_or_path, "config.json")
|
| model_path = os.path.join(pretrained_model_name_or_path, "customVAE_model2.pth")
|
| else:
|
|
|
| config_path = hf_hub_download(
|
| repo_id=pretrained_model_name_or_path,
|
| filename="config.json",
|
| cache_dir=kwargs.get("cache_dir", None)
|
| )
|
| model_path = hf_hub_download(
|
| repo_id=pretrained_model_name_or_path,
|
| filename="customVAE_model2.pth",
|
| cache_dir=kwargs.get("cache_dir", None)
|
| )
|
|
|
|
|
| with open(config_path, 'r') as f:
|
| config_dict = json.load(f)
|
|
|
|
|
| config = VAEConfig(**config_dict)
|
|
|
|
|
| model = cls(config)
|
|
|
|
|
| state_dict = torch.load(model_path)
|
| if 'model_state_dict' in state_dict:
|
| state_dict = state_dict['model_state_dict']
|
|
|
| model.load_state_dict(state_dict)
|
| return model |