""" Maimai Chart Tokenizer — rule-based bidirectional chart ↔ token conversion. Design: - BPM is NOT tokenized (computed separately by external program) - Beat division (div_N) tokens control note granularity - Each note event → 1~5 tokens, lossless round-trip Vocabulary size: 256 (0-255), with room for expansion. Usage: from tokenizer import MaiChartTokenizer tok = MaiChartTokenizer() tokens = tok.encode(chart) chart2 = tok.decode(tokens) # lossless """ from __future__ import annotations from dataclasses import dataclass import json from typing import Optional from mai_parser.models import Chart, TouchNote # ═══════════════════════════════════════════════════════════════════════ # Vocabulary definition # ═══════════════════════════════════════════════════════════════════════ # --- Special tokens (0-4) --- PAD = 0 BOS = 1 EOS = 2 SEP = 3 MASK = 4 _SPECIAL_END = 5 # --- Beat division tokens (5-16): {1} {2} {4} {8} {16} {32} {48} {64} {128} {192} {384} --- _DIV_VALUES = [1, 2, 4, 8, 16, 32, 48, 64, 128, 192, 384] DIV_BASE = _SPECIAL_END DIV_TO_ID: dict[int, int] = {v: DIV_BASE + i for i, v in enumerate(_DIV_VALUES)} ID_TO_DIV: dict[int, int] = {v: k for k, v in DIV_TO_ID.items()} DIV_END = DIV_BASE + len(_DIV_VALUES) # --- Rest token (17) --- RST = DIV_END _RST_END = RST + 1 # --- Duration marker (18): followed by 2 tokens [beat, subdiv] --- DUR = _RST_END _DUR_END = DUR + 1 # --- Tap tokens (19-26): tap_1 ~ tap_8 --- TAP_BASE = _DUR_END TAP_TO_ID = {i: TAP_BASE + i - 1 for i in range(1, 9)} ID_TO_TAP = {v: k for k, v in TAP_TO_ID.items()} TAP_END = TAP_BASE + 8 # --- Break tokens (27-34): brk_1 ~ brk_8 --- BRK_BASE = TAP_END BRK_TO_ID = {i: BRK_BASE + i - 1 for i in range(1, 9)} ID_TO_BRK = {v: k for k, v in BRK_TO_ID.items()} BRK_END = BRK_BASE + 8 # --- Hold tokens (35-42): hld_1 ~ hld_8 --- HLD_BASE = BRK_END HLD_TO_ID = {i: HLD_BASE + i - 1 for i in range(1, 9)} ID_TO_HLD = {v: k for k, v in HLD_TO_ID.items()} HLD_END = HLD_BASE + 8 # --- Slide waypoint tokens (43-50): sld_1 ~ sld_8 --- SLD_BASE = HLD_END SLD_TO_ID = {i: SLD_BASE + i - 1 for i in range(1, 9)} ID_TO_SLD = {v: k for k, v in SLD_TO_ID.items()} SLD_END = SLD_BASE + 8 # --- Slide control (51-52) --- SLD_BEG = SLD_END # slide start, next token = point count SLD_END_TOKEN = SLD_BEG + 1 # slide end marker _SLD_CTRL_END = SLD_END_TOKEN + 1 # --- Simultaneous control (53-54) --- SIM_BEG = _SLD_CTRL_END # simultaneous start, next token = note count SIM_END = SIM_BEG + 1 # simultaneous end _SIM_CTRL_END = SIM_END + 1 # --- Touch tokens (55-90): A1-A8, B1-B8, C1-C8, D1-D8, E1-E8 --- TCH_BASE = _SIM_CTRL_END _TOUCH_ZONES = ["A", "B", "C", "D", "E"] _tch_map: dict[str, int] = {} _idx = TCH_BASE for zone in _TOUCH_ZONES: for pos in range(1, 9): _tch_map[f"{zone}{pos}"] = _idx _idx += 1 # C (center touch without position) _tch_map["C"] = _idx _idx += 1 TCH_TO_ID = _tch_map ID_TO_TCH = {v: k for k, v in TCH_TO_ID.items()} TCH_END = _idx # --- Parameter tokens (96-120): context-safe count/duration values --- SIM_COUNT_2 = TCH_END _SIM_COUNT_END = SIM_COUNT_2 + 1 _DUR_NUM_VALUES = [1, 2, 3, 4, 6, 8, 12, 16] DUR_NUM_BASE = _SIM_COUNT_END DUR_NUM_TO_ID = {v: DUR_NUM_BASE + i for i, v in enumerate(_DUR_NUM_VALUES)} ID_TO_DUR_NUM = {v: k for k, v in DUR_NUM_TO_ID.items()} DUR_NUM_END = DUR_NUM_BASE + len(_DUR_NUM_VALUES) _DUR_DEN_VALUES = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64] DUR_DEN_BASE = DUR_NUM_END DUR_DEN_TO_ID = {v: DUR_DEN_BASE + i for i, v in enumerate(_DUR_DEN_VALUES)} ID_TO_DUR_DEN = {v: k for k, v in DUR_DEN_TO_ID.items()} DUR_DEN_END = DUR_DEN_BASE + len(_DUR_DEN_VALUES) # --- Metadata/header/helper tokens (220-231): kept for backward compatibility --- META_BPM = 220 # followed by BPM//2 as uint8 META_DIFF = 221 # followed by difficulty enum (0-4) META_LEVEL = 222 # followed by level*10 as uint8 META_GENRE = 223 # followed by genre id META_END = 224 # end of metadata header SLD_MID = 229 # slide intermediate waypoint HLD_ON = 230 # hold is ongoing at this beat position (informational) SLD_ON = 231 # slide is ongoing at this beat position (informational) # --- Whole time-slot configuration tokens (256+) --- CONFIG_BASE = 256 def _duration_pairs(max_beats: float = 4.0) -> list[tuple[int, int]]: pairs: list[tuple[int, int]] = [] for n in _DUR_NUM_VALUES: for d in _DUR_DEN_VALUES: if n / d <= max_beats + 1e-9: pairs.append((n, d)) return pairs CONFIG_DURATIONS = _duration_pairs(4.0) CONFIG_TO_ID: dict[tuple, int] = {} ID_TO_CONFIG: dict[int, tuple] = {} def _normalize_config_spec(spec) -> tuple: return tuple(tuple(x) if isinstance(x, list) else x for x in spec) def _add_config(spec: tuple) -> int: spec = _normalize_config_spec(spec) if spec in CONFIG_TO_ID: return CONFIG_TO_ID[spec] idx = CONFIG_BASE + len(CONFIG_TO_ID) CONFIG_TO_ID[spec] = idx ID_TO_CONFIG[idx] = spec if "_TOKEN_NAMES" in globals(): _TOKEN_NAMES[idx] = "cfg_" + "_".join(str(x) for x in spec) global VOCAB_SIZE VOCAB_SIZE = CONFIG_BASE + len(CONFIG_TO_ID) if "MaiChartTokenizer" in globals(): MaiChartTokenizer.vocab_size = VOCAB_SIZE return idx def _build_config_vocab() -> None: # Single button/touch events. for pos in range(1, 9): _add_config(("tap", pos)) _add_config(("brk", pos)) for dur in CONFIG_DURATIONS: _add_config(("hld", pos, dur[0], dur[1])) for region in sorted(TCH_TO_ID): _add_config(("tch", region)) # Two-note simultaneous button configurations. Holds share one duration. button_types = ("tap", "brk", "hld") for p1 in range(1, 9): for p2 in range(p1 + 1, 9): for t1 in button_types: for t2 in button_types: if (t1, p1) > (t2, p2): continue if "hld" in (t1, t2): for dur in CONFIG_DURATIONS: _add_config(("pair", t1, p1, t2, p2, dur[0], dur[1])) else: _add_config(("pair", t1, p1, t2, p2)) # Common slide configurations: 2- and 3-point paths with duration. for a in range(1, 9): for b in range(1, 9): if b == a: continue for dur in CONFIG_DURATIONS: _add_config(("sld", a, b, dur[0], dur[1])) for c in range(1, 9): if c in (a, b): continue for dur in CONFIG_DURATIONS: _add_config(("sld", a, b, c, dur[0], dur[1])) _build_config_vocab() VOCAB_SIZE = CONFIG_BASE + len(CONFIG_TO_ID) TOKENIZER_VERSION = 3 def export_config_vocab() -> list[list]: return [list(spec) for spec, _ in sorted(CONFIG_TO_ID.items(), key=lambda x: x[1])] def load_config_vocab(specs: list) -> None: for spec in specs: _add_config(tuple(spec)) def save_config_vocab(path: str) -> None: with open(path, "w", encoding="utf-8") as f: json.dump(export_config_vocab(), f, ensure_ascii=False) def load_config_vocab_file(path: str) -> None: with open(path, "r", encoding="utf-8") as f: load_config_vocab(json.load(f)) # --- Special tokens that start a multi-token group --- _MULTI_TOKEN_STARTS = {DUR, SLD_BEG, SLD_END_TOKEN, SIM_BEG, SIM_END} # ═══════════════════════════════════════════════════════════════════════ # Token name lookup (for debugging) # ═══════════════════════════════════════════════════════════════════════ _TOKEN_NAMES: dict[int, str] = { PAD: "[PAD]", BOS: "[BOS]", EOS: "[EOS]", SEP: "[SEP]", MASK: "[MASK]", RST: "[RST]", DUR: "[DUR]", SLD_BEG: "[SLD_BEG]", SLD_END_TOKEN: "[SLD_END]", SIM_BEG: "[SIM_BEG]", SIM_END: "[SIM_END]", META_BPM: "[META_BPM]", META_DIFF: "[META_DIFF]", META_LEVEL: "[META_LEVEL]", META_GENRE: "[META_GENRE]", META_END: "[META_END]", SLD_MID: "[SLD_MID]", HLD_ON: "[HLD_ON]", SLD_ON: "[SLD_ON]", } for v, i in DIV_TO_ID.items(): _TOKEN_NAMES[i] = f"div_{v}" for p, i in TAP_TO_ID.items(): _TOKEN_NAMES[i] = f"tap_{p}" for p, i in BRK_TO_ID.items(): _TOKEN_NAMES[i] = f"brk_{p}" for p, i in HLD_TO_ID.items(): _TOKEN_NAMES[i] = f"hld_{p}" for p, i in SLD_TO_ID.items(): _TOKEN_NAMES[i] = f"sld_{p}" for t, i in TCH_TO_ID.items(): _TOKEN_NAMES[i] = f"tch_{t}" _TOKEN_NAMES[SIM_COUNT_2] = "sim_count_2" for v, i in DUR_NUM_TO_ID.items(): _TOKEN_NAMES[i] = f"dur_num_{v}" for v, i in DUR_DEN_TO_ID.items(): _TOKEN_NAMES[i] = f"dur_den_{v}" for spec, i in CONFIG_TO_ID.items(): _TOKEN_NAMES[i] = "cfg_" + "_".join(str(x) for x in spec) def token_name(token_id: int) -> str: """Human-readable name for a token ID.""" return _TOKEN_NAMES.get(token_id, f"<{token_id}>") def _nearest(values: list[int], value: int) -> int: return min(values, key=lambda x: abs(x - value)) def encode_duration_tokens(duration: tuple[int, int]) -> list[int]: beat = _nearest(_DUR_NUM_VALUES, max(1, int(duration[0]))) den = _nearest(_DUR_DEN_VALUES, max(1, int(duration[1]))) return [DUR, DUR_NUM_TO_ID[beat], DUR_DEN_TO_ID[den]] def read_duration_tokens(tokens: list[int], start: int) -> Optional[tuple[int, int]]: if start + 2 >= len(tokens) or tokens[start] != DUR: return None num_tok = tokens[start + 1] den_tok = tokens[start + 2] if num_tok in ID_TO_DUR_NUM and den_tok in ID_TO_DUR_DEN: return ID_TO_DUR_NUM[num_tok], ID_TO_DUR_DEN[den_tok] # Backward compatibility with old checkpoints/caches that used raw ints. beat = _nearest(_DUR_NUM_VALUES, max(1, min(int(num_tok), 16))) den = _nearest(_DUR_DEN_VALUES, max(1, int(den_tok))) return beat, den def make_sim_tokens(note_tokens: list[int]) -> list[int]: note_tokens = [t for t in note_tokens if t not in (PAD, BOS, EOS)] if len(note_tokens) <= 1: return note_tokens result: list[int] = [] result.extend([SIM_BEG, SIM_COUNT_2, note_tokens[0], note_tokens[1], SIM_END]) result.extend(note_tokens[2:]) return result def _snap_config_duration(duration: tuple[int, int] | None) -> tuple[int, int]: if not duration: return (1, 1) n = _nearest(_DUR_NUM_VALUES, max(1, int(duration[0]))) d = _nearest(_DUR_DEN_VALUES, max(1, int(duration[1]))) if n / d > 4.0: return min(CONFIG_DURATIONS, key=lambda x: (abs((x[0] / x[1]) - 4.0), x[1])) return n, d def config_token_for_note(note: TouchNote) -> int | None: if note.is_rest or note.is_end: return None if note.is_touch and len(note.touch_regions) == 1 and not note.is_hold: return CONFIG_TO_ID.get(("tch", note.touch_regions[0])) if note.is_slide: path = note.slide_path or note.positions if len(path) >= 2: dur = _snap_config_duration(note.hold_duration) spec = ("sld", *path, dur[0], dur[1]) if spec in CONFIG_TO_ID: return CONFIG_TO_ID[spec] return None if len(note.positions) == 1: pos = note.positions[0] if not (1 <= pos <= 8): return None if note.is_hold: dur = _snap_config_duration(note.hold_duration) return CONFIG_TO_ID.get(("hld", pos, dur[0], dur[1])) if note.is_break: return CONFIG_TO_ID.get(("brk", pos)) return CONFIG_TO_ID.get(("tap", pos)) if len(note.positions) == 2: p1, p2 = sorted(note.positions) if not (1 <= p1 <= 8 and 1 <= p2 <= 8): return None typ = "hld" if note.is_hold else "brk" if note.is_break else "tap" if typ == "hld": dur = _snap_config_duration(note.hold_duration) return CONFIG_TO_ID.get(("pair", "hld", p1, "hld", p2, dur[0], dur[1])) return CONFIG_TO_ID.get(("pair", typ, p1, typ, p2)) return None def learn_config_from_note(note: TouchNote) -> int | None: """Register a config token from a real chart note, preserving rare shapes.""" if note.is_rest or note.is_end: return None existing = config_token_for_note(note) if existing is not None: return existing if note.is_touch and note.touch_regions and not note.is_hold: return _add_config(("touch_multi", *sorted(note.touch_regions))) if note.is_slide: path = note.slide_path or note.positions if len(path) >= 2: dur = _snap_config_duration(note.hold_duration) return _add_config(("sld", *path, dur[0], dur[1])) return None if len(note.positions) >= 2: positions = sorted(p for p in note.positions if 1 <= p <= 8) if len(positions) < 2: return None typ = "hld" if note.is_hold else "brk" if note.is_break else "tap" if note.is_hold: dur = _snap_config_duration(note.hold_duration) spec = ("multi", typ, *positions, dur[0], dur[1]) else: spec = ("multi", typ, *positions) if spec in CONFIG_TO_ID: return CONFIG_TO_ID[spec] if len(positions) == 2: return config_token_for_note(note) return _add_config(spec) return None def learn_config_vocab_from_charts(charts) -> int: before = len(CONFIG_TO_ID) for chart in charts: for note in chart.notes: learn_config_from_note(note) return len(CONFIG_TO_ID) - before def note_from_config_token(token_id: int, beat_div: int) -> TouchNote | None: spec = ID_TO_CONFIG.get(token_id) if spec is None: return None note = TouchNote(beat_div=beat_div) kind = spec[0] if kind == "tap": note.positions = [int(spec[1])] elif kind == "brk": note.positions = [int(spec[1])] note.is_break = True elif kind == "hld": note.positions = [int(spec[1])] note.is_hold = True note.hold_duration = (int(spec[2]), int(spec[3])) elif kind == "tch": note.is_touch = True note.touch_regions = [str(spec[1])] elif kind == "touch_multi": note.is_touch = True note.touch_regions = [str(x) for x in spec[1:]] note.is_simultaneous = len(note.touch_regions) > 1 elif kind == "pair": t1, p1, t2, p2 = spec[1], int(spec[2]), spec[3], int(spec[4]) note.positions = [p1, p2] note.is_simultaneous = True if t1 == "hld" or t2 == "hld": note.is_hold = True note.hold_duration = (int(spec[5]), int(spec[6])) elif t1 == "brk" or t2 == "brk": note.is_break = True elif kind == "sld": *path, n, d = spec[1:] note.positions = [int(x) for x in path] note.slide_path = list(note.positions) note.is_slide = True note.hold_duration = (int(n), int(d)) elif kind == "multi": typ = spec[1] if typ == "hld": *positions, n, d = spec[2:] note.positions = [int(x) for x in positions] note.is_hold = True note.hold_duration = (int(n), int(d)) else: note.positions = [int(x) for x in spec[2:]] note.is_break = typ == "brk" note.is_simultaneous = len(note.positions) > 1 else: return None return note # ═══════════════════════════════════════════════════════════════════════ # Tokenizer class # ═══════════════════════════════════════════════════════════════════════ class MaiChartTokenizer: """ Rule-based bidirectional tokenizer for maimai charts. encode(chart) → list[int] # chart → tokens decode(tokens) → Chart # tokens → chart (lossless) """ vocab_size: int = VOCAB_SIZE pad_token_id: int = PAD bos_token_id: int = BOS eos_token_id: int = EOS mask_token_id: int = MASK # ── Encode ────────────────────────────────────────────────────── def encode(self, chart: Chart, add_bos: bool = True, add_eos: bool = True) -> list[int]: """ Convert a Chart's notes into a token sequence. Args: chart: Parsed Chart from mai_parser. add_bos: Prepend [BOS] token. add_eos: Append [EOS] token. Returns: List of token IDs. """ tokens: list[int] = [] if add_bos: tokens.append(BOS) current_div = 4 # default beat division for note in chart.notes: # Update beat division if changed if note.beat_div != current_div: current_div = note.beat_div div_id = DIV_TO_ID.get(current_div) if div_id is not None: tokens.append(div_id) # Encode the note tokens.extend(self._encode_note(note)) if add_eos: tokens.append(EOS) return tokens def _encode_note(self, note: TouchNote) -> list[int]: """Encode a single TouchNote → list of token IDs.""" if note.is_end: return [EOS] if note.is_rest: return [RST] cfg = config_token_for_note(note) if cfg is not None: return [cfg] # Touch note if note.is_touch: return self._encode_touch(note) # Break note if note.is_break: result = [] for pos in note.positions: if 1 <= pos <= 8: result.append(BRK_TO_ID[pos]) return make_sim_tokens(result) # Hold note if note.is_hold: result = [] for pos in note.positions: if 1 <= pos <= 8: result.append(HLD_TO_ID[pos]) result = make_sim_tokens(result) if note.hold_duration: result.extend(encode_duration_tokens(note.hold_duration)) return result # Slide note if note.is_slide: result = [] # Collect all positions (slide path) positions = list(note.positions) if note.slide_path: # Use slide_path if available (more accurate) positions = note.slide_path if len(positions) >= 2: result.append(SLD_BEG) result.append(len(positions)) for pos in positions: if 1 <= pos <= 8: result.append(SLD_TO_ID[pos]) result.append(SLD_END_TOKEN) elif len(positions) == 1 and 1 <= positions[0] <= 8: result.append(SLD_TO_ID[positions[0]]) if note.hold_duration: result.extend(encode_duration_tokens(note.hold_duration)) return result # Regular tap if len(note.positions) > 1: result = [] for pos in note.positions: if 1 <= pos <= 8: result.append(TAP_TO_ID[pos]) return make_sim_tokens(result) # Single tap for pos in note.positions: if 1 <= pos <= 8: return [TAP_TO_ID[pos]] return [RST] # fallback def _encode_touch(self, note: TouchNote) -> list[int]: """Encode a touch note.""" result = [] for region in note.touch_regions: tid = TCH_TO_ID.get(region) if tid is not None: result.append(tid) if len(result) > 1: result = make_sim_tokens(result) if note.is_hold and note.hold_duration: result.extend(encode_duration_tokens(note.hold_duration)) return result if result else [RST] # ── Decode ────────────────────────────────────────────────────── def decode(self, tokens: list[int]) -> Chart: """ Convert a token sequence back into a Chart. Args: tokens: List of token IDs (may include BOS/EOS). Returns: Reconstructed Chart (notes only; metadata not recoverable from tokens alone). """ notes: list[TouchNote] = [] current_div = 4 i = 0 while i < len(tokens): tid = tokens[i] # Skip BOS if tid == BOS: i += 1 continue # End of sequence if tid == EOS: note = TouchNote(beat_div=current_div, raw="E") note.is_end = True notes.append(note) i += 1 continue # Beat division change if tid in ID_TO_DIV: current_div = ID_TO_DIV[tid] i += 1 continue # Rest if tid == RST: note = TouchNote(beat_div=current_div, raw="") note.is_rest = True notes.append(note) i += 1 continue cfg_note = note_from_config_token(tid, current_div) if cfg_note is not None: notes.append(cfg_note) i += 1 continue # Duration marker → read next 2 tokens as (beat, subdiv) # (handled inline in note decoding below) # Slide start if tid == SLD_BEG: i += 1 if i >= len(tokens): break if tokens[i] in ID_TO_SLD or tokens[i] in (SLD_MID, SLD_ON): # Unfolded slide format used by training/inference: # SLD_BEG [SLD_ON] sld_a [SLD_MID/SLD_ON sld_b ...] SLD_END [DUR beat subdiv] positions = [] while i < len(tokens) and tokens[i] not in (SLD_END_TOKEN, DUR, EOS): if tokens[i] in (SLD_MID, SLD_ON): i += 1 continue pt = tokens[i] if pt in ID_TO_SLD: positions.append(ID_TO_SLD[pt]) i += 1 continue # Malformed slide: stop before consuming unrelated chart events. break if i < len(tokens) and tokens[i] == SLD_END_TOKEN: i += 1 dur = self._read_dur(tokens, i) if dur: i += 3 note = TouchNote(beat_div=current_div, positions=positions) note.is_slide = True note.slide_path = list(positions) if dur: note.hold_duration = dur notes.append(note) continue n_pts = tokens[i] i += 1 positions = [] for _ in range(n_pts): if i >= len(tokens): break pt = tokens[i] if pt in ID_TO_SLD: positions.append(ID_TO_SLD[pt]) i += 1 # Skip SLD_END if i < len(tokens) and tokens[i] == SLD_END_TOKEN: i += 1 # Check for optional duration dur = self._read_dur(tokens, i) if dur: i += 3 # DUR + beat + subdiv note = TouchNote(beat_div=current_div, positions=positions) note.is_slide = True note.slide_path = list(positions) if dur: note.hold_duration = dur notes.append(note) continue # Simultaneous begin if tid == SIM_BEG: i += 1 if i >= len(tokens): break count_tok = tokens[i] n_notes = 2 if count_tok == SIM_COUNT_2 else int(count_tok) i += 1 sub_notes: list[TouchNote] = [] dur = None while i < len(tokens) and tokens[i] not in (SIM_END, EOS): sub_tid = tokens[i] if sub_tid == DUR: dur = self._read_dur(tokens, i) break # DUR after SIM group sub_note = self._decode_single_note(sub_tid, current_div) if sub_note: sub_notes.append(sub_note) i += 1 if i < len(tokens) and tokens[i] == SIM_END: i += 1 # Merge sub-notes into one simultaneous note if sub_notes: merged = sub_notes[0] all_pos = [] has_hold = merged.is_hold has_break = merged.is_break is_touch = merged.is_touch all_touch_regions = list(merged.touch_regions) for sn in sub_notes: all_pos.extend(sn.positions) has_hold = has_hold or sn.is_hold has_break = has_break or sn.is_break is_touch = is_touch or sn.is_touch all_touch_regions.extend(sn.touch_regions) merged.positions = all_pos merged.is_simultaneous = True merged.touch_regions = all_touch_regions merged.is_touch = is_touch if dur: merged.hold_duration = dur # For touch holds, don't set is_hold if not is_touch: has_hold = True merged.is_hold = has_hold and not is_touch merged.is_break = has_break and not is_touch notes.append(merged) continue # Duration marker (standalone, should not normally happen) if tid == DUR: i += 3 # skip DUR + 2 values continue # Slide end, SIM end (standalone — skip) if tid in (SLD_END_TOKEN, SIM_END): i += 1 continue # Single note token note = self._decode_single_note(tid, current_div) if note: # Check if next token is DUR (for hold/slide duration) dur = self._read_dur(tokens, i + 1) if dur: note.hold_duration = dur # Only set is_hold if not already a slide/touch if not note.is_slide and not note.is_touch and not note.is_break: note.is_hold = True i += 3 # skip DUR + beat + subdiv notes.append(note) i += 1 from mai_parser.models import Difficulty chart = Chart(difficulty_index=0, difficulty=Difficulty.ReMASTER) chart.notes = notes chart.compute_stats() return chart def _decode_single_note(self, tid: int, beat_div: int) -> Optional[TouchNote]: """Decode a single note token (not part of a group).""" note = TouchNote(beat_div=beat_div) if tid in ID_TO_TAP: note.positions = [ID_TO_TAP[tid]] return note if tid in ID_TO_BRK: note.positions = [ID_TO_BRK[tid]] note.is_break = True return note if tid in ID_TO_HLD: note.positions = [ID_TO_HLD[tid]] note.is_hold = True return note if tid in ID_TO_SLD: note.positions = [ID_TO_SLD[tid]] note.is_slide = True return note if tid in ID_TO_TCH: region = ID_TO_TCH[tid] note.is_touch = True note.touch_regions = [region] return note return None def _read_dur(self, tokens: list[int], start: int) -> Optional[tuple[int, int]]: """Try to read DUR beat subdiv from tokens[start:]. Returns (beat, subdiv) or None. Clamps to reasonable ranges to filter out hallucinated durations.""" return read_duration_tokens(tokens, start) # ── Batch ─────────────────────────────────────────────────────── def encode_batch(self, charts: list[Chart], pad_to: Optional[int] = None, add_bos: bool = True, add_eos: bool = True, return_tensors: bool = False): """ Encode a batch of charts, padding to the same length. Args: charts: List of Chart objects. pad_to: Pad all sequences to this length (auto-detect max if None). add_bos: Prepend BOS. add_eos: Append EOS. return_tensors: If True, return torch.Tensor (requires torch). Returns: If return_tensors=False: (list[list[int]], list[int]) = (token_seqs, lengths) If return_tensors=True: (Tensor[batch, max_len], Tensor[batch]) """ seqs = [self.encode(c, add_bos=add_bos, add_eos=add_eos) for c in charts] lengths = [len(s) for s in seqs] max_len = max(lengths) if pad_to is None else pad_to padded = [] for seq in seqs: if len(seq) < max_len: seq = seq + [PAD] * (max_len - len(seq)) padded.append(seq[:max_len]) if return_tensors: try: import torch return torch.tensor(padded, dtype=torch.long), torch.tensor(lengths, dtype=torch.long) except ImportError: raise ImportError("torch required for return_tensors=True") return padded, lengths # ── Debug ─────────────────────────────────────────────────────── def tokens_to_str(self, tokens: list[int], max_show: int = 60) -> str: """Pretty-print a token sequence with context for raw parameter ids.""" parts = [] i = 0 shown = 0 while i < len(tokens) and shown < max_show: tid = tokens[i] if tid == DUR and i + 2 < len(tokens): parts.append("[DUR]") shown += 1 if shown < max_show: parts.append(token_name(tokens[i + 1])) shown += 1 if shown < max_show: parts.append(token_name(tokens[i + 2])) shown += 1 i += 3 continue if tid == SIM_BEG and i + 1 < len(tokens): parts.append("[SIM_BEG]") shown += 1 if shown < max_show: parts.append(token_name(tokens[i + 1])) shown += 1 i += 2 continue parts.append(token_name(tid)) shown += 1 i += 1 if i < len(tokens): parts.append(f"... ({len(tokens) - i} more)") return " ".join(parts) def print_tokens(self, tokens: list[int], max_show: int = 60) -> None: """Print a token sequence.""" print(self.tokens_to_str(tokens, max_show)) # ═══════════════════════════════════════════════════════════════════════ # Metadata header builder # ═══════════════════════════════════════════════════════════════════════ def build_metadata_header(bpm: float, difficulty: int, level_value: float, genre: int = 0) -> list[int]: """ Build a metadata header token sequence. Format: [META_BPM] bpm_byte [META_DIFF] diff [META_LEVEL] level_byte [META_GENRE] genre [META_END] This is prepended to chart tokens during training so the model learns to associate metadata with chart style. Args: bpm: BPM value (e.g. 173.0) difficulty: 0=BASIC..4=ReMASTER level_value: e.g. 12.4 genre: Genre index Returns: List of token IDs. """ return [ META_BPM, int(bpm) // 2, # BPM 0-510 → 0-255 META_DIFF, difficulty, META_LEVEL, min(255, int(level_value * 10)), META_GENRE, genre, META_END, ] def encode_chart_with_header(chart: Chart, bpm: float, difficulty: int, level_value: float, genre: int = 0) -> list[int]: """Encode chart with grammar-friendly slides (no metadata header, no EOS). Metadata (BPM, difficulty, level, genre) is passed as separate condition inputs to the model — NOT as chart tokens. The model learns difficulty from the diff_embed MoE routing, not from token-level metadata. HLD_ON/SLD_ON and SLD_MID are inference context/helper tokens, not targets. Returns: [BOS] + chart_tokens """ tok = MaiChartTokenizer() chart_tokens = tok.encode(chart, add_bos=False, add_eos=False) # add_eos=False avoids appending a synthetic EOS, but parsed charts may # contain a terminal end note. Strip only terminal EOS tokens: raw numeric # values 1/2 are also used as SIM counts, so removing all EOS ids corrupts # simultaneous groups. while chart_tokens and chart_tokens[-1] == EOS: chart_tokens.pop() chart_tokens = unfold_slides(chart_tokens) return [BOS] + chart_tokens def unfold_slides(tokens): """Unfold multi-segment slides into grammar-friendly waypoint tokens. SLD_BEG n sld_a sld_b sld_c SLD_END → SLD_BEG sld_a sld_b sld_c SLD_END """ result, i = [], 0 while i < len(tokens): t = tokens[i] if t == SLD_BEG and i + 2 < len(tokens): n = tokens[i + 1] if 0 < n < 32 and i + 2 + n < len(tokens): pts = tokens[i + 2 : i + 2 + n] result.append(SLD_BEG) result.extend(pts) result.append(SLD_END_TOKEN) i += 2 + n + 1; continue result.append(t); i += 1 return result def inject_ongoing_tokens(tokens: list[int]) -> list[int]: """Insert HLD_ON/SLD_ON markers at intermediate positions where a hold/slide is active. HLD_n DUR beat subdiv ...tokens... → HLD_ON inserted at each non-DUR position while the hold is active. Same for slides. These are informational — the model learns "a hold is ongoing here". During inference they are suppressed; the engine doesn't generate them. """ result = [] current_div = 4.0 hold_beats = 0.0 # remaining beats of active hold slide_beats = 0.0 # remaining beats of active slide dur_skip = 0 # skip DUR parameter tokens i = 0 while i < len(tokens): t = tokens[i] step = 4.0 / current_div # ── Inject ON tokens before note-level tokens ── _is_note = (t >= TAP_BASE and t != DUR) or t == RST or t in ID_TO_DIV if _is_note and dur_skip == 0: if hold_beats > 0: result.append(HLD_ON) hold_beats -= step if slide_beats > 0: result.append(SLD_ON) slide_beats -= step # ── Track hold/slide duration ── if dur_skip > 0: dur_skip -= 1 result.append(t); i += 1 continue if t in ID_TO_DIV: current_div = float(ID_TO_DIV.get(t, current_div)) elif t == DUR: dur_skip = 2 elif t in HLD_TO_ID: # Check if followed by DUR if i + 3 < len(tokens) and tokens[i + 1] == DUR: beat = tokens[i + 2] subdiv = max(tokens[i + 3], 1) hold_beats = beat / subdiv elif t == SLD_BEG: # Find DUR after slide waypoints j = i + 2 # skip SLD_BEG + count while j < len(tokens) and tokens[j] != SLD_END_TOKEN and tokens[j] != DUR: j += 1 if j < len(tokens) and tokens[j] == SLD_END_TOKEN: j += 1 # skip SLD_END if j + 2 < len(tokens) and tokens[j] == DUR: beat = tokens[j + 1] subdiv = max(tokens[j + 2], 1) slide_beats = beat / subdiv result.append(t) i += 1 return result # ═══════════════════════════════════════════════════════════════════════ # Chart → maidata text conversion # ═══════════════════════════════════════════════════════════════════════ def notes_to_maitext(notes, bpm=150.0): """Convert TouchNote list back to maidata chart text. Format: (173){4}1,2,3h[4:1],5b/8b, """ bpm_int = int(bpm) current_div = 4 line = f"({bpm_int})" measure = [] for note in notes: if note.is_end: if measure: line += "{" + str(current_div) + "}" + ",".join(measure) + "," return line + "\nE" if note.beat_div != current_div: if measure: line += "{" + str(current_div) + "}" + ",".join(measure) + "," measure = [] else: line += "{" + str(current_div) + "}," current_div = note.beat_div measure.append(_note_to_text(note)) if measure: line += "{" + str(current_div) + "}" + ",".join(measure) + "," return line + "\nE" def tokens_to_maitext(tokens, bpm=150.0): """Token sequence → maidata.txt chart text.""" tok = MaiChartTokenizer() chart = tok.decode(tokens) return notes_to_maitext(chart.notes, bpm) def _note_to_text(note): """Single TouchNote → maidata text segment.""" if note.is_rest: return "" if note.is_touch: # Normalize touch regions: C1..C8 → C, others keep (B7, E2, etc.) regions = [] for r in note.touch_regions: if r.startswith("C") and len(r) > 1: regions.append("C") else: regions.append(r) text = "/".join(regions) if note.is_hold and note.hold_duration: text += f"h[{note.hold_duration[0]}:{note.hold_duration[1]}]" return text if note.positions: text = "/".join(str(p) for p in note.positions) else: return "" if note.is_hold and note.hold_duration: text += f"h[{note.hold_duration[0]}:{note.hold_duration[1]}]" elif note.is_break: text += "b" elif note.is_slide and note.slide_path and len(note.slide_path) >= 2: # Multi-segment slide: all use > start = note.slide_path[0] seg_start = 1 while seg_start < len(note.slide_path) and note.slide_path[seg_start] == start: seg_start += 1 if seg_start >= len(note.slide_path): return str(start) text = str(start) last_pos = start for p in note.slide_path[seg_start:]: if p != last_pos: text += ">" + str(p) last_pos = p dur = note.hold_duration or (4, 1) text += f"[{dur[0]}:{dur[1]}]" elif note.is_slide and len(note.positions) >= 2: text = str(note.positions[0]) + ">" + str(note.positions[1]) dur = note.hold_duration or (4, 1) text += f"[{dur[0]}:{dur[1]}]" if note.firework: text += "x" if note.is_star: text += "*" return text