| import argparse |
| import uvicorn |
| import sys |
|
|
| from fastapi import FastAPI |
| from pydantic import BaseModel, Field |
| from sse_starlette.sse import EventSourceResponse |
| from utils.logger import logger |
| from networks.message_streamer import MessageStreamer |
| from messagers.message_composer import MessageComposer |
|
|
|
|
| class ChatAPIApp: |
| def __init__(self): |
| self.app = FastAPI( |
| docs_url="/", |
| title="HuggingFace LLM API", |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
| version="1.0", |
| ) |
| self.setup_routes() |
|
|
| def get_available_models(self): |
| self.available_models = [ |
| { |
| "id": "mixtral-8x7b", |
| "description": "Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", |
| }, |
| ] |
| return self.available_models |
|
|
| class ChatCompletionsPostItem(BaseModel): |
| model: str = Field( |
| default="mixtral-8x7b", |
| description="(str) `mixtral-8x7b`", |
| ) |
| messages: list = Field( |
| default=[{"role": "user", "content": "Hello, who are you?"}], |
| description="(list) Messages", |
| ) |
| temperature: float = Field( |
| default=0.01, |
| description="(float) Temperature", |
| ) |
| max_tokens: int = Field( |
| default=32000, |
| description="(int) Max tokens", |
| ) |
| stream: bool = Field( |
| default=True, |
| description="(bool) Stream", |
| ) |
|
|
| def chat_completions(self, item: ChatCompletionsPostItem): |
| streamer = MessageStreamer(model=item.model) |
| composer = MessageComposer(model=item.model) |
| composer.merge(messages=item.messages) |
| return EventSourceResponse( |
| streamer.chat( |
| prompt=composer.merged_str, |
| temperature=item.temperature, |
| max_new_tokens=item.max_tokens, |
| stream=item.stream, |
| yield_output=True, |
| ), |
| media_type="text/event-stream", |
| ) |
|
|
| def setup_routes(self): |
| for prefix in ["", "/v1"]: |
| self.app.get( |
| prefix + "/models", |
| summary="Get available models", |
| )(self.get_available_models) |
|
|
| self.app.post( |
| prefix + "/chat/completions", |
| summary="Chat completions in conversation session", |
| )(self.chat_completions) |
|
|
|
|
| class ArgParser(argparse.ArgumentParser): |
| def __init__(self, *args, **kwargs): |
| super(ArgParser, self).__init__(*args, **kwargs) |
|
|
| self.add_argument( |
| "-s", |
| "--server", |
| type=str, |
| default="0.0.0.0", |
| help="Server IP for HF LLM Chat API", |
| ) |
| self.add_argument( |
| "-p", |
| "--port", |
| type=int, |
| default=23333, |
| help="Server Port for HF LLM Chat API", |
| ) |
|
|
| self.add_argument( |
| "-d", |
| "--dev", |
| default=False, |
| action="store_true", |
| help="Run in dev mode", |
| ) |
|
|
| self.args = self.parse_args(sys.argv[1:]) |
|
|
|
|
| app = ChatAPIApp().app |
|
|
| if __name__ == "__main__": |
| args = ArgParser().args |
| if args.dev: |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) |
| else: |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) |
|
|
| |
| |
|
|