Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from timm import create_model | |
| from model.config import load_config | |
| from .model_embedder import HybridEmbed | |
| config = load_config() | |
| class Encoder(nn.Module): | |
| def __init__(self, latent_dims=4): | |
| super(Encoder, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(num_features=16), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(num_features=32), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(num_features=64), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |
| nn.BatchNorm2d(num_features=128), | |
| nn.LeakyReLU() | |
| ) | |
| self.latent_dims = latent_dims | |
| self.fc1 = nn.Linear(128*14*14, 256) | |
| self.fc2 = nn.Linear(256, 128) | |
| self.mu = nn.Linear(128*14*14, self.latent_dims) | |
| self.var = nn.Linear(128*14*14, self.latent_dims) | |
| self.kl = 0 | |
| self.kl_weight = 0.5#0.00025 | |
| self.relu = nn.LeakyReLU() | |
| def reparameterize(self, x): | |
| # https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vanilla_vae.py | |
| std = torch.exp(0.5*self.mu(x)) | |
| eps = torch.randn_like(std) | |
| z = eps * std + self.mu(x) | |
| return z, std | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = torch.flatten(x, start_dim=1) | |
| mu = self.mu(x) | |
| var = self.var(x) | |
| z,_ = self.reparameterize(x) | |
| self.kl = self.kl_weight*torch.mean(-0.5*torch.sum(1+var - mu**2 - var.exp(), dim=1), dim=0) | |
| return z | |
| class Decoder(nn.Module): | |
| def __init__(self, latent_dims=4): | |
| super(Decoder, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2), | |
| nn.LeakyReLU(), | |
| nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), | |
| nn.LeakyReLU(), | |
| nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2), | |
| nn.LeakyReLU(), | |
| nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2), | |
| nn.LeakyReLU() | |
| ) | |
| self.latent_dims = latent_dims | |
| self.unflatten = nn.Unflatten(dim=1, unflattened_size=(256, 7, 7)) | |
| def forward(self, x): | |
| x = self.unflatten(x) | |
| x = self.features(x) | |
| return x | |
| class GenConViTVAE(nn.Module): | |
| def __init__(self, config, pretrained=True): | |
| super(GenConViTVAE, self).__init__() | |
| self.latent_dims = config['model']['latent_dims'] | |
| self.encoder = Encoder(self.latent_dims) | |
| self.decoder = Decoder(self.latent_dims) | |
| self.embedder = create_model(config['model']['embedder'], pretrained=True) | |
| self.convnext_backbone = create_model(config['model']['backbone'], pretrained=True, num_classes=1000, drop_path_rate=0, head_init_scale=1.0) | |
| self.convnext_backbone.patch_embed = HybridEmbed(self.embedder, img_size=config['img_size'], embed_dim=768) | |
| self.num_feature = self.convnext_backbone.head.fc.out_features * 2 | |
| self.fc = nn.Linear(self.num_feature, self.num_feature//4) | |
| self.fc3 = nn.Linear(self.num_feature//2, self.num_feature//4) | |
| self.fc2 = nn.Linear(self.num_feature//4, config['num_classes']) | |
| self.relu = nn.ReLU() | |
| self.resize = transforms.Resize((224,224), antialias=True) | |
| def forward(self, x): | |
| z = self.encoder(x) | |
| x_hat = self.decoder(z) | |
| x1 = self.convnext_backbone(x) | |
| x2 = self.convnext_backbone(x_hat) | |
| x = torch.cat((x1,x2), dim=1) | |
| x = self.fc2(self.relu(self.fc(self.relu(x)))) | |
| return x, self.resize(x_hat) |