""" VR Music Generator - HuggingFace Spaces Version Generates music from text descriptions using the text2midi AI model. """ import gradio as gr import torch import subprocess import os import sys import pickle import tempfile import numpy as np import scipy.io.wavfile as wavfile from huggingface_hub import hf_hub_download # Add text2midi model to path sys.path.insert(0, "text2midi_repo") repo_id = "amaai-lab/text2midi" # Detect device device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # Global model variables text2midi_model = None midi_tokenizer = None text_tokenizer = None def load_text2midi_model(): """Load the text2midi model and tokenizers.""" global text2midi_model, midi_tokenizer, text_tokenizer try: from model.transformer_model import Transformer from transformers import T5Tokenizer print("Loading text2midi model...") # Download model files model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl") print(f"Model path: {model_path}") print(f"Tokenizer path: {tokenizer_path}") # Load MIDI tokenizer with open(tokenizer_path, "rb") as f: midi_tokenizer = pickle.load(f) vocab_size = len(midi_tokenizer) print(f"Vocab size: {vocab_size}") # Initialize and load model text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device) text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) text2midi_model.to(device) text2midi_model.eval() # Load T5 tokenizer for text encoding text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") print("Text2midi model loaded successfully!") return True except Exception as e: print(f"Warning: Could not load text2midi model: {e}") import traceback traceback.print_exc() print("Falling back to simple MIDI generation...") return False # Try to load the model MODEL_LOADED = load_text2midi_model() def find_soundfont(): """Find a SoundFont file on the system.""" common_paths = [ "/usr/share/sounds/sf2/FluidR3_GM.sf2", "/usr/share/soundfonts/FluidR3_GM.sf2", "/usr/share/sounds/sf2/default-GM.sf2", "FluidR3_GM.sf2", ] for path in common_paths: if os.path.exists(path): return path return None SOUNDFONT_PATH = find_soundfont() print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}") def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9): """Generate MIDI using the text2midi model.""" global text2midi_model, midi_tokenizer, text_tokenizer inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) with torch.no_grad(): output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature) output_list = output[0].tolist() generated_midi = midi_tokenizer.decode(output_list) generated_midi.dump_midi(output_path) return output_path def midi_to_wav(midi_path: str, wav_path: str, sample_rate: int = 44100) -> bool: """Convert MIDI to WAV using FluidSynth.""" if not SOUNDFONT_PATH: return False result = subprocess.run([ "fluidsynth", "-ni", "-F", wav_path, "-r", str(sample_rate), SOUNDFONT_PATH, midi_path, ], capture_output=True, text=True, timeout=120) if result.returncode != 0: print(f"FluidSynth error: {result.stderr}") return False return os.path.exists(wav_path) def generate_music(prompt: str): """Generate music from text prompt. Returns audio as numpy array.""" if not prompt or not prompt.strip(): return None midi_path = None wav_path = None try: # Create temporary files midi_fd, midi_path = tempfile.mkstemp(suffix='.mid') os.close(midi_fd) wav_fd, wav_path = tempfile.mkstemp(suffix='.wav') os.close(wav_fd) # Generate MIDI (reduced max_len for faster testing) if MODEL_LOADED: generate_midi_with_model(prompt, midi_path, max_len=128, temperature=0.9) else: from midiutil import MIDIFile midi = MIDIFile(1) midi.addTempo(0, 0, 120) notes = [60, 62, 64, 65, 67, 69, 71, 72] for i, note in enumerate(notes[:min(len(prompt.split()), 8)]): midi.addNote(0, 0, note, i, 1, 100) with open(midi_path, "wb") as f: midi.writeFile(f) # Convert to WAV if SOUNDFONT_PATH and midi_to_wav(midi_path, wav_path): # Read WAV file and return as numpy array sample_rate, audio_data = wavfile.read(wav_path) # Convert to float32 for Gradio compatibility if audio_data.dtype == np.int16: audio_data = audio_data.astype(np.float32) / 32768.0 elif audio_data.dtype == np.int32: audio_data = audio_data.astype(np.float32) / 2147483648.0 return (sample_rate, audio_data) else: return None except Exception as e: import traceback traceback.print_exc() return None finally: # Clean up temp files if midi_path and os.path.exists(midi_path): try: os.unlink(midi_path) except: pass if wav_path and os.path.exists(wav_path): try: os.unlink(wav_path) except: pass # Create simple Gradio Interface demo = gr.Interface( fn=generate_music, inputs=gr.Textbox( label="Music Prompt", placeholder="A cheerful pop song with piano and drums in C major", lines=2 ), outputs=gr.Audio(label="Generated Music", type="numpy"), title="VR Game Music Generator", description="Generate music from text descriptions using AI. Enter a prompt describing the music you want.", examples=[ ["A cheerful pop song with piano and drums"], ["An energetic electronic trance track at 138 BPM"], ["A slow emotional classical piece with violin"], ["Epic cinematic soundtrack with dark atmosphere"], ], cache_examples=False, allow_flagging="never" ) # Launch demo.launch(show_api=True)