| """ |
| Maimai Chart Tokenizer โ rule-based bidirectional chart โ token conversion. |
| |
| Design: |
| - BPM is NOT tokenized (computed separately by external program) |
| - Beat division (div_N) tokens control note granularity |
| - Each note event โ 1~5 tokens, lossless round-trip |
| |
| Vocabulary size: 256 (0-255), with room for expansion. |
| |
| Usage: |
| from tokenizer import MaiChartTokenizer |
| |
| tok = MaiChartTokenizer() |
| tokens = tok.encode(chart) |
| chart2 = tok.decode(tokens) # lossless |
| """ |
|
|
| from __future__ import annotations |
| from dataclasses import dataclass |
| import json |
| from typing import Optional |
|
|
| from mai_parser.models import Chart, TouchNote |
|
|
| |
| |
| |
|
|
| |
| PAD = 0 |
| BOS = 1 |
| EOS = 2 |
| SEP = 3 |
| MASK = 4 |
| _SPECIAL_END = 5 |
|
|
| |
| _DIV_VALUES = [1, 2, 4, 8, 16, 32, 48, 64, 128, 192, 384] |
| DIV_BASE = _SPECIAL_END |
| DIV_TO_ID: dict[int, int] = {v: DIV_BASE + i for i, v in enumerate(_DIV_VALUES)} |
| ID_TO_DIV: dict[int, int] = {v: k for k, v in DIV_TO_ID.items()} |
| DIV_END = DIV_BASE + len(_DIV_VALUES) |
|
|
| |
| RST = DIV_END |
| _RST_END = RST + 1 |
|
|
| |
| DUR = _RST_END |
| _DUR_END = DUR + 1 |
|
|
| |
| TAP_BASE = _DUR_END |
| TAP_TO_ID = {i: TAP_BASE + i - 1 for i in range(1, 9)} |
| ID_TO_TAP = {v: k for k, v in TAP_TO_ID.items()} |
| TAP_END = TAP_BASE + 8 |
|
|
| |
| BRK_BASE = TAP_END |
| BRK_TO_ID = {i: BRK_BASE + i - 1 for i in range(1, 9)} |
| ID_TO_BRK = {v: k for k, v in BRK_TO_ID.items()} |
| BRK_END = BRK_BASE + 8 |
|
|
| |
| HLD_BASE = BRK_END |
| HLD_TO_ID = {i: HLD_BASE + i - 1 for i in range(1, 9)} |
| ID_TO_HLD = {v: k for k, v in HLD_TO_ID.items()} |
| HLD_END = HLD_BASE + 8 |
|
|
| |
| SLD_BASE = HLD_END |
| SLD_TO_ID = {i: SLD_BASE + i - 1 for i in range(1, 9)} |
| ID_TO_SLD = {v: k for k, v in SLD_TO_ID.items()} |
| SLD_END = SLD_BASE + 8 |
|
|
| |
| SLD_BEG = SLD_END |
| SLD_END_TOKEN = SLD_BEG + 1 |
| _SLD_CTRL_END = SLD_END_TOKEN + 1 |
|
|
| |
| SIM_BEG = _SLD_CTRL_END |
| SIM_END = SIM_BEG + 1 |
| _SIM_CTRL_END = SIM_END + 1 |
|
|
| |
| TCH_BASE = _SIM_CTRL_END |
| _TOUCH_ZONES = ["A", "B", "C", "D", "E"] |
|
|
| _tch_map: dict[str, int] = {} |
| _idx = TCH_BASE |
| for zone in _TOUCH_ZONES: |
| for pos in range(1, 9): |
| _tch_map[f"{zone}{pos}"] = _idx |
| _idx += 1 |
| |
| _tch_map["C"] = _idx |
| _idx += 1 |
|
|
| TCH_TO_ID = _tch_map |
| ID_TO_TCH = {v: k for k, v in TCH_TO_ID.items()} |
| TCH_END = _idx |
|
|
| |
| SIM_COUNT_2 = TCH_END |
| _SIM_COUNT_END = SIM_COUNT_2 + 1 |
|
|
| _DUR_NUM_VALUES = [1, 2, 3, 4, 6, 8, 12, 16] |
| DUR_NUM_BASE = _SIM_COUNT_END |
| DUR_NUM_TO_ID = {v: DUR_NUM_BASE + i for i, v in enumerate(_DUR_NUM_VALUES)} |
| ID_TO_DUR_NUM = {v: k for k, v in DUR_NUM_TO_ID.items()} |
| DUR_NUM_END = DUR_NUM_BASE + len(_DUR_NUM_VALUES) |
|
|
| _DUR_DEN_VALUES = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64] |
| DUR_DEN_BASE = DUR_NUM_END |
| DUR_DEN_TO_ID = {v: DUR_DEN_BASE + i for i, v in enumerate(_DUR_DEN_VALUES)} |
| ID_TO_DUR_DEN = {v: k for k, v in DUR_DEN_TO_ID.items()} |
| DUR_DEN_END = DUR_DEN_BASE + len(_DUR_DEN_VALUES) |
|
|
| |
| META_BPM = 220 |
| META_DIFF = 221 |
| META_LEVEL = 222 |
| META_GENRE = 223 |
| META_END = 224 |
| SLD_MID = 229 |
| HLD_ON = 230 |
| SLD_ON = 231 |
|
|
| |
| CONFIG_BASE = 256 |
|
|
|
|
| def _duration_pairs(max_beats: float = 4.0) -> list[tuple[int, int]]: |
| pairs: list[tuple[int, int]] = [] |
| for n in _DUR_NUM_VALUES: |
| for d in _DUR_DEN_VALUES: |
| if n / d <= max_beats + 1e-9: |
| pairs.append((n, d)) |
| return pairs |
|
|
|
|
| CONFIG_DURATIONS = _duration_pairs(4.0) |
| CONFIG_TO_ID: dict[tuple, int] = {} |
| ID_TO_CONFIG: dict[int, tuple] = {} |
|
|
|
|
| def _normalize_config_spec(spec) -> tuple: |
| return tuple(tuple(x) if isinstance(x, list) else x for x in spec) |
|
|
|
|
| def _add_config(spec: tuple) -> int: |
| spec = _normalize_config_spec(spec) |
| if spec in CONFIG_TO_ID: |
| return CONFIG_TO_ID[spec] |
| idx = CONFIG_BASE + len(CONFIG_TO_ID) |
| CONFIG_TO_ID[spec] = idx |
| ID_TO_CONFIG[idx] = spec |
| if "_TOKEN_NAMES" in globals(): |
| _TOKEN_NAMES[idx] = "cfg_" + "_".join(str(x) for x in spec) |
| global VOCAB_SIZE |
| VOCAB_SIZE = CONFIG_BASE + len(CONFIG_TO_ID) |
| if "MaiChartTokenizer" in globals(): |
| MaiChartTokenizer.vocab_size = VOCAB_SIZE |
| return idx |
|
|
|
|
| def _build_config_vocab() -> None: |
| |
| for pos in range(1, 9): |
| _add_config(("tap", pos)) |
| _add_config(("brk", pos)) |
| for dur in CONFIG_DURATIONS: |
| _add_config(("hld", pos, dur[0], dur[1])) |
|
|
| for region in sorted(TCH_TO_ID): |
| _add_config(("tch", region)) |
|
|
| |
| button_types = ("tap", "brk", "hld") |
| for p1 in range(1, 9): |
| for p2 in range(p1 + 1, 9): |
| for t1 in button_types: |
| for t2 in button_types: |
| if (t1, p1) > (t2, p2): |
| continue |
| if "hld" in (t1, t2): |
| for dur in CONFIG_DURATIONS: |
| _add_config(("pair", t1, p1, t2, p2, dur[0], dur[1])) |
| else: |
| _add_config(("pair", t1, p1, t2, p2)) |
|
|
| |
| for a in range(1, 9): |
| for b in range(1, 9): |
| if b == a: |
| continue |
| for dur in CONFIG_DURATIONS: |
| _add_config(("sld", a, b, dur[0], dur[1])) |
| for c in range(1, 9): |
| if c in (a, b): |
| continue |
| for dur in CONFIG_DURATIONS: |
| _add_config(("sld", a, b, c, dur[0], dur[1])) |
|
|
|
|
| _build_config_vocab() |
|
|
| VOCAB_SIZE = CONFIG_BASE + len(CONFIG_TO_ID) |
| TOKENIZER_VERSION = 3 |
|
|
|
|
| def export_config_vocab() -> list[list]: |
| return [list(spec) for spec, _ in sorted(CONFIG_TO_ID.items(), key=lambda x: x[1])] |
|
|
|
|
| def load_config_vocab(specs: list) -> None: |
| for spec in specs: |
| _add_config(tuple(spec)) |
|
|
|
|
| def save_config_vocab(path: str) -> None: |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(export_config_vocab(), f, ensure_ascii=False) |
|
|
|
|
| def load_config_vocab_file(path: str) -> None: |
| with open(path, "r", encoding="utf-8") as f: |
| load_config_vocab(json.load(f)) |
|
|
| |
| _MULTI_TOKEN_STARTS = {DUR, SLD_BEG, SLD_END_TOKEN, SIM_BEG, SIM_END} |
|
|
| |
| |
| |
|
|
| _TOKEN_NAMES: dict[int, str] = { |
| PAD: "[PAD]", |
| BOS: "[BOS]", |
| EOS: "[EOS]", |
| SEP: "[SEP]", |
| MASK: "[MASK]", |
| RST: "[RST]", |
| DUR: "[DUR]", |
| SLD_BEG: "[SLD_BEG]", |
| SLD_END_TOKEN: "[SLD_END]", |
| SIM_BEG: "[SIM_BEG]", |
| SIM_END: "[SIM_END]", |
| META_BPM: "[META_BPM]", |
| META_DIFF: "[META_DIFF]", |
| META_LEVEL: "[META_LEVEL]", |
| META_GENRE: "[META_GENRE]", |
| META_END: "[META_END]", |
| SLD_MID: "[SLD_MID]", |
| HLD_ON: "[HLD_ON]", |
| SLD_ON: "[SLD_ON]", |
| } |
| for v, i in DIV_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"div_{v}" |
| for p, i in TAP_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"tap_{p}" |
| for p, i in BRK_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"brk_{p}" |
| for p, i in HLD_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"hld_{p}" |
| for p, i in SLD_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"sld_{p}" |
| for t, i in TCH_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"tch_{t}" |
| _TOKEN_NAMES[SIM_COUNT_2] = "sim_count_2" |
| for v, i in DUR_NUM_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"dur_num_{v}" |
| for v, i in DUR_DEN_TO_ID.items(): |
| _TOKEN_NAMES[i] = f"dur_den_{v}" |
| for spec, i in CONFIG_TO_ID.items(): |
| _TOKEN_NAMES[i] = "cfg_" + "_".join(str(x) for x in spec) |
|
|
|
|
| def token_name(token_id: int) -> str: |
| """Human-readable name for a token ID.""" |
| return _TOKEN_NAMES.get(token_id, f"<{token_id}>") |
|
|
|
|
| def _nearest(values: list[int], value: int) -> int: |
| return min(values, key=lambda x: abs(x - value)) |
|
|
|
|
| def encode_duration_tokens(duration: tuple[int, int]) -> list[int]: |
| beat = _nearest(_DUR_NUM_VALUES, max(1, int(duration[0]))) |
| den = _nearest(_DUR_DEN_VALUES, max(1, int(duration[1]))) |
| return [DUR, DUR_NUM_TO_ID[beat], DUR_DEN_TO_ID[den]] |
|
|
|
|
| def read_duration_tokens(tokens: list[int], start: int) -> Optional[tuple[int, int]]: |
| if start + 2 >= len(tokens) or tokens[start] != DUR: |
| return None |
| num_tok = tokens[start + 1] |
| den_tok = tokens[start + 2] |
| if num_tok in ID_TO_DUR_NUM and den_tok in ID_TO_DUR_DEN: |
| return ID_TO_DUR_NUM[num_tok], ID_TO_DUR_DEN[den_tok] |
|
|
| |
| beat = _nearest(_DUR_NUM_VALUES, max(1, min(int(num_tok), 16))) |
| den = _nearest(_DUR_DEN_VALUES, max(1, int(den_tok))) |
| return beat, den |
|
|
|
|
| def make_sim_tokens(note_tokens: list[int]) -> list[int]: |
| note_tokens = [t for t in note_tokens if t not in (PAD, BOS, EOS)] |
| if len(note_tokens) <= 1: |
| return note_tokens |
| result: list[int] = [] |
| result.extend([SIM_BEG, SIM_COUNT_2, note_tokens[0], note_tokens[1], SIM_END]) |
| result.extend(note_tokens[2:]) |
| return result |
|
|
|
|
| def _snap_config_duration(duration: tuple[int, int] | None) -> tuple[int, int]: |
| if not duration: |
| return (1, 1) |
| n = _nearest(_DUR_NUM_VALUES, max(1, int(duration[0]))) |
| d = _nearest(_DUR_DEN_VALUES, max(1, int(duration[1]))) |
| if n / d > 4.0: |
| return min(CONFIG_DURATIONS, key=lambda x: (abs((x[0] / x[1]) - 4.0), x[1])) |
| return n, d |
|
|
|
|
| def config_token_for_note(note: TouchNote) -> int | None: |
| if note.is_rest or note.is_end: |
| return None |
| if note.is_touch and len(note.touch_regions) == 1 and not note.is_hold: |
| return CONFIG_TO_ID.get(("tch", note.touch_regions[0])) |
| if note.is_slide: |
| path = note.slide_path or note.positions |
| if len(path) >= 2: |
| dur = _snap_config_duration(note.hold_duration) |
| spec = ("sld", *path, dur[0], dur[1]) |
| if spec in CONFIG_TO_ID: |
| return CONFIG_TO_ID[spec] |
| return None |
|
|
| if len(note.positions) == 1: |
| pos = note.positions[0] |
| if not (1 <= pos <= 8): |
| return None |
| if note.is_hold: |
| dur = _snap_config_duration(note.hold_duration) |
| return CONFIG_TO_ID.get(("hld", pos, dur[0], dur[1])) |
| if note.is_break: |
| return CONFIG_TO_ID.get(("brk", pos)) |
| return CONFIG_TO_ID.get(("tap", pos)) |
|
|
| if len(note.positions) == 2: |
| p1, p2 = sorted(note.positions) |
| if not (1 <= p1 <= 8 and 1 <= p2 <= 8): |
| return None |
| typ = "hld" if note.is_hold else "brk" if note.is_break else "tap" |
| if typ == "hld": |
| dur = _snap_config_duration(note.hold_duration) |
| return CONFIG_TO_ID.get(("pair", "hld", p1, "hld", p2, dur[0], dur[1])) |
| return CONFIG_TO_ID.get(("pair", typ, p1, typ, p2)) |
|
|
| return None |
|
|
|
|
| def learn_config_from_note(note: TouchNote) -> int | None: |
| """Register a config token from a real chart note, preserving rare shapes.""" |
| if note.is_rest or note.is_end: |
| return None |
| existing = config_token_for_note(note) |
| if existing is not None: |
| return existing |
|
|
| if note.is_touch and note.touch_regions and not note.is_hold: |
| return _add_config(("touch_multi", *sorted(note.touch_regions))) |
|
|
| if note.is_slide: |
| path = note.slide_path or note.positions |
| if len(path) >= 2: |
| dur = _snap_config_duration(note.hold_duration) |
| return _add_config(("sld", *path, dur[0], dur[1])) |
| return None |
|
|
| if len(note.positions) >= 2: |
| positions = sorted(p for p in note.positions if 1 <= p <= 8) |
| if len(positions) < 2: |
| return None |
| typ = "hld" if note.is_hold else "brk" if note.is_break else "tap" |
| if note.is_hold: |
| dur = _snap_config_duration(note.hold_duration) |
| spec = ("multi", typ, *positions, dur[0], dur[1]) |
| else: |
| spec = ("multi", typ, *positions) |
| if spec in CONFIG_TO_ID: |
| return CONFIG_TO_ID[spec] |
| if len(positions) == 2: |
| return config_token_for_note(note) |
| return _add_config(spec) |
|
|
| return None |
|
|
|
|
| def learn_config_vocab_from_charts(charts) -> int: |
| before = len(CONFIG_TO_ID) |
| for chart in charts: |
| for note in chart.notes: |
| learn_config_from_note(note) |
| return len(CONFIG_TO_ID) - before |
|
|
|
|
| def note_from_config_token(token_id: int, beat_div: int) -> TouchNote | None: |
| spec = ID_TO_CONFIG.get(token_id) |
| if spec is None: |
| return None |
| note = TouchNote(beat_div=beat_div) |
| kind = spec[0] |
| if kind == "tap": |
| note.positions = [int(spec[1])] |
| elif kind == "brk": |
| note.positions = [int(spec[1])] |
| note.is_break = True |
| elif kind == "hld": |
| note.positions = [int(spec[1])] |
| note.is_hold = True |
| note.hold_duration = (int(spec[2]), int(spec[3])) |
| elif kind == "tch": |
| note.is_touch = True |
| note.touch_regions = [str(spec[1])] |
| elif kind == "touch_multi": |
| note.is_touch = True |
| note.touch_regions = [str(x) for x in spec[1:]] |
| note.is_simultaneous = len(note.touch_regions) > 1 |
| elif kind == "pair": |
| t1, p1, t2, p2 = spec[1], int(spec[2]), spec[3], int(spec[4]) |
| note.positions = [p1, p2] |
| note.is_simultaneous = True |
| if t1 == "hld" or t2 == "hld": |
| note.is_hold = True |
| note.hold_duration = (int(spec[5]), int(spec[6])) |
| elif t1 == "brk" or t2 == "brk": |
| note.is_break = True |
| elif kind == "sld": |
| *path, n, d = spec[1:] |
| note.positions = [int(x) for x in path] |
| note.slide_path = list(note.positions) |
| note.is_slide = True |
| note.hold_duration = (int(n), int(d)) |
| elif kind == "multi": |
| typ = spec[1] |
| if typ == "hld": |
| *positions, n, d = spec[2:] |
| note.positions = [int(x) for x in positions] |
| note.is_hold = True |
| note.hold_duration = (int(n), int(d)) |
| else: |
| note.positions = [int(x) for x in spec[2:]] |
| note.is_break = typ == "brk" |
| note.is_simultaneous = len(note.positions) > 1 |
| else: |
| return None |
| return note |
|
|
|
|
| |
| |
| |
|
|
| class MaiChartTokenizer: |
| """ |
| Rule-based bidirectional tokenizer for maimai charts. |
| |
| encode(chart) โ list[int] # chart โ tokens |
| decode(tokens) โ Chart # tokens โ chart (lossless) |
| """ |
|
|
| vocab_size: int = VOCAB_SIZE |
| pad_token_id: int = PAD |
| bos_token_id: int = BOS |
| eos_token_id: int = EOS |
| mask_token_id: int = MASK |
|
|
| |
|
|
| def encode(self, chart: Chart, add_bos: bool = True, |
| add_eos: bool = True) -> list[int]: |
| """ |
| Convert a Chart's notes into a token sequence. |
| |
| Args: |
| chart: Parsed Chart from mai_parser. |
| add_bos: Prepend [BOS] token. |
| add_eos: Append [EOS] token. |
| |
| Returns: |
| List of token IDs. |
| """ |
| tokens: list[int] = [] |
| if add_bos: |
| tokens.append(BOS) |
|
|
| current_div = 4 |
|
|
| for note in chart.notes: |
| |
| if note.beat_div != current_div: |
| current_div = note.beat_div |
| div_id = DIV_TO_ID.get(current_div) |
| if div_id is not None: |
| tokens.append(div_id) |
|
|
| |
| tokens.extend(self._encode_note(note)) |
|
|
| if add_eos: |
| tokens.append(EOS) |
|
|
| return tokens |
|
|
| def _encode_note(self, note: TouchNote) -> list[int]: |
| """Encode a single TouchNote โ list of token IDs.""" |
|
|
| if note.is_end: |
| return [EOS] |
|
|
| if note.is_rest: |
| return [RST] |
|
|
| cfg = config_token_for_note(note) |
| if cfg is not None: |
| return [cfg] |
|
|
| |
| if note.is_touch: |
| return self._encode_touch(note) |
|
|
| |
| if note.is_break: |
| result = [] |
| for pos in note.positions: |
| if 1 <= pos <= 8: |
| result.append(BRK_TO_ID[pos]) |
| return make_sim_tokens(result) |
|
|
| |
| if note.is_hold: |
| result = [] |
| for pos in note.positions: |
| if 1 <= pos <= 8: |
| result.append(HLD_TO_ID[pos]) |
| result = make_sim_tokens(result) |
| if note.hold_duration: |
| result.extend(encode_duration_tokens(note.hold_duration)) |
| return result |
|
|
| |
| if note.is_slide: |
| result = [] |
| |
| positions = list(note.positions) |
| if note.slide_path: |
| |
| positions = note.slide_path |
|
|
| if len(positions) >= 2: |
| result.append(SLD_BEG) |
| result.append(len(positions)) |
| for pos in positions: |
| if 1 <= pos <= 8: |
| result.append(SLD_TO_ID[pos]) |
| result.append(SLD_END_TOKEN) |
| elif len(positions) == 1 and 1 <= positions[0] <= 8: |
| result.append(SLD_TO_ID[positions[0]]) |
|
|
| if note.hold_duration: |
| result.extend(encode_duration_tokens(note.hold_duration)) |
| return result |
|
|
| |
| if len(note.positions) > 1: |
| result = [] |
| for pos in note.positions: |
| if 1 <= pos <= 8: |
| result.append(TAP_TO_ID[pos]) |
| return make_sim_tokens(result) |
|
|
| |
| for pos in note.positions: |
| if 1 <= pos <= 8: |
| return [TAP_TO_ID[pos]] |
|
|
| return [RST] |
|
|
| def _encode_touch(self, note: TouchNote) -> list[int]: |
| """Encode a touch note.""" |
| result = [] |
| for region in note.touch_regions: |
| tid = TCH_TO_ID.get(region) |
| if tid is not None: |
| result.append(tid) |
| if len(result) > 1: |
| result = make_sim_tokens(result) |
| if note.is_hold and note.hold_duration: |
| result.extend(encode_duration_tokens(note.hold_duration)) |
| return result if result else [RST] |
|
|
| |
|
|
| def decode(self, tokens: list[int]) -> Chart: |
| """ |
| Convert a token sequence back into a Chart. |
| |
| Args: |
| tokens: List of token IDs (may include BOS/EOS). |
| |
| Returns: |
| Reconstructed Chart (notes only; metadata not recoverable from tokens alone). |
| """ |
| notes: list[TouchNote] = [] |
| current_div = 4 |
| i = 0 |
|
|
| while i < len(tokens): |
| tid = tokens[i] |
|
|
| |
| if tid == BOS: |
| i += 1 |
| continue |
|
|
| |
| if tid == EOS: |
| note = TouchNote(beat_div=current_div, raw="E") |
| note.is_end = True |
| notes.append(note) |
| i += 1 |
| continue |
|
|
| |
| if tid in ID_TO_DIV: |
| current_div = ID_TO_DIV[tid] |
| i += 1 |
| continue |
|
|
| |
| if tid == RST: |
| note = TouchNote(beat_div=current_div, raw="") |
| note.is_rest = True |
| notes.append(note) |
| i += 1 |
| continue |
|
|
| cfg_note = note_from_config_token(tid, current_div) |
| if cfg_note is not None: |
| notes.append(cfg_note) |
| i += 1 |
| continue |
|
|
| |
| |
|
|
| |
| if tid == SLD_BEG: |
| i += 1 |
| if i >= len(tokens): |
| break |
| if tokens[i] in ID_TO_SLD or tokens[i] in (SLD_MID, SLD_ON): |
| |
| |
| positions = [] |
| while i < len(tokens) and tokens[i] not in (SLD_END_TOKEN, DUR, EOS): |
| if tokens[i] in (SLD_MID, SLD_ON): |
| i += 1 |
| continue |
| pt = tokens[i] |
| if pt in ID_TO_SLD: |
| positions.append(ID_TO_SLD[pt]) |
| i += 1 |
| continue |
| |
| break |
| if i < len(tokens) and tokens[i] == SLD_END_TOKEN: |
| i += 1 |
| dur = self._read_dur(tokens, i) |
| if dur: |
| i += 3 |
| note = TouchNote(beat_div=current_div, positions=positions) |
| note.is_slide = True |
| note.slide_path = list(positions) |
| if dur: |
| note.hold_duration = dur |
| notes.append(note) |
| continue |
| n_pts = tokens[i] |
| i += 1 |
| positions = [] |
| for _ in range(n_pts): |
| if i >= len(tokens): |
| break |
| pt = tokens[i] |
| if pt in ID_TO_SLD: |
| positions.append(ID_TO_SLD[pt]) |
| i += 1 |
| |
| if i < len(tokens) and tokens[i] == SLD_END_TOKEN: |
| i += 1 |
| |
| dur = self._read_dur(tokens, i) |
| if dur: |
| i += 3 |
| note = TouchNote(beat_div=current_div, positions=positions) |
| note.is_slide = True |
| note.slide_path = list(positions) |
| if dur: |
| note.hold_duration = dur |
| notes.append(note) |
| continue |
|
|
| |
| if tid == SIM_BEG: |
| i += 1 |
| if i >= len(tokens): |
| break |
| count_tok = tokens[i] |
| n_notes = 2 if count_tok == SIM_COUNT_2 else int(count_tok) |
| i += 1 |
| sub_notes: list[TouchNote] = [] |
| dur = None |
| while i < len(tokens) and tokens[i] not in (SIM_END, EOS): |
| sub_tid = tokens[i] |
| if sub_tid == DUR: |
| dur = self._read_dur(tokens, i) |
| break |
| sub_note = self._decode_single_note(sub_tid, current_div) |
| if sub_note: |
| sub_notes.append(sub_note) |
| i += 1 |
| if i < len(tokens) and tokens[i] == SIM_END: |
| i += 1 |
| |
| if sub_notes: |
| merged = sub_notes[0] |
| all_pos = [] |
| has_hold = merged.is_hold |
| has_break = merged.is_break |
| is_touch = merged.is_touch |
| all_touch_regions = list(merged.touch_regions) |
| for sn in sub_notes: |
| all_pos.extend(sn.positions) |
| has_hold = has_hold or sn.is_hold |
| has_break = has_break or sn.is_break |
| is_touch = is_touch or sn.is_touch |
| all_touch_regions.extend(sn.touch_regions) |
| merged.positions = all_pos |
| merged.is_simultaneous = True |
| merged.touch_regions = all_touch_regions |
| merged.is_touch = is_touch |
| if dur: |
| merged.hold_duration = dur |
| |
| if not is_touch: |
| has_hold = True |
| merged.is_hold = has_hold and not is_touch |
| merged.is_break = has_break and not is_touch |
| notes.append(merged) |
| continue |
|
|
| |
| if tid == DUR: |
| i += 3 |
| continue |
|
|
| |
| if tid in (SLD_END_TOKEN, SIM_END): |
| i += 1 |
| continue |
|
|
| |
| note = self._decode_single_note(tid, current_div) |
| if note: |
| |
| dur = self._read_dur(tokens, i + 1) |
| if dur: |
| note.hold_duration = dur |
| |
| if not note.is_slide and not note.is_touch and not note.is_break: |
| note.is_hold = True |
| i += 3 |
| notes.append(note) |
| i += 1 |
|
|
| from mai_parser.models import Difficulty |
| chart = Chart(difficulty_index=0, difficulty=Difficulty.ReMASTER) |
| chart.notes = notes |
| chart.compute_stats() |
| return chart |
|
|
| def _decode_single_note(self, tid: int, beat_div: int) -> Optional[TouchNote]: |
| """Decode a single note token (not part of a group).""" |
| note = TouchNote(beat_div=beat_div) |
|
|
| if tid in ID_TO_TAP: |
| note.positions = [ID_TO_TAP[tid]] |
| return note |
|
|
| if tid in ID_TO_BRK: |
| note.positions = [ID_TO_BRK[tid]] |
| note.is_break = True |
| return note |
|
|
| if tid in ID_TO_HLD: |
| note.positions = [ID_TO_HLD[tid]] |
| note.is_hold = True |
| return note |
|
|
| if tid in ID_TO_SLD: |
| note.positions = [ID_TO_SLD[tid]] |
| note.is_slide = True |
| return note |
|
|
| if tid in ID_TO_TCH: |
| region = ID_TO_TCH[tid] |
| note.is_touch = True |
| note.touch_regions = [region] |
| return note |
|
|
| return None |
|
|
| def _read_dur(self, tokens: list[int], start: int) -> Optional[tuple[int, int]]: |
| """Try to read DUR beat subdiv from tokens[start:]. Returns (beat, subdiv) or None. |
| Clamps to reasonable ranges to filter out hallucinated durations.""" |
| return read_duration_tokens(tokens, start) |
|
|
| |
|
|
| def encode_batch(self, charts: list[Chart], pad_to: Optional[int] = None, |
| add_bos: bool = True, add_eos: bool = True, |
| return_tensors: bool = False): |
| """ |
| Encode a batch of charts, padding to the same length. |
| |
| Args: |
| charts: List of Chart objects. |
| pad_to: Pad all sequences to this length (auto-detect max if None). |
| add_bos: Prepend BOS. |
| add_eos: Append EOS. |
| return_tensors: If True, return torch.Tensor (requires torch). |
| |
| Returns: |
| If return_tensors=False: (list[list[int]], list[int]) = (token_seqs, lengths) |
| If return_tensors=True: (Tensor[batch, max_len], Tensor[batch]) |
| """ |
| seqs = [self.encode(c, add_bos=add_bos, add_eos=add_eos) for c in charts] |
| lengths = [len(s) for s in seqs] |
| max_len = max(lengths) if pad_to is None else pad_to |
|
|
| padded = [] |
| for seq in seqs: |
| if len(seq) < max_len: |
| seq = seq + [PAD] * (max_len - len(seq)) |
| padded.append(seq[:max_len]) |
|
|
| if return_tensors: |
| try: |
| import torch |
| return torch.tensor(padded, dtype=torch.long), torch.tensor(lengths, dtype=torch.long) |
| except ImportError: |
| raise ImportError("torch required for return_tensors=True") |
|
|
| return padded, lengths |
|
|
| |
|
|
| def tokens_to_str(self, tokens: list[int], max_show: int = 60) -> str: |
| """Pretty-print a token sequence with context for raw parameter ids.""" |
| parts = [] |
| i = 0 |
| shown = 0 |
| while i < len(tokens) and shown < max_show: |
| tid = tokens[i] |
| if tid == DUR and i + 2 < len(tokens): |
| parts.append("[DUR]") |
| shown += 1 |
| if shown < max_show: |
| parts.append(token_name(tokens[i + 1])) |
| shown += 1 |
| if shown < max_show: |
| parts.append(token_name(tokens[i + 2])) |
| shown += 1 |
| i += 3 |
| continue |
| if tid == SIM_BEG and i + 1 < len(tokens): |
| parts.append("[SIM_BEG]") |
| shown += 1 |
| if shown < max_show: |
| parts.append(token_name(tokens[i + 1])) |
| shown += 1 |
| i += 2 |
| continue |
| parts.append(token_name(tid)) |
| shown += 1 |
| i += 1 |
| if i < len(tokens): |
| parts.append(f"... ({len(tokens) - i} more)") |
| return " ".join(parts) |
|
|
| def print_tokens(self, tokens: list[int], max_show: int = 60) -> None: |
| """Print a token sequence.""" |
| print(self.tokens_to_str(tokens, max_show)) |
|
|
|
|
| |
| |
| |
|
|
| def build_metadata_header(bpm: float, difficulty: int, |
| level_value: float, genre: int = 0) -> list[int]: |
| """ |
| Build a metadata header token sequence. |
| |
| Format: [META_BPM] bpm_byte [META_DIFF] diff [META_LEVEL] level_byte [META_GENRE] genre [META_END] |
| |
| This is prepended to chart tokens during training so the model |
| learns to associate metadata with chart style. |
| |
| Args: |
| bpm: BPM value (e.g. 173.0) |
| difficulty: 0=BASIC..4=ReMASTER |
| level_value: e.g. 12.4 |
| genre: Genre index |
| |
| Returns: |
| List of token IDs. |
| """ |
| return [ |
| META_BPM, int(bpm) // 2, |
| META_DIFF, difficulty, |
| META_LEVEL, min(255, int(level_value * 10)), |
| META_GENRE, genre, |
| META_END, |
| ] |
|
|
|
|
| def encode_chart_with_header(chart: Chart, bpm: float, difficulty: int, |
| level_value: float, genre: int = 0) -> list[int]: |
| """Encode chart with grammar-friendly slides (no metadata header, no EOS). |
| |
| Metadata (BPM, difficulty, level, genre) is passed as separate condition |
| inputs to the model โ NOT as chart tokens. The model learns difficulty |
| from the diff_embed MoE routing, not from token-level metadata. |
| |
| HLD_ON/SLD_ON and SLD_MID are inference context/helper tokens, not targets. |
| |
| Returns: [BOS] + chart_tokens |
| """ |
| tok = MaiChartTokenizer() |
| chart_tokens = tok.encode(chart, add_bos=False, add_eos=False) |
| |
| |
| |
| |
| while chart_tokens and chart_tokens[-1] == EOS: |
| chart_tokens.pop() |
| chart_tokens = unfold_slides(chart_tokens) |
| return [BOS] + chart_tokens |
|
|
|
|
| def unfold_slides(tokens): |
| """Unfold multi-segment slides into grammar-friendly waypoint tokens. |
| |
| SLD_BEG n sld_a sld_b sld_c SLD_END โ SLD_BEG sld_a sld_b sld_c SLD_END |
| """ |
| result, i = [], 0 |
| while i < len(tokens): |
| t = tokens[i] |
| if t == SLD_BEG and i + 2 < len(tokens): |
| n = tokens[i + 1] |
| if 0 < n < 32 and i + 2 + n < len(tokens): |
| pts = tokens[i + 2 : i + 2 + n] |
| result.append(SLD_BEG) |
| result.extend(pts) |
| result.append(SLD_END_TOKEN) |
| i += 2 + n + 1; continue |
| result.append(t); i += 1 |
| return result |
|
|
|
|
| def inject_ongoing_tokens(tokens: list[int]) -> list[int]: |
| """Insert HLD_ON/SLD_ON markers at intermediate positions where a hold/slide is active. |
| |
| HLD_n DUR beat subdiv ...tokens... โ HLD_ON inserted at each non-DUR position |
| while the hold is active. Same for slides. |
| |
| These are informational โ the model learns "a hold is ongoing here". |
| During inference they are suppressed; the engine doesn't generate them. |
| """ |
| result = [] |
| current_div = 4.0 |
| hold_beats = 0.0 |
| slide_beats = 0.0 |
| dur_skip = 0 |
| i = 0 |
|
|
| while i < len(tokens): |
| t = tokens[i] |
| step = 4.0 / current_div |
|
|
| |
| _is_note = (t >= TAP_BASE and t != DUR) or t == RST or t in ID_TO_DIV |
| if _is_note and dur_skip == 0: |
| if hold_beats > 0: |
| result.append(HLD_ON) |
| hold_beats -= step |
| if slide_beats > 0: |
| result.append(SLD_ON) |
| slide_beats -= step |
|
|
| |
| if dur_skip > 0: |
| dur_skip -= 1 |
| result.append(t); i += 1 |
| continue |
|
|
| if t in ID_TO_DIV: |
| current_div = float(ID_TO_DIV.get(t, current_div)) |
| elif t == DUR: |
| dur_skip = 2 |
| elif t in HLD_TO_ID: |
| |
| if i + 3 < len(tokens) and tokens[i + 1] == DUR: |
| beat = tokens[i + 2] |
| subdiv = max(tokens[i + 3], 1) |
| hold_beats = beat / subdiv |
| elif t == SLD_BEG: |
| |
| j = i + 2 |
| while j < len(tokens) and tokens[j] != SLD_END_TOKEN and tokens[j] != DUR: |
| j += 1 |
| if j < len(tokens) and tokens[j] == SLD_END_TOKEN: |
| j += 1 |
| if j + 2 < len(tokens) and tokens[j] == DUR: |
| beat = tokens[j + 1] |
| subdiv = max(tokens[j + 2], 1) |
| slide_beats = beat / subdiv |
|
|
| result.append(t) |
| i += 1 |
|
|
| return result |
|
|
|
|
| |
| |
| |
|
|
| def notes_to_maitext(notes, bpm=150.0): |
| """Convert TouchNote list back to maidata chart text. |
| Format: (173){4}1,2,3h[4:1],5b/8b, |
| """ |
| bpm_int = int(bpm) |
| current_div = 4 |
| line = f"({bpm_int})" |
| measure = [] |
|
|
| for note in notes: |
| if note.is_end: |
| if measure: |
| line += "{" + str(current_div) + "}" + ",".join(measure) + "," |
| return line + "\nE" |
|
|
| if note.beat_div != current_div: |
| if measure: |
| line += "{" + str(current_div) + "}" + ",".join(measure) + "," |
| measure = [] |
| else: |
| line += "{" + str(current_div) + "}," |
| current_div = note.beat_div |
|
|
| measure.append(_note_to_text(note)) |
|
|
| if measure: |
| line += "{" + str(current_div) + "}" + ",".join(measure) + "," |
| return line + "\nE" |
|
|
|
|
| def tokens_to_maitext(tokens, bpm=150.0): |
| """Token sequence โ maidata.txt chart text.""" |
| tok = MaiChartTokenizer() |
| chart = tok.decode(tokens) |
| return notes_to_maitext(chart.notes, bpm) |
|
|
|
|
| def _note_to_text(note): |
| """Single TouchNote โ maidata text segment.""" |
| if note.is_rest: |
| return "" |
| if note.is_touch: |
| |
| regions = [] |
| for r in note.touch_regions: |
| if r.startswith("C") and len(r) > 1: |
| regions.append("C") |
| else: |
| regions.append(r) |
| text = "/".join(regions) |
| if note.is_hold and note.hold_duration: |
| text += f"h[{note.hold_duration[0]}:{note.hold_duration[1]}]" |
| return text |
|
|
| if note.positions: |
| text = "/".join(str(p) for p in note.positions) |
| else: |
| return "" |
|
|
| if note.is_hold and note.hold_duration: |
| text += f"h[{note.hold_duration[0]}:{note.hold_duration[1]}]" |
| elif note.is_break: |
| text += "b" |
| elif note.is_slide and note.slide_path and len(note.slide_path) >= 2: |
| |
| start = note.slide_path[0] |
| seg_start = 1 |
| while seg_start < len(note.slide_path) and note.slide_path[seg_start] == start: |
| seg_start += 1 |
| if seg_start >= len(note.slide_path): |
| return str(start) |
| text = str(start) |
| last_pos = start |
| for p in note.slide_path[seg_start:]: |
| if p != last_pos: |
| text += ">" + str(p) |
| last_pos = p |
| dur = note.hold_duration or (4, 1) |
| text += f"[{dur[0]}:{dur[1]}]" |
| elif note.is_slide and len(note.positions) >= 2: |
| text = str(note.positions[0]) + ">" + str(note.positions[1]) |
| dur = note.hold_duration or (4, 1) |
| text += f"[{dur[0]}:{dur[1]}]" |
|
|
| if note.firework: |
| text += "x" |
| if note.is_star: |
| text += "*" |
| return text |
|
|