""" MaiChartTokenizer — VQ-VAE based learned tokenizer for maimai charts. Two modes: 1. Rule-based: use tokenizer.MaiChartTokenizer (vocab-based, deterministic) 2. Learned (VQ-VAE): this module — encoder → codebook → decoder Architecture: note_features (35-dim) → Encoder MLP → latent → VectorQuantizer → discrete codes ↓ Decoder MLP → reconstructed note_features Usage: from Tokenizer.MaiChartTokenizer import MaiChartTokenizer model = MaiChartTokenizer(input_dim=35, hidden_dim=256, codebook_size=512, latent_dim=64) tokens, commitment_loss = model.encode(features) # → discrete token sequence reconstructed = model.decode(tokens) # → back to features """ import torch import torch.nn as nn import torch.nn.functional as F class VectorQuantizer(nn.Module): """ Vector Quantization layer (VQ-VAE codebook). Maps continuous latent vectors to the nearest codebook entry. Returns discrete indices + commitment loss for training. """ def __init__(self, codebook_size: int, latent_dim: int, beta: float = 0.25): super().__init__() self.codebook_size = codebook_size self.latent_dim = latent_dim self.beta = beta # Codebook: [codebook_size, latent_dim] self.embedding = nn.Embedding(codebook_size, latent_dim) self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: z: [B, T, latent_dim] or [B*T, latent_dim] Returns: z_q: Quantized vectors (same shape as z) indices: Codebook indices [B, T] or [B*T] loss: Commitment loss (scalar) """ # Flatten if needed original_shape = z.shape if z.dim() == 3: B, T, D = z.shape z_flat = z.reshape(-1, D) else: z_flat = z # Compute distances to codebook: [N, codebook_size] # ||z - e||^2 = ||z||^2 + ||e||^2 - 2 * z @ e.T z_sq = (z_flat ** 2).sum(dim=1, keepdim=True) e_sq = (self.embedding.weight ** 2).sum(dim=1) dist = z_sq + e_sq.unsqueeze(0) - 2 * z_flat @ self.embedding.weight.T # Nearest codebook entry indices = dist.argmin(dim=1) # [N] z_q = self.embedding(indices) # [N, latent_dim] # VQ loss: commitment loss + codebook loss commitment_loss = F.mse_loss(z_flat.detach(), z_q) codebook_loss = F.mse_loss(z_flat, z_q.detach()) vq_loss = codebook_loss + self.beta * commitment_loss # Straight-through estimator z_q = z_flat + (z_q - z_flat).detach() # Reshape back if len(original_shape) == 3: z_q = z_q.reshape(B, T, D) indices = indices.reshape(B, T) return z_q, indices, vq_loss def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: """Convert discrete indices back to vectors.""" return self.embedding(indices) class MaiChartTokenizer(nn.Module): """ VQ-VAE Tokenizer: encodes chart note features → discrete token sequence. The rule-based tokenizer (tokenizer.py) provides the baseline vocabulary. This learned tokenizer can be trained end-to-end to discover optimal tokens. Args: input_dim: Feature dimension per note (from mai_parser note_to_vector). hidden_dim: Hidden dimension of encoder/decoder MLP. codebook_size: Number of discrete codes (vocabulary size). latent_dim: Dimension of each codebook vector. num_layers: Number of MLP layers in encoder/decoder. """ def __init__( self, input_dim: int = 35, hidden_dim: int = 256, codebook_size: int = 512, latent_dim: int = 64, num_layers: int = 2, ): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.codebook_size = codebook_size self.latent_dim = latent_dim # ── Encoder: note features → latent ── encoder_layers = [] in_dim = input_dim for i in range(num_layers): out_dim = hidden_dim if i < num_layers - 1 else latent_dim encoder_layers.append(nn.Linear(in_dim, out_dim)) if i < num_layers - 1: encoder_layers.append(nn.ReLU()) in_dim = out_dim self.encoder = nn.Sequential(*encoder_layers) # ── Pre-quantization projection ── self.pre_vq = nn.Linear(latent_dim, latent_dim) # ── Vector Quantizer ── self.vq = VectorQuantizer(codebook_size, latent_dim) # ── Decoder: latent → note features ── decoder_layers = [] in_dim = latent_dim for i in range(num_layers): out_dim = hidden_dim if i < num_layers - 1 else input_dim decoder_layers.append(nn.Linear(in_dim, out_dim)) if i < num_layers - 1: decoder_layers.append(nn.ReLU()) in_dim = out_dim self.decoder = nn.Sequential(*decoder_layers) def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Encode chart features to discrete token sequence. Args: x: [B, T, input_dim] — batch of chart feature sequences. Returns: tokens: [B, T] — discrete token IDs. vq_loss: scalar commitment loss (for training). """ # Encode z = self.encoder(x) # [B, T, latent_dim] z = self.pre_vq(z) # Quantize z_q, indices, vq_loss = self.vq(z) return indices, vq_loss, z_q def decode(self, tokens: torch.Tensor) -> torch.Tensor: """ Decode token sequence back to chart features. Args: tokens: [B, T] — discrete token IDs. Returns: x_recon: [B, T, input_dim] — reconstructed note features. """ z_q = self.vq.decode_indices(tokens) x_recon = self.decoder(z_q) return x_recon def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Full forward pass: encode → quantize → decode. Args: x: [B, T, input_dim] Returns: x_recon: [B, T, input_dim] reconstructed features. tokens: [B, T] discrete token IDs. vq_loss: scalar loss. """ tokens, vq_loss, z_q = self.encode(x) x_recon = self.decoder(z_q) return x_recon, tokens, vq_loss @property def vocab_size(self) -> int: return self.codebook_size