File size: 4,451 Bytes
3303abf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from argparse import ArgumentParser
from http import HTTPStatus
from typing import Annotated, Any
import ormsgpack
from baize.datastructures import ContentType
from kui.asgi import (
HTTPException,
HttpRequest,
JSONResponse,
request,
)
from loguru import logger
from pydantic import BaseModel
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.utils.schema import ServeTTSRequest
from tools.server.inference import inference_wrapper as inference
def parse_args():
parser = ArgumentParser()
parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/s2-pro",
)
parser.add_argument(
"--decoder-checkpoint-path",
type=str,
default="checkpoints/s2-pro/codec.pth",
)
parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--api-key", type=str, default=None)
return parser.parse_args()
class MsgPackRequest(HttpRequest):
async def data(
self,
) -> Annotated[
Any,
ContentType("application/msgpack"),
ContentType("application/json"),
ContentType("multipart/form-data"),
]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)
elif self.content_type == "application/json":
return await self.json
elif self.content_type == "multipart/form-data":
return await self.form
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={
"Accept": "application/msgpack, application/json, multipart/form-data"
},
)
async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
for chunk in inference(req, engine):
print("Got chunk")
if isinstance(chunk, bytes):
yield chunk
async def buffer_to_async_generator(buffer):
yield buffer
def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
elif audio_format == "flac":
return "audio/flac"
elif audio_format == "mp3":
return "audio/mpeg"
elif audio_format == "opus":
return "audio/ogg"
else:
return "application/octet-stream"
def wants_json(req):
"""Helper method to determine if the client wants a JSON response
Parameters
----------
req : Request
The request object
Returns
-------
bool
True if the client wants a JSON response, False otherwise
"""
q = req.query_params.get("format", "").strip().lower()
if q in {"json", "application/json", "msgpack", "application/msgpack"}:
return q in ("json", "application/json")
accept = req.headers.get("Accept", "").strip().lower()
return "application/json" in accept and "application/msgpack" not in accept
def format_response(response: BaseModel, status_code=200):
"""
Helper function to format responses consistently based on client preference.
Parameters
----------
response : BaseModel
The response object to format
status_code : int
HTTP status code (default: 200)
Returns
-------
Response
Formatted response in the client's preferred format
"""
try:
if wants_json(request):
return JSONResponse(
response.model_dump(mode="json"), status_code=status_code
)
return (
ormsgpack.packb(
response,
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
),
status_code,
{"Content-Type": "application/msgpack"},
)
except Exception as e:
logger.error(f"Error formatting response: {e}", exc_info=True)
# Fallback to JSON response if formatting fails
return JSONResponse(
{"error": "Response formatting failed", "details": str(e)}, status_code=500
)
|