|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import json |
|
|
import typing |
|
|
import contextlib |
|
|
|
|
|
from anyio import Lock |
|
|
from functools import partial |
|
|
from typing import List, Optional, Union, Dict |
|
|
|
|
|
import llama_cpp |
|
|
|
|
|
import anyio |
|
|
from anyio.streams.memory import MemoryObjectSendStream |
|
|
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool |
|
|
from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body |
|
|
from fastapi.middleware import Middleware |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.security import HTTPBearer |
|
|
from sse_starlette.sse import EventSourceResponse |
|
|
from starlette_context.plugins import RequestIdPlugin |
|
|
from starlette_context.middleware import RawContextMiddleware |
|
|
|
|
|
from llama_cpp.server.model import ( |
|
|
LlamaProxy, |
|
|
) |
|
|
from llama_cpp.server.settings import ( |
|
|
ConfigFileSettings, |
|
|
Settings, |
|
|
ModelSettings, |
|
|
ServerSettings, |
|
|
) |
|
|
from llama_cpp.server.types import ( |
|
|
CreateCompletionRequest, |
|
|
CreateEmbeddingRequest, |
|
|
CreateChatCompletionRequest, |
|
|
ModelList, |
|
|
TokenizeInputRequest, |
|
|
TokenizeInputResponse, |
|
|
TokenizeInputCountResponse, |
|
|
DetokenizeInputRequest, |
|
|
DetokenizeInputResponse, |
|
|
) |
|
|
from llama_cpp.server.errors import RouteErrorHandler |
|
|
|
|
|
|
|
|
router = APIRouter(route_class=RouteErrorHandler) |
|
|
|
|
|
_server_settings: Optional[ServerSettings] = None |
|
|
|
|
|
|
|
|
def set_server_settings(server_settings: ServerSettings): |
|
|
global _server_settings |
|
|
_server_settings = server_settings |
|
|
|
|
|
|
|
|
def get_server_settings(): |
|
|
yield _server_settings |
|
|
|
|
|
|
|
|
_llama_proxy: Optional[LlamaProxy] = None |
|
|
|
|
|
llama_outer_lock = Lock() |
|
|
llama_inner_lock = Lock() |
|
|
|
|
|
|
|
|
def set_llama_proxy(model_settings: List[ModelSettings]): |
|
|
global _llama_proxy |
|
|
_llama_proxy = LlamaProxy(models=model_settings) |
|
|
|
|
|
|
|
|
async def get_llama_proxy(): |
|
|
|
|
|
|
|
|
|
|
|
await llama_outer_lock.acquire() |
|
|
release_outer_lock = True |
|
|
try: |
|
|
await llama_inner_lock.acquire() |
|
|
try: |
|
|
llama_outer_lock.release() |
|
|
release_outer_lock = False |
|
|
yield _llama_proxy |
|
|
finally: |
|
|
llama_inner_lock.release() |
|
|
finally: |
|
|
if release_outer_lock: |
|
|
llama_outer_lock.release() |
|
|
|
|
|
|
|
|
_ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None |
|
|
|
|
|
|
|
|
def set_ping_message_factory(factory: typing.Callable[[], bytes]): |
|
|
global _ping_message_factory |
|
|
_ping_message_factory = factory |
|
|
|
|
|
|
|
|
def create_app( |
|
|
settings: Settings | None = None, |
|
|
server_settings: ServerSettings | None = None, |
|
|
model_settings: List[ModelSettings] | None = None, |
|
|
): |
|
|
config_file = os.environ.get("CONFIG_FILE", None) |
|
|
if config_file is not None: |
|
|
if not os.path.exists(config_file): |
|
|
raise ValueError(f"Config file {config_file} not found!") |
|
|
with open(config_file, "rb") as f: |
|
|
|
|
|
if config_file.endswith(".yaml") or config_file.endswith(".yml"): |
|
|
import yaml |
|
|
|
|
|
config_file_settings = ConfigFileSettings.model_validate_json( |
|
|
json.dumps(yaml.safe_load(f)) |
|
|
) |
|
|
else: |
|
|
config_file_settings = ConfigFileSettings.model_validate_json(f.read()) |
|
|
server_settings = ServerSettings.model_validate(config_file_settings) |
|
|
model_settings = config_file_settings.models |
|
|
|
|
|
if server_settings is None and model_settings is None: |
|
|
if settings is None: |
|
|
settings = Settings() |
|
|
server_settings = ServerSettings.model_validate(settings) |
|
|
model_settings = [ModelSettings.model_validate(settings)] |
|
|
|
|
|
assert ( |
|
|
server_settings is not None and model_settings is not None |
|
|
), "server_settings and model_settings must be provided together" |
|
|
|
|
|
set_server_settings(server_settings) |
|
|
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))] |
|
|
app = FastAPI( |
|
|
middleware=middleware, |
|
|
title="🦙 llama.cpp Python API", |
|
|
version=llama_cpp.__version__, |
|
|
root_path=server_settings.root_path, |
|
|
) |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
app.include_router(router) |
|
|
|
|
|
assert model_settings is not None |
|
|
set_llama_proxy(model_settings=model_settings) |
|
|
|
|
|
if server_settings.disable_ping_events: |
|
|
set_ping_message_factory(lambda: bytes()) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
def prepare_request_resources( |
|
|
body: CreateCompletionRequest | CreateChatCompletionRequest, |
|
|
llama_proxy: LlamaProxy, |
|
|
body_model: str | None, |
|
|
kwargs, |
|
|
) -> llama_cpp.Llama: |
|
|
if llama_proxy is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
|
detail="Service is not available", |
|
|
) |
|
|
llama = llama_proxy(body_model) |
|
|
if body.logit_bias is not None: |
|
|
kwargs["logit_bias"] = ( |
|
|
_logit_bias_tokens_to_input_ids(llama, body.logit_bias) |
|
|
if body.logit_bias_type == "tokens" |
|
|
else body.logit_bias |
|
|
) |
|
|
|
|
|
if body.grammar is not None: |
|
|
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) |
|
|
|
|
|
if body.min_tokens > 0: |
|
|
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList( |
|
|
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] |
|
|
) |
|
|
if "logits_processor" not in kwargs: |
|
|
kwargs["logits_processor"] = _min_tokens_logits_processor |
|
|
else: |
|
|
kwargs["logits_processor"].extend(_min_tokens_logits_processor) |
|
|
return llama |
|
|
|
|
|
|
|
|
async def get_event_publisher( |
|
|
request: Request, |
|
|
inner_send_chan: MemoryObjectSendStream[typing.Any], |
|
|
body: CreateCompletionRequest | CreateChatCompletionRequest, |
|
|
body_model: str | None, |
|
|
llama_call, |
|
|
kwargs, |
|
|
): |
|
|
server_settings = next(get_server_settings()) |
|
|
interrupt_requests = ( |
|
|
server_settings.interrupt_requests if server_settings else False |
|
|
) |
|
|
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: |
|
|
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) |
|
|
async with inner_send_chan: |
|
|
try: |
|
|
iterator = await run_in_threadpool(llama_call, llama, **kwargs) |
|
|
async for chunk in iterate_in_threadpool(iterator): |
|
|
await inner_send_chan.send(dict(data=json.dumps(chunk))) |
|
|
if await request.is_disconnected(): |
|
|
raise anyio.get_cancelled_exc_class()() |
|
|
if interrupt_requests and llama_outer_lock.locked(): |
|
|
await inner_send_chan.send(dict(data="[DONE]")) |
|
|
raise anyio.get_cancelled_exc_class()() |
|
|
await inner_send_chan.send(dict(data="[DONE]")) |
|
|
except anyio.get_cancelled_exc_class() as e: |
|
|
print("disconnected") |
|
|
with anyio.move_on_after(1, shield=True): |
|
|
print( |
|
|
f"Disconnected from client (via refresh/close) {request.client}" |
|
|
) |
|
|
raise e |
|
|
|
|
|
|
|
|
def _logit_bias_tokens_to_input_ids( |
|
|
llama: llama_cpp.Llama, |
|
|
logit_bias: Dict[str, float], |
|
|
) -> Dict[str, float]: |
|
|
to_bias: Dict[str, float] = {} |
|
|
for token, score in logit_bias.items(): |
|
|
token = token.encode("utf-8") |
|
|
for input_id in llama.tokenize(token, add_bos=False, special=True): |
|
|
to_bias[str(input_id)] = score |
|
|
return to_bias |
|
|
|
|
|
|
|
|
|
|
|
bearer_scheme = HTTPBearer(auto_error=False) |
|
|
|
|
|
|
|
|
async def authenticate( |
|
|
settings: Settings = Depends(get_server_settings), |
|
|
authorization: Optional[str] = Depends(bearer_scheme), |
|
|
): |
|
|
|
|
|
if settings.api_key is None: |
|
|
return True |
|
|
|
|
|
|
|
|
if authorization and authorization.credentials == settings.api_key: |
|
|
|
|
|
return authorization.credentials |
|
|
|
|
|
|
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Invalid API key", |
|
|
) |
|
|
|
|
|
|
|
|
openai_v1_tag = "OpenAI V1" |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/v1/completions", |
|
|
summary="Completion", |
|
|
dependencies=[Depends(authenticate)], |
|
|
response_model=Union[ |
|
|
llama_cpp.CreateCompletionResponse, |
|
|
str, |
|
|
], |
|
|
responses={ |
|
|
"200": { |
|
|
"description": "Successful Response", |
|
|
"content": { |
|
|
"application/json": { |
|
|
"schema": { |
|
|
"anyOf": [ |
|
|
{"$ref": "#/components/schemas/CreateCompletionResponse"} |
|
|
], |
|
|
"title": "Completion response, when stream=False", |
|
|
} |
|
|
}, |
|
|
"text/event-stream": { |
|
|
"schema": { |
|
|
"type": "string", |
|
|
"title": "Server Side Streaming response, when stream=True. " |
|
|
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", |
|
|
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", |
|
|
} |
|
|
}, |
|
|
}, |
|
|
} |
|
|
}, |
|
|
tags=[openai_v1_tag], |
|
|
) |
|
|
@router.post( |
|
|
"/v1/engines/copilot-codex/completions", |
|
|
include_in_schema=False, |
|
|
dependencies=[Depends(authenticate)], |
|
|
tags=[openai_v1_tag], |
|
|
) |
|
|
async def create_completion( |
|
|
request: Request, |
|
|
body: CreateCompletionRequest, |
|
|
) -> llama_cpp.Completion: |
|
|
if isinstance(body.prompt, list): |
|
|
assert len(body.prompt) <= 1 |
|
|
body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" |
|
|
|
|
|
body_model = ( |
|
|
body.model |
|
|
if request.url.path != "/v1/engines/copilot-codex/completions" |
|
|
else "copilot-codex" |
|
|
) |
|
|
|
|
|
exclude = { |
|
|
"n", |
|
|
"best_of", |
|
|
"logit_bias_type", |
|
|
"user", |
|
|
"min_tokens", |
|
|
} |
|
|
kwargs = body.model_dump(exclude=exclude) |
|
|
|
|
|
|
|
|
if kwargs.get("stream", False): |
|
|
send_chan, recv_chan = anyio.create_memory_object_stream(10) |
|
|
return EventSourceResponse( |
|
|
recv_chan, |
|
|
data_sender_callable=partial( |
|
|
get_event_publisher, |
|
|
request=request, |
|
|
inner_send_chan=send_chan, |
|
|
body=body, |
|
|
body_model=body_model, |
|
|
llama_call=llama_cpp.Llama.__call__, |
|
|
kwargs=kwargs, |
|
|
), |
|
|
sep="\n", |
|
|
ping_message_factory=_ping_message_factory, |
|
|
) |
|
|
|
|
|
|
|
|
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: |
|
|
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) |
|
|
|
|
|
if await request.is_disconnected(): |
|
|
print( |
|
|
f"Disconnected from client (via refresh/close) before llm invoked {request.client}" |
|
|
) |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
|
detail="Client closed request", |
|
|
) |
|
|
|
|
|
return await run_in_threadpool(llama, **kwargs) |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/v1/embeddings", |
|
|
summary="Embedding", |
|
|
dependencies=[Depends(authenticate)], |
|
|
tags=[openai_v1_tag], |
|
|
) |
|
|
async def create_embedding( |
|
|
request: CreateEmbeddingRequest, |
|
|
llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
|
|
): |
|
|
return await run_in_threadpool( |
|
|
llama_proxy(request.model).create_embedding, |
|
|
**request.model_dump(exclude={"user"}), |
|
|
) |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/v1/chat/completions", |
|
|
summary="Chat", |
|
|
dependencies=[Depends(authenticate)], |
|
|
response_model=Union[llama_cpp.ChatCompletion, str], |
|
|
responses={ |
|
|
"200": { |
|
|
"description": "Successful Response", |
|
|
"content": { |
|
|
"application/json": { |
|
|
"schema": { |
|
|
"anyOf": [ |
|
|
{ |
|
|
"$ref": "#/components/schemas/CreateChatCompletionResponse" |
|
|
} |
|
|
], |
|
|
"title": "Completion response, when stream=False", |
|
|
} |
|
|
}, |
|
|
"text/event-stream": { |
|
|
"schema": { |
|
|
"type": "string", |
|
|
"title": "Server Side Streaming response, when stream=True" |
|
|
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", |
|
|
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", |
|
|
} |
|
|
}, |
|
|
}, |
|
|
} |
|
|
}, |
|
|
tags=[openai_v1_tag], |
|
|
) |
|
|
async def create_chat_completion( |
|
|
request: Request, |
|
|
body: CreateChatCompletionRequest = Body( |
|
|
openapi_examples={ |
|
|
"normal": { |
|
|
"summary": "Chat Completion", |
|
|
"value": { |
|
|
"model": "gpt-3.5-turbo", |
|
|
"messages": [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "What is the capital of France?"}, |
|
|
], |
|
|
}, |
|
|
}, |
|
|
"json_mode": { |
|
|
"summary": "JSON Mode", |
|
|
"value": { |
|
|
"model": "gpt-3.5-turbo", |
|
|
"messages": [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "Who won the world series in 2020"}, |
|
|
], |
|
|
"response_format": {"type": "json_object"}, |
|
|
}, |
|
|
}, |
|
|
"tool_calling": { |
|
|
"summary": "Tool Calling", |
|
|
"value": { |
|
|
"model": "gpt-3.5-turbo", |
|
|
"messages": [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "Extract Jason is 30 years old."}, |
|
|
], |
|
|
"tools": [ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "User", |
|
|
"description": "User record", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"name": {"type": "string"}, |
|
|
"age": {"type": "number"}, |
|
|
}, |
|
|
"required": ["name", "age"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
], |
|
|
"tool_choice": { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "User", |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
"logprobs": { |
|
|
"summary": "Logprobs", |
|
|
"value": { |
|
|
"model": "gpt-3.5-turbo", |
|
|
"messages": [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "What is the capital of France?"}, |
|
|
], |
|
|
"logprobs": True, |
|
|
"top_logprobs": 10, |
|
|
}, |
|
|
}, |
|
|
} |
|
|
), |
|
|
) -> llama_cpp.ChatCompletion: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
body_model = body.model |
|
|
exclude = { |
|
|
"n", |
|
|
"logit_bias_type", |
|
|
"user", |
|
|
"min_tokens", |
|
|
} |
|
|
kwargs = body.model_dump(exclude=exclude) |
|
|
|
|
|
|
|
|
if kwargs.get("stream", False): |
|
|
send_chan, recv_chan = anyio.create_memory_object_stream(10) |
|
|
return EventSourceResponse( |
|
|
recv_chan, |
|
|
data_sender_callable=partial( |
|
|
get_event_publisher, |
|
|
request=request, |
|
|
inner_send_chan=send_chan, |
|
|
body=body, |
|
|
body_model=body_model, |
|
|
llama_call=llama_cpp.Llama.create_chat_completion, |
|
|
kwargs=kwargs, |
|
|
), |
|
|
sep="\n", |
|
|
ping_message_factory=_ping_message_factory, |
|
|
) |
|
|
|
|
|
|
|
|
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy: |
|
|
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs) |
|
|
|
|
|
if await request.is_disconnected(): |
|
|
print( |
|
|
f"Disconnected from client (via refresh/close) before llm invoked {request.client}" |
|
|
) |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
|
detail="Client closed request", |
|
|
) |
|
|
|
|
|
return await run_in_threadpool(llama.create_chat_completion, **kwargs) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/v1/models", |
|
|
summary="Models", |
|
|
dependencies=[Depends(authenticate)], |
|
|
tags=[openai_v1_tag], |
|
|
) |
|
|
async def get_models( |
|
|
llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
|
|
) -> ModelList: |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [ |
|
|
{ |
|
|
"id": model_alias, |
|
|
"object": "model", |
|
|
"owned_by": "me", |
|
|
"permissions": [], |
|
|
} |
|
|
for model_alias in llama_proxy |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
extras_tag = "Extras" |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/extras/tokenize", |
|
|
summary="Tokenize", |
|
|
dependencies=[Depends(authenticate)], |
|
|
tags=[extras_tag], |
|
|
) |
|
|
async def tokenize( |
|
|
body: TokenizeInputRequest, |
|
|
llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
|
|
) -> TokenizeInputResponse: |
|
|
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) |
|
|
|
|
|
return TokenizeInputResponse(tokens=tokens) |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/extras/tokenize/count", |
|
|
summary="Tokenize Count", |
|
|
dependencies=[Depends(authenticate)], |
|
|
tags=[extras_tag], |
|
|
) |
|
|
async def count_query_tokens( |
|
|
body: TokenizeInputRequest, |
|
|
llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
|
|
) -> TokenizeInputCountResponse: |
|
|
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) |
|
|
|
|
|
return TokenizeInputCountResponse(count=len(tokens)) |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/extras/detokenize", |
|
|
summary="Detokenize", |
|
|
dependencies=[Depends(authenticate)], |
|
|
tags=[extras_tag], |
|
|
) |
|
|
async def detokenize( |
|
|
body: DetokenizeInputRequest, |
|
|
llama_proxy: LlamaProxy = Depends(get_llama_proxy), |
|
|
) -> DetokenizeInputResponse: |
|
|
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8") |
|
|
|
|
|
return DetokenizeInputResponse(text=text) |
|
|
|