Spaces:
Sleeping
Sleeping
Commit
·
84d0752
1
Parent(s):
b01bf50
enable model selection
Browse files
app.py
CHANGED
|
@@ -7,42 +7,75 @@ from anticipation.tokenize import extract_instruments
|
|
| 7 |
import torch
|
| 8 |
from pyharp import *
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# Define the model card for PyHARP
|
| 11 |
model_card = ModelCard(
|
| 12 |
name="Anticipatory Music Transformer",
|
| 13 |
description="Using Anticipatory Music Transformer (AMT) to generate accompaniment for a given MIDI file.",
|
| 14 |
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
|
| 15 |
tags=["midi", "generation", "accompaniment"],
|
| 16 |
-
midi_in=True, # PyHARP
|
| 17 |
midi_out=True
|
| 18 |
)
|
| 19 |
|
| 20 |
-
# Load
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
# Function to generate accompaniment
|
| 27 |
-
def generate_accompaniment(midi_file, selected_midi_program, start_time, end_time):
|
| 28 |
# Convert MIDI to events
|
| 29 |
events = midi_to_events(midi_file.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Clip events based on the selected time range
|
| 31 |
clipped_events = ops.clip(events, start_time, end_time)
|
|
|
|
| 32 |
# Normalize timeline (start at 0)
|
| 33 |
clipped_events = ops.translate(clipped_events, -ops.min_time(clipped_events, seconds=False))
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
# Prepare history (first 5 seconds of the segment)
|
| 37 |
history = ops.clip(clipped_events, 0, 5, clip_duration=False)
|
| 38 |
-
|
|
|
|
| 39 |
accompaniment = generate(
|
| 40 |
model, 0, end_time - start_time, inputs=history, controls=melody, top_p=0.98
|
| 41 |
)
|
|
|
|
| 42 |
# Normalize generated accompaniment
|
| 43 |
accompaniment = ops.translate(accompaniment, -ops.min_time(accompaniment, seconds=False))
|
|
|
|
| 44 |
# Combine accompaniment with melody
|
| 45 |
output_events = ops.clip(ops.combine(accompaniment, melody), 0, end_time - start_time, clip_duration=True)
|
|
|
|
| 46 |
# Convert back to MIDI
|
| 47 |
output_midi = "generated_accompaniment.midi"
|
| 48 |
mid = events_to_midi(output_events)
|
|
@@ -50,17 +83,20 @@ def generate_accompaniment(midi_file, selected_midi_program, start_time, end_tim
|
|
| 50 |
|
| 51 |
return output_midi
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
output_midi = generate_accompaniment(input_midi,
|
| 57 |
return output_midi, LabelList()
|
| 58 |
|
| 59 |
-
|
| 60 |
-
# Build Gradio interface wrapped in PyHARP
|
| 61 |
with gr.Blocks() as demo:
|
| 62 |
components = [
|
| 63 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
gr.Slider(0, 30, step=1, label="Start Time (seconds)"),
|
| 65 |
gr.Slider(0, 30, step=1, label="End Time (seconds)"),
|
| 66 |
]
|
|
@@ -72,4 +108,4 @@ with gr.Blocks() as demo:
|
|
| 72 |
)
|
| 73 |
|
| 74 |
demo.queue()
|
| 75 |
-
demo.launch(share=True,
|
|
|
|
| 7 |
import torch
|
| 8 |
from pyharp import *
|
| 9 |
|
| 10 |
+
# === Define AMT Model Checkpoints ===
|
| 11 |
+
SMALL_MODEL = "stanford-crfm/music-small-800k" # Faster inference, worse quality
|
| 12 |
+
MEDIUM_MODEL = "stanford-crfm/music-medium-800k" # Slower inference, better quality
|
| 13 |
+
LARGE_MODEL = "stanford-crfm/music-large-800k" # Slowest inference, best quality
|
| 14 |
+
|
| 15 |
# Define the model card for PyHARP
|
| 16 |
model_card = ModelCard(
|
| 17 |
name="Anticipatory Music Transformer",
|
| 18 |
description="Using Anticipatory Music Transformer (AMT) to generate accompaniment for a given MIDI file.",
|
| 19 |
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
|
| 20 |
tags=["midi", "generation", "accompaniment"],
|
| 21 |
+
midi_in=True, # PyHARP automatically handles MIDI input
|
| 22 |
midi_out=True
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# === Function to Load AMT Model Based on Selection ===
|
| 26 |
+
def load_amt_model(model_choice):
|
| 27 |
+
"""Loads the selected AMT model."""
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
return AutoModelForCausalLM.from_pretrained(model_choice).to(device)
|
| 30 |
+
|
| 31 |
+
# === Function to Detect the Melody Program Automatically ===
|
| 32 |
+
def detect_melody_program(midi_file):
|
| 33 |
+
"""Automatically detects the only MIDI program in the input file."""
|
| 34 |
+
events = midi_to_events(midi_file.name)
|
| 35 |
+
instrument_programs = list(ops.get_instruments(events).keys())
|
| 36 |
+
|
| 37 |
+
if len(instrument_programs) == 1:
|
| 38 |
+
return instrument_programs[0] # Return the only available program
|
| 39 |
+
elif len(instrument_programs) > 1:
|
| 40 |
+
return min(instrument_programs) # Pick the lowest-numbered program
|
| 41 |
+
else:
|
| 42 |
+
return 0 # Default to Acoustic Grand Piano if no program found
|
| 43 |
|
| 44 |
+
# === Function to Generate Accompaniment ===
|
| 45 |
+
def generate_accompaniment(midi_file, model_choice, start_time, end_time):
|
| 46 |
+
"""Generates accompaniment using the selected AMT model."""
|
| 47 |
+
# Load selected AMT model
|
| 48 |
+
model = load_amt_model(model_choice)
|
| 49 |
|
|
|
|
|
|
|
| 50 |
# Convert MIDI to events
|
| 51 |
events = midi_to_events(midi_file.name)
|
| 52 |
+
|
| 53 |
+
# Automatically detect the melody program
|
| 54 |
+
melody_program = detect_melody_program(midi_file)
|
| 55 |
+
|
| 56 |
# Clip events based on the selected time range
|
| 57 |
clipped_events = ops.clip(events, start_time, end_time)
|
| 58 |
+
|
| 59 |
# Normalize timeline (start at 0)
|
| 60 |
clipped_events = ops.translate(clipped_events, -ops.min_time(clipped_events, seconds=False))
|
| 61 |
+
|
| 62 |
+
# Extract the melody instrument automatically
|
| 63 |
+
clipped_events, melody = extract_instruments(clipped_events, [melody_program])
|
| 64 |
+
|
| 65 |
# Prepare history (first 5 seconds of the segment)
|
| 66 |
history = ops.clip(clipped_events, 0, 5, clip_duration=False)
|
| 67 |
+
|
| 68 |
+
# Generate accompaniment using AMT
|
| 69 |
accompaniment = generate(
|
| 70 |
model, 0, end_time - start_time, inputs=history, controls=melody, top_p=0.98
|
| 71 |
)
|
| 72 |
+
|
| 73 |
# Normalize generated accompaniment
|
| 74 |
accompaniment = ops.translate(accompaniment, -ops.min_time(accompaniment, seconds=False))
|
| 75 |
+
|
| 76 |
# Combine accompaniment with melody
|
| 77 |
output_events = ops.clip(ops.combine(accompaniment, melody), 0, end_time - start_time, clip_duration=True)
|
| 78 |
+
|
| 79 |
# Convert back to MIDI
|
| 80 |
output_midi = "generated_accompaniment.midi"
|
| 81 |
mid = events_to_midi(output_events)
|
|
|
|
| 83 |
|
| 84 |
return output_midi
|
| 85 |
|
| 86 |
+
# === PyHARP Process Function ===
|
| 87 |
+
def process_fn(input_midi, model_choice, start_time, end_time):
|
| 88 |
+
"""Processes the input and runs AMT with selected model."""
|
| 89 |
+
output_midi = generate_accompaniment(input_midi, model_choice, start_time, end_time)
|
| 90 |
return output_midi, LabelList()
|
| 91 |
|
| 92 |
+
# === Build Gradio Interface with Model Selection ===
|
|
|
|
| 93 |
with gr.Blocks() as demo:
|
| 94 |
components = [
|
| 95 |
+
gr.Dropdown(
|
| 96 |
+
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
|
| 97 |
+
value=MEDIUM_MODEL,
|
| 98 |
+
label="Select AMT Model (Faster vs. Higher Quality)"
|
| 99 |
+
),
|
| 100 |
gr.Slider(0, 30, step=1, label="Start Time (seconds)"),
|
| 101 |
gr.Slider(0, 30, step=1, label="End Time (seconds)"),
|
| 102 |
]
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
demo.queue()
|
| 111 |
+
demo.launch(share = True,show_error=True)
|