import torch import torch.nn as nn import torch.nn.functional as F import math class SelfAttention(nn.Module): def __init__(self, n_heads, embd_dim, in_proj_bias=True, out_proj_bias=True): super().__init__() self.n_heads = n_heads self.in_proj = nn.Linear(embd_dim, 3 * embd_dim, bias=in_proj_bias) self.out_proj = nn.Linear(embd_dim, embd_dim, bias=out_proj_bias) self.d_heads = embd_dim // n_heads assert self.d_heads * n_heads == embd_dim, "embed_dim must be divisible by num_heads" def forward(self, x, casual_mask=False): batch_size, seq_len, embd_dim = x.shape interim_shape = (batch_size, seq_len, self.n_heads, self.d_heads) q, k, v = self.in_proj(x).chunk(3, dim=-1) q = q.view(interim_shape) k = k.view(interim_shape) v = v.view(interim_shape) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) weight = q @ k.transpose(-1, -2) if casual_mask: mask = torch.ones_like(weight, dtype=torch.bool).triu(1) weight.masked_fill_(mask, -torch.inf) weight /= math.sqrt(self.d_heads) weight = F.softmax(weight, dim=-1) output = weight @ v output = output.transpose(1, 2) output = output.reshape((batch_size, seq_len, embd_dim)) output = self.out_proj(output) return output class AttentionBlock(nn.Module): def __init__(self, channels): super().__init__() self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=channels) self.attention = SelfAttention(n_heads=1, embd_dim=channels) def forward(self, x): residual = x x = self.groupnorm(x) n, c, h, w = x.shape x = x.view((n, c, h * w)).transpose(-1, -2) x = self.attention(x) x = x.transpose(-1, -2).view((n, c, h, w)) x = x + residual return x class Residual(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) self.gn1 = nn.GroupNorm(32, out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) self.gn2 = nn.GroupNorm(32, out_channels) self.silu = nn.SiLU() if in_channels != out_channels: self.residual_layer = nn.Conv2d(in_channels, out_channels, 1, 1, 0) else: self.residual_layer = nn.Identity() def forward(self, x): x_residual = x.clone() x = self.gn1(x) x = self.silu(x) x = self.conv1(x) x = self.gn2(x) x = self.conv2(x) x += self.residual_layer(x_residual) return x class Encoder(nn.Module): def __init__(self, latent_channels): super().__init__() self.net = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.SiLU(), Residual(64, 64), Residual(64, 64), nn.Conv2d(64, 128, 3, 2, 1), Residual(128, 128), Residual(128, 128), nn.Conv2d(128, 256, 3, 2, 1), Residual(256, 256), Residual(256, 256), nn.Conv2d(256, 256, 3, 2, 1), Residual(256, 256), AttentionBlock(channels=256), Residual(256, 256), nn.GroupNorm(32, 256), nn.SiLU(), ) self.mu = nn.Conv2d(256, latent_channels, 3, padding=1) self.logvar = nn.Conv2d(256, latent_channels, 3, padding=1) def forward(self, x): x = self.net(x) mu = self.mu(x) logvar = self.logvar(x) return mu, logvar class Decoder(nn.Module): def __init__(self, latent_channels): super().__init__() self.net = nn.Sequential( nn.Conv2d(latent_channels, 256, 3, padding=1), Residual(256, 256), AttentionBlock(channels=256), Residual(256, 256), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(256, 256, 3, padding=1), Residual(256, 256), Residual(256, 256), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(256, 128, 3, padding=1), Residual(128, 128), Residual(128, 128), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(128, 64, 3, padding=1), Residual(64, 64), Residual(64, 64), nn.GroupNorm(32, 64), nn.SiLU(), nn.Conv2d(64, 3, 3, padding=1), nn.Tanh(), ) def forward(self, x): return self.net(x) class Vae(nn.Module): def __init__(self, latent_channels): super().__init__() self.encoder = Encoder(latent_channels) self.decoder = Decoder(latent_channels) def reparametrize(self, mu, logvar): logvar = torch.clamp(logvar, -30, 20) std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): mu, logvar = self.encoder(x) z = self.reparametrize(mu, logvar) return self.decoder(z), mu, logvar