Spaces:
Sleeping
Sleeping
| """ | |
| VR Music Generator - HuggingFace Spaces Version | |
| Generates music from text descriptions using the text2midi AI model. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import subprocess | |
| import os | |
| import sys | |
| import pickle | |
| import tempfile | |
| import numpy as np | |
| import scipy.io.wavfile as wavfile | |
| from huggingface_hub import hf_hub_download | |
| # Add text2midi model to path | |
| sys.path.insert(0, "text2midi_repo") | |
| repo_id = "amaai-lab/text2midi" | |
| # Detect device | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Using device: {device}") | |
| # Global model variables | |
| text2midi_model = None | |
| midi_tokenizer = None | |
| text_tokenizer = None | |
| def load_text2midi_model(): | |
| """Load the text2midi model and tokenizers.""" | |
| global text2midi_model, midi_tokenizer, text_tokenizer | |
| try: | |
| from model.transformer_model import Transformer | |
| from transformers import T5Tokenizer | |
| print("Loading text2midi model...") | |
| # Download model files | |
| model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") | |
| tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl") | |
| print(f"Model path: {model_path}") | |
| print(f"Tokenizer path: {tokenizer_path}") | |
| # Load MIDI tokenizer | |
| with open(tokenizer_path, "rb") as f: | |
| midi_tokenizer = pickle.load(f) | |
| vocab_size = len(midi_tokenizer) | |
| print(f"Vocab size: {vocab_size}") | |
| # Initialize and load model | |
| text2midi_model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device) | |
| text2midi_model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) | |
| text2midi_model.to(device) | |
| text2midi_model.eval() | |
| # Load T5 tokenizer for text encoding | |
| text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
| print("Text2midi model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Warning: Could not load text2midi model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print("Falling back to simple MIDI generation...") | |
| return False | |
| # Try to load the model | |
| MODEL_LOADED = load_text2midi_model() | |
| def find_soundfont(): | |
| """Find a SoundFont file on the system.""" | |
| common_paths = [ | |
| "/usr/share/sounds/sf2/FluidR3_GM.sf2", | |
| "/usr/share/soundfonts/FluidR3_GM.sf2", | |
| "/usr/share/sounds/sf2/default-GM.sf2", | |
| "FluidR3_GM.sf2", | |
| ] | |
| for path in common_paths: | |
| if os.path.exists(path): | |
| return path | |
| return None | |
| SOUNDFONT_PATH = find_soundfont() | |
| print(f"SoundFont: {SOUNDFONT_PATH or 'Not found'}") | |
| def generate_midi_with_model(prompt: str, output_path: str, max_len: int = 512, temperature: float = 0.9): | |
| """Generate MIDI using the text2midi model.""" | |
| global text2midi_model, midi_tokenizer, text_tokenizer | |
| inputs = text_tokenizer(prompt, return_tensors='pt', padding=True, truncation=True) | |
| input_ids = inputs.input_ids.to(device) | |
| attention_mask = inputs.attention_mask.to(device) | |
| with torch.no_grad(): | |
| output = text2midi_model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature) | |
| output_list = output[0].tolist() | |
| generated_midi = midi_tokenizer.decode(output_list) | |
| generated_midi.dump_midi(output_path) | |
| return output_path | |
| def midi_to_wav(midi_path: str, wav_path: str, sample_rate: int = 44100) -> bool: | |
| """Convert MIDI to WAV using FluidSynth.""" | |
| if not SOUNDFONT_PATH: | |
| return False | |
| result = subprocess.run([ | |
| "fluidsynth", | |
| "-ni", | |
| "-F", wav_path, | |
| "-r", str(sample_rate), | |
| SOUNDFONT_PATH, | |
| midi_path, | |
| ], capture_output=True, text=True, timeout=120) | |
| if result.returncode != 0: | |
| print(f"FluidSynth error: {result.stderr}") | |
| return False | |
| return os.path.exists(wav_path) | |
| def generate_music(prompt: str): | |
| """Generate music from text prompt. Returns audio as numpy array.""" | |
| if not prompt or not prompt.strip(): | |
| return None | |
| midi_path = None | |
| wav_path = None | |
| try: | |
| # Create temporary files | |
| midi_fd, midi_path = tempfile.mkstemp(suffix='.mid') | |
| os.close(midi_fd) | |
| wav_fd, wav_path = tempfile.mkstemp(suffix='.wav') | |
| os.close(wav_fd) | |
| # Generate MIDI (reduced max_len for faster testing) | |
| if MODEL_LOADED: | |
| generate_midi_with_model(prompt, midi_path, max_len=128, temperature=0.9) | |
| else: | |
| from midiutil import MIDIFile | |
| midi = MIDIFile(1) | |
| midi.addTempo(0, 0, 120) | |
| notes = [60, 62, 64, 65, 67, 69, 71, 72] | |
| for i, note in enumerate(notes[:min(len(prompt.split()), 8)]): | |
| midi.addNote(0, 0, note, i, 1, 100) | |
| with open(midi_path, "wb") as f: | |
| midi.writeFile(f) | |
| # Convert to WAV | |
| if SOUNDFONT_PATH and midi_to_wav(midi_path, wav_path): | |
| # Read WAV file and return as numpy array | |
| sample_rate, audio_data = wavfile.read(wav_path) | |
| # Convert to float32 for Gradio compatibility | |
| if audio_data.dtype == np.int16: | |
| audio_data = audio_data.astype(np.float32) / 32768.0 | |
| elif audio_data.dtype == np.int32: | |
| audio_data = audio_data.astype(np.float32) / 2147483648.0 | |
| return (sample_rate, audio_data) | |
| else: | |
| return None | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| finally: | |
| # Clean up temp files | |
| if midi_path and os.path.exists(midi_path): | |
| try: | |
| os.unlink(midi_path) | |
| except: | |
| pass | |
| if wav_path and os.path.exists(wav_path): | |
| try: | |
| os.unlink(wav_path) | |
| except: | |
| pass | |
| # Create simple Gradio Interface | |
| demo = gr.Interface( | |
| fn=generate_music, | |
| inputs=gr.Textbox( | |
| label="Music Prompt", | |
| placeholder="A cheerful pop song with piano and drums in C major", | |
| lines=2 | |
| ), | |
| outputs=gr.Audio(label="Generated Music", type="numpy"), | |
| title="VR Game Music Generator", | |
| description="Generate music from text descriptions using AI. Enter a prompt describing the music you want.", | |
| examples=[ | |
| ["A cheerful pop song with piano and drums"], | |
| ["An energetic electronic trance track at 138 BPM"], | |
| ["A slow emotional classical piece with violin"], | |
| ["Epic cinematic soundtrack with dark atmosphere"], | |
| ], | |
| cache_examples=False, | |
| allow_flagging="never" | |
| ) | |
| # Launch | |
| demo.launch(show_api=True) | |