maiChartGen / constrained_decode.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
15.5 kB
from __future__ import annotations
from dataclasses import dataclass
import torch
from grammar import ChartGrammarState, GrammarMode
from model import DIV_MAP, DUR_TOKEN, is_timeline_token
from grammar import NOTE_TOKENS
from tokenizer import (
BOS, EOS, PAD, RST, DUR, HLD_TO_ID, ID_TO_DIV, SIM_BEG, SIM_END, SLD_BEG,
SLD_END_TOKEN, SIM_COUNT_2, DUR_NUM_TO_ID, ID_TO_DUR_NUM, DUR_DEN_TO_ID,
ID_TO_DUR_DEN, CONFIG_TO_ID, ID_TO_CONFIG,
)
VALID_DUR_NUMS = list(DUR_NUM_TO_ID.keys())
VALID_DUR_DENS = list(DUR_DEN_TO_ID.keys())
MAX_DURATION_BEATS = 4.0
def snap_duration(beat: int, subdiv: int, max_beats: float,
max_duration_beats: float = MAX_DURATION_BEATS) -> tuple[int, int]:
"""Snap/clamp duration params to common values that fit before track end."""
limit = min(max_beats, max_duration_beats)
if limit <= 0:
return 1, 64
candidates = [
(b, d) for b in VALID_DUR_NUMS for d in VALID_DUR_DENS
if b / d <= limit + 1e-6
]
if not candidates:
return 1, 64
raw = max(float(beat), 1.0) / max(float(subdiv), 1.0)
return min(candidates, key=lambda x: (abs((x[0] / x[1]) - raw), x[1], x[0]))
def _config_duration_beats(spec: tuple) -> float:
if not spec:
return 0.0
kind = spec[0]
if kind in ("hld", "sld") and len(spec) >= 4:
return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0)
if kind == "pair" and len(spec) >= 7 and ("hld" in (spec[1], spec[3])):
return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0)
if kind == "multi" and len(spec) >= 5 and spec[1] == "hld":
return max(float(spec[-2]), 1.0) / max(float(spec[-1]), 1.0)
return 0.0
def _replace_config_duration_token(token_id: int, max_beats: float,
max_duration_beats: float = MAX_DURATION_BEATS) -> int:
spec = ID_TO_CONFIG.get(token_id)
if spec is None or _config_duration_beats(spec) <= 0:
return token_id
n, d = snap_duration(int(spec[-2]), int(spec[-1]), max_beats, max_duration_beats)
new_spec = tuple(list(spec[:-2]) + [n, d])
return CONFIG_TO_ID.get(new_spec, RST)
@dataclass
class DecodeState:
"""Stateful validator for candidate next-token decoding."""
bpm: float
total_beats: float
start_offset_beats: float
min_division: int = 4
max_sim: int = 2
max_duration_beats: float = MAX_DURATION_BEATS
def __post_init__(self) -> None:
self.grammar = ChartGrammarState(min_division=self.min_division, max_sim=self.max_sim)
self.current_beat = float(self.start_offset_beats)
self.div_value = float(self.min_division)
self.skip_beat = 0
self.slide_group_active = False
self.pending_dur_num: int | None = None
self.pending_duration_start_beat: float | None = None
self.sim_start_beat: float | None = None
self.slide_start_beat: float | None = None
self._normal_mask_cache: dict[tuple[torch.device, int], torch.Tensor] = {}
self._config_duration_cache: dict[tuple[torch.device, int], torch.Tensor] = {}
def remaining_beats(self) -> float:
return max(0.0, self.total_beats - self.current_beat)
def remaining_duration_beats(self) -> float:
start = self.pending_duration_start_beat
if start is None:
start = self.current_beat
return max(0.0, min(self.total_beats - start, self.max_duration_beats))
def allowed_tokens(self) -> set[int]:
allowed = set(self.grammar.allowed_tokens())
remaining = self.remaining_duration_beats()
if self.grammar.mode is GrammarMode.DUR_NUM:
nums = {DUR_NUM_TO_ID[b] for b in VALID_DUR_NUMS if b / max(VALID_DUR_DENS) <= remaining + 1e-6}
return allowed & nums if nums else {DUR_NUM_TO_ID[1]}
if self.grammar.mode is GrammarMode.DUR_DEN:
beat_num = self.pending_dur_num or 1
dens = {DUR_DEN_TO_ID[d] for d in VALID_DUR_DENS if beat_num / d <= remaining + 1e-6}
return allowed & dens if dens else {DUR_DEN_TO_ID[64]}
filtered: set[int] = set()
for tok in allowed:
if self.is_candidate_safe(tok):
filtered.add(tok)
return filtered
def can_start_sim(self) -> bool:
return self.grammar.mode is GrammarMode.NORMAL and self.max_sim >= 2
def is_candidate_safe(self, tok: int) -> bool:
if self.grammar.mode is not GrammarMode.NORMAL:
return True
if tok in (PAD, BOS, EOS):
return False
if tok == SIM_BEG and self.max_sim < 2:
return False
if tok in ID_TO_DIV:
next_div = float(ID_TO_DIV[tok])
return self.remaining_beats() >= 4.0 / max(next_div, 1.0) - 1e-6
if tok in HLD_TO_ID.values() or tok == SLD_BEG:
if self.remaining_beats() < 1.0 / max(VALID_DUR_DENS):
return False
if self._would_advance_time(tok):
step = 4.0 / max(self.div_value, 1.0)
return self.current_beat + step <= self.total_beats + 1e-6
return True
def apply_logits_mask(self, logits: torch.Tensor) -> torch.Tensor:
masked = torch.full_like(logits, float("-inf"))
vocab_size = int(logits.numel())
if self.grammar.mode is GrammarMode.NORMAL:
mask = self._normal_mask(logits.device, vocab_size).clone()
step = 4.0 / max(self.div_value, 1.0)
if self.current_beat + step > self.total_beats + 1e-6:
timeline_ids = [RST, SIM_BEG, SLD_BEG, *NOTE_TOKENS, *ID_TO_CONFIG.keys()]
timeline_ids = [t for t in timeline_ids if 0 <= int(t) < vocab_size]
if timeline_ids:
mask[torch.tensor(timeline_ids, dtype=torch.long, device=logits.device)] = False
remaining = self.remaining_duration_beats()
if remaining < 1.0 / max(VALID_DUR_DENS):
risky_ids = [SLD_BEG, *HLD_TO_ID.values()]
risky_ids = [t for t in risky_ids if 0 <= int(t) < vocab_size]
if risky_ids:
mask[torch.tensor(risky_ids, dtype=torch.long, device=logits.device)] = False
cfg_dur = self._config_duration_tensor(logits.device, vocab_size)
mask &= cfg_dur <= remaining + 1e-6
for div_tok, div_value in ID_TO_DIV.items():
if 0 <= div_tok < vocab_size:
mask[div_tok] = self.remaining_beats() >= 4.0 / max(float(div_value), 1.0) - 1e-6
masked[mask] = logits[mask]
return masked
allowed = self.allowed_tokens()
if allowed:
idx = torch.tensor([t for t in sorted(allowed) if 0 <= t < vocab_size],
dtype=torch.long, device=logits.device)
if idx.numel() > 0:
masked[idx] = logits[idx]
return masked
def _normal_mask(self, device: torch.device, vocab_size: int) -> torch.Tensor:
key = (device, vocab_size)
cached = self._normal_mask_cache.get(key)
if cached is not None:
return cached
mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
ids = {RST, SLD_BEG} | set(NOTE_TOKENS) | set(ID_TO_CONFIG.keys()) | set(self.grammar.allowed_div_tokens)
if self.max_sim >= 2:
ids.add(SIM_BEG)
ids.discard(PAD)
ids.discard(BOS)
ids.discard(EOS)
valid = [int(t) for t in ids if 0 <= int(t) < vocab_size]
if valid:
mask[torch.tensor(valid, dtype=torch.long, device=device)] = True
self._normal_mask_cache[key] = mask
return mask
def _config_duration_tensor(self, device: torch.device, vocab_size: int) -> torch.Tensor:
key = (device, vocab_size)
cached = self._config_duration_cache.get(key)
if cached is not None:
return cached
dur = torch.zeros(vocab_size, dtype=torch.float32, device=device)
for token_id, spec in ID_TO_CONFIG.items():
if 0 <= token_id < vocab_size:
d = _config_duration_beats(spec)
if d > 0:
dur[token_id] = float(d)
self._config_duration_cache[key] = dur
return dur
def step(self, tok: int) -> None:
mode_before = self.grammar.mode
if mode_before is GrammarMode.NORMAL:
if tok in HLD_TO_ID.values():
self.pending_duration_start_beat = self.current_beat
elif tok == SLD_BEG:
self.slide_start_beat = self.current_beat
elif tok == SIM_BEG:
self.sim_start_beat = self.current_beat
if mode_before is GrammarMode.SLIDE_BODY and tok == SLD_END_TOKEN:
self.pending_duration_start_beat = self.slide_start_beat
self.slide_start_beat = None
if mode_before is GrammarMode.SIM_BODY and tok == SIM_END and self.grammar.sim_has_hold:
self.pending_duration_start_beat = self.sim_start_beat
self.sim_start_beat = None
if mode_before is GrammarMode.DUR_NUM:
self.pending_dur_num = ID_TO_DUR_NUM.get(int(tok), int(tok))
elif mode_before is GrammarMode.DUR_DEN and self.pending_dur_num is not None:
den_value = ID_TO_DUR_DEN.get(int(tok), max(int(tok), 1))
beat, den = snap_duration(self.pending_dur_num, den_value,
self.remaining_duration_beats(),
self.max_duration_beats)
self.pending_dur_num = None
self.pending_duration_start_beat = None
in_sim_body = self.grammar.mode in (GrammarMode.SIM_COUNT, GrammarMode.SIM_BODY)
in_slide_body = self.slide_group_active and tok != SLD_BEG
if self.skip_beat > 0:
self.skip_beat -= 1
elif tok == DUR_TOKEN:
self.skip_beat = 2
elif tok in ID_TO_DIV:
self.div_value = float(ID_TO_DIV[tok])
elif is_timeline_token(torch.tensor(tok)) and not in_sim_body and not in_slide_body:
self.current_beat += 4.0 / max(self.div_value, 1.0)
if tok == SLD_BEG:
self.slide_group_active = True
elif tok == SLD_END_TOKEN:
self.slide_group_active = False
self.grammar.step(tok)
def _would_advance_time(self, tok: int) -> bool:
if self.skip_beat > 0 or tok == DUR_TOKEN:
return False
in_sim_body = self.grammar.mode in (GrammarMode.SIM_COUNT, GrammarMode.SIM_BODY)
in_slide_body = self.slide_group_active and tok != SLD_BEG
return bool(is_timeline_token(torch.tensor(tok)) and not in_sim_body and not in_slide_body)
def clamp_duration_tokens(tokens: list[int], total_beats: float,
start_offset_beats: float, min_division: int = 4) -> list[int]:
"""Clamp generated hold/slide durations so notes cannot extend past track end."""
result = list(tokens)
state = DecodeState(bpm=150.0, total_beats=total_beats,
start_offset_beats=start_offset_beats,
min_division=min_division)
i = 0
while i < len(result):
tok = result[i]
if tok in ID_TO_CONFIG:
result[i] = _replace_config_duration_token(
tok,
state.remaining_duration_beats(),
state.max_duration_beats,
)
state.step(result[i])
i += 1
continue
if tok == DUR and i + 2 < len(result):
result[i + 1], result[i + 2] = snap_duration(
ID_TO_DUR_NUM.get(result[i + 1], result[i + 1]),
ID_TO_DUR_DEN.get(result[i + 2], result[i + 2]),
state.remaining_duration_beats(),
state.max_duration_beats
)
result[i + 1] = DUR_NUM_TO_ID[result[i + 1]]
result[i + 2] = DUR_DEN_TO_ID[result[i + 2]]
state.step(tok)
state.step(result[i + 1])
state.step(result[i + 2])
i += 3
continue
state.step(tok)
i += 1
return result
def duration_report(tokens: list[int], total_beats: float, start_offset_beats: float,
min_division: int = 4) -> dict[str, float | int]:
state = DecodeState(bpm=150.0, total_beats=total_beats,
start_offset_beats=start_offset_beats,
min_division=min_division)
n_dur = 0
max_dur = 0.0
overrun = 0
i = 0
while i < len(tokens):
tok = tokens[i]
if tok in ID_TO_CONFIG:
dur = _config_duration_beats(ID_TO_CONFIG[tok])
if dur > 0:
n_dur += 1
max_dur = max(max_dur, dur)
if state.current_beat + dur > total_beats + 1e-6:
overrun += 1
if tok == DUR and i + 2 < len(tokens):
num = ID_TO_DUR_NUM.get(tokens[i + 1], tokens[i + 1])
den = ID_TO_DUR_DEN.get(tokens[i + 2], tokens[i + 2])
dur = max(float(num), 1.0) / max(float(den), 1.0)
n_dur += 1
max_dur = max(max_dur, dur)
start = state.pending_duration_start_beat
if start is None:
start = state.current_beat
if start + dur > total_beats + 1e-6:
overrun += 1
state.step(tok)
i += 1
return {"durations": n_dur, "max_beats": max_dur, "overrun": overrun}
def simultaneous_report(tokens: list[int], max_sim: int = 2) -> dict[str, int]:
n_sim = 0
max_count = 0
over_limit = 0
first_bad = -1
i = 0
while i < len(tokens):
tok = tokens[i]
if tok == SIM_BEG and i + 1 < len(tokens):
count = 2 if tokens[i + 1] == SIM_COUNT_2 else int(tokens[i + 1])
n_sim += 1
max_count = max(max_count, count)
if count > max_sim:
over_limit += 1
if first_bad < 0:
first_bad = i
i += 2
continue
i += 1
return {"groups": n_sim, "max_count": max_count, "over_limit": over_limit, "first_bad": first_bad}
def sanitize_sim_tokens(tokens: list[int], max_sim: int = 2) -> list[int]:
"""Remove malformed or over-wide SIM groups before decode/export."""
result: list[int] = []
i = 0
while i < len(tokens):
tok = tokens[i]
if tok != SIM_BEG:
result.append(tok)
i += 1
continue
if i + 1 >= len(tokens):
i += 1
continue
declared = int(tokens[i + 1])
j = i + 2
body: list[int] = []
while j < len(tokens) and tokens[j] not in (SIM_END, EOS):
if tokens[j] in NOTE_TOKENS:
body.append(tokens[j])
j += 1
declared_count = 2 if declared == SIM_COUNT_2 else declared
if declared_count == 2 and len(body) == 2 and max_sim >= 2:
result.extend([SIM_BEG, SIM_COUNT_2, body[0], body[1], SIM_END])
elif len(body) >= 2 and max_sim >= 2:
result.extend([SIM_BEG, SIM_COUNT_2, body[0], body[1], SIM_END])
result.extend(body[2:])
else:
result.extend(body[:1])
i = j + 1 if j < len(tokens) and tokens[j] == SIM_END else j
return result