import os import spaces # Enables ZeroGPU on Hugging Face import gradio as gr import torch from dataclasses import asdict from transformers import AutoModelForCausalLM from anticipation.sample import generate from anticipation.convert import events_to_midi, midi_to_events from anticipation.tokenize import extract_instruments from anticipation import ops from mido import MidiFile # parse MIDI explicitly to avoid .time error from pyharp.core import ModelCard, build_endpoint from pyharp.labels import LabelList # Model Choices SMALL_MODEL = "stanford-crfm/music-small-800k" MEDIUM_MODEL = "stanford-crfm/music-medium-800k" LARGE_MODEL = "stanford-crfm/music-large-800k" # Model Card (new pyharp) model_card = ModelCard( name="Anticipatory Music Transformer", description=( "Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. " "Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. " "Output: a new MIDI file with extended accompaniment matching the melody continuation. " "Use the controls to choose model size and how much of the song is used as context." ), author="John Thickstun, David Hall, Chris Donahue, Percy Liang", tags=["midi", "generation", "accompaniment"] ) # Model cache _model_cache = {} def load_amt_model(model_choice: str): """Loads and caches the AMT model inside the worker process (same behavior as old app).""" if model_choice in _model_cache: return _model_cache[model_choice] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model_choice == LARGE_MODEL: print(f"Loading {LARGE_MODEL} (low_cpu_mem_usage, fp16 on CUDA if available)...") model = AutoModelForCausalLM.from_pretrained( LARGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True ).to(device) else: print(f"Loading {model_choice} ...") model = AutoModelForCausalLM.from_pretrained(model_choice).to(device) _model_cache[model_choice] = model return model def find_melody_program(mid, debug=False): track_stats = [] for i, track in enumerate(mid.tracks): pitches, times = [], [] current_time = 0 for msg in track: current_time += getattr(msg, "time", 0) if msg.type == "note_on" and msg.velocity > 0: pitches.append(msg.note) times.append(current_time) if pitches: mean_pitch = sum(pitches) / len(pitches) span = (max(times) - min(times)) or 1 density = len(pitches) / span polyphony = len(set(pitches)) / len(pitches) track_stats.append((i, mean_pitch, len(pitches), density, polyphony)) if not track_stats: if debug: print("No notes detected in any track.") return 0 melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0] return melody_idx def get_program_number(mid, track_index): for msg in mid.tracks[track_index]: if msg.type == "program_change": return msg.program return None def auto_extract_melody(mid, debug=False): events = midi_to_events(mid) melody_track = find_melody_program(mid, debug=debug) melody_program = get_program_number(mid, melody_track) if debug: print(f"Melody Track: {melody_track} | Program: {melody_program}") if melody_program is not None: all_events_copy = events.copy() events, melody = extract_instruments(events, [melody_program]) for e in all_events_copy: if hasattr(e, "program") and e.program == melody_program: events.append(e) else: if debug: print("No program number found; using all events as melody.") melody = events return events, melody @spaces.GPU # Core generation def generate_accompaniment(midi_path: str, model_choice: str, history_length: float): """ Generates accompaniment for the entire MIDI input, conditioned on user-selected history length. FIX: parse MIDI with mido.MidiFile before midi_to_events to avoid 'str' .time error. """ model = load_amt_model(model_choice) # Parse MIDI correctly, then convert to events mid = MidiFile(midi_path) print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})") # Automatically detect and extract melody all_events, melody = auto_extract_melody(mid, debug=True) if len(melody) == 0: print("No melody detected; using all events") melody = all_events # History portion history = ops.clip(all_events, 0, history_length, clip_duration=False) start_time = ops.max_time(history, seconds=True) mid_time = mid.length or 0 ops_time = ops.max_time(all_events, seconds=True) total_time = round(max(mid_time, ops_time)) accompaniment = generate( model, start_time=history_length, end_time=total_time, inputs=history, controls=melody, top_p=0.95, debug=False ) # Combine accompaniment + melody and clip output_events = ops.clip( ops.combine(accompaniment, melody), 0, total_time, clip_duration=True ) # Save MIDI output_midi = "generated_accompaniment_huggingface.mid" mid_out = events_to_midi(output_events) mid_out.save(output_midi) return output_midi, None # HARP process fn (JSON FIRST) def process_fn(input_midi_path, model_choice, history_length): """ Returns (JSON, MIDI filepath) to satisfy HARP client's expectation that the 0th item is an object. """ output_midi, error_message = generate_accompaniment( input_midi_path, model_choice, float(history_length) ) if error_message: # JSON first, then no file return {"message": error_message}, None labels = LabelList() # add label entries if desired # JSON first, then MIDI filepath return asdict(labels), output_midi # Gradio + HARP UI with gr.Blocks() as demo: gr.Markdown("## 🎼 Anticipatory Music Transformer") # Inputs input_midi = gr.File( file_types=[".mid", ".midi"], label="Input MIDI File", type="filepath", ).harp_required(True) model_dropdown = gr.Dropdown( choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], value=MEDIUM_MODEL, label="Select AMT Model (Faster vs. Higher Quality)" ) history_slider = gr.Slider( minimum=1, maximum=10, step=1, value=5, label="Select History Length (seconds)" ) # Outputs (JSON FIRST) output_labels = gr.JSON(label="Labels / Metadata") output_midi = gr.File( file_types=[".mid", ".midi"], label="Generated MIDI Output", type="filepath", ) # Build HARP endpoint (new signature) _ = build_endpoint( model_card=model_card, input_components=[ input_midi, model_dropdown, history_slider ], output_components=[ output_labels, # JSON FIRST output_midi # MIDI SECOND ], process_fn=process_fn ) # Launch App demo.launch(share=True, show_error=True, debug=True)