| from argparse import ArgumentParser |
| from http import HTTPStatus |
| from typing import Annotated, Any |
|
|
| import ormsgpack |
| from baize.datastructures import ContentType |
| from kui.asgi import HTTPException, HttpRequest |
|
|
| from fish_speech.inference_engine import TTSInferenceEngine |
| from fish_speech.utils.schema import ServeTTSRequest |
| from tools.server.inference import inference_wrapper as inference |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts") |
| parser.add_argument("--load-asr-model", action="store_true") |
| parser.add_argument( |
| "--llama-checkpoint-path", |
| type=str, |
| default="checkpoints/openaudio-s1-mini", |
| ) |
| parser.add_argument( |
| "--decoder-checkpoint-path", |
| type=str, |
| default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", |
| ) |
| parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--half", action="store_true") |
| parser.add_argument("--compile", action="store_true") |
| parser.add_argument("--max-text-length", type=int, default=0) |
| parser.add_argument("--listen", type=str, default="127.0.0.1:8080") |
| parser.add_argument("--workers", type=int, default=1) |
| parser.add_argument("--api-key", type=str, default=None) |
|
|
| return parser.parse_args() |
|
|
|
|
| class MsgPackRequest(HttpRequest): |
| async def data( |
| self, |
| ) -> Annotated[ |
| Any, ContentType("application/msgpack"), ContentType("application/json") |
| ]: |
| if self.content_type == "application/msgpack": |
| return ormsgpack.unpackb(await self.body) |
|
|
| elif self.content_type == "application/json": |
| return await self.json |
|
|
| raise HTTPException( |
| HTTPStatus.UNSUPPORTED_MEDIA_TYPE, |
| headers={"Accept": "application/msgpack, application/json"}, |
| ) |
|
|
|
|
| async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine): |
| for chunk in inference(req, engine): |
| if isinstance(chunk, bytes): |
| yield chunk |
|
|
|
|
| async def buffer_to_async_generator(buffer): |
| yield buffer |
|
|
|
|
| def get_content_type(audio_format): |
| if audio_format == "wav": |
| return "audio/wav" |
| elif audio_format == "flac": |
| return "audio/flac" |
| elif audio_format == "mp3": |
| return "audio/mpeg" |
| else: |
| return "application/octet-stream" |
|
|