lllindsey0615's picture
initial commit
4504581
raw
history blame
4.41 kB
import os
import spaces
import pickle
import subprocess
import torch
import torch.nn as nn
import gradio as gr
from dataclasses import asdict
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download
from transformer_model import Transformer
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# Model/artifacts from HF Hub
REPO_ID = "amaai-lab/text2midi"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl")
# Optional, only if you later add WAV preview:
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2")
# (Optional) MIDI -> WAV
def save_wav(midi_path: str) -> str:
directory = os.path.dirname(midi_path) or "."
stem = os.path.splitext(os.path.basename(midi_path))[0]
midi_filepath = os.path.join(directory, f"{stem}.mid")
wav_filepath = os.path.join(directory, f"{stem}.wav")
cmd = (
f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell "
f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null"
)
subprocess.run(cmd, shell=True, check=False)
return wav_filepath
# Core Text -> MIDI
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load REMI vocab/tokenizer (pickle dict used by the provided model)
with open(TOKENIZER_PATH, "rb") as f:
r_tokenizer = pickle.load(f)
vocab_size = len(r_tokenizer)
model = Transformer(
vocab_size, # vocab size
768, # d_model
8, # nhead
2048, # dim_feedforward
18, # nlayers
1024, # max_seq_len
False, # use_rotary
8, # rotary_dim
device=device # device
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device)
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device)
with torch.no_grad():
output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
output_list = output[0].tolist()
generated_midi = r_tokenizer.decode(output_list)
midi_path = "output.mid"
generated_midi.dump_midi(midi_path)
return midi_path
# HARP process function
# Return JSON first, MIDI second
@spaces.GPU(duration=120)
def process_fn(prompt: str, temperature: float, max_length: int):
try:
midi_path = generate_midi(prompt, float(temperature), int(max_length))
labels = LabelList() # add MidiLabel entries here if you have metadata
return asdict(labels), midi_path
except Exception as e:
# On error: return JSON with error message, and no file
return {"message": f"Error: {e}"}, None
# HARP Model Card
model_card = ModelCard(
name="Text2MIDI (HARP)",
description="Generate MIDI from a text prompt using a transformer decoder conditioned on T5 embeddings.",
author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans",
tags=["text-to-music", "midi", "generation"]
)
# Gradio + HARP UI
with gr.Blocks() as demo:
gr.Markdown("## 🎶 text2midi")
# Inputs
prompt_in = gr.Textbox(label="Prompt").harp_required(True)
temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Temperature", interactive=True)
maxlen_in = gr.Number(value=500, label="Max Length (tokens)", minimum=64, maximum=2000, step=64)
# Outputs (JSON FIRST for HARP, then MIDI)
labels_out = gr.JSON(label="Labels / Metadata")
midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath")
# Build HARP endpoint
_ = build_endpoint(
model_card=model_card,
input_components=[prompt_in, temperature_in, maxlen_in],
output_components=[labels_out, midi_out], # JSON first
process_fn=process_fn
)
# Launch App
demo.launch(share=True, show_error=True, debug=True)