File size: 9,370 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#!/usr/bin/env python3
"""
This is an extra gRPC server of LocalAI for Pocket TTS
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import traceback
import scipy.io.wavfile
import backend_pb2
import backend_pb2_grpc
import torch
from pocket_tts import TTSModel

import grpc

def is_float(s):
    """Check if a string can be converted to float."""
    try:
        float(s)
        return True
    except ValueError:
        return False

def is_int(s):
    """Check if a string can be converted to int."""
    try:
        int(s)
        return True
    except ValueError:
        return False

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
    """
    BackendServicer is the class that implements the gRPC service
    """
    def Health(self, request, context):
        return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
    
    def LoadModel(self, request, context):
        # Get device
        if torch.cuda.is_available():
            print("CUDA is available", file=sys.stderr)
            device = "cuda"
        else:
            print("CUDA is not available", file=sys.stderr)
            device = "cpu"
        mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        if mps_available:
            device = "mps"
        if not torch.cuda.is_available() and request.CUDA:
            return backend_pb2.Result(success=False, message="CUDA is not available")

        # Normalize potential 'mpx' typo to 'mps'
        if device == "mpx":
            print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr)
            device = "mps"
        
        # Validate mps availability if requested
        if device == "mps" and not torch.backends.mps.is_available():
            print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr)
            device = "cpu"

        self.device = device

        options = request.Options

        # empty dict
        self.options = {}

        # The options are a list of strings in this form optname:optvalue
        # We are storing all the options in a dict so we can use it later when
        # generating the audio
        for opt in options:
            if ":" not in opt:
                continue
            key, value = opt.split(":", 1)  # Split only on first colon
            # if value is a number, convert it to the appropriate type
            if is_float(value):
                value = float(value)
            elif is_int(value):
                value = int(value)
            elif value.lower() in ["true", "false"]:
                value = value.lower() == "true"
            self.options[key] = value

        # Default voice for caching
        self.default_voice_url = self.options.get("default_voice", None)
        self._voice_cache = {}

        try:
            print("Loading Pocket TTS model", file=sys.stderr)
            self.tts_model = TTSModel.load_model()
            print(f"Model loaded successfully. Sample rate: {self.tts_model.sample_rate}", file=sys.stderr)

            # Pre-load default voice if specified
            if self.default_voice_url:
                try:
                    print(f"Pre-loading default voice: {self.default_voice_url}", file=sys.stderr)
                    voice_state = self.tts_model.get_state_for_audio_prompt(self.default_voice_url)
                    self._voice_cache[self.default_voice_url] = voice_state
                    print("Default voice loaded successfully", file=sys.stderr)
                except Exception as e:
                    print(f"Warning: Failed to pre-load default voice: {e}", file=sys.stderr)

        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        
        return backend_pb2.Result(message="Model loaded successfully", success=True)

    def _get_voice_state(self, voice_input):
        """
        Get voice state from cache or load it.
        voice_input can be:
        - HuggingFace URL (e.g., hf://kyutai/tts-voices/alba-mackenna/casual.wav)
        - Local file path
        - None (use default)
        """
        # Use default if no voice specified
        if not voice_input:
            voice_input = self.default_voice_url

        if not voice_input:
            return None

        # Check cache first
        if voice_input in self._voice_cache:
            return self._voice_cache[voice_input]

        # Load voice state
        try:
            print(f"Loading voice from: {voice_input}", file=sys.stderr)
            voice_state = self.tts_model.get_state_for_audio_prompt(voice_input)
            self._voice_cache[voice_input] = voice_state
            return voice_state
        except Exception as e:
            print(f"Error loading voice from {voice_input}: {e}", file=sys.stderr)
            return None

    def TTS(self, request, context):
        try:
            # Determine voice input
            # Priority: request.voice > AudioPath (from ModelOptions) > default
            voice_input = None
            
            if request.voice:
                voice_input = request.voice
            elif hasattr(request, 'AudioPath') and request.AudioPath:
                # Use AudioPath as voice file
                if os.path.isabs(request.AudioPath):
                    voice_input = request.AudioPath
                elif hasattr(request, 'ModelFile') and request.ModelFile:
                    model_file_base = os.path.dirname(request.ModelFile)
                    voice_input = os.path.join(model_file_base, request.AudioPath)
                elif hasattr(request, 'ModelPath') and request.ModelPath:
                    voice_input = os.path.join(request.ModelPath, request.AudioPath)
                else:
                    voice_input = request.AudioPath

            # Get voice state
            voice_state = self._get_voice_state(voice_input)
            if voice_state is None:
                return backend_pb2.Result(
                    success=False,
                    message=f"Voice not found or failed to load: {voice_input}. Please provide a valid voice URL or file path."
                )

            # Prepare text
            text = request.text.strip()

            if not text:
                return backend_pb2.Result(
                    success=False,
                    message="Text is empty"
                )

            print(f"Generating audio for text: {text[:50]}...", file=sys.stderr)

            # Generate audio
            audio = self.tts_model.generate_audio(voice_state, text)

            # Audio is a 1D torch tensor containing PCM data
            if audio is None or audio.numel() == 0:
                return backend_pb2.Result(
                    success=False,
                    message="No audio generated"
                )

            # Save audio to file
            output_path = request.dst
            if not output_path:
                output_path = "/tmp/pocket-tts-output.wav"

            # Ensure output directory exists
            output_dir = os.path.dirname(output_path)
            if output_dir and not os.path.exists(output_dir):
                os.makedirs(output_dir, exist_ok=True)

            # Convert torch tensor to numpy and save
            audio_numpy = audio.numpy()
            scipy.io.wavfile.write(output_path, self.tts_model.sample_rate, audio_numpy)
            print(f"Saved audio to {output_path}", file=sys.stderr)

        except Exception as err:
            print(f"Error in TTS: {err}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        
        return backend_pb2.Result(success=True)

def serve(address):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
        options=[
            ('grpc.max_message_length', 50 * 1024 * 1024),  # 50MB
            ('grpc.max_send_message_length', 50 * 1024 * 1024),  # 50MB
            ('grpc.max_receive_message_length', 50 * 1024 * 1024),  # 50MB
        ])
    backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
    server.add_insecure_port(address)
    server.start()
    print("Server started. Listening on: " + address, file=sys.stderr)

    # Define the signal handler function
    def signal_handler(sig, frame):
        print("Received termination signal. Shutting down...")
        server.stop(0)
        sys.exit(0)

    # Set the signal handlers for SIGINT and SIGTERM
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the gRPC server.")
    parser.add_argument(
        "--addr", default="localhost:50051", help="The address to bind the server to."
    )
    args = parser.parse_args()

    serve(args.addr)