from typing import Dict, List, Any import torch import os import numpy as np import soundfile as sf import base64 import io from songgen import ( VoiceBpeTokenizer, SongGenMixedForConditionalGeneration, SongGenProcessor ) class EndpointHandler: def __init__(self, path=""): # Load model and processor during initialization self.device = "cuda:0" if torch.cuda.is_available() else "cpu" self.model_path = path or "LiuZH-19/SongGen_mixed_pro" print(f"Loading model from {self.model_path} on {self.device}") self.model = SongGenMixedForConditionalGeneration.from_pretrained( self.model_path, attn_implementation='sdpa' ).to(self.device) self.processor = SongGenProcessor(self.model_path, self.device) self.sampling_rate = self.model.config.sampling_rate print("Model and processor loaded successfully") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data: Dictionary with the following keys: - text: Text description for music generation - lyrics: Lyrics for the song - ref_voice_base64: Base64 encoded reference voice audio (optional) - separate: Whether to separate vocal from reference (default: True) - do_sample: Whether to use sampling for generation (default: True) - generation_params: Additional parameters for generation (optional) Returns: Dictionary with audio data encoded in base64 """ # Extract params from the request text = data.get("text", "") lyrics = data.get("lyrics", "") ref_voice_base64 = data.get("ref_voice_base64", None) separate = data.get("separate", True) do_sample = data.get("do_sample", True) generation_params = data.get("generation_params", {}) # Handle reference audio if provided ref_voice_path = None if ref_voice_base64: # Decode base64 audio and save temporarily audio_bytes = base64.b64decode(ref_voice_base64) ref_voice_path = "/tmp/reference_audio.wav" with open(ref_voice_path, "wb") as f: f.write(audio_bytes) # Process inputs model_inputs = self.processor( text=text, lyrics=lyrics, ref_voice_path=ref_voice_path, separate=separate ) # Generate audio with torch.no_grad(): generation = self.model.generate( **model_inputs, do_sample=do_sample, **generation_params ) # Convert to audio array audio_arr = generation.cpu().numpy().squeeze() # Save to BytesIO and encode to base64 audio_buffer = io.BytesIO() sf.write(audio_buffer, audio_arr, self.sampling_rate, format='WAV') audio_buffer.seek(0) audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') # Clean up temp file if created if ref_voice_path and os.path.exists(ref_voice_path): os.remove(ref_voice_path) return { "audio_base64": audio_base64, "sampling_rate": self.sampling_rate }