SongGen_mixed_pro / handler.py
harikrishnad1997's picture
Create handler.py
6999b68 verified
raw
history blame
3.41 kB
from typing import Dict, List, Any
import torch
import os
import numpy as np
import soundfile as sf
import base64
import io
from songgen import (
VoiceBpeTokenizer,
SongGenMixedForConditionalGeneration,
SongGenProcessor
)
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor during initialization
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model_path = path or "LiuZH-19/SongGen_mixed_pro"
print(f"Loading model from {self.model_path} on {self.device}")
self.model = SongGenMixedForConditionalGeneration.from_pretrained(
self.model_path,
attn_implementation='sdpa'
).to(self.device)
self.processor = SongGenProcessor(self.model_path, self.device)
self.sampling_rate = self.model.config.sampling_rate
print("Model and processor loaded successfully")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data: Dictionary with the following keys:
- text: Text description for music generation
- lyrics: Lyrics for the song
- ref_voice_base64: Base64 encoded reference voice audio (optional)
- separate: Whether to separate vocal from reference (default: True)
- do_sample: Whether to use sampling for generation (default: True)
- generation_params: Additional parameters for generation (optional)
Returns:
Dictionary with audio data encoded in base64
"""
# Extract params from the request
text = data.get("text", "")
lyrics = data.get("lyrics", "")
ref_voice_base64 = data.get("ref_voice_base64", None)
separate = data.get("separate", True)
do_sample = data.get("do_sample", True)
generation_params = data.get("generation_params", {})
# Handle reference audio if provided
ref_voice_path = None
if ref_voice_base64:
# Decode base64 audio and save temporarily
audio_bytes = base64.b64decode(ref_voice_base64)
ref_voice_path = "/tmp/reference_audio.wav"
with open(ref_voice_path, "wb") as f:
f.write(audio_bytes)
# Process inputs
model_inputs = self.processor(
text=text,
lyrics=lyrics,
ref_voice_path=ref_voice_path,
separate=separate
)
# Generate audio
with torch.no_grad():
generation = self.model.generate(
**model_inputs,
do_sample=do_sample,
**generation_params
)
# Convert to audio array
audio_arr = generation.cpu().numpy().squeeze()
# Save to BytesIO and encode to base64
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_arr, self.sampling_rate, format='WAV')
audio_buffer.seek(0)
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
# Clean up temp file if created
if ref_voice_path and os.path.exists(ref_voice_path):
os.remove(ref_voice_path)
return {
"audio_base64": audio_base64,
"sampling_rate": self.sampling_rate
}