TinyDiT / vae.py
aniure
uploading
d069b0b
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, n_heads, embd_dim, in_proj_bias=True, out_proj_bias=True):
super().__init__()
self.n_heads = n_heads
self.in_proj = nn.Linear(embd_dim, 3 * embd_dim, bias=in_proj_bias)
self.out_proj = nn.Linear(embd_dim, embd_dim, bias=out_proj_bias)
self.d_heads = embd_dim // n_heads
assert self.d_heads * n_heads == embd_dim, "embed_dim must be divisible by num_heads"
def forward(self, x, casual_mask=False):
batch_size, seq_len, embd_dim = x.shape
interim_shape = (batch_size, seq_len, self.n_heads, self.d_heads)
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(interim_shape)
k = k.view(interim_shape)
v = v.view(interim_shape)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
weight = q @ k.transpose(-1, -2)
if casual_mask:
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
weight.masked_fill_(mask, -torch.inf)
weight /= math.sqrt(self.d_heads)
weight = F.softmax(weight, dim=-1)
output = weight @ v
output = output.transpose(1, 2)
output = output.reshape((batch_size, seq_len, embd_dim))
output = self.out_proj(output)
return output
class AttentionBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=channels)
self.attention = SelfAttention(n_heads=1, embd_dim=channels)
def forward(self, x):
residual = x
x = self.groupnorm(x)
n, c, h, w = x.shape
x = x.view((n, c, h * w)).transpose(-1, -2)
x = self.attention(x)
x = x.transpose(-1, -2).view((n, c, h, w))
x = x + residual
return x
class Residual(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.gn1 = nn.GroupNorm(32, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.gn2 = nn.GroupNorm(32, out_channels)
self.silu = nn.SiLU()
if in_channels != out_channels:
self.residual_layer = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
else:
self.residual_layer = nn.Identity()
def forward(self, x):
x_residual = x.clone()
x = self.gn1(x)
x = self.silu(x)
x = self.conv1(x)
x = self.gn2(x)
x = self.conv2(x)
x += self.residual_layer(x_residual)
return x
class Encoder(nn.Module):
def __init__(self, latent_channels):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.SiLU(),
Residual(64, 64),
Residual(64, 64),
nn.Conv2d(64, 128, 3, 2, 1),
Residual(128, 128),
Residual(128, 128),
nn.Conv2d(128, 256, 3, 2, 1),
Residual(256, 256),
Residual(256, 256),
nn.Conv2d(256, 256, 3, 2, 1),
Residual(256, 256),
AttentionBlock(channels=256),
Residual(256, 256),
nn.GroupNorm(32, 256),
nn.SiLU(),
)
self.mu = nn.Conv2d(256, latent_channels, 3, padding=1)
self.logvar = nn.Conv2d(256, latent_channels, 3, padding=1)
def forward(self, x):
x = self.net(x)
mu = self.mu(x)
logvar = self.logvar(x)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_channels):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(latent_channels, 256, 3, padding=1),
Residual(256, 256),
AttentionBlock(channels=256),
Residual(256, 256),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(256, 256, 3, padding=1),
Residual(256, 256),
Residual(256, 256),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(256, 128, 3, padding=1),
Residual(128, 128),
Residual(128, 128),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(128, 64, 3, padding=1),
Residual(64, 64),
Residual(64, 64),
nn.GroupNorm(32, 64),
nn.SiLU(),
nn.Conv2d(64, 3, 3, padding=1),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
class Vae(nn.Module):
def __init__(self, latent_channels):
super().__init__()
self.encoder = Encoder(latent_channels)
self.decoder = Decoder(latent_channels)
def reparametrize(self, mu, logvar):
logvar = torch.clamp(logvar, -30, 20)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparametrize(mu, logvar)
return self.decoder(z), mu, logvar