""" Inference script for MaiGenerator — generate maimai charts from audio. Usage: python inference.py --checkpoint checkpoints/best.pt --audio datasets/10/track.mp3 python inference.py --checkpoint checkpoints/best.pt --song-id 10 python inference.py --checkpoint checkpoints/best.pt --audio track.mp3 --bpm 173 --diff MASTER --level 12.4 """ from __future__ import annotations import argparse import json from pathlib import Path import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from model import MaiGenerator, BOS, EOS, DUR_TOKEN, AUDIO_FRAME_RATE from Tokenizer.MaiTrackTokenizer import MaiTrackTokenizer from tokenizer import ( MaiChartTokenizer, tokens_to_maitext, PAD, SIM_BEG, RST, TAP_TO_ID, BRK_TO_ID, HLD_TO_ID, SLD_TO_ID, TCH_TO_ID, ) from tokenizer import ID_TO_DUR_NUM, ID_TO_DUR_DEN, DUR_NUM_TO_ID, DUR_DEN_TO_ID from tokenizer import load_config_vocab, ID_TO_CONFIG from mai_parser import Difficulty, parse_maidata_file from grammar import NOTE_TOKENS from constrained_decode import ( DecodeState, snap_duration, clamp_duration_tokens, duration_report, simultaneous_report, sanitize_sim_tokens, ) def sample_token(logits: torch.Tensor, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, prev_tokens: list[int] | None = None, penalty_window: int = 50) -> int: """Sample a token with temperature, top-k, top-p, and repetition penalty. Args: repetition_penalty: >1.0 penalizes repeated tokens; 1.0 = off. prev_tokens: Recent token history for repetition detection. penalty_window: How many recent tokens to check for repeats. """ # Repetition penalty: reduce logits for tokens that appear in recent history if repetition_penalty > 1.0 and prev_tokens: recent = set(prev_tokens[-penalty_window:]) for tok in recent: logits[tok] = logits[tok] / repetition_penalty logits = logits / max(temperature, 1e-8) logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9) if top_k > 0: k = min(top_k, logits.size(-1)) min_val = torch.topk(logits, k).values[-1] logits[logits < min_val] = float("-inf") if top_p < 1.0: sorted_l, sorted_i = torch.sort(logits, descending=True) cum = torch.cumsum(F.softmax(sorted_l, dim=-1), dim=-1) mask = cum > top_p mask[1:] = mask[:-1].clone() mask[0] = False sorted_l[mask] = float("-inf") logits = sorted_l.scatter(0, sorted_i, sorted_l) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, 1).item() def _config_note_count(spec: tuple) -> int: if not spec: return 0 kind = spec[0] if kind == "pair": return 2 if kind == "multi": return max(1, len(spec) - (4 if spec[1] == "hld" else 2)) if kind == "touch_multi": return max(1, len(spec) - 1) return 1 def _token_note_count(tok: int) -> int: if tok in NOTE_TOKENS: return 1 spec = ID_TO_CONFIG.get(tok) if spec is not None: return _config_note_count(spec) return 0 def _config_has_break(spec: tuple) -> bool: if not spec: return False if spec[0] == "brk": return True if spec[0] == "pair": return spec[1] == "brk" or spec[3] == "brk" if spec[0] == "multi": return len(spec) > 1 and spec[1] == "brk" return False def _config_type_counts(spec: tuple) -> dict[str, int]: if not spec: return {} kind = spec[0] if kind == "tap": return {"tap": 1} if kind == "brk": return {"break": 1} if kind == "hld": return {"hold": 1} if kind == "sld": return {"slide": 1} if kind in ("tch", "touch_multi"): return {"touch": 1} if kind == "pair": counts: dict[str, int] = {} for typ in (spec[1], spec[3]): key = "break" if typ == "brk" else "hold" if typ == "hld" else "tap" counts[key] = counts.get(key, 0) + 1 return counts if kind == "multi": typ = spec[1] key = "break" if typ == "brk" else "hold" if typ == "hld" else "tap" if typ == "hld": n = max(1, len(spec) - 4) else: n = max(1, len(spec) - 2) return {key: n} return {} def _token_has_break(tok: int) -> bool: if tok in BRK_TO_ID.values(): return True spec = ID_TO_CONFIG.get(tok) return spec is not None and _config_has_break(spec) def _break_ids(vocab_size: int, device: torch.device) -> torch.Tensor: ids = set(BRK_TO_ID.values()) ids.update(t for t, spec in ID_TO_CONFIG.items() if _config_has_break(spec)) valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] return torch.tensor(valid, dtype=torch.long, device=device) def _density_note_ids(vocab_size: int, device: torch.device, include_break: bool = False) -> torch.Tensor: ids = set(NOTE_TOKENS) | set(ID_TO_CONFIG.keys()) if not include_break: ids = {t for t in ids if not _token_has_break(int(t))} valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] return torch.tensor(valid, dtype=torch.long, device=device) def _type_bias_vector(vocab_size: int, device: torch.device, type_biases: dict[str, float]) -> torch.Tensor | None: active = {k: float(v) for k, v in type_biases.items() if abs(float(v)) > 1e-9} if not active: return None bias = torch.zeros(vocab_size, dtype=torch.float32, device=device) token_groups = [ ("tap", TAP_TO_ID.values()), ("hold", HLD_TO_ID.values()), ("slide", SLD_TO_ID.values()), ("break", BRK_TO_ID.values()), ("touch", TCH_TO_ID.values()), ] for name, ids in token_groups: value = active.get(name, 0.0) if value == 0.0: continue valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] if valid: bias[torch.tensor(valid, dtype=torch.long, device=device)] += value if active.get("rest", 0.0) and 0 <= RST < vocab_size: bias[RST] += active["rest"] if active.get("sim", 0.0) and 0 <= SIM_BEG < vocab_size: bias[SIM_BEG] += active["sim"] for token_id, spec in ID_TO_CONFIG.items(): if not (0 <= int(token_id) < vocab_size): continue for name, count in _config_type_counts(spec).items(): value = active.get(name, 0.0) if value: bias[int(token_id)] += value * max(1, int(count)) return bias def _config_density_bias_vector(vocab_size: int, device: torch.device, strength: float, include_break: bool = False) -> torch.Tensor | None: if abs(float(strength)) <= 1e-9: return None bias = torch.zeros(vocab_size, dtype=torch.float32, device=device) for token_id, spec in ID_TO_CONFIG.items(): if not (0 <= int(token_id) < vocab_size): continue if not include_break and _config_has_break(spec): continue extra_notes = max(0, _config_note_count(spec) - 1) if extra_notes > 0: bias[int(token_id)] += float(strength) * extra_notes return bias def _auto_target_notes(level_value: float, duration_sec: float) -> int: # Conservative maimai-ish density prior. It is only used when density bias is enabled. notes_per_sec = float(np.interp( level_value, [10.0, 12.0, 13.0, 14.0, 14.5, 15.0], [2.2, 3.2, 4.0, 4.8, 5.3, 5.8], )) return max(1, int(round(notes_per_sec * duration_sec))) @torch.inference_mode() def generate_chart( model: MaiGenerator, audio_tokens: torch.Tensor, bpm: float, difficulty: int, level_value: float, genre: int = 0, max_tokens: int = 4096, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.95, eos_suppress: float = 1.5, min_division: int = 4, max_duration_beats: float = 4.0, max_sim: int = 2, sim_penalty: float = 1.0, repetition_penalty: float = 1.0, rest_streak_penalty: float = 0.0, rest_streak_threshold: int = 4, target_notes: int | None = None, density_bias: float = 0.0, break_penalty: float = 0.0, break_window: int = 64, max_break_ratio: float = 0.08, type_biases: dict[str, float] | None = None, config_density_bias: float = 0.0, config_density_include_break: bool = False, ) -> list[int]: """ Autoregressive chart generation from audio (with KV-cache for speed). Args: model: Trained MaiGenerator. audio_tokens: [1, T_aud] — EnCodec tokens. bpm: BPM value. difficulty: 0=BASIC..4=ReMASTER. level_value: e.g. 12.4. genre: Genre index. max_tokens: Max tokens to generate. temperature: Sampling temperature. top_k: Top-K filter. top_p: Nucleus filter. Returns: List of chart token IDs (including BOS, metadata header, and EOS). """ model.eval() device = next(model.parameters()).device param_dtype = next(model.parameters()).dtype if audio_tokens.dim() == 1: audio_tokens = audio_tokens.unsqueeze(0) audio_tokens = audio_tokens.to(device) # ── Encode audio (once) ── max_len = model.audio_pos_embed.num_embeddings # 32768 if audio_tokens.shape[1] > max_len: audio_tokens = audio_tokens[:, :max_len] T_aud = audio_tokens.shape[1] aud = model.audio_embed(audio_tokens) aud = aud + model.audio_pos_embed( torch.arange(T_aud, device=device).unsqueeze(0)) # Downsample audio (same as training forward) if model.audio_downsample > 1: aud = aud.transpose(1, 2) aud = model.audio_down(aud) aud = aud.transpose(1, 2) T_aud = aud.shape[1] aud = model.dropout(aud) aud_pos = torch.arange(T_aud, device=device, dtype=torch.float32) for blk in model.audio_encoder: aud = blk(aud, positions=aud_pos) aud = model.enc_norm(aud) # ── Conditions ── bpm_t = torch.tensor([[bpm]], device=device, dtype=param_dtype) diff_t = torch.tensor([[difficulty]], device=device) level_t = torch.tensor([[level_value]], device=device, dtype=param_dtype) genre_t = torch.tensor([[genre]], device=device) cond = (model.bpm_proj(bpm_t).unsqueeze(1) + model.diff_embed(diff_t.squeeze(-1)).unsqueeze(1) + model.level_proj(level_t).unsqueeze(1) + model.genre_embed(genre_t.squeeze(-1)).unsqueeze(1)) diff_vec = model.diff_embed(diff_t.squeeze(-1)) # [1, d_model] # Total beats audio_duration_sec = audio_tokens.shape[1] / 2 / AUDIO_FRAME_RATE total_beats = audio_duration_sec * bpm / 60.0 start_offset_beats = 2.0 * max(bpm, 30.0) / 60.0 print(f"Track: {audio_duration_sec:.0f}s = {total_beats:.0f} beats @ {bpm} BPM (start at beat {start_offset_beats:.1f})") if density_bias > 0 and target_notes is None: target_notes = _auto_target_notes(level_value, audio_duration_sec) if density_bias > 0 and target_notes: print(f"Density target: {target_notes} notes bias={density_bias}") # ── Init KV-cache ── model.init_kv_cache(batch_size=1, max_len=max_tokens + 1, enc_out=aud) # ── Prefill with BOS token ── tokens = [BOS] pos_cache = [0.0] tok_t = torch.tensor([[BOS]], device=device) emb = model.chart_embed(tok_t) + cond # cond adds batch, tok_t is [1,1] chart_pos = torch.tensor([[0.0]], device=device, dtype=torch.float32) x = emb for blk in model.chart_decoder: x = blk(x, enc_out=aud, self_positions=chart_pos, diff_emb=diff_vec, use_cache=True) # ── Autoregressive loop (KV-cached, one token at a time) ── decode_state = DecodeState( bpm=bpm, total_beats=total_beats, start_offset_beats=start_offset_beats, min_division=min_division, max_sim=max_sim, max_duration_beats=max_duration_beats, ) note_bonus_ids = _density_note_ids(model.chart_vocab_size, device, include_break=False) if density_bias > 0 else None break_ids = _break_ids(model.chart_vocab_size, device) if break_penalty > 0 else None type_bias = _type_bias_vector(model.chart_vocab_size, device, type_biases or {}) config_density_bias_vec = _config_density_bias_vector( model.chart_vocab_size, device, config_density_bias, include_break=config_density_include_break, ) if type_bias is not None: print("Type bias:", ", ".join(f"{k}={v:+.2f}" for k, v in (type_biases or {}).items() if abs(v) > 1e-9)) if config_density_bias_vec is not None: print(f"Config density bias: {config_density_bias:+.2f} " f"(include_break={config_density_include_break})") generated_notes = 0 generated_breaks = 0 pbar = tqdm(total=max_tokens, desc="Generating", unit="tok") for step_idx in range(max_tokens): x_last = model.dec_norm(x[:, -1:, :]) logits = model.output_head(x_last).squeeze(0).squeeze(0) # ── Constraints ── logits[EOS] = float("-inf") if decode_state.current_beat >= total_beats: pbar.update(1); pbar.set_postfix(tokens=len(tokens)+1, beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", end="track") tokens.append(EOS) break logits = decode_state.apply_logits_mask(logits) if type_bias is not None: logits = logits + type_bias.to(dtype=logits.dtype) if config_density_bias_vec is not None: logits = logits + config_density_bias_vec.to(dtype=logits.dtype) if decode_state.can_start_sim() and SIM_BEG < logits.numel(): logits[SIM_BEG] -= sim_penalty if not torch.isfinite(logits).any(): pbar.update(1); pbar.set_postfix(tokens=len(tokens)+1, beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", end="no_legal") tokens.append(EOS) break # ── Rest-streak penalty: if last N notes are all rests, suppress RST ── if rest_streak_penalty > 0 and len(tokens) >= rest_streak_threshold: recent = tokens[-rest_streak_threshold:] # Check if recent tokens are all RST (not counting structural tokens like DUR/SIM etc.) if all(t == RST for t in recent): logits[RST] -= rest_streak_penalty if density_bias > 0 and target_notes and note_bonus_ids is not None and note_bonus_ids.numel() > 0: progress = max(0.0, decode_state.current_beat - start_offset_beats) denom = max(1e-6, total_beats - start_offset_beats) expected = float(target_notes) * min(1.0, progress / denom) deficit = max(0.0, expected - generated_notes) / max(expected, 1.0) if deficit > 0: bonus = density_bias * deficit logits[note_bonus_ids] += bonus if RST < logits.numel(): logits[RST] -= bonus if break_penalty > 0 and break_ids is not None and break_ids.numel() > 0: recent = tokens[-break_window:] if break_window > 0 else tokens recent_notes = sum(_token_note_count(t) for t in recent) recent_breaks = sum(_token_note_count(t) for t in recent if _token_has_break(t)) global_ratio = generated_breaks / max(generated_notes, 1) recent_ratio = recent_breaks / max(recent_notes, 1) excess = max(global_ratio - max_break_ratio, recent_ratio - max_break_ratio, 0.0) if excess > 0: logits[break_ids] -= break_penalty * (1.0 + 8.0 * excess) next_tok = sample_token(logits, temperature, top_k, top_p, repetition_penalty=repetition_penalty, prev_tokens=tokens) tokens.append(next_tok) note_count = _token_note_count(next_tok) generated_notes += note_count if _token_has_break(next_tok): generated_breaks += note_count # ── Snap duration ── if len(tokens) >= 4 and tokens[-4] not in (DUR_TOKEN, PAD) and tokens[-3] == DUR_TOKEN: raw_num = ID_TO_DUR_NUM.get(tokens[-2], tokens[-2]) raw_den = ID_TO_DUR_DEN.get(next_tok, max(next_tok, 1)) beat, subdiv = snap_duration(raw_num, raw_den, decode_state.remaining_duration_beats(), decode_state.max_duration_beats) tokens[-2], tokens[-1] = DUR_NUM_TO_ID[beat], DUR_DEN_TO_ID[subdiv] next_tok = tokens[-1] # ── Record position ── bpm_v = max(bpm, 30.0) time_sec = decode_state.current_beat * 60.0 / bpm_v new_pos = time_sec * AUDIO_FRAME_RATE / model.audio_downsample pos_cache.append(new_pos) decode_state.step(next_tok) # ── Incremental decoder step (KV-cached) ── tok_t = torch.tensor([[next_tok]], device=device) emb = model.chart_embed(tok_t) + cond chart_pos = torch.tensor([[new_pos]], device=device, dtype=torch.float32) x = emb for blk in model.chart_decoder: x = blk(x, enc_out=aud, self_positions=chart_pos, diff_emb=diff_vec, use_cache=True) pbar.update(1) if len(tokens) % 200 == 0: pct = min(100, decode_state.current_beat / total_beats * 100) pbar.set_postfix(beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", pct=f"{pct:.0f}%", div=f"{decode_state.div_value}") pbar.close() tokens = sanitize_sim_tokens(tokens, max_sim=max_sim) return clamp_duration_tokens(tokens, total_beats, start_offset_beats, min_division) # ═══════════════════════════════════════════════════════════════════════ # CLI # ═══════════════════════════════════════════════════════════════════════ DIFF_MAP = {"BASIC": 0, "ADVANCED": 1, "EXPERT": 2, "MASTER": 3, "REMASTER": 4} def main(): parser = argparse.ArgumentParser(description="Generate maimai chart from audio") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") parser.add_argument("--audio", type=str, default=None, help="Path to audio file (mp3/wav)") parser.add_argument("--song-id", type=str, default=None, help="Song ID in datasets/ (uses datasets//track.mp3)") parser.add_argument("--bpm", type=float, default=None, help="BPM (auto-detect from maidata.txt if --song-id)") parser.add_argument("--diff", type=str, default="MASTER", choices=list(DIFF_MAP.keys())) parser.add_argument("--level", type=float, default=12.0, help="Target level value") parser.add_argument("--genre", type=int, default=0) parser.add_argument("--max-tokens", type=int, default=4096) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-k", type=int, default=50) parser.add_argument("--top-p", type=float, default=0.95) parser.add_argument("--eos-suppress", type=float, default=1.5, help="Subtract from EOS logit (0=off)") parser.add_argument("--min-division", type=int, default=4, help="Minimum generated beat division; blocks div_1/div_2 by default") parser.add_argument("--max-duration-beats", type=float, default=4.0, help="Maximum generated hold/slide duration in beats") parser.add_argument("--max-sim", type=int, default=2, help="Maximum simultaneous press count") parser.add_argument("--sim-penalty", type=float, default=1.0, help="Subtract from SIM_BEG logit before constrained sampling") parser.add_argument("--rep-penalty", type=float, default=1.0, help="Repetition penalty (>1.0=penalize repeats, 1.0=off)") parser.add_argument("--rest-streak-penalty", type=float, default=0.0, help="Penalize RST when last N tokens are all rests (try 2.0-5.0)") parser.add_argument("--rest-streak-threshold", type=int, default=4, help="Consecutive RSTs before applying rest-streak-penalty") parser.add_argument("--target-notes", type=int, default=None, help="Optional soft target note count for density-guided sampling") parser.add_argument("--density-bias", type=float, default=0.0, help="Softly boost note/config tokens when generated density lags target") parser.add_argument("--break-penalty", type=float, default=1.5, help="Type-level penalty when BREAK/绝赞 ratio is too high") parser.add_argument("--break-window", type=int, default=64, help="Recent token window used by BREAK ratio penalty") parser.add_argument("--max-break-ratio", type=float, default=0.08, help="Soft maximum BREAK/绝赞 ratio before penalty kicks in") parser.add_argument("--tap-bias", type=float, default=0.0, help="Logit bias for TAP config/tokens (+ increases, - decreases)") parser.add_argument("--hold-bias", type=float, default=0.0, help="Logit bias for HOLD config/tokens") parser.add_argument("--slide-bias", type=float, default=0.0, help="Logit bias for SLIDE config/tokens") parser.add_argument("--break-bias", type=float, default=0.0, help="Logit bias for BREAK/绝赞 config/tokens") parser.add_argument("--touch-bias", type=float, default=0.0, help="Logit bias for TOUCH config/tokens") parser.add_argument("--rest-bias", type=float, default=0.0, help="Logit bias for REST tokens") parser.add_argument("--sim-bias", type=float, default=0.0, help="Logit bias for simultaneous group start") parser.add_argument("--config-density-bias", type=float, default=0.0, help="Bias dense config tokens by extra notes per time slot") parser.add_argument("--config-density-include-break", action="store_true", help="Allow config-density-bias to boost BREAK/绝赞 configs too") parser.add_argument("--output", type=str, default=None, help="Output path for generated maidata.txt") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--precision", type=str, default="float32", choices=["float32", "float16", "bfloat16"], help="Model precision (float16/bfloat16 ≈2x faster on supported GPUs)") parser.add_argument("--compile", action="store_true", help="Use torch.compile for faster execution (PyTorch 2.0+)") args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # ── Load model ── ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) config = ckpt.get("config", {}) if "config_vocab" in config: load_config_vocab(config["config_vocab"]) print(f"Loaded config vocab: total={MaiChartTokenizer.vocab_size}") model = MaiGenerator( d_model=config.get("d_model", 512), enc_layers=config.get("enc_layers", 6), dec_layers=config.get("dec_layers", 8), heads=config.get("heads", 8), d_ff=config.get("d_ff", 2048), use_moe=config.get("use_moe", True), n_experts=config.get("n_experts", 6), moe_layers=config.get("moe_layers", None), ).to(device) model_state = model.state_dict() filtered_state = { k: v for k, v in ckpt["model_state"].items() if k in model_state and tuple(v.shape) == tuple(model_state[k].shape) } skipped = len(ckpt["model_state"]) - len(filtered_state) incompatible = model.load_state_dict(filtered_state, strict=False) if skipped: print(f"Model warm-start: skipped {skipped} shape-mismatched params") if incompatible.missing_keys: print(f"Model warm-start: initialized {len(incompatible.missing_keys)} new params") if incompatible.unexpected_keys: print(f"Model warm-start: ignored {len(incompatible.unexpected_keys)} checkpoint params") model.eval() print(f"Model loaded (epoch {ckpt.get('epoch', '?')})") # ── Precision ── if args.precision != "float32": dt = {"float16": torch.float16, "bfloat16": torch.bfloat16}[args.precision] model = model.to(dtype=dt) print(f"Precision: {args.precision}") # ── torch.compile ── if args.compile: try: model = torch.compile(model, mode="reduce-overhead") print("Model compiled with torch.compile (reduce-overhead)") except Exception as e: print(f"torch.compile failed: {e}, continuing without compile") # ── Load / tokenize audio ── audio_tok = MaiTrackTokenizer(n_layers=2, device=str(device)) if args.song_id: audio_path = Path("datasets") / args.song_id / "track.mp3" if not audio_path.exists(): print(f"ERROR: {audio_path} not found") return print(f"Audio: {audio_path}") audio_tokens = audio_tok.encode(str(audio_path), add_bos=False, add_eos=False) # Auto-detect BPM from maidata if args.bpm is None: mdata = Path("datasets") / args.song_id / "maidata.txt" if mdata.exists(): song = parse_maidata_file(str(mdata)) bpm = song.bpm if song.bpm > 0 else 150.0 print(f"Auto-detected BPM: {bpm}") else: bpm = 150.0 else: bpm = args.bpm elif args.audio: print(f"Audio: {args.audio}") audio_tokens = audio_tok.encode(args.audio, add_bos=False, add_eos=False) bpm = args.bpm or 150.0 else: print("ERROR: need --audio or --song-id") return print(f"Audio tokens: {len(audio_tokens)}") # ── Generate ── difficulty = DIFF_MAP[args.diff.upper()] print(f"Params: BPM={bpm}, Diff={args.diff}({difficulty}), Level={args.level}") print(f"Generating (max {args.max_tokens} tokens)...") audio_t = torch.tensor([audio_tokens], dtype=torch.long) chart_tokens = generate_chart( model, audio_t, bpm=bpm, difficulty=difficulty, level_value=args.level, genre=args.genre, max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, eos_suppress=args.eos_suppress, min_division=args.min_division, max_duration_beats=args.max_duration_beats, max_sim=args.max_sim, sim_penalty=args.sim_penalty, repetition_penalty=args.rep_penalty, rest_streak_penalty=args.rest_streak_penalty, rest_streak_threshold=args.rest_streak_threshold, target_notes=args.target_notes, density_bias=args.density_bias, break_penalty=args.break_penalty, break_window=args.break_window, max_break_ratio=args.max_break_ratio, type_biases={ "tap": args.tap_bias, "hold": args.hold_bias, "slide": args.slide_bias, "break": args.break_bias, "touch": args.touch_bias, "rest": args.rest_bias, "sim": args.sim_bias, }, config_density_bias=args.config_density_bias, config_density_include_break=args.config_density_include_break, ) print(f"Generated {len(chart_tokens)} tokens") audio_duration_sec = len(audio_tokens) / 2 / AUDIO_FRAME_RATE total_beats = audio_duration_sec * bpm / 60.0 start_offset_beats = 2.0 * max(bpm, 30.0) / 60.0 dur_stats = duration_report(chart_tokens, total_beats, start_offset_beats, args.min_division) print(f"Duration stats: count={dur_stats['durations']} " f"max={dur_stats['max_beats']:.3f} beats overrun={dur_stats['overrun']}") sim_stats = simultaneous_report(chart_tokens, args.max_sim) print(f"Sim stats: groups={sim_stats['groups']} " f"max_count={sim_stats['max_count']} over_limit={sim_stats['over_limit']} " f"first_bad={sim_stats['first_bad']}") # ── Decode to human-readable ── chart_tok = MaiChartTokenizer() print(f"\nTokens: {chart_tok.tokens_to_str(chart_tokens, 60)}") # Decode to chart decoded = chart_tok.decode(chart_tokens) print(f"\nChart stats: TAP={decoded.tap_count} HOLD={decoded.hold_count} " f"SLIDE={decoded.slide_count} BREAK={decoded.break_count} " f"TOUCH={decoded.touch_count} TOTAL={decoded.note_count}") print(f"Break ratio: {decoded.break_count / max(decoded.note_count, 1):.2%}") print(f"Density: {decoded.note_count / max(audio_duration_sec, 1e-6):.2f} notes/s, " f"{decoded.note_count / max(total_beats, 1e-6):.2f} notes/beat") # ── Save ── if args.output: out_path = Path(args.output) # Convert tokens → proper maidata chart text chart_text = tokens_to_maitext(chart_tokens, bpm) lines = [f"&title=Generated Chart", f"&wholebpm={bpm}", f"&artist=MaiGenerator", f"&lv_{difficulty+1}={args.level}", f"&inote_{difficulty+1}="] lines.append(chart_text) out_path.write_text("\n".join(lines).rstrip("\n"), encoding="utf-8") print(f"\nSaved to {out_path}") if __name__ == "__main__": main()