| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from enum import Enum, auto |
|
|
| import torch |
| from tokenizer import ( |
| BOS, EOS, HLD_ON, MASK, META_BPM, META_END, PAD, RST, SEP, SLD_MID, SLD_ON, |
| BRK_TO_ID, DUR, HLD_TO_ID, ID_TO_DIV, SIM_BEG, SIM_END, SLD_BEG, |
| SLD_END_TOKEN, SLD_TO_ID, TAP_TO_ID, TCH_TO_ID, |
| SIM_COUNT_2, DUR_NUM_TO_ID, DUR_DEN_TO_ID, |
| CONFIG_BASE, |
| ) |
|
|
|
|
| class GrammarMode(Enum): |
| NORMAL = auto() |
| SIM_COUNT = auto() |
| SIM_BODY = auto() |
| NEED_DUR = auto() |
| DUR_NUM = auto() |
| DUR_DEN = auto() |
| SLIDE_BODY = auto() |
|
|
|
|
| NOTE_TOKENS = ( |
| set(TAP_TO_ID.values()) |
| | set(BRK_TO_ID.values()) |
| | set(HLD_TO_ID.values()) |
| | set(SLD_TO_ID.values()) |
| | set(TCH_TO_ID.values()) |
| ) |
| HOLD_TOKENS = set(HLD_TO_ID.values()) |
| SLIDE_POINT_TOKENS = set(SLD_TO_ID.values()) |
| META_TOKENS = set(range(META_BPM, META_END + 1)) |
| ENGINE_ONLY_TOKENS = {PAD, BOS, EOS, SEP, MASK, HLD_ON, SLD_ON, SLD_MID} | META_TOKENS |
| |
| |
| |
| |
| VALID_DUR_NUMS = set(DUR_NUM_TO_ID.values()) |
| VALID_DUR_DENS = set(DUR_DEN_TO_ID.values()) |
|
|
|
|
| def config_tokens() -> set[int]: |
| |
| |
| import tokenizer as chart_tokenizer |
| return set(chart_tokenizer.ID_TO_CONFIG.keys()) |
|
|
|
|
| @dataclass |
| class ChartGrammarState: |
| """Hard token grammar for chart generation. |
| |
| This keeps the sequence parseable. It intentionally does not enforce chart |
| quality or muri rules; those belong in a later event/window validator. |
| """ |
|
|
| min_division: int = 1 |
| max_sim: int = 2 |
| mode: GrammarMode = GrammarMode.NORMAL |
| sim_target: int = 0 |
| sim_seen: int = 0 |
| sim_has_hold: bool = False |
| slide_points: int = 0 |
| last_event_token: int | None = None |
| can_take_duration: bool = False |
| active_slide: bool = False |
| allowed_div_tokens: set[int] = field(default_factory=set) |
|
|
| def __post_init__(self) -> None: |
| self.allowed_div_tokens = { |
| div_id for div_id, div_value in ID_TO_DIV.items() |
| if div_value >= self.min_division |
| } |
|
|
| def allowed_tokens(self) -> set[int]: |
| if self.mode is GrammarMode.SIM_COUNT: |
| return {SIM_COUNT_2} if self.max_sim >= 2 else set() |
|
|
| if self.mode is GrammarMode.SIM_BODY: |
| if self.sim_seen >= self.sim_target: |
| return {SIM_END} |
| return NOTE_TOKENS - SLIDE_POINT_TOKENS |
|
|
| if self.mode is GrammarMode.NEED_DUR: |
| return {DUR} |
|
|
| if self.mode is GrammarMode.DUR_NUM: |
| return VALID_DUR_NUMS |
|
|
| if self.mode is GrammarMode.DUR_DEN: |
| return VALID_DUR_DENS |
|
|
| if self.mode is GrammarMode.SLIDE_BODY: |
| allowed = set(SLIDE_POINT_TOKENS) |
| if self.slide_points >= 2: |
| allowed.add(SLD_END_TOKEN) |
| return allowed |
|
|
| allowed = {RST, SIM_BEG, SLD_BEG} | NOTE_TOKENS | config_tokens() | self.allowed_div_tokens |
| return allowed - ENGINE_ONLY_TOKENS |
|
|
| def apply_logits_mask(self, logits: torch.Tensor) -> torch.Tensor: |
| allowed = self.allowed_tokens() |
| masked = torch.full_like(logits, float("-inf")) |
| idx = torch.tensor(sorted(allowed), dtype=torch.long, device=logits.device) |
| masked[idx] = logits[idx] |
| return masked |
|
|
| def step(self, token: int) -> None: |
| if self.mode is GrammarMode.SIM_COUNT: |
| self.sim_target = 2 if token == SIM_COUNT_2 else min(max(int(token), 1), self.max_sim) |
| self.sim_seen = 0 |
| self.sim_has_hold = False |
| self.mode = GrammarMode.SIM_BODY |
| return |
|
|
| if self.mode is GrammarMode.SIM_BODY: |
| if token == SIM_END: |
| self.last_event_token = SIM_END |
| if self.sim_has_hold: |
| self.mode = GrammarMode.NEED_DUR |
| self.can_take_duration = True |
| else: |
| self.mode = GrammarMode.NORMAL |
| self.can_take_duration = False |
| else: |
| if token in HOLD_TOKENS: |
| self.sim_has_hold = True |
| self.sim_seen += 1 |
| return |
|
|
| if self.mode is GrammarMode.NEED_DUR: |
| if token == DUR: |
| self.mode = GrammarMode.DUR_NUM |
| return |
|
|
| if self.mode is GrammarMode.DUR_NUM: |
| self.mode = GrammarMode.DUR_DEN |
| return |
|
|
| if self.mode is GrammarMode.DUR_DEN: |
| self.mode = GrammarMode.NORMAL |
| self.last_event_token = None |
| self.can_take_duration = False |
| return |
|
|
| if self.mode is GrammarMode.SLIDE_BODY: |
| if token == SLD_END_TOKEN: |
| self.mode = GrammarMode.NEED_DUR |
| self.active_slide = False |
| self.last_event_token = SLD_END_TOKEN |
| self.can_take_duration = True |
| elif token in SLIDE_POINT_TOKENS: |
| self.slide_points += 1 |
| return |
|
|
| if token == SIM_BEG: |
| self.mode = GrammarMode.SIM_COUNT |
| self.last_event_token = SIM_BEG |
| self.can_take_duration = False |
| elif token == SLD_BEG: |
| self.mode = GrammarMode.SLIDE_BODY |
| self.slide_points = 0 |
| self.active_slide = True |
| self.last_event_token = SLD_BEG |
| self.can_take_duration = False |
| elif token == DUR: |
| self.mode = GrammarMode.DUR_NUM |
| elif token in ID_TO_DIV: |
| return |
| elif token == RST or token in NOTE_TOKENS or token >= CONFIG_BASE: |
| self.last_event_token = token |
| if token in HOLD_TOKENS: |
| self.mode = GrammarMode.NEED_DUR |
| self.can_take_duration = True |
| else: |
| self.can_take_duration = False |
|
|