Files changed (1) hide show
  1. handler.py +93 -0
handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import os
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import base64
7
+ import io
8
+ from songgen import (
9
+ VoiceBpeTokenizer,
10
+ SongGenMixedForConditionalGeneration,
11
+ SongGenProcessor
12
+ )
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path=""):
16
+ # Load model and processor during initialization
17
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+ self.model_path = path or "LiuZH-19/SongGen_mixed_pro"
19
+
20
+ print(f"Loading model from {self.model_path} on {self.device}")
21
+ self.model = SongGenMixedForConditionalGeneration.from_pretrained(
22
+ self.model_path,
23
+ attn_implementation='sdpa'
24
+ ).to(self.device)
25
+
26
+ self.processor = SongGenProcessor(self.model_path, self.device)
27
+ self.sampling_rate = self.model.config.sampling_rate
28
+ print("Model and processor loaded successfully")
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
+ """
32
+ Args:
33
+ data: Dictionary with the following keys:
34
+ - text: Text description for music generation
35
+ - lyrics: Lyrics for the song
36
+ - ref_voice_base64: Base64 encoded reference voice audio (optional)
37
+ - separate: Whether to separate vocal from reference (default: True)
38
+ - do_sample: Whether to use sampling for generation (default: True)
39
+ - generation_params: Additional parameters for generation (optional)
40
+
41
+ Returns:
42
+ Dictionary with audio data encoded in base64
43
+ """
44
+ # Extract params from the request
45
+ text = data.get("text", "")
46
+ lyrics = data.get("lyrics", "")
47
+ ref_voice_base64 = data.get("ref_voice_base64", None)
48
+ separate = data.get("separate", True)
49
+ do_sample = data.get("do_sample", True)
50
+ generation_params = data.get("generation_params", {})
51
+
52
+ # Handle reference audio if provided
53
+ ref_voice_path = None
54
+ if ref_voice_base64:
55
+ # Decode base64 audio and save temporarily
56
+ audio_bytes = base64.b64decode(ref_voice_base64)
57
+ ref_voice_path = "/tmp/reference_audio.wav"
58
+ with open(ref_voice_path, "wb") as f:
59
+ f.write(audio_bytes)
60
+
61
+ # Process inputs
62
+ model_inputs = self.processor(
63
+ text=text,
64
+ lyrics=lyrics,
65
+ ref_voice_path=ref_voice_path,
66
+ separate=separate
67
+ )
68
+
69
+ # Generate audio
70
+ with torch.no_grad():
71
+ generation = self.model.generate(
72
+ **model_inputs,
73
+ do_sample=do_sample,
74
+ **generation_params
75
+ )
76
+
77
+ # Convert to audio array
78
+ audio_arr = generation.cpu().numpy().squeeze()
79
+
80
+ # Save to BytesIO and encode to base64
81
+ audio_buffer = io.BytesIO()
82
+ sf.write(audio_buffer, audio_arr, self.sampling_rate, format='WAV')
83
+ audio_buffer.seek(0)
84
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
85
+
86
+ # Clean up temp file if created
87
+ if ref_voice_path and os.path.exists(ref_voice_path):
88
+ os.remove(ref_voice_path)
89
+
90
+ return {
91
+ "audio_base64": audio_base64,
92
+ "sampling_rate": self.sampling_rate
93
+ }