| | import asyncio |
| | import json |
| | import logging |
| | import os |
| | import traceback |
| | from threading import Thread |
| |
|
| | import speech_recognition as sr |
| | import uvicorn |
| | from fastapi import Depends, FastAPI, Header, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.requests import Request |
| | from fastapi.responses import JSONResponse |
| | from pydub import AudioSegment |
| | from sse_starlette import EventSourceResponse |
| |
|
| | import extensions.openai.completions as OAIcompletions |
| | import extensions.openai.embeddings as OAIembeddings |
| | import extensions.openai.images as OAIimages |
| | import extensions.openai.logits as OAIlogits |
| | import extensions.openai.models as OAImodels |
| | import extensions.openai.moderations as OAImoderations |
| | from extensions.openai.errors import ServiceUnavailableError |
| | from extensions.openai.tokens import token_count, token_decode, token_encode |
| | from extensions.openai.utils import _start_cloudflared |
| | from modules import shared |
| | from modules.logging_colors import logger |
| | from modules.models import unload_model |
| | from modules.text_generation import stop_everything_event |
| |
|
| | from .typing import ( |
| | ChatCompletionRequest, |
| | ChatCompletionResponse, |
| | CompletionRequest, |
| | CompletionResponse, |
| | DecodeRequest, |
| | DecodeResponse, |
| | EmbeddingsRequest, |
| | EmbeddingsResponse, |
| | EncodeRequest, |
| | EncodeResponse, |
| | LoadLorasRequest, |
| | LoadModelRequest, |
| | LogitsRequest, |
| | LogitsResponse, |
| | LoraListResponse, |
| | ModelInfoResponse, |
| | ModelListResponse, |
| | TokenCountResponse, |
| | to_dict |
| | ) |
| |
|
| | params = { |
| | 'embedding_device': 'cpu', |
| | 'embedding_model': 'sentence-transformers/all-mpnet-base-v2', |
| | 'sd_webui_url': '', |
| | 'debug': 0 |
| | } |
| |
|
| |
|
| | streaming_semaphore = asyncio.Semaphore(1) |
| |
|
| |
|
| | def verify_api_key(authorization: str = Header(None)) -> None: |
| | expected_api_key = shared.args.api_key |
| | if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): |
| | raise HTTPException(status_code=401, detail="Unauthorized") |
| |
|
| |
|
| | def verify_admin_key(authorization: str = Header(None)) -> None: |
| | expected_api_key = shared.args.admin_key |
| | if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): |
| | raise HTTPException(status_code=401, detail="Unauthorized") |
| |
|
| |
|
| | app = FastAPI() |
| | check_key = [Depends(verify_api_key)] |
| | check_admin_key = [Depends(verify_admin_key)] |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"] |
| | ) |
| |
|
| |
|
| | @app.options("/", dependencies=check_key) |
| | async def options_route(): |
| | return JSONResponse(content="OK") |
| |
|
| |
|
| | @app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key) |
| | async def openai_completions(request: Request, request_data: CompletionRequest): |
| | path = request.url.path |
| | is_legacy = "/generate" in path |
| |
|
| | if request_data.stream: |
| | async def generator(): |
| | async with streaming_semaphore: |
| | response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) |
| | for resp in response: |
| | disconnected = await request.is_disconnected() |
| | if disconnected: |
| | break |
| |
|
| | yield {"data": json.dumps(resp)} |
| |
|
| | return EventSourceResponse(generator()) |
| |
|
| | else: |
| | response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key) |
| | async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest): |
| | path = request.url.path |
| | is_legacy = "/generate" in path |
| |
|
| | if request_data.stream: |
| | async def generator(): |
| | async with streaming_semaphore: |
| | response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) |
| | for resp in response: |
| | disconnected = await request.is_disconnected() |
| | if disconnected: |
| | break |
| |
|
| | yield {"data": json.dumps(resp)} |
| |
|
| | return EventSourceResponse(generator()) |
| |
|
| | else: |
| | response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.get("/v1/models", dependencies=check_key) |
| | @app.get("/v1/models/{model}", dependencies=check_key) |
| | async def handle_models(request: Request): |
| | path = request.url.path |
| | is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models' |
| |
|
| | if is_list: |
| | response = OAImodels.list_dummy_models() |
| | else: |
| | model_name = path[len('/v1/models/'):] |
| | response = OAImodels.model_info_dict(model_name) |
| |
|
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.get('/v1/billing/usage', dependencies=check_key) |
| | def handle_billing_usage(): |
| | ''' |
| | Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 |
| | ''' |
| | return JSONResponse(content={"total_usage": 0}) |
| |
|
| |
|
| | @app.post('/v1/audio/transcriptions', dependencies=check_key) |
| | async def handle_audio_transcription(request: Request): |
| | r = sr.Recognizer() |
| |
|
| | form = await request.form() |
| | audio_file = await form["file"].read() |
| | audio_data = AudioSegment.from_file(audio_file) |
| |
|
| | |
| | raw_data = audio_data.raw_data |
| |
|
| | |
| | audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width) |
| | whisper_language = form.getvalue('language', None) |
| | whisper_model = form.getvalue('model', 'tiny') |
| |
|
| | transcription = {"text": ""} |
| |
|
| | try: |
| | transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model) |
| | except sr.UnknownValueError: |
| | print("Whisper could not understand audio") |
| | transcription["text"] = "Whisper could not understand audio UnknownValueError" |
| | except sr.RequestError as e: |
| | print("Could not request results from Whisper", e) |
| | transcription["text"] = "Whisper could not understand audio RequestError" |
| |
|
| | return JSONResponse(content=transcription) |
| |
|
| |
|
| | @app.post('/v1/images/generations', dependencies=check_key) |
| | async def handle_image_generation(request: Request): |
| |
|
| | if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): |
| | raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") |
| |
|
| | body = await request.json() |
| | prompt = body['prompt'] |
| | size = body.get('size', '1024x1024') |
| | response_format = body.get('response_format', 'url') |
| | n = body.get('n', 1) |
| |
|
| | response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key) |
| | async def handle_embeddings(request: Request, request_data: EmbeddingsRequest): |
| | input = request_data.input |
| | if not input: |
| | raise HTTPException(status_code=400, detail="Missing required argument input") |
| |
|
| | if type(input) is str: |
| | input = [input] |
| |
|
| | response = OAIembeddings.embeddings(input, request_data.encoding_format) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/moderations", dependencies=check_key) |
| | async def handle_moderations(request: Request): |
| | body = await request.json() |
| | input = body["input"] |
| | if not input: |
| | raise HTTPException(status_code=400, detail="Missing required argument input") |
| |
|
| | response = OAImoderations.moderations(input) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key) |
| | async def handle_token_encode(request_data: EncodeRequest): |
| | response = token_encode(request_data.text) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key) |
| | async def handle_token_decode(request_data: DecodeRequest): |
| | response = token_decode(request_data.tokens) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key) |
| | async def handle_token_count(request_data: EncodeRequest): |
| | response = token_count(request_data.text) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key) |
| | async def handle_logits(request_data: LogitsRequest): |
| | ''' |
| | Given a prompt, returns the top 50 most likely logits as a dict. |
| | The keys are the tokens, and the values are the probabilities. |
| | ''' |
| | response = OAIlogits._get_next_logits(to_dict(request_data)) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post("/v1/internal/stop-generation", dependencies=check_key) |
| | async def handle_stop_generation(request: Request): |
| | stop_everything_event() |
| | return JSONResponse(content="OK") |
| |
|
| |
|
| | @app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key) |
| | async def handle_model_info(): |
| | payload = OAImodels.get_current_model_info() |
| | return JSONResponse(content=payload) |
| |
|
| |
|
| | @app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key) |
| | async def handle_list_models(): |
| | payload = OAImodels.list_models() |
| | return JSONResponse(content=payload) |
| |
|
| |
|
| | @app.post("/v1/internal/model/load", dependencies=check_admin_key) |
| | async def handle_load_model(request_data: LoadModelRequest): |
| | ''' |
| | This endpoint is experimental and may change in the future. |
| | |
| | The "args" parameter can be used to modify flags like "--load-in-4bit" |
| | or "--n-gpu-layers" before loading a model. Example: |
| | |
| | ``` |
| | "args": { |
| | "load_in_4bit": true, |
| | "n_gpu_layers": 12 |
| | } |
| | ``` |
| | |
| | Note that those settings will remain after loading the model. So you |
| | may need to change them back to load a second model. |
| | |
| | The "settings" parameter is also a dict but with keys for the |
| | shared.settings object. It can be used to modify the default instruction |
| | template like this: |
| | |
| | ``` |
| | "settings": { |
| | "instruction_template": "Alpaca" |
| | } |
| | ``` |
| | ''' |
| |
|
| | try: |
| | OAImodels._load_model(to_dict(request_data)) |
| | return JSONResponse(content="OK") |
| | except: |
| | traceback.print_exc() |
| | return HTTPException(status_code=400, detail="Failed to load the model.") |
| |
|
| |
|
| | @app.post("/v1/internal/model/unload", dependencies=check_admin_key) |
| | async def handle_unload_model(): |
| | unload_model() |
| |
|
| |
|
| | @app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key) |
| | async def handle_list_loras(): |
| | response = OAImodels.list_loras() |
| | return JSONResponse(content=response) |
| |
|
| |
|
| | @app.post("/v1/internal/lora/load", dependencies=check_admin_key) |
| | async def handle_load_loras(request_data: LoadLorasRequest): |
| | try: |
| | OAImodels.load_loras(request_data.lora_names) |
| | return JSONResponse(content="OK") |
| | except: |
| | traceback.print_exc() |
| | return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).") |
| |
|
| |
|
| | @app.post("/v1/internal/lora/unload", dependencies=check_admin_key) |
| | async def handle_unload_loras(): |
| | OAImodels.unload_all_loras() |
| | return JSONResponse(content="OK") |
| |
|
| |
|
| | def run_server(): |
| | server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' |
| | port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port)) |
| |
|
| | ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile) |
| | ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile) |
| |
|
| | if shared.args.public_api: |
| | def on_start(public_url: str): |
| | logger.info(f'OpenAI-compatible API URL:\n\n{public_url}\n') |
| |
|
| | _start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start) |
| | else: |
| | if ssl_keyfile and ssl_certfile: |
| | logger.info(f'OpenAI-compatible API URL:\n\nhttps://{server_addr}:{port}\n') |
| | else: |
| | logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n') |
| |
|
| | if shared.args.api_key: |
| | if not shared.args.admin_key: |
| | shared.args.admin_key = shared.args.api_key |
| |
|
| | logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n') |
| |
|
| | if shared.args.admin_key and shared.args.admin_key != shared.args.api_key: |
| | logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') |
| |
|
| | logging.getLogger("uvicorn.error").propagate = False |
| | uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) |
| |
|
| |
|
| | def setup(): |
| | if shared.args.nowebui: |
| | run_server() |
| | else: |
| | Thread(target=run_server, daemon=True).start() |
| |
|