maiChartGen / inference.py
Goldgom's picture
Upload MaiGenerator model (epoch 10) and inference code
8061544
Raw
History Blame Contribute Delete
30.4 kB
"""
Inference script for MaiGenerator โ€” generate maimai charts from audio.
Usage:
python inference.py --checkpoint checkpoints/best.pt --audio datasets/10/track.mp3
python inference.py --checkpoint checkpoints/best.pt --song-id 10
python inference.py --checkpoint checkpoints/best.pt --audio track.mp3 --bpm 173 --diff MASTER --level 12.4
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from model import MaiGenerator, BOS, EOS, DUR_TOKEN, AUDIO_FRAME_RATE
from Tokenizer.MaiTrackTokenizer import MaiTrackTokenizer
from tokenizer import (
MaiChartTokenizer, tokens_to_maitext, PAD, SIM_BEG, RST,
TAP_TO_ID, BRK_TO_ID, HLD_TO_ID, SLD_TO_ID, TCH_TO_ID,
)
from tokenizer import ID_TO_DUR_NUM, ID_TO_DUR_DEN, DUR_NUM_TO_ID, DUR_DEN_TO_ID
from tokenizer import load_config_vocab, ID_TO_CONFIG
from mai_parser import Difficulty, parse_maidata_file
from grammar import NOTE_TOKENS
from constrained_decode import (
DecodeState, snap_duration, clamp_duration_tokens, duration_report,
simultaneous_report, sanitize_sim_tokens,
)
def sample_token(logits: torch.Tensor, temperature: float = 0.8,
top_k: int = 50, top_p: float = 0.95,
repetition_penalty: float = 1.0,
prev_tokens: list[int] | None = None,
penalty_window: int = 50) -> int:
"""Sample a token with temperature, top-k, top-p, and repetition penalty.
Args:
repetition_penalty: >1.0 penalizes repeated tokens; 1.0 = off.
prev_tokens: Recent token history for repetition detection.
penalty_window: How many recent tokens to check for repeats.
"""
# Repetition penalty: reduce logits for tokens that appear in recent history
if repetition_penalty > 1.0 and prev_tokens:
recent = set(prev_tokens[-penalty_window:])
for tok in recent:
logits[tok] = logits[tok] / repetition_penalty
logits = logits / max(temperature, 1e-8)
logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)
if top_k > 0:
k = min(top_k, logits.size(-1))
min_val = torch.topk(logits, k).values[-1]
logits[logits < min_val] = float("-inf")
if top_p < 1.0:
sorted_l, sorted_i = torch.sort(logits, descending=True)
cum = torch.cumsum(F.softmax(sorted_l, dim=-1), dim=-1)
mask = cum > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_l[mask] = float("-inf")
logits = sorted_l.scatter(0, sorted_i, sorted_l)
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).item()
def _config_note_count(spec: tuple) -> int:
if not spec:
return 0
kind = spec[0]
if kind == "pair":
return 2
if kind == "multi":
return max(1, len(spec) - (4 if spec[1] == "hld" else 2))
if kind == "touch_multi":
return max(1, len(spec) - 1)
return 1
def _token_note_count(tok: int) -> int:
if tok in NOTE_TOKENS:
return 1
spec = ID_TO_CONFIG.get(tok)
if spec is not None:
return _config_note_count(spec)
return 0
def _config_has_break(spec: tuple) -> bool:
if not spec:
return False
if spec[0] == "brk":
return True
if spec[0] == "pair":
return spec[1] == "brk" or spec[3] == "brk"
if spec[0] == "multi":
return len(spec) > 1 and spec[1] == "brk"
return False
def _config_type_counts(spec: tuple) -> dict[str, int]:
if not spec:
return {}
kind = spec[0]
if kind == "tap":
return {"tap": 1}
if kind == "brk":
return {"break": 1}
if kind == "hld":
return {"hold": 1}
if kind == "sld":
return {"slide": 1}
if kind in ("tch", "touch_multi"):
return {"touch": 1}
if kind == "pair":
counts: dict[str, int] = {}
for typ in (spec[1], spec[3]):
key = "break" if typ == "brk" else "hold" if typ == "hld" else "tap"
counts[key] = counts.get(key, 0) + 1
return counts
if kind == "multi":
typ = spec[1]
key = "break" if typ == "brk" else "hold" if typ == "hld" else "tap"
if typ == "hld":
n = max(1, len(spec) - 4)
else:
n = max(1, len(spec) - 2)
return {key: n}
return {}
def _token_has_break(tok: int) -> bool:
if tok in BRK_TO_ID.values():
return True
spec = ID_TO_CONFIG.get(tok)
return spec is not None and _config_has_break(spec)
def _break_ids(vocab_size: int, device: torch.device) -> torch.Tensor:
ids = set(BRK_TO_ID.values())
ids.update(t for t, spec in ID_TO_CONFIG.items() if _config_has_break(spec))
valid = [int(t) for t in ids if 0 <= int(t) < vocab_size]
return torch.tensor(valid, dtype=torch.long, device=device)
def _density_note_ids(vocab_size: int, device: torch.device,
include_break: bool = False) -> torch.Tensor:
ids = set(NOTE_TOKENS) | set(ID_TO_CONFIG.keys())
if not include_break:
ids = {t for t in ids if not _token_has_break(int(t))}
valid = [int(t) for t in ids if 0 <= int(t) < vocab_size]
return torch.tensor(valid, dtype=torch.long, device=device)
def _type_bias_vector(vocab_size: int, device: torch.device,
type_biases: dict[str, float]) -> torch.Tensor | None:
active = {k: float(v) for k, v in type_biases.items() if abs(float(v)) > 1e-9}
if not active:
return None
bias = torch.zeros(vocab_size, dtype=torch.float32, device=device)
token_groups = [
("tap", TAP_TO_ID.values()),
("hold", HLD_TO_ID.values()),
("slide", SLD_TO_ID.values()),
("break", BRK_TO_ID.values()),
("touch", TCH_TO_ID.values()),
]
for name, ids in token_groups:
value = active.get(name, 0.0)
if value == 0.0:
continue
valid = [int(t) for t in ids if 0 <= int(t) < vocab_size]
if valid:
bias[torch.tensor(valid, dtype=torch.long, device=device)] += value
if active.get("rest", 0.0) and 0 <= RST < vocab_size:
bias[RST] += active["rest"]
if active.get("sim", 0.0) and 0 <= SIM_BEG < vocab_size:
bias[SIM_BEG] += active["sim"]
for token_id, spec in ID_TO_CONFIG.items():
if not (0 <= int(token_id) < vocab_size):
continue
for name, count in _config_type_counts(spec).items():
value = active.get(name, 0.0)
if value:
bias[int(token_id)] += value * max(1, int(count))
return bias
def _config_density_bias_vector(vocab_size: int, device: torch.device,
strength: float,
include_break: bool = False) -> torch.Tensor | None:
if abs(float(strength)) <= 1e-9:
return None
bias = torch.zeros(vocab_size, dtype=torch.float32, device=device)
for token_id, spec in ID_TO_CONFIG.items():
if not (0 <= int(token_id) < vocab_size):
continue
if not include_break and _config_has_break(spec):
continue
extra_notes = max(0, _config_note_count(spec) - 1)
if extra_notes > 0:
bias[int(token_id)] += float(strength) * extra_notes
return bias
def _auto_target_notes(level_value: float, duration_sec: float) -> int:
# Conservative maimai-ish density prior. It is only used when density bias is enabled.
notes_per_sec = float(np.interp(
level_value,
[10.0, 12.0, 13.0, 14.0, 14.5, 15.0],
[2.2, 3.2, 4.0, 4.8, 5.3, 5.8],
))
return max(1, int(round(notes_per_sec * duration_sec)))
@torch.inference_mode()
def generate_chart(
model: MaiGenerator,
audio_tokens: torch.Tensor,
bpm: float,
difficulty: int,
level_value: float,
genre: int = 0,
max_tokens: int = 4096,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.95,
eos_suppress: float = 1.5,
min_division: int = 4,
max_duration_beats: float = 4.0,
max_sim: int = 2,
sim_penalty: float = 1.0,
repetition_penalty: float = 1.0,
rest_streak_penalty: float = 0.0,
rest_streak_threshold: int = 4,
target_notes: int | None = None,
density_bias: float = 0.0,
break_penalty: float = 0.0,
break_window: int = 64,
max_break_ratio: float = 0.08,
type_biases: dict[str, float] | None = None,
config_density_bias: float = 0.0,
config_density_include_break: bool = False,
) -> list[int]:
"""
Autoregressive chart generation from audio (with KV-cache for speed).
Args:
model: Trained MaiGenerator.
audio_tokens: [1, T_aud] โ€” EnCodec tokens.
bpm: BPM value.
difficulty: 0=BASIC..4=ReMASTER.
level_value: e.g. 12.4.
genre: Genre index.
max_tokens: Max tokens to generate.
temperature: Sampling temperature.
top_k: Top-K filter.
top_p: Nucleus filter.
Returns:
List of chart token IDs (including BOS, metadata header, and EOS).
"""
model.eval()
device = next(model.parameters()).device
param_dtype = next(model.parameters()).dtype
if audio_tokens.dim() == 1:
audio_tokens = audio_tokens.unsqueeze(0)
audio_tokens = audio_tokens.to(device)
# โ”€โ”€ Encode audio (once) โ”€โ”€
max_len = model.audio_pos_embed.num_embeddings # 32768
if audio_tokens.shape[1] > max_len:
audio_tokens = audio_tokens[:, :max_len]
T_aud = audio_tokens.shape[1]
aud = model.audio_embed(audio_tokens)
aud = aud + model.audio_pos_embed(
torch.arange(T_aud, device=device).unsqueeze(0))
# Downsample audio (same as training forward)
if model.audio_downsample > 1:
aud = aud.transpose(1, 2)
aud = model.audio_down(aud)
aud = aud.transpose(1, 2)
T_aud = aud.shape[1]
aud = model.dropout(aud)
aud_pos = torch.arange(T_aud, device=device, dtype=torch.float32)
for blk in model.audio_encoder:
aud = blk(aud, positions=aud_pos)
aud = model.enc_norm(aud)
# โ”€โ”€ Conditions โ”€โ”€
bpm_t = torch.tensor([[bpm]], device=device, dtype=param_dtype)
diff_t = torch.tensor([[difficulty]], device=device)
level_t = torch.tensor([[level_value]], device=device, dtype=param_dtype)
genre_t = torch.tensor([[genre]], device=device)
cond = (model.bpm_proj(bpm_t).unsqueeze(1) +
model.diff_embed(diff_t.squeeze(-1)).unsqueeze(1) +
model.level_proj(level_t).unsqueeze(1) +
model.genre_embed(genre_t.squeeze(-1)).unsqueeze(1))
diff_vec = model.diff_embed(diff_t.squeeze(-1)) # [1, d_model]
# Total beats
audio_duration_sec = audio_tokens.shape[1] / 2 / AUDIO_FRAME_RATE
total_beats = audio_duration_sec * bpm / 60.0
start_offset_beats = 2.0 * max(bpm, 30.0) / 60.0
print(f"Track: {audio_duration_sec:.0f}s = {total_beats:.0f} beats @ {bpm} BPM (start at beat {start_offset_beats:.1f})")
if density_bias > 0 and target_notes is None:
target_notes = _auto_target_notes(level_value, audio_duration_sec)
if density_bias > 0 and target_notes:
print(f"Density target: {target_notes} notes bias={density_bias}")
# โ”€โ”€ Init KV-cache โ”€โ”€
model.init_kv_cache(batch_size=1, max_len=max_tokens + 1, enc_out=aud)
# โ”€โ”€ Prefill with BOS token โ”€โ”€
tokens = [BOS]
pos_cache = [0.0]
tok_t = torch.tensor([[BOS]], device=device)
emb = model.chart_embed(tok_t) + cond # cond adds batch, tok_t is [1,1]
chart_pos = torch.tensor([[0.0]], device=device, dtype=torch.float32)
x = emb
for blk in model.chart_decoder:
x = blk(x, enc_out=aud, self_positions=chart_pos,
diff_emb=diff_vec, use_cache=True)
# โ”€โ”€ Autoregressive loop (KV-cached, one token at a time) โ”€โ”€
decode_state = DecodeState(
bpm=bpm,
total_beats=total_beats,
start_offset_beats=start_offset_beats,
min_division=min_division,
max_sim=max_sim,
max_duration_beats=max_duration_beats,
)
note_bonus_ids = _density_note_ids(model.chart_vocab_size, device, include_break=False) if density_bias > 0 else None
break_ids = _break_ids(model.chart_vocab_size, device) if break_penalty > 0 else None
type_bias = _type_bias_vector(model.chart_vocab_size, device, type_biases or {})
config_density_bias_vec = _config_density_bias_vector(
model.chart_vocab_size, device, config_density_bias,
include_break=config_density_include_break,
)
if type_bias is not None:
print("Type bias:", ", ".join(f"{k}={v:+.2f}" for k, v in (type_biases or {}).items() if abs(v) > 1e-9))
if config_density_bias_vec is not None:
print(f"Config density bias: {config_density_bias:+.2f} "
f"(include_break={config_density_include_break})")
generated_notes = 0
generated_breaks = 0
pbar = tqdm(total=max_tokens, desc="Generating", unit="tok")
for step_idx in range(max_tokens):
x_last = model.dec_norm(x[:, -1:, :])
logits = model.output_head(x_last).squeeze(0).squeeze(0)
# โ”€โ”€ Constraints โ”€โ”€
logits[EOS] = float("-inf")
if decode_state.current_beat >= total_beats:
pbar.update(1); pbar.set_postfix(tokens=len(tokens)+1, beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", end="track")
tokens.append(EOS)
break
logits = decode_state.apply_logits_mask(logits)
if type_bias is not None:
logits = logits + type_bias.to(dtype=logits.dtype)
if config_density_bias_vec is not None:
logits = logits + config_density_bias_vec.to(dtype=logits.dtype)
if decode_state.can_start_sim() and SIM_BEG < logits.numel():
logits[SIM_BEG] -= sim_penalty
if not torch.isfinite(logits).any():
pbar.update(1); pbar.set_postfix(tokens=len(tokens)+1, beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}", end="no_legal")
tokens.append(EOS)
break
# โ”€โ”€ Rest-streak penalty: if last N notes are all rests, suppress RST โ”€โ”€
if rest_streak_penalty > 0 and len(tokens) >= rest_streak_threshold:
recent = tokens[-rest_streak_threshold:]
# Check if recent tokens are all RST (not counting structural tokens like DUR/SIM etc.)
if all(t == RST for t in recent):
logits[RST] -= rest_streak_penalty
if density_bias > 0 and target_notes and note_bonus_ids is not None and note_bonus_ids.numel() > 0:
progress = max(0.0, decode_state.current_beat - start_offset_beats)
denom = max(1e-6, total_beats - start_offset_beats)
expected = float(target_notes) * min(1.0, progress / denom)
deficit = max(0.0, expected - generated_notes) / max(expected, 1.0)
if deficit > 0:
bonus = density_bias * deficit
logits[note_bonus_ids] += bonus
if RST < logits.numel():
logits[RST] -= bonus
if break_penalty > 0 and break_ids is not None and break_ids.numel() > 0:
recent = tokens[-break_window:] if break_window > 0 else tokens
recent_notes = sum(_token_note_count(t) for t in recent)
recent_breaks = sum(_token_note_count(t) for t in recent if _token_has_break(t))
global_ratio = generated_breaks / max(generated_notes, 1)
recent_ratio = recent_breaks / max(recent_notes, 1)
excess = max(global_ratio - max_break_ratio, recent_ratio - max_break_ratio, 0.0)
if excess > 0:
logits[break_ids] -= break_penalty * (1.0 + 8.0 * excess)
next_tok = sample_token(logits, temperature, top_k, top_p,
repetition_penalty=repetition_penalty,
prev_tokens=tokens)
tokens.append(next_tok)
note_count = _token_note_count(next_tok)
generated_notes += note_count
if _token_has_break(next_tok):
generated_breaks += note_count
# โ”€โ”€ Snap duration โ”€โ”€
if len(tokens) >= 4 and tokens[-4] not in (DUR_TOKEN, PAD) and tokens[-3] == DUR_TOKEN:
raw_num = ID_TO_DUR_NUM.get(tokens[-2], tokens[-2])
raw_den = ID_TO_DUR_DEN.get(next_tok, max(next_tok, 1))
beat, subdiv = snap_duration(raw_num, raw_den,
decode_state.remaining_duration_beats(),
decode_state.max_duration_beats)
tokens[-2], tokens[-1] = DUR_NUM_TO_ID[beat], DUR_DEN_TO_ID[subdiv]
next_tok = tokens[-1]
# โ”€โ”€ Record position โ”€โ”€
bpm_v = max(bpm, 30.0)
time_sec = decode_state.current_beat * 60.0 / bpm_v
new_pos = time_sec * AUDIO_FRAME_RATE / model.audio_downsample
pos_cache.append(new_pos)
decode_state.step(next_tok)
# โ”€โ”€ Incremental decoder step (KV-cached) โ”€โ”€
tok_t = torch.tensor([[next_tok]], device=device)
emb = model.chart_embed(tok_t) + cond
chart_pos = torch.tensor([[new_pos]], device=device, dtype=torch.float32)
x = emb
for blk in model.chart_decoder:
x = blk(x, enc_out=aud, self_positions=chart_pos,
diff_emb=diff_vec, use_cache=True)
pbar.update(1)
if len(tokens) % 200 == 0:
pct = min(100, decode_state.current_beat / total_beats * 100)
pbar.set_postfix(beat=f"{decode_state.current_beat:.0f}/{total_beats:.0f}",
pct=f"{pct:.0f}%", div=f"{decode_state.div_value}")
pbar.close()
tokens = sanitize_sim_tokens(tokens, max_sim=max_sim)
return clamp_duration_tokens(tokens, total_beats, start_offset_beats, min_division)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# CLI
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
DIFF_MAP = {"BASIC": 0, "ADVANCED": 1, "EXPERT": 2, "MASTER": 3, "REMASTER": 4}
def main():
parser = argparse.ArgumentParser(description="Generate maimai chart from audio")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to model checkpoint")
parser.add_argument("--audio", type=str, default=None,
help="Path to audio file (mp3/wav)")
parser.add_argument("--song-id", type=str, default=None,
help="Song ID in datasets/ (uses datasets/<id>/track.mp3)")
parser.add_argument("--bpm", type=float, default=None,
help="BPM (auto-detect from maidata.txt if --song-id)")
parser.add_argument("--diff", type=str, default="MASTER",
choices=list(DIFF_MAP.keys()))
parser.add_argument("--level", type=float, default=12.0,
help="Target level value")
parser.add_argument("--genre", type=int, default=0)
parser.add_argument("--max-tokens", type=int, default=4096)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument("--eos-suppress", type=float, default=1.5,
help="Subtract from EOS logit (0=off)")
parser.add_argument("--min-division", type=int, default=4,
help="Minimum generated beat division; blocks div_1/div_2 by default")
parser.add_argument("--max-duration-beats", type=float, default=4.0,
help="Maximum generated hold/slide duration in beats")
parser.add_argument("--max-sim", type=int, default=2,
help="Maximum simultaneous press count")
parser.add_argument("--sim-penalty", type=float, default=1.0,
help="Subtract from SIM_BEG logit before constrained sampling")
parser.add_argument("--rep-penalty", type=float, default=1.0,
help="Repetition penalty (>1.0=penalize repeats, 1.0=off)")
parser.add_argument("--rest-streak-penalty", type=float, default=0.0,
help="Penalize RST when last N tokens are all rests (try 2.0-5.0)")
parser.add_argument("--rest-streak-threshold", type=int, default=4,
help="Consecutive RSTs before applying rest-streak-penalty")
parser.add_argument("--target-notes", type=int, default=None,
help="Optional soft target note count for density-guided sampling")
parser.add_argument("--density-bias", type=float, default=0.0,
help="Softly boost note/config tokens when generated density lags target")
parser.add_argument("--break-penalty", type=float, default=1.5,
help="Type-level penalty when BREAK/็ป่ตž ratio is too high")
parser.add_argument("--break-window", type=int, default=64,
help="Recent token window used by BREAK ratio penalty")
parser.add_argument("--max-break-ratio", type=float, default=0.08,
help="Soft maximum BREAK/็ป่ตž ratio before penalty kicks in")
parser.add_argument("--tap-bias", type=float, default=0.0,
help="Logit bias for TAP config/tokens (+ increases, - decreases)")
parser.add_argument("--hold-bias", type=float, default=0.0,
help="Logit bias for HOLD config/tokens")
parser.add_argument("--slide-bias", type=float, default=0.0,
help="Logit bias for SLIDE config/tokens")
parser.add_argument("--break-bias", type=float, default=0.0,
help="Logit bias for BREAK/็ป่ตž config/tokens")
parser.add_argument("--touch-bias", type=float, default=0.0,
help="Logit bias for TOUCH config/tokens")
parser.add_argument("--rest-bias", type=float, default=0.0,
help="Logit bias for REST tokens")
parser.add_argument("--sim-bias", type=float, default=0.0,
help="Logit bias for simultaneous group start")
parser.add_argument("--config-density-bias", type=float, default=0.0,
help="Bias dense config tokens by extra notes per time slot")
parser.add_argument("--config-density-include-break", action="store_true",
help="Allow config-density-bias to boost BREAK/็ป่ตž configs too")
parser.add_argument("--output", type=str, default=None,
help="Output path for generated maidata.txt")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--precision", type=str, default="float32",
choices=["float32", "float16", "bfloat16"],
help="Model precision (float16/bfloat16 โ‰ˆ2x faster on supported GPUs)")
parser.add_argument("--compile", action="store_true",
help="Use torch.compile for faster execution (PyTorch 2.0+)")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# โ”€โ”€ Load model โ”€โ”€
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
config = ckpt.get("config", {})
if "config_vocab" in config:
load_config_vocab(config["config_vocab"])
print(f"Loaded config vocab: total={MaiChartTokenizer.vocab_size}")
model = MaiGenerator(
d_model=config.get("d_model", 512),
enc_layers=config.get("enc_layers", 6),
dec_layers=config.get("dec_layers", 8),
heads=config.get("heads", 8),
d_ff=config.get("d_ff", 2048),
use_moe=config.get("use_moe", True),
n_experts=config.get("n_experts", 6),
moe_layers=config.get("moe_layers", None),
).to(device)
model_state = model.state_dict()
filtered_state = {
k: v for k, v in ckpt["model_state"].items()
if k in model_state and tuple(v.shape) == tuple(model_state[k].shape)
}
skipped = len(ckpt["model_state"]) - len(filtered_state)
incompatible = model.load_state_dict(filtered_state, strict=False)
if skipped:
print(f"Model warm-start: skipped {skipped} shape-mismatched params")
if incompatible.missing_keys:
print(f"Model warm-start: initialized {len(incompatible.missing_keys)} new params")
if incompatible.unexpected_keys:
print(f"Model warm-start: ignored {len(incompatible.unexpected_keys)} checkpoint params")
model.eval()
print(f"Model loaded (epoch {ckpt.get('epoch', '?')})")
# โ”€โ”€ Precision โ”€โ”€
if args.precision != "float32":
dt = {"float16": torch.float16, "bfloat16": torch.bfloat16}[args.precision]
model = model.to(dtype=dt)
print(f"Precision: {args.precision}")
# โ”€โ”€ torch.compile โ”€โ”€
if args.compile:
try:
model = torch.compile(model, mode="reduce-overhead")
print("Model compiled with torch.compile (reduce-overhead)")
except Exception as e:
print(f"torch.compile failed: {e}, continuing without compile")
# โ”€โ”€ Load / tokenize audio โ”€โ”€
audio_tok = MaiTrackTokenizer(n_layers=2, device=str(device))
if args.song_id:
audio_path = Path("datasets") / args.song_id / "track.mp3"
if not audio_path.exists():
print(f"ERROR: {audio_path} not found")
return
print(f"Audio: {audio_path}")
audio_tokens = audio_tok.encode(str(audio_path), add_bos=False, add_eos=False)
# Auto-detect BPM from maidata
if args.bpm is None:
mdata = Path("datasets") / args.song_id / "maidata.txt"
if mdata.exists():
song = parse_maidata_file(str(mdata))
bpm = song.bpm if song.bpm > 0 else 150.0
print(f"Auto-detected BPM: {bpm}")
else:
bpm = 150.0
else:
bpm = args.bpm
elif args.audio:
print(f"Audio: {args.audio}")
audio_tokens = audio_tok.encode(args.audio, add_bos=False, add_eos=False)
bpm = args.bpm or 150.0
else:
print("ERROR: need --audio or --song-id")
return
print(f"Audio tokens: {len(audio_tokens)}")
# โ”€โ”€ Generate โ”€โ”€
difficulty = DIFF_MAP[args.diff.upper()]
print(f"Params: BPM={bpm}, Diff={args.diff}({difficulty}), Level={args.level}")
print(f"Generating (max {args.max_tokens} tokens)...")
audio_t = torch.tensor([audio_tokens], dtype=torch.long)
chart_tokens = generate_chart(
model, audio_t, bpm=bpm, difficulty=difficulty,
level_value=args.level, genre=args.genre,
max_tokens=args.max_tokens,
temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
eos_suppress=args.eos_suppress,
min_division=args.min_division,
max_duration_beats=args.max_duration_beats,
max_sim=args.max_sim,
sim_penalty=args.sim_penalty,
repetition_penalty=args.rep_penalty,
rest_streak_penalty=args.rest_streak_penalty,
rest_streak_threshold=args.rest_streak_threshold,
target_notes=args.target_notes,
density_bias=args.density_bias,
break_penalty=args.break_penalty,
break_window=args.break_window,
max_break_ratio=args.max_break_ratio,
type_biases={
"tap": args.tap_bias,
"hold": args.hold_bias,
"slide": args.slide_bias,
"break": args.break_bias,
"touch": args.touch_bias,
"rest": args.rest_bias,
"sim": args.sim_bias,
},
config_density_bias=args.config_density_bias,
config_density_include_break=args.config_density_include_break,
)
print(f"Generated {len(chart_tokens)} tokens")
audio_duration_sec = len(audio_tokens) / 2 / AUDIO_FRAME_RATE
total_beats = audio_duration_sec * bpm / 60.0
start_offset_beats = 2.0 * max(bpm, 30.0) / 60.0
dur_stats = duration_report(chart_tokens, total_beats, start_offset_beats, args.min_division)
print(f"Duration stats: count={dur_stats['durations']} "
f"max={dur_stats['max_beats']:.3f} beats overrun={dur_stats['overrun']}")
sim_stats = simultaneous_report(chart_tokens, args.max_sim)
print(f"Sim stats: groups={sim_stats['groups']} "
f"max_count={sim_stats['max_count']} over_limit={sim_stats['over_limit']} "
f"first_bad={sim_stats['first_bad']}")
# โ”€โ”€ Decode to human-readable โ”€โ”€
chart_tok = MaiChartTokenizer()
print(f"\nTokens: {chart_tok.tokens_to_str(chart_tokens, 60)}")
# Decode to chart
decoded = chart_tok.decode(chart_tokens)
print(f"\nChart stats: TAP={decoded.tap_count} HOLD={decoded.hold_count} "
f"SLIDE={decoded.slide_count} BREAK={decoded.break_count} "
f"TOUCH={decoded.touch_count} TOTAL={decoded.note_count}")
print(f"Break ratio: {decoded.break_count / max(decoded.note_count, 1):.2%}")
print(f"Density: {decoded.note_count / max(audio_duration_sec, 1e-6):.2f} notes/s, "
f"{decoded.note_count / max(total_beats, 1e-6):.2f} notes/beat")
# โ”€โ”€ Save โ”€โ”€
if args.output:
out_path = Path(args.output)
# Convert tokens โ†’ proper maidata chart text
chart_text = tokens_to_maitext(chart_tokens, bpm)
lines = [f"&title=Generated Chart",
f"&wholebpm={bpm}",
f"&artist=MaiGenerator",
f"&lv_{difficulty+1}={args.level}",
f"&inote_{difficulty+1}="]
lines.append(chart_text)
out_path.write_text("\n".join(lines).rstrip("\n"), encoding="utf-8")
print(f"\nSaved to {out_path}")
if __name__ == "__main__":
main()