File size: 3,412 Bytes
6999b68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Dict, List, Any
import torch
import os
import numpy as np
import soundfile as sf
import base64
import io
from songgen import (
VoiceBpeTokenizer,
SongGenMixedForConditionalGeneration,
SongGenProcessor
)
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor during initialization
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model_path = path or "LiuZH-19/SongGen_mixed_pro"
print(f"Loading model from {self.model_path} on {self.device}")
self.model = SongGenMixedForConditionalGeneration.from_pretrained(
self.model_path,
attn_implementation='sdpa'
).to(self.device)
self.processor = SongGenProcessor(self.model_path, self.device)
self.sampling_rate = self.model.config.sampling_rate
print("Model and processor loaded successfully")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data: Dictionary with the following keys:
- text: Text description for music generation
- lyrics: Lyrics for the song
- ref_voice_base64: Base64 encoded reference voice audio (optional)
- separate: Whether to separate vocal from reference (default: True)
- do_sample: Whether to use sampling for generation (default: True)
- generation_params: Additional parameters for generation (optional)
Returns:
Dictionary with audio data encoded in base64
"""
# Extract params from the request
text = data.get("text", "")
lyrics = data.get("lyrics", "")
ref_voice_base64 = data.get("ref_voice_base64", None)
separate = data.get("separate", True)
do_sample = data.get("do_sample", True)
generation_params = data.get("generation_params", {})
# Handle reference audio if provided
ref_voice_path = None
if ref_voice_base64:
# Decode base64 audio and save temporarily
audio_bytes = base64.b64decode(ref_voice_base64)
ref_voice_path = "/tmp/reference_audio.wav"
with open(ref_voice_path, "wb") as f:
f.write(audio_bytes)
# Process inputs
model_inputs = self.processor(
text=text,
lyrics=lyrics,
ref_voice_path=ref_voice_path,
separate=separate
)
# Generate audio
with torch.no_grad():
generation = self.model.generate(
**model_inputs,
do_sample=do_sample,
**generation_params
)
# Convert to audio array
audio_arr = generation.cpu().numpy().squeeze()
# Save to BytesIO and encode to base64
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_arr, self.sampling_rate, format='WAV')
audio_buffer.seek(0)
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
# Clean up temp file if created
if ref_voice_path and os.path.exists(ref_voice_path):
os.remove(ref_voice_path)
return {
"audio_base64": audio_base64,
"sampling_rate": self.sampling_rate
} |