import os import spaces import pickle import subprocess import torch import torch.nn as nn import gradio as gr from dataclasses import asdict from transformers import T5Tokenizer from huggingface_hub import hf_hub_download from transformer_model import Transformer from pyharp.core import ModelCard, build_endpoint from pyharp.labels import LabelList # Model/artifacts from HF Hub REPO_ID = "amaai-lab/text2midi" MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl") # Optional, only if you later add WAV preview: SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2") # (Optional) MIDI -> WAV def save_wav(midi_path: str) -> str: directory = os.path.dirname(midi_path) or "." stem = os.path.splitext(os.path.basename(midi_path))[0] midi_filepath = os.path.join(directory, f"{stem}.mid") wav_filepath = os.path.join(directory, f"{stem}.wav") cmd = ( f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell " f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null" ) subprocess.run(cmd, shell=True, check=False) return wav_filepath # Core Text -> MIDI def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str: device = "cuda" if torch.cuda.is_available() else "cpu" # Load REMI vocab/tokenizer (pickle dict used by the provided model) with open(TOKENIZER_PATH, "rb") as f: r_tokenizer = pickle.load(f) vocab_size = len(r_tokenizer) model = Transformer( vocab_size, # vocab size 768, # d_model 8, # nhead 2048, # dim_feedforward 18, # nlayers 1024, # max_seq_len False, # use_rotary 8, # rotary_dim device=device # device ) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.eval() tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device) attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device) with torch.no_grad(): output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature) output_list = output[0].tolist() generated_midi = r_tokenizer.decode(output_list) midi_path = "output.mid" generated_midi.dump_midi(midi_path) return midi_path # HARP process function # Return JSON first, MIDI second @spaces.GPU(duration=120) def process_fn(prompt: str, temperature: float, max_length: int): try: midi_path = generate_midi(prompt, float(temperature), int(max_length)) labels = LabelList() # add MidiLabel entries here if you have metadata return asdict(labels), midi_path except Exception as e: # On error: return JSON with error message, and no file return {"message": f"Error: {e}"}, None # HARP Model Card model_card = ModelCard( name="Text2MIDI (HARP)", description="Generate MIDI from a text prompt using a transformer decoder conditioned on T5 embeddings.", author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans", tags=["text-to-music", "midi", "generation"] ) # Gradio + HARP UI with gr.Blocks() as demo: gr.Markdown("## 🎶 text2midi") # Inputs prompt_in = gr.Textbox(label="Prompt").harp_required(True) temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Temperature", interactive=True) maxlen_in = gr.Number(value=500, label="Max Length (tokens)", minimum=64, maximum=2000, step=64) # Outputs (JSON FIRST for HARP, then MIDI) labels_out = gr.JSON(label="Labels / Metadata") midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath") # Build HARP endpoint _ = build_endpoint( model_card=model_card, input_components=[prompt_in, temperature_in, maxlen_in], output_components=[labels_out, midi_out], # JSON first process_fn=process_fn ) # Launch App demo.launch(share=True, show_error=True, debug=True)