Commit ·
d171350
0
Parent(s):
Initial commit
Browse files- .gitignore +5 -0
- README.md +26 -0
- app.py +108 -0
- convert_checkpoint.py +57 -0
- midmid/__init__.py +1 -0
- midmid/audio_prep.py +16 -0
- midmid/beat_tracker.py +42 -0
- midmid/constraints.py +79 -0
- midmid/datatypes.py +23 -0
- midmid/inference.py +335 -0
- midmid/ini_writer.py +32 -0
- midmid/midi_writer.py +104 -0
- midmid/nn.py +222 -0
- midmid/offset.py +29 -0
- midmid/sections.py +90 -0
- midmid/tempo_map.py +80 -0
- pipeline.py +304 -0
- requirements.txt +12 -0
- visualizer.py +360 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.egg-info/
|
| 5 |
+
model_upload/
|
README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Midmid - Guitar Hero Chart Generator
|
| 3 |
+
emoji: 🎸
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "5.23.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: mit
|
| 11 |
+
hardware: zero-a10g
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Midmid — AI Guitar Hero Chart Generator
|
| 15 |
+
|
| 16 |
+
Upload a song, get a playable Guitar Hero chart. Powered by a 19M-parameter
|
| 17 |
+
masked-prediction transformer trained on thousands of community-charted songs.
|
| 18 |
+
|
| 19 |
+
**How it works:**
|
| 20 |
+
1. Upload an audio file (MP3, FLAC, OGG, WAV)
|
| 21 |
+
2. Enter song metadata (title, artist, etc.)
|
| 22 |
+
3. Hit Generate — the model analyzes beats, structure, and audio features,
|
| 23 |
+
then predicts note placements for all four difficulty levels
|
| 24 |
+
4. Preview the chart in-browser, then download the ready-to-play song package
|
| 25 |
+
|
| 26 |
+
The output folder drops straight into GHWT:DE's MODS directory.
|
app.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Midmid — AI Guitar Hero Chart Generator (Hugging Face Space)."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
# ZeroGPU: import spaces if available (no-op locally)
|
| 7 |
+
try:
|
| 8 |
+
import spaces
|
| 9 |
+
ON_ZEROGPU = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
ON_ZEROGPU = False
|
| 12 |
+
|
| 13 |
+
from pipeline import ensure_model, generate_chart
|
| 14 |
+
from visualizer import build_visualizer_html
|
| 15 |
+
|
| 16 |
+
# Pre-load model on CPU at startup
|
| 17 |
+
ensure_model()
|
| 18 |
+
|
| 19 |
+
PLACEHOLDER_HTML = """
|
| 20 |
+
<div style="font-family: system-ui, sans-serif; background: #111; border-radius: 12px;
|
| 21 |
+
padding: 60px 20px; text-align: center; color: #666; max-width: 900px; margin: 0 auto;">
|
| 22 |
+
<div style="font-size: 48px; margin-bottom: 12px;">🎸</div>
|
| 23 |
+
<div style="font-size: 16px;">Upload a song and hit Generate to see your chart here</div>
|
| 24 |
+
</div>
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _generate_wrapper(audio_path, title, artist, album, year, genre, progress=gr.Progress()):
|
| 29 |
+
"""Gradio-facing wrapper with validation and progress."""
|
| 30 |
+
if not audio_path:
|
| 31 |
+
raise gr.Error("Please upload an audio file.")
|
| 32 |
+
if not title or not title.strip():
|
| 33 |
+
raise gr.Error("Song title is required.")
|
| 34 |
+
if not artist or not artist.strip():
|
| 35 |
+
raise gr.Error("Artist name is required.")
|
| 36 |
+
|
| 37 |
+
zip_path, chart_json = generate_chart(
|
| 38 |
+
audio_path=audio_path,
|
| 39 |
+
title=title.strip(),
|
| 40 |
+
artist=artist.strip(),
|
| 41 |
+
album=album.strip() if album else "",
|
| 42 |
+
year=year.strip() if year else "",
|
| 43 |
+
genre=genre.strip() if genre else "rock",
|
| 44 |
+
progress_cb=progress,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
html = build_visualizer_html(chart_json)
|
| 48 |
+
return html, zip_path
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Apply ZeroGPU decorator if running on HF Spaces
|
| 52 |
+
if ON_ZEROGPU:
|
| 53 |
+
_generate_wrapper = spaces.GPU(duration=180)(_generate_wrapper)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# --- UI ---
|
| 57 |
+
with gr.Blocks(
|
| 58 |
+
title="Midmid — Guitar Hero Chart Generator",
|
| 59 |
+
theme=gr.themes.Base(primary_hue="purple", neutral_hue="gray"),
|
| 60 |
+
css="""
|
| 61 |
+
.gradio-container { max-width: 960px !important; }
|
| 62 |
+
#generate-btn { min-height: 48px; font-size: 16px; }
|
| 63 |
+
""",
|
| 64 |
+
) as demo:
|
| 65 |
+
gr.Markdown(
|
| 66 |
+
"# Midmid — AI Guitar Hero Chart Generator\n"
|
| 67 |
+
"Upload a song, get a playable chart with 4 difficulty levels. "
|
| 68 |
+
"Preview it here, then download the GHWT:DE-ready package."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
with gr.Row():
|
| 72 |
+
with gr.Column(scale=1):
|
| 73 |
+
audio_input = gr.Audio(
|
| 74 |
+
label="Upload audio",
|
| 75 |
+
type="filepath",
|
| 76 |
+
sources=["upload"],
|
| 77 |
+
)
|
| 78 |
+
title_input = gr.Textbox(label="Song title *", placeholder="e.g. Through the Fire and Flames")
|
| 79 |
+
artist_input = gr.Textbox(label="Artist *", placeholder="e.g. DragonForce")
|
| 80 |
+
|
| 81 |
+
with gr.Row():
|
| 82 |
+
album_input = gr.Textbox(label="Album", placeholder="(optional)")
|
| 83 |
+
year_input = gr.Textbox(label="Year", placeholder="(optional)")
|
| 84 |
+
|
| 85 |
+
genre_input = gr.Textbox(label="Genre", placeholder="rock", value="rock")
|
| 86 |
+
|
| 87 |
+
generate_btn = gr.Button("Generate Chart", variant="primary", elem_id="generate-btn")
|
| 88 |
+
|
| 89 |
+
with gr.Column(scale=2):
|
| 90 |
+
viz_output = gr.HTML(value=PLACEHOLDER_HTML, label="Chart Preview")
|
| 91 |
+
zip_output = gr.File(label="Download song package (.zip)")
|
| 92 |
+
|
| 93 |
+
generate_btn.click(
|
| 94 |
+
fn=_generate_wrapper,
|
| 95 |
+
inputs=[audio_input, title_input, artist_input, album_input, year_input, genre_input],
|
| 96 |
+
outputs=[viz_output, zip_output],
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
gr.Markdown(
|
| 100 |
+
"---\n"
|
| 101 |
+
"*Charts generated by [Midmid](https://github.com/markury/midmid) — "
|
| 102 |
+
"a 19M-parameter masked transformer trained on community Guitar Hero charts. "
|
| 103 |
+
"Model: `markury/midmid3-19m-0326`*"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
demo.launch()
|
convert_checkpoint.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert a midmid PyTorch checkpoint to safetensors + config.json.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python convert_checkpoint.py path/to/best.pt --output-dir ./model_upload
|
| 5 |
+
|
| 6 |
+
This produces:
|
| 7 |
+
model_upload/model.safetensors (weights only, no pickle)
|
| 8 |
+
model_upload/config.json (model hyperparameters)
|
| 9 |
+
|
| 10 |
+
Then upload to HF:
|
| 11 |
+
huggingface-cli upload markury/midmid3-19m-0326 ./model_upload
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from safetensors.torch import save_file
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(description="Convert midmid checkpoint to safetensors")
|
| 24 |
+
parser.add_argument("checkpoint", type=Path, help="Path to .pt checkpoint")
|
| 25 |
+
parser.add_argument("--output-dir", type=Path, default=Path("model_upload"),
|
| 26 |
+
help="Output directory (default: ./model_upload)")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
print(f"Loading checkpoint: {args.checkpoint}")
|
| 32 |
+
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 33 |
+
|
| 34 |
+
# Save config
|
| 35 |
+
config = ckpt["config"]
|
| 36 |
+
config_path = args.output_dir / "config.json"
|
| 37 |
+
with open(config_path, "w") as f:
|
| 38 |
+
json.dump(config, f, indent=2)
|
| 39 |
+
print(f"Config saved: {config_path}")
|
| 40 |
+
print(f" {json.dumps(config, indent=2)}")
|
| 41 |
+
|
| 42 |
+
# Save weights as safetensors
|
| 43 |
+
state_dict = ckpt["model_state_dict"]
|
| 44 |
+
safetensors_path = args.output_dir / "model.safetensors"
|
| 45 |
+
save_file(state_dict, str(safetensors_path))
|
| 46 |
+
print(f"Weights saved: {safetensors_path}")
|
| 47 |
+
|
| 48 |
+
# Summary
|
| 49 |
+
n_params = sum(p.numel() for p in state_dict.values())
|
| 50 |
+
print(f" {n_params:,} parameters ({n_params / 1e6:.1f}M)")
|
| 51 |
+
|
| 52 |
+
print(f"\nUpload to HF with:")
|
| 53 |
+
print(f" huggingface-cli upload markury/midmid3-19m-0326 {args.output_dir}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
main()
|
midmid/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Midmid — Guitar Hero chart generation core
|
midmid/audio_prep.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio preparation — silence prepend, format conversion."""
|
| 2 |
+
|
| 3 |
+
from pydub import AudioSegment
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def prepare_audio(
|
| 7 |
+
audio_path: str,
|
| 8 |
+
output_path: str,
|
| 9 |
+
silence_duration_sec: float = 3.0,
|
| 10 |
+
output_format: str = "ogg",
|
| 11 |
+
) -> None:
|
| 12 |
+
"""Prepend silence to audio and export in the target format."""
|
| 13 |
+
audio = AudioSegment.from_file(audio_path)
|
| 14 |
+
silence = AudioSegment.silent(duration=int(silence_duration_sec * 1000))
|
| 15 |
+
prepared = silence + audio
|
| 16 |
+
prepared.export(output_path, format=output_format)
|
midmid/beat_tracker.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Beat and downbeat tracking via beat_this (CPJKU)."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from beat_this.inference import File2Beats
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class BeatData:
|
| 11 |
+
beats: np.ndarray
|
| 12 |
+
downbeats: np.ndarray
|
| 13 |
+
beat_numbers: np.ndarray
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
|
| 17 |
+
"""Run beat and downbeat tracking on an audio file."""
|
| 18 |
+
processor = File2Beats(checkpoint_path="final0", device=device)
|
| 19 |
+
beats, downbeats = processor(audio_path)
|
| 20 |
+
|
| 21 |
+
beat_numbers = _assign_beat_numbers(beats, downbeats)
|
| 22 |
+
|
| 23 |
+
return BeatData(
|
| 24 |
+
beats=np.asarray(beats),
|
| 25 |
+
downbeats=np.asarray(downbeats),
|
| 26 |
+
beat_numbers=beat_numbers,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _assign_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray:
|
| 31 |
+
beats = np.asarray(beats)
|
| 32 |
+
downbeats_set = set(np.round(downbeats, 6))
|
| 33 |
+
numbers = np.zeros(len(beats), dtype=int)
|
| 34 |
+
beat_num = 1
|
| 35 |
+
|
| 36 |
+
for i, t in enumerate(beats):
|
| 37 |
+
if round(float(t), 6) in downbeats_set:
|
| 38 |
+
beat_num = 1
|
| 39 |
+
numbers[i] = beat_num
|
| 40 |
+
beat_num += 1
|
| 41 |
+
|
| 42 |
+
return numbers
|
midmid/constraints.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Difficulty constraints — enforce per-difficulty fret and chord rules."""
|
| 2 |
+
|
| 3 |
+
from midmid.datatypes import NoteEvent
|
| 4 |
+
|
| 5 |
+
ALLOWED_FRETS = {
|
| 6 |
+
"easy": {0, 1, 2},
|
| 7 |
+
"medium": {0, 1, 2, 3},
|
| 8 |
+
"hard": {0, 1, 2, 3, 4},
|
| 9 |
+
"expert": {0, 1, 2, 3, 4},
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
MAX_CHORD_SIZE = {
|
| 13 |
+
"easy": 1,
|
| 14 |
+
"medium": 2,
|
| 15 |
+
"hard": 3,
|
| 16 |
+
"expert": 5,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
MIN_NOTE_SPACING = {
|
| 20 |
+
"easy": 192,
|
| 21 |
+
"medium": 96,
|
| 22 |
+
"hard": 48,
|
| 23 |
+
"expert": 0,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def enforce_constraints(
|
| 28 |
+
notes: list[NoteEvent], difficulty: str, resolution: int = 192,
|
| 29 |
+
) -> list[NoteEvent]:
|
| 30 |
+
allowed = ALLOWED_FRETS.get(difficulty, {0, 1, 2, 3, 4})
|
| 31 |
+
max_chord = MAX_CHORD_SIZE.get(difficulty, 5)
|
| 32 |
+
min_spacing = MIN_NOTE_SPACING.get(difficulty, 0)
|
| 33 |
+
|
| 34 |
+
result = []
|
| 35 |
+
for note in notes:
|
| 36 |
+
filtered = note.fret_set & allowed
|
| 37 |
+
if not filtered:
|
| 38 |
+
for fret in sorted(note.fret_set):
|
| 39 |
+
closest = min(allowed, key=lambda a: abs(a - fret))
|
| 40 |
+
filtered.add(closest)
|
| 41 |
+
break
|
| 42 |
+
if not filtered:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
if len(filtered) > max_chord:
|
| 46 |
+
filtered = set(sorted(filtered)[:max_chord])
|
| 47 |
+
|
| 48 |
+
if min_spacing > 0 and result:
|
| 49 |
+
if note.tick - result[-1].tick < min_spacing:
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
if result and result[-1].sustain_ticks > 0:
|
| 53 |
+
prev_end = result[-1].tick + result[-1].sustain_ticks
|
| 54 |
+
if note.tick < prev_end:
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
result.append(NoteEvent(
|
| 58 |
+
tick=note.tick,
|
| 59 |
+
fret_set=filtered,
|
| 60 |
+
sustain_ticks=note.sustain_ticks,
|
| 61 |
+
is_hopo=note.is_hopo,
|
| 62 |
+
))
|
| 63 |
+
|
| 64 |
+
sixteenth = resolution // 4
|
| 65 |
+
if len(result) >= 2 and result[-2].sustain_ticks > 0:
|
| 66 |
+
prev = result[-2]
|
| 67 |
+
max_sustain = note.tick - prev.tick - sixteenth
|
| 68 |
+
max_sustain = (max_sustain // sixteenth) * sixteenth
|
| 69 |
+
if max_sustain < sixteenth:
|
| 70 |
+
max_sustain = 0
|
| 71 |
+
if prev.sustain_ticks > max_sustain:
|
| 72 |
+
result[-2] = NoteEvent(
|
| 73 |
+
tick=prev.tick,
|
| 74 |
+
fret_set=prev.fret_set,
|
| 75 |
+
sustain_ticks=max_sustain,
|
| 76 |
+
is_hopo=prev.is_hopo,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return result
|
midmid/datatypes.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared data types used across the pipeline."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class NoteEvent:
|
| 8 |
+
"""A single note or chord at a specific tick position."""
|
| 9 |
+
tick: int
|
| 10 |
+
fret_set: set # {0, 1, 2, 3, 4} where 0=Green, 4=Orange
|
| 11 |
+
sustain_ticks: int = 0
|
| 12 |
+
is_hopo: bool = False
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ChartData:
|
| 17 |
+
"""Complete chart data ready for MIDI serialization."""
|
| 18 |
+
resolution: int = 192 # ticks per quarter note
|
| 19 |
+
tempo_events: list = field(default_factory=lambda: [(0, 120.0)])
|
| 20 |
+
time_signatures: list = field(default_factory=lambda: [(0, 4, 4)])
|
| 21 |
+
sections: list = field(default_factory=list) # [(tick, label), ...]
|
| 22 |
+
notes: dict = field(default_factory=dict) # {"expert": [NoteEvent, ...], ...}
|
| 23 |
+
beats: list = field(default_factory=list) # [(tick, is_downbeat), ...]
|
midmid/inference.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio encoding and iterative unmasking inference.
|
| 2 |
+
|
| 3 |
+
Adapted from midmid/prediction/model.py for standalone use.
|
| 4 |
+
Device management is caller-controlled (for ZeroGPU compatibility).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import itertools as _it
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from midmid.nn import (
|
| 17 |
+
ChartMaskPredictor, ChartMaskPredictorConfig,
|
| 18 |
+
MASK_TOKEN, SILENCE_TOKEN,
|
| 19 |
+
)
|
| 20 |
+
from midmid.datatypes import NoteEvent
|
| 21 |
+
|
| 22 |
+
MERT_MODEL_ID = "m-a-p/MERT-v1-95M"
|
| 23 |
+
|
| 24 |
+
DIFF_ID = {"easy": 0, "medium": 1, "hard": 2, "expert": 3}
|
| 25 |
+
|
| 26 |
+
# Class ID -> fret tuple
|
| 27 |
+
_CLASS_TO_FRETS: list[tuple[int, ...]] = []
|
| 28 |
+
for _r in range(1, 6):
|
| 29 |
+
_CLASS_TO_FRETS.extend(_it.combinations(range(5), _r))
|
| 30 |
+
_CLASS_TO_FRETS.append((7,)) # class 31 = open
|
| 31 |
+
|
| 32 |
+
# Sustain bucket center values in beats
|
| 33 |
+
_BUCKET_BEATS = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Model loading (safetensors from HF Hub)
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def load_model_from_hub(
|
| 41 |
+
repo_id: str = "markury/midmid3-19m-0326",
|
| 42 |
+
device: str = "cpu",
|
| 43 |
+
) -> ChartMaskPredictor:
|
| 44 |
+
"""Download and load model from HuggingFace Hub (safetensors)."""
|
| 45 |
+
from huggingface_hub import hf_hub_download
|
| 46 |
+
from safetensors.torch import load_file
|
| 47 |
+
|
| 48 |
+
config_path = hf_hub_download(repo_id, "config.json")
|
| 49 |
+
weights_path = hf_hub_download(repo_id, "model.safetensors")
|
| 50 |
+
|
| 51 |
+
with open(config_path) as f:
|
| 52 |
+
config_dict = json.load(f)
|
| 53 |
+
|
| 54 |
+
config = ChartMaskPredictorConfig(**config_dict)
|
| 55 |
+
model = ChartMaskPredictor(config)
|
| 56 |
+
|
| 57 |
+
state_dict = load_file(weights_path, device=device)
|
| 58 |
+
model.load_state_dict(state_dict)
|
| 59 |
+
model.to(device)
|
| 60 |
+
model.eval()
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# MERT audio encoding (lazy-loaded)
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
_mert_model = None
|
| 69 |
+
_mert_processor = None
|
| 70 |
+
_mert_frame_rate = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _ensure_mert(device: torch.device):
|
| 74 |
+
"""Load MERT model and processor on first use."""
|
| 75 |
+
global _mert_model, _mert_processor, _mert_frame_rate
|
| 76 |
+
if _mert_model is not None:
|
| 77 |
+
# Move to correct device if needed
|
| 78 |
+
if next(_mert_model.parameters()).device != device:
|
| 79 |
+
_mert_model.to(device)
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
from transformers import AutoModel, Wav2Vec2FeatureExtractor
|
| 83 |
+
|
| 84 |
+
print(f"Loading MERT ({MERT_MODEL_ID}) ...")
|
| 85 |
+
_mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 86 |
+
MERT_MODEL_ID, trust_remote_code=True,
|
| 87 |
+
)
|
| 88 |
+
_mert_model = AutoModel.from_pretrained(MERT_MODEL_ID, trust_remote_code=True)
|
| 89 |
+
_mert_model.to(device)
|
| 90 |
+
_mert_model.eval()
|
| 91 |
+
|
| 92 |
+
# Compute frame rate dynamically
|
| 93 |
+
sr = _mert_processor.sampling_rate
|
| 94 |
+
test_wav = np.zeros(sr, dtype=np.float32)
|
| 95 |
+
inputs = _mert_processor(test_wav, sampling_rate=sr, return_tensors="pt")
|
| 96 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
out = _mert_model(**inputs, output_hidden_states=False)
|
| 99 |
+
_mert_frame_rate = float(out.last_hidden_state.shape[1])
|
| 100 |
+
print(f" MERT frame rate: {_mert_frame_rate:.2f} Hz")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def move_models_to_device(device: torch.device):
|
| 104 |
+
"""Move all cached models to the specified device (for ZeroGPU)."""
|
| 105 |
+
global _mert_model
|
| 106 |
+
if _mert_model is not None:
|
| 107 |
+
_mert_model.to(device)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def encode_audio_mert(
|
| 112 |
+
audio_path: str,
|
| 113 |
+
device: torch.device,
|
| 114 |
+
chunk_sec: float = 60.0,
|
| 115 |
+
) -> tuple[torch.Tensor, float]:
|
| 116 |
+
"""Encode audio with MERT, return (embeddings, frame_rate)."""
|
| 117 |
+
import librosa
|
| 118 |
+
_ensure_mert(device)
|
| 119 |
+
|
| 120 |
+
sr = _mert_processor.sampling_rate
|
| 121 |
+
wav, _ = librosa.load(audio_path, sr=sr, mono=True)
|
| 122 |
+
|
| 123 |
+
chunk_samples = int(chunk_sec * sr)
|
| 124 |
+
overlap_sec = 5.0
|
| 125 |
+
overlap_samples = int(overlap_sec * sr)
|
| 126 |
+
stride_samples = chunk_samples - overlap_samples
|
| 127 |
+
|
| 128 |
+
if len(wav) <= chunk_samples:
|
| 129 |
+
inputs = _mert_processor(wav, sampling_rate=sr, return_tensors="pt")
|
| 130 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 131 |
+
out = _mert_model(**inputs, output_hidden_states=False)
|
| 132 |
+
return out.last_hidden_state.squeeze(0).cpu(), _mert_frame_rate
|
| 133 |
+
|
| 134 |
+
# Chunked processing for long audio
|
| 135 |
+
all_emb = []
|
| 136 |
+
pos = 0
|
| 137 |
+
idx = 0
|
| 138 |
+
while pos < len(wav):
|
| 139 |
+
end = min(pos + chunk_samples, len(wav))
|
| 140 |
+
chunk = wav[pos:end]
|
| 141 |
+
min_len = chunk_samples // 4
|
| 142 |
+
if len(chunk) < min_len:
|
| 143 |
+
chunk = np.pad(chunk, (0, min_len - len(chunk)))
|
| 144 |
+
|
| 145 |
+
inputs = _mert_processor(chunk, sampling_rate=sr, return_tensors="pt")
|
| 146 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 147 |
+
out = _mert_model(**inputs, output_hidden_states=False)
|
| 148 |
+
emb = out.last_hidden_state.squeeze(0)
|
| 149 |
+
|
| 150 |
+
n = emb.shape[0]
|
| 151 |
+
fps = n / (len(chunk) / sr)
|
| 152 |
+
half_overlap = int(round((overlap_sec / 2) * fps))
|
| 153 |
+
|
| 154 |
+
if idx == 0:
|
| 155 |
+
keep = n - half_overlap if end < len(wav) else n
|
| 156 |
+
all_emb.append(emb[:keep].cpu())
|
| 157 |
+
elif end >= len(wav):
|
| 158 |
+
all_emb.append(emb[half_overlap:].cpu())
|
| 159 |
+
else:
|
| 160 |
+
keep = int(round((len(chunk) / sr - overlap_sec) * fps))
|
| 161 |
+
all_emb.append(emb[half_overlap:half_overlap + keep].cpu())
|
| 162 |
+
|
| 163 |
+
pos += stride_samples
|
| 164 |
+
idx += 1
|
| 165 |
+
|
| 166 |
+
return torch.cat(all_emb, dim=0), _mert_frame_rate
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
# Grid helpers
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
|
| 173 |
+
def _build_16th_grid(fretbars):
|
| 174 |
+
"""Build 16th-note timestamps (ms) from beat positions."""
|
| 175 |
+
if len(fretbars) < 2:
|
| 176 |
+
return list(fretbars)
|
| 177 |
+
positions = []
|
| 178 |
+
for i in range(len(fretbars) - 1):
|
| 179 |
+
start = fretbars[i]
|
| 180 |
+
interval = fretbars[i + 1] - start
|
| 181 |
+
for sub in range(4):
|
| 182 |
+
positions.append(start + sub * interval / 4.0)
|
| 183 |
+
positions.append(fretbars[-1])
|
| 184 |
+
return positions
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _get_local_beat_ms(grid_idx, fretbars):
|
| 188 |
+
beat_idx = min(grid_idx // 4, len(fretbars) - 2)
|
| 189 |
+
beat_idx = max(0, beat_idx)
|
| 190 |
+
if beat_idx + 1 < len(fretbars):
|
| 191 |
+
return fretbars[beat_idx + 1] - fretbars[beat_idx]
|
| 192 |
+
return 500.0
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Main inference
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
@torch.no_grad()
|
| 200 |
+
def predict_notes(
|
| 201 |
+
audio_path: str,
|
| 202 |
+
model: ChartMaskPredictor,
|
| 203 |
+
beat_times: list[float],
|
| 204 |
+
difficulty: str = "expert",
|
| 205 |
+
device: torch.device = None,
|
| 206 |
+
num_steps: int = 12,
|
| 207 |
+
temperature: float = 0.9,
|
| 208 |
+
) -> list[NoteEvent]:
|
| 209 |
+
"""MaskGIT-style iterative unmasking inference."""
|
| 210 |
+
if device is None:
|
| 211 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 212 |
+
dev = device
|
| 213 |
+
model.to(dev)
|
| 214 |
+
model.eval()
|
| 215 |
+
|
| 216 |
+
fretbars = [t * 1000.0 for t in beat_times]
|
| 217 |
+
if len(fretbars) < 2:
|
| 218 |
+
return []
|
| 219 |
+
|
| 220 |
+
# MERT embeddings
|
| 221 |
+
embeddings, frame_rate = encode_audio_mert(audio_path, dev)
|
| 222 |
+
|
| 223 |
+
# Build grid and sample MERT frames with windowing
|
| 224 |
+
grid_times = _build_16th_grid(fretbars)
|
| 225 |
+
num_positions = len(grid_times)
|
| 226 |
+
max_frame = embeddings.shape[0] - 1
|
| 227 |
+
frame_indices = torch.tensor(
|
| 228 |
+
[min(int(round(t / 1000.0 * frame_rate)), max_frame)
|
| 229 |
+
for t in grid_times], dtype=torch.long,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
window = 2
|
| 233 |
+
if window > 0 and max_frame >= window * 2:
|
| 234 |
+
padded = torch.nn.functional.pad(
|
| 235 |
+
embeddings.unsqueeze(0), (0, 0, window, window), mode="replicate",
|
| 236 |
+
).squeeze(0)
|
| 237 |
+
shifted = frame_indices + window
|
| 238 |
+
stacked = torch.stack(
|
| 239 |
+
[padded[shifted + d] for d in range(-window, window + 1)], dim=0,
|
| 240 |
+
)
|
| 241 |
+
grid_emb = stacked.mean(dim=0)
|
| 242 |
+
else:
|
| 243 |
+
grid_emb = embeddings[frame_indices]
|
| 244 |
+
|
| 245 |
+
# Compute and concat audio features if model expects them
|
| 246 |
+
if model.config.audio_dim > grid_emb.shape[-1]:
|
| 247 |
+
import librosa as _lr
|
| 248 |
+
wav, _ = _lr.load(audio_path, sr=24000, mono=True)
|
| 249 |
+
hop = 320
|
| 250 |
+
onset = _lr.onset.onset_strength(y=wav, sr=24000, hop_length=hop)
|
| 251 |
+
rms_arr = _lr.feature.rms(y=wav, hop_length=hop)[0]
|
| 252 |
+
centroid = _lr.feature.spectral_centroid(y=wav, sr=24000, hop_length=hop)[0]
|
| 253 |
+
|
| 254 |
+
def _norm(x):
|
| 255 |
+
mn, mx = x.min(), x.max()
|
| 256 |
+
return (x - mn) / max(mx - mn, 1e-8)
|
| 257 |
+
|
| 258 |
+
onset, rms_arr, centroid = _norm(onset), _norm(rms_arr), _norm(centroid)
|
| 259 |
+
af_rate = 24000 / hop
|
| 260 |
+
af_max = len(onset) - 1
|
| 261 |
+
af_indices = [min(int(round(t / 1000.0 * af_rate)), af_max) for t in grid_times]
|
| 262 |
+
af_tensor = torch.tensor(
|
| 263 |
+
[[onset[i], rms_arr[i], centroid[i]] for i in af_indices],
|
| 264 |
+
dtype=torch.float32,
|
| 265 |
+
)
|
| 266 |
+
grid_emb = torch.cat([grid_emb, af_tensor], dim=-1)
|
| 267 |
+
|
| 268 |
+
audio_features = grid_emb.unsqueeze(0).to(dev)
|
| 269 |
+
|
| 270 |
+
diff_id = DIFF_ID.get(difficulty, 3)
|
| 271 |
+
diff_tensor = torch.tensor([diff_id], dtype=torch.long, device=dev)
|
| 272 |
+
padding_mask = torch.ones(1, num_positions, dtype=torch.bool, device=dev)
|
| 273 |
+
|
| 274 |
+
# Start fully masked
|
| 275 |
+
chart_tokens = torch.full(
|
| 276 |
+
(1, num_positions), MASK_TOKEN, dtype=torch.long, device=dev,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Cosine unmasking schedule
|
| 280 |
+
schedule = []
|
| 281 |
+
for step in range(num_steps):
|
| 282 |
+
r_prev = math.cos(math.pi / 2 * step / num_steps)
|
| 283 |
+
r_next = math.cos(math.pi / 2 * (step + 1) / num_steps)
|
| 284 |
+
n_unmask = max(1, int((r_prev - r_next) * num_positions))
|
| 285 |
+
schedule.append(n_unmask)
|
| 286 |
+
|
| 287 |
+
# Iterative unmasking
|
| 288 |
+
for step in range(num_steps):
|
| 289 |
+
outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask)
|
| 290 |
+
token_logits = outputs["token_logits"].squeeze(0)
|
| 291 |
+
|
| 292 |
+
is_masked = (chart_tokens.squeeze(0) == MASK_TOKEN)
|
| 293 |
+
masked_indices = is_masked.nonzero(as_tuple=True)[0]
|
| 294 |
+
|
| 295 |
+
if len(masked_indices) == 0:
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
probs = torch.softmax(token_logits / temperature, dim=-1)
|
| 299 |
+
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 300 |
+
|
| 301 |
+
n_unmask = min(schedule[step], len(masked_indices))
|
| 302 |
+
perm = torch.randperm(len(masked_indices), device=dev)
|
| 303 |
+
unmask_idx = masked_indices[perm[:n_unmask]]
|
| 304 |
+
chart_tokens[0, unmask_idx] = sampled[unmask_idx]
|
| 305 |
+
|
| 306 |
+
# Final pass for sustain predictions
|
| 307 |
+
outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask)
|
| 308 |
+
sustain_prob = outputs["sustain_logits"].squeeze(0).squeeze(-1).sigmoid()
|
| 309 |
+
dur_pred = outputs["duration_logits"].squeeze(0).argmax(dim=-1)
|
| 310 |
+
|
| 311 |
+
# Convert tokens to NoteEvents
|
| 312 |
+
tokens = chart_tokens.squeeze(0).cpu()
|
| 313 |
+
notes = []
|
| 314 |
+
for i in range(num_positions):
|
| 315 |
+
tok = tokens[i].item()
|
| 316 |
+
if tok >= SILENCE_TOKEN or tok < 0:
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
fret_set = set(_CLASS_TO_FRETS[tok])
|
| 320 |
+
if not fret_set:
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
sustain_ticks = 0
|
| 324 |
+
if sustain_prob[i] >= 0.5:
|
| 325 |
+
bucket = dur_pred[i].item()
|
| 326 |
+
beat_ms = _get_local_beat_ms(i, fretbars)
|
| 327 |
+
sustain_ticks = _BUCKET_BEATS[bucket] * beat_ms
|
| 328 |
+
|
| 329 |
+
notes.append(NoteEvent(
|
| 330 |
+
tick=i,
|
| 331 |
+
fret_set=fret_set,
|
| 332 |
+
sustain_ticks=sustain_ticks,
|
| 333 |
+
))
|
| 334 |
+
|
| 335 |
+
return notes
|
midmid/ini_writer.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate song.ini metadata for GHWT:DE."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def write_ini(
|
| 7 |
+
output_path: str,
|
| 8 |
+
title: str = "Unknown Song",
|
| 9 |
+
artist: str = "Unknown Artist",
|
| 10 |
+
album: str = "",
|
| 11 |
+
genre: str = "rock",
|
| 12 |
+
year: str = "2024",
|
| 13 |
+
charter: str = "Midmid",
|
| 14 |
+
diff_guitar: int = 0,
|
| 15 |
+
preview_start_time: int = 30000,
|
| 16 |
+
song_length: int = 0,
|
| 17 |
+
) -> None:
|
| 18 |
+
lines = [
|
| 19 |
+
"[Song]",
|
| 20 |
+
f"name = {title}",
|
| 21 |
+
f"artist = {artist}",
|
| 22 |
+
f"album = {album}",
|
| 23 |
+
f"genre = {genre}",
|
| 24 |
+
f"year = {year}",
|
| 25 |
+
f"charter = {charter}",
|
| 26 |
+
f"diff_guitar = {diff_guitar}",
|
| 27 |
+
f"preview_start_time = {preview_start_time}",
|
| 28 |
+
]
|
| 29 |
+
if song_length > 0:
|
| 30 |
+
lines.append(f"song_length = {song_length}")
|
| 31 |
+
|
| 32 |
+
Path(output_path).write_text("\n".join(lines) + "\n", encoding="utf-8")
|
midmid/midi_writer.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MIDI serialization: write ChartData to a GH-format .mid file."""
|
| 2 |
+
|
| 3 |
+
import mido
|
| 4 |
+
|
| 5 |
+
from midmid.datatypes import ChartData, NoteEvent
|
| 6 |
+
|
| 7 |
+
DIFFICULTY_OFFSETS = {"easy": 60, "medium": 72, "hard": 84, "expert": 96}
|
| 8 |
+
HOPO_NOTE = {"easy": 65, "medium": 77, "hard": 89, "expert": 101}
|
| 9 |
+
NOTE_VELOCITY = 100
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def write_midi(chart: ChartData, output_path: str) -> None:
|
| 13 |
+
mid = mido.MidiFile(ticks_per_beat=chart.resolution)
|
| 14 |
+
|
| 15 |
+
mid.tracks.append(_build_tempo_track(chart))
|
| 16 |
+
mid.tracks.append(_build_events_track(chart))
|
| 17 |
+
mid.tracks.append(_build_guitar_track(chart))
|
| 18 |
+
if chart.beats:
|
| 19 |
+
mid.tracks.append(_build_beat_track(chart))
|
| 20 |
+
|
| 21 |
+
mid.save(output_path)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _build_tempo_track(chart):
|
| 25 |
+
track = mido.MidiTrack()
|
| 26 |
+
events = []
|
| 27 |
+
|
| 28 |
+
for tick, bpm in chart.tempo_events:
|
| 29 |
+
events.append((tick, mido.MetaMessage(
|
| 30 |
+
"set_tempo", tempo=mido.bpm2tempo(bpm), time=0)))
|
| 31 |
+
|
| 32 |
+
for tick, num, den in chart.time_signatures:
|
| 33 |
+
events.append((tick, mido.MetaMessage(
|
| 34 |
+
"time_signature", numerator=num, denominator=den, time=0)))
|
| 35 |
+
|
| 36 |
+
_write_sorted_events(track, events)
|
| 37 |
+
return track
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_events_track(chart):
|
| 41 |
+
track = mido.MidiTrack()
|
| 42 |
+
track.append(mido.MetaMessage("track_name", name="EVENTS", time=0))
|
| 43 |
+
|
| 44 |
+
events = []
|
| 45 |
+
for tick, label in chart.sections:
|
| 46 |
+
events.append((tick, mido.MetaMessage(
|
| 47 |
+
"text", text=f"[section {label}]", time=0)))
|
| 48 |
+
|
| 49 |
+
_write_sorted_events(track, events)
|
| 50 |
+
return track
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _build_guitar_track(chart):
|
| 54 |
+
track = mido.MidiTrack()
|
| 55 |
+
track.append(mido.MetaMessage("track_name", name="PART GUITAR", time=0))
|
| 56 |
+
|
| 57 |
+
events = []
|
| 58 |
+
for difficulty, offset in DIFFICULTY_OFFSETS.items():
|
| 59 |
+
if difficulty not in chart.notes:
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
for note in chart.notes[difficulty]:
|
| 63 |
+
for fret in note.fret_set:
|
| 64 |
+
midi_note = offset + fret
|
| 65 |
+
events.append((note.tick, mido.Message(
|
| 66 |
+
"note_on", note=midi_note, velocity=NOTE_VELOCITY, time=0)))
|
| 67 |
+
off_tick = note.tick + max(note.sustain_ticks, 1)
|
| 68 |
+
events.append((off_tick, mido.Message(
|
| 69 |
+
"note_off", note=midi_note, velocity=0, time=0)))
|
| 70 |
+
|
| 71 |
+
if note.is_hopo:
|
| 72 |
+
hopo_note = HOPO_NOTE[difficulty]
|
| 73 |
+
events.append((note.tick, mido.Message(
|
| 74 |
+
"note_on", note=hopo_note, velocity=NOTE_VELOCITY, time=0)))
|
| 75 |
+
events.append((note.tick + 1, mido.Message(
|
| 76 |
+
"note_off", note=hopo_note, velocity=0, time=0)))
|
| 77 |
+
|
| 78 |
+
_write_sorted_events(track, events)
|
| 79 |
+
return track
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _build_beat_track(chart):
|
| 83 |
+
track = mido.MidiTrack()
|
| 84 |
+
track.append(mido.MetaMessage("track_name", name="BEAT", time=0))
|
| 85 |
+
|
| 86 |
+
events = []
|
| 87 |
+
for tick, is_downbeat in chart.beats:
|
| 88 |
+
midi_note = 12 if is_downbeat else 13
|
| 89 |
+
events.append((tick, mido.Message(
|
| 90 |
+
"note_on", note=midi_note, velocity=NOTE_VELOCITY, time=0)))
|
| 91 |
+
events.append((tick + 1, mido.Message(
|
| 92 |
+
"note_off", note=midi_note, velocity=0, time=0)))
|
| 93 |
+
|
| 94 |
+
_write_sorted_events(track, events)
|
| 95 |
+
return track
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _write_sorted_events(track, events):
|
| 99 |
+
events.sort(key=lambda e: e[0])
|
| 100 |
+
prev_tick = 0
|
| 101 |
+
for abs_tick, msg in events:
|
| 102 |
+
msg.time = abs_tick - prev_tick
|
| 103 |
+
track.append(msg)
|
| 104 |
+
prev_tick = abs_tick
|
midmid/nn.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chart prediction model architecture.
|
| 2 |
+
|
| 3 |
+
FiLM-conditioned masked transformer for Guitar Hero chart generation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# Utility layers
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0):
|
| 19 |
+
x_glu, x_linear = x[..., ::2], x[..., 1::2]
|
| 20 |
+
x_glu = x_glu.clamp(max=limit)
|
| 21 |
+
x_linear = x_linear.clamp(min=-limit, max=limit)
|
| 22 |
+
return x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RMSNorm(nn.Module):
|
| 26 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.eps = eps
|
| 29 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
t = x.float()
|
| 33 |
+
t = t * torch.rsqrt(t.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 34 |
+
return (t * self.scale).to(x.dtype)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FeedForward(nn.Module):
|
| 38 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.linear1 = nn.Linear(d_model, d_ff, bias=False)
|
| 41 |
+
self.linear_out = nn.Linear(d_ff // 2, d_model, bias=False)
|
| 42 |
+
self.dropout = nn.Dropout(dropout)
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
return self.linear_out(self.dropout(swiglu(self.linear1(x))))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Rotary position embeddings
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
def apply_rotary_emb(
|
| 53 |
+
x: torch.Tensor, dim: int, base: float = 10000.0,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""Apply RoPE to a tensor of shape [B, heads, T, head_dim]."""
|
| 56 |
+
seq_len = x.size(2)
|
| 57 |
+
device, dtype = x.device, x.dtype
|
| 58 |
+
theta = base ** (-torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)
|
| 59 |
+
positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1)
|
| 60 |
+
angles = positions * theta.unsqueeze(0)
|
| 61 |
+
sin, cos = angles.sin(), angles.cos()
|
| 62 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 63 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 64 |
+
x1 = x[..., : dim // 2]
|
| 65 |
+
x2 = x[..., dim // 2 : dim]
|
| 66 |
+
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# Bidirectional multi-head self-attention
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
class BidirectionalAttention(nn.Module):
|
| 74 |
+
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
|
| 75 |
+
rope_base: float = 10000.0):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.d_model = d_model
|
| 78 |
+
self.n_heads = n_heads
|
| 79 |
+
self.d_k = d_model // n_heads
|
| 80 |
+
self.rope_base = rope_base
|
| 81 |
+
|
| 82 |
+
self.w_q = nn.Linear(d_model, d_model, bias=False)
|
| 83 |
+
self.w_k = nn.Linear(d_model, d_model, bias=False)
|
| 84 |
+
self.w_v = nn.Linear(d_model, d_model, bias=False)
|
| 85 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 86 |
+
self.dropout = nn.Dropout(dropout)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor,
|
| 89 |
+
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 90 |
+
B, T, _ = x.shape
|
| 91 |
+
Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
|
| 92 |
+
K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
|
| 93 |
+
V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
Q = apply_rotary_emb(Q, dim=self.d_k, base=self.rope_base)
|
| 96 |
+
K = apply_rotary_emb(K, dim=self.d_k, base=self.rope_base)
|
| 97 |
+
|
| 98 |
+
sdpa_mask = None
|
| 99 |
+
if attn_mask is not None:
|
| 100 |
+
sdpa_mask = attn_mask[:, None, None, :].bool()
|
| 101 |
+
|
| 102 |
+
out = F.scaled_dot_product_attention(
|
| 103 |
+
Q, K, V, attn_mask=sdpa_mask,
|
| 104 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
| 105 |
+
is_causal=False,
|
| 106 |
+
)
|
| 107 |
+
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
|
| 108 |
+
return self.out_proj(out)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# FiLM-conditioned encoder block
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class FiLMEncoderBlock(nn.Module):
|
| 116 |
+
"""Encoder block with FiLM difficulty conditioning.
|
| 117 |
+
|
| 118 |
+
After the feedforward, the output is modulated:
|
| 119 |
+
h = (1 + gamma) * h + beta
|
| 120 |
+
where gamma, beta are derived from the difficulty embedding.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, d_model: int, d_ff: int, n_heads: int,
|
| 124 |
+
dropout: float = 0.1, rope_base: float = 10000.0):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.norm1 = RMSNorm(d_model)
|
| 127 |
+
self.attn = BidirectionalAttention(d_model, n_heads, dropout, rope_base)
|
| 128 |
+
self.norm2 = RMSNorm(d_model)
|
| 129 |
+
self.ff = FeedForward(d_model, d_ff, dropout)
|
| 130 |
+
self.dropout = nn.Dropout(dropout)
|
| 131 |
+
|
| 132 |
+
self.film_proj = nn.Linear(d_model, d_model * 2)
|
| 133 |
+
nn.init.zeros_(self.film_proj.weight)
|
| 134 |
+
nn.init.zeros_(self.film_proj.bias)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor, diff_emb: torch.Tensor,
|
| 137 |
+
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 138 |
+
x = x + self.dropout(self.attn(self.norm1(x), attn_mask))
|
| 139 |
+
h = self.ff(self.norm2(x))
|
| 140 |
+
|
| 141 |
+
film = self.film_proj(diff_emb).unsqueeze(1)
|
| 142 |
+
gamma, beta = film.chunk(2, dim=-1)
|
| 143 |
+
h = (1 + gamma) * h + beta
|
| 144 |
+
|
| 145 |
+
x = x + self.dropout(h)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# Constants
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
SILENCE_TOKEN = 32
|
| 154 |
+
MASK_TOKEN = 33
|
| 155 |
+
VOCAB_SIZE = 34
|
| 156 |
+
NUM_SUSTAIN_BUCKETS = 6
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ---------------------------------------------------------------------------
|
| 160 |
+
# Main model
|
| 161 |
+
# ---------------------------------------------------------------------------
|
| 162 |
+
|
| 163 |
+
class ChartMaskPredictor(nn.Module):
|
| 164 |
+
"""Masked prediction chart model (v3).
|
| 165 |
+
|
| 166 |
+
Token vocabulary: 0-31 fret combos, 32 silence, 33 MASK.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, config: "ChartMaskPredictorConfig"):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.config = config
|
| 172 |
+
d = config.d_model
|
| 173 |
+
|
| 174 |
+
self.audio_projection = nn.Linear(config.audio_dim, d, bias=False)
|
| 175 |
+
self.chart_embedding = nn.Embedding(VOCAB_SIZE, d)
|
| 176 |
+
self.input_dropout = nn.Dropout(config.dropout)
|
| 177 |
+
self.difficulty_embedding = nn.Embedding(4, d)
|
| 178 |
+
|
| 179 |
+
self.layers = nn.ModuleList([
|
| 180 |
+
FiLMEncoderBlock(
|
| 181 |
+
d_model=d, d_ff=config.d_ff, n_heads=config.n_heads,
|
| 182 |
+
dropout=config.dropout, rope_base=config.rope_base,
|
| 183 |
+
)
|
| 184 |
+
for _ in range(config.n_layers)
|
| 185 |
+
])
|
| 186 |
+
|
| 187 |
+
self.final_norm = RMSNorm(d)
|
| 188 |
+
self.token_head = nn.Linear(d, VOCAB_SIZE - 1) # 33 classes (no MASK)
|
| 189 |
+
self.sustain_head = nn.Linear(d, 1)
|
| 190 |
+
self.duration_head = nn.Linear(d, NUM_SUSTAIN_BUCKETS)
|
| 191 |
+
|
| 192 |
+
def forward(self, audio_features: torch.Tensor, chart_tokens: torch.Tensor,
|
| 193 |
+
difficulty: torch.Tensor,
|
| 194 |
+
padding_mask: Optional[torch.Tensor] = None) -> dict[str, torch.Tensor]:
|
| 195 |
+
audio = self.audio_projection(audio_features)
|
| 196 |
+
chart = self.chart_embedding(chart_tokens)
|
| 197 |
+
x = audio + chart
|
| 198 |
+
x = self.input_dropout(x)
|
| 199 |
+
|
| 200 |
+
diff_emb = self.difficulty_embedding(difficulty)
|
| 201 |
+
|
| 202 |
+
for layer in self.layers:
|
| 203 |
+
x = layer(x, diff_emb, attn_mask=padding_mask)
|
| 204 |
+
|
| 205 |
+
x = self.final_norm(x)
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"token_logits": self.token_head(x),
|
| 209 |
+
"sustain_logits": self.sustain_head(x),
|
| 210 |
+
"duration_logits": self.duration_head(x),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@dataclass
|
| 215 |
+
class ChartMaskPredictorConfig:
|
| 216 |
+
audio_dim: int = 771
|
| 217 |
+
d_model: int = 512
|
| 218 |
+
n_heads: int = 8
|
| 219 |
+
n_layers: int = 6
|
| 220 |
+
d_ff: int = 2048
|
| 221 |
+
dropout: float = 0.15
|
| 222 |
+
rope_base: float = 10000.0
|
midmid/offset.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Offset / silence duration calculation."""
|
| 2 |
+
|
| 3 |
+
from midmid.beat_tracker import BeatData
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def calculate_offset(
|
| 7 |
+
beat_data: BeatData,
|
| 8 |
+
bpm: float,
|
| 9 |
+
beats_per_measure: int = 4,
|
| 10 |
+
min_lead_in: float = 2.0,
|
| 11 |
+
) -> float:
|
| 12 |
+
"""Calculate silence duration to prepend to the audio."""
|
| 13 |
+
if len(beat_data.downbeats) == 0:
|
| 14 |
+
return min_lead_in
|
| 15 |
+
|
| 16 |
+
first_downbeat = float(beat_data.downbeats[0])
|
| 17 |
+
measure_duration = beats_per_measure * 60.0 / bpm
|
| 18 |
+
|
| 19 |
+
n = 1
|
| 20 |
+
while n * measure_duration < min_lead_in:
|
| 21 |
+
n += 1
|
| 22 |
+
|
| 23 |
+
silence = n * measure_duration - first_downbeat
|
| 24 |
+
|
| 25 |
+
while silence < 0:
|
| 26 |
+
n += 1
|
| 27 |
+
silence = n * measure_duration - first_downbeat
|
| 28 |
+
|
| 29 |
+
return silence
|
midmid/sections.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Structural segmentation (intro, verse, chorus, etc.)."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import librosa
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def detect_sections(
|
| 8 |
+
audio_path: str,
|
| 9 |
+
min_section_duration: float = 8.0,
|
| 10 |
+
) -> list[tuple[float, str]]:
|
| 11 |
+
"""Detect structural sections in an audio file."""
|
| 12 |
+
y, sr = librosa.load(audio_path, sr=22050, mono=True)
|
| 13 |
+
duration = len(y) / sr
|
| 14 |
+
|
| 15 |
+
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
|
| 16 |
+
|
| 17 |
+
n_frames = mfcc.shape[1]
|
| 18 |
+
k = max(2, min(n_frames - 1, int(duration / 25)))
|
| 19 |
+
bounds = librosa.segment.agglomerative(mfcc, k=k)
|
| 20 |
+
bound_times = librosa.frames_to_time(bounds, sr=sr)
|
| 21 |
+
|
| 22 |
+
if len(bound_times) == 0 or bound_times[0] > 0.5:
|
| 23 |
+
bound_times = np.concatenate([[0.0], bound_times])
|
| 24 |
+
|
| 25 |
+
bound_times = _merge_short_segments(bound_times, duration, min_section_duration)
|
| 26 |
+
labels = _assign_labels(y, sr, bound_times, duration)
|
| 27 |
+
|
| 28 |
+
return list(zip(bound_times.tolist(), labels))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _merge_short_segments(bounds, duration, min_dur):
|
| 32 |
+
merged = [bounds[0]]
|
| 33 |
+
for t in bounds[1:]:
|
| 34 |
+
if t - merged[-1] >= min_dur:
|
| 35 |
+
merged.append(t)
|
| 36 |
+
return np.array(merged)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _assign_labels(y, sr, bound_times, duration):
|
| 40 |
+
n = len(bound_times)
|
| 41 |
+
if n == 0:
|
| 42 |
+
return []
|
| 43 |
+
if n == 1:
|
| 44 |
+
return ["Intro"]
|
| 45 |
+
|
| 46 |
+
segment_features = []
|
| 47 |
+
for i in range(n):
|
| 48 |
+
start_sample = int(bound_times[i] * sr)
|
| 49 |
+
end_sample = int(bound_times[i + 1] * sr) if i + 1 < n else len(y)
|
| 50 |
+
seg = y[start_sample:end_sample]
|
| 51 |
+
if len(seg) < sr // 4:
|
| 52 |
+
segment_features.append(np.zeros(13))
|
| 53 |
+
else:
|
| 54 |
+
mfcc = librosa.feature.mfcc(y=seg, sr=sr, n_mfcc=13)
|
| 55 |
+
segment_features.append(np.mean(mfcc, axis=1))
|
| 56 |
+
|
| 57 |
+
labels = ["Intro"]
|
| 58 |
+
letter_idx = 0
|
| 59 |
+
letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
| 60 |
+
assigned = {}
|
| 61 |
+
|
| 62 |
+
for i in range(1, n):
|
| 63 |
+
best_sim = -1
|
| 64 |
+
best_j = -1
|
| 65 |
+
for j in range(i):
|
| 66 |
+
sim = _cosine_sim(segment_features[i], segment_features[j])
|
| 67 |
+
if sim > best_sim:
|
| 68 |
+
best_sim = sim
|
| 69 |
+
best_j = j
|
| 70 |
+
|
| 71 |
+
if best_sim > 0.85 and best_j in assigned:
|
| 72 |
+
labels.append(f"Section {assigned[best_j]}")
|
| 73 |
+
else:
|
| 74 |
+
letter = letters[letter_idx % len(letters)]
|
| 75 |
+
letter_idx += 1
|
| 76 |
+
assigned[i] = letter
|
| 77 |
+
labels.append(f"Section {letter}")
|
| 78 |
+
|
| 79 |
+
if best_j not in assigned and best_j > 0:
|
| 80 |
+
assigned[best_j] = labels[best_j].split()[-1] if " " in labels[best_j] else "A"
|
| 81 |
+
|
| 82 |
+
return labels
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _cosine_sim(a, b):
|
| 86 |
+
norm_a = np.linalg.norm(a)
|
| 87 |
+
norm_b = np.linalg.norm(b)
|
| 88 |
+
if norm_a == 0 or norm_b == 0:
|
| 89 |
+
return 0.0
|
| 90 |
+
return float(np.dot(a, b) / (norm_a * norm_b))
|
midmid/tempo_map.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tempo map derivation from beat tracker output."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from midmid.beat_tracker import BeatData
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def derive_tempo_map(
|
| 9 |
+
beat_data: BeatData, change_threshold: float = 0.08,
|
| 10 |
+
) -> list[tuple[float, float]]:
|
| 11 |
+
"""Derive a tempo map from beat data.
|
| 12 |
+
|
| 13 |
+
Returns list of (time_seconds, bpm) tuples, sorted by time.
|
| 14 |
+
"""
|
| 15 |
+
beats = beat_data.beats
|
| 16 |
+
if len(beats) < 2:
|
| 17 |
+
return [(0.0, 120.0)]
|
| 18 |
+
|
| 19 |
+
intervals = np.diff(beats)
|
| 20 |
+
bpms = 60.0 / intervals
|
| 21 |
+
|
| 22 |
+
median_bpm = np.median(bpms)
|
| 23 |
+
valid = (bpms > median_bpm * 0.6) & (bpms < median_bpm * 1.6)
|
| 24 |
+
if not np.any(valid):
|
| 25 |
+
return [(0.0, float(median_bpm))]
|
| 26 |
+
|
| 27 |
+
valid_bpms = bpms[valid]
|
| 28 |
+
if np.std(valid_bpms) / np.mean(valid_bpms) < change_threshold:
|
| 29 |
+
avg_bpm = float(np.mean(valid_bpms))
|
| 30 |
+
return [(0.0, _round_bpm(avg_bpm))]
|
| 31 |
+
|
| 32 |
+
tempo_map = []
|
| 33 |
+
current_bpm = float(bpms[0]) if valid[0] else float(median_bpm)
|
| 34 |
+
tempo_map.append((0.0, _round_bpm(current_bpm)))
|
| 35 |
+
|
| 36 |
+
window = 4
|
| 37 |
+
for i in range(window, len(bpms) - window + 1, window):
|
| 38 |
+
chunk = bpms[i : i + window]
|
| 39 |
+
chunk_valid = chunk[(chunk > median_bpm * 0.6) & (chunk < median_bpm * 1.6)]
|
| 40 |
+
if len(chunk_valid) == 0:
|
| 41 |
+
continue
|
| 42 |
+
local_bpm = float(np.mean(chunk_valid))
|
| 43 |
+
if abs(local_bpm - current_bpm) / current_bpm > change_threshold:
|
| 44 |
+
current_bpm = local_bpm
|
| 45 |
+
tempo_map.append((float(beats[i]), _round_bpm(current_bpm)))
|
| 46 |
+
|
| 47 |
+
return tempo_map
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_median_bpm(beat_data: BeatData) -> float:
|
| 51 |
+
if len(beat_data.beats) < 2:
|
| 52 |
+
return 120.0
|
| 53 |
+
intervals = np.diff(beat_data.beats)
|
| 54 |
+
bpms = 60.0 / intervals
|
| 55 |
+
return float(_round_bpm(np.median(bpms)))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def estimate_time_signature(beat_data: BeatData) -> int:
|
| 59 |
+
if len(beat_data.downbeats) < 2:
|
| 60 |
+
return 4
|
| 61 |
+
|
| 62 |
+
beats = beat_data.beats
|
| 63 |
+
downbeats = beat_data.downbeats
|
| 64 |
+
|
| 65 |
+
counts = []
|
| 66 |
+
for i in range(len(downbeats) - 1):
|
| 67 |
+
start, end = downbeats[i], downbeats[i + 1]
|
| 68 |
+
n = np.sum((beats >= start) & (beats < end))
|
| 69 |
+
if 2 <= n <= 7:
|
| 70 |
+
counts.append(n)
|
| 71 |
+
|
| 72 |
+
if not counts:
|
| 73 |
+
return 4
|
| 74 |
+
|
| 75 |
+
values, freq = np.unique(counts, return_counts=True)
|
| 76 |
+
return int(values[np.argmax(freq)])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _round_bpm(bpm: float) -> float:
|
| 80 |
+
return round(float(bpm), 2)
|
pipeline.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generation pipeline — callable from Gradio, ZeroGPU-compatible.
|
| 2 |
+
|
| 3 |
+
Wraps the full audio→chart pipeline into a single function that returns
|
| 4 |
+
a zip file path and chart JSON for the visualizer.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import base64
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
import tempfile
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from midmid.beat_tracker import track_beats
|
| 19 |
+
from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature
|
| 20 |
+
from midmid.offset import calculate_offset
|
| 21 |
+
from midmid.sections import detect_sections
|
| 22 |
+
from midmid.constraints import enforce_constraints
|
| 23 |
+
from midmid.datatypes import ChartData, NoteEvent
|
| 24 |
+
from midmid.inference import load_model_from_hub, predict_notes, move_models_to_device
|
| 25 |
+
from midmid.midi_writer import write_midi
|
| 26 |
+
from midmid.audio_prep import prepare_audio
|
| 27 |
+
from midmid.ini_writer import write_ini
|
| 28 |
+
|
| 29 |
+
RESOLUTION = 192
|
| 30 |
+
MODEL_REPO = "markury/midmid3-19m-0326"
|
| 31 |
+
|
| 32 |
+
# Loaded once at startup (on CPU)
|
| 33 |
+
_chart_model = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def ensure_model():
|
| 37 |
+
"""Pre-load model on CPU (called at app startup)."""
|
| 38 |
+
global _chart_model
|
| 39 |
+
if _chart_model is None:
|
| 40 |
+
print("Loading chart model from HF Hub...")
|
| 41 |
+
_chart_model = load_model_from_hub(MODEL_REPO, device="cpu")
|
| 42 |
+
print("Chart model loaded.")
|
| 43 |
+
return _chart_model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def generate_chart(
|
| 47 |
+
audio_path: str,
|
| 48 |
+
title: str,
|
| 49 |
+
artist: str,
|
| 50 |
+
album: str = "",
|
| 51 |
+
year: str = "",
|
| 52 |
+
genre: str = "rock",
|
| 53 |
+
temperature: float = 0.8,
|
| 54 |
+
num_steps: int = 12,
|
| 55 |
+
progress_cb=None,
|
| 56 |
+
) -> tuple[str, dict]:
|
| 57 |
+
"""Run the full generation pipeline.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
audio_path: Path to uploaded audio file.
|
| 61 |
+
title: Song title.
|
| 62 |
+
artist: Artist name.
|
| 63 |
+
album: Album name (optional).
|
| 64 |
+
year: Release year (optional).
|
| 65 |
+
genre: Genre string (optional).
|
| 66 |
+
temperature: Sampling temperature.
|
| 67 |
+
num_steps: Unmasking steps.
|
| 68 |
+
progress_cb: Optional callable(step, total, message) for progress.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
(zip_path, chart_json) where chart_json has the data for the visualizer.
|
| 72 |
+
"""
|
| 73 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 74 |
+
model = ensure_model()
|
| 75 |
+
model.to(device)
|
| 76 |
+
move_models_to_device(device)
|
| 77 |
+
|
| 78 |
+
if not year:
|
| 79 |
+
year = str(datetime.now().year)
|
| 80 |
+
|
| 81 |
+
# Create temp output dir
|
| 82 |
+
tmp_dir = tempfile.mkdtemp(prefix="midmid_")
|
| 83 |
+
song_dir = Path(tmp_dir) / f"{title} - {artist}"
|
| 84 |
+
song_dir.mkdir(parents=True, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
def _progress(step, total, msg):
|
| 87 |
+
if progress_cb:
|
| 88 |
+
progress_cb(step / total, desc=msg)
|
| 89 |
+
|
| 90 |
+
# --- Stage 1: Audio analysis ---
|
| 91 |
+
_progress(0, 8, "Tracking beats...")
|
| 92 |
+
beat_data = track_beats(audio_path, device=str(device))
|
| 93 |
+
|
| 94 |
+
_progress(1, 8, "Analyzing tempo...")
|
| 95 |
+
tempo_map = derive_tempo_map(beat_data)
|
| 96 |
+
bpm = get_median_bpm(beat_data)
|
| 97 |
+
time_sig = estimate_time_signature(beat_data)
|
| 98 |
+
offset_sec = calculate_offset(beat_data, bpm, beats_per_measure=time_sig)
|
| 99 |
+
|
| 100 |
+
_progress(2, 8, "Detecting sections...")
|
| 101 |
+
raw_sections = detect_sections(audio_path)
|
| 102 |
+
|
| 103 |
+
# --- Stage 2: Note prediction ---
|
| 104 |
+
beat_times = list(beat_data.beats)
|
| 105 |
+
difficulties = ["expert", "hard", "medium", "easy"]
|
| 106 |
+
all_notes = {}
|
| 107 |
+
|
| 108 |
+
for i, diff_name in enumerate(difficulties):
|
| 109 |
+
_progress(3 + i * 0.2, 8, f"Generating {diff_name} chart...")
|
| 110 |
+
raw_notes = predict_notes(
|
| 111 |
+
audio_path=audio_path,
|
| 112 |
+
model=model,
|
| 113 |
+
beat_times=beat_times,
|
| 114 |
+
difficulty=diff_name,
|
| 115 |
+
device=device,
|
| 116 |
+
temperature=temperature,
|
| 117 |
+
num_steps=num_steps,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
notes = _grid_to_musical_ticks(raw_notes, beat_times, offset_sec, bpm, RESOLUTION)
|
| 121 |
+
notes = enforce_constraints(notes, diff_name, RESOLUTION)
|
| 122 |
+
|
| 123 |
+
last_beat_sec = float(beat_data.beats[-1]) if len(beat_data.beats) > 0 else 0
|
| 124 |
+
last_beat_tick = int(round((last_beat_sec + offset_sec) * bpm / 60.0 * RESOLUTION))
|
| 125 |
+
notes = [n for n in notes if n.tick <= last_beat_tick]
|
| 126 |
+
|
| 127 |
+
all_notes[diff_name] = notes
|
| 128 |
+
|
| 129 |
+
# Fill missing difficulties
|
| 130 |
+
required = ["expert", "hard", "medium", "easy"]
|
| 131 |
+
for diff in required:
|
| 132 |
+
if diff not in all_notes:
|
| 133 |
+
for fallback in required:
|
| 134 |
+
if fallback in all_notes:
|
| 135 |
+
all_notes[diff] = all_notes[fallback]
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
# --- Stage 3: Assembly ---
|
| 139 |
+
_progress(5, 8, "Building chart...")
|
| 140 |
+
tempo_events = _tempo_map_to_ticks(tempo_map, offset_sec, bpm, RESOLUTION)
|
| 141 |
+
section_events = _sections_to_ticks(raw_sections, tempo_map, offset_sec, RESOLUTION)
|
| 142 |
+
|
| 143 |
+
all_ticks = [n.tick for ns in all_notes.values() for n in ns]
|
| 144 |
+
last_tick = max(all_ticks) + RESOLUTION * time_sig if all_ticks else RESOLUTION * time_sig * 4
|
| 145 |
+
beat_markers = _build_beat_markers(last_tick, RESOLUTION, time_sig)
|
| 146 |
+
|
| 147 |
+
chart = ChartData(
|
| 148 |
+
resolution=RESOLUTION,
|
| 149 |
+
tempo_events=tempo_events,
|
| 150 |
+
time_signatures=[(0, time_sig, 4)],
|
| 151 |
+
sections=section_events,
|
| 152 |
+
notes=all_notes,
|
| 153 |
+
beats=beat_markers,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# --- Stage 4: Write outputs ---
|
| 157 |
+
_progress(6, 8, "Writing MIDI...")
|
| 158 |
+
write_midi(chart, str(song_dir / "notes.mid"))
|
| 159 |
+
|
| 160 |
+
_progress(7, 8, "Preparing audio...")
|
| 161 |
+
prepare_audio(
|
| 162 |
+
audio_path=audio_path,
|
| 163 |
+
output_path=str(song_dir / "song.ogg"),
|
| 164 |
+
silence_duration_sec=offset_sec,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
write_ini(
|
| 168 |
+
output_path=str(song_dir / "song.ini"),
|
| 169 |
+
title=title,
|
| 170 |
+
artist=artist,
|
| 171 |
+
album=album,
|
| 172 |
+
genre=genre,
|
| 173 |
+
year=year,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# --- Zip it ---
|
| 177 |
+
zip_base = Path(tmp_dir) / f"{title} - {artist}"
|
| 178 |
+
zip_path = shutil.make_archive(str(zip_base), "zip", tmp_dir, song_dir.name)
|
| 179 |
+
|
| 180 |
+
# --- Build chart JSON for the visualizer ---
|
| 181 |
+
chart_json = _build_chart_json(
|
| 182 |
+
chart, bpm, offset_sec, audio_path, str(song_dir / "song.ogg"),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
_progress(8, 8, "Done!")
|
| 186 |
+
return zip_path, chart_json
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _build_chart_json(chart, bpm, offset_sec, original_audio_path, prepared_audio_path):
|
| 190 |
+
"""Build JSON payload for the client-side visualizer."""
|
| 191 |
+
# Encode prepared audio as base64 for the HTML player
|
| 192 |
+
with open(prepared_audio_path, "rb") as f:
|
| 193 |
+
audio_b64 = base64.b64encode(f.read()).decode("ascii")
|
| 194 |
+
|
| 195 |
+
notes_json = {}
|
| 196 |
+
for diff, note_list in chart.notes.items():
|
| 197 |
+
notes_json[diff] = [
|
| 198 |
+
{
|
| 199 |
+
"tick": n.tick,
|
| 200 |
+
"frets": sorted(n.fret_set),
|
| 201 |
+
"sustain": n.sustain_ticks,
|
| 202 |
+
"hopo": n.is_hopo,
|
| 203 |
+
}
|
| 204 |
+
for n in note_list
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
return {
|
| 208 |
+
"resolution": chart.resolution,
|
| 209 |
+
"bpm": bpm,
|
| 210 |
+
"offset_sec": offset_sec,
|
| 211 |
+
"tempo_events": [{"tick": t, "bpm": b} for t, b in chart.tempo_events],
|
| 212 |
+
"time_signatures": [{"tick": t, "num": n, "den": d} for t, n, d in chart.time_signatures],
|
| 213 |
+
"sections": [{"tick": t, "label": l} for t, l in chart.sections],
|
| 214 |
+
"beats": [{"tick": t, "downbeat": d} for t, d in chart.beats],
|
| 215 |
+
"notes": notes_json,
|
| 216 |
+
"audio_b64": audio_b64,
|
| 217 |
+
"audio_format": "ogg",
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
# Grid index -> musical tick conversion (from generate.py)
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
|
| 225 |
+
def _grid_to_musical_ticks(notes, beat_times, offset_sec, bpm, resolution):
|
| 226 |
+
if len(beat_times) < 2:
|
| 227 |
+
return notes
|
| 228 |
+
|
| 229 |
+
sixteenth = resolution // 4
|
| 230 |
+
|
| 231 |
+
fretbars_ms = [t * 1000.0 for t in beat_times]
|
| 232 |
+
grid_times_ms = []
|
| 233 |
+
for i in range(len(fretbars_ms) - 1):
|
| 234 |
+
start = fretbars_ms[i]
|
| 235 |
+
interval = fretbars_ms[i + 1] - start
|
| 236 |
+
for sub in range(4):
|
| 237 |
+
grid_times_ms.append(start + sub * interval / 4.0)
|
| 238 |
+
grid_times_ms.append(fretbars_ms[-1])
|
| 239 |
+
|
| 240 |
+
result = []
|
| 241 |
+
for note in notes:
|
| 242 |
+
grid_idx = note.tick
|
| 243 |
+
if grid_idx < 0 or grid_idx >= len(grid_times_ms):
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
time_sec = grid_times_ms[grid_idx] / 1000.0 + offset_sec
|
| 247 |
+
tick = round(time_sec * bpm / 60.0 * resolution)
|
| 248 |
+
tick = round(tick / sixteenth) * sixteenth
|
| 249 |
+
tick = max(0, tick)
|
| 250 |
+
|
| 251 |
+
sustain_ticks = 0
|
| 252 |
+
if note.sustain_ticks > 0:
|
| 253 |
+
sustain_sec = note.sustain_ticks / 1000.0
|
| 254 |
+
raw = sustain_sec * bpm / 60.0 * resolution
|
| 255 |
+
sustain_ticks = max(sixteenth, round(raw / sixteenth) * sixteenth)
|
| 256 |
+
|
| 257 |
+
result.append(NoteEvent(
|
| 258 |
+
tick=tick,
|
| 259 |
+
fret_set=note.fret_set,
|
| 260 |
+
sustain_ticks=sustain_ticks,
|
| 261 |
+
is_hopo=note.is_hopo,
|
| 262 |
+
))
|
| 263 |
+
|
| 264 |
+
return result
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _tempo_map_to_ticks(tempo_map, offset_sec, bpm, resolution):
|
| 268 |
+
events = []
|
| 269 |
+
for i, (time_sec, bpm_val) in enumerate(tempo_map):
|
| 270 |
+
if i == 0:
|
| 271 |
+
events.append((0, bpm_val))
|
| 272 |
+
else:
|
| 273 |
+
adjusted_time = time_sec + offset_sec
|
| 274 |
+
prev_time = tempo_map[i - 1][0] + offset_sec if i > 0 else 0
|
| 275 |
+
dt_sec = adjusted_time - prev_time
|
| 276 |
+
prev_tick = events[-1][0]
|
| 277 |
+
prev_bpm = events[-1][1]
|
| 278 |
+
tick = prev_tick + int(round(dt_sec * prev_bpm / 60.0 * resolution))
|
| 279 |
+
events.append((tick, bpm_val))
|
| 280 |
+
return events
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _sections_to_ticks(sections, tempo_map, offset_sec, resolution):
|
| 284 |
+
if not tempo_map:
|
| 285 |
+
return []
|
| 286 |
+
result = []
|
| 287 |
+
bpm = tempo_map[0][1]
|
| 288 |
+
for time_sec, label in sections:
|
| 289 |
+
adjusted = time_sec + offset_sec
|
| 290 |
+
tick = int(round(adjusted * bpm / 60.0 * resolution))
|
| 291 |
+
tick = max(0, tick)
|
| 292 |
+
result.append((tick, label))
|
| 293 |
+
return result
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _build_beat_markers(last_tick, resolution, beats_per_measure):
|
| 297 |
+
beats = []
|
| 298 |
+
tick = 0
|
| 299 |
+
beat_in_measure = 0
|
| 300 |
+
while tick <= last_tick:
|
| 301 |
+
beats.append((tick, beat_in_measure == 0))
|
| 302 |
+
beat_in_measure = (beat_in_measure + 1) % beats_per_measure
|
| 303 |
+
tick += resolution
|
| 304 |
+
return beats
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0
|
| 2 |
+
spaces
|
| 3 |
+
safetensors
|
| 4 |
+
transformers<5
|
| 5 |
+
huggingface-hub
|
| 6 |
+
mido
|
| 7 |
+
beat_this @ git+https://github.com/CPJKU/beat_this.git
|
| 8 |
+
librosa
|
| 9 |
+
pydub
|
| 10 |
+
numpy
|
| 11 |
+
scipy
|
| 12 |
+
tqdm
|
visualizer.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build the HTML/JS/CSS for the chart visualizer.
|
| 2 |
+
|
| 3 |
+
Returns a self-contained HTML string that Gradio embeds via gr.HTML().
|
| 4 |
+
The chart data + audio are injected as a JSON blob.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_visualizer_html(chart_json: dict) -> str:
|
| 11 |
+
"""Return a self-contained HTML string for the chart visualizer."""
|
| 12 |
+
data_json = json.dumps(chart_json, separators=(",", ":"))
|
| 13 |
+
return TEMPLATE.replace("__CHART_DATA__", data_json)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
TEMPLATE = r"""
|
| 17 |
+
<div id="midmid-viz" style="font-family: system-ui, -apple-system, sans-serif; background: #111; border-radius: 12px; overflow: hidden; max-width: 900px; margin: 0 auto;">
|
| 18 |
+
|
| 19 |
+
<!-- Controls bar -->
|
| 20 |
+
<div style="display:flex; align-items:center; gap:12px; padding:10px 16px; background:#1a1a1a; border-bottom:1px solid #333;">
|
| 21 |
+
<button id="viz-play" style="background:none; border:none; color:#fff; font-size:22px; cursor:pointer; padding:4px 8px;" title="Play/Pause">▶</button>
|
| 22 |
+
<div id="viz-time" style="color:#aaa; font-size:13px; min-width:80px;">0:00 / 0:00</div>
|
| 23 |
+
<div style="flex:1; position:relative; height:6px; background:#333; border-radius:3px; cursor:pointer;" id="viz-seekbar">
|
| 24 |
+
<div id="viz-seekfill" style="height:100%; background:#7c3aed; border-radius:3px; width:0%; pointer-events:none;"></div>
|
| 25 |
+
</div>
|
| 26 |
+
<select id="viz-diff" style="background:#222; color:#fff; border:1px solid #444; border-radius:4px; padding:2px 6px; font-size:13px;">
|
| 27 |
+
<option value="expert">Expert</option>
|
| 28 |
+
<option value="hard">Hard</option>
|
| 29 |
+
<option value="medium">Medium</option>
|
| 30 |
+
<option value="easy">Easy</option>
|
| 31 |
+
</select>
|
| 32 |
+
</div>
|
| 33 |
+
|
| 34 |
+
<!-- Canvas -->
|
| 35 |
+
<canvas id="viz-canvas" style="width:100%; display:block;"></canvas>
|
| 36 |
+
|
| 37 |
+
<!-- Section labels row -->
|
| 38 |
+
<div id="viz-sections" style="padding:6px 16px 10px; background:#1a1a1a; border-top:1px solid #333; color:#888; font-size:11px; min-height:20px; white-space:nowrap; overflow:hidden; text-overflow:ellipsis;"></div>
|
| 39 |
+
|
| 40 |
+
<script>
|
| 41 |
+
(function() {
|
| 42 |
+
const DATA = __CHART_DATA__;
|
| 43 |
+
|
| 44 |
+
// --- Constants ---
|
| 45 |
+
const FRET_COLORS = ['#22c55e','#ef4444','#eab308','#3b82f6','#f97316']; // G R Y B O
|
| 46 |
+
const FRET_GLOW = ['#4ade80','#f87171','#facc15','#60a5fa','#fb923c'];
|
| 47 |
+
const LANE_COUNT = 5;
|
| 48 |
+
const NOTE_RADIUS = 14;
|
| 49 |
+
const LANE_WIDTH = 48;
|
| 50 |
+
const HIGHWAY_WIDTH = LANE_COUNT * LANE_WIDTH;
|
| 51 |
+
const CANVAS_PAD_LEFT = 80;
|
| 52 |
+
const CANVAS_PAD_RIGHT = 20;
|
| 53 |
+
const RES = DATA.resolution; // 192
|
| 54 |
+
|
| 55 |
+
// Timing: convert tick to seconds using tempo events
|
| 56 |
+
const tempoMap = DATA.tempo_events.map(e => ({tick: e.tick, bpm: e.bpm}));
|
| 57 |
+
|
| 58 |
+
function tickToSec(tick) {
|
| 59 |
+
let sec = 0;
|
| 60 |
+
let prevTick = 0;
|
| 61 |
+
let bpm = tempoMap[0].bpm;
|
| 62 |
+
for (let i = 1; i < tempoMap.length; i++) {
|
| 63 |
+
if (tempoMap[i].tick > tick) break;
|
| 64 |
+
sec += (tempoMap[i].tick - prevTick) / RES * 60.0 / bpm;
|
| 65 |
+
prevTick = tempoMap[i].tick;
|
| 66 |
+
bpm = tempoMap[i].bpm;
|
| 67 |
+
}
|
| 68 |
+
sec += (tick - prevTick) / RES * 60.0 / bpm;
|
| 69 |
+
return sec;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
function secToTick(sec) {
|
| 73 |
+
let accSec = 0;
|
| 74 |
+
let prevTick = 0;
|
| 75 |
+
let bpm = tempoMap[0].bpm;
|
| 76 |
+
for (let i = 1; i < tempoMap.length; i++) {
|
| 77 |
+
const dt = (tempoMap[i].tick - prevTick) / RES * 60.0 / bpm;
|
| 78 |
+
if (accSec + dt > sec) break;
|
| 79 |
+
accSec += dt;
|
| 80 |
+
prevTick = tempoMap[i].tick;
|
| 81 |
+
bpm = tempoMap[i].bpm;
|
| 82 |
+
}
|
| 83 |
+
return prevTick + (sec - accSec) * bpm / 60.0 * RES;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// --- Audio setup ---
|
| 87 |
+
const audio = new Audio();
|
| 88 |
+
audio.src = 'data:audio/' + DATA.audio_format + ';base64,' + DATA.audio_b64;
|
| 89 |
+
audio.preload = 'auto';
|
| 90 |
+
|
| 91 |
+
// --- Canvas setup ---
|
| 92 |
+
const canvas = document.getElementById('viz-canvas');
|
| 93 |
+
const ctx = canvas.getContext('2d');
|
| 94 |
+
let W, H, pxPerSec;
|
| 95 |
+
const VISIBLE_SEC = 8; // seconds visible on screen
|
| 96 |
+
|
| 97 |
+
function resize() {
|
| 98 |
+
const container = canvas.parentElement;
|
| 99 |
+
W = container.clientWidth;
|
| 100 |
+
H = 360;
|
| 101 |
+
canvas.width = W * devicePixelRatio;
|
| 102 |
+
canvas.height = H * devicePixelRatio;
|
| 103 |
+
canvas.style.height = H + 'px';
|
| 104 |
+
ctx.setTransform(devicePixelRatio, 0, 0, devicePixelRatio, 0, 0);
|
| 105 |
+
pxPerSec = (W - CANVAS_PAD_LEFT - CANVAS_PAD_RIGHT) / VISIBLE_SEC;
|
| 106 |
+
}
|
| 107 |
+
resize();
|
| 108 |
+
new ResizeObserver(resize).observe(canvas.parentElement);
|
| 109 |
+
|
| 110 |
+
// --- State ---
|
| 111 |
+
let currentDiff = 'expert';
|
| 112 |
+
let playing = false;
|
| 113 |
+
|
| 114 |
+
// Precompute note positions in seconds
|
| 115 |
+
function buildNoteCache(diff) {
|
| 116 |
+
return (DATA.notes[diff] || []).map(n => ({
|
| 117 |
+
sec: tickToSec(n.tick),
|
| 118 |
+
frets: n.frets,
|
| 119 |
+
sustainSec: n.sustain > 0 ? tickToSec(n.tick + n.sustain) - tickToSec(n.tick) : 0,
|
| 120 |
+
hopo: n.hopo,
|
| 121 |
+
}));
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
let noteCache = buildNoteCache(currentDiff);
|
| 125 |
+
|
| 126 |
+
// Precompute beats in seconds
|
| 127 |
+
const beatCache = DATA.beats.map(b => ({
|
| 128 |
+
sec: tickToSec(b.tick),
|
| 129 |
+
downbeat: b.downbeat,
|
| 130 |
+
}));
|
| 131 |
+
|
| 132 |
+
// Sections in seconds
|
| 133 |
+
const sectionCache = DATA.sections.map(s => ({
|
| 134 |
+
sec: tickToSec(s.tick),
|
| 135 |
+
label: s.label,
|
| 136 |
+
}));
|
| 137 |
+
|
| 138 |
+
// Total duration
|
| 139 |
+
let totalDuration = 0;
|
| 140 |
+
audio.addEventListener('loadedmetadata', () => {
|
| 141 |
+
totalDuration = audio.duration;
|
| 142 |
+
});
|
| 143 |
+
// Fallback: estimate from last note
|
| 144 |
+
const allNoteSecs = Object.values(DATA.notes).flat().map(n => tickToSec(n.tick + (n.sustain || 0)));
|
| 145 |
+
const estimatedDuration = allNoteSecs.length ? Math.max(...allNoteSecs) + 5 : 120;
|
| 146 |
+
|
| 147 |
+
function getDuration() { return totalDuration || estimatedDuration; }
|
| 148 |
+
|
| 149 |
+
// --- Controls ---
|
| 150 |
+
const playBtn = document.getElementById('viz-play');
|
| 151 |
+
const timeDiv = document.getElementById('viz-time');
|
| 152 |
+
const seekBar = document.getElementById('viz-seekbar');
|
| 153 |
+
const seekFill = document.getElementById('viz-seekfill');
|
| 154 |
+
const diffSelect = document.getElementById('viz-diff');
|
| 155 |
+
const sectionsDiv = document.getElementById('viz-sections');
|
| 156 |
+
|
| 157 |
+
playBtn.addEventListener('click', () => {
|
| 158 |
+
if (playing) {
|
| 159 |
+
audio.pause();
|
| 160 |
+
playing = false;
|
| 161 |
+
playBtn.textContent = '\u25B6';
|
| 162 |
+
} else {
|
| 163 |
+
audio.play();
|
| 164 |
+
playing = true;
|
| 165 |
+
playBtn.textContent = '\u23F8';
|
| 166 |
+
}
|
| 167 |
+
});
|
| 168 |
+
|
| 169 |
+
seekBar.addEventListener('click', (e) => {
|
| 170 |
+
const rect = seekBar.getBoundingClientRect();
|
| 171 |
+
const frac = (e.clientX - rect.left) / rect.width;
|
| 172 |
+
audio.currentTime = frac * getDuration();
|
| 173 |
+
});
|
| 174 |
+
|
| 175 |
+
diffSelect.addEventListener('change', () => {
|
| 176 |
+
currentDiff = diffSelect.value;
|
| 177 |
+
noteCache = buildNoteCache(currentDiff);
|
| 178 |
+
});
|
| 179 |
+
|
| 180 |
+
audio.addEventListener('ended', () => {
|
| 181 |
+
playing = false;
|
| 182 |
+
playBtn.textContent = '\u25B6';
|
| 183 |
+
});
|
| 184 |
+
|
| 185 |
+
function formatTime(s) {
|
| 186 |
+
const m = Math.floor(s / 60);
|
| 187 |
+
const sec = Math.floor(s % 60);
|
| 188 |
+
return m + ':' + (sec < 10 ? '0' : '') + sec;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// --- Rendering ---
|
| 192 |
+
function draw() {
|
| 193 |
+
const t = audio.currentTime || 0;
|
| 194 |
+
const dur = getDuration();
|
| 195 |
+
|
| 196 |
+
// Update controls
|
| 197 |
+
seekFill.style.width = (t / dur * 100) + '%';
|
| 198 |
+
timeDiv.textContent = formatTime(t) + ' / ' + formatTime(dur);
|
| 199 |
+
|
| 200 |
+
// Update section label
|
| 201 |
+
let currentSection = '';
|
| 202 |
+
for (let i = sectionCache.length - 1; i >= 0; i--) {
|
| 203 |
+
if (sectionCache[i].sec <= t) { currentSection = sectionCache[i].label; break; }
|
| 204 |
+
}
|
| 205 |
+
sectionsDiv.textContent = currentSection;
|
| 206 |
+
|
| 207 |
+
// Clear
|
| 208 |
+
ctx.fillStyle = '#111';
|
| 209 |
+
ctx.fillRect(0, 0, W, H);
|
| 210 |
+
|
| 211 |
+
// The highway: current time is at the left edge + small offset
|
| 212 |
+
const playheadX = CANVAS_PAD_LEFT + 40;
|
| 213 |
+
const secToX = (sec) => playheadX + (sec - t) * pxPerSec;
|
| 214 |
+
const viewStart = t - 1;
|
| 215 |
+
const viewEnd = t + VISIBLE_SEC + 1;
|
| 216 |
+
|
| 217 |
+
// Draw lane backgrounds
|
| 218 |
+
const laneTop = 20;
|
| 219 |
+
const laneBottom = H - 20;
|
| 220 |
+
const laneHeight = laneBottom - laneTop;
|
| 221 |
+
const highwayLeft = playheadX - 20;
|
| 222 |
+
|
| 223 |
+
// Subtle lane separators
|
| 224 |
+
for (let i = 0; i <= LANE_COUNT; i++) {
|
| 225 |
+
const y = laneTop + (laneHeight / LANE_COUNT) * i;
|
| 226 |
+
ctx.strokeStyle = '#2a2a2a';
|
| 227 |
+
ctx.lineWidth = 1;
|
| 228 |
+
ctx.beginPath();
|
| 229 |
+
ctx.moveTo(CANVAS_PAD_LEFT - 10, y);
|
| 230 |
+
ctx.lineTo(W - CANVAS_PAD_RIGHT, y);
|
| 231 |
+
ctx.stroke();
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// Draw beat lines (vertical)
|
| 235 |
+
for (const beat of beatCache) {
|
| 236 |
+
if (beat.sec < viewStart || beat.sec > viewEnd) continue;
|
| 237 |
+
const x = secToX(beat.sec);
|
| 238 |
+
ctx.strokeStyle = beat.downbeat ? '#444' : '#222';
|
| 239 |
+
ctx.lineWidth = beat.downbeat ? 1.5 : 0.5;
|
| 240 |
+
ctx.beginPath();
|
| 241 |
+
ctx.moveTo(x, laneTop);
|
| 242 |
+
ctx.lineTo(x, laneBottom);
|
| 243 |
+
ctx.stroke();
|
| 244 |
+
|
| 245 |
+
// Measure number for downbeats
|
| 246 |
+
if (beat.downbeat && x > CANVAS_PAD_LEFT) {
|
| 247 |
+
ctx.fillStyle = '#555';
|
| 248 |
+
ctx.font = '9px system-ui';
|
| 249 |
+
ctx.fillText('|', x - 2, laneTop - 4);
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// Draw section boundaries
|
| 254 |
+
for (const sec of sectionCache) {
|
| 255 |
+
if (sec.sec < viewStart || sec.sec > viewEnd) continue;
|
| 256 |
+
const x = secToX(sec.sec);
|
| 257 |
+
ctx.strokeStyle = '#7c3aed55';
|
| 258 |
+
ctx.lineWidth = 2;
|
| 259 |
+
ctx.beginPath();
|
| 260 |
+
ctx.moveTo(x, laneTop);
|
| 261 |
+
ctx.lineTo(x, laneBottom);
|
| 262 |
+
ctx.stroke();
|
| 263 |
+
ctx.fillStyle = '#7c3aed';
|
| 264 |
+
ctx.font = '10px system-ui';
|
| 265 |
+
ctx.fillText(sec.label, x + 4, laneTop - 4);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// Draw playhead
|
| 269 |
+
ctx.strokeStyle = '#fff';
|
| 270 |
+
ctx.lineWidth = 2;
|
| 271 |
+
ctx.beginPath();
|
| 272 |
+
ctx.moveTo(playheadX, laneTop - 2);
|
| 273 |
+
ctx.lineTo(playheadX, laneBottom + 2);
|
| 274 |
+
ctx.stroke();
|
| 275 |
+
|
| 276 |
+
// Draw notes
|
| 277 |
+
const laneH = laneHeight / LANE_COUNT;
|
| 278 |
+
|
| 279 |
+
for (const note of noteCache) {
|
| 280 |
+
if (note.sec + note.sustainSec < viewStart || note.sec > viewEnd) continue;
|
| 281 |
+
|
| 282 |
+
const x = secToX(note.sec);
|
| 283 |
+
|
| 284 |
+
for (const fret of note.frets) {
|
| 285 |
+
if (fret > 4) continue; // skip open chord marker
|
| 286 |
+
const laneY = laneTop + fret * laneH + laneH / 2;
|
| 287 |
+
const color = FRET_COLORS[fret];
|
| 288 |
+
const glow = FRET_GLOW[fret];
|
| 289 |
+
|
| 290 |
+
// Draw sustain tail first (behind note)
|
| 291 |
+
if (note.sustainSec > 0) {
|
| 292 |
+
const endX = secToX(note.sec + note.sustainSec);
|
| 293 |
+
ctx.fillStyle = color + '55';
|
| 294 |
+
ctx.fillRect(x, laneY - 4, endX - x, 8);
|
| 295 |
+
ctx.fillStyle = color + '99';
|
| 296 |
+
ctx.fillRect(x, laneY - 2, endX - x, 4);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
// Note circle
|
| 300 |
+
const isPast = note.sec < t;
|
| 301 |
+
ctx.beginPath();
|
| 302 |
+
ctx.arc(x, laneY, NOTE_RADIUS - 2, 0, Math.PI * 2);
|
| 303 |
+
|
| 304 |
+
if (isPast) {
|
| 305 |
+
ctx.fillStyle = color + '44';
|
| 306 |
+
ctx.fill();
|
| 307 |
+
ctx.strokeStyle = color + '66';
|
| 308 |
+
ctx.lineWidth = 1.5;
|
| 309 |
+
ctx.stroke();
|
| 310 |
+
} else {
|
| 311 |
+
// Glow for upcoming notes near playhead
|
| 312 |
+
const dist = note.sec - t;
|
| 313 |
+
if (dist < 0.3) {
|
| 314 |
+
ctx.shadowColor = glow;
|
| 315 |
+
ctx.shadowBlur = 12;
|
| 316 |
+
}
|
| 317 |
+
ctx.fillStyle = color;
|
| 318 |
+
ctx.fill();
|
| 319 |
+
ctx.shadowBlur = 0;
|
| 320 |
+
ctx.strokeStyle = '#fff';
|
| 321 |
+
ctx.lineWidth = 2;
|
| 322 |
+
ctx.stroke();
|
| 323 |
+
|
| 324 |
+
// HOPO = open center
|
| 325 |
+
if (note.hopo) {
|
| 326 |
+
ctx.beginPath();
|
| 327 |
+
ctx.arc(x, laneY, NOTE_RADIUS - 6, 0, Math.PI * 2);
|
| 328 |
+
ctx.fillStyle = '#111';
|
| 329 |
+
ctx.fill();
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
// Fret labels on the left
|
| 336 |
+
const fretNames = ['Green', 'Red', 'Yellow', 'Blue', 'Orange'];
|
| 337 |
+
const fretAbbrev = ['G', 'R', 'Y', 'B', 'O'];
|
| 338 |
+
ctx.font = 'bold 13px system-ui';
|
| 339 |
+
for (let i = 0; i < LANE_COUNT; i++) {
|
| 340 |
+
const y = laneTop + i * laneH + laneH / 2;
|
| 341 |
+
ctx.fillStyle = FRET_COLORS[i];
|
| 342 |
+
ctx.textAlign = 'right';
|
| 343 |
+
ctx.fillText(fretAbbrev[i], CANVAS_PAD_LEFT - 20, y + 5);
|
| 344 |
+
}
|
| 345 |
+
ctx.textAlign = 'left';
|
| 346 |
+
|
| 347 |
+
// Note count overlay
|
| 348 |
+
const noteCount = noteCache.length;
|
| 349 |
+
ctx.fillStyle = '#666';
|
| 350 |
+
ctx.font = '11px system-ui';
|
| 351 |
+
ctx.fillText(noteCount + ' notes (' + currentDiff + ')', W - CANVAS_PAD_RIGHT - 140, laneTop - 4);
|
| 352 |
+
|
| 353 |
+
requestAnimationFrame(draw);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
requestAnimationFrame(draw);
|
| 357 |
+
})();
|
| 358 |
+
</script>
|
| 359 |
+
</div>
|
| 360 |
+
"""
|