maiChartGen / tokenizer.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
41.3 kB
"""
Maimai Chart Tokenizer โ€” rule-based bidirectional chart โ†” token conversion.
Design:
- BPM is NOT tokenized (computed separately by external program)
- Beat division (div_N) tokens control note granularity
- Each note event โ†’ 1~5 tokens, lossless round-trip
Vocabulary size: 256 (0-255), with room for expansion.
Usage:
from tokenizer import MaiChartTokenizer
tok = MaiChartTokenizer()
tokens = tok.encode(chart)
chart2 = tok.decode(tokens) # lossless
"""
from __future__ import annotations
from dataclasses import dataclass
import json
from typing import Optional
from mai_parser.models import Chart, TouchNote
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Vocabulary definition
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# --- Special tokens (0-4) ---
PAD = 0
BOS = 1
EOS = 2
SEP = 3
MASK = 4
_SPECIAL_END = 5
# --- Beat division tokens (5-16): {1} {2} {4} {8} {16} {32} {48} {64} {128} {192} {384} ---
_DIV_VALUES = [1, 2, 4, 8, 16, 32, 48, 64, 128, 192, 384]
DIV_BASE = _SPECIAL_END
DIV_TO_ID: dict[int, int] = {v: DIV_BASE + i for i, v in enumerate(_DIV_VALUES)}
ID_TO_DIV: dict[int, int] = {v: k for k, v in DIV_TO_ID.items()}
DIV_END = DIV_BASE + len(_DIV_VALUES)
# --- Rest token (17) ---
RST = DIV_END
_RST_END = RST + 1
# --- Duration marker (18): followed by 2 tokens [beat, subdiv] ---
DUR = _RST_END
_DUR_END = DUR + 1
# --- Tap tokens (19-26): tap_1 ~ tap_8 ---
TAP_BASE = _DUR_END
TAP_TO_ID = {i: TAP_BASE + i - 1 for i in range(1, 9)}
ID_TO_TAP = {v: k for k, v in TAP_TO_ID.items()}
TAP_END = TAP_BASE + 8
# --- Break tokens (27-34): brk_1 ~ brk_8 ---
BRK_BASE = TAP_END
BRK_TO_ID = {i: BRK_BASE + i - 1 for i in range(1, 9)}
ID_TO_BRK = {v: k for k, v in BRK_TO_ID.items()}
BRK_END = BRK_BASE + 8
# --- Hold tokens (35-42): hld_1 ~ hld_8 ---
HLD_BASE = BRK_END
HLD_TO_ID = {i: HLD_BASE + i - 1 for i in range(1, 9)}
ID_TO_HLD = {v: k for k, v in HLD_TO_ID.items()}
HLD_END = HLD_BASE + 8
# --- Slide waypoint tokens (43-50): sld_1 ~ sld_8 ---
SLD_BASE = HLD_END
SLD_TO_ID = {i: SLD_BASE + i - 1 for i in range(1, 9)}
ID_TO_SLD = {v: k for k, v in SLD_TO_ID.items()}
SLD_END = SLD_BASE + 8
# --- Slide control (51-52) ---
SLD_BEG = SLD_END # slide start, next token = point count
SLD_END_TOKEN = SLD_BEG + 1 # slide end marker
_SLD_CTRL_END = SLD_END_TOKEN + 1
# --- Simultaneous control (53-54) ---
SIM_BEG = _SLD_CTRL_END # simultaneous start, next token = note count
SIM_END = SIM_BEG + 1 # simultaneous end
_SIM_CTRL_END = SIM_END + 1
# --- Touch tokens (55-90): A1-A8, B1-B8, C1-C8, D1-D8, E1-E8 ---
TCH_BASE = _SIM_CTRL_END
_TOUCH_ZONES = ["A", "B", "C", "D", "E"]
_tch_map: dict[str, int] = {}
_idx = TCH_BASE
for zone in _TOUCH_ZONES:
for pos in range(1, 9):
_tch_map[f"{zone}{pos}"] = _idx
_idx += 1
# C (center touch without position)
_tch_map["C"] = _idx
_idx += 1
TCH_TO_ID = _tch_map
ID_TO_TCH = {v: k for k, v in TCH_TO_ID.items()}
TCH_END = _idx
# --- Parameter tokens (96-120): context-safe count/duration values ---
SIM_COUNT_2 = TCH_END
_SIM_COUNT_END = SIM_COUNT_2 + 1
_DUR_NUM_VALUES = [1, 2, 3, 4, 6, 8, 12, 16]
DUR_NUM_BASE = _SIM_COUNT_END
DUR_NUM_TO_ID = {v: DUR_NUM_BASE + i for i, v in enumerate(_DUR_NUM_VALUES)}
ID_TO_DUR_NUM = {v: k for k, v in DUR_NUM_TO_ID.items()}
DUR_NUM_END = DUR_NUM_BASE + len(_DUR_NUM_VALUES)
_DUR_DEN_VALUES = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64]
DUR_DEN_BASE = DUR_NUM_END
DUR_DEN_TO_ID = {v: DUR_DEN_BASE + i for i, v in enumerate(_DUR_DEN_VALUES)}
ID_TO_DUR_DEN = {v: k for k, v in DUR_DEN_TO_ID.items()}
DUR_DEN_END = DUR_DEN_BASE + len(_DUR_DEN_VALUES)
# --- Metadata/header/helper tokens (220-231): kept for backward compatibility ---
META_BPM = 220 # followed by BPM//2 as uint8
META_DIFF = 221 # followed by difficulty enum (0-4)
META_LEVEL = 222 # followed by level*10 as uint8
META_GENRE = 223 # followed by genre id
META_END = 224 # end of metadata header
SLD_MID = 229 # slide intermediate waypoint
HLD_ON = 230 # hold is ongoing at this beat position (informational)
SLD_ON = 231 # slide is ongoing at this beat position (informational)
# --- Whole time-slot configuration tokens (256+) ---
CONFIG_BASE = 256
def _duration_pairs(max_beats: float = 4.0) -> list[tuple[int, int]]:
pairs: list[tuple[int, int]] = []
for n in _DUR_NUM_VALUES:
for d in _DUR_DEN_VALUES:
if n / d <= max_beats + 1e-9:
pairs.append((n, d))
return pairs
CONFIG_DURATIONS = _duration_pairs(4.0)
CONFIG_TO_ID: dict[tuple, int] = {}
ID_TO_CONFIG: dict[int, tuple] = {}
def _normalize_config_spec(spec) -> tuple:
return tuple(tuple(x) if isinstance(x, list) else x for x in spec)
def _add_config(spec: tuple) -> int:
spec = _normalize_config_spec(spec)
if spec in CONFIG_TO_ID:
return CONFIG_TO_ID[spec]
idx = CONFIG_BASE + len(CONFIG_TO_ID)
CONFIG_TO_ID[spec] = idx
ID_TO_CONFIG[idx] = spec
if "_TOKEN_NAMES" in globals():
_TOKEN_NAMES[idx] = "cfg_" + "_".join(str(x) for x in spec)
global VOCAB_SIZE
VOCAB_SIZE = CONFIG_BASE + len(CONFIG_TO_ID)
if "MaiChartTokenizer" in globals():
MaiChartTokenizer.vocab_size = VOCAB_SIZE
return idx
def _build_config_vocab() -> None:
# Single button/touch events.
for pos in range(1, 9):
_add_config(("tap", pos))
_add_config(("brk", pos))
for dur in CONFIG_DURATIONS:
_add_config(("hld", pos, dur[0], dur[1]))
for region in sorted(TCH_TO_ID):
_add_config(("tch", region))
# Two-note simultaneous button configurations. Holds share one duration.
button_types = ("tap", "brk", "hld")
for p1 in range(1, 9):
for p2 in range(p1 + 1, 9):
for t1 in button_types:
for t2 in button_types:
if (t1, p1) > (t2, p2):
continue
if "hld" in (t1, t2):
for dur in CONFIG_DURATIONS:
_add_config(("pair", t1, p1, t2, p2, dur[0], dur[1]))
else:
_add_config(("pair", t1, p1, t2, p2))
# Common slide configurations: 2- and 3-point paths with duration.
for a in range(1, 9):
for b in range(1, 9):
if b == a:
continue
for dur in CONFIG_DURATIONS:
_add_config(("sld", a, b, dur[0], dur[1]))
for c in range(1, 9):
if c in (a, b):
continue
for dur in CONFIG_DURATIONS:
_add_config(("sld", a, b, c, dur[0], dur[1]))
_build_config_vocab()
VOCAB_SIZE = CONFIG_BASE + len(CONFIG_TO_ID)
TOKENIZER_VERSION = 3
def export_config_vocab() -> list[list]:
return [list(spec) for spec, _ in sorted(CONFIG_TO_ID.items(), key=lambda x: x[1])]
def load_config_vocab(specs: list) -> None:
for spec in specs:
_add_config(tuple(spec))
def save_config_vocab(path: str) -> None:
with open(path, "w", encoding="utf-8") as f:
json.dump(export_config_vocab(), f, ensure_ascii=False)
def load_config_vocab_file(path: str) -> None:
with open(path, "r", encoding="utf-8") as f:
load_config_vocab(json.load(f))
# --- Special tokens that start a multi-token group ---
_MULTI_TOKEN_STARTS = {DUR, SLD_BEG, SLD_END_TOKEN, SIM_BEG, SIM_END}
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Token name lookup (for debugging)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
_TOKEN_NAMES: dict[int, str] = {
PAD: "[PAD]",
BOS: "[BOS]",
EOS: "[EOS]",
SEP: "[SEP]",
MASK: "[MASK]",
RST: "[RST]",
DUR: "[DUR]",
SLD_BEG: "[SLD_BEG]",
SLD_END_TOKEN: "[SLD_END]",
SIM_BEG: "[SIM_BEG]",
SIM_END: "[SIM_END]",
META_BPM: "[META_BPM]",
META_DIFF: "[META_DIFF]",
META_LEVEL: "[META_LEVEL]",
META_GENRE: "[META_GENRE]",
META_END: "[META_END]",
SLD_MID: "[SLD_MID]",
HLD_ON: "[HLD_ON]",
SLD_ON: "[SLD_ON]",
}
for v, i in DIV_TO_ID.items():
_TOKEN_NAMES[i] = f"div_{v}"
for p, i in TAP_TO_ID.items():
_TOKEN_NAMES[i] = f"tap_{p}"
for p, i in BRK_TO_ID.items():
_TOKEN_NAMES[i] = f"brk_{p}"
for p, i in HLD_TO_ID.items():
_TOKEN_NAMES[i] = f"hld_{p}"
for p, i in SLD_TO_ID.items():
_TOKEN_NAMES[i] = f"sld_{p}"
for t, i in TCH_TO_ID.items():
_TOKEN_NAMES[i] = f"tch_{t}"
_TOKEN_NAMES[SIM_COUNT_2] = "sim_count_2"
for v, i in DUR_NUM_TO_ID.items():
_TOKEN_NAMES[i] = f"dur_num_{v}"
for v, i in DUR_DEN_TO_ID.items():
_TOKEN_NAMES[i] = f"dur_den_{v}"
for spec, i in CONFIG_TO_ID.items():
_TOKEN_NAMES[i] = "cfg_" + "_".join(str(x) for x in spec)
def token_name(token_id: int) -> str:
"""Human-readable name for a token ID."""
return _TOKEN_NAMES.get(token_id, f"<{token_id}>")
def _nearest(values: list[int], value: int) -> int:
return min(values, key=lambda x: abs(x - value))
def encode_duration_tokens(duration: tuple[int, int]) -> list[int]:
beat = _nearest(_DUR_NUM_VALUES, max(1, int(duration[0])))
den = _nearest(_DUR_DEN_VALUES, max(1, int(duration[1])))
return [DUR, DUR_NUM_TO_ID[beat], DUR_DEN_TO_ID[den]]
def read_duration_tokens(tokens: list[int], start: int) -> Optional[tuple[int, int]]:
if start + 2 >= len(tokens) or tokens[start] != DUR:
return None
num_tok = tokens[start + 1]
den_tok = tokens[start + 2]
if num_tok in ID_TO_DUR_NUM and den_tok in ID_TO_DUR_DEN:
return ID_TO_DUR_NUM[num_tok], ID_TO_DUR_DEN[den_tok]
# Backward compatibility with old checkpoints/caches that used raw ints.
beat = _nearest(_DUR_NUM_VALUES, max(1, min(int(num_tok), 16)))
den = _nearest(_DUR_DEN_VALUES, max(1, int(den_tok)))
return beat, den
def make_sim_tokens(note_tokens: list[int]) -> list[int]:
note_tokens = [t for t in note_tokens if t not in (PAD, BOS, EOS)]
if len(note_tokens) <= 1:
return note_tokens
result: list[int] = []
result.extend([SIM_BEG, SIM_COUNT_2, note_tokens[0], note_tokens[1], SIM_END])
result.extend(note_tokens[2:])
return result
def _snap_config_duration(duration: tuple[int, int] | None) -> tuple[int, int]:
if not duration:
return (1, 1)
n = _nearest(_DUR_NUM_VALUES, max(1, int(duration[0])))
d = _nearest(_DUR_DEN_VALUES, max(1, int(duration[1])))
if n / d > 4.0:
return min(CONFIG_DURATIONS, key=lambda x: (abs((x[0] / x[1]) - 4.0), x[1]))
return n, d
def config_token_for_note(note: TouchNote) -> int | None:
if note.is_rest or note.is_end:
return None
if note.is_touch and len(note.touch_regions) == 1 and not note.is_hold:
return CONFIG_TO_ID.get(("tch", note.touch_regions[0]))
if note.is_slide:
path = note.slide_path or note.positions
if len(path) >= 2:
dur = _snap_config_duration(note.hold_duration)
spec = ("sld", *path, dur[0], dur[1])
if spec in CONFIG_TO_ID:
return CONFIG_TO_ID[spec]
return None
if len(note.positions) == 1:
pos = note.positions[0]
if not (1 <= pos <= 8):
return None
if note.is_hold:
dur = _snap_config_duration(note.hold_duration)
return CONFIG_TO_ID.get(("hld", pos, dur[0], dur[1]))
if note.is_break:
return CONFIG_TO_ID.get(("brk", pos))
return CONFIG_TO_ID.get(("tap", pos))
if len(note.positions) == 2:
p1, p2 = sorted(note.positions)
if not (1 <= p1 <= 8 and 1 <= p2 <= 8):
return None
typ = "hld" if note.is_hold else "brk" if note.is_break else "tap"
if typ == "hld":
dur = _snap_config_duration(note.hold_duration)
return CONFIG_TO_ID.get(("pair", "hld", p1, "hld", p2, dur[0], dur[1]))
return CONFIG_TO_ID.get(("pair", typ, p1, typ, p2))
return None
def learn_config_from_note(note: TouchNote) -> int | None:
"""Register a config token from a real chart note, preserving rare shapes."""
if note.is_rest or note.is_end:
return None
existing = config_token_for_note(note)
if existing is not None:
return existing
if note.is_touch and note.touch_regions and not note.is_hold:
return _add_config(("touch_multi", *sorted(note.touch_regions)))
if note.is_slide:
path = note.slide_path or note.positions
if len(path) >= 2:
dur = _snap_config_duration(note.hold_duration)
return _add_config(("sld", *path, dur[0], dur[1]))
return None
if len(note.positions) >= 2:
positions = sorted(p for p in note.positions if 1 <= p <= 8)
if len(positions) < 2:
return None
typ = "hld" if note.is_hold else "brk" if note.is_break else "tap"
if note.is_hold:
dur = _snap_config_duration(note.hold_duration)
spec = ("multi", typ, *positions, dur[0], dur[1])
else:
spec = ("multi", typ, *positions)
if spec in CONFIG_TO_ID:
return CONFIG_TO_ID[spec]
if len(positions) == 2:
return config_token_for_note(note)
return _add_config(spec)
return None
def learn_config_vocab_from_charts(charts) -> int:
before = len(CONFIG_TO_ID)
for chart in charts:
for note in chart.notes:
learn_config_from_note(note)
return len(CONFIG_TO_ID) - before
def note_from_config_token(token_id: int, beat_div: int) -> TouchNote | None:
spec = ID_TO_CONFIG.get(token_id)
if spec is None:
return None
note = TouchNote(beat_div=beat_div)
kind = spec[0]
if kind == "tap":
note.positions = [int(spec[1])]
elif kind == "brk":
note.positions = [int(spec[1])]
note.is_break = True
elif kind == "hld":
note.positions = [int(spec[1])]
note.is_hold = True
note.hold_duration = (int(spec[2]), int(spec[3]))
elif kind == "tch":
note.is_touch = True
note.touch_regions = [str(spec[1])]
elif kind == "touch_multi":
note.is_touch = True
note.touch_regions = [str(x) for x in spec[1:]]
note.is_simultaneous = len(note.touch_regions) > 1
elif kind == "pair":
t1, p1, t2, p2 = spec[1], int(spec[2]), spec[3], int(spec[4])
note.positions = [p1, p2]
note.is_simultaneous = True
if t1 == "hld" or t2 == "hld":
note.is_hold = True
note.hold_duration = (int(spec[5]), int(spec[6]))
elif t1 == "brk" or t2 == "brk":
note.is_break = True
elif kind == "sld":
*path, n, d = spec[1:]
note.positions = [int(x) for x in path]
note.slide_path = list(note.positions)
note.is_slide = True
note.hold_duration = (int(n), int(d))
elif kind == "multi":
typ = spec[1]
if typ == "hld":
*positions, n, d = spec[2:]
note.positions = [int(x) for x in positions]
note.is_hold = True
note.hold_duration = (int(n), int(d))
else:
note.positions = [int(x) for x in spec[2:]]
note.is_break = typ == "brk"
note.is_simultaneous = len(note.positions) > 1
else:
return None
return note
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Tokenizer class
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
class MaiChartTokenizer:
"""
Rule-based bidirectional tokenizer for maimai charts.
encode(chart) โ†’ list[int] # chart โ†’ tokens
decode(tokens) โ†’ Chart # tokens โ†’ chart (lossless)
"""
vocab_size: int = VOCAB_SIZE
pad_token_id: int = PAD
bos_token_id: int = BOS
eos_token_id: int = EOS
mask_token_id: int = MASK
# โ”€โ”€ Encode โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def encode(self, chart: Chart, add_bos: bool = True,
add_eos: bool = True) -> list[int]:
"""
Convert a Chart's notes into a token sequence.
Args:
chart: Parsed Chart from mai_parser.
add_bos: Prepend [BOS] token.
add_eos: Append [EOS] token.
Returns:
List of token IDs.
"""
tokens: list[int] = []
if add_bos:
tokens.append(BOS)
current_div = 4 # default beat division
for note in chart.notes:
# Update beat division if changed
if note.beat_div != current_div:
current_div = note.beat_div
div_id = DIV_TO_ID.get(current_div)
if div_id is not None:
tokens.append(div_id)
# Encode the note
tokens.extend(self._encode_note(note))
if add_eos:
tokens.append(EOS)
return tokens
def _encode_note(self, note: TouchNote) -> list[int]:
"""Encode a single TouchNote โ†’ list of token IDs."""
if note.is_end:
return [EOS]
if note.is_rest:
return [RST]
cfg = config_token_for_note(note)
if cfg is not None:
return [cfg]
# Touch note
if note.is_touch:
return self._encode_touch(note)
# Break note
if note.is_break:
result = []
for pos in note.positions:
if 1 <= pos <= 8:
result.append(BRK_TO_ID[pos])
return make_sim_tokens(result)
# Hold note
if note.is_hold:
result = []
for pos in note.positions:
if 1 <= pos <= 8:
result.append(HLD_TO_ID[pos])
result = make_sim_tokens(result)
if note.hold_duration:
result.extend(encode_duration_tokens(note.hold_duration))
return result
# Slide note
if note.is_slide:
result = []
# Collect all positions (slide path)
positions = list(note.positions)
if note.slide_path:
# Use slide_path if available (more accurate)
positions = note.slide_path
if len(positions) >= 2:
result.append(SLD_BEG)
result.append(len(positions))
for pos in positions:
if 1 <= pos <= 8:
result.append(SLD_TO_ID[pos])
result.append(SLD_END_TOKEN)
elif len(positions) == 1 and 1 <= positions[0] <= 8:
result.append(SLD_TO_ID[positions[0]])
if note.hold_duration:
result.extend(encode_duration_tokens(note.hold_duration))
return result
# Regular tap
if len(note.positions) > 1:
result = []
for pos in note.positions:
if 1 <= pos <= 8:
result.append(TAP_TO_ID[pos])
return make_sim_tokens(result)
# Single tap
for pos in note.positions:
if 1 <= pos <= 8:
return [TAP_TO_ID[pos]]
return [RST] # fallback
def _encode_touch(self, note: TouchNote) -> list[int]:
"""Encode a touch note."""
result = []
for region in note.touch_regions:
tid = TCH_TO_ID.get(region)
if tid is not None:
result.append(tid)
if len(result) > 1:
result = make_sim_tokens(result)
if note.is_hold and note.hold_duration:
result.extend(encode_duration_tokens(note.hold_duration))
return result if result else [RST]
# โ”€โ”€ Decode โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def decode(self, tokens: list[int]) -> Chart:
"""
Convert a token sequence back into a Chart.
Args:
tokens: List of token IDs (may include BOS/EOS).
Returns:
Reconstructed Chart (notes only; metadata not recoverable from tokens alone).
"""
notes: list[TouchNote] = []
current_div = 4
i = 0
while i < len(tokens):
tid = tokens[i]
# Skip BOS
if tid == BOS:
i += 1
continue
# End of sequence
if tid == EOS:
note = TouchNote(beat_div=current_div, raw="E")
note.is_end = True
notes.append(note)
i += 1
continue
# Beat division change
if tid in ID_TO_DIV:
current_div = ID_TO_DIV[tid]
i += 1
continue
# Rest
if tid == RST:
note = TouchNote(beat_div=current_div, raw="")
note.is_rest = True
notes.append(note)
i += 1
continue
cfg_note = note_from_config_token(tid, current_div)
if cfg_note is not None:
notes.append(cfg_note)
i += 1
continue
# Duration marker โ†’ read next 2 tokens as (beat, subdiv)
# (handled inline in note decoding below)
# Slide start
if tid == SLD_BEG:
i += 1
if i >= len(tokens):
break
if tokens[i] in ID_TO_SLD or tokens[i] in (SLD_MID, SLD_ON):
# Unfolded slide format used by training/inference:
# SLD_BEG [SLD_ON] sld_a [SLD_MID/SLD_ON sld_b ...] SLD_END [DUR beat subdiv]
positions = []
while i < len(tokens) and tokens[i] not in (SLD_END_TOKEN, DUR, EOS):
if tokens[i] in (SLD_MID, SLD_ON):
i += 1
continue
pt = tokens[i]
if pt in ID_TO_SLD:
positions.append(ID_TO_SLD[pt])
i += 1
continue
# Malformed slide: stop before consuming unrelated chart events.
break
if i < len(tokens) and tokens[i] == SLD_END_TOKEN:
i += 1
dur = self._read_dur(tokens, i)
if dur:
i += 3
note = TouchNote(beat_div=current_div, positions=positions)
note.is_slide = True
note.slide_path = list(positions)
if dur:
note.hold_duration = dur
notes.append(note)
continue
n_pts = tokens[i]
i += 1
positions = []
for _ in range(n_pts):
if i >= len(tokens):
break
pt = tokens[i]
if pt in ID_TO_SLD:
positions.append(ID_TO_SLD[pt])
i += 1
# Skip SLD_END
if i < len(tokens) and tokens[i] == SLD_END_TOKEN:
i += 1
# Check for optional duration
dur = self._read_dur(tokens, i)
if dur:
i += 3 # DUR + beat + subdiv
note = TouchNote(beat_div=current_div, positions=positions)
note.is_slide = True
note.slide_path = list(positions)
if dur:
note.hold_duration = dur
notes.append(note)
continue
# Simultaneous begin
if tid == SIM_BEG:
i += 1
if i >= len(tokens):
break
count_tok = tokens[i]
n_notes = 2 if count_tok == SIM_COUNT_2 else int(count_tok)
i += 1
sub_notes: list[TouchNote] = []
dur = None
while i < len(tokens) and tokens[i] not in (SIM_END, EOS):
sub_tid = tokens[i]
if sub_tid == DUR:
dur = self._read_dur(tokens, i)
break # DUR after SIM group
sub_note = self._decode_single_note(sub_tid, current_div)
if sub_note:
sub_notes.append(sub_note)
i += 1
if i < len(tokens) and tokens[i] == SIM_END:
i += 1
# Merge sub-notes into one simultaneous note
if sub_notes:
merged = sub_notes[0]
all_pos = []
has_hold = merged.is_hold
has_break = merged.is_break
is_touch = merged.is_touch
all_touch_regions = list(merged.touch_regions)
for sn in sub_notes:
all_pos.extend(sn.positions)
has_hold = has_hold or sn.is_hold
has_break = has_break or sn.is_break
is_touch = is_touch or sn.is_touch
all_touch_regions.extend(sn.touch_regions)
merged.positions = all_pos
merged.is_simultaneous = True
merged.touch_regions = all_touch_regions
merged.is_touch = is_touch
if dur:
merged.hold_duration = dur
# For touch holds, don't set is_hold
if not is_touch:
has_hold = True
merged.is_hold = has_hold and not is_touch
merged.is_break = has_break and not is_touch
notes.append(merged)
continue
# Duration marker (standalone, should not normally happen)
if tid == DUR:
i += 3 # skip DUR + 2 values
continue
# Slide end, SIM end (standalone โ€” skip)
if tid in (SLD_END_TOKEN, SIM_END):
i += 1
continue
# Single note token
note = self._decode_single_note(tid, current_div)
if note:
# Check if next token is DUR (for hold/slide duration)
dur = self._read_dur(tokens, i + 1)
if dur:
note.hold_duration = dur
# Only set is_hold if not already a slide/touch
if not note.is_slide and not note.is_touch and not note.is_break:
note.is_hold = True
i += 3 # skip DUR + beat + subdiv
notes.append(note)
i += 1
from mai_parser.models import Difficulty
chart = Chart(difficulty_index=0, difficulty=Difficulty.ReMASTER)
chart.notes = notes
chart.compute_stats()
return chart
def _decode_single_note(self, tid: int, beat_div: int) -> Optional[TouchNote]:
"""Decode a single note token (not part of a group)."""
note = TouchNote(beat_div=beat_div)
if tid in ID_TO_TAP:
note.positions = [ID_TO_TAP[tid]]
return note
if tid in ID_TO_BRK:
note.positions = [ID_TO_BRK[tid]]
note.is_break = True
return note
if tid in ID_TO_HLD:
note.positions = [ID_TO_HLD[tid]]
note.is_hold = True
return note
if tid in ID_TO_SLD:
note.positions = [ID_TO_SLD[tid]]
note.is_slide = True
return note
if tid in ID_TO_TCH:
region = ID_TO_TCH[tid]
note.is_touch = True
note.touch_regions = [region]
return note
return None
def _read_dur(self, tokens: list[int], start: int) -> Optional[tuple[int, int]]:
"""Try to read DUR beat subdiv from tokens[start:]. Returns (beat, subdiv) or None.
Clamps to reasonable ranges to filter out hallucinated durations."""
return read_duration_tokens(tokens, start)
# โ”€โ”€ Batch โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def encode_batch(self, charts: list[Chart], pad_to: Optional[int] = None,
add_bos: bool = True, add_eos: bool = True,
return_tensors: bool = False):
"""
Encode a batch of charts, padding to the same length.
Args:
charts: List of Chart objects.
pad_to: Pad all sequences to this length (auto-detect max if None).
add_bos: Prepend BOS.
add_eos: Append EOS.
return_tensors: If True, return torch.Tensor (requires torch).
Returns:
If return_tensors=False: (list[list[int]], list[int]) = (token_seqs, lengths)
If return_tensors=True: (Tensor[batch, max_len], Tensor[batch])
"""
seqs = [self.encode(c, add_bos=add_bos, add_eos=add_eos) for c in charts]
lengths = [len(s) for s in seqs]
max_len = max(lengths) if pad_to is None else pad_to
padded = []
for seq in seqs:
if len(seq) < max_len:
seq = seq + [PAD] * (max_len - len(seq))
padded.append(seq[:max_len])
if return_tensors:
try:
import torch
return torch.tensor(padded, dtype=torch.long), torch.tensor(lengths, dtype=torch.long)
except ImportError:
raise ImportError("torch required for return_tensors=True")
return padded, lengths
# โ”€โ”€ Debug โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def tokens_to_str(self, tokens: list[int], max_show: int = 60) -> str:
"""Pretty-print a token sequence with context for raw parameter ids."""
parts = []
i = 0
shown = 0
while i < len(tokens) and shown < max_show:
tid = tokens[i]
if tid == DUR and i + 2 < len(tokens):
parts.append("[DUR]")
shown += 1
if shown < max_show:
parts.append(token_name(tokens[i + 1]))
shown += 1
if shown < max_show:
parts.append(token_name(tokens[i + 2]))
shown += 1
i += 3
continue
if tid == SIM_BEG and i + 1 < len(tokens):
parts.append("[SIM_BEG]")
shown += 1
if shown < max_show:
parts.append(token_name(tokens[i + 1]))
shown += 1
i += 2
continue
parts.append(token_name(tid))
shown += 1
i += 1
if i < len(tokens):
parts.append(f"... ({len(tokens) - i} more)")
return " ".join(parts)
def print_tokens(self, tokens: list[int], max_show: int = 60) -> None:
"""Print a token sequence."""
print(self.tokens_to_str(tokens, max_show))
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Metadata header builder
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def build_metadata_header(bpm: float, difficulty: int,
level_value: float, genre: int = 0) -> list[int]:
"""
Build a metadata header token sequence.
Format: [META_BPM] bpm_byte [META_DIFF] diff [META_LEVEL] level_byte [META_GENRE] genre [META_END]
This is prepended to chart tokens during training so the model
learns to associate metadata with chart style.
Args:
bpm: BPM value (e.g. 173.0)
difficulty: 0=BASIC..4=ReMASTER
level_value: e.g. 12.4
genre: Genre index
Returns:
List of token IDs.
"""
return [
META_BPM, int(bpm) // 2, # BPM 0-510 โ†’ 0-255
META_DIFF, difficulty,
META_LEVEL, min(255, int(level_value * 10)),
META_GENRE, genre,
META_END,
]
def encode_chart_with_header(chart: Chart, bpm: float, difficulty: int,
level_value: float, genre: int = 0) -> list[int]:
"""Encode chart with grammar-friendly slides (no metadata header, no EOS).
Metadata (BPM, difficulty, level, genre) is passed as separate condition
inputs to the model โ€” NOT as chart tokens. The model learns difficulty
from the diff_embed MoE routing, not from token-level metadata.
HLD_ON/SLD_ON and SLD_MID are inference context/helper tokens, not targets.
Returns: [BOS] + chart_tokens
"""
tok = MaiChartTokenizer()
chart_tokens = tok.encode(chart, add_bos=False, add_eos=False)
# add_eos=False avoids appending a synthetic EOS, but parsed charts may
# contain a terminal end note. Strip only terminal EOS tokens: raw numeric
# values 1/2 are also used as SIM counts, so removing all EOS ids corrupts
# simultaneous groups.
while chart_tokens and chart_tokens[-1] == EOS:
chart_tokens.pop()
chart_tokens = unfold_slides(chart_tokens)
return [BOS] + chart_tokens
def unfold_slides(tokens):
"""Unfold multi-segment slides into grammar-friendly waypoint tokens.
SLD_BEG n sld_a sld_b sld_c SLD_END โ†’ SLD_BEG sld_a sld_b sld_c SLD_END
"""
result, i = [], 0
while i < len(tokens):
t = tokens[i]
if t == SLD_BEG and i + 2 < len(tokens):
n = tokens[i + 1]
if 0 < n < 32 and i + 2 + n < len(tokens):
pts = tokens[i + 2 : i + 2 + n]
result.append(SLD_BEG)
result.extend(pts)
result.append(SLD_END_TOKEN)
i += 2 + n + 1; continue
result.append(t); i += 1
return result
def inject_ongoing_tokens(tokens: list[int]) -> list[int]:
"""Insert HLD_ON/SLD_ON markers at intermediate positions where a hold/slide is active.
HLD_n DUR beat subdiv ...tokens... โ†’ HLD_ON inserted at each non-DUR position
while the hold is active. Same for slides.
These are informational โ€” the model learns "a hold is ongoing here".
During inference they are suppressed; the engine doesn't generate them.
"""
result = []
current_div = 4.0
hold_beats = 0.0 # remaining beats of active hold
slide_beats = 0.0 # remaining beats of active slide
dur_skip = 0 # skip DUR parameter tokens
i = 0
while i < len(tokens):
t = tokens[i]
step = 4.0 / current_div
# โ”€โ”€ Inject ON tokens before note-level tokens โ”€โ”€
_is_note = (t >= TAP_BASE and t != DUR) or t == RST or t in ID_TO_DIV
if _is_note and dur_skip == 0:
if hold_beats > 0:
result.append(HLD_ON)
hold_beats -= step
if slide_beats > 0:
result.append(SLD_ON)
slide_beats -= step
# โ”€โ”€ Track hold/slide duration โ”€โ”€
if dur_skip > 0:
dur_skip -= 1
result.append(t); i += 1
continue
if t in ID_TO_DIV:
current_div = float(ID_TO_DIV.get(t, current_div))
elif t == DUR:
dur_skip = 2
elif t in HLD_TO_ID:
# Check if followed by DUR
if i + 3 < len(tokens) and tokens[i + 1] == DUR:
beat = tokens[i + 2]
subdiv = max(tokens[i + 3], 1)
hold_beats = beat / subdiv
elif t == SLD_BEG:
# Find DUR after slide waypoints
j = i + 2 # skip SLD_BEG + count
while j < len(tokens) and tokens[j] != SLD_END_TOKEN and tokens[j] != DUR:
j += 1
if j < len(tokens) and tokens[j] == SLD_END_TOKEN:
j += 1 # skip SLD_END
if j + 2 < len(tokens) and tokens[j] == DUR:
beat = tokens[j + 1]
subdiv = max(tokens[j + 2], 1)
slide_beats = beat / subdiv
result.append(t)
i += 1
return result
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Chart โ†’ maidata text conversion
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def notes_to_maitext(notes, bpm=150.0):
"""Convert TouchNote list back to maidata chart text.
Format: (173){4}1,2,3h[4:1],5b/8b,
"""
bpm_int = int(bpm)
current_div = 4
line = f"({bpm_int})"
measure = []
for note in notes:
if note.is_end:
if measure:
line += "{" + str(current_div) + "}" + ",".join(measure) + ","
return line + "\nE"
if note.beat_div != current_div:
if measure:
line += "{" + str(current_div) + "}" + ",".join(measure) + ","
measure = []
else:
line += "{" + str(current_div) + "},"
current_div = note.beat_div
measure.append(_note_to_text(note))
if measure:
line += "{" + str(current_div) + "}" + ",".join(measure) + ","
return line + "\nE"
def tokens_to_maitext(tokens, bpm=150.0):
"""Token sequence โ†’ maidata.txt chart text."""
tok = MaiChartTokenizer()
chart = tok.decode(tokens)
return notes_to_maitext(chart.notes, bpm)
def _note_to_text(note):
"""Single TouchNote โ†’ maidata text segment."""
if note.is_rest:
return ""
if note.is_touch:
# Normalize touch regions: C1..C8 โ†’ C, others keep (B7, E2, etc.)
regions = []
for r in note.touch_regions:
if r.startswith("C") and len(r) > 1:
regions.append("C")
else:
regions.append(r)
text = "/".join(regions)
if note.is_hold and note.hold_duration:
text += f"h[{note.hold_duration[0]}:{note.hold_duration[1]}]"
return text
if note.positions:
text = "/".join(str(p) for p in note.positions)
else:
return ""
if note.is_hold and note.hold_duration:
text += f"h[{note.hold_duration[0]}:{note.hold_duration[1]}]"
elif note.is_break:
text += "b"
elif note.is_slide and note.slide_path and len(note.slide_path) >= 2:
# Multi-segment slide: all use >
start = note.slide_path[0]
seg_start = 1
while seg_start < len(note.slide_path) and note.slide_path[seg_start] == start:
seg_start += 1
if seg_start >= len(note.slide_path):
return str(start)
text = str(start)
last_pos = start
for p in note.slide_path[seg_start:]:
if p != last_pos:
text += ">" + str(p)
last_pos = p
dur = note.hold_duration or (4, 1)
text += f"[{dur[0]}:{dur[1]}]"
elif note.is_slide and len(note.positions) >= 2:
text = str(note.positions[0]) + ">" + str(note.positions[1])
dur = note.hold_duration or (4, 1)
text += f"[{dur[0]}:{dur[1]}]"
if note.firework:
text += "x"
if note.is_star:
text += "*"
return text