| import re |
| from threading import Lock |
|
|
| import pyrootutils |
| import uvicorn |
| from kui.asgi import ( |
| Depends, |
| FactoryClass, |
| HTTPException, |
| HttpRoute, |
| Kui, |
| OpenAPI, |
| Routes, |
| ) |
| from kui.cors import CORSConfig |
| from kui.openapi.specification import Info |
| from kui.security import bearer_auth |
| from loguru import logger |
| from typing_extensions import Annotated |
|
|
| pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
|
|
| from tools.server.api_utils import MsgPackRequest, parse_args |
| from tools.server.exception_handler import ExceptionHandler |
| from tools.server.model_manager import ModelManager |
| from tools.server.views import routes |
|
|
|
|
| class API(ExceptionHandler): |
| def __init__(self): |
| self.args = parse_args() |
|
|
| def api_auth(endpoint): |
| async def verify(token: Annotated[str, Depends(bearer_auth)]): |
| if token != self.args.api_key: |
| raise HTTPException(401, None, "Invalid token") |
| return await endpoint() |
|
|
| async def passthrough(): |
| return await endpoint() |
|
|
| if self.args.api_key is not None: |
| return verify |
| else: |
| return passthrough |
|
|
| self.routes = Routes( |
| routes, |
| http_middlewares=[api_auth], |
| ) |
|
|
| |
| self.openapi = OpenAPI( |
| Info( |
| { |
| "title": "Fish Speech API", |
| "version": "1.5.0", |
| } |
| ), |
| ).routes |
|
|
| |
| self.app = Kui( |
| routes=self.routes + self.openapi[1:], |
| exception_handlers={ |
| HTTPException: self.http_exception_handler, |
| Exception: self.other_exception_handler, |
| }, |
| factory_class=FactoryClass(http=MsgPackRequest), |
| cors_config=CORSConfig(), |
| ) |
|
|
| |
| self.app.state.lock = Lock() |
| self.app.state.device = self.args.device |
| self.app.state.max_text_length = self.args.max_text_length |
|
|
| |
| self.app.on_startup(self.initialize_app) |
|
|
| async def initialize_app(self, app: Kui): |
| |
| app.state.model_manager = ModelManager( |
| mode=self.args.mode, |
| device=self.args.device, |
| half=self.args.half, |
| compile=self.args.compile, |
| asr_enabled=self.args.load_asr_model, |
| llama_checkpoint_path=self.args.llama_checkpoint_path, |
| decoder_checkpoint_path=self.args.decoder_checkpoint_path, |
| decoder_config_name=self.args.decoder_config_name, |
| ) |
|
|
| logger.info(f"Startup done, listening server at http://{self.args.listen}") |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| api = API() |
|
|
| |
| match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen) |
| if match: |
| host, port = match.groups() |
| else: |
| host, port = api.args.listen.split(":") |
|
|
| uvicorn.run( |
| api.app, |
| host=host, |
| port=int(port), |
| workers=api.args.workers, |
| log_level="info", |
| ) |
|
|