Spaces:
Sleeping
Sleeping
Update apis/chat_api.py
Browse files- apis/chat_api.py +34 -119
apis/chat_api.py
CHANGED
|
@@ -20,16 +20,14 @@ from mocks.stream_chat_mocker import stream_chat_mock
|
|
| 20 |
class ChatAPIApp:
|
| 21 |
def __init__(self):
|
| 22 |
self.app = FastAPI(
|
| 23 |
-
docs_url=
|
|
|
|
| 24 |
title="HuggingFace LLM API",
|
| 25 |
-
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
|
| 26 |
version="1.0",
|
| 27 |
)
|
| 28 |
self.setup_routes()
|
| 29 |
|
| 30 |
def get_available_models(self):
|
| 31 |
-
# https://platform.openai.com/docs/api-reference/models/list
|
| 32 |
-
# ANCHOR[id=available-models]: Available models
|
| 33 |
self.available_models = {
|
| 34 |
"object": "list",
|
| 35 |
"data": [
|
|
@@ -63,58 +61,25 @@ class ChatAPIApp:
|
|
| 63 |
HTTPBearer(auto_error=False)
|
| 64 |
),
|
| 65 |
):
|
| 66 |
-
api_key = None
|
| 67 |
-
if
|
| 68 |
-
api_key
|
| 69 |
-
|
| 70 |
-
api_key = os.getenv("HF_TOKEN")
|
| 71 |
-
|
| 72 |
-
if api_key:
|
| 73 |
-
if api_key.startswith("hf_"):
|
| 74 |
-
return api_key
|
| 75 |
-
else:
|
| 76 |
-
logger.warn(f"Invalid HF Token!")
|
| 77 |
-
else:
|
| 78 |
-
logger.warn("Not provide HF Token!")
|
| 79 |
return None
|
| 80 |
|
| 81 |
class ChatCompletionsPostItem(BaseModel):
|
| 82 |
-
model: str = Field(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
default=0.5,
|
| 92 |
-
description="(float) Temperature",
|
| 93 |
-
)
|
| 94 |
-
top_p: Union[float, None] = Field(
|
| 95 |
-
default=0.95,
|
| 96 |
-
description="(float) top p",
|
| 97 |
-
)
|
| 98 |
-
max_tokens: Union[int, None] = Field(
|
| 99 |
-
default=-1,
|
| 100 |
-
description="(int) Max tokens",
|
| 101 |
-
)
|
| 102 |
-
use_cache: bool = Field(
|
| 103 |
-
default=False,
|
| 104 |
-
description="(bool) Use cache",
|
| 105 |
-
)
|
| 106 |
-
stream: bool = Field(
|
| 107 |
-
default=True,
|
| 108 |
-
description="(bool) Stream",
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
def chat_completions(
|
| 112 |
-
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
| 113 |
-
):
|
| 114 |
streamer = MessageStreamer(model=item.model)
|
| 115 |
composer = MessageComposer(model=item.model)
|
| 116 |
composer.merge(messages=item.messages)
|
| 117 |
-
# streamer.chat = stream_chat_mock
|
| 118 |
|
| 119 |
stream_response = streamer.chat_response(
|
| 120 |
prompt=composer.merged_str,
|
|
@@ -124,80 +89,36 @@ class ChatAPIApp:
|
|
| 124 |
api_key=api_key,
|
| 125 |
use_cache=item.use_cache,
|
| 126 |
)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
)
|
| 134 |
-
return event_source_response
|
| 135 |
-
else:
|
| 136 |
-
data_response = streamer.chat_return_dict(stream_response)
|
| 137 |
-
return data_response
|
| 138 |
|
| 139 |
def get_readme(self):
|
| 140 |
readme_path = Path(__file__).parents[1] / "README.md"
|
| 141 |
with open(readme_path, "r", encoding="utf-8") as rf:
|
| 142 |
-
|
| 143 |
-
readme_html = markdown2.markdown(
|
| 144 |
-
readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
|
| 145 |
-
)
|
| 146 |
-
return readme_html
|
| 147 |
|
| 148 |
def setup_routes(self):
|
|
|
|
|
|
|
| 149 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
prefix + "/models",
|
| 157 |
-
summary="Get available models",
|
| 158 |
-
include_in_schema=include_in_schema,
|
| 159 |
-
)(self.get_available_models)
|
| 160 |
-
|
| 161 |
-
self.app.post(
|
| 162 |
-
prefix + "/chat/completions",
|
| 163 |
-
summary="Chat completions in conversation session",
|
| 164 |
-
include_in_schema=include_in_schema,
|
| 165 |
-
)(self.chat_completions)
|
| 166 |
-
self.app.get(
|
| 167 |
-
"/readme",
|
| 168 |
-
summary="README of HF LLM API",
|
| 169 |
-
response_class=HTMLResponse,
|
| 170 |
-
include_in_schema=False,
|
| 171 |
-
)(self.get_readme)
|
| 172 |
|
| 173 |
|
| 174 |
class ArgParser(argparse.ArgumentParser):
|
| 175 |
def __init__(self, *args, **kwargs):
|
| 176 |
super(ArgParser, self).__init__(*args, **kwargs)
|
| 177 |
-
|
| 178 |
-
self.add_argument(
|
| 179 |
-
|
| 180 |
-
"--server",
|
| 181 |
-
type=str,
|
| 182 |
-
default="0.0.0.0",
|
| 183 |
-
help="Server IP for HF LLM Chat API",
|
| 184 |
-
)
|
| 185 |
-
self.add_argument(
|
| 186 |
-
"-p",
|
| 187 |
-
"--port",
|
| 188 |
-
type=int,
|
| 189 |
-
default=23333,
|
| 190 |
-
help="Server Port for HF LLM Chat API",
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
self.add_argument(
|
| 194 |
-
"-d",
|
| 195 |
-
"--dev",
|
| 196 |
-
default=False,
|
| 197 |
-
action="store_true",
|
| 198 |
-
help="Run in dev mode",
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
self.args = self.parse_args(sys.argv[1:])
|
| 202 |
|
| 203 |
|
|
@@ -205,10 +126,4 @@ app = ChatAPIApp().app
|
|
| 205 |
|
| 206 |
if __name__ == "__main__":
|
| 207 |
args = ArgParser().args
|
| 208 |
-
|
| 209 |
-
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
|
| 210 |
-
else:
|
| 211 |
-
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
|
| 212 |
-
|
| 213 |
-
# python -m apis.chat_api # [Docker] on product mode
|
| 214 |
-
# python -m apis.chat_api -d # [Dev] on develop mode
|
|
|
|
| 20 |
class ChatAPIApp:
|
| 21 |
def __init__(self):
|
| 22 |
self.app = FastAPI(
|
| 23 |
+
docs_url=None, # Hide Swagger UI
|
| 24 |
+
redoc_url=None, # Hide ReDoc UI
|
| 25 |
title="HuggingFace LLM API",
|
|
|
|
| 26 |
version="1.0",
|
| 27 |
)
|
| 28 |
self.setup_routes()
|
| 29 |
|
| 30 |
def get_available_models(self):
|
|
|
|
|
|
|
| 31 |
self.available_models = {
|
| 32 |
"object": "list",
|
| 33 |
"data": [
|
|
|
|
| 61 |
HTTPBearer(auto_error=False)
|
| 62 |
),
|
| 63 |
):
|
| 64 |
+
api_key = os.getenv("HF_TOKEN") if credentials is None else credentials.credentials
|
| 65 |
+
if api_key and api_key.startswith("hf_"):
|
| 66 |
+
return api_key
|
| 67 |
+
logger.warn("Invalid or missing HF Token!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
return None
|
| 69 |
|
| 70 |
class ChatCompletionsPostItem(BaseModel):
|
| 71 |
+
model: str = Field(default="mixtral-8x7b", description="(str) `mixtral-8x7b`")
|
| 72 |
+
messages: list = Field(default=[{"role": "user", "content": "Hello, who are you?"}], description="(list) Messages")
|
| 73 |
+
temperature: Union[float, None] = Field(default=0.5, description="(float) Temperature")
|
| 74 |
+
top_p: Union[float, None] = Field(default=0.95, description="(float) top p")
|
| 75 |
+
max_tokens: Union[int, None] = Field(default=-1, description="(int) Max tokens")
|
| 76 |
+
use_cache: bool = Field(default=False, description="(bool) Use cache")
|
| 77 |
+
stream: bool = Field(default=True, description="(bool) Stream")
|
| 78 |
+
|
| 79 |
+
def chat_completions(self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
streamer = MessageStreamer(model=item.model)
|
| 81 |
composer = MessageComposer(model=item.model)
|
| 82 |
composer.merge(messages=item.messages)
|
|
|
|
| 83 |
|
| 84 |
stream_response = streamer.chat_response(
|
| 85 |
prompt=composer.merged_str,
|
|
|
|
| 89 |
api_key=api_key,
|
| 90 |
use_cache=item.use_cache,
|
| 91 |
)
|
| 92 |
+
return EventSourceResponse(
|
| 93 |
+
streamer.chat_return_generator(stream_response),
|
| 94 |
+
media_type="text/event-stream",
|
| 95 |
+
ping=2000,
|
| 96 |
+
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
|
| 97 |
+
) if item.stream else streamer.chat_return_dict(stream_response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
def get_readme(self):
|
| 100 |
readme_path = Path(__file__).parents[1] / "README.md"
|
| 101 |
with open(readme_path, "r", encoding="utf-8") as rf:
|
| 102 |
+
return markdown2.markdown(rf.read(), extras=["table", "fenced-code-blocks", "highlightjs-lang"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def setup_routes(self):
|
| 105 |
+
self.app.get("/", summary="Root endpoint", include_in_schema=False)(lambda: "Hello World!") # Root route
|
| 106 |
+
|
| 107 |
for prefix in ["", "/v1", "/api", "/api/v1"]:
|
| 108 |
+
include_in_schema = prefix == "/api/v1"
|
| 109 |
+
|
| 110 |
+
self.app.get(prefix + "/models", summary="Get available models", include_in_schema=include_in_schema)(self.get_available_models)
|
| 111 |
+
self.app.post(prefix + "/chat/completions", summary="Chat completions in conversation session", include_in_schema=include_in_schema)(self.chat_completions)
|
| 112 |
+
|
| 113 |
+
self.app.get("/readme", summary="README of HF LLM API", response_class=HTMLResponse, include_in_schema=False)(self.get_readme)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
class ArgParser(argparse.ArgumentParser):
|
| 117 |
def __init__(self, *args, **kwargs):
|
| 118 |
super(ArgParser, self).__init__(*args, **kwargs)
|
| 119 |
+
self.add_argument("-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API")
|
| 120 |
+
self.add_argument("-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API")
|
| 121 |
+
self.add_argument("-d", "--dev", default=False, action="store_true", help="Run in dev mode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
self.args = self.parse_args(sys.argv[1:])
|
| 123 |
|
| 124 |
|
|
|
|
| 126 |
|
| 127 |
if __name__ == "__main__":
|
| 128 |
args = ArgParser().args
|
| 129 |
+
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=args.dev)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|