| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, n_heads, embd_dim, in_proj_bias=True, out_proj_bias=True): |
| super().__init__() |
| self.n_heads = n_heads |
| self.in_proj = nn.Linear(embd_dim, 3 * embd_dim, bias=in_proj_bias) |
| self.out_proj = nn.Linear(embd_dim, embd_dim, bias=out_proj_bias) |
|
|
| self.d_heads = embd_dim // n_heads |
| assert self.d_heads * n_heads == embd_dim, "embed_dim must be divisible by num_heads" |
|
|
| def forward(self, x, casual_mask=False): |
| batch_size, seq_len, embd_dim = x.shape |
| interim_shape = (batch_size, seq_len, self.n_heads, self.d_heads) |
| q, k, v = self.in_proj(x).chunk(3, dim=-1) |
| q = q.view(interim_shape) |
| k = k.view(interim_shape) |
| v = v.view(interim_shape) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| weight = q @ k.transpose(-1, -2) |
| if casual_mask: |
| mask = torch.ones_like(weight, dtype=torch.bool).triu(1) |
| weight.masked_fill_(mask, -torch.inf) |
| weight /= math.sqrt(self.d_heads) |
| weight = F.softmax(weight, dim=-1) |
| output = weight @ v |
| output = output.transpose(1, 2) |
| output = output.reshape((batch_size, seq_len, embd_dim)) |
| output = self.out_proj(output) |
| return output |
|
|
| class AttentionBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=channels) |
| self.attention = SelfAttention(n_heads=1, embd_dim=channels) |
|
|
| def forward(self, x): |
| residual = x |
| x = self.groupnorm(x) |
| n, c, h, w = x.shape |
| x = x.view((n, c, h * w)).transpose(-1, -2) |
| x = self.attention(x) |
| x = x.transpose(-1, -2).view((n, c, h, w)) |
| x = x + residual |
| return x |
|
|
| class Residual(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) |
| self.gn1 = nn.GroupNorm(32, out_channels) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1) |
| self.gn2 = nn.GroupNorm(32, out_channels) |
| self.silu = nn.SiLU() |
| if in_channels != out_channels: |
| self.residual_layer = nn.Conv2d(in_channels, out_channels, 1, 1, 0) |
| else: |
| self.residual_layer = nn.Identity() |
|
|
| def forward(self, x): |
| x_residual = x.clone() |
| x = self.gn1(x) |
| x = self.silu(x) |
| x = self.conv1(x) |
| x = self.gn2(x) |
| x = self.conv2(x) |
| x += self.residual_layer(x_residual) |
| return x |
|
|
| class Encoder(nn.Module): |
| def __init__(self, latent_channels): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(3, 64, 3, padding=1), |
| nn.SiLU(), |
|
|
| Residual(64, 64), |
| Residual(64, 64), |
|
|
| nn.Conv2d(64, 128, 3, 2, 1), |
| Residual(128, 128), |
| Residual(128, 128), |
|
|
| nn.Conv2d(128, 256, 3, 2, 1), |
| Residual(256, 256), |
| Residual(256, 256), |
|
|
| nn.Conv2d(256, 256, 3, 2, 1), |
| Residual(256, 256), |
| AttentionBlock(channels=256), |
| Residual(256, 256), |
|
|
| nn.GroupNorm(32, 256), |
| nn.SiLU(), |
| ) |
| self.mu = nn.Conv2d(256, latent_channels, 3, padding=1) |
| self.logvar = nn.Conv2d(256, latent_channels, 3, padding=1) |
|
|
| def forward(self, x): |
| x = self.net(x) |
| mu = self.mu(x) |
| logvar = self.logvar(x) |
| return mu, logvar |
|
|
| class Decoder(nn.Module): |
| def __init__(self, latent_channels): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(latent_channels, 256, 3, padding=1), |
| Residual(256, 256), |
| AttentionBlock(channels=256), |
| Residual(256, 256), |
|
|
| nn.Upsample(scale_factor=2, mode='nearest'), |
| nn.Conv2d(256, 256, 3, padding=1), |
| Residual(256, 256), |
| Residual(256, 256), |
|
|
| nn.Upsample(scale_factor=2, mode='nearest'), |
| nn.Conv2d(256, 128, 3, padding=1), |
| Residual(128, 128), |
| Residual(128, 128), |
|
|
| nn.Upsample(scale_factor=2, mode='nearest'), |
| nn.Conv2d(128, 64, 3, padding=1), |
| Residual(64, 64), |
| Residual(64, 64), |
|
|
| nn.GroupNorm(32, 64), |
| nn.SiLU(), |
| nn.Conv2d(64, 3, 3, padding=1), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| class Vae(nn.Module): |
| def __init__(self, latent_channels): |
| super().__init__() |
| self.encoder = Encoder(latent_channels) |
| self.decoder = Decoder(latent_channels) |
|
|
| def reparametrize(self, mu, logvar): |
| logvar = torch.clamp(logvar, -30, 20) |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
|
|
| def forward(self, x): |
| mu, logvar = self.encoder(x) |
| z = self.reparametrize(mu, logvar) |
| return self.decoder(z), mu, logvar |
|
|
|
|