| """ |
| MaiGenerator โ Audio-conditioned autoregressive maimai chart generator. |
| |
| Architecture: Encoder-Decoder Transformer with time-aligned RoPE. |
| |
| Audio tokens โ AudioEncoder โ audio_feat [T_aud, d] |
| โ Cross-Attention |
| Chart tokens โ ChartDecoder โโโโโโโโ |
| (autoregressive, + BPM/Diff/Genre conditioning |
| causal mask) + time-aligned RoPE positions |
| |
| Key design: Chart RoPE uses audio frame indices (via BPM translation), |
| ensuring strong time alignment between music and generated notes. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import tokenizer as chart_tokenizer |
| from tokenizer import CONFIG_BASE |
|
|
| |
| |
| |
|
|
| CHART_VOCAB_SIZE = chart_tokenizer.VOCAB_SIZE |
| AUDIO_VOCAB_SIZE = 2051 |
| AUDIO_FRAME_RATE = 75 |
|
|
| BOS, EOS, PAD = 1, 2, 0 |
|
|
| DIFF_NAMES = ["BASIC", "ADVANCED", "EXPERT", "MASTER", "ReMASTER"] |
| NUM_DIFFICULTIES = len(DIFF_NAMES) |
|
|
| |
| DIV_MAP = {5: 1, 6: 2, 7: 4, 8: 8, 9: 16, 10: 32, |
| 11: 48, 12: 64, 13: 128, 14: 192, 15: 384} |
| DUR_TOKEN = 17 |
| RST_TOKEN = 16 |
| TAP_BASE, TAP_END = 18, 26 |
| BRK_BASE, BRK_END = 26, 34 |
| HLD_BASE, HLD_END = 34, 42 |
| SLD_BASE, SLD_END = 42, 50 |
| SLD_BEG_TOKEN = 50 |
| SLD_END_TOKEN = 51 |
| SIM_BEG_TOKEN = 52 |
| SIM_END_TOKEN = 53 |
| TCH_BASE, TCH_END = 54, 95 |
|
|
| TYPE_REST = 0 |
| TYPE_TAP = 1 |
| TYPE_HOLD = 2 |
| TYPE_SLIDE = 3 |
| TYPE_BREAK = 4 |
| TYPE_TOUCH = 5 |
| TYPE_CONTROL = 6 |
| NUM_TOKEN_TYPES = 7 |
| NUM_POSITIONS = 9 |
| NUM_DIV_CLASSES = len(DIV_MAP) |
|
|
|
|
| def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, |
| dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: |
| """Use PyTorch SDPA so CUDA can pick Flash/mem-efficient attention kernels.""" |
| return F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=dropout_p if torch.is_grad_enabled() else 0.0, |
| is_causal=is_causal, |
| ) |
|
|
|
|
| def is_timeline_token(tok: torch.Tensor) -> torch.Tensor: |
| """Tokens that correspond to one chart time slot after decoding.""" |
| is_note = (((tok >= TAP_BASE) & (tok < TAP_END)) | |
| ((tok >= BRK_BASE) & (tok < BRK_END)) | |
| ((tok >= HLD_BASE) & (tok < HLD_END)) | |
| ((tok >= SLD_BASE) & (tok < SLD_END)) | |
| ((tok >= TCH_BASE) & (tok < TCH_END)) | |
| (tok >= CONFIG_BASE)) |
| return (tok == RST_TOKEN) | (tok == SIM_BEG_TOKEN) | (tok == SLD_BEG_TOKEN) | is_note |
|
|
|
|
| def token_structure_features(tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return token type ids and 0-based position ids for structure-aware embeds.""" |
| typ = torch.full_like(tokens, TYPE_CONTROL) |
| pos = torch.full_like(tokens, NUM_POSITIONS - 1) |
|
|
| typ = torch.where(tokens == RST_TOKEN, torch.full_like(typ, TYPE_REST), typ) |
|
|
| ranges = [ |
| (TAP_BASE, TAP_END, TYPE_TAP), |
| (HLD_BASE, HLD_END, TYPE_HOLD), |
| (SLD_BASE, SLD_END, TYPE_SLIDE), |
| (BRK_BASE, BRK_END, TYPE_BREAK), |
| (TCH_BASE, TCH_END, TYPE_TOUCH), |
| ] |
| for start, end, token_type in ranges: |
| mask = (tokens >= start) & (tokens < end) |
| typ = torch.where(mask, torch.full_like(typ, token_type), typ) |
| pos = torch.where(mask, (tokens - start) % 8, pos) |
| return typ, pos |
|
|
|
|
| |
| |
| |
|
|
| class RoPE(nn.Module): |
| """Rotary Position Embedding supporting custom position indices.""" |
|
|
| def __init__(self, dim: int, base: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, x: torch.Tensor, |
| positions: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, H, T, D] โ query or key. |
| positions: [T] or [B, T] โ custom position indices. |
| |
| Returns: |
| Rotated tensor [B, H, T, D]. |
| """ |
| B, H, T, D = x.shape |
| device = x.device |
|
|
| if positions is None: |
| positions = torch.arange(T, device=device, dtype=torch.float32) |
|
|
| if positions.dim() == 1: |
| positions = positions.unsqueeze(0) |
| positions = positions.unsqueeze(1).unsqueeze(-1) |
|
|
| angles = positions * self.inv_freq.to(device) |
| sin, cos = angles.sin(), angles.cos() |
|
|
| x_even, x_odd = x[..., 0::2], x[..., 1::2] |
| out = torch.empty_like(x) |
| out[..., 0::2] = x_even * cos - x_odd * sin |
| out[..., 1::2] = x_even * sin + x_odd * cos |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class OnsetFiLM(nn.Module): |
| """Beat-prior injection: modulates encoder features at beat positions.""" |
|
|
| def __init__(self, d_model: int = 512): |
| super().__init__() |
| self.gamma = nn.Linear(1, d_model) |
| self.beta = nn.Linear(1, d_model) |
|
|
| def forward(self, enc_out, onset): |
| """enc_out: [B, T_enc, D], onset: [B, T_enc, 1]""" |
| g = torch.tanh(self.gamma(onset)) * 0.5 |
| b = self.beta(onset) * 0.1 |
| return enc_out * (1.0 + g) + b |
|
|
|
|
| |
| |
| |
|
|
| class MoEFFN(nn.Module): |
| """MoE FFN: routes input through N experts weighted by difficulty.""" |
|
|
| def __init__(self, d_model=512, d_ff=2048, n_experts=6, dropout=0.1): |
| super().__init__() |
| self.n_experts = n_experts |
| self.experts = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), nn.Dropout(dropout), |
| ) for _ in range(n_experts) |
| ]) |
| self.router = nn.Linear(d_model, n_experts) |
|
|
| def forward(self, x, diff_emb): |
| """x: [B, T, d], diff_emb: [B, d]""" |
| weights = F.softmax(self.router(diff_emb), dim=-1) |
| out = sum(weights[:, i:i+1, None] * self.experts[i](x) |
| for i in range(self.n_experts)) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class EncoderBlock(nn.Module): |
| """Pre-LN encoder: Self-Attn + FFN. Uses FlashAttention for O(T) memory.""" |
|
|
| def __init__(self, d_model: int = 512, heads: int = 8, |
| d_ff: int = 2048, dropout: float = 0.1): |
| super().__init__() |
| self.heads = heads |
| self.head_dim = d_model // heads |
|
|
| self.q_proj = nn.Linear(d_model, d_model) |
| self.k_proj = nn.Linear(d_model, d_model) |
| self.v_proj = nn.Linear(d_model, d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
| self.rope = RoPE(self.head_dim) |
|
|
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), nn.Dropout(dropout), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor, |
| positions: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| B, T, D = x.shape |
|
|
| residual = x |
| x_norm = self.norm1(x) |
|
|
| Q = self.q_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2) |
| K = self.k_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2) |
| V = self.v_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2) |
|
|
| |
| if positions is not None: |
| Q = self.rope(Q, positions) |
| K = self.rope(K, positions) |
|
|
| if mask is not None: |
| attn_mask = mask |
| if attn_mask.dtype != torch.bool: |
| attn_mask = attn_mask.to(torch.bool) |
| attn_out = F.scaled_dot_product_attention( |
| Q, K, V, |
| attn_mask=attn_mask, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| ) |
| else: |
| attn_out = _sdpa(Q, K, V, self.dropout.p if self.training else 0.0) |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D) |
| x = residual + self.dropout(self.out_proj(attn_out)) |
|
|
| residual = x |
| x = residual + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class DecoderBlock(nn.Module): |
| """Pre-LN decoder: Causal Self-Attn + Cross-Attn + FFN. |
| |
| Supports KV-cache for fast autoregressive inference. |
| """ |
|
|
| def __init__(self, d_model: int = 512, heads: int = 8, |
| d_ff: int = 2048, dropout: float = 0.1, |
| use_moe: bool = False, n_experts: int = 6): |
| super().__init__() |
| self.heads = heads |
| self.head_dim = d_model // heads |
|
|
| |
| self.q_proj = nn.Linear(d_model, d_model) |
| self.k_proj = nn.Linear(d_model, d_model) |
| self.v_proj = nn.Linear(d_model, d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
| self.rope = RoPE(self.head_dim) |
|
|
| |
| self.cross_q = nn.Linear(d_model, d_model) |
| self.cross_k = nn.Linear(d_model, d_model) |
| self.cross_v = nn.Linear(d_model, d_model) |
| self.cross_out = nn.Linear(d_model, d_model) |
|
|
| |
| if use_moe: |
| self.ffn = MoEFFN(d_model, d_ff, n_experts, dropout) |
| self.is_moe = True |
| else: |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), nn.Dropout(dropout), |
| ) |
| self.is_moe = False |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def _init_self_kv_cache(self, batch_size: int, max_len: int, device): |
| """Pre-allocate self-attention KV cache: tuple of (K, V) tensors.""" |
| D = self.head_dim |
| H = self.heads |
| self._self_k_cache = torch.zeros(batch_size, H, max_len, D, device=device) |
| self._self_v_cache = torch.zeros(batch_size, H, max_len, D, device=device) |
| self._self_cache_len = 0 |
|
|
| def _init_cross_kv_cache(self, enc_out: torch.Tensor): |
| """Precompute cross-attention K, V from encoder output (fixed during generation). |
| Note: enc_out is the raw encoder output, NOT pre-normed (norm2 is for decoder hidden).""" |
| B, T_enc, D = enc_out.shape |
| Kc = self.cross_k(enc_out).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) |
| Vc = self.cross_v(enc_out).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) |
| self._cross_k_cache = Kc |
| self._cross_v_cache = Vc |
|
|
| def forward(self, x, enc_out, self_positions=None, diff_emb=None, |
| onset_film=None, onset_kv=None, |
| use_cache: bool = False): |
| B, T_dec, D = x.shape |
|
|
| |
| residual = x |
| x_norm = self.norm1(x) |
| Q = self.q_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) |
| K = self.k_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) |
| V = self.v_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) |
|
|
| |
| if self_positions is not None: |
| Q = self.rope(Q, self_positions) |
| K = self.rope(K, self_positions) |
|
|
| if use_cache and hasattr(self, '_self_k_cache'): |
| |
| cache_len = self._self_cache_len |
| new_len = T_dec |
| self._self_k_cache[:, :, cache_len:cache_len + new_len] = K |
| self._self_v_cache[:, :, cache_len:cache_len + new_len] = V |
| K = self._self_k_cache[:, :, :cache_len + new_len] |
| V = self._self_v_cache[:, :, :cache_len + new_len] |
| self._self_cache_len = cache_len + new_len |
|
|
| attn = _sdpa(Q, K, V, self.dropout.p if self.training else 0.0, |
| is_causal=(not use_cache)) |
| x = residual + self.dropout(self.out_proj( |
| attn.transpose(1, 2).contiguous().view(B, T_dec, D))) |
|
|
| |
| residual = x |
| x_norm = self.norm2(x) |
| Qc = self.cross_q(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) |
|
|
| if use_cache and hasattr(self, '_cross_k_cache'): |
| Kc = self._cross_k_cache |
| Vc = self._cross_v_cache |
| else: |
| _enc = enc_out |
| if onset_film is not None and onset_kv is not None: |
| _enc = onset_film(_enc, onset_kv) |
| T_enc = _enc.shape[1] |
| Kc = self.cross_k(_enc).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) |
| Vc = self.cross_v(_enc).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) |
|
|
| attn_c = _sdpa(Qc, Kc, Vc, self.dropout.p if self.training else 0.0) |
| x = residual + self.dropout(self.cross_out( |
| attn_c.transpose(1, 2).contiguous().view(B, T_dec, D))) |
|
|
| |
| residual = x |
| if self.is_moe and diff_emb is not None: |
| x = residual + self.ffn(self.norm3(x), diff_emb) |
| else: |
| x = residual + self.ffn(self.norm3(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class MaiGenerator(nn.Module): |
| """Audio-conditioned autoregressive maimai chart generator. |
| |
| Input: |
| audio_tokens: [B, T_aud] โ EnCodec tokens |
| chart_tokens: [B, T_chart] โ chart token sequence (train: full[:-1]) |
| bpm: [B, 1] โ BPM value |
| difficulty: [B, 1] โ difficulty enum (0..4) |
| level_value: [B, 1] โ numeric level (e.g. 12.4) |
| genre: [B, 1] โ genre index (optional) |
| |
| Output: |
| logits: [B, T_chart, chart_vocab] โ next-token prediction |
| """ |
|
|
| def __init__(self, d_model: int = 512, enc_layers: int = 6, |
| dec_layers: int = 8, heads: int = 8, d_ff: int = 2048, |
| dropout: float = 0.1, chart_vocab: int | None = None, |
| audio_vocab: int = AUDIO_VOCAB_SIZE, |
| num_genres: int = 16, max_audio_len: int = 32768, |
| audio_downsample: int = 8, use_moe: bool = True, |
| n_experts: int = 6, moe_layers: list = None): |
| """moe_layers: indices of decoder layers that use MoE (e.g. [8,9,10,11]) |
| Other decoder layers use shared FFN. None = all decoder layers use MoE.""" |
| super().__init__() |
| if chart_vocab is None: |
| chart_vocab = chart_tokenizer.VOCAB_SIZE |
| self.d_model = d_model |
| self.chart_vocab_size = chart_vocab |
| self.audio_downsample = audio_downsample |
|
|
| if moe_layers is None: |
| moe_layers = list(range(dec_layers)) |
|
|
| |
| self.audio_embed = nn.Embedding(audio_vocab, d_model) |
| self.chart_embed = nn.Embedding(chart_vocab, d_model) |
| self.chart_type_embed = nn.Embedding(NUM_TOKEN_TYPES, d_model) |
| self.chart_pos_embed = nn.Embedding(NUM_POSITIONS, d_model) |
|
|
| |
| self.onset_film = OnsetFiLM(d_model) |
|
|
| |
| if audio_downsample > 1: |
| self.audio_down = nn.Sequential( |
| nn.Conv1d(d_model, d_model, kernel_size=audio_downsample, |
| stride=audio_downsample, padding=0), |
| nn.GELU(), |
| ) |
| else: |
| self.audio_down = nn.Identity() |
|
|
| self.audio_pos_embed = nn.Embedding(max_audio_len, d_model) |
|
|
| |
| self.bpm_proj = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(), |
| nn.Linear(d_model, d_model)) |
| self.diff_embed = nn.Embedding(NUM_DIFFICULTIES, d_model) |
| self.level_proj = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(), |
| nn.Linear(d_model, d_model)) |
| self.genre_embed = nn.Embedding(num_genres, d_model) |
|
|
| |
| self.audio_encoder = nn.ModuleList([ |
| EncoderBlock(d_model, heads, d_ff, dropout) for _ in range(enc_layers)]) |
| self.moe_layers = set(moe_layers) |
| n_shared = dec_layers - len(self.moe_layers) |
| n_moe = len(self.moe_layers) |
| print(f"Decoder: {n_shared} shared + {n_moe} MoE ร{n_experts} experts") |
| self.chart_decoder = nn.ModuleList([ |
| DecoderBlock(d_model, heads, d_ff, dropout, |
| use_moe=(i in self.moe_layers), n_experts=n_experts) |
| for i in range(dec_layers)]) |
|
|
| self.output_head = nn.Linear(d_model, chart_vocab) |
| self.presence_head = nn.Linear(d_model, 2) |
| self.type_head = nn.Linear(d_model, NUM_TOKEN_TYPES) |
| self.position_head = nn.Linear(d_model, 8) |
| self.division_head = nn.Linear(d_model, NUM_DIV_CLASSES) |
| self.sim_head = nn.Linear(d_model, 2) |
| self.duration_head = nn.Linear(d_model, 2) |
| self.enc_norm = nn.LayerNorm(d_model) |
| self.dec_norm = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight, gain=0.5) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
|
|
| def init_kv_cache(self, batch_size: int, max_len: int, enc_out: torch.Tensor): |
| """Initialize KV caches for all decoder blocks. Call before incremental generation. |
| |
| Args: |
| batch_size: Batch size (usually 1 for inference). |
| max_len: Maximum generation length for pre-allocation. |
| enc_out: Encoder output [B, T_enc, D] for cross-attention cache. |
| """ |
| for blk in self.chart_decoder: |
| blk._init_self_kv_cache(batch_size, max_len, enc_out.device) |
| blk._init_cross_kv_cache(enc_out) |
|
|
| |
|
|
| @staticmethod |
| def compute_chart_positions(chart_tokens: torch.Tensor, |
| bpm: torch.Tensor, |
| downsample: int = 4) -> torch.Tensor: |
| """Compute downsampled audio-frame positions for chart tokens. |
| |
| Args: |
| chart_tokens: [B, T] |
| bpm: [B, 1] |
| downsample: Audio downsampling factor. |
| |
| Returns: |
| positions: [B, T] โ downsampled frame indices (float). |
| """ |
| B, T = chart_tokens.shape |
| device = chart_tokens.device |
| bpm_v = bpm.view(B).float().clamp(min=30.0) |
|
|
| div_values = torch.full((B,), 4.0, device=device) |
| positions = torch.zeros(B, T, device=device) |
| current_beat = torch.zeros(B, device=device) |
| dur_param_skip = torch.zeros(B, dtype=torch.long, device=device) |
| sim_skip = torch.zeros(B, dtype=torch.long, device=device) |
| slide_active = torch.zeros(B, dtype=torch.bool, device=device) |
|
|
| for i in range(T): |
| tok = chart_tokens[:, i] |
|
|
| |
| for div_id, div_val in DIV_MAP.items(): |
| div_values = torch.where(tok == div_id, |
| torch.tensor(float(div_val), device=device), |
| div_values) |
|
|
| beat_per_token = 4.0 / div_values |
|
|
| |
| time_sec = current_beat * 60.0 / bpm_v |
| positions[:, i] = time_sec * AUDIO_FRAME_RATE / downsample |
|
|
| |
| |
| is_dur_param = dur_param_skip > 0 |
| is_dur = (tok == DUR_TOKEN) |
| in_sim_body = sim_skip > 0 |
| in_slide_body = slide_active & (tok != SLD_BEG_TOKEN) |
| group_body = in_sim_body | in_slide_body |
| advances_time = is_timeline_token(tok) & ~is_dur_param & ~is_dur & ~group_body |
| current_beat = torch.where(advances_time, |
| current_beat + beat_per_token, |
| current_beat) |
|
|
| is_sim_beg = tok == SIM_BEG_TOKEN |
| is_sim_end = tok == SIM_END_TOKEN |
| count_after_sim_beg = is_sim_beg & (i + 1 < T) |
| next_tok = chart_tokens[:, i + 1] if i + 1 < T else torch.zeros_like(tok) |
| sim_skip = torch.where(count_after_sim_beg, |
| torch.clamp(next_tok + 2, min=0), |
| torch.clamp(sim_skip - 1, min=0)) |
| sim_skip = torch.where(is_sim_end, torch.zeros_like(sim_skip), sim_skip) |
|
|
| slide_active = torch.where(tok == SLD_BEG_TOKEN, torch.ones_like(slide_active), slide_active) |
| slide_active = torch.where(tok == SLD_END_TOKEN, torch.zeros_like(slide_active), slide_active) |
|
|
| dur_param_skip = torch.where(is_dur, |
| torch.full_like(dur_param_skip, 2), |
| torch.clamp(dur_param_skip - 1, min=0)) |
|
|
| return positions |
|
|
| |
|
|
| def forward(self, audio_tokens: torch.Tensor, |
| chart_tokens: torch.Tensor, bpm: torch.Tensor, |
| difficulty: torch.Tensor, level_value: torch.Tensor, |
| genre: Optional[torch.Tensor] = None, |
| onset_curve: Optional[torch.Tensor] = None, |
| return_aux: bool = False, |
| return_hidden: bool = False) -> torch.Tensor | dict[str, torch.Tensor]: |
| """Training forward pass (teacher forcing). |
| |
| Returns: |
| logits: [B, T_chart, 256] |
| """ |
| B, T_chart = chart_tokens.shape |
| device = chart_tokens.device |
|
|
| |
| T_aud = audio_tokens.shape[1] |
| aud = self.audio_embed(audio_tokens) |
| aud = aud + self.audio_pos_embed( |
| torch.arange(T_aud, device=device).unsqueeze(0).expand(B, -1)) |
|
|
| |
| if self.audio_downsample > 1: |
| aud = aud.transpose(1, 2) |
| aud = self.audio_down(aud) |
| aud = aud.transpose(1, 2) |
| T_aud = aud.shape[1] |
|
|
| if onset_curve is None: |
| delta = torch.zeros(B, T_aud, device=device, dtype=aud.dtype) |
| if T_aud > 1: |
| delta[:, 1:] = (aud[:, 1:] - aud[:, :-1]).pow(2).mean(dim=-1).sqrt() |
| denom = delta.amax(dim=1, keepdim=True).clamp_min(1e-6) |
| delta = delta / denom |
| onset_curve = delta.unsqueeze(-1) |
| elif onset_curve.dim() == 2: |
| onset_curve = onset_curve.unsqueeze(-1) |
| if onset_curve.shape[1] != T_aud: |
| onset_curve = F.interpolate(onset_curve.transpose(1, 2), size=T_aud, |
| mode="linear", align_corners=False).transpose(1, 2) |
|
|
| aud = self.dropout(aud) |
| aud_pos = torch.arange(T_aud, device=device, dtype=torch.float32) |
| for blk in self.audio_encoder: |
| aud = blk(aud, positions=aud_pos) |
| aud = self.enc_norm(aud) |
|
|
| |
| diff_vec = self.diff_embed(difficulty.squeeze(-1)) |
|
|
| |
| token_type, token_pos = token_structure_features(chart_tokens) |
| emb = (self.chart_embed(chart_tokens) + |
| self.chart_type_embed(token_type) + |
| self.chart_pos_embed(token_pos)) |
|
|
| bpm_emb = self.bpm_proj(bpm.float()).unsqueeze(1) |
| diff_emb = self.diff_embed(difficulty.squeeze(-1)).unsqueeze(1) |
| level_emb = self.level_proj(level_value.float()).unsqueeze(1) |
| genre_emb = torch.zeros(B, 1, self.d_model, device=device) |
| if genre is not None: |
| genre_emb = self.genre_embed(genre.squeeze(-1)).unsqueeze(1) |
|
|
| emb = emb + bpm_emb + diff_emb + level_emb + genre_emb |
|
|
| |
| chart_pos = self.compute_chart_positions(chart_tokens, bpm, |
| self.audio_downsample) |
|
|
| |
| x = emb |
| onset_kv = onset_curve if onset_curve is not None else None |
| for blk in self.chart_decoder: |
| x = blk(x, enc_out=aud, self_positions=chart_pos, |
| diff_emb=diff_vec, onset_film=self.onset_film, onset_kv=onset_kv) |
| x = self.dec_norm(x) |
|
|
| if return_hidden: |
| result = {"hidden": x} |
| if return_aux: |
| result.update({ |
| "presence": self.presence_head(x), |
| "type": self.type_head(x), |
| "position": self.position_head(x), |
| "division": self.division_head(x), |
| "sim": self.sim_head(x), |
| "duration": self.duration_head(x), |
| }) |
| return result |
|
|
| token_logits = self.output_head(x) |
| if not return_aux: |
| return token_logits |
|
|
| return { |
| "token": token_logits, |
| "presence": self.presence_head(x), |
| "type": self.type_head(x), |
| "position": self.position_head(x), |
| "division": self.division_head(x), |
| "sim": self.sim_head(x), |
| "duration": self.duration_head(x), |
| } |
|
|
| @property |
| def device(self) -> torch.device: |
| return next(self.parameters()).device |
|
|
| @property |
| def device(self) -> torch.device: |
| return next(self.parameters()).device |
|
|