import os import time import json import logging from typing import AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from sse_starlette.sse import EventSourceResponse from .models import ChatRequest, ChatResponse, ModelInfo, ErrorResponse from .llm_manager import LLMManager # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global LLM manager instance llm_manager: LLMManager = None @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan.""" global llm_manager # Startup logger.info("Starting up LLM API...") llm_manager = LLMManager() # Load the model success = await llm_manager.load_model() if not success: logger.warning("Failed to load model, using mock implementation") yield # Shutdown logger.info("Shutting down LLM API...") # Create FastAPI app app = FastAPI( title="LLM API - GPT Clone", description="A ChatGPT-like API with SSE streaming support using free LLM models", version="1.0.0", lifespan=lifespan, ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure appropriately for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", response_model=dict) async def root(): """Root endpoint with API information.""" return { "message": "LLM API - GPT Clone", "version": "1.0.0", "description": "A ChatGPT-like API with SSE streaming support", "endpoints": { "chat": "/v1/chat/completions", "models": "/v1/models", "health": "/health", }, } @app.get("/health", response_model=dict) async def health_check(): """Health check endpoint.""" global llm_manager return { "status": "healthy", "model_loaded": llm_manager.is_loaded if llm_manager else False, "model_type": llm_manager.model_type if llm_manager else "none", "timestamp": int(time.time()), } @app.get("/v1/models", response_model=dict) async def list_models(): """List available models.""" global llm_manager if not llm_manager: raise HTTPException(status_code=503, detail="Model manager not initialized") model_info = llm_manager.get_model_info() return {"object": "list", "data": [model_info]} @app.post("/v1/chat/completions") async def chat_completions(request: ChatRequest): """Chat completion endpoint with SSE streaming support.""" global llm_manager if not llm_manager: raise HTTPException(status_code=503, detail="Model manager not initialized") if not llm_manager.is_loaded: raise HTTPException(status_code=503, detail="Model not loaded") # Validate request if not request.messages: raise HTTPException(status_code=400, detail="Messages cannot be empty") # Check if streaming is requested if request.stream: return EventSourceResponse( stream_chat_response(request), media_type="text/event-stream" ) else: # Non-streaming response (collect all tokens and return at once) full_response = "" async for chunk in llm_manager.generate_stream(request): if "error" in chunk: raise HTTPException(status_code=500, detail=chunk["error"]["message"]) if "choices" in chunk and chunk["choices"]: choice = chunk["choices"][0] if "delta" in choice and "content" in choice["delta"]: full_response += choice["delta"]["content"] # Return complete response return ChatResponse( id=chunk["id"], created=chunk["created"], model=chunk["model"], choices=[ { "index": 0, "message": {"role": "assistant", "content": full_response}, "finish_reason": "stop", } ], usage={ "prompt_tokens": len(full_response.split()), # Rough estimate "completion_tokens": len(full_response.split()), "total_tokens": len(full_response.split()) * 2, }, ) async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[dict, None]: """Stream chat response tokens via SSE.""" global llm_manager try: async for chunk in llm_manager.generate_stream(request): if "error" in chunk: # Send error as SSE event yield {"event": "error", "data": json.dumps(chunk["error"])} return # Send chunk as SSE event yield {"event": "message", "data": json.dumps(chunk)} # Check if this is the final chunk if ( chunk.get("choices") and chunk["choices"][0].get("finish_reason") == "stop" ): break except Exception as e: logger.error(f"Error in stream_chat_response: {e}") yield { "event": "error", "data": json.dumps({"error": {"message": str(e), "type": "stream_error"}}), } @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """Handle HTTP exceptions.""" return JSONResponse( status_code=exc.status_code, content={ "error": { "message": exc.detail, "type": "http_error", "code": exc.status_code, } }, ) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """Handle general exceptions.""" logger.error(f"Unhandled exception: {exc}") return JSONResponse( status_code=500, content={ "error": { "message": "Internal server error", "type": "internal_error", "code": 500, } }, ) if __name__ == "__main__": import uvicorn uvicorn.run( "app.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info" )