maiChartGen / grammar.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
6.18 kB
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
import torch
from tokenizer import (
BOS, EOS, HLD_ON, MASK, META_BPM, META_END, PAD, RST, SEP, SLD_MID, SLD_ON,
BRK_TO_ID, DUR, HLD_TO_ID, ID_TO_DIV, SIM_BEG, SIM_END, SLD_BEG,
SLD_END_TOKEN, SLD_TO_ID, TAP_TO_ID, TCH_TO_ID,
SIM_COUNT_2, DUR_NUM_TO_ID, DUR_DEN_TO_ID,
CONFIG_BASE,
)
class GrammarMode(Enum):
NORMAL = auto()
SIM_COUNT = auto()
SIM_BODY = auto()
NEED_DUR = auto()
DUR_NUM = auto()
DUR_DEN = auto()
SLIDE_BODY = auto()
NOTE_TOKENS = (
set(TAP_TO_ID.values())
| set(BRK_TO_ID.values())
| set(HLD_TO_ID.values())
| set(SLD_TO_ID.values())
| set(TCH_TO_ID.values())
)
HOLD_TOKENS = set(HLD_TO_ID.values())
SLIDE_POINT_TOKENS = set(SLD_TO_ID.values())
META_TOKENS = set(range(META_BPM, META_END + 1))
ENGINE_ONLY_TOKENS = {PAD, BOS, EOS, SEP, MASK, HLD_ON, SLD_ON, SLD_MID} | META_TOKENS
# Duration parameters are raw integer tokens, not normal vocabulary tokens.
# Keep them in the same compact domain produced by real charts; allowing the
# whole uint8 range lets generation sample special/control-looking ids and
# absurdly long holds/slides.
VALID_DUR_NUMS = set(DUR_NUM_TO_ID.values())
VALID_DUR_DENS = set(DUR_DEN_TO_ID.values())
def config_tokens() -> set[int]:
# Config vocab can be extended after importing this module when a checkpoint
# loads its learned config_vocab. Keep this dynamic instead of freezing at import.
import tokenizer as chart_tokenizer
return set(chart_tokenizer.ID_TO_CONFIG.keys())
@dataclass
class ChartGrammarState:
"""Hard token grammar for chart generation.
This keeps the sequence parseable. It intentionally does not enforce chart
quality or muri rules; those belong in a later event/window validator.
"""
min_division: int = 1
max_sim: int = 2
mode: GrammarMode = GrammarMode.NORMAL
sim_target: int = 0
sim_seen: int = 0
sim_has_hold: bool = False
slide_points: int = 0
last_event_token: int | None = None
can_take_duration: bool = False
active_slide: bool = False
allowed_div_tokens: set[int] = field(default_factory=set)
def __post_init__(self) -> None:
self.allowed_div_tokens = {
div_id for div_id, div_value in ID_TO_DIV.items()
if div_value >= self.min_division
}
def allowed_tokens(self) -> set[int]:
if self.mode is GrammarMode.SIM_COUNT:
return {SIM_COUNT_2} if self.max_sim >= 2 else set()
if self.mode is GrammarMode.SIM_BODY:
if self.sim_seen >= self.sim_target:
return {SIM_END}
return NOTE_TOKENS - SLIDE_POINT_TOKENS
if self.mode is GrammarMode.NEED_DUR:
return {DUR}
if self.mode is GrammarMode.DUR_NUM:
return VALID_DUR_NUMS
if self.mode is GrammarMode.DUR_DEN:
return VALID_DUR_DENS
if self.mode is GrammarMode.SLIDE_BODY:
allowed = set(SLIDE_POINT_TOKENS)
if self.slide_points >= 2:
allowed.add(SLD_END_TOKEN)
return allowed
allowed = {RST, SIM_BEG, SLD_BEG} | NOTE_TOKENS | config_tokens() | self.allowed_div_tokens
return allowed - ENGINE_ONLY_TOKENS
def apply_logits_mask(self, logits: torch.Tensor) -> torch.Tensor:
allowed = self.allowed_tokens()
masked = torch.full_like(logits, float("-inf"))
idx = torch.tensor(sorted(allowed), dtype=torch.long, device=logits.device)
masked[idx] = logits[idx]
return masked
def step(self, token: int) -> None:
if self.mode is GrammarMode.SIM_COUNT:
self.sim_target = 2 if token == SIM_COUNT_2 else min(max(int(token), 1), self.max_sim)
self.sim_seen = 0
self.sim_has_hold = False
self.mode = GrammarMode.SIM_BODY
return
if self.mode is GrammarMode.SIM_BODY:
if token == SIM_END:
self.last_event_token = SIM_END
if self.sim_has_hold:
self.mode = GrammarMode.NEED_DUR
self.can_take_duration = True
else:
self.mode = GrammarMode.NORMAL
self.can_take_duration = False
else:
if token in HOLD_TOKENS:
self.sim_has_hold = True
self.sim_seen += 1
return
if self.mode is GrammarMode.NEED_DUR:
if token == DUR:
self.mode = GrammarMode.DUR_NUM
return
if self.mode is GrammarMode.DUR_NUM:
self.mode = GrammarMode.DUR_DEN
return
if self.mode is GrammarMode.DUR_DEN:
self.mode = GrammarMode.NORMAL
self.last_event_token = None
self.can_take_duration = False
return
if self.mode is GrammarMode.SLIDE_BODY:
if token == SLD_END_TOKEN:
self.mode = GrammarMode.NEED_DUR
self.active_slide = False
self.last_event_token = SLD_END_TOKEN
self.can_take_duration = True
elif token in SLIDE_POINT_TOKENS:
self.slide_points += 1
return
if token == SIM_BEG:
self.mode = GrammarMode.SIM_COUNT
self.last_event_token = SIM_BEG
self.can_take_duration = False
elif token == SLD_BEG:
self.mode = GrammarMode.SLIDE_BODY
self.slide_points = 0
self.active_slide = True
self.last_event_token = SLD_BEG
self.can_take_duration = False
elif token == DUR:
self.mode = GrammarMode.DUR_NUM
elif token in ID_TO_DIV:
return
elif token == RST or token in NOTE_TOKENS or token >= CONFIG_BASE:
self.last_event_token = token
if token in HOLD_TOKENS:
self.mode = GrammarMode.NEED_DUR
self.can_take_duration = True
else:
self.can_take_duration = False