File size: 13,924 Bytes
750bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
"""
This is an extra gRPC server of LocalAI for VoxCPM
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import traceback
import numpy as np
import soundfile as sf
from voxcpm import VoxCPM

import backend_pb2
import backend_pb2_grpc
import torch

import grpc

def is_float(s):
    """Check if a string can be converted to float."""
    try:
        float(s)
        return True
    except ValueError:
        return False

def is_int(s):
    """Check if a string can be converted to int."""
    try:
        int(s)
        return True
    except ValueError:
        return False

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
    """
    BackendServicer is the class that implements the gRPC service
    """
    def Health(self, request, context):
        return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
    
    def LoadModel(self, request, context):
        # Get device
        if torch.cuda.is_available():
            print("CUDA is available", file=sys.stderr)
            device = "cuda"
        else:
            print("CUDA is not available", file=sys.stderr)
            device = "cpu"
        mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        if mps_available:
            device = "mps"
        if not torch.cuda.is_available() and request.CUDA:
            return backend_pb2.Result(success=False, message="CUDA is not available")

        # Normalize potential 'mpx' typo to 'mps'
        if device == "mpx":
            print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
            device = "mps"
        
        # Validate mps availability if requested
        if device == "mps" and not torch.backends.mps.is_available():
            print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
            device = "cpu"

        self.device = device

        options = request.Options

        # empty dict
        self.options = {}

        # The options are a list of strings in this form optname:optvalue
        # We are storing all the options in a dict so we can use it later when
        # generating the audio
        for opt in options:
            if ":" not in opt:
                continue
            key, value = opt.split(":", 1)  # Split only on first colon
            # if value is a number, convert it to the appropriate type
            if is_float(value):
                value = float(value)
            elif is_int(value):
                value = int(value)
            elif value.lower() in ["true", "false"]:
                value = value.lower() == "true"
            self.options[key] = value

        # Get model path from request
        model_path = request.Model
        if not model_path:
            model_path = "openbmb/VoxCPM1.5"
        
        try:
            print(f"Loading model from {model_path}", file=sys.stderr)
            self.model = VoxCPM.from_pretrained(model_path)
            print(f"Model loaded successfully on device: {self.device}", file=sys.stderr)
        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        
        return backend_pb2.Result(message="Model loaded successfully", success=True)

    def TTS(self, request, context):
        try:
            # Get generation parameters from options with defaults
            cfg_value = self.options.get("cfg_value", 2.0)
            inference_timesteps = self.options.get("inference_timesteps", 10)
            normalize = self.options.get("normalize", False)
            denoise = self.options.get("denoise", False)
            retry_badcase = self.options.get("retry_badcase", True)
            retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
            retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
            use_streaming = self.options.get("streaming", False)

            # Handle voice cloning via prompt_wav_path and prompt_text
            prompt_wav_path = None
            prompt_text = None

            # Priority: request.voice > AudioPath > options
            if hasattr(request, 'voice') and request.voice:
                # If voice is provided, try to use it as a path
                if os.path.exists(request.voice):
                    prompt_wav_path = request.voice
                elif hasattr(request, 'ModelFile') and request.ModelFile:
                    model_file_base = os.path.dirname(request.ModelFile)
                    potential_path = os.path.join(model_file_base, request.voice)
                    if os.path.exists(potential_path):
                        prompt_wav_path = potential_path
                elif hasattr(request, 'ModelPath') and request.ModelPath:
                    potential_path = os.path.join(request.ModelPath, request.voice)
                    if os.path.exists(potential_path):
                        prompt_wav_path = potential_path

            if hasattr(request, 'AudioPath') and request.AudioPath:
                if os.path.isabs(request.AudioPath):
                    prompt_wav_path = request.AudioPath
                elif hasattr(request, 'ModelFile') and request.ModelFile:
                    model_file_base = os.path.dirname(request.ModelFile)
                    prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
                elif hasattr(request, 'ModelPath') and request.ModelPath:
                    prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
                else:
                    prompt_wav_path = request.AudioPath

            # Get prompt_text from options if available
            if "prompt_text" in self.options:
                prompt_text = self.options["prompt_text"]

            # Prepare text
            text = request.text.strip()

            print(f"Generating audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, streaming: {use_streaming}", file=sys.stderr)

            # Generate audio
            if use_streaming:
                # Streaming generation
                chunks = []
                for chunk in self.model.generate_streaming(
                    text=text,
                    prompt_wav_path=prompt_wav_path,
                    prompt_text=prompt_text,
                    cfg_value=cfg_value,
                    inference_timesteps=inference_timesteps,
                    normalize=normalize,
                    denoise=denoise,
                    retry_badcase=retry_badcase,
                    retry_badcase_max_times=retry_badcase_max_times,
                    retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
                ):
                    chunks.append(chunk)
                wav = np.concatenate(chunks)
            else:
                # Non-streaming generation
                wav = self.model.generate(
                    text=text,
                    prompt_wav_path=prompt_wav_path,
                    prompt_text=prompt_text,
                    cfg_value=cfg_value,
                    inference_timesteps=inference_timesteps,
                    normalize=normalize,
                    denoise=denoise,
                    retry_badcase=retry_badcase,
                    retry_badcase_max_times=retry_badcase_max_times,
                    retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
                )

            # Get sample rate from model
            sample_rate = self.model.tts_model.sample_rate

            # Save output
            sf.write(request.dst, wav, sample_rate)
            print(f"Saved output to {request.dst}", file=sys.stderr)

        except Exception as err:
            print(f"Error in TTS: {err}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        
        return backend_pb2.Result(success=True)

    def TTSStream(self, request, context):
        try:
            # Get generation parameters from options with defaults
            cfg_value = self.options.get("cfg_value", 2.0)
            inference_timesteps = self.options.get("inference_timesteps", 10)
            normalize = self.options.get("normalize", False)
            denoise = self.options.get("denoise", False)
            retry_badcase = self.options.get("retry_badcase", True)
            retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
            retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)

            # Handle voice cloning via prompt_wav_path and prompt_text
            prompt_wav_path = None
            prompt_text = None

            # Priority: request.voice > AudioPath > options
            if hasattr(request, 'voice') and request.voice:
                # If voice is provided, try to use it as a path
                if os.path.exists(request.voice):
                    prompt_wav_path = request.voice
                elif hasattr(request, 'ModelFile') and request.ModelFile:
                    model_file_base = os.path.dirname(request.ModelFile)
                    potential_path = os.path.join(model_file_base, request.voice)
                    if os.path.exists(potential_path):
                        prompt_wav_path = potential_path
                elif hasattr(request, 'ModelPath') and request.ModelPath:
                    potential_path = os.path.join(request.ModelPath, request.voice)
                    if os.path.exists(potential_path):
                        prompt_wav_path = potential_path

            if hasattr(request, 'AudioPath') and request.AudioPath:
                if os.path.isabs(request.AudioPath):
                    prompt_wav_path = request.AudioPath
                elif hasattr(request, 'ModelFile') and request.ModelFile:
                    model_file_base = os.path.dirname(request.ModelFile)
                    prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
                elif hasattr(request, 'ModelPath') and request.ModelPath:
                    prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
                else:
                    prompt_wav_path = request.AudioPath

            # Get prompt_text from options if available
            if "prompt_text" in self.options:
                prompt_text = self.options["prompt_text"]

            # Prepare text
            text = request.text.strip()

            # Get sample rate from model (needed for WAV header)
            sample_rate = self.model.tts_model.sample_rate

            print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr)

            # Send sample rate as first message (in message field as JSON or string)
            # Format: "sample_rate:16000" so we can parse it
            import json
            sample_rate_info = json.dumps({"sample_rate": int(sample_rate)})
            yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8'))

            # Stream audio chunks
            for chunk in self.model.generate_streaming(
                text=text,
                prompt_wav_path=prompt_wav_path,
                prompt_text=prompt_text,
                cfg_value=cfg_value,
                inference_timesteps=inference_timesteps,
                normalize=normalize,
                denoise=denoise,
                retry_badcase=retry_badcase,
                retry_badcase_max_times=retry_badcase_max_times,
                retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
            ):
                # Convert numpy array to int16 PCM and then to bytes
                # Ensure values are in int16 range
                chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16)
                chunk_bytes = chunk_int16.tobytes()
                yield backend_pb2.Reply(audio=chunk_bytes)

        except Exception as err:
            print(f"Error in TTSStream: {err}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            # Yield an error reply
            yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8'))

def serve(address):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
        options=[
            ('grpc.max_message_length', 50 * 1024 * 1024),  # 50MB
            ('grpc.max_send_message_length', 50 * 1024 * 1024),  # 50MB
            ('grpc.max_receive_message_length', 50 * 1024 * 1024),  # 50MB
        ])
    backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
    server.add_insecure_port(address)
    server.start()
    print("Server started. Listening on: " + address, file=sys.stderr)

    # Define the signal handler function
    def signal_handler(sig, frame):
        print("Received termination signal. Shutting down...")
        server.stop(0)
        sys.exit(0)

    # Set the signal handlers for SIGINT and SIGTERM
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the gRPC server.")
    parser.add_argument(
        "--addr", default="localhost:50051", help="The address to bind the server to."
    )
    args = parser.parse_args()

    serve(args.addr)