File size: 3,792 Bytes
4578e45
a9f7d79
4578e45
 
 
 
 
 
 
 
e69ad44
fd23462
 
 
84d0752
9360ee5
4578e45
 
2e151e1
4578e45
 
fd23462
4578e45
 
 
18e97fe
9360ee5
eb9bb80
18e97fe
 
 
 
 
 
 
 
 
 
11c3cbf
eb9bb80
a9f7d79
8f765e9
 
 
 
eb9bb80
 
8f765e9
 
 
 
 
 
 
 
 
 
eb9bb80
 
8f765e9
eb9bb80
 
 
 
 
 
 
 
 
8f765e9
eb9bb80
8f765e9
eb9bb80
 
 
d9f0ff3
eb9bb80
 
8f765e9
 
 
 
 
 
 
 
eb9bb80
 
 
 
 
 
 
 
 
 
 
33e42c6
8f765e9
eb9bb80
 
 
 
 
 
 
 
 
d9f0ff3
eb9bb80
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import spaces  # Enables ZeroGPU on Hugging Face
from transformers import AutoModelForCausalLM
from anticipation.sample import generate
from anticipation.convert import events_to_midi, midi_to_events
from anticipation import ops
from anticipation.tokenize import extract_instruments
import torch
from pyharp import *

#Model Choices
SMALL_MODEL = "stanford-crfm/music-small-800k"
MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
LARGE_MODEL = "stanford-crfm/music-large-800k"

# === Model Card ===
model_card = ModelCard(
    name="Anticipatory Music Transformer",
    description="Using Anticipatory Music Transformer (AMT) to generate accompaniment for a given MIDI file with selected melody.",
    author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
    tags=["midi", "generation", "accompaniment"],
    midi_in=True,
    midi_out=True
)

model_cache = {}

def load_amt_model(model_choice):
    """Loads and caches the AMT model inside the worker process."""
    if model_choice in model_cache:
        return model_cache[model_choice]  
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
    
    model_cache[model_choice] = model  
    return model



@spaces.GPU
def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length):
    """Generates accompaniment for the entire MIDI input, conditioned on the user-selected history length."""
    
    model = load_amt_model(model_choice)  
    events = midi_to_events(midi_file.name)
    total_time = round(ops.max_time(events, seconds=True))
    
    # Extract melody line using the selected MIDI program number
    events, melody = extract_instruments(events, [selected_midi_program])

    if not melody:  
        return None, "⚠️ Please select a valid MIDI program that contains events."

    history = ops.clip(events, 0, history_length, clip_duration=False)

    # Generate accompaniment for the remaining duration
    accompaniment = generate(
        model,
        history_length,  # Start generating after user-defined history length
        total_time,  # Generate for the full remaining duration
        inputs=history,
        controls=melody,
        top_p=0.95,
        debug=False
    )

    # Combine the accompaniment with the melody
    output_events = ops.clip(ops.combine(accompaniment, melody), 0, total_time, clip_duration=True)
    
    # Convert back to MIDI
    output_midi = "generated_accompaniment_huggingface.mid"
    mid = events_to_midi(output_events)
    mid.save(output_midi)

    return output_midi, None  


def process_fn(input_midi, model_choice, selected_midi_program, history_length):
    """Processes the input and runs AMT to generate accompaniment for the full MIDI file."""   
    output_midi, error_message = generate_accompaniment(input_midi, model_choice, selected_midi_program, history_length)

    if error_message:  
        return None, {"message": error_message}  

    output_labels = LabelList()
    return output_midi, output_labels


# === Build HARP gradioEndpoint ===
with gr.Blocks() as demo:
    components = [
        gr.Dropdown(
          choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], 
            value=MEDIUM_MODEL, 
            label="Select AMT Model (Faster vs. Higher Quality)"
        ),
        gr.Slider(0, 127, step=1, value=1, label="Select Melody Instrument (MIDI Program Number)"),
        gr.Slider(1, 10, step=1, value=5, label="Select History Length (seconds)")
    ]

        # Wrap in PyHARP
    app = build_endpoint(
        model_card=model_card,
        components=components,
        process_fn=process_fn)

# Launch PyHARP App
demo.launch(share=True, show_error=True, debug=True)