from __future__ import annotations from dataclasses import dataclass import torch from grammar import ChartGrammarState, GrammarMode from model import DIV_MAP, DUR_TOKEN, is_timeline_token from grammar import NOTE_TOKENS from tokenizer import ( BOS, EOS, PAD, RST, DUR, HLD_TO_ID, ID_TO_DIV, SIM_BEG, SIM_END, SLD_BEG, SLD_END_TOKEN, SIM_COUNT_2, DUR_NUM_TO_ID, ID_TO_DUR_NUM, DUR_DEN_TO_ID, ID_TO_DUR_DEN, CONFIG_TO_ID, ID_TO_CONFIG, ) VALID_DUR_NUMS = list(DUR_NUM_TO_ID.keys()) VALID_DUR_DENS = list(DUR_DEN_TO_ID.keys()) MAX_DURATION_BEATS = 4.0 def snap_duration(beat: int, subdiv: int, max_beats: float, max_duration_beats: float = MAX_DURATION_BEATS) -> tuple[int, int]: """Snap/clamp duration params to common values that fit before track end.""" limit = min(max_beats, max_duration_beats) if limit <= 0: return 1, 64 candidates = [ (b, d) for b in VALID_DUR_NUMS for d in VALID_DUR_DENS if b / d <= limit + 1e-6 ] if not candidates: return 1, 64 raw = max(float(beat), 1.0) / max(float(subdiv), 1.0) return min(candidates, key=lambda x: (abs((x[0] / x[1]) - raw), x[1], x[0])) def _config_duration_beats(spec: tuple) -> float: if not spec: return 0.0 kind = spec[0] if kind in ("hld", "sld") and len(spec) >= 4: return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0) if kind == "pair" and len(spec) >= 7 and ("hld" in (spec[1], spec[3])): return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0) if kind == "multi" and len(spec) >= 5 and spec[1] == "hld": return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0) return 0.0 def _replace_config_duration_token(token_id: int, max_beats: float, max_duration_beats: float = MAX_DURATION_BEATS) -> int: spec = ID_TO_CONFIG.get(token_id) if spec is None or _config_duration_beats(spec) <= 0: return token_id n, d = snap_duration(int(spec[-2]), int(spec[-1]), max_beats, max_duration_beats) new_spec = tuple(list(spec[:-2]) + [n, d]) return CONFIG_TO_ID.get(new_spec, RST) @dataclass class DecodeState: """Stateful validator for candidate next-token decoding.""" bpm: float total_beats: float start_offset_beats: float min_division: int = 4 max_sim: int = 2 max_duration_beats: float = MAX_DURATION_BEATS def __post_init__(self) -> None: self.grammar = ChartGrammarState(min_division=self.min_division, max_sim=self.max_sim) self.current_beat = float(self.start_offset_beats) self.div_value = float(self.min_division) self.skip_beat = 0 self.slide_group_active = False self.pending_dur_num: int | None = None self.pending_duration_start_beat: float | None = None self.sim_start_beat: float | None = None self.slide_start_beat: float | None = None self._normal_mask_cache: dict[tuple[torch.device, int], torch.Tensor] = {} self._config_duration_cache: dict[tuple[torch.device, int], torch.Tensor] = {} def remaining_beats(self) -> float: return max(0.0, self.total_beats - self.current_beat) def remaining_duration_beats(self) -> float: start = self.pending_duration_start_beat if start is None: start = self.current_beat return max(0.0, min(self.total_beats - start, self.max_duration_beats)) def allowed_tokens(self) -> set[int]: allowed = set(self.grammar.allowed_tokens()) remaining = self.remaining_duration_beats() if self.grammar.mode is GrammarMode.DUR_NUM: nums = {DUR_NUM_TO_ID[b] for b in VALID_DUR_NUMS if b / max(VALID_DUR_DENS) <= remaining + 1e-6} return allowed & nums if nums else {DUR_NUM_TO_ID[1]} if self.grammar.mode is GrammarMode.DUR_DEN: beat_num = self.pending_dur_num or 1 dens = {DUR_DEN_TO_ID[d] for d in VALID_DUR_DENS if beat_num / d <= remaining + 1e-6} return allowed & dens if dens else {DUR_DEN_TO_ID[64]} filtered: set[int] = set() for tok in allowed: if self.is_candidate_safe(tok): filtered.add(tok) return filtered def can_start_sim(self) -> bool: return self.grammar.mode is GrammarMode.NORMAL and self.max_sim >= 2 def is_candidate_safe(self, tok: int) -> bool: if self.grammar.mode is not GrammarMode.NORMAL: return True if tok in (PAD, BOS, EOS): return False if tok == SIM_BEG and self.max_sim < 2: return False if tok in ID_TO_DIV: next_div = float(ID_TO_DIV[tok]) return self.remaining_beats() >= 4.0 / max(next_div, 1.0) - 1e-6 if tok in HLD_TO_ID.values() or tok == SLD_BEG: if self.remaining_beats() < 1.0 / max(VALID_DUR_DENS): return False if self._would_advance_time(tok): step = 4.0 / max(self.div_value, 1.0) return self.current_beat + step <= self.total_beats + 1e-6 return True def apply_logits_mask(self, logits: torch.Tensor) -> torch.Tensor: masked = torch.full_like(logits, float("-inf")) vocab_size = int(logits.numel()) if self.grammar.mode is GrammarMode.NORMAL: mask = self._normal_mask(logits.device, vocab_size).clone() step = 4.0 / max(self.div_value, 1.0) if self.current_beat + step > self.total_beats + 1e-6: timeline_ids = [RST, SIM_BEG, SLD_BEG, *NOTE_TOKENS, *ID_TO_CONFIG.keys()] timeline_ids = [t for t in timeline_ids if 0 <= int(t) < vocab_size] if timeline_ids: mask[torch.tensor(timeline_ids, dtype=torch.long, device=logits.device)] = False remaining = self.remaining_duration_beats() if remaining < 1.0 / max(VALID_DUR_DENS): risky_ids = [SLD_BEG, *HLD_TO_ID.values()] risky_ids = [t for t in risky_ids if 0 <= int(t) < vocab_size] if risky_ids: mask[torch.tensor(risky_ids, dtype=torch.long, device=logits.device)] = False cfg_dur = self._config_duration_tensor(logits.device, vocab_size) mask &= cfg_dur <= remaining + 1e-6 for div_tok, div_value in ID_TO_DIV.items(): if 0 <= div_tok < vocab_size: mask[div_tok] = self.remaining_beats() >= 4.0 / max(float(div_value), 1.0) - 1e-6 masked[mask] = logits[mask] return masked allowed = self.allowed_tokens() if allowed: idx = torch.tensor([t for t in sorted(allowed) if 0 <= t < vocab_size], dtype=torch.long, device=logits.device) if idx.numel() > 0: masked[idx] = logits[idx] return masked def _normal_mask(self, device: torch.device, vocab_size: int) -> torch.Tensor: key = (device, vocab_size) cached = self._normal_mask_cache.get(key) if cached is not None: return cached mask = torch.zeros(vocab_size, dtype=torch.bool, device=device) ids = {RST, SLD_BEG} | set(NOTE_TOKENS) | set(ID_TO_CONFIG.keys()) | set(self.grammar.allowed_div_tokens) if self.max_sim >= 2: ids.add(SIM_BEG) ids.discard(PAD) ids.discard(BOS) ids.discard(EOS) valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] if valid: mask[torch.tensor(valid, dtype=torch.long, device=device)] = True self._normal_mask_cache[key] = mask return mask def _config_duration_tensor(self, device: torch.device, vocab_size: int) -> torch.Tensor: key = (device, vocab_size) cached = self._config_duration_cache.get(key) if cached is not None: return cached dur = torch.zeros(vocab_size, dtype=torch.float32, device=device) for token_id, spec in ID_TO_CONFIG.items(): if 0 <= token_id < vocab_size: d = _config_duration_beats(spec) if d > 0: dur[token_id] = float(d) self._config_duration_cache[key] = dur return dur def step(self, tok: int) -> None: mode_before = self.grammar.mode if mode_before is GrammarMode.NORMAL: if tok in HLD_TO_ID.values(): self.pending_duration_start_beat = self.current_beat elif tok == SLD_BEG: self.slide_start_beat = self.current_beat elif tok == SIM_BEG: self.sim_start_beat = self.current_beat if mode_before is GrammarMode.SLIDE_BODY and tok == SLD_END_TOKEN: self.pending_duration_start_beat = self.slide_start_beat self.slide_start_beat = None if mode_before is GrammarMode.SIM_BODY and tok == SIM_END and self.grammar.sim_has_hold: self.pending_duration_start_beat = self.sim_start_beat self.sim_start_beat = None if mode_before is GrammarMode.DUR_NUM: self.pending_dur_num = ID_TO_DUR_NUM.get(int(tok), int(tok)) elif mode_before is GrammarMode.DUR_DEN and self.pending_dur_num is not None: den_value = ID_TO_DUR_DEN.get(int(tok), max(int(tok), 1)) beat, den = snap_duration(self.pending_dur_num, den_value, self.remaining_duration_beats(), self.max_duration_beats) self.pending_dur_num = None self.pending_duration_start_beat = None in_sim_body = self.grammar.mode in (GrammarMode.SIM_COUNT, GrammarMode.SIM_BODY) in_slide_body = self.slide_group_active and tok != SLD_BEG if self.skip_beat > 0: self.skip_beat -= 1 elif tok == DUR_TOKEN: self.skip_beat = 2 elif tok in ID_TO_DIV: self.div_value = float(ID_TO_DIV[tok]) elif is_timeline_token(torch.tensor(tok)) and not in_sim_body and not in_slide_body: self.current_beat += 4.0 / max(self.div_value, 1.0) if tok == SLD_BEG: self.slide_group_active = True elif tok == SLD_END_TOKEN: self.slide_group_active = False self.grammar.step(tok) def _would_advance_time(self, tok: int) -> bool: if self.skip_beat > 0 or tok == DUR_TOKEN: return False in_sim_body = self.grammar.mode in (GrammarMode.SIM_COUNT, GrammarMode.SIM_BODY) in_slide_body = self.slide_group_active and tok != SLD_BEG return bool(is_timeline_token(torch.tensor(tok)) and not in_sim_body and not in_slide_body) def clamp_duration_tokens(tokens: list[int], total_beats: float, start_offset_beats: float, min_division: int = 4) -> list[int]: """Clamp generated hold/slide durations so notes cannot extend past track end.""" result = list(tokens) state = DecodeState(bpm=150.0, total_beats=total_beats, start_offset_beats=start_offset_beats, min_division=min_division) i = 0 while i < len(result): tok = result[i] if tok in ID_TO_CONFIG: result[i] = _replace_config_duration_token( tok, state.remaining_duration_beats(), state.max_duration_beats, ) state.step(result[i]) i += 1 continue if tok == DUR and i + 2 < len(result): result[i + 1], result[i + 2] = snap_duration( ID_TO_DUR_NUM.get(result[i + 1], result[i + 1]), ID_TO_DUR_DEN.get(result[i + 2], result[i + 2]), state.remaining_duration_beats(), state.max_duration_beats ) result[i + 1] = DUR_NUM_TO_ID[result[i + 1]] result[i + 2] = DUR_DEN_TO_ID[result[i + 2]] state.step(tok) state.step(result[i + 1]) state.step(result[i + 2]) i += 3 continue state.step(tok) i += 1 return result def duration_report(tokens: list[int], total_beats: float, start_offset_beats: float, min_division: int = 4) -> dict[str, float | int]: state = DecodeState(bpm=150.0, total_beats=total_beats, start_offset_beats=start_offset_beats, min_division=min_division) n_dur = 0 max_dur = 0.0 overrun = 0 i = 0 while i < len(tokens): tok = tokens[i] if tok in ID_TO_CONFIG: dur = _config_duration_beats(ID_TO_CONFIG[tok]) if dur > 0: n_dur += 1 max_dur = max(max_dur, dur) if state.current_beat + dur > total_beats + 1e-6: overrun += 1 if tok == DUR and i + 2 < len(tokens): num = ID_TO_DUR_NUM.get(tokens[i + 1], tokens[i + 1]) den = ID_TO_DUR_DEN.get(tokens[i + 2], tokens[i + 2]) dur = max(float(num), 1.0) / max(float(den), 1.0) n_dur += 1 max_dur = max(max_dur, dur) start = state.pending_duration_start_beat if start is None: start = state.current_beat if start + dur > total_beats + 1e-6: overrun += 1 state.step(tok) i += 1 return {"durations": n_dur, "max_beats": max_dur, "overrun": overrun} def simultaneous_report(tokens: list[int], max_sim: int = 2) -> dict[str, int]: n_sim = 0 max_count = 0 over_limit = 0 first_bad = -1 i = 0 while i < len(tokens): tok = tokens[i] if tok == SIM_BEG and i + 1 < len(tokens): count = 2 if tokens[i + 1] == SIM_COUNT_2 else int(tokens[i + 1]) n_sim += 1 max_count = max(max_count, count) if count > max_sim: over_limit += 1 if first_bad < 0: first_bad = i i += 2 continue i += 1 return {"groups": n_sim, "max_count": max_count, "over_limit": over_limit, "first_bad": first_bad} def sanitize_sim_tokens(tokens: list[int], max_sim: int = 2) -> list[int]: """Remove malformed or over-wide SIM groups before decode/export.""" result: list[int] = [] i = 0 while i < len(tokens): tok = tokens[i] if tok != SIM_BEG: result.append(tok) i += 1 continue if i + 1 >= len(tokens): i += 1 continue declared = int(tokens[i + 1]) j = i + 2 body: list[int] = [] while j < len(tokens) and tokens[j] not in (SIM_END, EOS): if tokens[j] in NOTE_TOKENS: body.append(tokens[j]) j += 1 declared_count = 2 if declared == SIM_COUNT_2 else declared if declared_count == 2 and len(body) == 2 and max_sim >= 2: result.extend([SIM_BEG, SIM_COUNT_2, body[0], body[1], SIM_END]) elif len(body) >= 2 and max_sim >= 2: result.extend([SIM_BEG, SIM_COUNT_2, body[0], body[1], SIM_END]) result.extend(body[2:]) else: result.extend(body[:1]) i = j + 1 if j < len(tokens) and tokens[j] == SIM_END else j return result