|
|
|
|
|
import asyncio |
|
|
from concurrent import futures |
|
|
import argparse |
|
|
import signal |
|
|
import sys |
|
|
import os |
|
|
import shutil |
|
|
import glob |
|
|
from typing import List |
|
|
import time |
|
|
import tempfile |
|
|
|
|
|
import backend_pb2 |
|
|
import backend_pb2_grpc |
|
|
|
|
|
import grpc |
|
|
from mlx_audio.tts.utils import load_model |
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
import uuid |
|
|
|
|
|
def is_float(s): |
|
|
"""Check if a string can be converted to float.""" |
|
|
try: |
|
|
float(s) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
def is_int(s): |
|
|
"""Check if a string can be converted to int.""" |
|
|
try: |
|
|
int(s) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
|
|
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 |
|
|
|
|
|
|
|
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) |
|
|
|
|
|
|
|
|
class BackendServicer(backend_pb2_grpc.BackendServicer): |
|
|
""" |
|
|
A gRPC servicer that implements the Backend service defined in backend.proto. |
|
|
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio. |
|
|
""" |
|
|
|
|
|
def Health(self, request, context): |
|
|
""" |
|
|
Returns a health check message. |
|
|
|
|
|
Args: |
|
|
request: The health check request. |
|
|
context: The gRPC context. |
|
|
|
|
|
Returns: |
|
|
backend_pb2.Reply: The health check reply. |
|
|
""" |
|
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) |
|
|
|
|
|
async def LoadModel(self, request, context): |
|
|
""" |
|
|
Loads a TTS model using MLX-Audio. |
|
|
|
|
|
Args: |
|
|
request: The load model request. |
|
|
context: The gRPC context. |
|
|
|
|
|
Returns: |
|
|
backend_pb2.Result: The load model result. |
|
|
""" |
|
|
try: |
|
|
print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr) |
|
|
print(f"Request: {request}", file=sys.stderr) |
|
|
|
|
|
|
|
|
options = request.Options |
|
|
self.options = {} |
|
|
|
|
|
|
|
|
|
|
|
for opt in options: |
|
|
if ":" not in opt: |
|
|
continue |
|
|
key, value = opt.split(":", 1) |
|
|
|
|
|
|
|
|
if is_float(value): |
|
|
value = float(value) |
|
|
elif is_int(value): |
|
|
value = int(value) |
|
|
elif value.lower() in ["true", "false"]: |
|
|
value = value.lower() == "true" |
|
|
|
|
|
self.options[key] = value |
|
|
|
|
|
print(f"Options: {self.options}", file=sys.stderr) |
|
|
|
|
|
|
|
|
try: |
|
|
self.tts_model = load_model(request.Model) |
|
|
self.model_path = request.Model |
|
|
print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr) |
|
|
except Exception as model_err: |
|
|
print(f"Error loading TTS model: {model_err}", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}") |
|
|
|
|
|
except Exception as err: |
|
|
print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}") |
|
|
|
|
|
print("MLX-Audio TTS model loaded successfully", file=sys.stderr) |
|
|
return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True) |
|
|
|
|
|
def TTS(self, request, context): |
|
|
""" |
|
|
Generates TTS audio from text using MLX-Audio. |
|
|
|
|
|
Args: |
|
|
request: A TTSRequest object containing text, model, destination, voice, and language. |
|
|
context: A grpc.ServicerContext object that provides information about the RPC. |
|
|
|
|
|
Returns: |
|
|
A Result object indicating success or failure. |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not hasattr(self, 'tts_model') or self.tts_model is None: |
|
|
return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.") |
|
|
|
|
|
print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr) |
|
|
|
|
|
|
|
|
speed_value = self._handle_speed_parameter(request, self.model_path) |
|
|
|
|
|
|
|
|
lang_code = self._map_language_code(request.language, request.voice) |
|
|
|
|
|
|
|
|
gen_params = { |
|
|
"text": request.text, |
|
|
"speed": speed_value, |
|
|
"verbose": False, |
|
|
} |
|
|
|
|
|
|
|
|
if request.voice and request.voice.strip(): |
|
|
gen_params["voice"] = request.voice |
|
|
|
|
|
|
|
|
if "kokoro" in self.model_path.lower(): |
|
|
gen_params["lang_code"] = lang_code |
|
|
|
|
|
|
|
|
if "spark" in self.model_path.lower(): |
|
|
gen_params["pitch"] = 1.0 |
|
|
gen_params["gender"] = "female" |
|
|
|
|
|
print(f"Generation parameters: {gen_params}", file=sys.stderr) |
|
|
|
|
|
|
|
|
try: |
|
|
results = self.tts_model.generate(**gen_params) |
|
|
except Exception as gen_err: |
|
|
print(f"Error during TTS generation: {gen_err}", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}") |
|
|
|
|
|
|
|
|
audio_arrays = [] |
|
|
for segment in results: |
|
|
audio_arrays.append(segment.audio) |
|
|
|
|
|
|
|
|
if not audio_arrays: |
|
|
print("No audio segments generated", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message="No audio generated") |
|
|
|
|
|
|
|
|
cat_audio = np.concatenate(audio_arrays, axis=0) |
|
|
|
|
|
|
|
|
if request.dst: |
|
|
output_path = request.dst |
|
|
else: |
|
|
unique_id = str(uuid.uuid4()) |
|
|
filename = f"tts_{unique_id}.wav" |
|
|
output_path = filename |
|
|
|
|
|
|
|
|
try: |
|
|
sf.write(output_path, cat_audio, 24000) |
|
|
print(f"Successfully wrote audio file to {output_path}", file=sys.stderr) |
|
|
|
|
|
|
|
|
if not os.path.exists(output_path): |
|
|
print(f"File was not created at {output_path}", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message="Failed to create audio file") |
|
|
|
|
|
file_size = os.path.getsize(output_path) |
|
|
if file_size == 0: |
|
|
print("File was created but is empty", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message="Generated audio file is empty") |
|
|
|
|
|
print(f"Audio file size: {file_size} bytes", file=sys.stderr) |
|
|
|
|
|
except Exception as write_err: |
|
|
print(f"Error writing audio file: {write_err}", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}") |
|
|
|
|
|
return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr) |
|
|
return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}") |
|
|
|
|
|
async def Predict(self, request, context): |
|
|
""" |
|
|
Generates TTS audio based on the given prompt using MLX-Audio TTS. |
|
|
This is a fallback method for compatibility with the Predict endpoint. |
|
|
|
|
|
Args: |
|
|
request: The predict request. |
|
|
context: The gRPC context. |
|
|
|
|
|
Returns: |
|
|
backend_pb2.Reply: The predict result. |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not hasattr(self, 'tts_model') or self.tts_model is None: |
|
|
context.set_code(grpc.StatusCode.FAILED_PRECONDITION) |
|
|
context.set_details("TTS model not loaded. Please call LoadModel first.") |
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8')) |
|
|
|
|
|
|
|
|
if not request.Prompt: |
|
|
context.set_code(grpc.StatusCode.INVALID_ARGUMENT) |
|
|
context.set_details("Prompt is required for TTS generation") |
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8')) |
|
|
|
|
|
|
|
|
speed_value = self._handle_speed_parameter(request, self.model_path) |
|
|
|
|
|
|
|
|
lang_code = self._map_language_code(None, None) |
|
|
|
|
|
|
|
|
gen_params = { |
|
|
"text": request.Prompt, |
|
|
"speed": speed_value, |
|
|
"verbose": False, |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(self, 'options') and 'voice' in self.options: |
|
|
gen_params["voice"] = self.options['voice'] |
|
|
|
|
|
|
|
|
if "kokoro" in self.model_path.lower(): |
|
|
gen_params["lang_code"] = lang_code |
|
|
|
|
|
print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr) |
|
|
|
|
|
|
|
|
try: |
|
|
results = self.tts_model.generate(**gen_params) |
|
|
except Exception as gen_err: |
|
|
print(f"Error during TTS generation: {gen_err}", file=sys.stderr) |
|
|
context.set_code(grpc.StatusCode.INTERNAL) |
|
|
context.set_details(f"TTS generation failed: {gen_err}") |
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8')) |
|
|
|
|
|
|
|
|
audio_arrays = [] |
|
|
for segment in results: |
|
|
audio_arrays.append(segment.audio) |
|
|
|
|
|
|
|
|
if not audio_arrays: |
|
|
print("No audio segments generated", file=sys.stderr) |
|
|
return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8')) |
|
|
|
|
|
|
|
|
cat_audio = np.concatenate(audio_arrays, axis=0) |
|
|
duration = len(cat_audio) / 24000 |
|
|
|
|
|
|
|
|
response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz" |
|
|
return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr) |
|
|
context.set_code(grpc.StatusCode.INTERNAL) |
|
|
context.set_details(f"TTS generation failed: {str(e)}") |
|
|
return backend_pb2.Reply(message=bytes("", encoding='utf-8')) |
|
|
|
|
|
def _handle_speed_parameter(self, request, model_path): |
|
|
""" |
|
|
Handle speed parameter based on model type. |
|
|
|
|
|
Args: |
|
|
request: The TTSRequest object. |
|
|
model_path: The model path to determine model type. |
|
|
|
|
|
Returns: |
|
|
float: The processed speed value. |
|
|
""" |
|
|
|
|
|
speed = 1.0 |
|
|
if hasattr(self, 'options') and 'speed' in self.options: |
|
|
speed = self.options['speed'] |
|
|
|
|
|
|
|
|
if "spark" in model_path.lower(): |
|
|
|
|
|
speed_map = { |
|
|
"very_low": 0.0, |
|
|
"low": 0.5, |
|
|
"moderate": 1.0, |
|
|
"high": 1.5, |
|
|
"very_high": 2.0, |
|
|
} |
|
|
if isinstance(speed, str) and speed in speed_map: |
|
|
speed_value = speed_map[speed] |
|
|
else: |
|
|
|
|
|
try: |
|
|
speed_value = float(speed) |
|
|
if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]: |
|
|
speed_value = 1.0 |
|
|
except: |
|
|
speed_value = 1.0 |
|
|
else: |
|
|
|
|
|
try: |
|
|
speed_value = float(speed) |
|
|
if speed_value < 0.5 or speed_value > 2.0: |
|
|
speed_value = 1.0 |
|
|
except ValueError: |
|
|
speed_value = 1.0 |
|
|
|
|
|
return speed_value |
|
|
|
|
|
def _map_language_code(self, language, voice): |
|
|
""" |
|
|
Map language names to codes if needed. |
|
|
|
|
|
Args: |
|
|
language: The language parameter from the request. |
|
|
voice: The voice parameter from the request. |
|
|
|
|
|
Returns: |
|
|
str: The language code. |
|
|
""" |
|
|
if not language: |
|
|
|
|
|
return voice[0] if voice else "a" |
|
|
|
|
|
|
|
|
language_map = { |
|
|
"american_english": "a", |
|
|
"british_english": "b", |
|
|
"spanish": "e", |
|
|
"french": "f", |
|
|
"hindi": "h", |
|
|
"italian": "i", |
|
|
"portuguese": "p", |
|
|
"japanese": "j", |
|
|
"mandarin_chinese": "z", |
|
|
|
|
|
"a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z", |
|
|
} |
|
|
|
|
|
return language_map.get(language.lower(), language) |
|
|
|
|
|
def _build_generation_params(self, request, default_speed=1.0): |
|
|
""" |
|
|
Build generation parameters from request attributes and options for MLX-Audio TTS. |
|
|
|
|
|
Args: |
|
|
request: The gRPC request. |
|
|
default_speed: Default speed if not specified. |
|
|
|
|
|
Returns: |
|
|
dict: Generation parameters for MLX-Audio |
|
|
""" |
|
|
|
|
|
generation_params = { |
|
|
'speed': default_speed, |
|
|
'voice': 'af_heart', |
|
|
'lang_code': 'a', |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(request, 'Temperature') and request.Temperature > 0: |
|
|
|
|
|
generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5 |
|
|
|
|
|
|
|
|
if hasattr(self, 'options'): |
|
|
|
|
|
if 'speed' in self.options: |
|
|
generation_params['speed'] = self.options['speed'] |
|
|
|
|
|
|
|
|
if 'voice' in self.options: |
|
|
generation_params['voice'] = self.options['voice'] |
|
|
|
|
|
|
|
|
if 'lang_code' in self.options: |
|
|
generation_params['lang_code'] = self.options['lang_code'] |
|
|
|
|
|
|
|
|
param_option_mapping = { |
|
|
'temp': 'speed', |
|
|
'temperature': 'speed', |
|
|
'top_p': 'speed', |
|
|
} |
|
|
|
|
|
for option_key, param_key in param_option_mapping.items(): |
|
|
if option_key in self.options: |
|
|
if param_key == 'speed': |
|
|
|
|
|
speed_val = float(self.options[option_key]) |
|
|
if 0.5 <= speed_val <= 2.0: |
|
|
generation_params[param_key] = speed_val |
|
|
|
|
|
return generation_params |
|
|
|
|
|
async def serve(address): |
|
|
|
|
|
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), |
|
|
options=[ |
|
|
('grpc.max_message_length', 50 * 1024 * 1024), |
|
|
('grpc.max_send_message_length', 50 * 1024 * 1024), |
|
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), |
|
|
]) |
|
|
|
|
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) |
|
|
|
|
|
server.add_insecure_port(address) |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
for sig in (signal.SIGINT, signal.SIGTERM): |
|
|
loop.add_signal_handler( |
|
|
sig, lambda: asyncio.ensure_future(server.stop(5)) |
|
|
) |
|
|
|
|
|
|
|
|
await server.start() |
|
|
print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr) |
|
|
|
|
|
await server.wait_for_termination() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.") |
|
|
parser.add_argument( |
|
|
"--addr", default="localhost:50051", help="The address to bind the server to." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
asyncio.run(serve(args.addr)) |
|
|
|