| | import os |
| | import time |
| | import json |
| | import logging |
| | from typing import AsyncGenerator |
| | from contextlib import asynccontextmanager |
| |
|
| | from fastapi import FastAPI, HTTPException, Request |
| | from fastapi.responses import StreamingResponse, JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from sse_starlette.sse import EventSourceResponse |
| |
|
| | from .models import ChatRequest, ChatResponse, ModelInfo, ErrorResponse |
| | from .llm_manager import LLMManager |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | llm_manager: LLMManager = None |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Manage application lifespan.""" |
| | global llm_manager |
| |
|
| | |
| | logger.info("Starting up LLM API...") |
| | llm_manager = LLMManager() |
| |
|
| | |
| | success = await llm_manager.load_model() |
| | if not success: |
| | logger.warning("Failed to load model, using mock implementation") |
| |
|
| | yield |
| |
|
| | |
| | logger.info("Shutting down LLM API...") |
| |
|
| |
|
| | |
| | app = FastAPI( |
| | title="LLM API - GPT Clone", |
| | description="A ChatGPT-like API with SSE streaming support using free LLM models", |
| | version="1.0.0", |
| | lifespan=lifespan, |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | @app.get("/", response_model=dict) |
| | async def root(): |
| | """Root endpoint with API information.""" |
| | return { |
| | "message": "LLM API - GPT Clone", |
| | "version": "1.0.0", |
| | "description": "A ChatGPT-like API with SSE streaming support", |
| | "endpoints": { |
| | "chat": "/v1/chat/completions", |
| | "models": "/v1/models", |
| | "health": "/health", |
| | }, |
| | } |
| |
|
| |
|
| | @app.get("/health", response_model=dict) |
| | async def health_check(): |
| | """Health check endpoint.""" |
| | global llm_manager |
| |
|
| | return { |
| | "status": "healthy", |
| | "model_loaded": llm_manager.is_loaded if llm_manager else False, |
| | "model_type": llm_manager.model_type if llm_manager else "none", |
| | "timestamp": int(time.time()), |
| | } |
| |
|
| |
|
| | @app.get("/v1/models", response_model=dict) |
| | async def list_models(): |
| | """List available models.""" |
| | global llm_manager |
| |
|
| | if not llm_manager: |
| | raise HTTPException(status_code=503, detail="Model manager not initialized") |
| |
|
| | model_info = llm_manager.get_model_info() |
| |
|
| | return {"object": "list", "data": [model_info]} |
| |
|
| |
|
| | @app.post("/v1/chat/completions") |
| | async def chat_completions(request: ChatRequest): |
| | """Chat completion endpoint with SSE streaming support.""" |
| | global llm_manager |
| |
|
| | if not llm_manager: |
| | raise HTTPException(status_code=503, detail="Model manager not initialized") |
| |
|
| | if not llm_manager.is_loaded: |
| | raise HTTPException(status_code=503, detail="Model not loaded") |
| |
|
| | |
| | if not request.messages: |
| | raise HTTPException(status_code=400, detail="Messages cannot be empty") |
| |
|
| | |
| | if request.stream: |
| | return EventSourceResponse( |
| | stream_chat_response(request), media_type="text/event-stream" |
| | ) |
| | else: |
| | |
| | full_response = "" |
| | async for chunk in llm_manager.generate_stream(request): |
| | if "error" in chunk: |
| | raise HTTPException(status_code=500, detail=chunk["error"]["message"]) |
| |
|
| | if "choices" in chunk and chunk["choices"]: |
| | choice = chunk["choices"][0] |
| | if "delta" in choice and "content" in choice["delta"]: |
| | full_response += choice["delta"]["content"] |
| |
|
| | |
| | return ChatResponse( |
| | id=chunk["id"], |
| | created=chunk["created"], |
| | model=chunk["model"], |
| | choices=[ |
| | { |
| | "index": 0, |
| | "message": {"role": "assistant", "content": full_response}, |
| | "finish_reason": "stop", |
| | } |
| | ], |
| | usage={ |
| | "prompt_tokens": len(full_response.split()), |
| | "completion_tokens": len(full_response.split()), |
| | "total_tokens": len(full_response.split()) * 2, |
| | }, |
| | ) |
| |
|
| |
|
| | async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[dict, None]: |
| | """Stream chat response tokens via SSE.""" |
| | global llm_manager |
| |
|
| | try: |
| | async for chunk in llm_manager.generate_stream(request): |
| | if "error" in chunk: |
| | |
| | yield {"event": "error", "data": json.dumps(chunk["error"])} |
| | return |
| |
|
| | |
| | yield {"event": "message", "data": json.dumps(chunk)} |
| |
|
| | |
| | if ( |
| | chunk.get("choices") |
| | and chunk["choices"][0].get("finish_reason") == "stop" |
| | ): |
| | break |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in stream_chat_response: {e}") |
| | yield { |
| | "event": "error", |
| | "data": json.dumps({"error": {"message": str(e), "type": "stream_error"}}), |
| | } |
| |
|
| |
|
| | @app.exception_handler(HTTPException) |
| | async def http_exception_handler(request: Request, exc: HTTPException): |
| | """Handle HTTP exceptions.""" |
| | return JSONResponse( |
| | status_code=exc.status_code, |
| | content={ |
| | "error": { |
| | "message": exc.detail, |
| | "type": "http_error", |
| | "code": exc.status_code, |
| | } |
| | }, |
| | ) |
| |
|
| |
|
| | @app.exception_handler(Exception) |
| | async def general_exception_handler(request: Request, exc: Exception): |
| | """Handle general exceptions.""" |
| | logger.error(f"Unhandled exception: {exc}") |
| | return JSONResponse( |
| | status_code=500, |
| | content={ |
| | "error": { |
| | "message": "Internal server error", |
| | "type": "internal_error", |
| | "code": 500, |
| | } |
| | }, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run( |
| | "app.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info" |
| | ) |
| |
|