Spaces:
Sleeping
Sleeping
| import os | |
| import spaces # Enables ZeroGPU on Hugging Face | |
| import gradio as gr | |
| import torch | |
| from dataclasses import asdict | |
| from mido import MidiFile, tempo2bpm | |
| from transformers import AutoModelForCausalLM | |
| from anticipation.sample import generate | |
| from anticipation.convert import events_to_midi, midi_to_events | |
| from anticipation.tokenize import extract_instruments | |
| from anticipation import ops | |
| from pyharp.core import ModelCard, build_endpoint | |
| from pyharp.labels import LabelList | |
| # --------------------------------------------------------- | |
| # Model Choices | |
| # --------------------------------------------------------- | |
| SMALL_MODEL = "stanford-crfm/music-small-800k" | |
| MEDIUM_MODEL = "stanford-crfm/music-medium-800k" | |
| LARGE_MODEL = "stanford-crfm/music-large-800k" | |
| # --------------------------------------------------------- | |
| # Model Card (for HARP) | |
| # --------------------------------------------------------- | |
| model_card = ModelCard( | |
| name="Anticipatory Music Transformer", | |
| description=( | |
| "Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. " | |
| "Input: a MIDI file with a short accompaniment followed by a melody line. " | |
| "Output: a new MIDI file with extended accompaniment matching the melody. " | |
| "Use the sliders to choose model size and how much of the song is used as context." | |
| ), | |
| author="John Thickstun, David Hall, Chris Donahue, Percy Liang", | |
| tags=["midi", "generation", "accompaniment"] | |
| ) | |
| # --------------------------------------------------------- | |
| # Model Cache Loader | |
| # --------------------------------------------------------- | |
| _model_cache = {} | |
| def load_amt_model(model_choice: str): | |
| """Loads and caches the AMT model inside the worker process.""" | |
| if model_choice in _model_cache: | |
| return _model_cache[model_choice] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading {model_choice} ...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_choice, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| _model_cache[model_choice] = model | |
| return model | |
| # --------------------------------------------------------- | |
| # Melody Detection (Auto Program) | |
| # --------------------------------------------------------- | |
| def find_melody_program(mid, debug=False): | |
| """Detect melody track’s program number using pitch, density, and duration heuristics.""" | |
| track_stats = [] | |
| total_duration = 0 | |
| for i, track in enumerate(mid.tracks): | |
| pitches, times = [], [] | |
| current_time = 0 | |
| current_program = None | |
| track_note_count = 0 | |
| for msg in track: | |
| if msg.type not in ("note_on", "program_change"): | |
| continue | |
| current_time += msg.time | |
| if msg.type == "program_change": | |
| current_program = msg.program | |
| continue | |
| if msg.velocity > 0: | |
| pitches.append(msg.note) | |
| times.append(current_time) | |
| track_note_count += 1 | |
| if track_note_count >= 100: | |
| break | |
| if not pitches: | |
| continue | |
| track_duration = max(times) - min(times) | |
| total_duration = max(total_duration, current_time) | |
| mean_pitch = sum(pitches) / len(pitches) | |
| polyphony = len(set(pitches)) / len(pitches) | |
| coverage = track_duration / total_duration if total_duration > 0 else 0 | |
| track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage)) | |
| if not track_stats: | |
| return None, False | |
| if len(track_stats) == 1: | |
| prog = track_stats[0][3] | |
| if debug: | |
| print(f"Single-track MIDI detected — using program {prog or 'None'}") | |
| return prog, prog is not None | |
| candidates = [t for t in track_stats if t[3] is not None and t[3] > 0] | |
| has_valid_programs = len(candidates) > 0 | |
| if not candidates: | |
| candidates = track_stats | |
| max_notes = max(t[2] for t in candidates) | |
| max_pitch = max(t[1] for t in candidates) | |
| min_pitch = min(t[1] for t in candidates) | |
| pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1 | |
| best_score = -1 | |
| best_program = None | |
| best_track = None | |
| best_pitch = None | |
| for t in candidates: | |
| idx, pitch, notes, prog, poly, coverage = t | |
| pitch_norm = (pitch - min_pitch) / pitch_span | |
| notes_norm = notes / max_notes | |
| score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30) | |
| if poly < 0.15: | |
| score *= 0.95 | |
| if 55 <= pitch <= 75: | |
| score *= 1.1 | |
| if notes >= 30: | |
| score *= 1.05 | |
| if coverage > 0.7: | |
| score *= 1.15 | |
| if score > best_score: | |
| best_score = score | |
| best_program = prog | |
| best_track = idx | |
| best_pitch = pitch | |
| return best_program, has_valid_programs | |
| def auto_extract_melody(mid, debug=False): | |
| """Extract melody events from MIDI object (optimized for direct input).""" | |
| events = midi_to_events(mid) | |
| melody_program, has_valid_program = find_melody_program(mid, debug=debug) | |
| if not has_valid_program or melody_program is None or melody_program == 0: | |
| if debug: | |
| print("No valid program changes found; using all events as melody.") | |
| return events, events | |
| events, melody = extract_instruments(events, [melody_program]) | |
| if len(melody) == 0: | |
| if debug: | |
| print("No melody events found for program — reverting to all events.") | |
| return events, events | |
| if debug: | |
| print(f"Extracted {len(melody)} melody events from program {melody_program}") | |
| return events, melody | |
| # --------------------------------------------------------- | |
| # Core Generation Logic | |
| # --------------------------------------------------------- | |
| spaces.GPU | |
| def generate_accompaniment(midi_path: str, model_choice: str, history_length: float): | |
| """Generate accompaniment conditioned on context history and melody.""" | |
| model = load_amt_model(model_choice) | |
| mid = MidiFile(midi_path) | |
| print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})") | |
| all_events, melody = auto_extract_melody(mid, debug=True) | |
| if len(melody) == 0: | |
| melody = all_events | |
| mid_time = mid.length or 0 | |
| ops_time = ops.max_time(all_events, seconds=True) | |
| total_time = round(max(mid_time, ops_time)) | |
| melody_history = ops.clip(all_events, 0, history_length, clip_duration=False) | |
| melody_future = ops.clip(melody, history_length, total_time, clip_duration=False) | |
| accompaniment = generate( | |
| model, | |
| start_time=history_length, | |
| end_time=total_time, | |
| inputs=melody_history, | |
| controls=melody_future, | |
| top_p=0.95, | |
| debug=False | |
| ) | |
| output_events = ops.clip( | |
| ops.combine(accompaniment, melody), | |
| 0, | |
| total_time, | |
| clip_duration=True | |
| ) | |
| print(f"Generating from {history_length:.2f}s → {total_time:.2f}s " | |
| f"(duration = {total_time - history_length:.2f}s)") | |
| output_midi = "generated_accompaniment_huggingface.mid" | |
| events_to_midi(output_events).save(output_midi) | |
| return output_midi, None | |
| # --------------------------------------------------------- | |
| # HARP process_fn — with tempo-aware bar→seconds conversion | |
| # --------------------------------------------------------- | |
| def process_fn(input_midi_path, model_choice, history_length, use_bars): | |
| """Convert bars to seconds (tempo-aware) before generation.""" | |
| if use_bars: | |
| BEATS_PER_BAR = 4 | |
| bpm = 120 | |
| try: | |
| mid = MidiFile(input_midi_path) | |
| for tr in mid.tracks: | |
| for msg in tr: | |
| if msg.type == "set_tempo": | |
| bpm = round(tempo2bpm(msg.tempo)) | |
| break | |
| except Exception: | |
| pass | |
| seconds_per_bar = (60.0 / bpm) * BEATS_PER_BAR | |
| history_length = history_length * seconds_per_bar | |
| print(f"[INFO] Converted to {history_length:.2f}s from bars @ {bpm} BPM") | |
| output_midi, error_message = generate_accompaniment( | |
| input_midi_path, | |
| model_choice, | |
| float(history_length) | |
| ) | |
| if error_message: | |
| return {"message": error_message}, None | |
| return asdict(LabelList()), output_midi | |
| # --------------------------------------------------------- | |
| # Gradio + HARP UI | |
| # --------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎼 Anticipatory Music Transformer") | |
| input_midi = gr.File(file_types=[".mid", ".midi"], label="Input MIDI File", type="filepath").harp_required(True) | |
| model_dropdown = gr.Dropdown( | |
| choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], | |
| value=MEDIUM_MODEL, | |
| label="Select AMT Model (Faster vs. Higher Quality)" | |
| ) | |
| with gr.Row(): | |
| history_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Select History Length (seconds)") | |
| use_bars = gr.Checkbox( | |
| value=False, | |
| label="Use Musical Bars Instead of Seconds", | |
| info="If enabled, context length is interpreted as bars based on the MIDI tempo." | |
| ) | |
| def get_midi_tempo(midi_path): | |
| try: | |
| mid = MidiFile(midi_path) | |
| for track in mid.tracks: | |
| for msg in track: | |
| if msg.type == "set_tempo": | |
| return round(tempo2bpm(msg.tempo)) | |
| except Exception: | |
| pass | |
| return 120 | |
| BEATS_PER_BAR = 4 | |
| def bars_to_seconds(bars, bpm, beats_per_bar=BEATS_PER_BAR): | |
| return bars * beats_per_bar * (60.0 / bpm) | |
| def toggle_label_and_range(use_bars): | |
| if use_bars: | |
| return gr.update(label="Select History Length (bars)", minimum=2, maximum=8, step=2, value=4) | |
| else: | |
| return gr.update(label="Select History Length (seconds)", minimum=1, maximum=10, step=1, value=5) | |
| def update_bar_label(history_value, midi_path, use_bars): | |
| if not use_bars: | |
| return gr.update(label="Select History Length (seconds)") | |
| bpm = get_midi_tempo(midi_path) | |
| approx_sec = bars_to_seconds(history_value, bpm) | |
| return gr.update(label=f"Select History Length ({history_value} bars ≈ {approx_sec:.1f}s @ {bpm} BPM)") | |
| use_bars.change(fn=toggle_label_and_range, inputs=use_bars, outputs=history_slider) | |
| history_slider.change(fn=update_bar_label, inputs=[history_slider, input_midi, use_bars], | |
| outputs=history_slider, queue=False) | |
| output_labels = gr.JSON(label="Labels / Metadata") | |
| output_midi = gr.File(file_types=[".mid", ".midi"], label="Generated MIDI Output", type="filepath") | |
| _ = build_endpoint( | |
| model_card=model_card, | |
| input_components=[input_midi, model_dropdown, history_slider, use_bars], | |
| output_components=[output_labels, output_midi], | |
| process_fn=process_fn | |
| ) | |
| # --------------------------------------------------------- | |
| # Launch App | |
| # --------------------------------------------------------- | |
| demo.launch(share=True, show_error=True, debug=True) |