Spaces:
Sleeping
Sleeping
| import sys, os | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from config import * | |
| from attention import SelfCrossAttn | |
| class VAEResBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None): | |
| out_channels = out_channels or in_channels | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.GroupNorm(vae_group_size, in_channels), nn.SiLU(inplace=True), | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), | |
| nn.GroupNorm(vae_group_size, out_channels), nn.SiLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), | |
| ) | |
| self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity() | |
| def forward(self, x): | |
| return self.block(x) + self.skip(x) | |
| class VAE(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Encoder | |
| self.encoder_conv = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, 1, padding=1, bias=False), # (B, 4, 256, 256) | |
| nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True), | |
| nn.Conv2d(32, 128, 3, 2, padding=1, bias=False), # (B, 32, 128, 128) | |
| nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True), | |
| nn.Conv2d(128, 256, 3, 2, padding=1, bias=False), # (B, 128, 64, 64) | |
| nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True), | |
| nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 32, 32) | |
| nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), | |
| # nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 16, 16) | |
| # nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), | |
| # SelfCrossAttn(512, heads=8, cross=False), | |
| VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512), | |
| nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), | |
| ) | |
| # Channel‑wise μ and log σ², shape = (B, latent_channels, 4, 4, 4) | |
| self.to_latent = nn.Conv2d(512, 2 * vae_latent_channels, kernel_size=1) | |
| self.from_latent = nn.Conv2d(vae_latent_channels, 512, kernel_size=1) | |
| # Decoder | |
| self.decoder_conv = nn.Sequential( | |
| VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512), | |
| nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), | |
| # SelfCrossAttn(512, heads=8, cross=False), | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv2d(512, 256, 3, padding=1, bias=False), | |
| nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True), | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv2d(256, 128, 3, padding=1, bias=False), | |
| nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True), | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv2d(128, 32, 3, padding=1, bias=False), | |
| nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True), | |
| nn.Conv2d(32, 3, 3, 1, padding=1), | |
| nn.Sigmoid() | |
| # nn.Tanh() | |
| ) | |
| def reparameterize(mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, x): | |
| h = self.encoder_conv(x) # (B, C, D, H) | |
| h = self.to_latent(h) | |
| mu, logvar = torch.chunk(h, 2, dim=1) | |
| z = self.reparameterize(mu, logvar) # Latent (B, C, D, H) | |
| h = self.from_latent(z) | |
| recon = self.decoder_conv(h) | |
| return recon, mu, logvar | |
| def encode_img_to_latent(self, x): | |
| h = self.encoder_conv(x) # (B, C, D, H) | |
| h = self.to_latent(h) | |
| mu, logvar = torch.chunk(h, 2, dim=1) | |
| z = self.reparameterize(mu, logvar) | |
| return z | |
| def decode_latent_to_img(self, z): | |
| h = self.from_latent(z) | |
| return self.decoder_conv(h) |