File size: 4,408 Bytes
3303abf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import json
import multiprocessing
import os
import re
from argparse import Namespace
from threading import Lock
import pyrootutils
import uvicorn
from kui.asgi import (
Depends,
FactoryClass,
HTTPException,
HttpRoute,
Kui,
OpenAPI,
Routes,
)
from kui.cors import CORSConfig
from kui.openapi.specification import Info
from kui.security import bearer_auth
from loguru import logger
from typing_extensions import Annotated
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from tools.server.api_utils import MsgPackRequest, parse_args
from tools.server.exception_handler import ExceptionHandler
from tools.server.model_manager import ModelManager
from tools.server.views import routes
ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
class API(ExceptionHandler):
def __init__(self, args: Namespace | None = None):
self.args = args or parse_args()
def api_auth(endpoint):
async def verify(token: Annotated[str, Depends(bearer_auth)]):
if token != self.args.api_key:
raise HTTPException(401, None, "Invalid token")
return await endpoint()
async def passthrough():
return await endpoint()
if self.args.api_key is not None:
return verify
else:
return passthrough
self.routes = Routes(
routes, # keep existing routes
http_middlewares=[api_auth], # apply api_auth middleware
)
# OpenAPIの設定
self.openapi = OpenAPI(
Info(
{
"title": "Fish Speech API",
"version": "1.5.0",
}
),
).routes
# Initialize the app
self.app = Kui(
routes=self.routes + self.openapi[1:], # Remove the default route
exception_handlers={
HTTPException: self.http_exception_handler,
Exception: self.other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config=CORSConfig(),
)
# Add the state variables
self.app.state.lock = Lock()
self.app.state.device = self.args.device
self.app.state.max_text_length = self.args.max_text_length
# Associate the app with the model manager
self.app.on_startup(self.initialize_app)
async def initialize_app(self, app: Kui):
# Make the ModelManager available to the views
app.state.model_manager = ModelManager(
mode=self.args.mode,
device=self.args.device,
half=self.args.half,
compile=self.args.compile,
llama_checkpoint_path=self.args.llama_checkpoint_path,
decoder_checkpoint_path=self.args.decoder_checkpoint_path,
decoder_config_name=self.args.decoder_config_name,
)
logger.info(f"Startup done, listening server at http://{self.args.listen}")
def create_app():
args_env = os.environ.get(ENV_ARGS_KEY)
args = None
if args_env:
try:
args = Namespace(**json.loads(args_env))
except Exception as exc:
logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
return API(args=args).app
# Each worker process created by Uvicorn has its own memory space,
# meaning that models and variables are not shared between processes.
# Therefore, any variables (like `llama_queue` or `decoder_model`)
# will not be shared across workers.
# Multi-threading for deep learning can cause issues, such as inconsistent
# outputs if multiple threads access the same buffers simultaneously.
# Instead, it's better to use multiprocessing or independent models per thread.
if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)
args = parse_args()
os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
# IPv6 address format is [xxxx:xxxx::xxxx]:port
match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
if match:
host, port = match.groups() # IPv6
else:
host, port = args.listen.split(":") # IPv4
uvicorn.run(
"tools.api_server:create_app",
host=host,
port=int(port),
workers=args.workers,
log_level="info",
factory=True,
)
|