|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from functools import reduce |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
class TransformerVQVAE(nn.Module): |
|
|
def __init__(self, input_shape=(1, 128, 216), num_codebook_vectors=512, |
|
|
codebook_dim=64, num_layers=4, num_heads=8, hidden_dim=256): |
|
|
super(TransformerVQVAE, self).__init__() |
|
|
|
|
|
self.input_shape = input_shape |
|
|
self.num_codebook_vectors = num_codebook_vectors |
|
|
self.codebook_dim = codebook_dim |
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, codebook_dim, kernel_size=3, stride=2, padding=1) |
|
|
) |
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=codebook_dim, nhead=num_heads, dim_feedforward=hidden_dim) |
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
|
|
|
|
|
self.vq = VectorQuantizer(num_codebook_vectors, codebook_dim) |
|
|
|
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(codebook_dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def encode(self, x): |
|
|
|
|
|
z = self.encoder(x) |
|
|
|
|
|
|
|
|
z = z.permute(2, 3, 0, 1).contiguous() |
|
|
z = z.view(-1, z.shape[2], z.shape[3]) |
|
|
|
|
|
|
|
|
z = self.transformer(z) |
|
|
|
|
|
|
|
|
z = z.view(x.shape[2]//8, x.shape[3]//8, x.shape[0], self.codebook_dim) |
|
|
z = z.permute(2, 3, 0, 1).contiguous() |
|
|
|
|
|
return z |
|
|
|
|
|
def decode(self, z): |
|
|
|
|
|
return self.decoder(z) |
|
|
|
|
|
def forward(self, x): |
|
|
z = self.encode(x) |
|
|
z_q, indices, vq_loss = self.vq(z) |
|
|
x_recon = self.decode(z_q) |
|
|
return x_recon, indices, vq_loss |
|
|
|
|
|
class VectorQuantizer(nn.Module): |
|
|
def __init__(self, num_embeddings, embedding_dim, commitment_cost = 0.25): |
|
|
super(VectorQuantizer, self).__init__() |
|
|
self.num_embeddings = num_embeddings |
|
|
self.embedding_dim = embedding_dim |
|
|
self.commitment_cost = commitment_cost |
|
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim) |
|
|
self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings) |
|
|
|
|
|
def forward(self, z): |
|
|
|
|
|
z = z.permute(0, 2, 3, 1).contiguous() |
|
|
z_flattened = z.view(-1, self.embedding_dim) |
|
|
|
|
|
|
|
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ |
|
|
torch.sum(self.embedding.weight**2, dim=1) - \ |
|
|
2 * torch.matmul(z_flattened, self.embedding.weight.t()) |
|
|
|
|
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) |
|
|
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(z.device) |
|
|
min_encodings.scatter_(1, min_encoding_indices, 1) |
|
|
|
|
|
|
|
|
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) |
|
|
|
|
|
|
|
|
e_latent_loss = F.mse_loss(z_q.detach(), z) |
|
|
q_latent_loss = F.mse_loss(z_q, z.detach()) |
|
|
vq_loss = q_latent_loss + self.commitment_cost * e_latent_loss |
|
|
|
|
|
|
|
|
z_q = z + (z_q - z).detach() |
|
|
|
|
|
|
|
|
z_q = z_q.permute(0, 3, 1, 2).contiguous() |
|
|
|
|
|
return z_q, min_encoding_indices, vq_loss |