File size: 3,707 Bytes
9f031f6 d1e2169 9f031f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | from contextlib import asynccontextmanager
from datetime import datetime
from typing import Optional
import uvicorn
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from src.chat_service import chat_service
class ChatRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=4096, description="User input message")
thread_id: Optional[str] = Field(default="default", description="Conversation ID for memory tracking")
class ChatResponse(BaseModel):
response: str = Field(..., description="Assistant's response")
thread_id: str = Field(..., description="Conversation ID used for memory tracking")
timestamp: datetime = Field(default_factory=datetime.now)
class HealthResponse(BaseModel):
status: str = Field(..., description="Service status")
timestamp: datetime = Field(default_factory=datetime.now)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Starting up the application...")
yield
# Shutdown
print("Shutting down the application...")
app = FastAPI(
title="9jaLingo RAG Chat API",
description="RAG API for interacting with the 9jaLingo support chatbot",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", tags=["Health"])
async def root():
return {
"name": "9jaLingo RAG Chat API",
"status": "ok",
"docs": "/docs",
"health": "/health",
}
@app.get("/health",
response_model=HealthResponse,
status_code=status.HTTP_200_OK,
tags=["Health"])
async def health_check():
"""
Endpoint to check if the service is running.
Returns a 200 OK response if the service is healthy.
"""
try:
return HealthResponse(
status="healthy",
timestamp=datetime.now()
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Service health check failed: {str(e)}"
)
@app.post("/chat",
response_model=ChatResponse,
status_code=status.HTTP_200_OK,
tags=["Chat"])
async def chat_endpoint(request: ChatRequest):
try:
thread_id = request.thread_id or "default"
response = chat_service.chat(request.message, thread_id)
return ChatResponse(
response=response,
thread_id=thread_id,
timestamp=datetime.now()
)
except ValueError as ve:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=str(ve)
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing chat request: {str(e)}"
)
@app.post("/stream", tags=["Chat"])
async def stream_endpoint(request: ChatRequest):
try:
thread_id = request.thread_id or "default"
def generate():
yield from chat_service.stream(request.message, thread_id)
return StreamingResponse(generate(), media_type="text/plain")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error streaming chat response: {str(e)}"
)
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
workers=1
) |