LowLevelMusicFeatureContrastiveModel / timeseries_vqvae_transformer.py
theodoredc's picture
Upload 204 files
22cfe7b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class TransformerVQVAE(nn.Module):
def __init__(self, input_shape=(1, 128, 216), num_codebook_vectors=512,
codebook_dim=64, num_layers=4, num_heads=8, hidden_dim=256):
super(TransformerVQVAE, self).__init__()
self.input_shape = input_shape
self.num_codebook_vectors = num_codebook_vectors
self.codebook_dim = codebook_dim
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, codebook_dim, kernel_size=3, stride=2, padding=1)
)
# Transformer layers
encoder_layer = nn.TransformerEncoderLayer(d_model=codebook_dim, nhead=num_heads, dim_feedforward=hidden_dim)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Vector Quantizer
self.vq = VectorQuantizer(num_codebook_vectors, codebook_dim)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(codebook_dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Tanh()
)
def encode(self, x):
# Encode the input
z = self.encoder(x)
# Reshape for transformer
z = z.permute(2, 3, 0, 1).contiguous()
z = z.view(-1, z.shape[2], z.shape[3])
# Apply transformer
z = self.transformer(z)
# Reshape back
z = z.view(x.shape[2]//8, x.shape[3]//8, x.shape[0], self.codebook_dim)
z = z.permute(2, 3, 0, 1).contiguous()
return z
def decode(self, z):
# Decode the latent representation
return self.decoder(z)
def forward(self, x):
z = self.encode(x)
z_q, indices, vq_loss = self.vq(z)
x_recon = self.decode(z_q)
return x_recon, indices, vq_loss
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost = 0.25):
super(VectorQuantizer, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.commitment_cost = commitment_cost
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
def forward(self, z):
# Reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.embedding_dim)
# Distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2ze
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
# Find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(z.device)
min_encodings.scatter_(1, min_encoding_indices, 1)
# Quantize and unflatten
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# Compute loss for embedding
e_latent_loss = F.mse_loss(z_q.detach(), z)
q_latent_loss = F.mse_loss(z_q, z.detach())
vq_loss = q_latent_loss + self.commitment_cost * e_latent_loss
# Straight Through Estimator
z_q = z + (z_q - z).detach()
# Reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, min_encoding_indices, vq_loss