markury commited on
Commit
d171350
·
0 Parent(s):

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.egg-info/
5
+ model_upload/
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Midmid - Guitar Hero Chart Generator
3
+ emoji: 🎸
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: "5.23.0"
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ hardware: zero-a10g
12
+ ---
13
+
14
+ # Midmid — AI Guitar Hero Chart Generator
15
+
16
+ Upload a song, get a playable Guitar Hero chart. Powered by a 19M-parameter
17
+ masked-prediction transformer trained on thousands of community-charted songs.
18
+
19
+ **How it works:**
20
+ 1. Upload an audio file (MP3, FLAC, OGG, WAV)
21
+ 2. Enter song metadata (title, artist, etc.)
22
+ 3. Hit Generate — the model analyzes beats, structure, and audio features,
23
+ then predicts note placements for all four difficulty levels
24
+ 4. Preview the chart in-browser, then download the ready-to-play song package
25
+
26
+ The output folder drops straight into GHWT:DE's MODS directory.
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Midmid — AI Guitar Hero Chart Generator (Hugging Face Space)."""
2
+
3
+ import os
4
+ import gradio as gr
5
+
6
+ # ZeroGPU: import spaces if available (no-op locally)
7
+ try:
8
+ import spaces
9
+ ON_ZEROGPU = True
10
+ except ImportError:
11
+ ON_ZEROGPU = False
12
+
13
+ from pipeline import ensure_model, generate_chart
14
+ from visualizer import build_visualizer_html
15
+
16
+ # Pre-load model on CPU at startup
17
+ ensure_model()
18
+
19
+ PLACEHOLDER_HTML = """
20
+ <div style="font-family: system-ui, sans-serif; background: #111; border-radius: 12px;
21
+ padding: 60px 20px; text-align: center; color: #666; max-width: 900px; margin: 0 auto;">
22
+ <div style="font-size: 48px; margin-bottom: 12px;">🎸</div>
23
+ <div style="font-size: 16px;">Upload a song and hit Generate to see your chart here</div>
24
+ </div>
25
+ """
26
+
27
+
28
+ def _generate_wrapper(audio_path, title, artist, album, year, genre, progress=gr.Progress()):
29
+ """Gradio-facing wrapper with validation and progress."""
30
+ if not audio_path:
31
+ raise gr.Error("Please upload an audio file.")
32
+ if not title or not title.strip():
33
+ raise gr.Error("Song title is required.")
34
+ if not artist or not artist.strip():
35
+ raise gr.Error("Artist name is required.")
36
+
37
+ zip_path, chart_json = generate_chart(
38
+ audio_path=audio_path,
39
+ title=title.strip(),
40
+ artist=artist.strip(),
41
+ album=album.strip() if album else "",
42
+ year=year.strip() if year else "",
43
+ genre=genre.strip() if genre else "rock",
44
+ progress_cb=progress,
45
+ )
46
+
47
+ html = build_visualizer_html(chart_json)
48
+ return html, zip_path
49
+
50
+
51
+ # Apply ZeroGPU decorator if running on HF Spaces
52
+ if ON_ZEROGPU:
53
+ _generate_wrapper = spaces.GPU(duration=180)(_generate_wrapper)
54
+
55
+
56
+ # --- UI ---
57
+ with gr.Blocks(
58
+ title="Midmid — Guitar Hero Chart Generator",
59
+ theme=gr.themes.Base(primary_hue="purple", neutral_hue="gray"),
60
+ css="""
61
+ .gradio-container { max-width: 960px !important; }
62
+ #generate-btn { min-height: 48px; font-size: 16px; }
63
+ """,
64
+ ) as demo:
65
+ gr.Markdown(
66
+ "# Midmid — AI Guitar Hero Chart Generator\n"
67
+ "Upload a song, get a playable chart with 4 difficulty levels. "
68
+ "Preview it here, then download the GHWT:DE-ready package."
69
+ )
70
+
71
+ with gr.Row():
72
+ with gr.Column(scale=1):
73
+ audio_input = gr.Audio(
74
+ label="Upload audio",
75
+ type="filepath",
76
+ sources=["upload"],
77
+ )
78
+ title_input = gr.Textbox(label="Song title *", placeholder="e.g. Through the Fire and Flames")
79
+ artist_input = gr.Textbox(label="Artist *", placeholder="e.g. DragonForce")
80
+
81
+ with gr.Row():
82
+ album_input = gr.Textbox(label="Album", placeholder="(optional)")
83
+ year_input = gr.Textbox(label="Year", placeholder="(optional)")
84
+
85
+ genre_input = gr.Textbox(label="Genre", placeholder="rock", value="rock")
86
+
87
+ generate_btn = gr.Button("Generate Chart", variant="primary", elem_id="generate-btn")
88
+
89
+ with gr.Column(scale=2):
90
+ viz_output = gr.HTML(value=PLACEHOLDER_HTML, label="Chart Preview")
91
+ zip_output = gr.File(label="Download song package (.zip)")
92
+
93
+ generate_btn.click(
94
+ fn=_generate_wrapper,
95
+ inputs=[audio_input, title_input, artist_input, album_input, year_input, genre_input],
96
+ outputs=[viz_output, zip_output],
97
+ )
98
+
99
+ gr.Markdown(
100
+ "---\n"
101
+ "*Charts generated by [Midmid](https://github.com/markury/midmid) — "
102
+ "a 19M-parameter masked transformer trained on community Guitar Hero charts. "
103
+ "Model: `markury/midmid3-19m-0326`*"
104
+ )
105
+
106
+
107
+ if __name__ == "__main__":
108
+ demo.launch()
convert_checkpoint.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert a midmid PyTorch checkpoint to safetensors + config.json.
2
+
3
+ Usage:
4
+ python convert_checkpoint.py path/to/best.pt --output-dir ./model_upload
5
+
6
+ This produces:
7
+ model_upload/model.safetensors (weights only, no pickle)
8
+ model_upload/config.json (model hyperparameters)
9
+
10
+ Then upload to HF:
11
+ huggingface-cli upload markury/midmid3-19m-0326 ./model_upload
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ from safetensors.torch import save_file
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(description="Convert midmid checkpoint to safetensors")
24
+ parser.add_argument("checkpoint", type=Path, help="Path to .pt checkpoint")
25
+ parser.add_argument("--output-dir", type=Path, default=Path("model_upload"),
26
+ help="Output directory (default: ./model_upload)")
27
+ args = parser.parse_args()
28
+
29
+ args.output_dir.mkdir(parents=True, exist_ok=True)
30
+
31
+ print(f"Loading checkpoint: {args.checkpoint}")
32
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
33
+
34
+ # Save config
35
+ config = ckpt["config"]
36
+ config_path = args.output_dir / "config.json"
37
+ with open(config_path, "w") as f:
38
+ json.dump(config, f, indent=2)
39
+ print(f"Config saved: {config_path}")
40
+ print(f" {json.dumps(config, indent=2)}")
41
+
42
+ # Save weights as safetensors
43
+ state_dict = ckpt["model_state_dict"]
44
+ safetensors_path = args.output_dir / "model.safetensors"
45
+ save_file(state_dict, str(safetensors_path))
46
+ print(f"Weights saved: {safetensors_path}")
47
+
48
+ # Summary
49
+ n_params = sum(p.numel() for p in state_dict.values())
50
+ print(f" {n_params:,} parameters ({n_params / 1e6:.1f}M)")
51
+
52
+ print(f"\nUpload to HF with:")
53
+ print(f" huggingface-cli upload markury/midmid3-19m-0326 {args.output_dir}")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
midmid/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Midmid — Guitar Hero chart generation core
midmid/audio_prep.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio preparation — silence prepend, format conversion."""
2
+
3
+ from pydub import AudioSegment
4
+
5
+
6
+ def prepare_audio(
7
+ audio_path: str,
8
+ output_path: str,
9
+ silence_duration_sec: float = 3.0,
10
+ output_format: str = "ogg",
11
+ ) -> None:
12
+ """Prepend silence to audio and export in the target format."""
13
+ audio = AudioSegment.from_file(audio_path)
14
+ silence = AudioSegment.silent(duration=int(silence_duration_sec * 1000))
15
+ prepared = silence + audio
16
+ prepared.export(output_path, format=output_format)
midmid/beat_tracker.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beat and downbeat tracking via beat_this (CPJKU)."""
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ from beat_this.inference import File2Beats
7
+
8
+
9
+ @dataclass
10
+ class BeatData:
11
+ beats: np.ndarray
12
+ downbeats: np.ndarray
13
+ beat_numbers: np.ndarray
14
+
15
+
16
+ def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
17
+ """Run beat and downbeat tracking on an audio file."""
18
+ processor = File2Beats(checkpoint_path="final0", device=device)
19
+ beats, downbeats = processor(audio_path)
20
+
21
+ beat_numbers = _assign_beat_numbers(beats, downbeats)
22
+
23
+ return BeatData(
24
+ beats=np.asarray(beats),
25
+ downbeats=np.asarray(downbeats),
26
+ beat_numbers=beat_numbers,
27
+ )
28
+
29
+
30
+ def _assign_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray:
31
+ beats = np.asarray(beats)
32
+ downbeats_set = set(np.round(downbeats, 6))
33
+ numbers = np.zeros(len(beats), dtype=int)
34
+ beat_num = 1
35
+
36
+ for i, t in enumerate(beats):
37
+ if round(float(t), 6) in downbeats_set:
38
+ beat_num = 1
39
+ numbers[i] = beat_num
40
+ beat_num += 1
41
+
42
+ return numbers
midmid/constraints.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Difficulty constraints — enforce per-difficulty fret and chord rules."""
2
+
3
+ from midmid.datatypes import NoteEvent
4
+
5
+ ALLOWED_FRETS = {
6
+ "easy": {0, 1, 2},
7
+ "medium": {0, 1, 2, 3},
8
+ "hard": {0, 1, 2, 3, 4},
9
+ "expert": {0, 1, 2, 3, 4},
10
+ }
11
+
12
+ MAX_CHORD_SIZE = {
13
+ "easy": 1,
14
+ "medium": 2,
15
+ "hard": 3,
16
+ "expert": 5,
17
+ }
18
+
19
+ MIN_NOTE_SPACING = {
20
+ "easy": 192,
21
+ "medium": 96,
22
+ "hard": 48,
23
+ "expert": 0,
24
+ }
25
+
26
+
27
+ def enforce_constraints(
28
+ notes: list[NoteEvent], difficulty: str, resolution: int = 192,
29
+ ) -> list[NoteEvent]:
30
+ allowed = ALLOWED_FRETS.get(difficulty, {0, 1, 2, 3, 4})
31
+ max_chord = MAX_CHORD_SIZE.get(difficulty, 5)
32
+ min_spacing = MIN_NOTE_SPACING.get(difficulty, 0)
33
+
34
+ result = []
35
+ for note in notes:
36
+ filtered = note.fret_set & allowed
37
+ if not filtered:
38
+ for fret in sorted(note.fret_set):
39
+ closest = min(allowed, key=lambda a: abs(a - fret))
40
+ filtered.add(closest)
41
+ break
42
+ if not filtered:
43
+ continue
44
+
45
+ if len(filtered) > max_chord:
46
+ filtered = set(sorted(filtered)[:max_chord])
47
+
48
+ if min_spacing > 0 and result:
49
+ if note.tick - result[-1].tick < min_spacing:
50
+ continue
51
+
52
+ if result and result[-1].sustain_ticks > 0:
53
+ prev_end = result[-1].tick + result[-1].sustain_ticks
54
+ if note.tick < prev_end:
55
+ continue
56
+
57
+ result.append(NoteEvent(
58
+ tick=note.tick,
59
+ fret_set=filtered,
60
+ sustain_ticks=note.sustain_ticks,
61
+ is_hopo=note.is_hopo,
62
+ ))
63
+
64
+ sixteenth = resolution // 4
65
+ if len(result) >= 2 and result[-2].sustain_ticks > 0:
66
+ prev = result[-2]
67
+ max_sustain = note.tick - prev.tick - sixteenth
68
+ max_sustain = (max_sustain // sixteenth) * sixteenth
69
+ if max_sustain < sixteenth:
70
+ max_sustain = 0
71
+ if prev.sustain_ticks > max_sustain:
72
+ result[-2] = NoteEvent(
73
+ tick=prev.tick,
74
+ fret_set=prev.fret_set,
75
+ sustain_ticks=max_sustain,
76
+ is_hopo=prev.is_hopo,
77
+ )
78
+
79
+ return result
midmid/datatypes.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared data types used across the pipeline."""
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class NoteEvent:
8
+ """A single note or chord at a specific tick position."""
9
+ tick: int
10
+ fret_set: set # {0, 1, 2, 3, 4} where 0=Green, 4=Orange
11
+ sustain_ticks: int = 0
12
+ is_hopo: bool = False
13
+
14
+
15
+ @dataclass
16
+ class ChartData:
17
+ """Complete chart data ready for MIDI serialization."""
18
+ resolution: int = 192 # ticks per quarter note
19
+ tempo_events: list = field(default_factory=lambda: [(0, 120.0)])
20
+ time_signatures: list = field(default_factory=lambda: [(0, 4, 4)])
21
+ sections: list = field(default_factory=list) # [(tick, label), ...]
22
+ notes: dict = field(default_factory=dict) # {"expert": [NoteEvent, ...], ...}
23
+ beats: list = field(default_factory=list) # [(tick, is_downbeat), ...]
midmid/inference.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio encoding and iterative unmasking inference.
2
+
3
+ Adapted from midmid/prediction/model.py for standalone use.
4
+ Device management is caller-controlled (for ZeroGPU compatibility).
5
+ """
6
+
7
+ import itertools as _it
8
+ import json
9
+ import math
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ from midmid.nn import (
17
+ ChartMaskPredictor, ChartMaskPredictorConfig,
18
+ MASK_TOKEN, SILENCE_TOKEN,
19
+ )
20
+ from midmid.datatypes import NoteEvent
21
+
22
+ MERT_MODEL_ID = "m-a-p/MERT-v1-95M"
23
+
24
+ DIFF_ID = {"easy": 0, "medium": 1, "hard": 2, "expert": 3}
25
+
26
+ # Class ID -> fret tuple
27
+ _CLASS_TO_FRETS: list[tuple[int, ...]] = []
28
+ for _r in range(1, 6):
29
+ _CLASS_TO_FRETS.extend(_it.combinations(range(5), _r))
30
+ _CLASS_TO_FRETS.append((7,)) # class 31 = open
31
+
32
+ # Sustain bucket center values in beats
33
+ _BUCKET_BEATS = [0.0, 1.0, 2.0, 4.0, 8.0, 16.0]
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Model loading (safetensors from HF Hub)
38
+ # ---------------------------------------------------------------------------
39
+
40
+ def load_model_from_hub(
41
+ repo_id: str = "markury/midmid3-19m-0326",
42
+ device: str = "cpu",
43
+ ) -> ChartMaskPredictor:
44
+ """Download and load model from HuggingFace Hub (safetensors)."""
45
+ from huggingface_hub import hf_hub_download
46
+ from safetensors.torch import load_file
47
+
48
+ config_path = hf_hub_download(repo_id, "config.json")
49
+ weights_path = hf_hub_download(repo_id, "model.safetensors")
50
+
51
+ with open(config_path) as f:
52
+ config_dict = json.load(f)
53
+
54
+ config = ChartMaskPredictorConfig(**config_dict)
55
+ model = ChartMaskPredictor(config)
56
+
57
+ state_dict = load_file(weights_path, device=device)
58
+ model.load_state_dict(state_dict)
59
+ model.to(device)
60
+ model.eval()
61
+ return model
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # MERT audio encoding (lazy-loaded)
66
+ # ---------------------------------------------------------------------------
67
+
68
+ _mert_model = None
69
+ _mert_processor = None
70
+ _mert_frame_rate = None
71
+
72
+
73
+ def _ensure_mert(device: torch.device):
74
+ """Load MERT model and processor on first use."""
75
+ global _mert_model, _mert_processor, _mert_frame_rate
76
+ if _mert_model is not None:
77
+ # Move to correct device if needed
78
+ if next(_mert_model.parameters()).device != device:
79
+ _mert_model.to(device)
80
+ return
81
+
82
+ from transformers import AutoModel, Wav2Vec2FeatureExtractor
83
+
84
+ print(f"Loading MERT ({MERT_MODEL_ID}) ...")
85
+ _mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(
86
+ MERT_MODEL_ID, trust_remote_code=True,
87
+ )
88
+ _mert_model = AutoModel.from_pretrained(MERT_MODEL_ID, trust_remote_code=True)
89
+ _mert_model.to(device)
90
+ _mert_model.eval()
91
+
92
+ # Compute frame rate dynamically
93
+ sr = _mert_processor.sampling_rate
94
+ test_wav = np.zeros(sr, dtype=np.float32)
95
+ inputs = _mert_processor(test_wav, sampling_rate=sr, return_tensors="pt")
96
+ inputs = {k: v.to(device) for k, v in inputs.items()}
97
+ with torch.no_grad():
98
+ out = _mert_model(**inputs, output_hidden_states=False)
99
+ _mert_frame_rate = float(out.last_hidden_state.shape[1])
100
+ print(f" MERT frame rate: {_mert_frame_rate:.2f} Hz")
101
+
102
+
103
+ def move_models_to_device(device: torch.device):
104
+ """Move all cached models to the specified device (for ZeroGPU)."""
105
+ global _mert_model
106
+ if _mert_model is not None:
107
+ _mert_model.to(device)
108
+
109
+
110
+ @torch.no_grad()
111
+ def encode_audio_mert(
112
+ audio_path: str,
113
+ device: torch.device,
114
+ chunk_sec: float = 60.0,
115
+ ) -> tuple[torch.Tensor, float]:
116
+ """Encode audio with MERT, return (embeddings, frame_rate)."""
117
+ import librosa
118
+ _ensure_mert(device)
119
+
120
+ sr = _mert_processor.sampling_rate
121
+ wav, _ = librosa.load(audio_path, sr=sr, mono=True)
122
+
123
+ chunk_samples = int(chunk_sec * sr)
124
+ overlap_sec = 5.0
125
+ overlap_samples = int(overlap_sec * sr)
126
+ stride_samples = chunk_samples - overlap_samples
127
+
128
+ if len(wav) <= chunk_samples:
129
+ inputs = _mert_processor(wav, sampling_rate=sr, return_tensors="pt")
130
+ inputs = {k: v.to(device) for k, v in inputs.items()}
131
+ out = _mert_model(**inputs, output_hidden_states=False)
132
+ return out.last_hidden_state.squeeze(0).cpu(), _mert_frame_rate
133
+
134
+ # Chunked processing for long audio
135
+ all_emb = []
136
+ pos = 0
137
+ idx = 0
138
+ while pos < len(wav):
139
+ end = min(pos + chunk_samples, len(wav))
140
+ chunk = wav[pos:end]
141
+ min_len = chunk_samples // 4
142
+ if len(chunk) < min_len:
143
+ chunk = np.pad(chunk, (0, min_len - len(chunk)))
144
+
145
+ inputs = _mert_processor(chunk, sampling_rate=sr, return_tensors="pt")
146
+ inputs = {k: v.to(device) for k, v in inputs.items()}
147
+ out = _mert_model(**inputs, output_hidden_states=False)
148
+ emb = out.last_hidden_state.squeeze(0)
149
+
150
+ n = emb.shape[0]
151
+ fps = n / (len(chunk) / sr)
152
+ half_overlap = int(round((overlap_sec / 2) * fps))
153
+
154
+ if idx == 0:
155
+ keep = n - half_overlap if end < len(wav) else n
156
+ all_emb.append(emb[:keep].cpu())
157
+ elif end >= len(wav):
158
+ all_emb.append(emb[half_overlap:].cpu())
159
+ else:
160
+ keep = int(round((len(chunk) / sr - overlap_sec) * fps))
161
+ all_emb.append(emb[half_overlap:half_overlap + keep].cpu())
162
+
163
+ pos += stride_samples
164
+ idx += 1
165
+
166
+ return torch.cat(all_emb, dim=0), _mert_frame_rate
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # Grid helpers
171
+ # ---------------------------------------------------------------------------
172
+
173
+ def _build_16th_grid(fretbars):
174
+ """Build 16th-note timestamps (ms) from beat positions."""
175
+ if len(fretbars) < 2:
176
+ return list(fretbars)
177
+ positions = []
178
+ for i in range(len(fretbars) - 1):
179
+ start = fretbars[i]
180
+ interval = fretbars[i + 1] - start
181
+ for sub in range(4):
182
+ positions.append(start + sub * interval / 4.0)
183
+ positions.append(fretbars[-1])
184
+ return positions
185
+
186
+
187
+ def _get_local_beat_ms(grid_idx, fretbars):
188
+ beat_idx = min(grid_idx // 4, len(fretbars) - 2)
189
+ beat_idx = max(0, beat_idx)
190
+ if beat_idx + 1 < len(fretbars):
191
+ return fretbars[beat_idx + 1] - fretbars[beat_idx]
192
+ return 500.0
193
+
194
+
195
+ # ---------------------------------------------------------------------------
196
+ # Main inference
197
+ # ---------------------------------------------------------------------------
198
+
199
+ @torch.no_grad()
200
+ def predict_notes(
201
+ audio_path: str,
202
+ model: ChartMaskPredictor,
203
+ beat_times: list[float],
204
+ difficulty: str = "expert",
205
+ device: torch.device = None,
206
+ num_steps: int = 12,
207
+ temperature: float = 0.9,
208
+ ) -> list[NoteEvent]:
209
+ """MaskGIT-style iterative unmasking inference."""
210
+ if device is None:
211
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
+ dev = device
213
+ model.to(dev)
214
+ model.eval()
215
+
216
+ fretbars = [t * 1000.0 for t in beat_times]
217
+ if len(fretbars) < 2:
218
+ return []
219
+
220
+ # MERT embeddings
221
+ embeddings, frame_rate = encode_audio_mert(audio_path, dev)
222
+
223
+ # Build grid and sample MERT frames with windowing
224
+ grid_times = _build_16th_grid(fretbars)
225
+ num_positions = len(grid_times)
226
+ max_frame = embeddings.shape[0] - 1
227
+ frame_indices = torch.tensor(
228
+ [min(int(round(t / 1000.0 * frame_rate)), max_frame)
229
+ for t in grid_times], dtype=torch.long,
230
+ )
231
+
232
+ window = 2
233
+ if window > 0 and max_frame >= window * 2:
234
+ padded = torch.nn.functional.pad(
235
+ embeddings.unsqueeze(0), (0, 0, window, window), mode="replicate",
236
+ ).squeeze(0)
237
+ shifted = frame_indices + window
238
+ stacked = torch.stack(
239
+ [padded[shifted + d] for d in range(-window, window + 1)], dim=0,
240
+ )
241
+ grid_emb = stacked.mean(dim=0)
242
+ else:
243
+ grid_emb = embeddings[frame_indices]
244
+
245
+ # Compute and concat audio features if model expects them
246
+ if model.config.audio_dim > grid_emb.shape[-1]:
247
+ import librosa as _lr
248
+ wav, _ = _lr.load(audio_path, sr=24000, mono=True)
249
+ hop = 320
250
+ onset = _lr.onset.onset_strength(y=wav, sr=24000, hop_length=hop)
251
+ rms_arr = _lr.feature.rms(y=wav, hop_length=hop)[0]
252
+ centroid = _lr.feature.spectral_centroid(y=wav, sr=24000, hop_length=hop)[0]
253
+
254
+ def _norm(x):
255
+ mn, mx = x.min(), x.max()
256
+ return (x - mn) / max(mx - mn, 1e-8)
257
+
258
+ onset, rms_arr, centroid = _norm(onset), _norm(rms_arr), _norm(centroid)
259
+ af_rate = 24000 / hop
260
+ af_max = len(onset) - 1
261
+ af_indices = [min(int(round(t / 1000.0 * af_rate)), af_max) for t in grid_times]
262
+ af_tensor = torch.tensor(
263
+ [[onset[i], rms_arr[i], centroid[i]] for i in af_indices],
264
+ dtype=torch.float32,
265
+ )
266
+ grid_emb = torch.cat([grid_emb, af_tensor], dim=-1)
267
+
268
+ audio_features = grid_emb.unsqueeze(0).to(dev)
269
+
270
+ diff_id = DIFF_ID.get(difficulty, 3)
271
+ diff_tensor = torch.tensor([diff_id], dtype=torch.long, device=dev)
272
+ padding_mask = torch.ones(1, num_positions, dtype=torch.bool, device=dev)
273
+
274
+ # Start fully masked
275
+ chart_tokens = torch.full(
276
+ (1, num_positions), MASK_TOKEN, dtype=torch.long, device=dev,
277
+ )
278
+
279
+ # Cosine unmasking schedule
280
+ schedule = []
281
+ for step in range(num_steps):
282
+ r_prev = math.cos(math.pi / 2 * step / num_steps)
283
+ r_next = math.cos(math.pi / 2 * (step + 1) / num_steps)
284
+ n_unmask = max(1, int((r_prev - r_next) * num_positions))
285
+ schedule.append(n_unmask)
286
+
287
+ # Iterative unmasking
288
+ for step in range(num_steps):
289
+ outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask)
290
+ token_logits = outputs["token_logits"].squeeze(0)
291
+
292
+ is_masked = (chart_tokens.squeeze(0) == MASK_TOKEN)
293
+ masked_indices = is_masked.nonzero(as_tuple=True)[0]
294
+
295
+ if len(masked_indices) == 0:
296
+ break
297
+
298
+ probs = torch.softmax(token_logits / temperature, dim=-1)
299
+ sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
300
+
301
+ n_unmask = min(schedule[step], len(masked_indices))
302
+ perm = torch.randperm(len(masked_indices), device=dev)
303
+ unmask_idx = masked_indices[perm[:n_unmask]]
304
+ chart_tokens[0, unmask_idx] = sampled[unmask_idx]
305
+
306
+ # Final pass for sustain predictions
307
+ outputs = model(audio_features, chart_tokens, diff_tensor, padding_mask)
308
+ sustain_prob = outputs["sustain_logits"].squeeze(0).squeeze(-1).sigmoid()
309
+ dur_pred = outputs["duration_logits"].squeeze(0).argmax(dim=-1)
310
+
311
+ # Convert tokens to NoteEvents
312
+ tokens = chart_tokens.squeeze(0).cpu()
313
+ notes = []
314
+ for i in range(num_positions):
315
+ tok = tokens[i].item()
316
+ if tok >= SILENCE_TOKEN or tok < 0:
317
+ continue
318
+
319
+ fret_set = set(_CLASS_TO_FRETS[tok])
320
+ if not fret_set:
321
+ continue
322
+
323
+ sustain_ticks = 0
324
+ if sustain_prob[i] >= 0.5:
325
+ bucket = dur_pred[i].item()
326
+ beat_ms = _get_local_beat_ms(i, fretbars)
327
+ sustain_ticks = _BUCKET_BEATS[bucket] * beat_ms
328
+
329
+ notes.append(NoteEvent(
330
+ tick=i,
331
+ fret_set=fret_set,
332
+ sustain_ticks=sustain_ticks,
333
+ ))
334
+
335
+ return notes
midmid/ini_writer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate song.ini metadata for GHWT:DE."""
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ def write_ini(
7
+ output_path: str,
8
+ title: str = "Unknown Song",
9
+ artist: str = "Unknown Artist",
10
+ album: str = "",
11
+ genre: str = "rock",
12
+ year: str = "2024",
13
+ charter: str = "Midmid",
14
+ diff_guitar: int = 0,
15
+ preview_start_time: int = 30000,
16
+ song_length: int = 0,
17
+ ) -> None:
18
+ lines = [
19
+ "[Song]",
20
+ f"name = {title}",
21
+ f"artist = {artist}",
22
+ f"album = {album}",
23
+ f"genre = {genre}",
24
+ f"year = {year}",
25
+ f"charter = {charter}",
26
+ f"diff_guitar = {diff_guitar}",
27
+ f"preview_start_time = {preview_start_time}",
28
+ ]
29
+ if song_length > 0:
30
+ lines.append(f"song_length = {song_length}")
31
+
32
+ Path(output_path).write_text("\n".join(lines) + "\n", encoding="utf-8")
midmid/midi_writer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MIDI serialization: write ChartData to a GH-format .mid file."""
2
+
3
+ import mido
4
+
5
+ from midmid.datatypes import ChartData, NoteEvent
6
+
7
+ DIFFICULTY_OFFSETS = {"easy": 60, "medium": 72, "hard": 84, "expert": 96}
8
+ HOPO_NOTE = {"easy": 65, "medium": 77, "hard": 89, "expert": 101}
9
+ NOTE_VELOCITY = 100
10
+
11
+
12
+ def write_midi(chart: ChartData, output_path: str) -> None:
13
+ mid = mido.MidiFile(ticks_per_beat=chart.resolution)
14
+
15
+ mid.tracks.append(_build_tempo_track(chart))
16
+ mid.tracks.append(_build_events_track(chart))
17
+ mid.tracks.append(_build_guitar_track(chart))
18
+ if chart.beats:
19
+ mid.tracks.append(_build_beat_track(chart))
20
+
21
+ mid.save(output_path)
22
+
23
+
24
+ def _build_tempo_track(chart):
25
+ track = mido.MidiTrack()
26
+ events = []
27
+
28
+ for tick, bpm in chart.tempo_events:
29
+ events.append((tick, mido.MetaMessage(
30
+ "set_tempo", tempo=mido.bpm2tempo(bpm), time=0)))
31
+
32
+ for tick, num, den in chart.time_signatures:
33
+ events.append((tick, mido.MetaMessage(
34
+ "time_signature", numerator=num, denominator=den, time=0)))
35
+
36
+ _write_sorted_events(track, events)
37
+ return track
38
+
39
+
40
+ def _build_events_track(chart):
41
+ track = mido.MidiTrack()
42
+ track.append(mido.MetaMessage("track_name", name="EVENTS", time=0))
43
+
44
+ events = []
45
+ for tick, label in chart.sections:
46
+ events.append((tick, mido.MetaMessage(
47
+ "text", text=f"[section {label}]", time=0)))
48
+
49
+ _write_sorted_events(track, events)
50
+ return track
51
+
52
+
53
+ def _build_guitar_track(chart):
54
+ track = mido.MidiTrack()
55
+ track.append(mido.MetaMessage("track_name", name="PART GUITAR", time=0))
56
+
57
+ events = []
58
+ for difficulty, offset in DIFFICULTY_OFFSETS.items():
59
+ if difficulty not in chart.notes:
60
+ continue
61
+
62
+ for note in chart.notes[difficulty]:
63
+ for fret in note.fret_set:
64
+ midi_note = offset + fret
65
+ events.append((note.tick, mido.Message(
66
+ "note_on", note=midi_note, velocity=NOTE_VELOCITY, time=0)))
67
+ off_tick = note.tick + max(note.sustain_ticks, 1)
68
+ events.append((off_tick, mido.Message(
69
+ "note_off", note=midi_note, velocity=0, time=0)))
70
+
71
+ if note.is_hopo:
72
+ hopo_note = HOPO_NOTE[difficulty]
73
+ events.append((note.tick, mido.Message(
74
+ "note_on", note=hopo_note, velocity=NOTE_VELOCITY, time=0)))
75
+ events.append((note.tick + 1, mido.Message(
76
+ "note_off", note=hopo_note, velocity=0, time=0)))
77
+
78
+ _write_sorted_events(track, events)
79
+ return track
80
+
81
+
82
+ def _build_beat_track(chart):
83
+ track = mido.MidiTrack()
84
+ track.append(mido.MetaMessage("track_name", name="BEAT", time=0))
85
+
86
+ events = []
87
+ for tick, is_downbeat in chart.beats:
88
+ midi_note = 12 if is_downbeat else 13
89
+ events.append((tick, mido.Message(
90
+ "note_on", note=midi_note, velocity=NOTE_VELOCITY, time=0)))
91
+ events.append((tick + 1, mido.Message(
92
+ "note_off", note=midi_note, velocity=0, time=0)))
93
+
94
+ _write_sorted_events(track, events)
95
+ return track
96
+
97
+
98
+ def _write_sorted_events(track, events):
99
+ events.sort(key=lambda e: e[0])
100
+ prev_tick = 0
101
+ for abs_tick, msg in events:
102
+ msg.time = abs_tick - prev_tick
103
+ track.append(msg)
104
+ prev_tick = abs_tick
midmid/nn.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chart prediction model architecture.
2
+
3
+ FiLM-conditioned masked transformer for Guitar Hero chart generation.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Utility layers
16
+ # ---------------------------------------------------------------------------
17
+
18
+ def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0):
19
+ x_glu, x_linear = x[..., ::2], x[..., 1::2]
20
+ x_glu = x_glu.clamp(max=limit)
21
+ x_linear = x_linear.clamp(min=-limit, max=limit)
22
+ return x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1)
23
+
24
+
25
+ class RMSNorm(nn.Module):
26
+ def __init__(self, dim: int, eps: float = 1e-5):
27
+ super().__init__()
28
+ self.eps = eps
29
+ self.scale = nn.Parameter(torch.ones(dim))
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ t = x.float()
33
+ t = t * torch.rsqrt(t.pow(2).mean(dim=-1, keepdim=True) + self.eps)
34
+ return (t * self.scale).to(x.dtype)
35
+
36
+
37
+ class FeedForward(nn.Module):
38
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
39
+ super().__init__()
40
+ self.linear1 = nn.Linear(d_model, d_ff, bias=False)
41
+ self.linear_out = nn.Linear(d_ff // 2, d_model, bias=False)
42
+ self.dropout = nn.Dropout(dropout)
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ return self.linear_out(self.dropout(swiglu(self.linear1(x))))
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Rotary position embeddings
50
+ # ---------------------------------------------------------------------------
51
+
52
+ def apply_rotary_emb(
53
+ x: torch.Tensor, dim: int, base: float = 10000.0,
54
+ ) -> torch.Tensor:
55
+ """Apply RoPE to a tensor of shape [B, heads, T, head_dim]."""
56
+ seq_len = x.size(2)
57
+ device, dtype = x.device, x.dtype
58
+ theta = base ** (-torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)
59
+ positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1)
60
+ angles = positions * theta.unsqueeze(0)
61
+ sin, cos = angles.sin(), angles.cos()
62
+ sin = sin.unsqueeze(0).unsqueeze(0)
63
+ cos = cos.unsqueeze(0).unsqueeze(0)
64
+ x1 = x[..., : dim // 2]
65
+ x2 = x[..., dim // 2 : dim]
66
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Bidirectional multi-head self-attention
71
+ # ---------------------------------------------------------------------------
72
+
73
+ class BidirectionalAttention(nn.Module):
74
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
75
+ rope_base: float = 10000.0):
76
+ super().__init__()
77
+ self.d_model = d_model
78
+ self.n_heads = n_heads
79
+ self.d_k = d_model // n_heads
80
+ self.rope_base = rope_base
81
+
82
+ self.w_q = nn.Linear(d_model, d_model, bias=False)
83
+ self.w_k = nn.Linear(d_model, d_model, bias=False)
84
+ self.w_v = nn.Linear(d_model, d_model, bias=False)
85
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
86
+ self.dropout = nn.Dropout(dropout)
87
+
88
+ def forward(self, x: torch.Tensor,
89
+ attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
90
+ B, T, _ = x.shape
91
+ Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
92
+ K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
93
+ V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
94
+
95
+ Q = apply_rotary_emb(Q, dim=self.d_k, base=self.rope_base)
96
+ K = apply_rotary_emb(K, dim=self.d_k, base=self.rope_base)
97
+
98
+ sdpa_mask = None
99
+ if attn_mask is not None:
100
+ sdpa_mask = attn_mask[:, None, None, :].bool()
101
+
102
+ out = F.scaled_dot_product_attention(
103
+ Q, K, V, attn_mask=sdpa_mask,
104
+ dropout_p=self.dropout.p if self.training else 0.0,
105
+ is_causal=False,
106
+ )
107
+ out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
108
+ return self.out_proj(out)
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # FiLM-conditioned encoder block
113
+ # ---------------------------------------------------------------------------
114
+
115
+ class FiLMEncoderBlock(nn.Module):
116
+ """Encoder block with FiLM difficulty conditioning.
117
+
118
+ After the feedforward, the output is modulated:
119
+ h = (1 + gamma) * h + beta
120
+ where gamma, beta are derived from the difficulty embedding.
121
+ """
122
+
123
+ def __init__(self, d_model: int, d_ff: int, n_heads: int,
124
+ dropout: float = 0.1, rope_base: float = 10000.0):
125
+ super().__init__()
126
+ self.norm1 = RMSNorm(d_model)
127
+ self.attn = BidirectionalAttention(d_model, n_heads, dropout, rope_base)
128
+ self.norm2 = RMSNorm(d_model)
129
+ self.ff = FeedForward(d_model, d_ff, dropout)
130
+ self.dropout = nn.Dropout(dropout)
131
+
132
+ self.film_proj = nn.Linear(d_model, d_model * 2)
133
+ nn.init.zeros_(self.film_proj.weight)
134
+ nn.init.zeros_(self.film_proj.bias)
135
+
136
+ def forward(self, x: torch.Tensor, diff_emb: torch.Tensor,
137
+ attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
138
+ x = x + self.dropout(self.attn(self.norm1(x), attn_mask))
139
+ h = self.ff(self.norm2(x))
140
+
141
+ film = self.film_proj(diff_emb).unsqueeze(1)
142
+ gamma, beta = film.chunk(2, dim=-1)
143
+ h = (1 + gamma) * h + beta
144
+
145
+ x = x + self.dropout(h)
146
+ return x
147
+
148
+
149
+ # ---------------------------------------------------------------------------
150
+ # Constants
151
+ # ---------------------------------------------------------------------------
152
+
153
+ SILENCE_TOKEN = 32
154
+ MASK_TOKEN = 33
155
+ VOCAB_SIZE = 34
156
+ NUM_SUSTAIN_BUCKETS = 6
157
+
158
+
159
+ # ---------------------------------------------------------------------------
160
+ # Main model
161
+ # ---------------------------------------------------------------------------
162
+
163
+ class ChartMaskPredictor(nn.Module):
164
+ """Masked prediction chart model (v3).
165
+
166
+ Token vocabulary: 0-31 fret combos, 32 silence, 33 MASK.
167
+ """
168
+
169
+ def __init__(self, config: "ChartMaskPredictorConfig"):
170
+ super().__init__()
171
+ self.config = config
172
+ d = config.d_model
173
+
174
+ self.audio_projection = nn.Linear(config.audio_dim, d, bias=False)
175
+ self.chart_embedding = nn.Embedding(VOCAB_SIZE, d)
176
+ self.input_dropout = nn.Dropout(config.dropout)
177
+ self.difficulty_embedding = nn.Embedding(4, d)
178
+
179
+ self.layers = nn.ModuleList([
180
+ FiLMEncoderBlock(
181
+ d_model=d, d_ff=config.d_ff, n_heads=config.n_heads,
182
+ dropout=config.dropout, rope_base=config.rope_base,
183
+ )
184
+ for _ in range(config.n_layers)
185
+ ])
186
+
187
+ self.final_norm = RMSNorm(d)
188
+ self.token_head = nn.Linear(d, VOCAB_SIZE - 1) # 33 classes (no MASK)
189
+ self.sustain_head = nn.Linear(d, 1)
190
+ self.duration_head = nn.Linear(d, NUM_SUSTAIN_BUCKETS)
191
+
192
+ def forward(self, audio_features: torch.Tensor, chart_tokens: torch.Tensor,
193
+ difficulty: torch.Tensor,
194
+ padding_mask: Optional[torch.Tensor] = None) -> dict[str, torch.Tensor]:
195
+ audio = self.audio_projection(audio_features)
196
+ chart = self.chart_embedding(chart_tokens)
197
+ x = audio + chart
198
+ x = self.input_dropout(x)
199
+
200
+ diff_emb = self.difficulty_embedding(difficulty)
201
+
202
+ for layer in self.layers:
203
+ x = layer(x, diff_emb, attn_mask=padding_mask)
204
+
205
+ x = self.final_norm(x)
206
+
207
+ return {
208
+ "token_logits": self.token_head(x),
209
+ "sustain_logits": self.sustain_head(x),
210
+ "duration_logits": self.duration_head(x),
211
+ }
212
+
213
+
214
+ @dataclass
215
+ class ChartMaskPredictorConfig:
216
+ audio_dim: int = 771
217
+ d_model: int = 512
218
+ n_heads: int = 8
219
+ n_layers: int = 6
220
+ d_ff: int = 2048
221
+ dropout: float = 0.15
222
+ rope_base: float = 10000.0
midmid/offset.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Offset / silence duration calculation."""
2
+
3
+ from midmid.beat_tracker import BeatData
4
+
5
+
6
+ def calculate_offset(
7
+ beat_data: BeatData,
8
+ bpm: float,
9
+ beats_per_measure: int = 4,
10
+ min_lead_in: float = 2.0,
11
+ ) -> float:
12
+ """Calculate silence duration to prepend to the audio."""
13
+ if len(beat_data.downbeats) == 0:
14
+ return min_lead_in
15
+
16
+ first_downbeat = float(beat_data.downbeats[0])
17
+ measure_duration = beats_per_measure * 60.0 / bpm
18
+
19
+ n = 1
20
+ while n * measure_duration < min_lead_in:
21
+ n += 1
22
+
23
+ silence = n * measure_duration - first_downbeat
24
+
25
+ while silence < 0:
26
+ n += 1
27
+ silence = n * measure_duration - first_downbeat
28
+
29
+ return silence
midmid/sections.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Structural segmentation (intro, verse, chorus, etc.)."""
2
+
3
+ import numpy as np
4
+ import librosa
5
+
6
+
7
+ def detect_sections(
8
+ audio_path: str,
9
+ min_section_duration: float = 8.0,
10
+ ) -> list[tuple[float, str]]:
11
+ """Detect structural sections in an audio file."""
12
+ y, sr = librosa.load(audio_path, sr=22050, mono=True)
13
+ duration = len(y) / sr
14
+
15
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
16
+
17
+ n_frames = mfcc.shape[1]
18
+ k = max(2, min(n_frames - 1, int(duration / 25)))
19
+ bounds = librosa.segment.agglomerative(mfcc, k=k)
20
+ bound_times = librosa.frames_to_time(bounds, sr=sr)
21
+
22
+ if len(bound_times) == 0 or bound_times[0] > 0.5:
23
+ bound_times = np.concatenate([[0.0], bound_times])
24
+
25
+ bound_times = _merge_short_segments(bound_times, duration, min_section_duration)
26
+ labels = _assign_labels(y, sr, bound_times, duration)
27
+
28
+ return list(zip(bound_times.tolist(), labels))
29
+
30
+
31
+ def _merge_short_segments(bounds, duration, min_dur):
32
+ merged = [bounds[0]]
33
+ for t in bounds[1:]:
34
+ if t - merged[-1] >= min_dur:
35
+ merged.append(t)
36
+ return np.array(merged)
37
+
38
+
39
+ def _assign_labels(y, sr, bound_times, duration):
40
+ n = len(bound_times)
41
+ if n == 0:
42
+ return []
43
+ if n == 1:
44
+ return ["Intro"]
45
+
46
+ segment_features = []
47
+ for i in range(n):
48
+ start_sample = int(bound_times[i] * sr)
49
+ end_sample = int(bound_times[i + 1] * sr) if i + 1 < n else len(y)
50
+ seg = y[start_sample:end_sample]
51
+ if len(seg) < sr // 4:
52
+ segment_features.append(np.zeros(13))
53
+ else:
54
+ mfcc = librosa.feature.mfcc(y=seg, sr=sr, n_mfcc=13)
55
+ segment_features.append(np.mean(mfcc, axis=1))
56
+
57
+ labels = ["Intro"]
58
+ letter_idx = 0
59
+ letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
60
+ assigned = {}
61
+
62
+ for i in range(1, n):
63
+ best_sim = -1
64
+ best_j = -1
65
+ for j in range(i):
66
+ sim = _cosine_sim(segment_features[i], segment_features[j])
67
+ if sim > best_sim:
68
+ best_sim = sim
69
+ best_j = j
70
+
71
+ if best_sim > 0.85 and best_j in assigned:
72
+ labels.append(f"Section {assigned[best_j]}")
73
+ else:
74
+ letter = letters[letter_idx % len(letters)]
75
+ letter_idx += 1
76
+ assigned[i] = letter
77
+ labels.append(f"Section {letter}")
78
+
79
+ if best_j not in assigned and best_j > 0:
80
+ assigned[best_j] = labels[best_j].split()[-1] if " " in labels[best_j] else "A"
81
+
82
+ return labels
83
+
84
+
85
+ def _cosine_sim(a, b):
86
+ norm_a = np.linalg.norm(a)
87
+ norm_b = np.linalg.norm(b)
88
+ if norm_a == 0 or norm_b == 0:
89
+ return 0.0
90
+ return float(np.dot(a, b) / (norm_a * norm_b))
midmid/tempo_map.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tempo map derivation from beat tracker output."""
2
+
3
+ import numpy as np
4
+
5
+ from midmid.beat_tracker import BeatData
6
+
7
+
8
+ def derive_tempo_map(
9
+ beat_data: BeatData, change_threshold: float = 0.08,
10
+ ) -> list[tuple[float, float]]:
11
+ """Derive a tempo map from beat data.
12
+
13
+ Returns list of (time_seconds, bpm) tuples, sorted by time.
14
+ """
15
+ beats = beat_data.beats
16
+ if len(beats) < 2:
17
+ return [(0.0, 120.0)]
18
+
19
+ intervals = np.diff(beats)
20
+ bpms = 60.0 / intervals
21
+
22
+ median_bpm = np.median(bpms)
23
+ valid = (bpms > median_bpm * 0.6) & (bpms < median_bpm * 1.6)
24
+ if not np.any(valid):
25
+ return [(0.0, float(median_bpm))]
26
+
27
+ valid_bpms = bpms[valid]
28
+ if np.std(valid_bpms) / np.mean(valid_bpms) < change_threshold:
29
+ avg_bpm = float(np.mean(valid_bpms))
30
+ return [(0.0, _round_bpm(avg_bpm))]
31
+
32
+ tempo_map = []
33
+ current_bpm = float(bpms[0]) if valid[0] else float(median_bpm)
34
+ tempo_map.append((0.0, _round_bpm(current_bpm)))
35
+
36
+ window = 4
37
+ for i in range(window, len(bpms) - window + 1, window):
38
+ chunk = bpms[i : i + window]
39
+ chunk_valid = chunk[(chunk > median_bpm * 0.6) & (chunk < median_bpm * 1.6)]
40
+ if len(chunk_valid) == 0:
41
+ continue
42
+ local_bpm = float(np.mean(chunk_valid))
43
+ if abs(local_bpm - current_bpm) / current_bpm > change_threshold:
44
+ current_bpm = local_bpm
45
+ tempo_map.append((float(beats[i]), _round_bpm(current_bpm)))
46
+
47
+ return tempo_map
48
+
49
+
50
+ def get_median_bpm(beat_data: BeatData) -> float:
51
+ if len(beat_data.beats) < 2:
52
+ return 120.0
53
+ intervals = np.diff(beat_data.beats)
54
+ bpms = 60.0 / intervals
55
+ return float(_round_bpm(np.median(bpms)))
56
+
57
+
58
+ def estimate_time_signature(beat_data: BeatData) -> int:
59
+ if len(beat_data.downbeats) < 2:
60
+ return 4
61
+
62
+ beats = beat_data.beats
63
+ downbeats = beat_data.downbeats
64
+
65
+ counts = []
66
+ for i in range(len(downbeats) - 1):
67
+ start, end = downbeats[i], downbeats[i + 1]
68
+ n = np.sum((beats >= start) & (beats < end))
69
+ if 2 <= n <= 7:
70
+ counts.append(n)
71
+
72
+ if not counts:
73
+ return 4
74
+
75
+ values, freq = np.unique(counts, return_counts=True)
76
+ return int(values[np.argmax(freq)])
77
+
78
+
79
+ def _round_bpm(bpm: float) -> float:
80
+ return round(float(bpm), 2)
pipeline.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generation pipeline — callable from Gradio, ZeroGPU-compatible.
2
+
3
+ Wraps the full audio→chart pipeline into a single function that returns
4
+ a zip file path and chart JSON for the visualizer.
5
+ """
6
+
7
+ import base64
8
+ import json
9
+ import os
10
+ import shutil
11
+ import tempfile
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from midmid.beat_tracker import track_beats
19
+ from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature
20
+ from midmid.offset import calculate_offset
21
+ from midmid.sections import detect_sections
22
+ from midmid.constraints import enforce_constraints
23
+ from midmid.datatypes import ChartData, NoteEvent
24
+ from midmid.inference import load_model_from_hub, predict_notes, move_models_to_device
25
+ from midmid.midi_writer import write_midi
26
+ from midmid.audio_prep import prepare_audio
27
+ from midmid.ini_writer import write_ini
28
+
29
+ RESOLUTION = 192
30
+ MODEL_REPO = "markury/midmid3-19m-0326"
31
+
32
+ # Loaded once at startup (on CPU)
33
+ _chart_model = None
34
+
35
+
36
+ def ensure_model():
37
+ """Pre-load model on CPU (called at app startup)."""
38
+ global _chart_model
39
+ if _chart_model is None:
40
+ print("Loading chart model from HF Hub...")
41
+ _chart_model = load_model_from_hub(MODEL_REPO, device="cpu")
42
+ print("Chart model loaded.")
43
+ return _chart_model
44
+
45
+
46
+ def generate_chart(
47
+ audio_path: str,
48
+ title: str,
49
+ artist: str,
50
+ album: str = "",
51
+ year: str = "",
52
+ genre: str = "rock",
53
+ temperature: float = 0.8,
54
+ num_steps: int = 12,
55
+ progress_cb=None,
56
+ ) -> tuple[str, dict]:
57
+ """Run the full generation pipeline.
58
+
59
+ Args:
60
+ audio_path: Path to uploaded audio file.
61
+ title: Song title.
62
+ artist: Artist name.
63
+ album: Album name (optional).
64
+ year: Release year (optional).
65
+ genre: Genre string (optional).
66
+ temperature: Sampling temperature.
67
+ num_steps: Unmasking steps.
68
+ progress_cb: Optional callable(step, total, message) for progress.
69
+
70
+ Returns:
71
+ (zip_path, chart_json) where chart_json has the data for the visualizer.
72
+ """
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ model = ensure_model()
75
+ model.to(device)
76
+ move_models_to_device(device)
77
+
78
+ if not year:
79
+ year = str(datetime.now().year)
80
+
81
+ # Create temp output dir
82
+ tmp_dir = tempfile.mkdtemp(prefix="midmid_")
83
+ song_dir = Path(tmp_dir) / f"{title} - {artist}"
84
+ song_dir.mkdir(parents=True, exist_ok=True)
85
+
86
+ def _progress(step, total, msg):
87
+ if progress_cb:
88
+ progress_cb(step / total, desc=msg)
89
+
90
+ # --- Stage 1: Audio analysis ---
91
+ _progress(0, 8, "Tracking beats...")
92
+ beat_data = track_beats(audio_path, device=str(device))
93
+
94
+ _progress(1, 8, "Analyzing tempo...")
95
+ tempo_map = derive_tempo_map(beat_data)
96
+ bpm = get_median_bpm(beat_data)
97
+ time_sig = estimate_time_signature(beat_data)
98
+ offset_sec = calculate_offset(beat_data, bpm, beats_per_measure=time_sig)
99
+
100
+ _progress(2, 8, "Detecting sections...")
101
+ raw_sections = detect_sections(audio_path)
102
+
103
+ # --- Stage 2: Note prediction ---
104
+ beat_times = list(beat_data.beats)
105
+ difficulties = ["expert", "hard", "medium", "easy"]
106
+ all_notes = {}
107
+
108
+ for i, diff_name in enumerate(difficulties):
109
+ _progress(3 + i * 0.2, 8, f"Generating {diff_name} chart...")
110
+ raw_notes = predict_notes(
111
+ audio_path=audio_path,
112
+ model=model,
113
+ beat_times=beat_times,
114
+ difficulty=diff_name,
115
+ device=device,
116
+ temperature=temperature,
117
+ num_steps=num_steps,
118
+ )
119
+
120
+ notes = _grid_to_musical_ticks(raw_notes, beat_times, offset_sec, bpm, RESOLUTION)
121
+ notes = enforce_constraints(notes, diff_name, RESOLUTION)
122
+
123
+ last_beat_sec = float(beat_data.beats[-1]) if len(beat_data.beats) > 0 else 0
124
+ last_beat_tick = int(round((last_beat_sec + offset_sec) * bpm / 60.0 * RESOLUTION))
125
+ notes = [n for n in notes if n.tick <= last_beat_tick]
126
+
127
+ all_notes[diff_name] = notes
128
+
129
+ # Fill missing difficulties
130
+ required = ["expert", "hard", "medium", "easy"]
131
+ for diff in required:
132
+ if diff not in all_notes:
133
+ for fallback in required:
134
+ if fallback in all_notes:
135
+ all_notes[diff] = all_notes[fallback]
136
+ break
137
+
138
+ # --- Stage 3: Assembly ---
139
+ _progress(5, 8, "Building chart...")
140
+ tempo_events = _tempo_map_to_ticks(tempo_map, offset_sec, bpm, RESOLUTION)
141
+ section_events = _sections_to_ticks(raw_sections, tempo_map, offset_sec, RESOLUTION)
142
+
143
+ all_ticks = [n.tick for ns in all_notes.values() for n in ns]
144
+ last_tick = max(all_ticks) + RESOLUTION * time_sig if all_ticks else RESOLUTION * time_sig * 4
145
+ beat_markers = _build_beat_markers(last_tick, RESOLUTION, time_sig)
146
+
147
+ chart = ChartData(
148
+ resolution=RESOLUTION,
149
+ tempo_events=tempo_events,
150
+ time_signatures=[(0, time_sig, 4)],
151
+ sections=section_events,
152
+ notes=all_notes,
153
+ beats=beat_markers,
154
+ )
155
+
156
+ # --- Stage 4: Write outputs ---
157
+ _progress(6, 8, "Writing MIDI...")
158
+ write_midi(chart, str(song_dir / "notes.mid"))
159
+
160
+ _progress(7, 8, "Preparing audio...")
161
+ prepare_audio(
162
+ audio_path=audio_path,
163
+ output_path=str(song_dir / "song.ogg"),
164
+ silence_duration_sec=offset_sec,
165
+ )
166
+
167
+ write_ini(
168
+ output_path=str(song_dir / "song.ini"),
169
+ title=title,
170
+ artist=artist,
171
+ album=album,
172
+ genre=genre,
173
+ year=year,
174
+ )
175
+
176
+ # --- Zip it ---
177
+ zip_base = Path(tmp_dir) / f"{title} - {artist}"
178
+ zip_path = shutil.make_archive(str(zip_base), "zip", tmp_dir, song_dir.name)
179
+
180
+ # --- Build chart JSON for the visualizer ---
181
+ chart_json = _build_chart_json(
182
+ chart, bpm, offset_sec, audio_path, str(song_dir / "song.ogg"),
183
+ )
184
+
185
+ _progress(8, 8, "Done!")
186
+ return zip_path, chart_json
187
+
188
+
189
+ def _build_chart_json(chart, bpm, offset_sec, original_audio_path, prepared_audio_path):
190
+ """Build JSON payload for the client-side visualizer."""
191
+ # Encode prepared audio as base64 for the HTML player
192
+ with open(prepared_audio_path, "rb") as f:
193
+ audio_b64 = base64.b64encode(f.read()).decode("ascii")
194
+
195
+ notes_json = {}
196
+ for diff, note_list in chart.notes.items():
197
+ notes_json[diff] = [
198
+ {
199
+ "tick": n.tick,
200
+ "frets": sorted(n.fret_set),
201
+ "sustain": n.sustain_ticks,
202
+ "hopo": n.is_hopo,
203
+ }
204
+ for n in note_list
205
+ ]
206
+
207
+ return {
208
+ "resolution": chart.resolution,
209
+ "bpm": bpm,
210
+ "offset_sec": offset_sec,
211
+ "tempo_events": [{"tick": t, "bpm": b} for t, b in chart.tempo_events],
212
+ "time_signatures": [{"tick": t, "num": n, "den": d} for t, n, d in chart.time_signatures],
213
+ "sections": [{"tick": t, "label": l} for t, l in chart.sections],
214
+ "beats": [{"tick": t, "downbeat": d} for t, d in chart.beats],
215
+ "notes": notes_json,
216
+ "audio_b64": audio_b64,
217
+ "audio_format": "ogg",
218
+ }
219
+
220
+
221
+ # ---------------------------------------------------------------------------
222
+ # Grid index -> musical tick conversion (from generate.py)
223
+ # ---------------------------------------------------------------------------
224
+
225
+ def _grid_to_musical_ticks(notes, beat_times, offset_sec, bpm, resolution):
226
+ if len(beat_times) < 2:
227
+ return notes
228
+
229
+ sixteenth = resolution // 4
230
+
231
+ fretbars_ms = [t * 1000.0 for t in beat_times]
232
+ grid_times_ms = []
233
+ for i in range(len(fretbars_ms) - 1):
234
+ start = fretbars_ms[i]
235
+ interval = fretbars_ms[i + 1] - start
236
+ for sub in range(4):
237
+ grid_times_ms.append(start + sub * interval / 4.0)
238
+ grid_times_ms.append(fretbars_ms[-1])
239
+
240
+ result = []
241
+ for note in notes:
242
+ grid_idx = note.tick
243
+ if grid_idx < 0 or grid_idx >= len(grid_times_ms):
244
+ continue
245
+
246
+ time_sec = grid_times_ms[grid_idx] / 1000.0 + offset_sec
247
+ tick = round(time_sec * bpm / 60.0 * resolution)
248
+ tick = round(tick / sixteenth) * sixteenth
249
+ tick = max(0, tick)
250
+
251
+ sustain_ticks = 0
252
+ if note.sustain_ticks > 0:
253
+ sustain_sec = note.sustain_ticks / 1000.0
254
+ raw = sustain_sec * bpm / 60.0 * resolution
255
+ sustain_ticks = max(sixteenth, round(raw / sixteenth) * sixteenth)
256
+
257
+ result.append(NoteEvent(
258
+ tick=tick,
259
+ fret_set=note.fret_set,
260
+ sustain_ticks=sustain_ticks,
261
+ is_hopo=note.is_hopo,
262
+ ))
263
+
264
+ return result
265
+
266
+
267
+ def _tempo_map_to_ticks(tempo_map, offset_sec, bpm, resolution):
268
+ events = []
269
+ for i, (time_sec, bpm_val) in enumerate(tempo_map):
270
+ if i == 0:
271
+ events.append((0, bpm_val))
272
+ else:
273
+ adjusted_time = time_sec + offset_sec
274
+ prev_time = tempo_map[i - 1][0] + offset_sec if i > 0 else 0
275
+ dt_sec = adjusted_time - prev_time
276
+ prev_tick = events[-1][0]
277
+ prev_bpm = events[-1][1]
278
+ tick = prev_tick + int(round(dt_sec * prev_bpm / 60.0 * resolution))
279
+ events.append((tick, bpm_val))
280
+ return events
281
+
282
+
283
+ def _sections_to_ticks(sections, tempo_map, offset_sec, resolution):
284
+ if not tempo_map:
285
+ return []
286
+ result = []
287
+ bpm = tempo_map[0][1]
288
+ for time_sec, label in sections:
289
+ adjusted = time_sec + offset_sec
290
+ tick = int(round(adjusted * bpm / 60.0 * resolution))
291
+ tick = max(0, tick)
292
+ result.append((tick, label))
293
+ return result
294
+
295
+
296
+ def _build_beat_markers(last_tick, resolution, beats_per_measure):
297
+ beats = []
298
+ tick = 0
299
+ beat_in_measure = 0
300
+ while tick <= last_tick:
301
+ beats.append((tick, beat_in_measure == 0))
302
+ beat_in_measure = (beat_in_measure + 1) % beats_per_measure
303
+ tick += resolution
304
+ return beats
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0
2
+ spaces
3
+ safetensors
4
+ transformers<5
5
+ huggingface-hub
6
+ mido
7
+ beat_this @ git+https://github.com/CPJKU/beat_this.git
8
+ librosa
9
+ pydub
10
+ numpy
11
+ scipy
12
+ tqdm
visualizer.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build the HTML/JS/CSS for the chart visualizer.
2
+
3
+ Returns a self-contained HTML string that Gradio embeds via gr.HTML().
4
+ The chart data + audio are injected as a JSON blob.
5
+ """
6
+
7
+ import json
8
+
9
+
10
+ def build_visualizer_html(chart_json: dict) -> str:
11
+ """Return a self-contained HTML string for the chart visualizer."""
12
+ data_json = json.dumps(chart_json, separators=(",", ":"))
13
+ return TEMPLATE.replace("__CHART_DATA__", data_json)
14
+
15
+
16
+ TEMPLATE = r"""
17
+ <div id="midmid-viz" style="font-family: system-ui, -apple-system, sans-serif; background: #111; border-radius: 12px; overflow: hidden; max-width: 900px; margin: 0 auto;">
18
+
19
+ <!-- Controls bar -->
20
+ <div style="display:flex; align-items:center; gap:12px; padding:10px 16px; background:#1a1a1a; border-bottom:1px solid #333;">
21
+ <button id="viz-play" style="background:none; border:none; color:#fff; font-size:22px; cursor:pointer; padding:4px 8px;" title="Play/Pause">&#9654;</button>
22
+ <div id="viz-time" style="color:#aaa; font-size:13px; min-width:80px;">0:00 / 0:00</div>
23
+ <div style="flex:1; position:relative; height:6px; background:#333; border-radius:3px; cursor:pointer;" id="viz-seekbar">
24
+ <div id="viz-seekfill" style="height:100%; background:#7c3aed; border-radius:3px; width:0%; pointer-events:none;"></div>
25
+ </div>
26
+ <select id="viz-diff" style="background:#222; color:#fff; border:1px solid #444; border-radius:4px; padding:2px 6px; font-size:13px;">
27
+ <option value="expert">Expert</option>
28
+ <option value="hard">Hard</option>
29
+ <option value="medium">Medium</option>
30
+ <option value="easy">Easy</option>
31
+ </select>
32
+ </div>
33
+
34
+ <!-- Canvas -->
35
+ <canvas id="viz-canvas" style="width:100%; display:block;"></canvas>
36
+
37
+ <!-- Section labels row -->
38
+ <div id="viz-sections" style="padding:6px 16px 10px; background:#1a1a1a; border-top:1px solid #333; color:#888; font-size:11px; min-height:20px; white-space:nowrap; overflow:hidden; text-overflow:ellipsis;"></div>
39
+
40
+ <script>
41
+ (function() {
42
+ const DATA = __CHART_DATA__;
43
+
44
+ // --- Constants ---
45
+ const FRET_COLORS = ['#22c55e','#ef4444','#eab308','#3b82f6','#f97316']; // G R Y B O
46
+ const FRET_GLOW = ['#4ade80','#f87171','#facc15','#60a5fa','#fb923c'];
47
+ const LANE_COUNT = 5;
48
+ const NOTE_RADIUS = 14;
49
+ const LANE_WIDTH = 48;
50
+ const HIGHWAY_WIDTH = LANE_COUNT * LANE_WIDTH;
51
+ const CANVAS_PAD_LEFT = 80;
52
+ const CANVAS_PAD_RIGHT = 20;
53
+ const RES = DATA.resolution; // 192
54
+
55
+ // Timing: convert tick to seconds using tempo events
56
+ const tempoMap = DATA.tempo_events.map(e => ({tick: e.tick, bpm: e.bpm}));
57
+
58
+ function tickToSec(tick) {
59
+ let sec = 0;
60
+ let prevTick = 0;
61
+ let bpm = tempoMap[0].bpm;
62
+ for (let i = 1; i < tempoMap.length; i++) {
63
+ if (tempoMap[i].tick > tick) break;
64
+ sec += (tempoMap[i].tick - prevTick) / RES * 60.0 / bpm;
65
+ prevTick = tempoMap[i].tick;
66
+ bpm = tempoMap[i].bpm;
67
+ }
68
+ sec += (tick - prevTick) / RES * 60.0 / bpm;
69
+ return sec;
70
+ }
71
+
72
+ function secToTick(sec) {
73
+ let accSec = 0;
74
+ let prevTick = 0;
75
+ let bpm = tempoMap[0].bpm;
76
+ for (let i = 1; i < tempoMap.length; i++) {
77
+ const dt = (tempoMap[i].tick - prevTick) / RES * 60.0 / bpm;
78
+ if (accSec + dt > sec) break;
79
+ accSec += dt;
80
+ prevTick = tempoMap[i].tick;
81
+ bpm = tempoMap[i].bpm;
82
+ }
83
+ return prevTick + (sec - accSec) * bpm / 60.0 * RES;
84
+ }
85
+
86
+ // --- Audio setup ---
87
+ const audio = new Audio();
88
+ audio.src = 'data:audio/' + DATA.audio_format + ';base64,' + DATA.audio_b64;
89
+ audio.preload = 'auto';
90
+
91
+ // --- Canvas setup ---
92
+ const canvas = document.getElementById('viz-canvas');
93
+ const ctx = canvas.getContext('2d');
94
+ let W, H, pxPerSec;
95
+ const VISIBLE_SEC = 8; // seconds visible on screen
96
+
97
+ function resize() {
98
+ const container = canvas.parentElement;
99
+ W = container.clientWidth;
100
+ H = 360;
101
+ canvas.width = W * devicePixelRatio;
102
+ canvas.height = H * devicePixelRatio;
103
+ canvas.style.height = H + 'px';
104
+ ctx.setTransform(devicePixelRatio, 0, 0, devicePixelRatio, 0, 0);
105
+ pxPerSec = (W - CANVAS_PAD_LEFT - CANVAS_PAD_RIGHT) / VISIBLE_SEC;
106
+ }
107
+ resize();
108
+ new ResizeObserver(resize).observe(canvas.parentElement);
109
+
110
+ // --- State ---
111
+ let currentDiff = 'expert';
112
+ let playing = false;
113
+
114
+ // Precompute note positions in seconds
115
+ function buildNoteCache(diff) {
116
+ return (DATA.notes[diff] || []).map(n => ({
117
+ sec: tickToSec(n.tick),
118
+ frets: n.frets,
119
+ sustainSec: n.sustain > 0 ? tickToSec(n.tick + n.sustain) - tickToSec(n.tick) : 0,
120
+ hopo: n.hopo,
121
+ }));
122
+ }
123
+
124
+ let noteCache = buildNoteCache(currentDiff);
125
+
126
+ // Precompute beats in seconds
127
+ const beatCache = DATA.beats.map(b => ({
128
+ sec: tickToSec(b.tick),
129
+ downbeat: b.downbeat,
130
+ }));
131
+
132
+ // Sections in seconds
133
+ const sectionCache = DATA.sections.map(s => ({
134
+ sec: tickToSec(s.tick),
135
+ label: s.label,
136
+ }));
137
+
138
+ // Total duration
139
+ let totalDuration = 0;
140
+ audio.addEventListener('loadedmetadata', () => {
141
+ totalDuration = audio.duration;
142
+ });
143
+ // Fallback: estimate from last note
144
+ const allNoteSecs = Object.values(DATA.notes).flat().map(n => tickToSec(n.tick + (n.sustain || 0)));
145
+ const estimatedDuration = allNoteSecs.length ? Math.max(...allNoteSecs) + 5 : 120;
146
+
147
+ function getDuration() { return totalDuration || estimatedDuration; }
148
+
149
+ // --- Controls ---
150
+ const playBtn = document.getElementById('viz-play');
151
+ const timeDiv = document.getElementById('viz-time');
152
+ const seekBar = document.getElementById('viz-seekbar');
153
+ const seekFill = document.getElementById('viz-seekfill');
154
+ const diffSelect = document.getElementById('viz-diff');
155
+ const sectionsDiv = document.getElementById('viz-sections');
156
+
157
+ playBtn.addEventListener('click', () => {
158
+ if (playing) {
159
+ audio.pause();
160
+ playing = false;
161
+ playBtn.textContent = '\u25B6';
162
+ } else {
163
+ audio.play();
164
+ playing = true;
165
+ playBtn.textContent = '\u23F8';
166
+ }
167
+ });
168
+
169
+ seekBar.addEventListener('click', (e) => {
170
+ const rect = seekBar.getBoundingClientRect();
171
+ const frac = (e.clientX - rect.left) / rect.width;
172
+ audio.currentTime = frac * getDuration();
173
+ });
174
+
175
+ diffSelect.addEventListener('change', () => {
176
+ currentDiff = diffSelect.value;
177
+ noteCache = buildNoteCache(currentDiff);
178
+ });
179
+
180
+ audio.addEventListener('ended', () => {
181
+ playing = false;
182
+ playBtn.textContent = '\u25B6';
183
+ });
184
+
185
+ function formatTime(s) {
186
+ const m = Math.floor(s / 60);
187
+ const sec = Math.floor(s % 60);
188
+ return m + ':' + (sec < 10 ? '0' : '') + sec;
189
+ }
190
+
191
+ // --- Rendering ---
192
+ function draw() {
193
+ const t = audio.currentTime || 0;
194
+ const dur = getDuration();
195
+
196
+ // Update controls
197
+ seekFill.style.width = (t / dur * 100) + '%';
198
+ timeDiv.textContent = formatTime(t) + ' / ' + formatTime(dur);
199
+
200
+ // Update section label
201
+ let currentSection = '';
202
+ for (let i = sectionCache.length - 1; i >= 0; i--) {
203
+ if (sectionCache[i].sec <= t) { currentSection = sectionCache[i].label; break; }
204
+ }
205
+ sectionsDiv.textContent = currentSection;
206
+
207
+ // Clear
208
+ ctx.fillStyle = '#111';
209
+ ctx.fillRect(0, 0, W, H);
210
+
211
+ // The highway: current time is at the left edge + small offset
212
+ const playheadX = CANVAS_PAD_LEFT + 40;
213
+ const secToX = (sec) => playheadX + (sec - t) * pxPerSec;
214
+ const viewStart = t - 1;
215
+ const viewEnd = t + VISIBLE_SEC + 1;
216
+
217
+ // Draw lane backgrounds
218
+ const laneTop = 20;
219
+ const laneBottom = H - 20;
220
+ const laneHeight = laneBottom - laneTop;
221
+ const highwayLeft = playheadX - 20;
222
+
223
+ // Subtle lane separators
224
+ for (let i = 0; i <= LANE_COUNT; i++) {
225
+ const y = laneTop + (laneHeight / LANE_COUNT) * i;
226
+ ctx.strokeStyle = '#2a2a2a';
227
+ ctx.lineWidth = 1;
228
+ ctx.beginPath();
229
+ ctx.moveTo(CANVAS_PAD_LEFT - 10, y);
230
+ ctx.lineTo(W - CANVAS_PAD_RIGHT, y);
231
+ ctx.stroke();
232
+ }
233
+
234
+ // Draw beat lines (vertical)
235
+ for (const beat of beatCache) {
236
+ if (beat.sec < viewStart || beat.sec > viewEnd) continue;
237
+ const x = secToX(beat.sec);
238
+ ctx.strokeStyle = beat.downbeat ? '#444' : '#222';
239
+ ctx.lineWidth = beat.downbeat ? 1.5 : 0.5;
240
+ ctx.beginPath();
241
+ ctx.moveTo(x, laneTop);
242
+ ctx.lineTo(x, laneBottom);
243
+ ctx.stroke();
244
+
245
+ // Measure number for downbeats
246
+ if (beat.downbeat && x > CANVAS_PAD_LEFT) {
247
+ ctx.fillStyle = '#555';
248
+ ctx.font = '9px system-ui';
249
+ ctx.fillText('|', x - 2, laneTop - 4);
250
+ }
251
+ }
252
+
253
+ // Draw section boundaries
254
+ for (const sec of sectionCache) {
255
+ if (sec.sec < viewStart || sec.sec > viewEnd) continue;
256
+ const x = secToX(sec.sec);
257
+ ctx.strokeStyle = '#7c3aed55';
258
+ ctx.lineWidth = 2;
259
+ ctx.beginPath();
260
+ ctx.moveTo(x, laneTop);
261
+ ctx.lineTo(x, laneBottom);
262
+ ctx.stroke();
263
+ ctx.fillStyle = '#7c3aed';
264
+ ctx.font = '10px system-ui';
265
+ ctx.fillText(sec.label, x + 4, laneTop - 4);
266
+ }
267
+
268
+ // Draw playhead
269
+ ctx.strokeStyle = '#fff';
270
+ ctx.lineWidth = 2;
271
+ ctx.beginPath();
272
+ ctx.moveTo(playheadX, laneTop - 2);
273
+ ctx.lineTo(playheadX, laneBottom + 2);
274
+ ctx.stroke();
275
+
276
+ // Draw notes
277
+ const laneH = laneHeight / LANE_COUNT;
278
+
279
+ for (const note of noteCache) {
280
+ if (note.sec + note.sustainSec < viewStart || note.sec > viewEnd) continue;
281
+
282
+ const x = secToX(note.sec);
283
+
284
+ for (const fret of note.frets) {
285
+ if (fret > 4) continue; // skip open chord marker
286
+ const laneY = laneTop + fret * laneH + laneH / 2;
287
+ const color = FRET_COLORS[fret];
288
+ const glow = FRET_GLOW[fret];
289
+
290
+ // Draw sustain tail first (behind note)
291
+ if (note.sustainSec > 0) {
292
+ const endX = secToX(note.sec + note.sustainSec);
293
+ ctx.fillStyle = color + '55';
294
+ ctx.fillRect(x, laneY - 4, endX - x, 8);
295
+ ctx.fillStyle = color + '99';
296
+ ctx.fillRect(x, laneY - 2, endX - x, 4);
297
+ }
298
+
299
+ // Note circle
300
+ const isPast = note.sec < t;
301
+ ctx.beginPath();
302
+ ctx.arc(x, laneY, NOTE_RADIUS - 2, 0, Math.PI * 2);
303
+
304
+ if (isPast) {
305
+ ctx.fillStyle = color + '44';
306
+ ctx.fill();
307
+ ctx.strokeStyle = color + '66';
308
+ ctx.lineWidth = 1.5;
309
+ ctx.stroke();
310
+ } else {
311
+ // Glow for upcoming notes near playhead
312
+ const dist = note.sec - t;
313
+ if (dist < 0.3) {
314
+ ctx.shadowColor = glow;
315
+ ctx.shadowBlur = 12;
316
+ }
317
+ ctx.fillStyle = color;
318
+ ctx.fill();
319
+ ctx.shadowBlur = 0;
320
+ ctx.strokeStyle = '#fff';
321
+ ctx.lineWidth = 2;
322
+ ctx.stroke();
323
+
324
+ // HOPO = open center
325
+ if (note.hopo) {
326
+ ctx.beginPath();
327
+ ctx.arc(x, laneY, NOTE_RADIUS - 6, 0, Math.PI * 2);
328
+ ctx.fillStyle = '#111';
329
+ ctx.fill();
330
+ }
331
+ }
332
+ }
333
+ }
334
+
335
+ // Fret labels on the left
336
+ const fretNames = ['Green', 'Red', 'Yellow', 'Blue', 'Orange'];
337
+ const fretAbbrev = ['G', 'R', 'Y', 'B', 'O'];
338
+ ctx.font = 'bold 13px system-ui';
339
+ for (let i = 0; i < LANE_COUNT; i++) {
340
+ const y = laneTop + i * laneH + laneH / 2;
341
+ ctx.fillStyle = FRET_COLORS[i];
342
+ ctx.textAlign = 'right';
343
+ ctx.fillText(fretAbbrev[i], CANVAS_PAD_LEFT - 20, y + 5);
344
+ }
345
+ ctx.textAlign = 'left';
346
+
347
+ // Note count overlay
348
+ const noteCount = noteCache.length;
349
+ ctx.fillStyle = '#666';
350
+ ctx.font = '11px system-ui';
351
+ ctx.fillText(noteCount + ' notes (' + currentDiff + ')', W - CANVAS_PAD_RIGHT - 140, laneTop - 4);
352
+
353
+ requestAnimationFrame(draw);
354
+ }
355
+
356
+ requestAnimationFrame(draw);
357
+ })();
358
+ </script>
359
+ </div>
360
+ """