import uuid from contextlib import asynccontextmanager from typing import List, Optional import urllib3 from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn from service.handlers.response import ResponseHandler, ChatCompletionRequest from service.middleware.rate_limiter import RateLimiter from config.constants import ( MODEL_MAPPING, ALLOWED_HOSTS, ENCRYPTION_KEY, ) from managers.chat_manager import ChatManager from api.chat.chat_api import ChatAPI, ChatConfig from config.api_keys import APIKeyManager from config.models import ModelID from utils.encrypt import encrypt from utils.http import HTTPClient from service.utils.validation import validate_message_format import traceback urllib3.disable_warnings() @asynccontextmanager async def lifespan(app: FastAPI): # Initialize global variables global http_client, console, chat_manager, rate_limiter console = Console() http_client = HTTPClient() chat_manager = ChatManager(queue_type="deque") # Use deque instead of redis rate_limiter = RateLimiter() # Initialize rate limiter # Initialize chat queue with API keys from config key_manager = APIKeyManager() api_keys = key_manager.list_keys() if not api_keys: console.print("[bold red]Warning: No API keys found in configuration[/]") else: for api_key in api_keys: await chat_manager.add_chat(api_key) console.print(f"[bold green]Initialized chat queue with {len(api_keys)} API keys[/]") yield # No cleanup needed for memory queue app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None) app.add_middleware( CORSMiddleware, allow_credentials=True, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(TrustedHostMiddleware, allowed_hosts=ALLOWED_HOSTS) class ChatCompletionRequest(BaseModel): model: str messages: List[dict] system_prompt: Optional[str] = None temperature: Optional[float] = 0.7 max_tokens: Optional[int] = 4096 top_p: Optional[float] = 1.0 top_k: Optional[int] = 50 stream: Optional[bool] = False class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[dict] usage: dict def map_model_name_to_id(model_name: str) -> ModelID: """Map standard model names to internal ModelID enum""" if model_name not in MODEL_MAPPING: raise HTTPException( status_code=400, detail=f"Unsupported model: {model_name}. Available models: {', '.join(MODEL_MAPPING.keys())}", ) return MODEL_MAPPING[model_name] @app.get("/") async def root(): return ResponseHandler.create_error_response(403, "forbidden") @app.middleware("http") async def check_request(request: Request, call_next): # is_allowed, error_msg, status_code = await rate_limiter.authenticate_request( # request # ) # if not is_allowed: # return ResponseHandler.create_error_response(status_code, error_msg) return await call_next(request) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, req: Request): model_id = map_model_name_to_id(request.model) messages = request.messages user_name = getattr(req.state, "user_name", None) validate_message_format(messages) try: config = ChatConfig( model_id=model_id, temperature=request.temperature, max_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, ) # Get chat instance from queue chat = await chat_manager.get_chat() if chat is None: raise HTTPException( status_code=503, detail="No available chat instances. Please try again later." ) if request.stream: return await ResponseHandler.generate_stream( f"chatcmpl-{uuid.uuid4()}", chat, messages, config, request.model ) else: response_text = "" tokens = 0 async for chunk, token_count in chat.send_message( messages=messages, config=config, ): if chunk: response_text += chunk tokens += token_count return ResponseHandler.build_chat_response( chunk_id=f"chatcmpl-{uuid.uuid4()}", model=request.model, response_text=response_text, prompt_tokens=len(str(messages)), tokens=tokens, ) except Exception as e: print(e) traceback.print_exc() encrypted_message = encrypt(str(e), ENCRYPTION_KEY) error_response = { "error": { "encrypted_message": encrypted_message, "type": "internal_server_error", "code": 500, } } return JSONResponse(status_code=500, content=error_response) MODEL_METADATA = { model_name: { "id": model_name, "object": "model", "created": 220899661, "owned_by": "Delamain" } for model_name in MODEL_MAPPING.keys() } @app.get("/v1/models") async def list_models(): models = list(MODEL_METADATA.values()) return { "object": "list", "data": models } @app.get("/v1/models/{model_id}") async def get_model(model_id: str): model = MODEL_METADATA.get(model_id) if not model: raise HTTPException( status_code=404, detail=f"Model not found. Available models: {', '.join(MODEL_METADATA.keys())}" ) return model