Rohan3's picture
Updated: VAE, UNet, config, text embeddings, model and main
a625e96
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import torch
from torch import nn
import torch.nn.functional as F
from config import *
from attention import SelfCrossAttn
class VAEResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
out_channels = out_channels or in_channels
super().__init__()
self.block = nn.Sequential(
nn.GroupNorm(vae_group_size, in_channels), nn.SiLU(inplace=True),
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.GroupNorm(vae_group_size, out_channels), nn.SiLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
)
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity()
def forward(self, x):
return self.block(x) + self.skip(x)
class VAE(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.encoder_conv = nn.Sequential(
nn.Conv2d(3, 32, 3, 1, padding=1, bias=False), # (B, 4, 256, 256)
nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True),
nn.Conv2d(32, 128, 3, 2, padding=1, bias=False), # (B, 32, 128, 128)
nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True),
nn.Conv2d(128, 256, 3, 2, padding=1, bias=False), # (B, 128, 64, 64)
nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True),
nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 32, 32)
nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True),
# nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 16, 16)
# nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True),
# SelfCrossAttn(512, heads=8, cross=False),
VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512),
nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True),
)
# Channel‑wise μ and log σ², shape = (B, latent_channels, 4, 4, 4)
self.to_latent = nn.Conv2d(512, 2 * vae_latent_channels, kernel_size=1)
self.from_latent = nn.Conv2d(vae_latent_channels, 512, kernel_size=1)
# Decoder
self.decoder_conv = nn.Sequential(
VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512),
nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True),
# SelfCrossAttn(512, heads=8, cross=False),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(512, 256, 3, padding=1, bias=False),
nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(256, 128, 3, padding=1, bias=False),
nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True),
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(128, 32, 3, padding=1, bias=False),
nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True),
nn.Conv2d(32, 3, 3, 1, padding=1),
nn.Sigmoid()
# nn.Tanh()
)
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
h = self.encoder_conv(x) # (B, C, D, H)
h = self.to_latent(h)
mu, logvar = torch.chunk(h, 2, dim=1)
z = self.reparameterize(mu, logvar) # Latent (B, C, D, H)
h = self.from_latent(z)
recon = self.decoder_conv(h)
return recon, mu, logvar
def encode_img_to_latent(self, x):
h = self.encoder_conv(x) # (B, C, D, H)
h = self.to_latent(h)
mu, logvar = torch.chunk(h, 2, dim=1)
z = self.reparameterize(mu, logvar)
return z
def decode_latent_to_img(self, z):
h = self.from_latent(z)
return self.decoder_conv(h)