""" MaiGenerator — Audio-conditioned autoregressive maimai chart generator. Architecture: Encoder-Decoder Transformer with time-aligned RoPE. Audio tokens → AudioEncoder → audio_feat [T_aud, d] │ Cross-Attention Chart tokens → ChartDecoder ───────┘ (autoregressive, + BPM/Diff/Genre conditioning causal mask) + time-aligned RoPE positions Key design: Chart RoPE uses audio frame indices (via BPM translation), ensuring strong time alignment between music and generated notes. """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import tokenizer as chart_tokenizer from tokenizer import CONFIG_BASE # ═══════════════════════════════════════════════════════════════════════ # Constants (aligned with tokenizers) # ═══════════════════════════════════════════════════════════════════════ CHART_VOCAB_SIZE = chart_tokenizer.VOCAB_SIZE AUDIO_VOCAB_SIZE = 2051 # MaiTrackTokenizer (2-layer EnCodec) AUDIO_FRAME_RATE = 75 # EnCodec 24kHz / 320 stride BOS, EOS, PAD = 1, 2, 0 DIFF_NAMES = ["BASIC", "ADVANCED", "EXPERT", "MASTER", "ReMASTER"] NUM_DIFFICULTIES = len(DIFF_NAMES) # Beat division token → value DIV_MAP = {5: 1, 6: 2, 7: 4, 8: 8, 9: 16, 10: 32, 11: 48, 12: 64, 13: 128, 14: 192, 15: 384} DUR_TOKEN = 17 # [DUR] in chart vocab RST_TOKEN = 16 TAP_BASE, TAP_END = 18, 26 BRK_BASE, BRK_END = 26, 34 HLD_BASE, HLD_END = 34, 42 SLD_BASE, SLD_END = 42, 50 SLD_BEG_TOKEN = 50 SLD_END_TOKEN = 51 SIM_BEG_TOKEN = 52 SIM_END_TOKEN = 53 TCH_BASE, TCH_END = 54, 95 TYPE_REST = 0 TYPE_TAP = 1 TYPE_HOLD = 2 TYPE_SLIDE = 3 TYPE_BREAK = 4 TYPE_TOUCH = 5 TYPE_CONTROL = 6 NUM_TOKEN_TYPES = 7 NUM_POSITIONS = 9 # 0-7 real positions, 8 = none/control NUM_DIV_CLASSES = len(DIV_MAP) def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: """Use PyTorch SDPA so CUDA can pick Flash/mem-efficient attention kernels.""" return F.scaled_dot_product_attention( q, k, v, dropout_p=dropout_p if torch.is_grad_enabled() else 0.0, is_causal=is_causal, ) def is_timeline_token(tok: torch.Tensor) -> torch.Tensor: """Tokens that correspond to one chart time slot after decoding.""" is_note = (((tok >= TAP_BASE) & (tok < TAP_END)) | ((tok >= BRK_BASE) & (tok < BRK_END)) | ((tok >= HLD_BASE) & (tok < HLD_END)) | ((tok >= SLD_BASE) & (tok < SLD_END)) | ((tok >= TCH_BASE) & (tok < TCH_END)) | (tok >= CONFIG_BASE)) return (tok == RST_TOKEN) | (tok == SIM_BEG_TOKEN) | (tok == SLD_BEG_TOKEN) | is_note def token_structure_features(tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return token type ids and 0-based position ids for structure-aware embeds.""" typ = torch.full_like(tokens, TYPE_CONTROL) pos = torch.full_like(tokens, NUM_POSITIONS - 1) typ = torch.where(tokens == RST_TOKEN, torch.full_like(typ, TYPE_REST), typ) ranges = [ (TAP_BASE, TAP_END, TYPE_TAP), (HLD_BASE, HLD_END, TYPE_HOLD), (SLD_BASE, SLD_END, TYPE_SLIDE), (BRK_BASE, BRK_END, TYPE_BREAK), (TCH_BASE, TCH_END, TYPE_TOUCH), ] for start, end, token_type in ranges: mask = (tokens >= start) & (tokens < end) typ = torch.where(mask, torch.full_like(typ, token_type), typ) pos = torch.where(mask, (tokens - start) % 8, pos) return typ, pos # ═══════════════════════════════════════════════════════════════════════ # RoPE with custom positions # ═══════════════════════════════════════════════════════════════════════ class RoPE(nn.Module): """Rotary Position Embedding supporting custom position indices.""" def __init__(self, dim: int, base: float = 10000.0): super().__init__() self.dim = dim inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, x: torch.Tensor, positions: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: [B, H, T, D] — query or key. positions: [T] or [B, T] — custom position indices. Returns: Rotated tensor [B, H, T, D]. """ B, H, T, D = x.shape device = x.device if positions is None: positions = torch.arange(T, device=device, dtype=torch.float32) if positions.dim() == 1: positions = positions.unsqueeze(0) positions = positions.unsqueeze(1).unsqueeze(-1) # [B, 1, T, 1] angles = positions * self.inv_freq.to(device) # [B, 1, T, D/2] sin, cos = angles.sin(), angles.cos() x_even, x_odd = x[..., 0::2], x[..., 1::2] out = torch.empty_like(x) out[..., 0::2] = x_even * cos - x_odd * sin out[..., 1::2] = x_even * sin + x_odd * cos return out # ═══════════════════════════════════════════════════════════════════════ # Onset Feature Injection (FiLM) # ═══════════════════════════════════════════════════════════════════════ class OnsetFiLM(nn.Module): """Beat-prior injection: modulates encoder features at beat positions.""" def __init__(self, d_model: int = 512): super().__init__() self.gamma = nn.Linear(1, d_model) self.beta = nn.Linear(1, d_model) def forward(self, enc_out, onset): """enc_out: [B, T_enc, D], onset: [B, T_enc, 1]""" g = torch.tanh(self.gamma(onset)) * 0.5 # [B, T_enc, D] b = self.beta(onset) * 0.1 return enc_out * (1.0 + g) + b # ═══════════════════════════════════════════════════════════════════════ # MoE FFN (Mixture of Experts for difficulty routing) # ═══════════════════════════════════════════════════════════════════════ class MoEFFN(nn.Module): """MoE FFN: routes input through N experts weighted by difficulty.""" def __init__(self, d_model=512, d_ff=2048, n_experts=6, dropout=0.1): super().__init__() self.n_experts = n_experts self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout), ) for _ in range(n_experts) ]) self.router = nn.Linear(d_model, n_experts) def forward(self, x, diff_emb): """x: [B, T, d], diff_emb: [B, d]""" weights = F.softmax(self.router(diff_emb), dim=-1) # [B, N] out = sum(weights[:, i:i+1, None] * self.experts[i](x) for i in range(self.n_experts)) return out # ═══════════════════════════════════════════════════════════════════════ # Encoder Block # ═══════════════════════════════════════════════════════════════════════ class EncoderBlock(nn.Module): """Pre-LN encoder: Self-Attn + FFN. Uses FlashAttention for O(T) memory.""" def __init__(self, d_model: int = 512, heads: int = 8, d_ff: int = 2048, dropout: float = 0.1): super().__init__() self.heads = heads self.head_dim = d_model // heads self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) self.rope = RoPE(self.head_dim) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout), ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, positions: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, T, D = x.shape residual = x x_norm = self.norm1(x) Q = self.q_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2) K = self.k_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2) V = self.v_proj(x_norm).view(B, T, self.heads, self.head_dim).transpose(1, 2) # RoPE time-aligned positions. if positions is not None: Q = self.rope(Q, positions) K = self.rope(K, positions) if mask is not None: attn_mask = mask if attn_mask.dtype != torch.bool: attn_mask = attn_mask.to(torch.bool) attn_out = F.scaled_dot_product_attention( Q, K, V, attn_mask=attn_mask, dropout_p=self.dropout.p if self.training else 0.0, ) else: attn_out = _sdpa(Q, K, V, self.dropout.p if self.training else 0.0) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D) x = residual + self.dropout(self.out_proj(attn_out)) residual = x x = residual + self.ffn(self.norm2(x)) return x # ═══════════════════════════════════════════════════════════════════════ # Decoder Block # ═══════════════════════════════════════════════════════════════════════ class DecoderBlock(nn.Module): """Pre-LN decoder: Causal Self-Attn + Cross-Attn + FFN. Supports KV-cache for fast autoregressive inference. """ def __init__(self, d_model: int = 512, heads: int = 8, d_ff: int = 2048, dropout: float = 0.1, use_moe: bool = False, n_experts: int = 6): super().__init__() self.heads = heads self.head_dim = d_model // heads # Self-attn self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) self.rope = RoPE(self.head_dim) # Cross-attn self.cross_q = nn.Linear(d_model, d_model) self.cross_k = nn.Linear(d_model, d_model) self.cross_v = nn.Linear(d_model, d_model) self.cross_out = nn.Linear(d_model, d_model) # FFN (standard or MoE) if use_moe: self.ffn = MoEFFN(d_model, d_ff, n_experts, dropout) self.is_moe = True else: self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout), ) self.is_moe = False self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def _init_self_kv_cache(self, batch_size: int, max_len: int, device): """Pre-allocate self-attention KV cache: tuple of (K, V) tensors.""" D = self.head_dim H = self.heads self._self_k_cache = torch.zeros(batch_size, H, max_len, D, device=device) self._self_v_cache = torch.zeros(batch_size, H, max_len, D, device=device) self._self_cache_len = 0 def _init_cross_kv_cache(self, enc_out: torch.Tensor): """Precompute cross-attention K, V from encoder output (fixed during generation). Note: enc_out is the raw encoder output, NOT pre-normed (norm2 is for decoder hidden).""" B, T_enc, D = enc_out.shape Kc = self.cross_k(enc_out).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) Vc = self.cross_v(enc_out).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) self._cross_k_cache = Kc self._cross_v_cache = Vc def forward(self, x, enc_out, self_positions=None, diff_emb=None, onset_film=None, onset_kv=None, use_cache: bool = False): B, T_dec, D = x.shape # ── Causal Self-Attn ── residual = x x_norm = self.norm1(x) Q = self.q_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) K = self.k_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) V = self.v_proj(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) # RoPE time-aligned positions (full sequence positions for cache mode) if self_positions is not None: Q = self.rope(Q, self_positions) K = self.rope(K, self_positions) if use_cache and hasattr(self, '_self_k_cache'): # Append new K, V to cache cache_len = self._self_cache_len new_len = T_dec self._self_k_cache[:, :, cache_len:cache_len + new_len] = K self._self_v_cache[:, :, cache_len:cache_len + new_len] = V K = self._self_k_cache[:, :, :cache_len + new_len] V = self._self_v_cache[:, :, :cache_len + new_len] self._self_cache_len = cache_len + new_len attn = _sdpa(Q, K, V, self.dropout.p if self.training else 0.0, is_causal=(not use_cache)) x = residual + self.dropout(self.out_proj( attn.transpose(1, 2).contiguous().view(B, T_dec, D))) # ── Cross-Attn (with optional OnsetFiLM modulation) ── residual = x x_norm = self.norm2(x) Qc = self.cross_q(x_norm).view(B, T_dec, self.heads, self.head_dim).transpose(1, 2) if use_cache and hasattr(self, '_cross_k_cache'): Kc = self._cross_k_cache Vc = self._cross_v_cache else: _enc = enc_out if onset_film is not None and onset_kv is not None: _enc = onset_film(_enc, onset_kv) T_enc = _enc.shape[1] Kc = self.cross_k(_enc).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) Vc = self.cross_v(_enc).view(B, T_enc, self.heads, self.head_dim).transpose(1, 2) attn_c = _sdpa(Qc, Kc, Vc, self.dropout.p if self.training else 0.0) x = residual + self.dropout(self.cross_out( attn_c.transpose(1, 2).contiguous().view(B, T_dec, D))) # ── FFN ── residual = x if self.is_moe and diff_emb is not None: x = residual + self.ffn(self.norm3(x), diff_emb) else: x = residual + self.ffn(self.norm3(x)) return x # ═══════════════════════════════════════════════════════════════════════ # MaiGenerator # ═══════════════════════════════════════════════════════════════════════ class MaiGenerator(nn.Module): """Audio-conditioned autoregressive maimai chart generator. Input: audio_tokens: [B, T_aud] — EnCodec tokens chart_tokens: [B, T_chart] — chart token sequence (train: full[:-1]) bpm: [B, 1] — BPM value difficulty: [B, 1] — difficulty enum (0..4) level_value: [B, 1] — numeric level (e.g. 12.4) genre: [B, 1] — genre index (optional) Output: logits: [B, T_chart, chart_vocab] — next-token prediction """ def __init__(self, d_model: int = 512, enc_layers: int = 6, dec_layers: int = 8, heads: int = 8, d_ff: int = 2048, dropout: float = 0.1, chart_vocab: int | None = None, audio_vocab: int = AUDIO_VOCAB_SIZE, num_genres: int = 16, max_audio_len: int = 32768, audio_downsample: int = 8, use_moe: bool = True, n_experts: int = 6, moe_layers: list = None): """moe_layers: indices of decoder layers that use MoE (e.g. [8,9,10,11]) Other decoder layers use shared FFN. None = all decoder layers use MoE.""" super().__init__() if chart_vocab is None: chart_vocab = chart_tokenizer.VOCAB_SIZE self.d_model = d_model self.chart_vocab_size = chart_vocab self.audio_downsample = audio_downsample if moe_layers is None: moe_layers = list(range(dec_layers)) # all MoE # Embeddings self.audio_embed = nn.Embedding(audio_vocab, d_model) self.chart_embed = nn.Embedding(chart_vocab, d_model) self.chart_type_embed = nn.Embedding(NUM_TOKEN_TYPES, d_model) self.chart_pos_embed = nn.Embedding(NUM_POSITIONS, d_model) # Onset feature injection self.onset_film = OnsetFiLM(d_model) # Audio downsampling (Conv1D + LayerNorm to reduce seq len) if audio_downsample > 1: self.audio_down = nn.Sequential( nn.Conv1d(d_model, d_model, kernel_size=audio_downsample, stride=audio_downsample, padding=0), nn.GELU(), ) else: self.audio_down = nn.Identity() self.audio_pos_embed = nn.Embedding(max_audio_len, d_model) # Conditions self.bpm_proj = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(), nn.Linear(d_model, d_model)) self.diff_embed = nn.Embedding(NUM_DIFFICULTIES, d_model) self.level_proj = nn.Sequential(nn.Linear(1, d_model), nn.SiLU(), nn.Linear(d_model, d_model)) self.genre_embed = nn.Embedding(num_genres, d_model) # Encoder / Decoder (hybrid: shared FFN + MoE layers) self.audio_encoder = nn.ModuleList([ EncoderBlock(d_model, heads, d_ff, dropout) for _ in range(enc_layers)]) self.moe_layers = set(moe_layers) n_shared = dec_layers - len(self.moe_layers) n_moe = len(self.moe_layers) print(f"Decoder: {n_shared} shared + {n_moe} MoE ×{n_experts} experts") self.chart_decoder = nn.ModuleList([ DecoderBlock(d_model, heads, d_ff, dropout, use_moe=(i in self.moe_layers), n_experts=n_experts) for i in range(dec_layers)]) self.output_head = nn.Linear(d_model, chart_vocab) self.presence_head = nn.Linear(d_model, 2) self.type_head = nn.Linear(d_model, NUM_TOKEN_TYPES) self.position_head = nn.Linear(d_model, 8) self.division_head = nn.Linear(d_model, NUM_DIV_CLASSES) self.sim_head = nn.Linear(d_model, 2) self.duration_head = nn.Linear(d_model, 2) self.enc_norm = nn.LayerNorm(d_model) self.dec_norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=0.5) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def init_kv_cache(self, batch_size: int, max_len: int, enc_out: torch.Tensor): """Initialize KV caches for all decoder blocks. Call before incremental generation. Args: batch_size: Batch size (usually 1 for inference). max_len: Maximum generation length for pre-allocation. enc_out: Encoder output [B, T_enc, D] for cross-attention cache. """ for blk in self.chart_decoder: blk._init_self_kv_cache(batch_size, max_len, enc_out.device) blk._init_cross_kv_cache(enc_out) # ── Time-aligned positions (核心) ────────────────────────────── @staticmethod def compute_chart_positions(chart_tokens: torch.Tensor, bpm: torch.Tensor, downsample: int = 4) -> torch.Tensor: """Compute downsampled audio-frame positions for chart tokens. Args: chart_tokens: [B, T] bpm: [B, 1] downsample: Audio downsampling factor. Returns: positions: [B, T] — downsampled frame indices (float). """ B, T = chart_tokens.shape device = chart_tokens.device bpm_v = bpm.view(B).float().clamp(min=30.0) div_values = torch.full((B,), 4.0, device=device) positions = torch.zeros(B, T, device=device) current_beat = torch.zeros(B, device=device) dur_param_skip = torch.zeros(B, dtype=torch.long, device=device) sim_skip = torch.zeros(B, dtype=torch.long, device=device) slide_active = torch.zeros(B, dtype=torch.bool, device=device) for i in range(T): tok = chart_tokens[:, i] # Update beat division for div_id, div_val in DIV_MAP.items(): div_values = torch.where(tok == div_id, torch.tensor(float(div_val), device=device), div_values) beat_per_token = 4.0 / div_values # Record position: beat → seconds → audio frame → downsampled time_sec = current_beat * 60.0 / bpm_v positions[:, i] = time_sec * AUDIO_FRAME_RATE / downsample # Advance beat only for decoded timeline slots. SIM/SLD groups # occupy one slot at their begin token; their contents are structural. is_dur_param = dur_param_skip > 0 is_dur = (tok == DUR_TOKEN) in_sim_body = sim_skip > 0 in_slide_body = slide_active & (tok != SLD_BEG_TOKEN) group_body = in_sim_body | in_slide_body advances_time = is_timeline_token(tok) & ~is_dur_param & ~is_dur & ~group_body current_beat = torch.where(advances_time, current_beat + beat_per_token, current_beat) is_sim_beg = tok == SIM_BEG_TOKEN is_sim_end = tok == SIM_END_TOKEN count_after_sim_beg = is_sim_beg & (i + 1 < T) next_tok = chart_tokens[:, i + 1] if i + 1 < T else torch.zeros_like(tok) sim_skip = torch.where(count_after_sim_beg, torch.clamp(next_tok + 2, min=0), torch.clamp(sim_skip - 1, min=0)) sim_skip = torch.where(is_sim_end, torch.zeros_like(sim_skip), sim_skip) slide_active = torch.where(tok == SLD_BEG_TOKEN, torch.ones_like(slide_active), slide_active) slide_active = torch.where(tok == SLD_END_TOKEN, torch.zeros_like(slide_active), slide_active) dur_param_skip = torch.where(is_dur, torch.full_like(dur_param_skip, 2), torch.clamp(dur_param_skip - 1, min=0)) return positions # ── Forward ──────────────────────────────────────────────────── def forward(self, audio_tokens: torch.Tensor, chart_tokens: torch.Tensor, bpm: torch.Tensor, difficulty: torch.Tensor, level_value: torch.Tensor, genre: Optional[torch.Tensor] = None, onset_curve: Optional[torch.Tensor] = None, return_aux: bool = False, return_hidden: bool = False) -> torch.Tensor | dict[str, torch.Tensor]: """Training forward pass (teacher forcing). Returns: logits: [B, T_chart, 256] """ B, T_chart = chart_tokens.shape device = chart_tokens.device # ── Encode audio ── T_aud = audio_tokens.shape[1] aud = self.audio_embed(audio_tokens) aud = aud + self.audio_pos_embed( torch.arange(T_aud, device=device).unsqueeze(0).expand(B, -1)) # Downsample: [B, T, D] → [B, T//stride, D] if self.audio_downsample > 1: aud = aud.transpose(1, 2) # [B, D, T] aud = self.audio_down(aud) # [B, D, T'] aud = aud.transpose(1, 2) # [B, T', D] T_aud = aud.shape[1] if onset_curve is None: delta = torch.zeros(B, T_aud, device=device, dtype=aud.dtype) if T_aud > 1: delta[:, 1:] = (aud[:, 1:] - aud[:, :-1]).pow(2).mean(dim=-1).sqrt() denom = delta.amax(dim=1, keepdim=True).clamp_min(1e-6) delta = delta / denom onset_curve = delta.unsqueeze(-1) elif onset_curve.dim() == 2: onset_curve = onset_curve.unsqueeze(-1) if onset_curve.shape[1] != T_aud: onset_curve = F.interpolate(onset_curve.transpose(1, 2), size=T_aud, mode="linear", align_corners=False).transpose(1, 2) aud = self.dropout(aud) aud_pos = torch.arange(T_aud, device=device, dtype=torch.float32) for blk in self.audio_encoder: aud = blk(aud, positions=aud_pos) aud = self.enc_norm(aud) # Diff vector for decoder MoE routing diff_vec = self.diff_embed(difficulty.squeeze(-1)) # [B, d_model] # Embed chart + structural token features + conditions token_type, token_pos = token_structure_features(chart_tokens) emb = (self.chart_embed(chart_tokens) + self.chart_type_embed(token_type) + self.chart_pos_embed(token_pos)) bpm_emb = self.bpm_proj(bpm.float()).unsqueeze(1) diff_emb = self.diff_embed(difficulty.squeeze(-1)).unsqueeze(1) level_emb = self.level_proj(level_value.float()).unsqueeze(1) genre_emb = torch.zeros(B, 1, self.d_model, device=device) if genre is not None: genre_emb = self.genre_embed(genre.squeeze(-1)).unsqueeze(1) emb = emb + bpm_emb + diff_emb + level_emb + genre_emb # ── Time-aligned positions ── chart_pos = self.compute_chart_positions(chart_tokens, bpm, self.audio_downsample) # ── Decode ── x = emb onset_kv = onset_curve if onset_curve is not None else None for blk in self.chart_decoder: x = blk(x, enc_out=aud, self_positions=chart_pos, diff_emb=diff_vec, onset_film=self.onset_film, onset_kv=onset_kv) x = self.dec_norm(x) if return_hidden: result = {"hidden": x} if return_aux: result.update({ "presence": self.presence_head(x), "type": self.type_head(x), "position": self.position_head(x), "division": self.division_head(x), "sim": self.sim_head(x), "duration": self.duration_head(x), }) return result token_logits = self.output_head(x) if not return_aux: return token_logits return { "token": token_logits, "presence": self.presence_head(x), "type": self.type_head(x), "position": self.position_head(x), "division": self.division_head(x), "sim": self.sim_head(x), "duration": self.duration_head(x), } @property def device(self) -> torch.device: return next(self.parameters()).device @property def device(self) -> torch.device: return next(self.parameters()).device