| | |
| | import argparse |
| | import contextlib |
| | import gc |
| | import os |
| | import queue |
| | import re |
| | import subprocess |
| | import sys |
| | import threading |
| | import time |
| | import yaml |
| | import json |
| |
|
| | from fastapi.responses import StreamingResponse |
| | from loguru import logger |
| | from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError |
| | from pydantic import BaseModel |
| | import uvicorn |
| |
|
| | @contextlib.asynccontextmanager |
| | async def lifespan(app): |
| | yield |
| | gc.collect() |
| | try: |
| | import torch |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| | except: |
| | pass |
| |
|
| | app = OpenAIStub(lifespan=lifespan) |
| | xtts = None |
| | args = None |
| |
|
| | def unload_model(): |
| | import torch, gc |
| | global xtts |
| | if xtts: |
| | logger.info("Unloading model") |
| | xtts.xtts.to('cpu') |
| | del xtts |
| | xtts = None |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| |
|
| | class xtts_wrapper(): |
| | check_interval: int = 1 |
| |
|
| | def __init__(self, model_name, device, model_path=None, unload_timer=None): |
| | self.model_name = model_name |
| | self.unload_timer = unload_timer |
| | self.last_used = time.time() |
| | self.timer = None |
| | self.lock = threading.Lock() |
| |
|
| | logger.info(f"Loading model {self.model_name} to {device}") |
| |
|
| | if model_path is None: |
| | model_path = ModelManager().download_model(model_name)[0] |
| |
|
| | config_path = os.path.join(model_path, 'config.json') |
| | config = XttsConfig() |
| | config.load_json(config_path) |
| | self.xtts = Xtts.init_from_config(config) |
| | self.xtts.load_checkpoint(config, checkpoint_dir=model_path, use_deepspeed=args.use_deepspeed) |
| | self.xtts = self.xtts.to(device=device) |
| | self.xtts.eval() |
| |
|
| | if self.unload_timer: |
| | logger.info(f"Setting unload timer to {self.unload_timer} seconds") |
| | self.last_used = time.time() |
| | self.check_idle() |
| |
|
| | def check_idle(self): |
| | with self.lock: |
| | if time.time() - self.last_used >= self.unload_timer: |
| | print("Unloading TTS model due to inactivity") |
| | unload_model() |
| | else: |
| | |
| | self.timer = threading.Timer(self.check_interval, self.check_idle) |
| | self.timer.daemon = True |
| | self.timer.start() |
| |
|
| | def tts(self, text, language, audio_path, **hf_generate_kwargs): |
| | with torch.no_grad(): |
| | self.last_used = time.time() |
| | tokens = 0 |
| | try: |
| | with self.lock: |
| | logger.debug(f"generating [{language}]: {[text]}") |
| |
|
| | gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=audio_path) |
| | pcm_stream = self.xtts.inference_stream(text, language, gpt_cond_latent, speaker_embedding, **hf_generate_kwargs) |
| | self.last_used = time.time() |
| |
|
| | while True: |
| | with self.lock: |
| | yield next(pcm_stream).cpu().numpy().tobytes() |
| | self.last_used = time.time() |
| | tokens += 1 |
| |
|
| | except StopIteration: |
| | pass |
| |
|
| | finally: |
| | logger.debug(f"Generated {tokens} tokens in {time.time() - self.last_used:.2f}s @ {tokens / (time.time() - self.last_used):.2f} T/s") |
| | self.last_used = time.time() |
| |
|
| | def default_exists(filename: str): |
| | if not os.path.exists(filename): |
| | fpath, ext = os.path.splitext(filename) |
| | basename = os.path.basename(fpath) |
| | default = f"{basename}.default{ext}" |
| | |
| | logger.info(f"{filename} does not exist, setting defaults from {default}") |
| |
|
| | with open(default, 'r', encoding='utf8') as from_file: |
| | with open(filename, 'w', encoding='utf8') as to_file: |
| | to_file.write(from_file.read()) |
| |
|
| | |
| | def preprocess(raw_input): |
| | |
| | default_exists('config/pre_process_map.yaml') |
| | with open('config/pre_process_map.yaml', 'r', encoding='utf8') as file: |
| | pre_process_map = yaml.safe_load(file) |
| | for a, b in pre_process_map: |
| | raw_input = re.sub(a, b, raw_input) |
| | |
| | raw_input = raw_input.strip() |
| | |
| | return raw_input |
| |
|
| | |
| | def map_voice_to_speaker(voice: str, model: str): |
| | default_exists('config/voice_to_speaker.yaml') |
| | with open('config/voice_to_speaker.yaml', 'r', encoding='utf8') as file: |
| | voice_map = yaml.safe_load(file) |
| | try: |
| | return voice_map[model][voice] |
| |
|
| | except KeyError as e: |
| | raise BadRequestError(f"Error loading voice: {voice}, KeyError: {e}", param='voice') |
| |
|
| | class GenerateSpeechRequest(BaseModel): |
| | model: str = "tts-1" |
| | input: str |
| | voice: str = "alloy" |
| | response_format: str = "mp3" |
| | speed: float = 1.0 |
| |
|
| | def build_ffmpeg_args(response_format, input_format, sample_rate): |
| | |
| | if input_format == 'WAV': |
| | ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "WAV", "-i", "-"] |
| | else: |
| | ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", input_format, "-ar", sample_rate, "-ac", "1", "-i", "-"] |
| | |
| | if response_format == "mp3": |
| | ffmpeg_args.extend(["-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k"]) |
| | elif response_format == "opus": |
| | ffmpeg_args.extend(["-f", "ogg", "-c:a", "libopus"]) |
| | elif response_format == "aac": |
| | ffmpeg_args.extend(["-f", "adts", "-c:a", "aac", "-ab", "64k"]) |
| | elif response_format == "flac": |
| | ffmpeg_args.extend(["-f", "flac", "-c:a", "flac"]) |
| | elif response_format == "wav": |
| | ffmpeg_args.extend(["-f", "wav", "-c:a", "pcm_s16le"]) |
| | elif response_format == "pcm": |
| | ffmpeg_args.extend(["-f", "s16le", "-c:a", "pcm_s16le"]) |
| |
|
| | return ffmpeg_args |
| |
|
| | @app.post("/v1/audio/speech", response_class=StreamingResponse) |
| | async def generate_speech(request: GenerateSpeechRequest): |
| | global xtts, args |
| | if len(request.input) < 1: |
| | raise BadRequestError("Empty Input", param='input') |
| |
|
| | input_text = preprocess(request.input) |
| |
|
| | if len(input_text) < 1: |
| | raise BadRequestError("Input text empty after preprocess.", param='input') |
| |
|
| | model = request.model |
| | voice = request.voice |
| | response_format = request.response_format.lower() |
| | speed = request.speed |
| |
|
| | |
| | if response_format == "mp3": |
| | media_type = "audio/mpeg" |
| | elif response_format == "opus": |
| | media_type = "audio/ogg;codec=opus" |
| | elif response_format == "aac": |
| | media_type = "audio/aac" |
| | elif response_format == "flac": |
| | media_type = "audio/x-flac" |
| | elif response_format == "wav": |
| | media_type = "audio/wav" |
| | elif response_format == "pcm": |
| | if model == 'tts-1': |
| | media_type = "audio/pcm;rate=22050" |
| | elif model == 'tts-1-hd': |
| | media_type = "audio/pcm;rate=24000" |
| | else: |
| | raise BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format') |
| |
|
| | ffmpeg_args = None |
| |
|
| | |
| | if model == 'tts-1' or args.xtts_device == 'none': |
| | voice_map = map_voice_to_speaker(voice, 'tts-1') |
| | try: |
| | piper_model = voice_map['model'] |
| |
|
| | except KeyError as e: |
| | raise ServiceUnavailableError(f"Configuration error: tts-1 voice '{voice}' is missing 'model:' setting. KeyError: {e}") |
| |
|
| | speaker = voice_map.get('speaker', None) |
| |
|
| | tts_args = ["piper", "--model", str(piper_model), "--data-dir", "voices", "--download-dir", "voices", "--output-raw"] |
| | if speaker: |
| | tts_args.extend(["--speaker", str(speaker)]) |
| | if speed != 1.0: |
| | tts_args.extend(["--length-scale", f"{1.0/speed}"]) |
| |
|
| | tts_proc = subprocess.Popen(tts_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) |
| | tts_proc.stdin.write(bytearray(input_text.encode('utf-8'))) |
| | tts_proc.stdin.close() |
| |
|
| | try: |
| | with open(f"{piper_model}.json", 'r') as pvc_f: |
| | conf = json.load(pvc_f) |
| | sample_rate = str(conf['audio']['sample_rate']) |
| |
|
| | except: |
| | sample_rate = '22050' |
| | |
| | ffmpeg_args = build_ffmpeg_args(response_format, input_format="s16le", sample_rate=sample_rate) |
| |
|
| | |
| | ffmpeg_args.extend(["-"]) |
| | ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_proc.stdout, stdout=subprocess.PIPE) |
| |
|
| | return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type) |
| | |
| | elif model == 'tts-1-hd': |
| | voice_map = map_voice_to_speaker(voice, 'tts-1-hd') |
| | try: |
| | tts_model = voice_map.pop('model') |
| | speaker = voice_map.pop('speaker') |
| |
|
| | except KeyError as e: |
| | raise ServiceUnavailableError(f"Configuration error: tts-1-hd voice '{voice}' is missing setting. KeyError: {e}") |
| |
|
| | if xtts and xtts.model_name != tts_model: |
| | unload_model() |
| |
|
| | tts_model_path = voice_map.pop('model_path', None) |
| |
|
| | if xtts is None: |
| | xtts = xtts_wrapper(tts_model, device=args.xtts_device, model_path=tts_model_path, unload_timer=args.unload_timer) |
| |
|
| | ffmpeg_args = build_ffmpeg_args(response_format, input_format="f32le", sample_rate="24000") |
| |
|
| | |
| | speed = voice_map.pop('speed', speed) |
| | if speed < 0.5: |
| | speed = speed / 0.5 |
| | ffmpeg_args.extend(["-af", "atempo=0.5"]) |
| | if speed > 1.0: |
| | ffmpeg_args.extend(["-af", f"atempo={speed}"]) |
| | speed = 1.0 |
| |
|
| | |
| | ffmpeg_args.extend(["-"]) |
| |
|
| | language = voice_map.pop('language', 'auto') |
| | if language == 'auto': |
| | try: |
| | language = detect(input_text) |
| | if language not in [ |
| | 'en', 'es', 'fr', 'de', 'it', 'pt', 'pl', 'tr', |
| | 'ru', 'nl', 'cs', 'ar', 'zh-cn', 'hu', 'ko', 'ja', 'hi' |
| | ]: |
| | logger.debug(f"Detected language {language} not supported, defaulting to en") |
| | language = 'en' |
| | else: |
| | logger.debug(f"Detected language: {language}") |
| | except: |
| | language = 'en' |
| | logger.debug(f"Failed to detect language, defaulting to en") |
| |
|
| | comment = voice_map.pop('comment', None) |
| |
|
| | hf_generate_kwargs = dict( |
| | speed=speed, |
| | **voice_map, |
| | ) |
| |
|
| | hf_generate_kwargs['enable_text_splitting'] = hf_generate_kwargs.get('enable_text_splitting', True) |
| |
|
| | if hf_generate_kwargs['enable_text_splitting']: |
| | if language == 'zh-cn': |
| | split_lang = 'zh' |
| | else: |
| | split_lang = language |
| | all_text = split_sentence(input_text, split_lang, xtts.xtts.tokenizer.char_limits[split_lang]) |
| | else: |
| | all_text = [input_text] |
| |
|
| | ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) |
| |
|
| | in_q = queue.Queue() |
| | ex_q = queue.Queue() |
| |
|
| | def get_speaker_samples(samples: str) -> list[str]: |
| | if os.path.isfile(samples): |
| | audio_path = [samples] |
| | elif os.path.isdir(samples): |
| | audio_path = [os.path.join(samples, sample) for sample in os.listdir(samples) if os.path.isfile(os.path.join(samples, sample))] |
| |
|
| | if len(audio_path) < 1: |
| | logger.error(f"No files found: {samples}") |
| | raise ServiceUnavailableError(f"Invalid path: {samples}") |
| | else: |
| | logger.error(f"Invalid path: {samples}") |
| | raise ServiceUnavailableError(f"Invalid path: {samples}") |
| | |
| | return audio_path |
| |
|
| | def exception_check(exq: queue.Queue): |
| | try: |
| | e = exq.get_nowait() |
| | except queue.Empty: |
| | return |
| | |
| | raise e |
| |
|
| | def generator(): |
| | |
| |
|
| | audio_path = get_speaker_samples(speaker) |
| | logger.debug(f"{voice} wav samples: {audio_path}") |
| |
|
| | try: |
| | for text in all_text: |
| | for chunk in xtts.tts(text=text, language=language, audio_path=audio_path, **hf_generate_kwargs): |
| | exception_check(ex_q) |
| | in_q.put(chunk) |
| |
|
| | except BrokenPipeError as e: |
| | logger.info("Client disconnected - 'Broken pipe'") |
| |
|
| | except Exception as e: |
| | logger.error(f"Exception: {repr(e)}") |
| | raise e |
| | |
| | finally: |
| | in_q.put(None) |
| |
|
| | def out_writer(): |
| | |
| | try: |
| | while True: |
| | chunk = in_q.get() |
| | if chunk is None: |
| | break |
| | ffmpeg_proc.stdin.write(chunk) |
| |
|
| | except Exception as e: |
| | ex_q.put(e) |
| | ffmpeg_proc.kill() |
| | return |
| | |
| | finally: |
| | ffmpeg_proc.stdin.close() |
| |
|
| | generator_worker = threading.Thread(target=generator, daemon=True) |
| | generator_worker.start() |
| |
|
| | out_writer_worker = threading.Thread(target=out_writer, daemon=True) |
| | out_writer_worker.start() |
| |
|
| | def cleanup(): |
| | ffmpeg_proc.kill() |
| | del generator_worker |
| | del out_writer_worker |
| |
|
| | return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type, background=cleanup) |
| | else: |
| | raise BadRequestError("No such model, must be tts-1 or tts-1-hd.", param='model') |
| |
|
| |
|
| | |
| | def auto_torch_device(): |
| | try: |
| | import torch |
| | return 'cuda' if torch.cuda.is_available() else 'mps' if ( torch.backends.mps.is_available() and torch.backends.mps.is_built() ) else 'cpu' |
| | |
| | except: |
| | return 'none' |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description='OpenedAI Speech API Server', |
| | formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| |
|
| | parser.add_argument('--xtts_device', action='store', default=auto_torch_device(), help="Set the device for the xtts model. The special value of 'none' will use piper for all models.") |
| | parser.add_argument('--preload', action='store', default=None, help="Preload a model (Ex. 'xtts' or 'xtts_v2.0.2'). By default it's loaded on first use.") |
| | parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds, Ex. 900 for 15 minutes") |
| | parser.add_argument('--use-deepspeed', action='store_true', default=False, help="Use deepspeed with xtts (this option is unsupported)") |
| | parser.add_argument('--no-cache-speaker', action='store_true', default=False, help="Don't use the speaker wav embeddings cache") |
| | parser.add_argument('-P', '--port', action='store', default=8000, type=int, help="Server tcp port") |
| | parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. 0.0.0.0") |
| | parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level") |
| |
|
| | args = parser.parse_args() |
| |
|
| | default_exists('config/pre_process_map.yaml') |
| | default_exists('config/voice_to_speaker.yaml') |
| |
|
| | logger.remove() |
| | logger.add(sink=sys.stderr, level=args.log_level) |
| |
|
| | if args.xtts_device != "none": |
| | import torch |
| | from TTS.tts.configs.xtts_config import XttsConfig |
| | from TTS.tts.models.xtts import Xtts |
| | from TTS.utils.manage import ModelManager |
| | from TTS.tts.layers.xtts.tokenizer import split_sentence |
| | from langdetect import detect |
| |
|
| | if args.preload: |
| | xtts = xtts_wrapper(args.preload, device=args.xtts_device, unload_timer=args.unload_timer) |
| |
|
| | app.register_model('tts-1') |
| | app.register_model('tts-1-hd') |
| |
|
| | uvicorn.run(app, host=args.host, port=args.port) |
| |
|