from __future__ import annotations from dataclasses import dataclass, field from enum import Enum, auto import torch from tokenizer import ( BOS, EOS, HLD_ON, MASK, META_BPM, META_END, PAD, RST, SEP, SLD_MID, SLD_ON, BRK_TO_ID, DUR, HLD_TO_ID, ID_TO_DIV, SIM_BEG, SIM_END, SLD_BEG, SLD_END_TOKEN, SLD_TO_ID, TAP_TO_ID, TCH_TO_ID, SIM_COUNT_2, DUR_NUM_TO_ID, DUR_DEN_TO_ID, CONFIG_BASE, ) class GrammarMode(Enum): NORMAL = auto() SIM_COUNT = auto() SIM_BODY = auto() NEED_DUR = auto() DUR_NUM = auto() DUR_DEN = auto() SLIDE_BODY = auto() NOTE_TOKENS = ( set(TAP_TO_ID.values()) | set(BRK_TO_ID.values()) | set(HLD_TO_ID.values()) | set(SLD_TO_ID.values()) | set(TCH_TO_ID.values()) ) HOLD_TOKENS = set(HLD_TO_ID.values()) SLIDE_POINT_TOKENS = set(SLD_TO_ID.values()) META_TOKENS = set(range(META_BPM, META_END + 1)) ENGINE_ONLY_TOKENS = {PAD, BOS, EOS, SEP, MASK, HLD_ON, SLD_ON, SLD_MID} | META_TOKENS # Duration parameters are raw integer tokens, not normal vocabulary tokens. # Keep them in the same compact domain produced by real charts; allowing the # whole uint8 range lets generation sample special/control-looking ids and # absurdly long holds/slides. VALID_DUR_NUMS = set(DUR_NUM_TO_ID.values()) VALID_DUR_DENS = set(DUR_DEN_TO_ID.values()) def config_tokens() -> set[int]: # Config vocab can be extended after importing this module when a checkpoint # loads its learned config_vocab. Keep this dynamic instead of freezing at import. import tokenizer as chart_tokenizer return set(chart_tokenizer.ID_TO_CONFIG.keys()) @dataclass class ChartGrammarState: """Hard token grammar for chart generation. This keeps the sequence parseable. It intentionally does not enforce chart quality or muri rules; those belong in a later event/window validator. """ min_division: int = 1 max_sim: int = 2 mode: GrammarMode = GrammarMode.NORMAL sim_target: int = 0 sim_seen: int = 0 sim_has_hold: bool = False slide_points: int = 0 last_event_token: int | None = None can_take_duration: bool = False active_slide: bool = False allowed_div_tokens: set[int] = field(default_factory=set) def __post_init__(self) -> None: self.allowed_div_tokens = { div_id for div_id, div_value in ID_TO_DIV.items() if div_value >= self.min_division } def allowed_tokens(self) -> set[int]: if self.mode is GrammarMode.SIM_COUNT: return {SIM_COUNT_2} if self.max_sim >= 2 else set() if self.mode is GrammarMode.SIM_BODY: if self.sim_seen >= self.sim_target: return {SIM_END} return NOTE_TOKENS - SLIDE_POINT_TOKENS if self.mode is GrammarMode.NEED_DUR: return {DUR} if self.mode is GrammarMode.DUR_NUM: return VALID_DUR_NUMS if self.mode is GrammarMode.DUR_DEN: return VALID_DUR_DENS if self.mode is GrammarMode.SLIDE_BODY: allowed = set(SLIDE_POINT_TOKENS) if self.slide_points >= 2: allowed.add(SLD_END_TOKEN) return allowed allowed = {RST, SIM_BEG, SLD_BEG} | NOTE_TOKENS | config_tokens() | self.allowed_div_tokens return allowed - ENGINE_ONLY_TOKENS def apply_logits_mask(self, logits: torch.Tensor) -> torch.Tensor: allowed = self.allowed_tokens() masked = torch.full_like(logits, float("-inf")) idx = torch.tensor(sorted(allowed), dtype=torch.long, device=logits.device) masked[idx] = logits[idx] return masked def step(self, token: int) -> None: if self.mode is GrammarMode.SIM_COUNT: self.sim_target = 2 if token == SIM_COUNT_2 else min(max(int(token), 1), self.max_sim) self.sim_seen = 0 self.sim_has_hold = False self.mode = GrammarMode.SIM_BODY return if self.mode is GrammarMode.SIM_BODY: if token == SIM_END: self.last_event_token = SIM_END if self.sim_has_hold: self.mode = GrammarMode.NEED_DUR self.can_take_duration = True else: self.mode = GrammarMode.NORMAL self.can_take_duration = False else: if token in HOLD_TOKENS: self.sim_has_hold = True self.sim_seen += 1 return if self.mode is GrammarMode.NEED_DUR: if token == DUR: self.mode = GrammarMode.DUR_NUM return if self.mode is GrammarMode.DUR_NUM: self.mode = GrammarMode.DUR_DEN return if self.mode is GrammarMode.DUR_DEN: self.mode = GrammarMode.NORMAL self.last_event_token = None self.can_take_duration = False return if self.mode is GrammarMode.SLIDE_BODY: if token == SLD_END_TOKEN: self.mode = GrammarMode.NEED_DUR self.active_slide = False self.last_event_token = SLD_END_TOKEN self.can_take_duration = True elif token in SLIDE_POINT_TOKENS: self.slide_points += 1 return if token == SIM_BEG: self.mode = GrammarMode.SIM_COUNT self.last_event_token = SIM_BEG self.can_take_duration = False elif token == SLD_BEG: self.mode = GrammarMode.SLIDE_BODY self.slide_points = 0 self.active_slide = True self.last_event_token = SLD_BEG self.can_take_duration = False elif token == DUR: self.mode = GrammarMode.DUR_NUM elif token in ID_TO_DIV: return elif token == RST or token in NOTE_TOKENS or token >= CONFIG_BASE: self.last_event_token = token if token in HOLD_TOKENS: self.mode = GrammarMode.NEED_DUR self.can_take_duration = True else: self.can_take_duration = False