saumya-pailwan's picture
slider -> control
7599353 verified
raw
history blame
7.34 kB
import os
import spaces # Enables ZeroGPU on Hugging Face
import gradio as gr
import torch
from dataclasses import asdict
from transformers import AutoModelForCausalLM
from anticipation.sample import generate
from anticipation.convert import events_to_midi, midi_to_events
from anticipation.tokenize import extract_instruments
from anticipation import ops
from mido import MidiFile # parse MIDI explicitly to avoid .time error
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# Model Choices
SMALL_MODEL = "stanford-crfm/music-small-800k"
MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
LARGE_MODEL = "stanford-crfm/music-large-800k"
# Model Card (new pyharp)
model_card = ModelCard(
name="Anticipatory Music Transformer",
description=(
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
"Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. "
"Output: a new MIDI file with extended accompaniment matching the melody continuation. "
"Use the controls to choose model size and how much of the song is used as context."
),
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
tags=["midi", "generation", "accompaniment"]
)
# Model cache
_model_cache = {}
def load_amt_model(model_choice: str):
"""Loads and caches the AMT model inside the worker process (same behavior as old app)."""
if model_choice in _model_cache:
return _model_cache[model_choice]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_choice == LARGE_MODEL:
print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...")
model = AutoModelForCausalLM.from_pretrained(
LARGE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(device)
else:
print(f"Loading {model_choice} ...")
model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
_model_cache[model_choice] = model
return model
def find_melody_program(mid, debug=False):
track_stats = []
for i, track in enumerate(mid.tracks):
pitches, times = [], []
current_time = 0
for msg in track:
current_time += getattr(msg, "time", 0)
if msg.type == "note_on" and msg.velocity > 0:
pitches.append(msg.note)
times.append(current_time)
if pitches:
mean_pitch = sum(pitches) / len(pitches)
span = (max(times) - min(times)) or 1
density = len(pitches) / span
polyphony = len(set(pitches)) / len(pitches)
track_stats.append((i, mean_pitch, len(pitches), density, polyphony))
if not track_stats:
if debug:
print("No notes detected in any track.")
return 0
melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0]
return melody_idx
def get_program_number(mid, track_index):
for msg in mid.tracks[track_index]:
if msg.type == "program_change":
return msg.program
return None
def auto_extract_melody(mid, debug=False):
events = midi_to_events(mid)
melody_track = find_melody_program(mid, debug=debug)
melody_program = get_program_number(mid, melody_track)
if debug:
print(f"Melody Track: {melody_track} | Program: {melody_program}")
if melody_program is not None:
events, melody = extract_instruments(events, [melody_program])
if debug:
print(f"Extracted {len(melody)} melody events from program {melody_program}")
else:
if debug:
print("No program number found; using all events as melody.")
melody = events
return events, melody
@spaces.GPU
# Core generation
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
"""
Generates accompaniment for the entire MIDI input, conditioned on user-selected history length.
FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error.
"""
model = load_amt_model(model_choice)
# Parse MIDI correctly, then convert to events
mid = MidiFile(midi_path)
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
# Automatically detect and extract melody
all_events, melody = auto_extract_melody(mid, debug=True)
if len(melody) == 0:
print("No melody detected; using all events")
melody = all_events
# History portion
history = ops.clip(all_events, 0, history_length, clip_duration=False)
start_time = ops.max_time(history, seconds=True)
mid_time = mid.length or 0
ops_time = ops.max_time(all_events, seconds=True)
total_time = round(max(mid_time, ops_time))
accompaniment = generate(
model,
start_time=history_length,
end_time=total_time,
inputs=history,
controls=melody,
top_p=0.95,
debug=False
)
# Combine accompaniment + melody and clip
output_events = ops.clip(
ops.combine(accompaniment, melody),
0,
total_time,
clip_duration=True
)
# Save MIDI
output_midi = "generated_accompaniment_huggingface.mid"
mid_out = events_to_midi(output_events)
mid_out.save(output_midi)
return output_midi, None
# HARP process fn (JSON FIRST)
def process_fn(input_midi_path, model_choice, history_length):
"""
Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object.
"""
output_midi, error_message = generate_accompaniment(
input_midi_path,
model_choice,
float(history_length)
)
if error_message:
# JSON first, then no file
return {"message": error_message}, None
labels = LabelList() # add label entries if desired
# JSON first, then MIDI filepath
return asdict(labels), output_midi
# Gradio + HARP UI
with gr.Blocks() as demo:
gr.Markdown("## 🎼 Anticipatory Music Transformer")
# Inputs
input_midi = gr.File(
file_types=[".mid", ".midi"],
label="Input MIDI File",
type="filepath",
).harp_required(True)
model_dropdown = gr.Dropdown(
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
value=MEDIUM_MODEL,
label="Select AMT Model (Faster vs. Higher Quality)"
)
history_slider = gr.Slider(
minimum=1, maximum=10, step=1, value=5,
label="Select History Length (seconds)"
)
# Outputs (JSON FIRST)
output_labels = gr.JSON(label="Labels / Metadata")
output_midi = gr.File(
file_types=[".mid", ".midi"],
label="Generated MIDI Output",
type="filepath",
)
# Build HARP endpoint (new signature)
_ = build_endpoint(
model_card=model_card,
input_components=[
input_midi,
model_dropdown,
history_slider
],
output_components=[
output_labels, # JSON FIRST
output_midi # MIDI SECOND
],
process_fn=process_fn
)
# Launch App
demo.launch(share=True, show_error=True, debug=True)