| from contextlib import asynccontextmanager |
| from datetime import datetime |
| from typing import Optional |
|
|
| import uvicorn |
| from fastapi import FastAPI, HTTPException, status |
| from pydantic import BaseModel, Field |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse |
|
|
| from src.chat_service import chat_service |
|
|
| class ChatRequest(BaseModel): |
| message: str = Field(..., min_length=1, max_length=4096, description="User input message") |
| thread_id: Optional[str] = Field(default="default", description="Conversation ID for memory tracking") |
|
|
| class ChatResponse(BaseModel): |
| response: str = Field(..., description="Assistant's response") |
| thread_id: str = Field(..., description="Conversation ID used for memory tracking") |
| timestamp: datetime = Field(default_factory=datetime.now) |
|
|
| class HealthResponse(BaseModel): |
| status: str = Field(..., description="Service status") |
| timestamp: datetime = Field(default_factory=datetime.now) |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| print("Starting up the application...") |
| yield |
| |
| print("Shutting down the application...") |
|
|
| app = FastAPI( |
| title="9jaLingo RAG Chat API", |
| description="RAG API for interacting with the 9jaLingo support chatbot", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.get("/", tags=["Health"]) |
| async def root(): |
| return { |
| "name": "9jaLingo RAG Chat API", |
| "status": "ok", |
| "docs": "/docs", |
| "health": "/health", |
| } |
|
|
| @app.get("/health", |
| response_model=HealthResponse, |
| status_code=status.HTTP_200_OK, |
| tags=["Health"]) |
| async def health_check(): |
| """ |
| Endpoint to check if the service is running. |
| Returns a 200 OK response if the service is healthy. |
| """ |
| try: |
| return HealthResponse( |
| status="healthy", |
| timestamp=datetime.now() |
| ) |
| except Exception as e: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Service health check failed: {str(e)}" |
| ) |
|
|
| @app.post("/chat", |
| response_model=ChatResponse, |
| status_code=status.HTTP_200_OK, |
| tags=["Chat"]) |
| async def chat_endpoint(request: ChatRequest): |
| try: |
| thread_id = request.thread_id or "default" |
| response = chat_service.chat(request.message, thread_id) |
|
|
| return ChatResponse( |
| response=response, |
| thread_id=thread_id, |
| timestamp=datetime.now() |
| ) |
|
|
| except ValueError as ve: |
| raise HTTPException( |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, |
| detail=str(ve) |
| ) |
| except Exception as e: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Error processing chat request: {str(e)}" |
| ) |
|
|
|
|
| @app.post("/stream", tags=["Chat"]) |
| async def stream_endpoint(request: ChatRequest): |
| try: |
| thread_id = request.thread_id or "default" |
|
|
| def generate(): |
| yield from chat_service.stream(request.message, thread_id) |
|
|
| return StreamingResponse(generate(), media_type="text/plain") |
| except Exception as e: |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail=f"Error streaming chat response: {str(e)}" |
| ) |
|
|
| if __name__ == "__main__": |
| uvicorn.run( |
| "main:app", |
| host="0.0.0.0", |
| port=8000, |
| reload=True, |
| workers=1 |
| ) |