Spaces:
Sleeping
Sleeping
File size: 6,702 Bytes
b148e11 1b26622 b148e11 ad462a2 5838ea5 b148e11 ad462a2 b148e11 5838ea5 ad462a2 b148e11 ad462a2 b148e11 ad462a2 b148e11 5838ea5 1b26622 b148e11 5838ea5 b148e11 749e66d 5151779 749e66d b148e11 5151779 749e66d 5151779 749e66d 146e687 749e66d 146e687 749e66d 1b26622 749e66d 5838ea5 de6a347 749e66d 5151779 749e66d 5151779 749e66d 5838ea5 1b26622 5838ea5 1ebc467 de6a347 5838ea5 de6a347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""
VR Music Generator - HuggingFace Spaces Version
Generates music from text descriptions using the text2midi AI model.
"""
import gradio as gr
import torch
import subprocess
import os
import sys
import pickle
import tempfile
import numpy as np
import scipy.io.wavfile as wavfile
from huggingface_hub import hf_hub_download
# Add text2midi model to path
sys.path.insert(0, "text2midi_repo")
repo_id = "amaai-lab/text2midi"
# Detect device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Global model variables
text2midi_model = None
midi_tokenizer = None
text_tokenizer = None
def load_text2midi_model():
"""Load the text2midi model and tokenizers."""
global text2midi_model, midi_tokenizer, text_tokenizer
try:
from model.transformer_model import Transformer
from transformers import T5Tokenizer
print("Loading text2midi model...")
# Download model files
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
print(f"Model path: {model_path}")
print(f"Tokenizer path: {tokenizer_path}")
# Load MIDI tokenizer
with open(tokenizer_path, "rb") as f:
midi_tokenizer = pickle.load(f)
vocab_size = len(midi_tokenizer)
print(f"Vocab size: {vocab_size}")
# Initialize and load model
text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
text2midi_model.to(device)
text2midi_model.eval()
# Load T5 tokenizer for text encoding
text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
print("Text2midi model loaded successfully!")
return True
except Exception as e:
print(f"Warning: Could not load text2midi model: {e}")
import traceback
traceback.print_exc()
print("Falling back to simple MIDI generation...")
return False
# Try to load the model
MODEL_LOADED = load_text2midi_model()
def find_soundfont():
"""Find a SoundFont file on the system."""
common_paths = [
"/usr/share/sounds/sf2/FluidR3_GM.sf2",
"/usr/share/soundfonts/FluidR3_GM.sf2",
"/usr/share/sounds/sf2/default-GM.sf2",
"FluidR3_GM.sf2",
]
for path in common_paths:
if os.path.exists(path):
return path
return None
SOUNDFONT_PATH = find_soundfont()
print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}")
def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9):
"""Generate MIDI using the text2midi model."""
global text2midi_model, midi_tokenizer, text_tokenizer
inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
with torch.no_grad():
output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
output_list = output[0].tolist()
generated_midi = midi_tokenizer.decode(output_list)
generated_midi.dump_midi(output_path)
return output_path
def midi_to_wav(midi_path: str, wav_path: str, sample_rate: int = 44100) -> bool:
"""Convert MIDI to WAV using FluidSynth."""
if not SOUNDFONT_PATH:
return False
result = subprocess.run([
"fluidsynth",
"-ni",
"-F", wav_path,
"-r", str(sample_rate),
SOUNDFONT_PATH,
midi_path,
], capture_output=True, text=True, timeout=120)
if result.returncode != 0:
print(f"FluidSynth error: {result.stderr}")
return False
return os.path.exists(wav_path)
def generate_music(prompt: str):
"""Generate music from text prompt. Returns audio as numpy array."""
if not prompt or not prompt.strip():
return None
midi_path = None
wav_path = None
try:
# Create temporary files
midi_fd, midi_path = tempfile.mkstemp(suffix='.mid')
os.close(midi_fd)
wav_fd, wav_path = tempfile.mkstemp(suffix='.wav')
os.close(wav_fd)
# Generate MIDI (reduced max_len for faster testing)
if MODEL_LOADED:
generate_midi_with_model(prompt, midi_path, max_len=128, temperature=0.9)
else:
from midiutil import MIDIFile
midi = MIDIFile(1)
midi.addTempo(0, 0, 120)
notes = [60, 62, 64, 65, 67, 69, 71, 72]
for i, note in enumerate(notes[:min(len(prompt.split()), 8)]):
midi.addNote(0, 0, note, i, 1, 100)
with open(midi_path, "wb") as f:
midi.writeFile(f)
# Convert to WAV
if SOUNDFONT_PATH and midi_to_wav(midi_path, wav_path):
# Read WAV file and return as numpy array
sample_rate, audio_data = wavfile.read(wav_path)
# Convert to float32 for Gradio compatibility
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
return (sample_rate, audio_data)
else:
return None
except Exception as e:
import traceback
traceback.print_exc()
return None
finally:
# Clean up temp files
if midi_path and os.path.exists(midi_path):
try:
os.unlink(midi_path)
except:
pass
if wav_path and os.path.exists(wav_path):
try:
os.unlink(wav_path)
except:
pass
# Create simple Gradio Interface
demo = gr.Interface(
fn=generate_music,
inputs=gr.Textbox(
label="Music Prompt",
placeholder="A cheerful pop song with piano and drums in C major",
lines=2
),
outputs=gr.Audio(label="Generated Music", type="numpy"),
title="VR Game Music Generator",
description="Generate music from text descriptions using AI. Enter a prompt describing the music you want.",
examples=[
["A cheerful pop song with piano and drums"],
["An energetic electronic trance track at 138 BPM"],
["A slow emotional classical piece with violin"],
["Epic cinematic soundtrack with dark atmosphere"],
],
cache_examples=False,
allow_flagging="never"
)
# Launch
demo.launch(show_api=True)
|