MyChatbot / app /main.py
Parsa2025AI's picture
fastapi app
8ea787a verified
import os
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from app.rag import RAGPipeline
from app.llm import LLMClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ALLOWED_ORIGINS = [
origin.strip()
for origin in os.getenv(
"ALLOWED_ORIGINS",
"http://localhost:3000,https://aprouhi.com"
).split(",")
if origin.strip()
]
rag: RAGPipeline | None = None
llm: LLMClient | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag, llm
logger.info("Initialising RAG pipeline and LLM client...")
rag = RAGPipeline()
llm = LLMClient()
logger.info("Ready.")
yield
logger.info("Shutting down.")
app = FastAPI(
title="Parsa Rouhi — Chatbot API",
description="Ask anything about Parsa's skills, projects, and experience.",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
class Message(BaseModel):
role: str = Field(..., pattern="^(user|assistant)$")
content: str
class ChatRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=1000)
history: list[Message] = Field(default_factory=list, max_length=20)
class ChatResponse(BaseModel):
response: str
sources_retrieved: int
@app.get("/health")
async def health():
return {"status": "ok", "rag_ready": rag is not None, "llm_ready": llm is not None}
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
if rag is None or llm is None:
raise HTTPException(status_code=503, detail="Service is still initialising. Please retry.")
try:
# Retrieve relevant context
context = rag.retrieve(req.message)
chunks_count = context.count("---") + 1 if context else 0
# Generate response
history = [m.model_dump() for m in req.history]
answer = llm.generate(
user_message=req.message,
context=context,
history=history,
)
return ChatResponse(response=answer, sources_retrieved=chunks_count)
except Exception as e:
logger.exception("Error during chat generation")
raise HTTPException(status_code=500, detail=str(e))