File size: 7,408 Bytes
6d2a4fa
7166cfc
3caedfa
4ec0a5d
 
 
 
 
 
e0c2e8c
4ec0a5d
3caedfa
6d2a4fa
 
 
 
4ec0a5d
3caedfa
4ec0a5d
 
 
 
6d2a4fa
3caedfa
4ec0a5d
 
e0c2e8c
251510f
 
 
7599353
e0c2e8c
4ec0a5d
6d2a4fa
4ec0a5d
 
6d2a4fa
3caedfa
 
55b78cb
6d2a4fa
3caedfa
 
 
6d2a4fa
4ec0a5d
 
6d2a4fa
 
 
 
 
 
 
 
 
 
 
3caedfa
6d2a4fa
 
e0c2e8c
 
 
 
 
 
251510f
 
e0c2e8c
 
251510f
 
 
 
 
 
e0c2e8c
 
 
251510f
 
e0c2e8c
251510f
e0c2e8c
251510f
e0c2e8c
 
251510f
 
 
 
 
e0c2e8c
 
 
 
251510f
 
e0c2e8c
251510f
 
e0c2e8c
251510f
486dcc8
251510f
486dcc8
 
 
 
251510f
e0c2e8c
251510f
 
e0c2e8c
 
 
7166cfc
6d2a4fa
8186da5
6d2a4fa
 
 
 
 
 
3caedfa
6d2a4fa
251510f
4ec0a5d
e0c2e8c
 
 
 
 
4ec0a5d
6d2a4fa
0debc61
251510f
 
 
 
 
4ec0a5d
 
 
251510f
 
0debc61
 
4ec0a5d
 
 
 
6d2a4fa
4ec0a5d
 
 
 
 
 
 
6d2a4fa
 
 
 
4ec0a5d
6d2a4fa
4ec0a5d
 
3caedfa
8186da5
4ec0a5d
3caedfa
4ec0a5d
6d2a4fa
4ec0a5d
 
6d2a4fa
4ec0a5d
 
 
3caedfa
 
4ec0a5d
3caedfa
 
 
4ec0a5d
6d2a4fa
4ec0a5d
3caedfa
4ec0a5d
3caedfa
4ec0a5d
 
6d2a4fa
4ec0a5d
 
 
 
 
 
6d2a4fa
4ec0a5d
 
 
6d2a4fa
251510f
4ec0a5d
 
3caedfa
 
4ec0a5d
 
6d2a4fa
4ec0a5d
 
 
3caedfa
4ec0a5d
 
 
 
 
6d2a4fa
4ec0a5d
 
3caedfa
 
4ec0a5d
 
 
 
3caedfa
6d2a4fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import spaces  # Enables ZeroGPU on Hugging Face
import gradio as gr
import torch
from dataclasses import asdict

from transformers import AutoModelForCausalLM
from anticipation.sample import generate
from anticipation.convert import events_to_midi, midi_to_events
from anticipation.tokenize import extract_instruments
from anticipation import ops
from mido import MidiFile  # parse MIDI explicitly to avoid .time error

from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList


# Model Choices
SMALL_MODEL = "stanford-crfm/music-small-800k"
MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
LARGE_MODEL = "stanford-crfm/music-large-800k"


# Model Card (new pyharp)
model_card = ModelCard(
    name="Anticipatory Music Transformer",
   description=(
    "Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
    "Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. "
    "Output: a new MIDI file with extended accompaniment matching the melody continuation. "
    "Use the controls to choose model size and how much of the song is used as context."
),
    author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
    tags=["midi", "generation", "accompaniment"]
)


# Model cache
_model_cache = {}

def load_amt_model(model_choice: str):
    """Loads and caches the AMT model inside the worker process (same behavior as old app)."""
    if model_choice in _model_cache:
        return _model_cache[model_choice]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if model_choice == LARGE_MODEL:
        print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...")
        model = AutoModelForCausalLM.from_pretrained(
            LARGE_MODEL,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(device)
    else:
        print(f"Loading {model_choice} ...")
        model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)

    _model_cache[model_choice] = model
    return model

