TextToAudio / app.py
kmaes's picture
Upload app.py with huggingface_hub
1b26622 verified
"""
VR Music Generator - HuggingFace Spaces Version
Generates music from text descriptions using the text2midi AI model.
"""
import gradio as gr
import torch
import subprocess
import os
import sys
import pickle
import tempfile
import numpy as np
import scipy.io.wavfile as wavfile
from huggingface_hub import hf_hub_download
# Add text2midi model to path
sys.path.insert(0, "text2midi_repo")
repo_id = "amaai-lab/text2midi"
# Detect device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Global model variables
text2midi_model = None
midi_tokenizer = None
text_tokenizer = None
def load_text2midi_model():
"""Load the text2midi model and tokenizers."""
global text2midi_model, midi_tokenizer, text_tokenizer
try:
from model.transformer_model import Transformer
from transformers import T5Tokenizer
print("Loading text2midi model...")
# Download model files
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
print(f"Model path: {model_path}")
print(f"Tokenizer path: {tokenizer_path}")
# Load MIDI tokenizer
with open(tokenizer_path, "rb") as f:
midi_tokenizer = pickle.load(f)
vocab_size = len(midi_tokenizer)
print(f"Vocab size: {vocab_size}")
# Initialize and load model
text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
text2midi_model.to(device)
text2midi_model.eval()
# Load T5 tokenizer for text encoding
text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
print("Text2midi model loaded successfully!")
return True
except Exception as e:
print(f"Warning: Could not load text2midi model: {e}")
import traceback
traceback.print_exc()
print("Falling back to simple MIDI generation...")
return False
# Try to load the model
MODEL_LOADED = load_text2midi_model()
def find_soundfont():
"""Find a SoundFont file on the system."""
common_paths = [
"/usr/share/sounds/sf2/FluidR3_GM.sf2",
"/usr/share/soundfonts/FluidR3_GM.sf2",
"/usr/share/sounds/sf2/default-GM.sf2",
"FluidR3_GM.sf2",
]
for path in common_paths:
if os.path.exists(path):
return path
return None
SOUNDFONT_PATH = find_soundfont()
print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}")
def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9):
"""Generate MIDI using the text2midi model."""
global text2midi_model, midi_tokenizer, text_tokenizer
inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
with torch.no_grad():
output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
output_list = output[0].tolist()
generated_midi = midi_tokenizer.decode(output_list)
generated_midi.dump_midi(output_path)
return output_path
def midi_to_wav(midi_path: str, wav_path: str, sample_rate: int = 44100) -> bool:
"""Convert MIDI to WAV using FluidSynth."""
if not SOUNDFONT_PATH:
return False
result = subprocess.run([
"fluidsynth",
"-ni",
"-F", wav_path,
"-r", str(sample_rate),
SOUNDFONT_PATH,
midi_path,
], capture_output=True, text=True, timeout=120)
if result.returncode != 0:
print(f"FluidSynth error: {result.stderr}")
return False
return os.path.exists(wav_path)
def generate_music(prompt: str):
"""Generate music from text prompt. Returns audio as numpy array."""
if not prompt or not prompt.strip():
return None
midi_path = None
wav_path = None
try:
# Create temporary files
midi_fd, midi_path = tempfile.mkstemp(suffix='.mid')
os.close(midi_fd)
wav_fd, wav_path = tempfile.mkstemp(suffix='.wav')
os.close(wav_fd)
# Generate MIDI (reduced max_len for faster testing)
if MODEL_LOADED:
generate_midi_with_model(prompt, midi_path, max_len=128, temperature=0.9)
else:
from midiutil import MIDIFile
midi = MIDIFile(1)
midi.addTempo(0, 0, 120)
notes = [60, 62, 64, 65, 67, 69, 71, 72]
for i, note in enumerate(notes[:min(len(prompt.split()), 8)]):
midi.addNote(0, 0, note, i, 1, 100)
with open(midi_path, "wb") as f:
midi.writeFile(f)
# Convert to WAV
if SOUNDFONT_PATH and midi_to_wav(midi_path, wav_path):
# Read WAV file and return as numpy array
sample_rate, audio_data = wavfile.read(wav_path)
# Convert to float32 for Gradio compatibility
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
return (sample_rate, audio_data)
else:
return None
except Exception as e:
import traceback
traceback.print_exc()
return None
finally:
# Clean up temp files
if midi_path and os.path.exists(midi_path):
try:
os.unlink(midi_path)
except:
pass
if wav_path and os.path.exists(wav_path):
try:
os.unlink(wav_path)
except:
pass
# Create simple Gradio Interface
demo = gr.Interface(
fn=generate_music,
inputs=gr.Textbox(
label="Music Prompt",
placeholder="A cheerful pop song with piano and drums in C major",
lines=2
),
outputs=gr.Audio(label="Generated Music", type="numpy"),
title="VR Game Music Generator",
description="Generate music from text descriptions using AI. Enter a prompt describing the music you want.",
examples=[
["A cheerful pop song with piano and drums"],
["An energetic electronic trance track at 138 BPM"],
["A slow emotional classical piece with violin"],
["Epic cinematic soundtrack with dark atmosphere"],
],
cache_examples=False,
allow_flagging="never"
)
# Launch
demo.launch(show_api=True)