chat-bot / app /main.py
surahj's picture
Initial commit: LLM Chat Interface for HF Spaces
c2f9396
import os
import time
import json
import logging
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from .models import ChatRequest, ChatResponse, ModelInfo, ErrorResponse
from .llm_manager import LLMManager
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global LLM manager instance
llm_manager: LLMManager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan."""
global llm_manager
# Startup
logger.info("Starting up LLM API...")
llm_manager = LLMManager()
# Load the model
success = await llm_manager.load_model()
if not success:
logger.warning("Failed to load model, using mock implementation")
yield
# Shutdown
logger.info("Shutting down LLM API...")
# Create FastAPI app
app = FastAPI(
title="LLM API - GPT Clone",
description="A ChatGPT-like API with SSE streaming support using free LLM models",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_model=dict)
async def root():
"""Root endpoint with API information."""
return {
"message": "LLM API - GPT Clone",
"version": "1.0.0",
"description": "A ChatGPT-like API with SSE streaming support",
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"health": "/health",
},
}
@app.get("/health", response_model=dict)
async def health_check():
"""Health check endpoint."""
global llm_manager
return {
"status": "healthy",
"model_loaded": llm_manager.is_loaded if llm_manager else False,
"model_type": llm_manager.model_type if llm_manager else "none",
"timestamp": int(time.time()),
}
@app.get("/v1/models", response_model=dict)
async def list_models():
"""List available models."""
global llm_manager
if not llm_manager:
raise HTTPException(status_code=503, detail="Model manager not initialized")
model_info = llm_manager.get_model_info()
return {"object": "list", "data": [model_info]}
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint with SSE streaming support."""
global llm_manager
if not llm_manager:
raise HTTPException(status_code=503, detail="Model manager not initialized")
if not llm_manager.is_loaded:
raise HTTPException(status_code=503, detail="Model not loaded")
# Validate request
if not request.messages:
raise HTTPException(status_code=400, detail="Messages cannot be empty")
# Check if streaming is requested
if request.stream:
return EventSourceResponse(
stream_chat_response(request), media_type="text/event-stream"
)
else:
# Non-streaming response (collect all tokens and return at once)
full_response = ""
async for chunk in llm_manager.generate_stream(request):
if "error" in chunk:
raise HTTPException(status_code=500, detail=chunk["error"]["message"])
if "choices" in chunk and chunk["choices"]:
choice = chunk["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
full_response += choice["delta"]["content"]
# Return complete response
return ChatResponse(
id=chunk["id"],
created=chunk["created"],
model=chunk["model"],
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": full_response},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": len(full_response.split()), # Rough estimate
"completion_tokens": len(full_response.split()),
"total_tokens": len(full_response.split()) * 2,
},
)
async def stream_chat_response(request: ChatRequest) -> AsyncGenerator[dict, None]:
"""Stream chat response tokens via SSE."""
global llm_manager
try:
async for chunk in llm_manager.generate_stream(request):
if "error" in chunk:
# Send error as SSE event
yield {"event": "error", "data": json.dumps(chunk["error"])}
return
# Send chunk as SSE event
yield {"event": "message", "data": json.dumps(chunk)}
# Check if this is the final chunk
if (
chunk.get("choices")
and chunk["choices"][0].get("finish_reason") == "stop"
):
break
except Exception as e:
logger.error(f"Error in stream_chat_response: {e}")
yield {
"event": "error",
"data": json.dumps({"error": {"message": str(e), "type": "stream_error"}}),
}
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""Handle HTTP exceptions."""
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"message": exc.detail,
"type": "http_error",
"code": exc.status_code,
}
},
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions."""
logger.error(f"Unhandled exception: {exc}")
return JSONResponse(
status_code=500,
content={
"error": {
"message": "Internal server error",
"type": "internal_error",
"code": 500,
}
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info"
)