|
|
import uuid |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import List, Optional |
|
|
|
|
|
import urllib3 |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware |
|
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
|
from pydantic import BaseModel |
|
|
from rich.console import Console |
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn |
|
|
|
|
|
from service.handlers.response import ResponseHandler, ChatCompletionRequest |
|
|
from service.middleware.rate_limiter import RateLimiter |
|
|
from config.constants import ( |
|
|
MODEL_MAPPING, |
|
|
ALLOWED_HOSTS, |
|
|
ENCRYPTION_KEY, |
|
|
) |
|
|
from managers.chat_manager import ChatManager |
|
|
from api.chat.chat_api import ChatAPI, ChatConfig |
|
|
from config.api_keys import APIKeyManager |
|
|
from config.models import ModelID |
|
|
from utils.encrypt import encrypt |
|
|
from utils.http import HTTPClient |
|
|
from service.utils.validation import validate_message_format |
|
|
|
|
|
import traceback |
|
|
|
|
|
urllib3.disable_warnings() |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
global http_client, console, chat_manager, rate_limiter |
|
|
console = Console() |
|
|
http_client = HTTPClient() |
|
|
chat_manager = ChatManager(queue_type="deque") |
|
|
rate_limiter = RateLimiter() |
|
|
|
|
|
|
|
|
key_manager = APIKeyManager() |
|
|
api_keys = key_manager.list_keys() |
|
|
if not api_keys: |
|
|
console.print("[bold red]Warning: No API keys found in configuration[/]") |
|
|
else: |
|
|
for api_key in api_keys: |
|
|
await chat_manager.add_chat(api_key) |
|
|
console.print(f"[bold green]Initialized chat queue with {len(api_keys)} API keys[/]") |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None) |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_credentials=True, |
|
|
allow_origins=["*"], |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
app.add_middleware(TrustedHostMiddleware, allowed_hosts=ALLOWED_HOSTS) |
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[dict] |
|
|
system_prompt: Optional[str] = None |
|
|
temperature: Optional[float] = 0.7 |
|
|
max_tokens: Optional[int] = 4096 |
|
|
top_p: Optional[float] = 1.0 |
|
|
top_k: Optional[int] = 50 |
|
|
stream: Optional[bool] = False |
|
|
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[dict] |
|
|
usage: dict |
|
|
|
|
|
|
|
|
def map_model_name_to_id(model_name: str) -> ModelID: |
|
|
"""Map standard model names to internal ModelID enum""" |
|
|
|
|
|
if model_name not in MODEL_MAPPING: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Unsupported model: {model_name}. Available models: {', '.join(MODEL_MAPPING.keys())}", |
|
|
) |
|
|
|
|
|
return MODEL_MAPPING[model_name] |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return ResponseHandler.create_error_response(403, "forbidden") |
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def check_request(request: Request, call_next): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return await call_next(request) |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def create_chat_completion(request: ChatCompletionRequest, req: Request): |
|
|
model_id = map_model_name_to_id(request.model) |
|
|
messages = request.messages |
|
|
user_name = getattr(req.state, "user_name", None) |
|
|
|
|
|
validate_message_format(messages) |
|
|
|
|
|
try: |
|
|
config = ChatConfig( |
|
|
model_id=model_id, |
|
|
temperature=request.temperature, |
|
|
max_tokens=request.max_tokens, |
|
|
top_p=request.top_p, |
|
|
top_k=request.top_k, |
|
|
) |
|
|
|
|
|
|
|
|
chat = await chat_manager.get_chat() |
|
|
if chat is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="No available chat instances. Please try again later." |
|
|
) |
|
|
|
|
|
if request.stream: |
|
|
return await ResponseHandler.generate_stream( |
|
|
f"chatcmpl-{uuid.uuid4()}", chat, messages, config, request.model |
|
|
) |
|
|
else: |
|
|
response_text = "" |
|
|
tokens = 0 |
|
|
async for chunk, token_count in chat.send_message( |
|
|
messages=messages, |
|
|
config=config, |
|
|
): |
|
|
if chunk: |
|
|
response_text += chunk |
|
|
tokens += token_count |
|
|
|
|
|
return ResponseHandler.build_chat_response( |
|
|
chunk_id=f"chatcmpl-{uuid.uuid4()}", |
|
|
model=request.model, |
|
|
response_text=response_text, |
|
|
prompt_tokens=len(str(messages)), |
|
|
tokens=tokens, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(e) |
|
|
traceback.print_exc() |
|
|
encrypted_message = encrypt(str(e), ENCRYPTION_KEY) |
|
|
error_response = { |
|
|
"error": { |
|
|
"encrypted_message": encrypted_message, |
|
|
"type": "internal_server_error", |
|
|
"code": 500, |
|
|
} |
|
|
} |
|
|
return JSONResponse(status_code=500, content=error_response) |
|
|
|
|
|
MODEL_METADATA = { |
|
|
model_name: { |
|
|
"id": model_name, |
|
|
"object": "model", |
|
|
"created": 220899661, |
|
|
"owned_by": "Delamain" |
|
|
} |
|
|
for model_name in MODEL_MAPPING.keys() |
|
|
} |
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
models = list(MODEL_METADATA.values()) |
|
|
return { |
|
|
"object": "list", |
|
|
"data": models |
|
|
} |
|
|
|
|
|
@app.get("/v1/models/{model_id}") |
|
|
async def get_model(model_id: str): |
|
|
model = MODEL_METADATA.get(model_id) |
|
|
if not model: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Model not found. Available models: {', '.join(MODEL_METADATA.keys())}" |
|
|
) |
|
|
return model |