|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
import os |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import base64 |
|
|
import io |
|
|
from songgen import ( |
|
|
VoiceBpeTokenizer, |
|
|
SongGenMixedForConditionalGeneration, |
|
|
SongGenProcessor |
|
|
) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
self.model_path = path or "LiuZH-19/SongGen_mixed_pro" |
|
|
|
|
|
print(f"Loading model from {self.model_path} on {self.device}") |
|
|
self.model = SongGenMixedForConditionalGeneration.from_pretrained( |
|
|
self.model_path, |
|
|
attn_implementation='sdpa' |
|
|
).to(self.device) |
|
|
|
|
|
self.processor = SongGenProcessor(self.model_path, self.device) |
|
|
self.sampling_rate = self.model.config.sampling_rate |
|
|
print("Model and processor loaded successfully") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Args: |
|
|
data: Dictionary with the following keys: |
|
|
- text: Text description for music generation |
|
|
- lyrics: Lyrics for the song |
|
|
- ref_voice_base64: Base64 encoded reference voice audio (optional) |
|
|
- separate: Whether to separate vocal from reference (default: True) |
|
|
- do_sample: Whether to use sampling for generation (default: True) |
|
|
- generation_params: Additional parameters for generation (optional) |
|
|
|
|
|
Returns: |
|
|
Dictionary with audio data encoded in base64 |
|
|
""" |
|
|
|
|
|
text = data.get("text", "") |
|
|
lyrics = data.get("lyrics", "") |
|
|
ref_voice_base64 = data.get("ref_voice_base64", None) |
|
|
separate = data.get("separate", True) |
|
|
do_sample = data.get("do_sample", True) |
|
|
generation_params = data.get("generation_params", {}) |
|
|
|
|
|
|
|
|
ref_voice_path = None |
|
|
if ref_voice_base64: |
|
|
|
|
|
audio_bytes = base64.b64decode(ref_voice_base64) |
|
|
ref_voice_path = "/tmp/reference_audio.wav" |
|
|
with open(ref_voice_path, "wb") as f: |
|
|
f.write(audio_bytes) |
|
|
|
|
|
|
|
|
model_inputs = self.processor( |
|
|
text=text, |
|
|
lyrics=lyrics, |
|
|
ref_voice_path=ref_voice_path, |
|
|
separate=separate |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generation = self.model.generate( |
|
|
**model_inputs, |
|
|
do_sample=do_sample, |
|
|
**generation_params |
|
|
) |
|
|
|
|
|
|
|
|
audio_arr = generation.cpu().numpy().squeeze() |
|
|
|
|
|
|
|
|
audio_buffer = io.BytesIO() |
|
|
sf.write(audio_buffer, audio_arr, self.sampling_rate, format='WAV') |
|
|
audio_buffer.seek(0) |
|
|
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') |
|
|
|
|
|
|
|
|
if ref_voice_path and os.path.exists(ref_voice_path): |
|
|
os.remove(ref_voice_path) |
|
|
|
|
|
return { |
|
|
"audio_base64": audio_base64, |
|
|
"sampling_rate": self.sampling_rate |
|
|
} |