File size: 4,405 Bytes
4504581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import spaces 
import pickle
import subprocess
import torch
import torch.nn as nn
import gradio as gr
from dataclasses import asdict
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download

from transformer_model import Transformer  
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList

# Model/artifacts from HF Hub
REPO_ID = "amaai-lab/text2midi"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl")
# Optional, only if you later add WAV preview:
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2")


# (Optional) MIDI -> WAV 
def save_wav(midi_path: str) -> str:
    directory = os.path.dirname(midi_path) or "."
    stem = os.path.splitext(os.path.basename(midi_path))[0]
    midi_filepath = os.path.join(directory, f"{stem}.mid")
    wav_filepath = os.path.join(directory, f"{stem}.wav")
    cmd = (
        f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell "
        f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null"
    )
    subprocess.run(cmd, shell=True, check=False)
    return wav_filepath


# Core Text -> MIDI
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load REMI vocab/tokenizer (pickle dict used by the provided model)
    with open(TOKENIZER_PATH, "rb") as f:
        r_tokenizer = pickle.load(f)

    vocab_size = len(r_tokenizer)
    model = Transformer(
        vocab_size,    # vocab size
        768,           # d_model
        8,             # nhead
        2048,          # dim_feedforward
        18,            # nlayers
        1024,          # max_seq_len
        False,         # use_rotary
        8,             # rotary_dim
        device=device  # device
    )
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model.eval()

    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)

    input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device)
    attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device)

    with torch.no_grad():
        output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)

    output_list = output[0].tolist()
    generated_midi = r_tokenizer.decode(output_list)

    midi_path = "output.mid"
    generated_midi.dump_midi(midi_path)
    return midi_path

# HARP process function
# Return JSON first, MIDI second
@spaces.GPU(duration=120)
def process_fn(prompt: str, temperature: float, max_length: int):
    try:
        midi_path = generate_midi(prompt, float(temperature), int(max_length))
        labels = LabelList()  # add MidiLabel entries here if you have metadata
        return asdict(labels), midi_path
    except Exception as e:
        # On error: return JSON with error message, and no file
        return {"message": f"Error: {e}"}, None

# HARP Model Card
model_card = ModelCard(
    name="Text2MIDI (HARP)",
    description="Generate MIDI from a text prompt using a transformer decoder conditioned on T5 embeddings.",
    author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans",
    tags=["text-to-music", "midi", "generation"]
)

# Gradio + HARP UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎶 text2midi")

    # Inputs
    prompt_in = gr.Textbox(label="Prompt").harp_required(True)
    temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Temperature", interactive=True)
    maxlen_in = gr.Number(value=500, label="Max Length (tokens)", minimum=64, maximum=2000, step=64)

    # Outputs (JSON FIRST for HARP, then MIDI)
    labels_out = gr.JSON(label="Labels / Metadata")
    midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath")

    # Build HARP endpoint
    _ = build_endpoint(
        model_card=model_card,
        input_components=[prompt_in, temperature_in, maxlen_in],
        output_components=[labels_out, midi_out],  # JSON first
        process_fn=process_fn
    )


# Launch App
demo.launch(share=True, show_error=True, debug=True)