| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
|
|
| from grammar import ChartGrammarState, GrammarMode |
| from model import DIV_MAP, DUR_TOKEN, is_timeline_token |
| from grammar import NOTE_TOKENS |
| from tokenizer import ( |
| BOS, EOS, PAD, RST, DUR, HLD_TO_ID, ID_TO_DIV, SIM_BEG, SIM_END, SLD_BEG, |
| SLD_END_TOKEN, SIM_COUNT_2, DUR_NUM_TO_ID, ID_TO_DUR_NUM, DUR_DEN_TO_ID, |
| ID_TO_DUR_DEN, CONFIG_TO_ID, ID_TO_CONFIG, |
| ) |
|
|
|
|
| VALID_DUR_NUMS = list(DUR_NUM_TO_ID.keys()) |
| VALID_DUR_DENS = list(DUR_DEN_TO_ID.keys()) |
| MAX_DURATION_BEATS = 4.0 |
|
|
|
|
| def snap_duration(beat: int, subdiv: int, max_beats: float, |
| max_duration_beats: float = MAX_DURATION_BEATS) -> tuple[int, int]: |
| """Snap/clamp duration params to common values that fit before track end.""" |
| limit = min(max_beats, max_duration_beats) |
| if limit <= 0: |
| return 1, 64 |
| candidates = [ |
| (b, d) for b in VALID_DUR_NUMS for d in VALID_DUR_DENS |
| if b / d <= limit + 1e-6 |
| ] |
| if not candidates: |
| return 1, 64 |
| raw = max(float(beat), 1.0) / max(float(subdiv), 1.0) |
| return min(candidates, key=lambda x: (abs((x[0] / x[1]) - raw), x[1], x[0])) |
|
|
|
|
| def _config_duration_beats(spec: tuple) -> float: |
| if not spec: |
| return 0.0 |
| kind = spec[0] |
| if kind in ("hld", "sld") and len(spec) >= 4: |
| return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0) |
| if kind == "pair" and len(spec) >= 7 and ("hld" in (spec[1], spec[3])): |
| return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0) |
| if kind == "multi" and len(spec) >= 5 and spec[1] == "hld": |
| return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0) |
| return 0.0 |
|
|
|
|
| def _replace_config_duration_token(token_id: int, max_beats: float, |
| max_duration_beats: float = MAX_DURATION_BEATS) -> int: |
| spec = ID_TO_CONFIG.get(token_id) |
| if spec is None or _config_duration_beats(spec) <= 0: |
| return token_id |
| n, d = snap_duration(int(spec[-2]), int(spec[-1]), max_beats, max_duration_beats) |
| new_spec = tuple(list(spec[:-2]) + [n, d]) |
| return CONFIG_TO_ID.get(new_spec, RST) |
|
|
|
|
| @dataclass |
| class DecodeState: |
| """Stateful validator for candidate next-token decoding.""" |
|
|
| bpm: float |
| total_beats: float |
| start_offset_beats: float |
| min_division: int = 4 |
| max_sim: int = 2 |
| max_duration_beats: float = MAX_DURATION_BEATS |
|
|
| def __post_init__(self) -> None: |
| self.grammar = ChartGrammarState(min_division=self.min_division, max_sim=self.max_sim) |
| self.current_beat = float(self.start_offset_beats) |
| self.div_value = float(self.min_division) |
| self.skip_beat = 0 |
| self.slide_group_active = False |
| self.pending_dur_num: int | None = None |
| self.pending_duration_start_beat: float | None = None |
| self.sim_start_beat: float | None = None |
| self.slide_start_beat: float | None = None |
| self._normal_mask_cache: dict[tuple[torch.device, int], torch.Tensor] = {} |
| self._config_duration_cache: dict[tuple[torch.device, int], torch.Tensor] = {} |
|
|
| def remaining_beats(self) -> float: |
| return max(0.0, self.total_beats - self.current_beat) |
|
|
| def remaining_duration_beats(self) -> float: |
| start = self.pending_duration_start_beat |
| if start is None: |
| start = self.current_beat |
| return max(0.0, min(self.total_beats - start, self.max_duration_beats)) |
|
|
| def allowed_tokens(self) -> set[int]: |
| allowed = set(self.grammar.allowed_tokens()) |
| remaining = self.remaining_duration_beats() |
|
|
| if self.grammar.mode is GrammarMode.DUR_NUM: |
| nums = {DUR_NUM_TO_ID[b] for b in VALID_DUR_NUMS if b / max(VALID_DUR_DENS) <= remaining + 1e-6} |
| return allowed & nums if nums else {DUR_NUM_TO_ID[1]} |
|
|
| if self.grammar.mode is GrammarMode.DUR_DEN: |
| beat_num = self.pending_dur_num or 1 |
| dens = {DUR_DEN_TO_ID[d] for d in VALID_DUR_DENS if beat_num / d <= remaining + 1e-6} |
| return allowed & dens if dens else {DUR_DEN_TO_ID[64]} |
|
|
| filtered: set[int] = set() |
| for tok in allowed: |
| if self.is_candidate_safe(tok): |
| filtered.add(tok) |
| return filtered |
|
|
| def can_start_sim(self) -> bool: |
| return self.grammar.mode is GrammarMode.NORMAL and self.max_sim >= 2 |
|
|
| def is_candidate_safe(self, tok: int) -> bool: |
| if self.grammar.mode is not GrammarMode.NORMAL: |
| return True |
| if tok in (PAD, BOS, EOS): |
| return False |
| if tok == SIM_BEG and self.max_sim < 2: |
| return False |
| if tok in ID_TO_DIV: |
| next_div = float(ID_TO_DIV[tok]) |
| return self.remaining_beats() >= 4.0 / max(next_div, 1.0) - 1e-6 |
| if tok in HLD_TO_ID.values() or tok == SLD_BEG: |
| if self.remaining_beats() < 1.0 / max(VALID_DUR_DENS): |
| return False |
| if self._would_advance_time(tok): |
| step = 4.0 / max(self.div_value, 1.0) |
| return self.current_beat + step <= self.total_beats + 1e-6 |
| return True |
|
|
| def apply_logits_mask(self, logits: torch.Tensor) -> torch.Tensor: |
| masked = torch.full_like(logits, float("-inf")) |
| vocab_size = int(logits.numel()) |
|
|
| if self.grammar.mode is GrammarMode.NORMAL: |
| mask = self._normal_mask(logits.device, vocab_size).clone() |
|
|
| step = 4.0 / max(self.div_value, 1.0) |
| if self.current_beat + step > self.total_beats + 1e-6: |
| timeline_ids = [RST, SIM_BEG, SLD_BEG, *NOTE_TOKENS, *ID_TO_CONFIG.keys()] |
| timeline_ids = [t for t in timeline_ids if 0 <= int(t) < vocab_size] |
| if timeline_ids: |
| mask[torch.tensor(timeline_ids, dtype=torch.long, device=logits.device)] = False |
|
|
| remaining = self.remaining_duration_beats() |
| if remaining < 1.0 / max(VALID_DUR_DENS): |
| risky_ids = [SLD_BEG, *HLD_TO_ID.values()] |
| risky_ids = [t for t in risky_ids if 0 <= int(t) < vocab_size] |
| if risky_ids: |
| mask[torch.tensor(risky_ids, dtype=torch.long, device=logits.device)] = False |
|
|
| cfg_dur = self._config_duration_tensor(logits.device, vocab_size) |
| mask &= cfg_dur <= remaining + 1e-6 |
|
|
| for div_tok, div_value in ID_TO_DIV.items(): |
| if 0 <= div_tok < vocab_size: |
| mask[div_tok] = self.remaining_beats() >= 4.0 / max(float(div_value), 1.0) - 1e-6 |
|
|
| masked[mask] = logits[mask] |
| return masked |
|
|
| allowed = self.allowed_tokens() |
| if allowed: |
| idx = torch.tensor([t for t in sorted(allowed) if 0 <= t < vocab_size], |
| dtype=torch.long, device=logits.device) |
| if idx.numel() > 0: |
| masked[idx] = logits[idx] |
| return masked |
|
|
| def _normal_mask(self, device: torch.device, vocab_size: int) -> torch.Tensor: |
| key = (device, vocab_size) |
| cached = self._normal_mask_cache.get(key) |
| if cached is not None: |
| return cached |
|
|
| mask = torch.zeros(vocab_size, dtype=torch.bool, device=device) |
| ids = {RST, SLD_BEG} | set(NOTE_TOKENS) | set(ID_TO_CONFIG.keys()) | set(self.grammar.allowed_div_tokens) |
| if self.max_sim >= 2: |
| ids.add(SIM_BEG) |
| ids.discard(PAD) |
| ids.discard(BOS) |
| ids.discard(EOS) |
| valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] |
| if valid: |
| mask[torch.tensor(valid, dtype=torch.long, device=device)] = True |
| self._normal_mask_cache[key] = mask |
| return mask |
|
|
| def _config_duration_tensor(self, device: torch.device, vocab_size: int) -> torch.Tensor: |
| key = (device, vocab_size) |
| cached = self._config_duration_cache.get(key) |
| if cached is not None: |
| return cached |
|
|
| dur = torch.zeros(vocab_size, dtype=torch.float32, device=device) |
| for token_id, spec in ID_TO_CONFIG.items(): |
| if 0 <= token_id < vocab_size: |
| d = _config_duration_beats(spec) |
| if d > 0: |
| dur[token_id] = float(d) |
| self._config_duration_cache[key] = dur |
| return dur |
|
|
| def step(self, tok: int) -> None: |
| mode_before = self.grammar.mode |
|
|
| if mode_before is GrammarMode.NORMAL: |
| if tok in HLD_TO_ID.values(): |
| self.pending_duration_start_beat = self.current_beat |
| elif tok == SLD_BEG: |
| self.slide_start_beat = self.current_beat |
| elif tok == SIM_BEG: |
| self.sim_start_beat = self.current_beat |
|
|
| if mode_before is GrammarMode.SLIDE_BODY and tok == SLD_END_TOKEN: |
| self.pending_duration_start_beat = self.slide_start_beat |
| self.slide_start_beat = None |
|
|
| if mode_before is GrammarMode.SIM_BODY and tok == SIM_END and self.grammar.sim_has_hold: |
| self.pending_duration_start_beat = self.sim_start_beat |
| self.sim_start_beat = None |
|
|
| if mode_before is GrammarMode.DUR_NUM: |
| self.pending_dur_num = ID_TO_DUR_NUM.get(int(tok), int(tok)) |
| elif mode_before is GrammarMode.DUR_DEN and self.pending_dur_num is not None: |
| den_value = ID_TO_DUR_DEN.get(int(tok), max(int(tok), 1)) |
| beat, den = snap_duration(self.pending_dur_num, den_value, |
| self.remaining_duration_beats(), |
| self.max_duration_beats) |
| self.pending_dur_num = None |
| self.pending_duration_start_beat = None |
|
|
| in_sim_body = self.grammar.mode in (GrammarMode.SIM_COUNT, GrammarMode.SIM_BODY) |
| in_slide_body = self.slide_group_active and tok != SLD_BEG |
| if self.skip_beat > 0: |
| self.skip_beat -= 1 |
| elif tok == DUR_TOKEN: |
| self.skip_beat = 2 |
| elif tok in ID_TO_DIV: |
| self.div_value = float(ID_TO_DIV[tok]) |
| elif is_timeline_token(torch.tensor(tok)) and not in_sim_body and not in_slide_body: |
| self.current_beat += 4.0 / max(self.div_value, 1.0) |
|
|
| if tok == SLD_BEG: |
| self.slide_group_active = True |
| elif tok == SLD_END_TOKEN: |
| self.slide_group_active = False |
|
|
| self.grammar.step(tok) |
|
|
| def _would_advance_time(self, tok: int) -> bool: |
| if self.skip_beat > 0 or tok == DUR_TOKEN: |
| return False |
| in_sim_body = self.grammar.mode in (GrammarMode.SIM_COUNT, GrammarMode.SIM_BODY) |
| in_slide_body = self.slide_group_active and tok != SLD_BEG |
| return bool(is_timeline_token(torch.tensor(tok)) and not in_sim_body and not in_slide_body) |
|
|
|
|
| def clamp_duration_tokens(tokens: list[int], total_beats: float, |
| start_offset_beats: float, min_division: int = 4) -> list[int]: |
| """Clamp generated hold/slide durations so notes cannot extend past track end.""" |
| result = list(tokens) |
| state = DecodeState(bpm=150.0, total_beats=total_beats, |
| start_offset_beats=start_offset_beats, |
| min_division=min_division) |
| i = 0 |
| while i < len(result): |
| tok = result[i] |
| if tok in ID_TO_CONFIG: |
| result[i] = _replace_config_duration_token( |
| tok, |
| state.remaining_duration_beats(), |
| state.max_duration_beats, |
| ) |
| state.step(result[i]) |
| i += 1 |
| continue |
| if tok == DUR and i + 2 < len(result): |
| result[i + 1], result[i + 2] = snap_duration( |
| ID_TO_DUR_NUM.get(result[i + 1], result[i + 1]), |
| ID_TO_DUR_DEN.get(result[i + 2], result[i + 2]), |
| state.remaining_duration_beats(), |
| state.max_duration_beats |
| ) |
| result[i + 1] = DUR_NUM_TO_ID[result[i + 1]] |
| result[i + 2] = DUR_DEN_TO_ID[result[i + 2]] |
| state.step(tok) |
| state.step(result[i + 1]) |
| state.step(result[i + 2]) |
| i += 3 |
| continue |
| state.step(tok) |
| i += 1 |
| return result |
|
|
|
|
| def duration_report(tokens: list[int], total_beats: float, start_offset_beats: float, |
| min_division: int = 4) -> dict[str, float | int]: |
| state = DecodeState(bpm=150.0, total_beats=total_beats, |
| start_offset_beats=start_offset_beats, |
| min_division=min_division) |
| n_dur = 0 |
| max_dur = 0.0 |
| overrun = 0 |
| i = 0 |
| while i < len(tokens): |
| tok = tokens[i] |
| if tok in ID_TO_CONFIG: |
| dur = _config_duration_beats(ID_TO_CONFIG[tok]) |
| if dur > 0: |
| n_dur += 1 |
| max_dur = max(max_dur, dur) |
| if state.current_beat + dur > total_beats + 1e-6: |
| overrun += 1 |
| if tok == DUR and i + 2 < len(tokens): |
| num = ID_TO_DUR_NUM.get(tokens[i + 1], tokens[i + 1]) |
| den = ID_TO_DUR_DEN.get(tokens[i + 2], tokens[i + 2]) |
| dur = max(float(num), 1.0) / max(float(den), 1.0) |
| n_dur += 1 |
| max_dur = max(max_dur, dur) |
| start = state.pending_duration_start_beat |
| if start is None: |
| start = state.current_beat |
| if start + dur > total_beats + 1e-6: |
| overrun += 1 |
| state.step(tok) |
| i += 1 |
| return {"durations": n_dur, "max_beats": max_dur, "overrun": overrun} |
|
|
|
|
| def simultaneous_report(tokens: list[int], max_sim: int = 2) -> dict[str, int]: |
| n_sim = 0 |
| max_count = 0 |
| over_limit = 0 |
| first_bad = -1 |
| i = 0 |
| while i < len(tokens): |
| tok = tokens[i] |
| if tok == SIM_BEG and i + 1 < len(tokens): |
| count = 2 if tokens[i + 1] == SIM_COUNT_2 else int(tokens[i + 1]) |
| n_sim += 1 |
| max_count = max(max_count, count) |
| if count > max_sim: |
| over_limit += 1 |
| if first_bad < 0: |
| first_bad = i |
| i += 2 |
| continue |
| i += 1 |
| return {"groups": n_sim, "max_count": max_count, "over_limit": over_limit, "first_bad": first_bad} |
|
|
|
|
| def sanitize_sim_tokens(tokens: list[int], max_sim: int = 2) -> list[int]: |
| """Remove malformed or over-wide SIM groups before decode/export.""" |
| result: list[int] = [] |
| i = 0 |
| while i < len(tokens): |
| tok = tokens[i] |
| if tok != SIM_BEG: |
| result.append(tok) |
| i += 1 |
| continue |
|
|
| if i + 1 >= len(tokens): |
| i += 1 |
| continue |
|
|
| declared = int(tokens[i + 1]) |
| j = i + 2 |
| body: list[int] = [] |
| while j < len(tokens) and tokens[j] not in (SIM_END, EOS): |
| if tokens[j] in NOTE_TOKENS: |
| body.append(tokens[j]) |
| j += 1 |
|
|
| declared_count = 2 if declared == SIM_COUNT_2 else declared |
| if declared_count == 2 and len(body) == 2 and max_sim >= 2: |
| result.extend([SIM_BEG, SIM_COUNT_2, body[0], body[1], SIM_END]) |
| elif len(body) >= 2 and max_sim >= 2: |
| result.extend([SIM_BEG, SIM_COUNT_2, body[0], body[1], SIM_END]) |
| result.extend(body[2:]) |
| else: |
| result.extend(body[:1]) |
|
|
| i = j + 1 if j < len(tokens) and tokens[j] == SIM_END else j |
| return result |
|
|