File size: 5,386 Bytes
4504581 07e756d 4504581 07e756d 4504581 07e756d 4504581 32db598 3b314da 2d169b5 32db598 4504581 a9b503f 2d169b5 a9b503f 32db598 a9b503f 4504581 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import spaces
import pickle
import subprocess
import torch
import torch.nn as nn
import gradio as gr
from dataclasses import asdict
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download
from time import time_ns
from uuid import uuid4
from transformer_model import Transformer
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# Model/artifacts from HF Hub
REPO_ID = "amaai-lab/text2midi"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl")
# Optional, only if you later add WAV preview:
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2")
# (Optional) MIDI -> WAV
def save_wav(midi_path: str) -> str:
directory = os.path.dirname(midi_path) or "."
stem = os.path.splitext(os.path.basename(midi_path))[0]
midi_filepath = os.path.join(directory, f"{stem}.mid")
wav_filepath = os.path.join(directory, f"{stem}.wav")
cmd = (
f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell "
f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null"
)
subprocess.run(cmd, shell=True, check=False)
return wav_filepath
# Helpers
def _unique_path(ext: str) -> str:
"""Create a unique file path in /tmp to avoid naming collisions."""
return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}")
# Core Text -> MIDI
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load REMI vocab/tokenizer (pickle dict used by the provided model)
with open(TOKENIZER_PATH, "rb") as f:
r_tokenizer = pickle.load(f)
vocab_size = len(r_tokenizer)
model = Transformer(
vocab_size, # vocab size
768, # d_model
8, # nhead
2048, # dim_feedforward
18, # nlayers
1024, # max_seq_len
False, # use_rotary
8, # rotary_dim
device=device # device
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device)
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device)
with torch.no_grad():
output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
output_list = output[0].tolist()
generated_midi = r_tokenizer.decode(output_list)
midi_path = _unique_path(".mid")
generated_midi.dump_midi(midi_path)
return midi_path
# HARP process function
# Return JSON first, MIDI second
@spaces.GPU(duration=120)
def process_fn(prompt: str, temperature: float, max_length: int):
try:
midi_path = generate_midi(prompt, float(temperature), int(max_length))
labels = LabelList() # add MidiLabel entries here if you have metadata
return asdict(labels), midi_path
except Exception as e:
# On error: return JSON with error message, and no file
return {"message": f"Error: {e}"}, None
# HARP Model Card
model_card = ModelCard(
name="Text2MIDI Generation",
description=(
"Turn your musical ideas into playable MIDI notes. \n"
"Input: Describe what you'd like to hear. For example: a gentle piano lullaby with soft strings. \n"
"Output: This model will generate a matching MIDI sequence for playback or editing. \n"
"Use the sliders to control the amount of creativity and length."
),
author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans",
tags=["text-to-music", "midi", "generation"]
)
# Gradio + HARP UI
with gr.Blocks() as demo:
gr.Markdown("## 🎶 text2midi")
# Inputs
prompt_in = gr.Textbox(
label="Describe Your Music",
info="Type a short phrase like 'calm piano with flowing arpeggios' ",
).harp_required(True)
temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Creativity", info=(
"Adjusts how much freedom the model takes while composing.\n"
"Lower = safer and more predictable (structured), "
"Higher = more varied and expressive."
), interactive=True)
maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Composition Length", info=(
"Determines how long the generated piece is in musical tokens.\n"
"Higher values produce longer phrases (roughly more measures of music).")
)
# Outputs (JSON FIRST for HARP, then MIDI)
labels_out = gr.JSON(label="Labels / Metadata")
midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath")
# Build HARP endpoint
_ = build_endpoint(
model_card=model_card,
input_components=[prompt_in, temperature_in, maxlen_in],
output_components=[labels_out, midi_out], # JSON first
process_fn=process_fn
)
# Launch App
demo.launch(share=True, show_error=True, debug=True) |