| """ |
| MaiChartTokenizer โ VQ-VAE based learned tokenizer for maimai charts. |
| |
| Two modes: |
| 1. Rule-based: use tokenizer.MaiChartTokenizer (vocab-based, deterministic) |
| 2. Learned (VQ-VAE): this module โ encoder โ codebook โ decoder |
| |
| Architecture: |
| note_features (35-dim) โ Encoder MLP โ latent โ VectorQuantizer โ discrete codes |
| โ |
| Decoder MLP โ reconstructed note_features |
| |
| Usage: |
| from Tokenizer.MaiChartTokenizer import MaiChartTokenizer |
| |
| model = MaiChartTokenizer(input_dim=35, hidden_dim=256, codebook_size=512, latent_dim=64) |
| tokens, commitment_loss = model.encode(features) # โ discrete token sequence |
| reconstructed = model.decode(tokens) # โ back to features |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class VectorQuantizer(nn.Module): |
| """ |
| Vector Quantization layer (VQ-VAE codebook). |
| |
| Maps continuous latent vectors to the nearest codebook entry. |
| Returns discrete indices + commitment loss for training. |
| """ |
|
|
| def __init__(self, codebook_size: int, latent_dim: int, beta: float = 0.25): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.latent_dim = latent_dim |
| self.beta = beta |
|
|
| |
| self.embedding = nn.Embedding(codebook_size, latent_dim) |
| self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) |
|
|
| def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| z: [B, T, latent_dim] or [B*T, latent_dim] |
| |
| Returns: |
| z_q: Quantized vectors (same shape as z) |
| indices: Codebook indices [B, T] or [B*T] |
| loss: Commitment loss (scalar) |
| """ |
| |
| original_shape = z.shape |
| if z.dim() == 3: |
| B, T, D = z.shape |
| z_flat = z.reshape(-1, D) |
| else: |
| z_flat = z |
|
|
| |
| |
| z_sq = (z_flat ** 2).sum(dim=1, keepdim=True) |
| e_sq = (self.embedding.weight ** 2).sum(dim=1) |
| dist = z_sq + e_sq.unsqueeze(0) - 2 * z_flat @ self.embedding.weight.T |
|
|
| |
| indices = dist.argmin(dim=1) |
| z_q = self.embedding(indices) |
|
|
| |
| commitment_loss = F.mse_loss(z_flat.detach(), z_q) |
| codebook_loss = F.mse_loss(z_flat, z_q.detach()) |
| vq_loss = codebook_loss + self.beta * commitment_loss |
|
|
| |
| z_q = z_flat + (z_q - z_flat).detach() |
|
|
| |
| if len(original_shape) == 3: |
| z_q = z_q.reshape(B, T, D) |
| indices = indices.reshape(B, T) |
|
|
| return z_q, indices, vq_loss |
|
|
| def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: |
| """Convert discrete indices back to vectors.""" |
| return self.embedding(indices) |
|
|
|
|
| class MaiChartTokenizer(nn.Module): |
| """ |
| VQ-VAE Tokenizer: encodes chart note features โ discrete token sequence. |
| |
| The rule-based tokenizer (tokenizer.py) provides the baseline vocabulary. |
| This learned tokenizer can be trained end-to-end to discover optimal tokens. |
| |
| Args: |
| input_dim: Feature dimension per note (from mai_parser note_to_vector). |
| hidden_dim: Hidden dimension of encoder/decoder MLP. |
| codebook_size: Number of discrete codes (vocabulary size). |
| latent_dim: Dimension of each codebook vector. |
| num_layers: Number of MLP layers in encoder/decoder. |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 35, |
| hidden_dim: int = 256, |
| codebook_size: int = 512, |
| latent_dim: int = 64, |
| num_layers: int = 2, |
| ): |
| super().__init__() |
| self.input_dim = input_dim |
| self.hidden_dim = hidden_dim |
| self.codebook_size = codebook_size |
| self.latent_dim = latent_dim |
|
|
| |
| encoder_layers = [] |
| in_dim = input_dim |
| for i in range(num_layers): |
| out_dim = hidden_dim if i < num_layers - 1 else latent_dim |
| encoder_layers.append(nn.Linear(in_dim, out_dim)) |
| if i < num_layers - 1: |
| encoder_layers.append(nn.ReLU()) |
| in_dim = out_dim |
| self.encoder = nn.Sequential(*encoder_layers) |
|
|
| |
| self.pre_vq = nn.Linear(latent_dim, latent_dim) |
|
|
| |
| self.vq = VectorQuantizer(codebook_size, latent_dim) |
|
|
| |
| decoder_layers = [] |
| in_dim = latent_dim |
| for i in range(num_layers): |
| out_dim = hidden_dim if i < num_layers - 1 else input_dim |
| decoder_layers.append(nn.Linear(in_dim, out_dim)) |
| if i < num_layers - 1: |
| decoder_layers.append(nn.ReLU()) |
| in_dim = out_dim |
| self.decoder = nn.Sequential(*decoder_layers) |
|
|
| def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Encode chart features to discrete token sequence. |
| |
| Args: |
| x: [B, T, input_dim] โ batch of chart feature sequences. |
| |
| Returns: |
| tokens: [B, T] โ discrete token IDs. |
| vq_loss: scalar commitment loss (for training). |
| """ |
| |
| z = self.encoder(x) |
| z = self.pre_vq(z) |
|
|
| |
| z_q, indices, vq_loss = self.vq(z) |
|
|
| return indices, vq_loss, z_q |
|
|
| def decode(self, tokens: torch.Tensor) -> torch.Tensor: |
| """ |
| Decode token sequence back to chart features. |
| |
| Args: |
| tokens: [B, T] โ discrete token IDs. |
| |
| Returns: |
| x_recon: [B, T, input_dim] โ reconstructed note features. |
| """ |
| z_q = self.vq.decode_indices(tokens) |
| x_recon = self.decoder(z_q) |
| return x_recon |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Full forward pass: encode โ quantize โ decode. |
| |
| Args: |
| x: [B, T, input_dim] |
| |
| Returns: |
| x_recon: [B, T, input_dim] reconstructed features. |
| tokens: [B, T] discrete token IDs. |
| vq_loss: scalar loss. |
| """ |
| tokens, vq_loss, z_q = self.encode(x) |
| x_recon = self.decoder(z_q) |
| return x_recon, tokens, vq_loss |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.codebook_size |