import sys, os sys.path.insert(0, os.path.dirname(__file__)) import torch from torch import nn import torch.nn.functional as F from config import * from attention import SelfCrossAttn class VAEResBlock(nn.Module): def __init__(self, in_channels, out_channels=None): out_channels = out_channels or in_channels super().__init__() self.block = nn.Sequential( nn.GroupNorm(vae_group_size, in_channels), nn.SiLU(inplace=True), nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), nn.GroupNorm(vae_group_size, out_channels), nn.SiLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), ) self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity() def forward(self, x): return self.block(x) + self.skip(x) class VAE(nn.Module): def __init__(self): super().__init__() # Encoder self.encoder_conv = nn.Sequential( nn.Conv2d(3, 32, 3, 1, padding=1, bias=False), # (B, 4, 256, 256) nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True), nn.Conv2d(32, 128, 3, 2, padding=1, bias=False), # (B, 32, 128, 128) nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True), nn.Conv2d(128, 256, 3, 2, padding=1, bias=False), # (B, 128, 64, 64) nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True), nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 32, 32) nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), # nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 16, 16) # nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), # SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512), nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), ) # Channel‑wise μ and log σ², shape = (B, latent_channels, 4, 4, 4) self.to_latent = nn.Conv2d(512, 2 * vae_latent_channels, kernel_size=1) self.from_latent = nn.Conv2d(vae_latent_channels, 512, kernel_size=1) # Decoder self.decoder_conv = nn.Sequential( VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512), nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), # SelfCrossAttn(512, heads=8, cross=False), nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, padding=1, bias=False), nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True), nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, padding=1, bias=False), nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True), nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 32, 3, padding=1, bias=False), nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True), nn.Conv2d(32, 3, 3, 1, padding=1), nn.Sigmoid() # nn.Tanh() ) @staticmethod def reparameterize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): h = self.encoder_conv(x) # (B, C, D, H) h = self.to_latent(h) mu, logvar = torch.chunk(h, 2, dim=1) z = self.reparameterize(mu, logvar) # Latent (B, C, D, H) h = self.from_latent(z) recon = self.decoder_conv(h) return recon, mu, logvar def encode_img_to_latent(self, x): h = self.encoder_conv(x) # (B, C, D, H) h = self.to_latent(h) mu, logvar = torch.chunk(h, 2, dim=1) z = self.reparameterize(mu, logvar) return z def decode_latent_to_img(self, z): h = self.from_latent(z) return self.decoder_conv(h)