amt / app.py
saumya-pailwan's picture
Update app.py
27ce3bd verified
import os
import spaces # Enables ZeroGPU on Hugging Face
import gradio as gr
import torch
from dataclasses import asdict
from mido import MidiFile, tempo2bpm
from transformers import AutoModelForCausalLM
from anticipation.sample import generate
from anticipation.convert import events_to_midi, midi_to_events
from anticipation.tokenize import extract_instruments
from anticipation import ops
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# ---------------------------------------------------------
# Model Choices
# ---------------------------------------------------------
SMALL_MODEL = "stanford-crfm/music-small-800k"
MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
LARGE_MODEL = "stanford-crfm/music-large-800k"
# ---------------------------------------------------------
# Model Card (for HARP)
# ---------------------------------------------------------
model_card = ModelCard(
name="Anticipatory Music Transformer",
description=(
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
"Input: a MIDI file with a short accompaniment followed by a melody line. "
"Output: a new MIDI file with extended accompaniment matching the melody. "
"Use the sliders to choose model size and how much of the song is used as context."
),
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
tags=["midi", "generation", "accompaniment"]
)
# ---------------------------------------------------------
# Model Cache Loader
# ---------------------------------------------------------
_model_cache = {}
def load_amt_model(model_choice: str):
"""Loads and caches the AMT model inside the worker process."""
if model_choice in _model_cache:
return _model_cache[model_choice]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading {model_choice} ...")
model = AutoModelForCausalLM.from_pretrained(
model_choice,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(device)
_model_cache[model_choice] = model
return model
# ---------------------------------------------------------
# Melody Detection (Auto Program)
# ---------------------------------------------------------
def find_melody_program(mid, debug=False):
"""Detect melody track’s program number using pitch, density, and duration heuristics."""
track_stats = []
total_duration = 0
for i, track in enumerate(mid.tracks):
pitches, times = [], []
current_time = 0
current_program = None
track_note_count = 0
for msg in track:
if msg.type not in ("note_on", "program_change"):
continue
current_time += msg.time
if msg.type == "program_change":
current_program = msg.program
continue
if msg.velocity > 0:
pitches.append(msg.note)
times.append(current_time)
track_note_count += 1
if track_note_count >= 100:
break
if not pitches:
continue
track_duration = max(times) - min(times)
total_duration = max(total_duration, current_time)
mean_pitch = sum(pitches) / len(pitches)
polyphony = len(set(pitches)) / len(pitches)
coverage = track_duration / total_duration if total_duration > 0 else 0
track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage))
if not track_stats:
return None, False
if len(track_stats) == 1:
prog = track_stats[0][3]
if debug:
print(f"Single-track MIDI detected — using program {prog or 'None'}")
return prog, prog is not None
candidates = [t for t in track_stats if t[3] is not None and t[3] > 0]
has_valid_programs = len(candidates) > 0
if not candidates:
candidates = track_stats
max_notes = max(t[2] for t in candidates)
max_pitch = max(t[1] for t in candidates)
min_pitch = min(t[1] for t in candidates)
pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1
best_score = -1
best_program = None
best_track = None
best_pitch = None
for t in candidates:
idx, pitch, notes, prog, poly, coverage = t
pitch_norm = (pitch - min_pitch) / pitch_span
notes_norm = notes / max_notes
score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30)
if poly < 0.15:
score *= 0.95
if 55 <= pitch <= 75:
score *= 1.1
if notes >= 30:
score *= 1.05
if coverage > 0.7:
score *= 1.15
if score > best_score:
best_score = score
best_program = prog
best_track = idx
best_pitch = pitch
return best_program, has_valid_programs
def auto_extract_melody(mid, debug=False):
"""Extract melody events from MIDI object (optimized for direct input)."""
events = midi_to_events(mid)
melody_program, has_valid_program = find_melody_program(mid, debug=debug)
if not has_valid_program or melody_program is None or melody_program == 0:
if debug:
print("No valid program changes found; using all events as melody.")
return events, events
events, melody = extract_instruments(events, [melody_program])
if len(melody) == 0:
if debug:
print("No melody events found for program — reverting to all events.")
return events, events
if debug:
print(f"Extracted {len(melody)} melody events from program {melody_program}")
return events, melody
# ---------------------------------------------------------
# Core Generation Logic
# ---------------------------------------------------------
spaces.GPU
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
"""Generate accompaniment conditioned on context history and melody."""
model = load_amt_model(model_choice)
mid = MidiFile(midi_path)
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
all_events, melody = auto_extract_melody(mid, debug=True)
if len(melody) == 0:
melody = all_events
mid_time = mid.length or 0
ops_time = ops.max_time(all_events, seconds=True)
total_time = round(max(mid_time, ops_time))
melody_history = ops.clip(all_events, 0, history_length, clip_duration=False)
melody_future = ops.clip(melody, history_length, total_time, clip_duration=False)
accompaniment = generate(
model,
start_time=history_length,
end_time=total_time,
inputs=melody_history,
controls=melody_future,
top_p=0.95,
debug=False
)
output_events = ops.clip(
ops.combine(accompaniment, melody),
0,
total_time,
clip_duration=True
)
print(f"Generating from {history_length:.2f}s → {total_time:.2f}s "
f"(duration = {total_time - history_length:.2f}s)")
output_midi = "generated_accompaniment_huggingface.mid"
events_to_midi(output_events).save(output_midi)
return output_midi, None
# ---------------------------------------------------------
# HARP process_fn — with tempo-aware bar→seconds conversion
# ---------------------------------------------------------
def process_fn(input_midi_path, model_choice, history_length, use_bars):
"""Convert bars to seconds (tempo-aware) before generation."""
if use_bars:
BEATS_PER_BAR = 4
bpm = 120
try:
mid = MidiFile(input_midi_path)
for tr in mid.tracks:
for msg in tr:
if msg.type == "set_tempo":
bpm = round(tempo2bpm(msg.tempo))
break
except Exception:
pass
seconds_per_bar = (60.0 / bpm) * BEATS_PER_BAR
history_length = history_length * seconds_per_bar
print(f"[INFO] Converted to {history_length:.2f}s from bars @ {bpm} BPM")
output_midi, error_message = generate_accompaniment(
input_midi_path,
model_choice,
float(history_length)
)
if error_message:
return {"message": error_message}, None
return asdict(LabelList()), output_midi
# ---------------------------------------------------------
# Gradio + HARP UI
# ---------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🎼 Anticipatory Music Transformer")
input_midi = gr.File(file_types=[".mid", ".midi"], label="Input MIDI File", type="filepath").harp_required(True)
model_dropdown = gr.Dropdown(
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
value=MEDIUM_MODEL,
label="Select AMT Model (Faster vs. Higher Quality)"
)
with gr.Row():
history_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Select History Length (seconds)")
use_bars = gr.Checkbox(
value=False,
label="Use Musical Bars Instead of Seconds",
info="If enabled, context length is interpreted as bars based on the MIDI tempo."
)
def get_midi_tempo(midi_path):
try:
mid = MidiFile(midi_path)
for track in mid.tracks:
for msg in track:
if msg.type == "set_tempo":
return round(tempo2bpm(msg.tempo))
except Exception:
pass
return 120
BEATS_PER_BAR = 4
def bars_to_seconds(bars, bpm, beats_per_bar=BEATS_PER_BAR):
return bars * beats_per_bar * (60.0 / bpm)
def toggle_label_and_range(use_bars):
if use_bars:
return gr.update(label="Select History Length (bars)", minimum=2, maximum=8, step=2, value=4)
else:
return gr.update(label="Select History Length (seconds)", minimum=1, maximum=10, step=1, value=5)
def update_bar_label(history_value, midi_path, use_bars):
if not use_bars:
return gr.update(label="Select History Length (seconds)")
bpm = get_midi_tempo(midi_path)
approx_sec = bars_to_seconds(history_value, bpm)
return gr.update(label=f"Select History Length ({history_value} bars ≈ {approx_sec:.1f}s @ {bpm} BPM)")
use_bars.change(fn=toggle_label_and_range, inputs=use_bars, outputs=history_slider)
history_slider.change(fn=update_bar_label, inputs=[history_slider, input_midi, use_bars],
outputs=history_slider, queue=False)
output_labels = gr.JSON(label="Labels / Metadata")
output_midi = gr.File(file_types=[".mid", ".midi"], label="Generated MIDI Output", type="filepath")
_ = build_endpoint(
model_card=model_card,
input_components=[input_midi, model_dropdown, history_slider, use_bars],
output_components=[output_labels, output_midi],
process_fn=process_fn
)
# ---------------------------------------------------------
# Launch App
# ---------------------------------------------------------
demo.launch(share=True, show_error=True, debug=True)