| """ |
| Inference script for MaiGenerator โ generate maimai charts from audio. |
| |
| Usage: |
| python inference.py --checkpoint checkpoints/best.pt --audio datasets/10/track.mp3 |
| python inference.py --checkpoint checkpoints/best.pt --song-id 10 |
| python inference.py --checkpoint checkpoints/best.pt --audio track.mp3 --bpm 173 --diff MASTER --level 12.4 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from tqdm import tqdm |
|
|
| from model import MaiGenerator, BOS, EOS, DUR_TOKEN, AUDIO_FRAME_RATE |
| from Tokenizer.MaiTrackTokenizer import MaiTrackTokenizer |
| from tokenizer import ( |
| MaiChartTokenizer, tokens_to_maitext, PAD, SIM_BEG, RST, |
| TAP_TO_ID, BRK_TO_ID, HLD_TO_ID, SLD_TO_ID, TCH_TO_ID, |
| ) |
| from tokenizer import ID_TO_DUR_NUM, ID_TO_DUR_DEN, DUR_NUM_TO_ID, DUR_DEN_TO_ID |
| from tokenizer import load_config_vocab, ID_TO_CONFIG |
| from mai_parser import Difficulty, parse_maidata_file |
| from grammar import NOTE_TOKENS |
| from constrained_decode import ( |
| DecodeState, snap_duration, clamp_duration_tokens, duration_report, |
| simultaneous_report, sanitize_sim_tokens, |
| ) |
|
|
|
|
| def sample_token(logits: torch.Tensor, temperature: float = 0.8, |
| top_k: int = 50, top_p: float = 0.95, |
| repetition_penalty: float = 1.0, |
| prev_tokens: list[int] | None = None, |
| penalty_window: int = 50) -> int: |
| """Sample a token with temperature, top-k, top-p, and repetition penalty. |
| |
| Args: |
| repetition_penalty: >1.0 penalizes repeated tokens; 1.0 = off. |
| prev_tokens: Recent token history for repetition detection. |
| penalty_window: How many recent tokens to check for repeats. |
| """ |
| |
| if repetition_penalty > 1.0 and prev_tokens: |
| recent = set(prev_tokens[-penalty_window:]) |
| for tok in recent: |
| logits[tok] = logits[tok] / repetition_penalty |
|
|
| logits = logits / max(temperature, 1e-8) |
| logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9) |
|
|
| if top_k > 0: |
| k = min(top_k, logits.size(-1)) |
| min_val = torch.topk(logits, k).values[-1] |
| logits[logits < min_val] = float("-inf") |
|
|
| if top_p < 1.0: |
| sorted_l, sorted_i = torch.sort(logits, descending=True) |
| cum = torch.cumsum(F.softmax(sorted_l, dim=-1), dim=-1) |
| mask = cum > top_p |
| mask[1:] = mask[:-1].clone() |
| mask[0] = False |
| sorted_l[mask] = float("-inf") |
| logits = sorted_l.scatter(0, sorted_i, sorted_l) |
|
|
| probs = F.softmax(logits, dim=-1) |
| return torch.multinomial(probs, 1).item() |
|
|
|
|
| def _config_note_count(spec: tuple) -> int: |
| if not spec: |
| return 0 |
| kind = spec[0] |
| if kind == "pair": |
| return 2 |
| if kind == "multi": |
| return max(1, len(spec) - (4 if spec[1] == "hld" else 2)) |
| if kind == "touch_multi": |
| return max(1, len(spec) - 1) |
| return 1 |
|
|
|
|
| def _token_note_count(tok: int) -> int: |
| if tok in NOTE_TOKENS: |
| return 1 |
| spec = ID_TO_CONFIG.get(tok) |
| if spec is not None: |
| return _config_note_count(spec) |
| return 0 |
|
|
|
|
| def _config_has_break(spec: tuple) -> bool: |
| if not spec: |
| return False |
| if spec[0] == "brk": |
| return True |
| if spec[0] == "pair": |
| return spec[1] == "brk" or spec[3] == "brk" |
| if spec[0] == "multi": |
| return len(spec) > 1 and spec[1] == "brk" |
| return False |
|
|
|
|
| def _config_type_counts(spec: tuple) -> dict[str, int]: |
| if not spec: |
| return {} |
| kind = spec[0] |
| if kind == "tap": |
| return {"tap": 1} |
| if kind == "brk": |
| return {"break": 1} |
| if kind == "hld": |
| return {"hold": 1} |
| if kind == "sld": |
| return {"slide": 1} |
| if kind in ("tch", "touch_multi"): |
| return {"touch": 1} |
| if kind == "pair": |
| counts: dict[str, int] = {} |
| for typ in (spec[1], spec[3]): |
| key = "break" if typ == "brk" else "hold" if typ == "hld" else "tap" |
| counts[key] = counts.get(key, 0) + 1 |
| return counts |
| if kind == "multi": |
| typ = spec[1] |
| key = "break" if typ == "brk" else "hold" if typ == "hld" else "tap" |
| if typ == "hld": |
| n = max(1, len(spec) - 4) |
| else: |
| n = max(1, len(spec) - 2) |
| return {key: n} |
| return {} |
|
|
|
|
| def _token_has_break(tok: int) -> bool: |
| if tok in BRK_TO_ID.values(): |
| return True |
| spec = ID_TO_CONFIG.get(tok) |
| return spec is not None and _config_has_break(spec) |
|
|
|
|
| def _break_ids(vocab_size: int, device: torch.device) -> torch.Tensor: |
| ids = set(BRK_TO_ID.values()) |
| ids.update(t for t, spec in ID_TO_CONFIG.items() if _config_has_break(spec)) |
| valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] |
| return torch.tensor(valid, dtype=torch.long, device=device) |
|
|
|
|
| def _density_note_ids(vocab_size: int, device: torch.device, |
| include_break: bool = False) -> torch.Tensor: |
| ids = set(NOTE_TOKENS) | set(ID_TO_CONFIG.keys()) |
| if not include_break: |
| ids = {t for t in ids if not _token_has_break(int(t))} |
| valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] |
| return torch.tensor(valid, dtype=torch.long, device=device) |
|
|
|
|
| def _type_bias_vector(vocab_size: int, device: torch.device, |
| type_biases: dict[str, float]) -> torch.Tensor | None: |
| active = {k: float(v) for k, v in type_biases.items() if abs(float(v)) > 1e-9} |
| if not active: |
| return None |
| bias = torch.zeros(vocab_size, dtype=torch.float32, device=device) |
| token_groups = [ |
| ("tap", TAP_TO_ID.values()), |
| ("hold", HLD_TO_ID.values()), |
| ("slide", SLD_TO_ID.values()), |
| ("break", BRK_TO_ID.values()), |
| ("touch", TCH_TO_ID.values()), |
| ] |
| for name, ids in token_groups: |
| value = active.get(name, 0.0) |
| if value == 0.0: |
| continue |
| valid = [int(t) for t in ids if 0 <= int(t) < vocab_size] |
| if valid: |
| bias[torch.tensor(valid, dtype=torch.long, device=device)] += value |
| if active.get("rest", 0.0) and 0 <= RST < vocab_size: |
| bias[RST] += active["rest"] |
| if active.get("sim", 0.0) and 0 <= SIM_BEG < vocab_size: |
| bias[SIM_BEG] += active["sim"] |
|
|
| for token_id, spec in ID_TO_CONFIG.items(): |
| if not (0 <= int(token_id) < vocab_size): |
| continue |
| for name, count in _config_type_counts(spec).items(): |
| value = active.get(name, 0.0) |
| if value: |
| bias[int(token_id)] += value * max(1, int(count)) |
| return bias |
|
|
|
|
| def _config_density_bias_vector(vocab_size: int, device: torch.device, |
| strength: float, |
| include_break: bool = False) -> torch.Tensor | None: |
| if abs(float(strength)) <= 1e-9: |
| return None |
| bias = torch.zeros(vocab_size, dtype=torch.float32, device=device) |
| for token_id, spec in ID_TO_CONFIG.items(): |
| if not (0 <= int(token_id) < vocab_size): |
| continue |
| if not include_break and _config_has_break(spec): |
| continue |
| extra_notes = max(0, _config_note_count(spec) - 1) |
| if extra_notes > 0: |
| bias[int(token_id)] += float(strength) * extra_notes |
| return bias |
|
|
|
|
| def _auto_target_notes(level_value: float, duration_sec: float) -> int: |
| |
| notes_per_sec = float(np.interp( |
| level_value, |
| [10.0, 12.0, 13.0, 14.0, 14.5, 15.0], |
| [2.2, 3.2, 4.0, 4.8, 5.3, 5.8], |
| )) |
| return max(1, int(round(notes_per_sec * duration_sec))) |
|
|
|
|
| @torch.inference_mode() |
| def generate_chart( |
| model: MaiGenerator, |
| audio_tokens: torch.Tensor, |
| bpm: float, |
| difficulty: int, |
| level_value: float, |
| genre: int = 0, |
| max_tokens: int = 4096, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.95, |
| eos_suppress: float = 1.5, |
| min_division: int = 4, |
| max_duration_beats: float = 4.0, |
| max_sim: int = 2, |
| sim_penalty: float = 1.0, |
| repetition_penalty: float = 1.0, |
| rest_streak_penalty: float = 0.0, |
| rest_streak_threshold: int = 4, |
| target_notes: int | None = None, |
| density_bias: float = 0.0, |
| break_penalty: float = 0.0, |
| break_window: int = 64, |
| max_break_ratio: float = 0.08, |
| type_biases: dict[str, float] | None = None, |
| config_density_bias: float = 0.0, |
| config_density_include_break: bool = False, |
| ) -> list[int]: |
| """ |
| Autoregressive chart generation from audio (with KV-cache for speed). |
| |
| Args: |
| model: Trained MaiGenerator. |
| audio_tokens: [1, T_aud] โ EnCodec tokens. |
| bpm: BPM value. |
| difficulty: 0=BASIC..4=ReMASTER. |
| level_value: e.g. 12.4. |
| genre: Genre index. |
| max_tokens: Max tokens to generate. |
| temperature: Sampling temperature. |
| top_k: Top-K filter. |
| top_p: Nucleus filter. |
| |
| Returns: |
| List of chart token IDs (including BOS, metadata header, and EOS). |
| """ |
| model.eval() |
| device = next(model.parameters()).device |
| param_dtype = next(model.parameters()).dtype |
|
|
| if audio_tokens.dim() == 1: |
| audio_tokens = audio_tokens.unsqueeze(0) |
| audio_tokens = audio_tokens.to(device) |
|
|
| |
| max_len = model.audio_pos_embed.num_embeddings |
| if audio_tokens.shape[1] > max_len: |
| audio_tokens = audio_tokens[:, :max_len] |
| T_aud = audio_tokens.shape[1] |
| aud = model.audio_embed(audio_tokens) |
| aud = aud + model.audio_pos_embed( |
| torch.arange(T_aud, device=device).unsqueeze(0)) |
|
|
| |
| if model.audio_downsample > 1: |
| aud = aud.transpose(1, 2) |
| aud = model.audio_down(aud) |
| aud = aud.transpose(1, 2) |
| T_aud = aud.shape[1] |
|
|
| aud = model.dropout(aud) |
| aud_pos = torch.arange(T_aud, device=device, dtype=torch.float32) |
| for blk in model.audio_encoder: |
| aud = blk(aud, positions=aud_pos) |
| aud = model.enc_norm(aud) |
|
|
| |
| bpm_t = torch.tensor([[bpm]], device=device, dtype=param_dtype) |
| diff_t = torch.tensor([[difficulty]], device=device) |
| level_t = torch.tensor([[level_value]], device=device, dtype=param_dtype) |
| genre_t = torch.tensor([[genre]], device=device) |
|
|
| cond = (model.bpm_proj(bpm_t).unsqueeze(1) + |
| model.diff_embed(diff_t.squeeze(-1)).unsqueeze(1) + |
| model.level_proj(level_t).unsqueeze(1) + |
| model.genre_embed(genre_t.squeeze(-1)).unsqueeze(1)) |
|
|
| diff_vec = model.diff_embed(diff_t.squeeze(-1)) |
|
|
| |
| audio_duration_sec = audio_tokens.shape[1] / 2 / AUDIO_FRAME_RATE |
| total_beats = audio_duration_sec * bpm / 60.0 |
| start_offset_beats = 2.0 * max(bpm, 30.0) / 60.0 |
| print(f"Track: {audio_duration_sec:.0f}s = {total_beats:.0f} beats @ {bpm} BPM (start at beat {start_offset_beats:.1f})") |
| if density_bias > 0 and target_notes is None: |
| target_notes = _auto_target_notes(level_value, audio_duration_sec) |
| if density_bias > 0 and target_notes: |
| print(f"Density target: {target_notes} notes bias={density_bias}") |
|
|
| |
| model.init_kv_cache(batch_size=1, max_len=max_tokens + 1, enc_out=aud) |
|
|
| |
| tokens = [BOS] |
| pos_cache = [0.0] |
|
|
| tok_t = torch.tensor([[BOS]], device=device) |
| emb = model.chart_embed(tok_t) + cond |
| chart_pos = torch.tensor([[0.0]], device=device, dtype=torch.float32) |
|
|
| x = emb |
| for blk in model.chart_decoder: |
| x = blk(x, enc_out=aud, self_positions=chart_pos, |
| diff_emb=diff_vec, use_cache=True) |
|
|
| |
| decode_state = DecodeState( |
| bpm=bpm, |
| total_beats=total_beats, |
| start_offset_beats=start_offset_beats, |
| min_division=min_division, |
| max_sim=max_sim, |
| max_duration_beats=max_duration_beats, |
| ) |
| note_bonus_ids = _density_note_ids(model.chart_vocab_size, device, include_break=False) if density_bias > 0 else None |
| break_ids = _break_ids(model.chart_vocab_size, device) if break_penalty > 0 else None |
| type_bias = _type_bias_vector(model.chart_vocab_size, device, type_biases or {}) |
| config_density_bias_vec = _config_density_bias_vector( |
| model.chart_vocab_size, device, config_density_bias, |
| include_break=config_density_include_break, |
| ) |
| if type_bias is not None: |
| print("Type bias:", ", ".join(f"{k}={v:+.2f}" for k, v in (type_biases or {}).items() if abs(v) > 1e-9)) |
| if config_density_bias_vec is not None: |
| print(f"Config density bias: {config_density_bias:+.2f} " |
| f"(include_break={config_density_include_break})") |
| generated_notes = 0 |
| generated_breaks = 0 |
|
|
| pbar = tqdm(total=max_tokens, desc="Generating", unit="tok") |
| for step_idx in range(max_tokens): |
| x_last = model.dec_norm(x[:, -1:, :]) |
| logits = model.output_head(x_last).squeeze(0).squeeze(0) |
|
|
| |
| logits[EOS] = float("-inf") |
|
|
| if decode_state.current_beat >= total_beats: |
| pbar.update(1); pbar.set_postfix(tokens=len(tokens)+1, beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", end="track") |
| tokens.append(EOS) |
| break |
|
|
| logits = decode_state.apply_logits_mask(logits) |
| if type_bias is not None: |
| logits = logits + type_bias.to(dtype=logits.dtype) |
| if config_density_bias_vec is not None: |
| logits = logits + config_density_bias_vec.to(dtype=logits.dtype) |
| if decode_state.can_start_sim() and SIM_BEG < logits.numel(): |
| logits[SIM_BEG] -= sim_penalty |
| if not torch.isfinite(logits).any(): |
| pbar.update(1); pbar.set_postfix(tokens=len(tokens)+1, beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", end="no_legal") |
| tokens.append(EOS) |
| break |
|
|
| |
| if rest_streak_penalty > 0 and len(tokens) >= rest_streak_threshold: |
| recent = tokens[-rest_streak_threshold:] |
| |
| if all(t == RST for t in recent): |
| logits[RST] -= rest_streak_penalty |
|
|
| if density_bias > 0 and target_notes and note_bonus_ids is not None and note_bonus_ids.numel() > 0: |
| progress = max(0.0, decode_state.current_beat - start_offset_beats) |
| denom = max(1e-6, total_beats - start_offset_beats) |
| expected = float(target_notes) * min(1.0, progress / denom) |
| deficit = max(0.0, expected - generated_notes) / max(expected, 1.0) |
| if deficit > 0: |
| bonus = density_bias * deficit |
| logits[note_bonus_ids] += bonus |
| if RST < logits.numel(): |
| logits[RST] -= bonus |
|
|
| if break_penalty > 0 and break_ids is not None and break_ids.numel() > 0: |
| recent = tokens[-break_window:] if break_window > 0 else tokens |
| recent_notes = sum(_token_note_count(t) for t in recent) |
| recent_breaks = sum(_token_note_count(t) for t in recent if _token_has_break(t)) |
| global_ratio = generated_breaks / max(generated_notes, 1) |
| recent_ratio = recent_breaks / max(recent_notes, 1) |
| excess = max(global_ratio - max_break_ratio, recent_ratio - max_break_ratio, 0.0) |
| if excess > 0: |
| logits[break_ids] -= break_penalty * (1.0 + 8.0 * excess) |
|
|
| next_tok = sample_token(logits, temperature, top_k, top_p, |
| repetition_penalty=repetition_penalty, |
| prev_tokens=tokens) |
| tokens.append(next_tok) |
| note_count = _token_note_count(next_tok) |
| generated_notes += note_count |
| if _token_has_break(next_tok): |
| generated_breaks += note_count |
|
|
| |
| if len(tokens) >= 4 and tokens[-4] not in (DUR_TOKEN, PAD) and tokens[-3] == DUR_TOKEN: |
| raw_num = ID_TO_DUR_NUM.get(tokens[-2], tokens[-2]) |
| raw_den = ID_TO_DUR_DEN.get(next_tok, max(next_tok, 1)) |
| beat, subdiv = snap_duration(raw_num, raw_den, |
| decode_state.remaining_duration_beats(), |
| decode_state.max_duration_beats) |
| tokens[-2], tokens[-1] = DUR_NUM_TO_ID[beat], DUR_DEN_TO_ID[subdiv] |
| next_tok = tokens[-1] |
|
|
| |
| bpm_v = max(bpm, 30.0) |
| time_sec = decode_state.current_beat * 60.0 / bpm_v |
| new_pos = time_sec * AUDIO_FRAME_RATE / model.audio_downsample |
| pos_cache.append(new_pos) |
|
|
| decode_state.step(next_tok) |
|
|
| |
| tok_t = torch.tensor([[next_tok]], device=device) |
| emb = model.chart_embed(tok_t) + cond |
| chart_pos = torch.tensor([[new_pos]], device=device, dtype=torch.float32) |
|
|
| x = emb |
| for blk in model.chart_decoder: |
| x = blk(x, enc_out=aud, self_positions=chart_pos, |
| diff_emb=diff_vec, use_cache=True) |
|
|
| pbar.update(1) |
| if len(tokens) % 200 == 0: |
| pct = min(100, decode_state.current_beat / total_beats * 100) |
| pbar.set_postfix(beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", |
| pct=f"{pct:.0f}%", div=f"{decode_state.div_value}") |
| pbar.close() |
|
|
| tokens = sanitize_sim_tokens(tokens, max_sim=max_sim) |
| return clamp_duration_tokens(tokens, total_beats, start_offset_beats, min_division) |
|
|
|
|
| |
| |
| |
|
|
| DIFF_MAP = {"BASIC": 0, "ADVANCED": 1, "EXPERT": 2, "MASTER": 3, "REMASTER": 4} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate maimai chart from audio") |
| parser.add_argument("--checkpoint", type=str, required=True, |
| help="Path to model checkpoint") |
| parser.add_argument("--audio", type=str, default=None, |
| help="Path to audio file (mp3/wav)") |
| parser.add_argument("--song-id", type=str, default=None, |
| help="Song ID in datasets/ (uses datasets/<id>/track.mp3)") |
| parser.add_argument("--bpm", type=float, default=None, |
| help="BPM (auto-detect from maidata.txt if --song-id)") |
| parser.add_argument("--diff", type=str, default="MASTER", |
| choices=list(DIFF_MAP.keys())) |
| parser.add_argument("--level", type=float, default=12.0, |
| help="Target level value") |
| parser.add_argument("--genre", type=int, default=0) |
| parser.add_argument("--max-tokens", type=int, default=4096) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--top-p", type=float, default=0.95) |
| parser.add_argument("--eos-suppress", type=float, default=1.5, |
| help="Subtract from EOS logit (0=off)") |
| parser.add_argument("--min-division", type=int, default=4, |
| help="Minimum generated beat division; blocks div_1/div_2 by default") |
| parser.add_argument("--max-duration-beats", type=float, default=4.0, |
| help="Maximum generated hold/slide duration in beats") |
| parser.add_argument("--max-sim", type=int, default=2, |
| help="Maximum simultaneous press count") |
| parser.add_argument("--sim-penalty", type=float, default=1.0, |
| help="Subtract from SIM_BEG logit before constrained sampling") |
| parser.add_argument("--rep-penalty", type=float, default=1.0, |
| help="Repetition penalty (>1.0=penalize repeats, 1.0=off)") |
| parser.add_argument("--rest-streak-penalty", type=float, default=0.0, |
| help="Penalize RST when last N tokens are all rests (try 2.0-5.0)") |
| parser.add_argument("--rest-streak-threshold", type=int, default=4, |
| help="Consecutive RSTs before applying rest-streak-penalty") |
| parser.add_argument("--target-notes", type=int, default=None, |
| help="Optional soft target note count for density-guided sampling") |
| parser.add_argument("--density-bias", type=float, default=0.0, |
| help="Softly boost note/config tokens when generated density lags target") |
| parser.add_argument("--break-penalty", type=float, default=1.5, |
| help="Type-level penalty when BREAK/็ป่ต ratio is too high") |
| parser.add_argument("--break-window", type=int, default=64, |
| help="Recent token window used by BREAK ratio penalty") |
| parser.add_argument("--max-break-ratio", type=float, default=0.08, |
| help="Soft maximum BREAK/็ป่ต ratio before penalty kicks in") |
| parser.add_argument("--tap-bias", type=float, default=0.0, |
| help="Logit bias for TAP config/tokens (+ increases, - decreases)") |
| parser.add_argument("--hold-bias", type=float, default=0.0, |
| help="Logit bias for HOLD config/tokens") |
| parser.add_argument("--slide-bias", type=float, default=0.0, |
| help="Logit bias for SLIDE config/tokens") |
| parser.add_argument("--break-bias", type=float, default=0.0, |
| help="Logit bias for BREAK/็ป่ต config/tokens") |
| parser.add_argument("--touch-bias", type=float, default=0.0, |
| help="Logit bias for TOUCH config/tokens") |
| parser.add_argument("--rest-bias", type=float, default=0.0, |
| help="Logit bias for REST tokens") |
| parser.add_argument("--sim-bias", type=float, default=0.0, |
| help="Logit bias for simultaneous group start") |
| parser.add_argument("--config-density-bias", type=float, default=0.0, |
| help="Bias dense config tokens by extra notes per time slot") |
| parser.add_argument("--config-density-include-break", action="store_true", |
| help="Allow config-density-bias to boost BREAK/็ป่ต configs too") |
| parser.add_argument("--output", type=str, default=None, |
| help="Output path for generated maidata.txt") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--precision", type=str, default="float32", |
| choices=["float32", "float16", "bfloat16"], |
| help="Model precision (float16/bfloat16 โ2x faster on supported GPUs)") |
| parser.add_argument("--compile", action="store_true", |
| help="Use torch.compile for faster execution (PyTorch 2.0+)") |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
|
|
| |
| ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| config = ckpt.get("config", {}) |
| if "config_vocab" in config: |
| load_config_vocab(config["config_vocab"]) |
| print(f"Loaded config vocab: total={MaiChartTokenizer.vocab_size}") |
| model = MaiGenerator( |
| d_model=config.get("d_model", 512), |
| enc_layers=config.get("enc_layers", 6), |
| dec_layers=config.get("dec_layers", 8), |
| heads=config.get("heads", 8), |
| d_ff=config.get("d_ff", 2048), |
| use_moe=config.get("use_moe", True), |
| n_experts=config.get("n_experts", 6), |
| moe_layers=config.get("moe_layers", None), |
| ).to(device) |
| model_state = model.state_dict() |
| filtered_state = { |
| k: v for k, v in ckpt["model_state"].items() |
| if k in model_state and tuple(v.shape) == tuple(model_state[k].shape) |
| } |
| skipped = len(ckpt["model_state"]) - len(filtered_state) |
| incompatible = model.load_state_dict(filtered_state, strict=False) |
| if skipped: |
| print(f"Model warm-start: skipped {skipped} shape-mismatched params") |
| if incompatible.missing_keys: |
| print(f"Model warm-start: initialized {len(incompatible.missing_keys)} new params") |
| if incompatible.unexpected_keys: |
| print(f"Model warm-start: ignored {len(incompatible.unexpected_keys)} checkpoint params") |
| model.eval() |
| print(f"Model loaded (epoch {ckpt.get('epoch', '?')})") |
|
|
| |
| if args.precision != "float32": |
| dt = {"float16": torch.float16, "bfloat16": torch.bfloat16}[args.precision] |
| model = model.to(dtype=dt) |
| print(f"Precision: {args.precision}") |
|
|
| |
| if args.compile: |
| try: |
| model = torch.compile(model, mode="reduce-overhead") |
| print("Model compiled with torch.compile (reduce-overhead)") |
| except Exception as e: |
| print(f"torch.compile failed: {e}, continuing without compile") |
|
|
| |
| audio_tok = MaiTrackTokenizer(n_layers=2, device=str(device)) |
|
|
| if args.song_id: |
| audio_path = Path("datasets") / args.song_id / "track.mp3" |
| if not audio_path.exists(): |
| print(f"ERROR: {audio_path} not found") |
| return |
| print(f"Audio: {audio_path}") |
| audio_tokens = audio_tok.encode(str(audio_path), add_bos=False, add_eos=False) |
|
|
| |
| if args.bpm is None: |
| mdata = Path("datasets") / args.song_id / "maidata.txt" |
| if mdata.exists(): |
| song = parse_maidata_file(str(mdata)) |
| bpm = song.bpm if song.bpm > 0 else 150.0 |
| print(f"Auto-detected BPM: {bpm}") |
| else: |
| bpm = 150.0 |
| else: |
| bpm = args.bpm |
| elif args.audio: |
| print(f"Audio: {args.audio}") |
| audio_tokens = audio_tok.encode(args.audio, add_bos=False, add_eos=False) |
| bpm = args.bpm or 150.0 |
| else: |
| print("ERROR: need --audio or --song-id") |
| return |
|
|
| print(f"Audio tokens: {len(audio_tokens)}") |
|
|
| |
| difficulty = DIFF_MAP[args.diff.upper()] |
| print(f"Params: BPM={bpm}, Diff={args.diff}({difficulty}), Level={args.level}") |
| print(f"Generating (max {args.max_tokens} tokens)...") |
|
|
| audio_t = torch.tensor([audio_tokens], dtype=torch.long) |
| chart_tokens = generate_chart( |
| model, audio_t, bpm=bpm, difficulty=difficulty, |
| level_value=args.level, genre=args.genre, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, |
| eos_suppress=args.eos_suppress, |
| min_division=args.min_division, |
| max_duration_beats=args.max_duration_beats, |
| max_sim=args.max_sim, |
| sim_penalty=args.sim_penalty, |
| repetition_penalty=args.rep_penalty, |
| rest_streak_penalty=args.rest_streak_penalty, |
| rest_streak_threshold=args.rest_streak_threshold, |
| target_notes=args.target_notes, |
| density_bias=args.density_bias, |
| break_penalty=args.break_penalty, |
| break_window=args.break_window, |
| max_break_ratio=args.max_break_ratio, |
| type_biases={ |
| "tap": args.tap_bias, |
| "hold": args.hold_bias, |
| "slide": args.slide_bias, |
| "break": args.break_bias, |
| "touch": args.touch_bias, |
| "rest": args.rest_bias, |
| "sim": args.sim_bias, |
| }, |
| config_density_bias=args.config_density_bias, |
| config_density_include_break=args.config_density_include_break, |
| ) |
|
|
| print(f"Generated {len(chart_tokens)} tokens") |
| audio_duration_sec = len(audio_tokens) / 2 / AUDIO_FRAME_RATE |
| total_beats = audio_duration_sec * bpm / 60.0 |
| start_offset_beats = 2.0 * max(bpm, 30.0) / 60.0 |
| dur_stats = duration_report(chart_tokens, total_beats, start_offset_beats, args.min_division) |
| print(f"Duration stats: count={dur_stats['durations']} " |
| f"max={dur_stats['max_beats']:.3f} beats overrun={dur_stats['overrun']}") |
| sim_stats = simultaneous_report(chart_tokens, args.max_sim) |
| print(f"Sim stats: groups={sim_stats['groups']} " |
| f"max_count={sim_stats['max_count']} over_limit={sim_stats['over_limit']} " |
| f"first_bad={sim_stats['first_bad']}") |
|
|
| |
| chart_tok = MaiChartTokenizer() |
| print(f"\nTokens: {chart_tok.tokens_to_str(chart_tokens, 60)}") |
|
|
| |
| decoded = chart_tok.decode(chart_tokens) |
| print(f"\nChart stats: TAP={decoded.tap_count} HOLD={decoded.hold_count} " |
| f"SLIDE={decoded.slide_count} BREAK={decoded.break_count} " |
| f"TOUCH={decoded.touch_count} TOTAL={decoded.note_count}") |
| print(f"Break ratio: {decoded.break_count / max(decoded.note_count, 1):.2%}") |
| print(f"Density: {decoded.note_count / max(audio_duration_sec, 1e-6):.2f} notes/s, " |
| f"{decoded.note_count / max(total_beats, 1e-6):.2f} notes/beat") |
|
|
| |
| if args.output: |
| out_path = Path(args.output) |
| |
| chart_text = tokens_to_maitext(chart_tokens, bpm) |
| lines = [f"&title=Generated Chart", |
| f"&wholebpm={bpm}", |
| f"&artist=MaiGenerator", |
| f"&lv_{difficulty+1}={args.level}", |
| f"&inote_{difficulty+1}="] |
| lines.append(chart_text) |
| out_path.write_text("\n".join(lines).rstrip("\n"), encoding="utf-8") |
| print(f"\nSaved to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|