maiChartGen / model.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
29.5 kB
"""
MaiGenerator โ€” Audio-conditioned autoregressive maimai chart generator.
Architecture: Encoder-Decoder Transformer with time-aligned RoPE.
Audio tokens โ†’ AudioEncoder โ†’ audio_feat [T_aud, d]
โ”‚ Cross-Attention
Chart tokens โ†’ ChartDecoder โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
(autoregressive, + BPM/Diff/Genre conditioning
causal mask) + time-aligned RoPE positions
Key design: Chart RoPE uses audio frame indices (via BPM translation),
ensuring strong time alignment between music and generated notes.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import tokenizer as chart_tokenizer
from tokenizer import CONFIG_BASE
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Constants (aligned with tokenizers)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
CHART_VOCAB_SIZE = chart_tokenizer.VOCAB_SIZE
AUDIO_VOCAB_SIZE = 2051 # MaiTrackTokenizer (2-layer EnCodec)
AUDIO_FRAME_RATE = 75 # EnCodec 24kHz / 320 stride
BOS, EOS, PAD = 1, 2, 0
DIFF_NAMES = ["BASIC", "ADVANCED", "EXPERT", "MASTER", "ReMASTER"]
NUM_DIFFICULTIES = len(DIFF_NAMES)
# Beat division token โ†’ value
DIV_MAP = {5: 1, 6: 2, 7: 4, 8: 8, 9: 16, 10: 32,
11: 48, 12: 64, 13: 128, 14: 192, 15: 384}
DUR_TOKEN = 17 # [DUR] in chart vocab
RST_TOKEN = 16
TAP_BASE, TAP_END = 18, 26
BRK_BASE, BRK_END = 26, 34
HLD_BASE, HLD_END = 34, 42
SLD_BASE, SLD_END = 42, 50
SLD_BEG_TOKEN = 50
SLD_END_TOKEN = 51
SIM_BEG_TOKEN = 52
SIM_END_TOKEN = 53
TCH_BASE, TCH_END = 54, 95
TYPE_REST = 0
TYPE_TAP = 1
TYPE_HOLD = 2
TYPE_SLIDE = 3
TYPE_BREAK = 4
TYPE_TOUCH = 5
TYPE_CONTROL = 6
NUM_TOKEN_TYPES = 7
NUM_POSITIONS = 9 # 0-7 real positions, 8 = none/control
NUM_DIV_CLASSES = len(DIV_MAP)
def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor:
"""Use PyTorch SDPA so CUDA can pick Flash/mem-efficient attention kernels."""
return F.scaled_dot_product_attention(
q, k, v,
dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
is_causal=is_causal,
)
def is_timeline_token(tok: torch.Tensor) -> torch.Tensor:
"""Tokens that correspond to one chart time slot after decoding."""
is_note = (((tok >= TAP_BASE) & (tok < TAP_END)) |
((tok >= BRK_BASE) & (tok < BRK_END)) |
((tok >= HLD_BASE) & (tok < HLD_END)) |
((tok >= SLD_BASE) & (tok < SLD_END)) |
((tok >= TCH_BASE) & (tok < TCH_END)) |
(tok >= CONFIG_BASE))
return (tok == RST_TOKEN) | (tok == SIM_BEG_TOKEN) | (tok == SLD_BEG_TOKEN) | is_note
def token_structure_features(tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Return token type ids and 0-based position ids for structure-aware embeds."""
typ = torch.full_like(tokens, TYPE_CONTROL)
pos = torch.full_like(tokens, NUM_POSITIONS - 1)
typ = torch.where(tokens == RST_TOKEN, torch.full_like(typ, TYPE_REST), typ)
ranges = [
(TAP_BASE, TAP_END, TYPE_TAP),
(HLD_BASE, HLD_END, TYPE_HOLD),
(SLD_BASE, SLD_END, TYPE_SLIDE),
(BRK_BASE, BRK_END, TYPE_BREAK),
(TCH_BASE, TCH_END, TYPE_TOUCH),
]
for start, end, token_type in ranges:
mask = (tokens >= start) & (tokens < end)
typ = torch.where(mask, torch.full_like(typ, token_type), typ)
pos = torch.where(mask, (tokens - start) % 8, pos)
return typ, pos
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# RoPE with custom positions
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class RoPE(nn.Module):
"""Rotary Position Embedding supporting custom position indices."""
def __init__(self, dim: int, base: float = 10000.0):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x: torch.Tensor,
positions: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: [B, H, T, D] โ€” query or key.
positions: [T] or [B, T] โ€” custom position indices.
Returns:
Rotated tensor [B, H, T, D].
"""
B, H, T, D = x.shape
device = x.device
if positions is None:
positions = torch.arange(T, device=device, dtype=torch.float32)
if positions.dim() == 1:
positions = positions.unsqueeze(0)
positions = positions.unsqueeze(1).unsqueeze(-1) # [B, 1, T, 1]
angles = positions * self.inv_freq.to(device) # [B, 1, T, D/2]
sin, cos = angles.sin(), angles.cos()
x_even, x_odd = x[..., 0::2], x[..., 1::2]
out = torch.empty_like(x)
out[..., 0::2] = x_even * cos - x_odd * sin
out[..., 1::2] = x_even * sin + x_odd * cos
return out
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Onset Feature Injection (FiLM)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class OnsetFiLM(nn.Module):
"""Beat-prior injection: modulates encoder features at beat positions."""
def __init__(self, d_model: int = 512):
super().__init__()
self.gamma = nn.Linear(1, d_model)
self.beta = nn.Linear(1, d_model)
def forward(self, enc_out, onset):
"""enc_out: [B, T_enc, D], onset: [B, T_enc, 1]"""
g = torch.tanh(self.gamma(onset)) * 0.5 # [B, T_enc, D]
b = self.beta(onset) * 0.1
return enc_out * (1.0 + g) + b
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# MoE FFN (Mixture of Experts for difficulty routing)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class MoEFFN(nn.Module):
"""MoE FFN: routes input through N experts weighted by difficulty."""
def __init__(self, d_model=512, d_ff=2048, n_experts=6, dropout=0.1):
super().__init__()
self.n_experts = n_experts
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
) for _ in range(n_experts)
])
self.router = nn.Linear(d_model, n_experts)
def forward(self, x, diff_emb):
"""x: [B, T, d], diff_emb: [B, d]"""
weights = F.softmax(self.router(diff_emb), dim=-1) # [B, N]
out = sum(weights[:, i:i+1, None] * self.experts[i](x)
for i in range(self.n_experts))
return out
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Encoder Block
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class EncoderBlock(nn.Module):
"""Pre-LN encoder: Self-Attn + FFN. Uses FlashAttention for O(T) memory."""
def __init__(self, d_model: int = 512, heads: int = 8,
d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.heads = heads
self.head_dim = d_model // heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rope = RoPE(self.head_dim)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor,
positions: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, D = x.shape
residual = x
x_norm = self.norm1(x)
Q = self.q_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2)
# RoPE time-aligned positions.
if positions is not None:
Q = self.rope(Q, positions)
K = self.rope(K, positions)
if mask is not None:
attn_mask = mask
if attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(torch.bool)
attn_out = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
)
else:
attn_out = _sdpa(Q, K, V, self.dropout.p if self.training else 0.0)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D)
x = residual + self.dropout(self.out_proj(attn_out))
residual = x
x = residual + self.ffn(self.norm2(x))
return x
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Decoder Block
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class DecoderBlock(nn.Module):
"""Pre-LN decoder: Causal Self-Attn + Cross-Attn + FFN.
Supports KV-cache for fast autoregressive inference.
"""
def __init__(self, d_model: int = 512, heads: int = 8,
d_ff: int = 2048, dropout: float = 0.1,
use_moe: bool = False, n_experts: int = 6):
super().__init__()
self.heads = heads
self.head_dim = d_model // heads
# Self-attn
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rope = RoPE(self.head_dim)
# Cross-attn
self.cross_q = nn.Linear(d_model, d_model)
self.cross_k = nn.Linear(d_model, d_model)
self.cross_v = nn.Linear(d_model, d_model)
self.cross_out = nn.Linear(d_model, d_model)
# FFN (standard or MoE)
if use_moe:
self.ffn = MoEFFN(d_model, d_ff, n_experts, dropout)
self.is_moe = True
else:
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
)
self.is_moe = False
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def _init_self_kv_cache(self, batch_size: int, max_len: int, device):
"""Pre-allocate self-attention KV cache: tuple of (K, V) tensors."""
D = self.head_dim
H = self.heads
self._self_k_cache = torch.zeros(batch_size, H, max_len, D, device=device)
self._self_v_cache = torch.zeros(batch_size, H, max_len, D, device=device)
self._self_cache_len = 0
def _init_cross_kv_cache(self, enc_out: torch.Tensor):
"""Precompute cross-attention K, V from encoder output (fixed during generation).
Note: enc_out is the raw encoder output, NOT pre-normed (norm2 is for decoder hidden)."""
B, T_enc, D = enc_out.shape
Kc = self.cross_k(enc_out).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2)
Vc = self.cross_v(enc_out).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2)
self._cross_k_cache = Kc
self._cross_v_cache = Vc
def forward(self, x, enc_out, self_positions=None, diff_emb=None,
onset_film=None, onset_kv=None,
use_cache: bool = False):
B, T_dec, D = x.shape
# โ”€โ”€ Causal Self-Attn โ”€โ”€
residual = x
x_norm = self.norm1(x)
Q = self.q_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2)
# RoPE time-aligned positions (full sequence positions for cache mode)
if self_positions is not None:
Q = self.rope(Q, self_positions)
K = self.rope(K, self_positions)
if use_cache and hasattr(self, '_self_k_cache'):
# Append new K, V to cache
cache_len = self._self_cache_len
new_len = T_dec
self._self_k_cache[:, :, cache_len:cache_len + new_len] = K
self._self_v_cache[:, :, cache_len:cache_len + new_len] = V
K = self._self_k_cache[:, :, :cache_len + new_len]
V = self._self_v_cache[:, :, :cache_len + new_len]
self._self_cache_len = cache_len + new_len
attn = _sdpa(Q, K, V, self.dropout.p if self.training else 0.0,
is_causal=(not use_cache))
x = residual + self.dropout(self.out_proj(
attn.transpose(1, 2).contiguous().view(B, T_dec, D)))
# โ”€โ”€ Cross-Attn (with optional OnsetFiLM modulation) โ”€โ”€
residual = x
x_norm = self.norm2(x)
Qc = self.cross_q(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2)
if use_cache and hasattr(self, '_cross_k_cache'):
Kc = self._cross_k_cache
Vc = self._cross_v_cache
else:
_enc = enc_out
if onset_film is not None and onset_kv is not None:
_enc = onset_film(_enc, onset_kv)
T_enc = _enc.shape[1]
Kc = self.cross_k(_enc).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2)
Vc = self.cross_v(_enc).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2)
attn_c = _sdpa(Qc, Kc, Vc, self.dropout.p if self.training else 0.0)
x = residual + self.dropout(self.cross_out(
attn_c.transpose(1, 2).contiguous().view(B, T_dec, D)))
# โ”€โ”€ FFN โ”€โ”€
residual = x
if self.is_moe and diff_emb is not None:
x = residual + self.ffn(self.norm3(x), diff_emb)
else:
x = residual + self.ffn(self.norm3(x))
return x
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# MaiGenerator
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class MaiGenerator(nn.Module):
"""Audio-conditioned autoregressive maimai chart generator.
Input:
audio_tokens: [B, T_aud] โ€” EnCodec tokens
chart_tokens: [B, T_chart] โ€” chart token sequence (train: full[:-1])
bpm: [B, 1] โ€” BPM value
difficulty: [B, 1] โ€” difficulty enum (0..4)
level_value: [B, 1] โ€” numeric level (e.g. 12.4)
genre: [B, 1] โ€” genre index (optional)
Output:
logits: [B, T_chart, chart_vocab] โ€” next-token prediction
"""
def __init__(self, d_model: int = 512, enc_layers: int = 6,
dec_layers: int = 8, heads: int = 8, d_ff: int = 2048,
dropout: float = 0.1, chart_vocab: int | None = None,
audio_vocab: int = AUDIO_VOCAB_SIZE,
num_genres: int = 16, max_audio_len: int = 32768,
audio_downsample: int = 8, use_moe: bool = True,
n_experts: int = 6, moe_layers: list = None):
"""moe_layers: indices of decoder layers that use MoE (e.g. [8,9,10,11])
Other decoder layers use shared FFN. None = all decoder layers use MoE."""
super().__init__()
if chart_vocab is None:
chart_vocab = chart_tokenizer.VOCAB_SIZE
self.d_model = d_model
self.chart_vocab_size = chart_vocab
self.audio_downsample = audio_downsample
if moe_layers is None:
moe_layers = list(range(dec_layers)) # all MoE
# Embeddings
self.audio_embed = nn.Embedding(audio_vocab, d_model)
self.chart_embed = nn.Embedding(chart_vocab, d_model)
self.chart_type_embed = nn.Embedding(NUM_TOKEN_TYPES, d_model)
self.chart_pos_embed = nn.Embedding(NUM_POSITIONS, d_model)
# Onset feature injection
self.onset_film = OnsetFiLM(d_model)
# Audio downsampling (Conv1D + LayerNorm to reduce seq len)
if audio_downsample > 1:
self.audio_down = nn.Sequential(
nn.Conv1d(d_model, d_model, kernel_size=audio_downsample,
stride=audio_downsample, padding=0),
nn.GELU(),
)
else:
self.audio_down = nn.Identity()
self.audio_pos_embed = nn.Embedding(max_audio_len, d_model)
# Conditions
self.bpm_proj = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(),
nn.Linear(d_model, d_model))
self.diff_embed = nn.Embedding(NUM_DIFFICULTIES, d_model)
self.level_proj = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(),
nn.Linear(d_model, d_model))
self.genre_embed = nn.Embedding(num_genres, d_model)
# Encoder / Decoder (hybrid: shared FFN + MoE layers)
self.audio_encoder = nn.ModuleList([
EncoderBlock(d_model, heads, d_ff, dropout) for _ in range(enc_layers)])
self.moe_layers = set(moe_layers)
n_shared = dec_layers - len(self.moe_layers)
n_moe = len(self.moe_layers)
print(f"Decoder: {n_shared} shared + {n_moe} MoE ร—{n_experts} experts")
self.chart_decoder = nn.ModuleList([
DecoderBlock(d_model, heads, d_ff, dropout,
use_moe=(i in self.moe_layers), n_experts=n_experts)
for i in range(dec_layers)])
self.output_head = nn.Linear(d_model, chart_vocab)
self.presence_head = nn.Linear(d_model, 2)
self.type_head = nn.Linear(d_model, NUM_TOKEN_TYPES)
self.position_head = nn.Linear(d_model, 8)
self.division_head = nn.Linear(d_model, NUM_DIV_CLASSES)
self.sim_head = nn.Linear(d_model, 2)
self.duration_head = nn.Linear(d_model, 2)
self.enc_norm = nn.LayerNorm(d_model)
self.dec_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=0.5)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def init_kv_cache(self, batch_size: int, max_len: int, enc_out: torch.Tensor):
"""Initialize KV caches for all decoder blocks. Call before incremental generation.
Args:
batch_size: Batch size (usually 1 for inference).
max_len: Maximum generation length for pre-allocation.
enc_out: Encoder output [B, T_enc, D] for cross-attention cache.
"""
for blk in self.chart_decoder:
blk._init_self_kv_cache(batch_size, max_len, enc_out.device)
blk._init_cross_kv_cache(enc_out)
# โ”€โ”€ Time-aligned positions (ๆ ธๅฟƒ) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@staticmethod
def compute_chart_positions(chart_tokens: torch.Tensor,
bpm: torch.Tensor,
downsample: int = 4) -> torch.Tensor:
"""Compute downsampled audio-frame positions for chart tokens.
Args:
chart_tokens: [B, T]
bpm: [B, 1]
downsample: Audio downsampling factor.
Returns:
positions: [B, T] โ€” downsampled frame indices (float).
"""
B, T = chart_tokens.shape
device = chart_tokens.device
bpm_v = bpm.view(B).float().clamp(min=30.0)
div_values = torch.full((B,), 4.0, device=device)
positions = torch.zeros(B, T, device=device)
current_beat = torch.zeros(B, device=device)
dur_param_skip = torch.zeros(B, dtype=torch.long, device=device)
sim_skip = torch.zeros(B, dtype=torch.long, device=device)
slide_active = torch.zeros(B, dtype=torch.bool, device=device)
for i in range(T):
tok = chart_tokens[:, i]
# Update beat division
for div_id, div_val in DIV_MAP.items():
div_values = torch.where(tok == div_id,
torch.tensor(float(div_val), device=device),
div_values)
beat_per_token = 4.0 / div_values
# Record position: beat โ†’ seconds โ†’ audio frame โ†’ downsampled
time_sec = current_beat * 60.0 / bpm_v
positions[:, i] = time_sec * AUDIO_FRAME_RATE / downsample
# Advance beat only for decoded timeline slots. SIM/SLD groups
# occupy one slot at their begin token; their contents are structural.
is_dur_param = dur_param_skip > 0
is_dur = (tok == DUR_TOKEN)
in_sim_body = sim_skip > 0
in_slide_body = slide_active & (tok != SLD_BEG_TOKEN)
group_body = in_sim_body | in_slide_body
advances_time = is_timeline_token(tok) & ~is_dur_param & ~is_dur & ~group_body
current_beat = torch.where(advances_time,
current_beat + beat_per_token,
current_beat)
is_sim_beg = tok == SIM_BEG_TOKEN
is_sim_end = tok == SIM_END_TOKEN
count_after_sim_beg = is_sim_beg & (i + 1 < T)
next_tok = chart_tokens[:, i + 1] if i + 1 < T else torch.zeros_like(tok)
sim_skip = torch.where(count_after_sim_beg,
torch.clamp(next_tok + 2, min=0),
torch.clamp(sim_skip - 1, min=0))
sim_skip = torch.where(is_sim_end, torch.zeros_like(sim_skip), sim_skip)
slide_active = torch.where(tok == SLD_BEG_TOKEN, torch.ones_like(slide_active), slide_active)
slide_active = torch.where(tok == SLD_END_TOKEN, torch.zeros_like(slide_active), slide_active)
dur_param_skip = torch.where(is_dur,
torch.full_like(dur_param_skip, 2),
torch.clamp(dur_param_skip - 1, min=0))
return positions
# โ”€โ”€ Forward โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def forward(self, audio_tokens: torch.Tensor,
chart_tokens: torch.Tensor, bpm: torch.Tensor,
difficulty: torch.Tensor, level_value: torch.Tensor,
genre: Optional[torch.Tensor] = None,
onset_curve: Optional[torch.Tensor] = None,
return_aux: bool = False,
return_hidden: bool = False) -> torch.Tensor | dict[str, torch.Tensor]:
"""Training forward pass (teacher forcing).
Returns:
logits: [B, T_chart, 256]
"""
B, T_chart = chart_tokens.shape
device = chart_tokens.device
# โ”€โ”€ Encode audio โ”€โ”€
T_aud = audio_tokens.shape[1]
aud = self.audio_embed(audio_tokens)
aud = aud + self.audio_pos_embed(
torch.arange(T_aud, device=device).unsqueeze(0).expand(B, -1))
# Downsample: [B, T, D] โ†’ [B, T//stride, D]
if self.audio_downsample > 1:
aud = aud.transpose(1, 2) # [B, D, T]
aud = self.audio_down(aud) # [B, D, T']
aud = aud.transpose(1, 2) # [B, T', D]
T_aud = aud.shape[1]
if onset_curve is None:
delta = torch.zeros(B, T_aud, device=device, dtype=aud.dtype)
if T_aud > 1:
delta[:, 1:] = (aud[:, 1:] - aud[:, :-1]).pow(2).mean(dim=-1).sqrt()
denom = delta.amax(dim=1, keepdim=True).clamp_min(1e-6)
delta = delta / denom
onset_curve = delta.unsqueeze(-1)
elif onset_curve.dim() == 2:
onset_curve = onset_curve.unsqueeze(-1)
if onset_curve.shape[1] != T_aud:
onset_curve = F.interpolate(onset_curve.transpose(1, 2), size=T_aud,
mode="linear", align_corners=False).transpose(1, 2)
aud = self.dropout(aud)
aud_pos = torch.arange(T_aud, device=device, dtype=torch.float32)
for blk in self.audio_encoder:
aud = blk(aud, positions=aud_pos)
aud = self.enc_norm(aud)
# Diff vector for decoder MoE routing
diff_vec = self.diff_embed(difficulty.squeeze(-1)) # [B, d_model]
# Embed chart + structural token features + conditions
token_type, token_pos = token_structure_features(chart_tokens)
emb = (self.chart_embed(chart_tokens) +
self.chart_type_embed(token_type) +
self.chart_pos_embed(token_pos))
bpm_emb = self.bpm_proj(bpm.float()).unsqueeze(1)
diff_emb = self.diff_embed(difficulty.squeeze(-1)).unsqueeze(1)
level_emb = self.level_proj(level_value.float()).unsqueeze(1)
genre_emb = torch.zeros(B, 1, self.d_model, device=device)
if genre is not None:
genre_emb = self.genre_embed(genre.squeeze(-1)).unsqueeze(1)
emb = emb + bpm_emb + diff_emb + level_emb + genre_emb
# โ”€โ”€ Time-aligned positions โ”€โ”€
chart_pos = self.compute_chart_positions(chart_tokens, bpm,
self.audio_downsample)
# โ”€โ”€ Decode โ”€โ”€
x = emb
onset_kv = onset_curve if onset_curve is not None else None
for blk in self.chart_decoder:
x = blk(x, enc_out=aud, self_positions=chart_pos,
diff_emb=diff_vec, onset_film=self.onset_film, onset_kv=onset_kv)
x = self.dec_norm(x)
if return_hidden:
result = {"hidden": x}
if return_aux:
result.update({
"presence": self.presence_head(x),
"type": self.type_head(x),
"position": self.position_head(x),
"division": self.division_head(x),
"sim": self.sim_head(x),
"duration": self.duration_head(x),
})
return result
token_logits = self.output_head(x)
if not return_aux:
return token_logits
return {
"token": token_logits,
"presence": self.presence_head(x),
"type": self.type_head(x),
"position": self.position_head(x),
"division": self.division_head(x),
"sim": self.sim_head(x),
"duration": self.duration_head(x),
}
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def device(self) -> torch.device:
return next(self.parameters()).device