maiChartGen / Tokenizer /MaiChartTokenizer.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
6.86 kB
"""
MaiChartTokenizer โ€” VQ-VAE based learned tokenizer for maimai charts.
Two modes:
1. Rule-based: use tokenizer.MaiChartTokenizer (vocab-based, deterministic)
2. Learned (VQ-VAE): this module โ€” encoder โ†’ codebook โ†’ decoder
Architecture:
note_features (35-dim) โ†’ Encoder MLP โ†’ latent โ†’ VectorQuantizer โ†’ discrete codes
โ†“
Decoder MLP โ†’ reconstructed note_features
Usage:
from Tokenizer.MaiChartTokenizer import MaiChartTokenizer
model = MaiChartTokenizer(input_dim=35, hidden_dim=256, codebook_size=512, latent_dim=64)
tokens, commitment_loss = model.encode(features) # โ†’ discrete token sequence
reconstructed = model.decode(tokens) # โ†’ back to features
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
"""
Vector Quantization layer (VQ-VAE codebook).
Maps continuous latent vectors to the nearest codebook entry.
Returns discrete indices + commitment loss for training.
"""
def __init__(self, codebook_size: int, latent_dim: int, beta: float = 0.25):
super().__init__()
self.codebook_size = codebook_size
self.latent_dim = latent_dim
self.beta = beta
# Codebook: [codebook_size, latent_dim]
self.embedding = nn.Embedding(codebook_size, latent_dim)
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
z: [B, T, latent_dim] or [B*T, latent_dim]
Returns:
z_q: Quantized vectors (same shape as z)
indices: Codebook indices [B, T] or [B*T]
loss: Commitment loss (scalar)
"""
# Flatten if needed
original_shape = z.shape
if z.dim() == 3:
B, T, D = z.shape
z_flat = z.reshape(-1, D)
else:
z_flat = z
# Compute distances to codebook: [N, codebook_size]
# ||z - e||^2 = ||z||^2 + ||e||^2 - 2 * z @ e.T
z_sq = (z_flat ** 2).sum(dim=1, keepdim=True)
e_sq = (self.embedding.weight ** 2).sum(dim=1)
dist = z_sq + e_sq.unsqueeze(0) - 2 * z_flat @ self.embedding.weight.T
# Nearest codebook entry
indices = dist.argmin(dim=1) # [N]
z_q = self.embedding(indices) # [N, latent_dim]
# VQ loss: commitment loss + codebook loss
commitment_loss = F.mse_loss(z_flat.detach(), z_q)
codebook_loss = F.mse_loss(z_flat, z_q.detach())
vq_loss = codebook_loss + self.beta * commitment_loss
# Straight-through estimator
z_q = z_flat + (z_q - z_flat).detach()
# Reshape back
if len(original_shape) == 3:
z_q = z_q.reshape(B, T, D)
indices = indices.reshape(B, T)
return z_q, indices, vq_loss
def decode_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""Convert discrete indices back to vectors."""
return self.embedding(indices)
class MaiChartTokenizer(nn.Module):
"""
VQ-VAE Tokenizer: encodes chart note features โ†’ discrete token sequence.
The rule-based tokenizer (tokenizer.py) provides the baseline vocabulary.
This learned tokenizer can be trained end-to-end to discover optimal tokens.
Args:
input_dim: Feature dimension per note (from mai_parser note_to_vector).
hidden_dim: Hidden dimension of encoder/decoder MLP.
codebook_size: Number of discrete codes (vocabulary size).
latent_dim: Dimension of each codebook vector.
num_layers: Number of MLP layers in encoder/decoder.
"""
def __init__(
self,
input_dim: int = 35,
hidden_dim: int = 256,
codebook_size: int = 512,
latent_dim: int = 64,
num_layers: int = 2,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.codebook_size = codebook_size
self.latent_dim = latent_dim
# โ”€โ”€ Encoder: note features โ†’ latent โ”€โ”€
encoder_layers = []
in_dim = input_dim
for i in range(num_layers):
out_dim = hidden_dim if i < num_layers - 1 else latent_dim
encoder_layers.append(nn.Linear(in_dim, out_dim))
if i < num_layers - 1:
encoder_layers.append(nn.ReLU())
in_dim = out_dim
self.encoder = nn.Sequential(*encoder_layers)
# โ”€โ”€ Pre-quantization projection โ”€โ”€
self.pre_vq = nn.Linear(latent_dim, latent_dim)
# โ”€โ”€ Vector Quantizer โ”€โ”€
self.vq = VectorQuantizer(codebook_size, latent_dim)
# โ”€โ”€ Decoder: latent โ†’ note features โ”€โ”€
decoder_layers = []
in_dim = latent_dim
for i in range(num_layers):
out_dim = hidden_dim if i < num_layers - 1 else input_dim
decoder_layers.append(nn.Linear(in_dim, out_dim))
if i < num_layers - 1:
decoder_layers.append(nn.ReLU())
in_dim = out_dim
self.decoder = nn.Sequential(*decoder_layers)
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Encode chart features to discrete token sequence.
Args:
x: [B, T, input_dim] โ€” batch of chart feature sequences.
Returns:
tokens: [B, T] โ€” discrete token IDs.
vq_loss: scalar commitment loss (for training).
"""
# Encode
z = self.encoder(x) # [B, T, latent_dim]
z = self.pre_vq(z)
# Quantize
z_q, indices, vq_loss = self.vq(z)
return indices, vq_loss, z_q
def decode(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Decode token sequence back to chart features.
Args:
tokens: [B, T] โ€” discrete token IDs.
Returns:
x_recon: [B, T, input_dim] โ€” reconstructed note features.
"""
z_q = self.vq.decode_indices(tokens)
x_recon = self.decoder(z_q)
return x_recon
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Full forward pass: encode โ†’ quantize โ†’ decode.
Args:
x: [B, T, input_dim]
Returns:
x_recon: [B, T, input_dim] reconstructed features.
tokens: [B, T] discrete token IDs.
vq_loss: scalar loss.
"""
tokens, vq_loss, z_q = self.encode(x)
x_recon = self.decoder(z_q)
return x_recon, tokens, vq_loss
@property
def vocab_size(self) -> int:
return self.codebook_size