File size: 6,702 Bytes
b148e11
 
 
 
 
 
 
 
 
 
 
1b26622
 
b148e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad462a2
5838ea5
b148e11
ad462a2
b148e11
 
 
5838ea5
 
 
 
 
ad462a2
b148e11
 
ad462a2
 
b148e11
ad462a2
b148e11
5838ea5
1b26622
b148e11
5838ea5
b148e11
749e66d
5151779
749e66d
b148e11
5151779
749e66d
 
 
5151779
 
749e66d
146e687
749e66d
146e687
749e66d
 
 
 
 
 
 
 
 
 
 
 
1b26622
 
 
 
 
 
 
 
749e66d
 
5838ea5
de6a347
 
 
 
 
749e66d
5151779
749e66d
 
 
 
 
5151779
 
 
 
 
749e66d
 
5838ea5
 
 
 
 
 
 
1b26622
5838ea5
 
 
 
 
 
 
 
1ebc467
de6a347
5838ea5
 
 
de6a347
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
VR Music Generator - HuggingFace Spaces Version
Generates music from text descriptions using the text2midi AI model.
"""
import gradio as gr
import torch
import subprocess
import os
import sys
import pickle
import tempfile
import numpy as np
import scipy.io.wavfile as wavfile
from huggingface_hub import hf_hub_download

# Add text2midi model to path
sys.path.insert(0, "text2midi_repo")

repo_id = "amaai-lab/text2midi"

# Detect device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Global model variables
text2midi_model = None
midi_tokenizer = None
text_tokenizer = None

def load_text2midi_model():
    """Load the text2midi model and tokenizers."""
    global text2midi_model, midi_tokenizer, text_tokenizer

    try:
        from model.transformer_model import Transformer
        from transformers import T5Tokenizer

        print("Loading text2midi model...")

        # Download model files
        model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
        tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")

        print(f"Model path: {model_path}")
        print(f"Tokenizer path: {tokenizer_path}")

        # Load MIDI tokenizer
        with open(tokenizer_path, "rb") as f:
            midi_tokenizer = pickle.load(f)

        vocab_size = len(midi_tokenizer)
        print(f"Vocab size: {vocab_size}")

        # Initialize and load model
        text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
        text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
        text2midi_model.to(device)
        text2midi_model.eval()

        # Load T5 tokenizer for text encoding
        text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

        print("Text2midi model loaded successfully!")
        return True

    except Exception as e:
        print(f"Warning: Could not load text2midi model: {e}")
        import traceback
        traceback.print_exc()
        print("Falling back to simple MIDI generation...")
        return False

# Try to load the model
MODEL_LOADED = load_text2midi_model()

def find_soundfont():
    """Find a SoundFont file on the system."""
    common_paths = [
        "/usr/share/sounds/sf2/FluidR3_GM.sf2",
        "/usr/share/soundfonts/FluidR3_GM.sf2",
        "/usr/share/sounds/sf2/default-GM.sf2",
        "FluidR3_GM.sf2",
    ]
    for path in common_paths:
        if os.path.exists(path):
            return path
    return None

SOUNDFONT_PATH = find_soundfont()
print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}")

def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9):
    """Generate MIDI using the text2midi model."""
    global text2midi_model, midi_tokenizer, text_tokenizer

    inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    with torch.no_grad():
        output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)

    output_list = output[0].tolist()
    generated_midi = midi_tokenizer.decode(output_list)
    generated_midi.dump_midi(output_path)
    return output_path

def midi_to_wav(midi_path: str, wav_path: str, sample_rate: int = 44100) -> bool:
    """Convert MIDI to WAV using FluidSynth."""
    if not SOUNDFONT_PATH:
        return False

    result = subprocess.run([
        "fluidsynth",
        "-ni",
        "-F", wav_path,
        "-r", str(sample_rate),
        SOUNDFONT_PATH,
        midi_path,
    ], capture_output=True, text=True, timeout=120)

    if result.returncode != 0:
        print(f"FluidSynth error: {result.stderr}")
        return False

    return os.path.exists(wav_path)

def generate_music(prompt: str):
    """Generate music from text prompt. Returns audio as numpy array."""
    if not prompt or not prompt.strip():
        return None

    midi_path = None
    wav_path = None

    try:
        # Create temporary files
        midi_fd, midi_path = tempfile.mkstemp(suffix='.mid')
        os.close(midi_fd)

        wav_fd, wav_path = tempfile.mkstemp(suffix='.wav')
        os.close(wav_fd)

        # Generate MIDI (reduced max_len for faster testing)
        if MODEL_LOADED:
            generate_midi_with_model(prompt, midi_path, max_len=128, temperature=0.9)
        else:
            from midiutil import MIDIFile
            midi = MIDIFile(1)
            midi.addTempo(0, 0, 120)
            notes = [60, 62, 64, 65, 67, 69, 71, 72]
            for i, note in enumerate(notes[:min(len(prompt.split()), 8)]):
                midi.addNote(0, 0, note, i, 1, 100)
            with open(midi_path, "wb") as f:
                midi.writeFile(f)

        # Convert to WAV
        if SOUNDFONT_PATH and midi_to_wav(midi_path, wav_path):
            # Read WAV file and return as numpy array
            sample_rate, audio_data = wavfile.read(wav_path)
            # Convert to float32 for Gradio compatibility
            if audio_data.dtype == np.int16:
                audio_data = audio_data.astype(np.float32) / 32768.0
            elif audio_data.dtype == np.int32:
                audio_data = audio_data.astype(np.float32) / 2147483648.0
            return (sample_rate, audio_data)
        else:
            return None

    except Exception as e:
        import traceback
        traceback.print_exc()
        return None

    finally:
        # Clean up temp files
        if midi_path and os.path.exists(midi_path):
            try:
                os.unlink(midi_path)
            except:
                pass
        if wav_path and os.path.exists(wav_path):
            try:
                os.unlink(wav_path)
            except:
                pass

# Create simple Gradio Interface
demo = gr.Interface(
    fn=generate_music,
    inputs=gr.Textbox(
        label="Music Prompt",
        placeholder="A cheerful pop song with piano and drums in C major",
        lines=2
    ),
    outputs=gr.Audio(label="Generated Music", type="numpy"),
    title="VR Game Music Generator",
    description="Generate music from text descriptions using AI. Enter a prompt describing the music you want.",
    examples=[
        ["A cheerful pop song with piano and drums"],
        ["An energetic electronic trance track at 138 BPM"],
        ["A slow emotional classical piece with violin"],
        ["Epic cinematic soundtrack with dark atmosphere"],
    ],
    cache_examples=False,
    allow_flagging="never"
)

# Launch
demo.launch(show_api=True)