def find_melody_program(mid, debug=False):
    track_stats = []
    for i, track in enumerate(mid.tracks):
        pitches, times = [], []
        current_time = 0
        for msg in track:
            current_time += getattr(msg, "time", 0)
            if msg.type == "note_on" and msg.velocity > 0:
                pitches.append(msg.note)
                times.append(current_time)
        if pitches:
            mean_pitch = sum(pitches) / len(pitches)
            span = (max(times) - min(times)) or 1
            density = len(pitches) / span
            polyphony = len(set(pitches)) / len(pitches)
            track_stats.append((i, mean_pitch, len(pitches), density, polyphony))

    if not track_stats:
        if debug:
            print("No notes detected in any track.")
        return 0

    melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0]

    return melody_idx


def get_program_number(mid, track_index):
    for msg in mid.tracks[track_index]:
        if msg.type == "program_change":
            return msg.program
    return None


def auto_extract_melody(mid, debug=False):
    events = midi_to_events(mid)
    melody_track = find_melody_program(mid, debug=debug)
    melody_program = get_program_number(mid, melody_track)

    if debug:
        print(f"Melody Track: {melody_track} | Program: {melody_program}")

    if melody_program is not None:
        all_events_copy = events.copy()
        events, melody = extract_instruments(events, [melody_program])

        for e in all_events_copy:
            if hasattr(e, "program") and e.program == melody_program:
                events.append(e)
    else:
        if debug:
            print("No program number found; using all events as melody.")
        melody = events

    return events, melody

@spaces.GPU
# Core generation 
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
    """
    Generates accompaniment for the entire MIDI input, conditioned on user-selected history length.
    FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error.
    """
    model = load_amt_model(model_choice)

    # Parse MIDI correctly, then convert to events
    mid = MidiFile(midi_path)
    print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")

    # Automatically detect and extract melody
    all_events, melody = auto_extract_melody(mid, debug=True)
    if len(melody) == 0:
        print("No melody detected; using all events")
        melody = all_events

    # History portion
    history = ops.clip(all_events, 0, history_length, clip_duration=False)
    start_time = ops.max_time(history, seconds=True)
    
    mid_time = mid.length or 0
    ops_time = ops.max_time(all_events, seconds=True)
    total_time = round(max(mid_time, ops_time))

    accompaniment = generate(
        model,
        start_time=history_length,
        end_time=total_time,
        inputs=history,
        controls=melody,
        top_p=0.95,
        debug=False
    )

    # Combine accompaniment + melody and clip
    output_events = ops.clip(
        ops.combine(accompaniment, melody),
        0,
        total_time,
        clip_duration=True
    )

    # Save MIDI 
    output_midi = "generated_accompaniment_huggingface.mid"
    mid_out = events_to_midi(output_events)
    mid_out.save(output_midi)

    return output_midi, None


# HARP process fn (JSON FIRST)
def process_fn(input_midi_path, model_choice, history_length):
    """
    Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object.
    """
    output_midi, error_message = generate_accompaniment(
        input_midi_path,
        model_choice,
        float(history_length)
    )

    if error_message:
        # JSON first, then no file
        return {"message": error_message}, None

    labels = LabelList()  # add label entries if desired
    # JSON first, then MIDI filepath
    return asdict(labels), output_midi

# Gradio + HARP UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎼 Anticipatory Music Transformer")

    # Inputs
    input_midi = gr.File(
        file_types=[".mid", ".midi"],
        label="Input MIDI File",
        type="filepath",
    ).harp_required(True)

    model_dropdown = gr.Dropdown(
        choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
        value=MEDIUM_MODEL,
        label="Select AMT Model (Faster vs. Higher Quality)"
    )

    history_slider = gr.Slider(
        minimum=1, maximum=10, step=1, value=5,
        label="Select History Length (seconds)"
    )

    # Outputs (JSON FIRST)
    output_labels = gr.JSON(label="Labels / Metadata")
    output_midi = gr.File(
        file_types=[".mid", ".midi"],
        label="Generated MIDI Output",
        type="filepath",
    )

    # Build HARP endpoint (new signature)
    _ = build_endpoint(
        model_card=model_card,
        input_components=[
            input_midi,
            model_dropdown,
            history_slider
        ],
        output_components=[
            output_labels,  # JSON FIRST
            output_midi     # MIDI SECOND
        ],
        process_fn=process_fn
    )

# Launch App
demo.launch(share=True, show_error=True, debug=True)