File size: 5,386 Bytes
4504581
 
 
 
 
 
 
 
 
 
07e756d
 
4504581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07e756d
 
 
 
 
4504581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07e756d
4504581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32db598
 
3b314da
 
 
2d169b5
32db598
4504581
 
 
 
 
 
 
 
 
a9b503f
2d169b5
 
a9b503f
 
32db598
a9b503f
 
 
 
 
 
 
 
4504581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import spaces 
import pickle
import subprocess
import torch
import torch.nn as nn
import gradio as gr
from dataclasses import asdict
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download
from time import time_ns
from uuid import uuid4

from transformer_model import Transformer  
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList

# Model/artifacts from HF Hub
REPO_ID = "amaai-lab/text2midi"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl")
# Optional, only if you later add WAV preview:
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2")


# (Optional) MIDI -> WAV 
def save_wav(midi_path: str) -> str:
    directory = os.path.dirname(midi_path) or "."
    stem = os.path.splitext(os.path.basename(midi_path))[0]
    midi_filepath = os.path.join(directory, f"{stem}.mid")
    wav_filepath = os.path.join(directory, f"{stem}.wav")
    cmd = (
        f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell "
        f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null"
    )
    subprocess.run(cmd, shell=True, check=False)
    return wav_filepath

# Helpers
def _unique_path(ext: str) -> str:
    """Create a unique file path in /tmp to avoid naming collisions."""
    return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}")


# Core Text -> MIDI
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load REMI vocab/tokenizer (pickle dict used by the provided model)
    with open(TOKENIZER_PATH, "rb") as f:
        r_tokenizer = pickle.load(f)

    vocab_size = len(r_tokenizer)
    model = Transformer(
        vocab_size,    # vocab size
        768,           # d_model
        8,             # nhead
        2048,          # dim_feedforward
        18,            # nlayers
        1024,          # max_seq_len
        False,         # use_rotary
        8,             # rotary_dim
        device=device  # device
    )
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model.eval()

    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)

    input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device)
    attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device)

    with torch.no_grad():
        output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)

    output_list = output[0].tolist()
    generated_midi = r_tokenizer.decode(output_list)

    midi_path = _unique_path(".mid")
    generated_midi.dump_midi(midi_path)
    return midi_path

# HARP process function
# Return JSON first, MIDI second
@spaces.GPU(duration=120)
def process_fn(prompt: str, temperature: float, max_length: int):
    try:
        midi_path = generate_midi(prompt, float(temperature), int(max_length))
        labels = LabelList()  # add MidiLabel entries here if you have metadata
        return asdict(labels), midi_path
    except Exception as e:
        # On error: return JSON with error message, and no file
        return {"message": f"Error: {e}"}, None

# HARP Model Card
model_card = ModelCard(
    name="Text2MIDI Generation",
    description=(
        "Turn your musical ideas into playable MIDI notes. \n"
        "Input: Describe what you'd like to hear. For example: a gentle piano lullaby with soft strings. \n"
        "Output: This model will generate a matching MIDI sequence for playback or editing. \n"
        "Use the sliders to control the amount of creativity and length."
    ),
    author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans",
    tags=["text-to-music", "midi", "generation"]
)

# Gradio + HARP UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎶 text2midi")

    # Inputs
    prompt_in = gr.Textbox(
        label="Describe Your Music",
        info="Type a short phrase like 'calm piano with flowing arpeggios' ",
        ).harp_required(True)
    temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Creativity", info=(
        "Adjusts how much freedom the model takes while composing.\n"
        "Lower = safer and more predictable (structured), "
        "Higher = more varied and expressive."
    ), interactive=True)

    maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Composition Length", info=(
        "Determines how long the generated piece is in musical tokens.\n"
        "Higher values produce longer phrases (roughly more measures of music).")
    )

    # Outputs (JSON FIRST for HARP, then MIDI)
    labels_out = gr.JSON(label="Labels / Metadata")
    midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath")

    # Build HARP endpoint
    _ = build_endpoint(
        model_card=model_card,
        input_components=[prompt_in, temperature_in, maxlen_in],
        output_components=[labels_out, midi_out],  # JSON first
        process_fn=process_fn
    )


# Launch App
demo.launch(share=True, show_error=True, debug=True)