Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces # Enables ZeroGPU on Hugging Face | |
| from transformers import AutoModelForCausalLM | |
| from anticipation.sample import generate | |
| from anticipation.convert import events_to_midi, midi_to_events | |
| from anticipation import ops | |
| from anticipation.tokenize import extract_instruments | |
| import torch | |
| from pyharp import * | |
| from safetensors.torch import load_file | |
| import os | |
| #Model Choices | |
| SMALL_MODEL = "stanford-crfm/music-small-800k" | |
| MEDIUM_MODEL = "stanford-crfm/music-medium-800k" | |
| LARGE_MODEL = "stanford-crfm/music-large-800k" | |
| # === Model Card === | |
| model_card = ModelCard( | |
| name="Anticipatory Music Transformer", | |
| description="Using Anticipatory Music Transformer (AMT) to generate accompaniment for a given MIDI file with selected melody.", | |
| author="John Thickstun, David Hall, Chris Donahue, Percy Liang", | |
| tags=["midi", "generation", "accompaniment"], | |
| midi_in=True, | |
| midi_out=True | |
| ) | |
| model_cache = {} | |
| ''' | |
| def load_amt_model(model_choice): | |
| """Loads and caches the AMT model inside the worker process.""" | |
| if model_choice in model_cache: | |
| return model_cache[model_choice] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained(model_choice).to(device) | |
| model_cache[model_choice] = model | |
| return model | |
| ''' | |
| def load_amt_model(model_choice): | |
| """Loads and caches the AMT model inside the worker process.""" | |
| if model_choice in model_cache: | |
| return model_cache[model_choice] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model_choice == LARGE_MODEL: | |
| # Large model uses safetensors | |
| model_dir = "./tmp_music_large" | |
| os.makedirs(model_dir, exist_ok=True) | |
| print(f"Loading {LARGE_MODEL} from safetensors format...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LARGE_MODEL, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| else: | |
| # Small and medium use standard PyTorch .bin format | |
| print(f"Loading {model_choice} from standard format...") | |
| model = AutoModelForCausalLM.from_pretrained(model_choice).to(device) | |
| model_cache[model_choice] = model | |
| return model | |
| def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length): | |
| """Generates accompaniment for the entire MIDI input, conditioned on the user-selected history length.""" | |
| model = load_amt_model(model_choice) | |
| events = midi_to_events(midi_file.name) | |
| total_time = round(ops.max_time(events, seconds=True)) | |
| # Extract melody line using the selected MIDI program number | |
| events, melody = extract_instruments(events, [selected_midi_program]) | |
| if not melody: | |
| return None, "⚠️ Please select a valid MIDI program that contains events." | |
| history = ops.clip(events, 0, history_length, clip_duration=False) | |
| # Generate accompaniment for the remaining duration | |
| accompaniment = generate( | |
| model, | |
| history_length, # Start generating after user-defined history length | |
| total_time, # Generate for the full remaining duration | |
| inputs=history, | |
| controls=melody, | |
| top_p=0.95, | |
| debug=False | |
| ) | |
| # Combine the accompaniment with the melody | |
| output_events = ops.clip(ops.combine(accompaniment, melody), 0, total_time, clip_duration=True) | |
| # Convert back to MIDI | |
| output_midi = "generated_accompaniment_huggingface.mid" | |
| mid = events_to_midi(output_events) | |
| mid.save(output_midi) | |
| return output_midi, None | |
| def process_fn(input_midi, model_choice, selected_midi_program, history_length): | |
| """Processes the input and runs AMT to generate accompaniment for the full MIDI file.""" | |
| output_midi, error_message = generate_accompaniment(input_midi, model_choice, selected_midi_program, history_length) | |
| if error_message: | |
| return None, {"message": error_message} | |
| output_labels = LabelList() | |
| return output_midi, output_labels | |
| # === Build HARP gradioEndpoint === | |
| with gr.Blocks() as demo: | |
| components = [ | |
| gr.Dropdown( | |
| choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL], | |
| value=MEDIUM_MODEL, | |
| label="Select AMT Model (Faster vs. Higher Quality)" | |
| ), | |
| gr.Slider(0, 127, step=1, value=1, label="Select Melody Instrument (MIDI Program Number)"), | |
| gr.Slider(1, 10, step=1, value=5, label="Select History Length (seconds)") | |
| ] | |
| # Wrap in PyHARP | |
| app = build_endpoint( | |
| model_card=model_card, | |
| components=components, | |
| process_fn=process_fn) | |
| # Launch PyHARP App | |
| demo.launch(share=True, show_error=True, debug=True) | |