File size: 3,412 Bytes
6999b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Dict, List, Any
import torch
import os
import numpy as np
import soundfile as sf
import base64
import io
from songgen import (
    VoiceBpeTokenizer,
    SongGenMixedForConditionalGeneration,
    SongGenProcessor
)

class EndpointHandler:
    def __init__(self, path=""):
        # Load model and processor during initialization
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.model_path = path or "LiuZH-19/SongGen_mixed_pro"
        
        print(f"Loading model from {self.model_path} on {self.device}")
        self.model = SongGenMixedForConditionalGeneration.from_pretrained(
            self.model_path,
            attn_implementation='sdpa'
        ).to(self.device)
        
        self.processor = SongGenProcessor(self.model_path, self.device)
        self.sampling_rate = self.model.config.sampling_rate
        print("Model and processor loaded successfully")
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            data: Dictionary with the following keys:
                - text: Text description for music generation
                - lyrics: Lyrics for the song
                - ref_voice_base64: Base64 encoded reference voice audio (optional)
                - separate: Whether to separate vocal from reference (default: True)
                - do_sample: Whether to use sampling for generation (default: True)
                - generation_params: Additional parameters for generation (optional)
        
        Returns:
            Dictionary with audio data encoded in base64
        """
        # Extract params from the request
        text = data.get("text", "")
        lyrics = data.get("lyrics", "")
        ref_voice_base64 = data.get("ref_voice_base64", None)
        separate = data.get("separate", True)
        do_sample = data.get("do_sample", True)
        generation_params = data.get("generation_params", {})
        
        # Handle reference audio if provided
        ref_voice_path = None
        if ref_voice_base64:
            # Decode base64 audio and save temporarily
            audio_bytes = base64.b64decode(ref_voice_base64)
            ref_voice_path = "/tmp/reference_audio.wav"
            with open(ref_voice_path, "wb") as f:
                f.write(audio_bytes)
        
        # Process inputs
        model_inputs = self.processor(
            text=text, 
            lyrics=lyrics, 
            ref_voice_path=ref_voice_path, 
            separate=separate
        )
        
        # Generate audio
        with torch.no_grad():
            generation = self.model.generate(
                **model_inputs,
                do_sample=do_sample,
                **generation_params
            )
        
        # Convert to audio array
        audio_arr = generation.cpu().numpy().squeeze()
        
        # Save to BytesIO and encode to base64
        audio_buffer = io.BytesIO()
        sf.write(audio_buffer, audio_arr, self.sampling_rate, format='WAV')
        audio_buffer.seek(0)
        audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
        
        # Clean up temp file if created
        if ref_voice_path and os.path.exists(ref_voice_path):
            os.remove(ref_voice_path)
        
        return {
            "audio_base64": audio_base64,
            "sampling_rate": self.sampling_rate
        }