final
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
-
import spaces # Enables ZeroGPU on Hugging Face
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
from dataclasses import asdict
|
|
@@ -25,10 +25,10 @@ LARGE_MODEL = "stanford-crfm/music-large-800k"
|
|
| 25 |
model_card = ModelCard(
|
| 26 |
name="Anticipatory Music Transformer",
|
| 27 |
description=(
|
| 28 |
-
"Generate musical accompaniment for your existing
|
| 29 |
-
"Input: a MIDI file
|
| 30 |
-
"Output: a new MIDI file with extended accompaniment
|
| 31 |
-
"Use the
|
| 32 |
),
|
| 33 |
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
|
| 34 |
tags=["midi", "generation", "accompaniment"]
|
|
@@ -60,135 +60,59 @@ def load_amt_model(model_choice: str):
|
|
| 60 |
return model
|
| 61 |
|
| 62 |
def find_melody_program(mid, debug=False):
|
| 63 |
-
"""
|
| 64 |
-
Automatically detect the melody track's program number from a MIDI file.
|
| 65 |
-
Uses a balanced heuristic: pitch + note count + temporal coverage.
|
| 66 |
-
"""
|
| 67 |
track_stats = []
|
| 68 |
-
total_duration = 0
|
| 69 |
-
|
| 70 |
for i, track in enumerate(mid.tracks):
|
| 71 |
pitches, times = [], []
|
| 72 |
current_time = 0
|
| 73 |
-
current_program = None
|
| 74 |
-
track_note_count = 0
|
| 75 |
-
|
| 76 |
for msg in track:
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
current_time += msg.time
|
| 81 |
-
if msg.type == "program_change":
|
| 82 |
-
current_program = msg.program
|
| 83 |
-
continue
|
| 84 |
-
|
| 85 |
-
# note_on event
|
| 86 |
-
if msg.velocity > 0:
|
| 87 |
pitches.append(msg.note)
|
| 88 |
times.append(current_time)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# Skip empty or trivial tracks
|
| 96 |
-
if not pitches:
|
| 97 |
-
continue
|
| 98 |
-
|
| 99 |
-
# Compute duration for this track and update total_duration
|
| 100 |
-
track_duration = max(times) - min(times)
|
| 101 |
-
total_duration = max(total_duration, current_time)
|
| 102 |
-
|
| 103 |
-
mean_pitch = sum(pitches) / len(pitches)
|
| 104 |
-
polyphony = len(set(pitches)) / len(pitches)
|
| 105 |
-
coverage = track_duration / total_duration if total_duration > 0 else 0
|
| 106 |
-
|
| 107 |
-
track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage))
|
| 108 |
|
| 109 |
if not track_stats:
|
| 110 |
-
return None, False
|
| 111 |
-
|
| 112 |
-
if len(track_stats) == 1:
|
| 113 |
-
prog = track_stats[0][3]
|
| 114 |
if debug:
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
else:
|
| 118 |
-
print(f"Single-track MIDI detected, using program {prog or 'None'}")
|
| 119 |
-
return prog, prog is not None
|
| 120 |
-
|
| 121 |
-
candidates = [t for t in track_stats if t[3] is not None and t[3] > 0]
|
| 122 |
-
has_valid_programs = len(candidates) > 0
|
| 123 |
-
if not candidates:
|
| 124 |
-
candidates = track_stats
|
| 125 |
-
|
| 126 |
-
if debug:
|
| 127 |
-
print(f"\nCandidates: {len(candidates)} tracks")
|
| 128 |
-
|
| 129 |
-
max_notes = max(t[2] for t in candidates)
|
| 130 |
-
max_pitch = max(t[1] for t in candidates)
|
| 131 |
-
min_pitch = min(t[1] for t in candidates)
|
| 132 |
-
pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1
|
| 133 |
-
|
| 134 |
-
best_score = -1
|
| 135 |
-
best_program = None
|
| 136 |
-
best_track = None
|
| 137 |
-
best_pitch = None
|
| 138 |
-
|
| 139 |
-
for t in candidates:
|
| 140 |
-
idx, pitch, notes, prog, poly, coverage = t
|
| 141 |
-
pitch_norm = (pitch - min_pitch) / pitch_span
|
| 142 |
-
notes_norm = notes / max_notes
|
| 143 |
|
| 144 |
-
|
| 145 |
|
| 146 |
-
|
| 147 |
-
score *= 0.95
|
| 148 |
-
if 55 <= pitch <= 75:
|
| 149 |
-
score *= 1.1
|
| 150 |
-
if notes >= 30:
|
| 151 |
-
score *= 1.05
|
| 152 |
-
if coverage > 0.7:
|
| 153 |
-
score *= 1.15
|
| 154 |
|
| 155 |
-
if score > best_score:
|
| 156 |
-
best_score = score
|
| 157 |
-
best_program = prog
|
| 158 |
-
best_track = idx
|
| 159 |
-
best_pitch = pitch
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
def auto_extract_melody(mid, debug=False):
|
| 165 |
-
"""
|
| 166 |
-
Extract melody events from a MIDI object (already loaded via MidiFile).
|
| 167 |
-
Optimized to avoid re-reading the file from disk.
|
| 168 |
-
Returns: (all_events, melody_events)
|
| 169 |
-
"""
|
| 170 |
events = midi_to_events(mid)
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
|
|
|
|
| 173 |
|
| 174 |
-
if
|
|
|
|
| 175 |
if debug:
|
| 176 |
-
print("
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
events, melody = extract_instruments(events, [melody_program])
|
| 180 |
-
|
| 181 |
-
if len(melody) == 0:
|
| 182 |
if debug:
|
| 183 |
-
print("No
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
if debug:
|
| 187 |
-
print(f"Extracted {len(melody)} melody events from program {melody_program}")
|
| 188 |
|
| 189 |
return events, melody
|
| 190 |
|
| 191 |
-
@spaces.GPU
|
| 192 |
# Core generation
|
| 193 |
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
|
| 194 |
"""
|
|
@@ -199,7 +123,7 @@ def generate_accompaniment(midi_path: str, model_choice: str, history_length: fl
|
|
| 199 |
|
| 200 |
# Parse MIDI correctly, then convert to events
|
| 201 |
mid = MidiFile(midi_path)
|
| 202 |
-
|
| 203 |
|
| 204 |
# Automatically detect and extract melody
|
| 205 |
all_events, melody = auto_extract_melody(mid, debug=True)
|
|
@@ -207,16 +131,18 @@ def generate_accompaniment(midi_path: str, model_choice: str, history_length: fl
|
|
| 207 |
print("No melody detected; using all events")
|
| 208 |
melody = all_events
|
| 209 |
|
| 210 |
-
total_time = round(ops.max_time(all_events, seconds=True))
|
| 211 |
-
|
| 212 |
# History portion
|
| 213 |
history = ops.clip(all_events, 0, history_length, clip_duration=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
# Generate accompaniment for the remaining duration
|
| 216 |
accompaniment = generate(
|
| 217 |
model,
|
| 218 |
-
start_time=history_length,
|
| 219 |
-
end_time=total_time,
|
| 220 |
inputs=history,
|
| 221 |
controls=melody,
|
| 222 |
top_p=0.95,
|
|
@@ -273,16 +199,11 @@ with gr.Blocks() as demo:
|
|
| 273 |
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
|
| 274 |
value=MEDIUM_MODEL,
|
| 275 |
label="Select AMT Model (Faster vs. Higher Quality)"
|
| 276 |
-
).set_info(
|
| 277 |
-
"Choose the model size: Smaller models generate faster but may be less detailed. \n larger models produce richer, more expressive accompaniment."
|
| 278 |
)
|
| 279 |
|
| 280 |
history_slider = gr.Slider(
|
| 281 |
minimum=1, maximum=10, step=1, value=5,
|
| 282 |
-
label="
|
| 283 |
-
).set_info(
|
| 284 |
-
"Controls how much of the beginning of your song is used as context for generation.\n "
|
| 285 |
-
"A longer history helps the model better understand the style and rhythm before extending the accompaniment."
|
| 286 |
)
|
| 287 |
|
| 288 |
# Outputs (JSON FIRST)
|
|
|
|
| 1 |
import os
|
| 2 |
+
#import spaces # Enables ZeroGPU on Hugging Face
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
from dataclasses import asdict
|
|
|
|
| 25 |
model_card = ModelCard(
|
| 26 |
name="Anticipatory Music Transformer",
|
| 27 |
description=(
|
| 28 |
+
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
|
| 29 |
+
"Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. "
|
| 30 |
+
"Output: a new MIDI file with extended accompaniment matching the melody continuation. "
|
| 31 |
+
"Use the sliders to choose model size and how much of the song is used as context."
|
| 32 |
),
|
| 33 |
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
|
| 34 |
tags=["midi", "generation", "accompaniment"]
|
|
|
|
| 60 |
return model
|
| 61 |
|
| 62 |
def find_melody_program(mid, debug=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
track_stats = []
|
|
|
|
|
|
|
| 64 |
for i, track in enumerate(mid.tracks):
|
| 65 |
pitches, times = [], []
|
| 66 |
current_time = 0
|
|
|
|
|
|
|
|
|
|
| 67 |
for msg in track:
|
| 68 |
+
current_time += getattr(msg, "time", 0)
|
| 69 |
+
if msg.type == "note_on" and msg.velocity > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
pitches.append(msg.note)
|
| 71 |
times.append(current_time)
|
| 72 |
+
if pitches:
|
| 73 |
+
mean_pitch = sum(pitches) / len(pitches)
|
| 74 |
+
span = (max(times) - min(times)) or 1
|
| 75 |
+
density = len(pitches) / span
|
| 76 |
+
polyphony = len(set(pitches)) / len(pitches)
|
| 77 |
+
track_stats.append((i, mean_pitch, len(pitches), density, polyphony))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
if not track_stats:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
if debug:
|
| 81 |
+
print("No notes detected in any track.")
|
| 82 |
+
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0]
|
| 85 |
|
| 86 |
+
return melody_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
def get_program_number(mid, track_index):
|
| 90 |
+
for msg in mid.tracks[track_index]:
|
| 91 |
+
if msg.type == "program_change":
|
| 92 |
+
return msg.program
|
| 93 |
+
return None
|
| 94 |
|
| 95 |
|
| 96 |
def auto_extract_melody(mid, debug=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
events = midi_to_events(mid)
|
| 98 |
+
melody_track = find_melody_program(mid, debug=debug)
|
| 99 |
+
melody_program = get_program_number(mid, melody_track)
|
| 100 |
|
| 101 |
+
if debug:
|
| 102 |
+
print(f"Melody Track: {melody_track} | Program: {melody_program}")
|
| 103 |
|
| 104 |
+
if melody_program is not None:
|
| 105 |
+
events, melody = extract_instruments(events, [melody_program])
|
| 106 |
if debug:
|
| 107 |
+
print(f"Extracted {len(melody)} melody events from program {melody_program}")
|
| 108 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
if debug:
|
| 110 |
+
print("No program number found; using all events as melody.")
|
| 111 |
+
melody = events
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
return events, melody
|
| 114 |
|
| 115 |
+
#@spaces.GPU
|
| 116 |
# Core generation
|
| 117 |
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
|
| 118 |
"""
|
|
|
|
| 123 |
|
| 124 |
# Parse MIDI correctly, then convert to events
|
| 125 |
mid = MidiFile(midi_path)
|
| 126 |
+
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
|
| 127 |
|
| 128 |
# Automatically detect and extract melody
|
| 129 |
all_events, melody = auto_extract_melody(mid, debug=True)
|
|
|
|
| 131 |
print("No melody detected; using all events")
|
| 132 |
melody = all_events
|
| 133 |
|
|
|
|
|
|
|
| 134 |
# History portion
|
| 135 |
history = ops.clip(all_events, 0, history_length, clip_duration=False)
|
| 136 |
+
start_time = ops.max_time(history, seconds=True)
|
| 137 |
+
|
| 138 |
+
mid_time = mid.length or 0
|
| 139 |
+
ops_time = ops.max_time(all_events, seconds=True)
|
| 140 |
+
total_time = round(max(mid_time, ops_time))
|
| 141 |
|
|
|
|
| 142 |
accompaniment = generate(
|
| 143 |
model,
|
| 144 |
+
start_time=history_length,
|
| 145 |
+
end_time=total_time,
|
| 146 |
inputs=history,
|
| 147 |
controls=melody,
|
| 148 |
top_p=0.95,
|
|
|
|
| 199 |
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
|
| 200 |
value=MEDIUM_MODEL,
|
| 201 |
label="Select AMT Model (Faster vs. Higher Quality)"
|
|
|
|
|
|
|
| 202 |
)
|
| 203 |
|
| 204 |
history_slider = gr.Slider(
|
| 205 |
minimum=1, maximum=10, step=1, value=5,
|
| 206 |
+
label="Select History Length (seconds)"
|
|
|
|
|
|
|
|
|
|
| 207 |
)
|
| 208 |
|
| 209 |
# Outputs (JSON FIRST)
|