Spaces:
Paused
Paused
| import os | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from app.rag import RAGPipeline | |
| from app.llm import LLMClient | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| ALLOWED_ORIGINS = [ | |
| origin.strip() | |
| for origin in os.getenv( | |
| "ALLOWED_ORIGINS", | |
| "http://localhost:3000,https://aprouhi.com" | |
| ).split(",") | |
| if origin.strip() | |
| ] | |
| rag: RAGPipeline | None = None | |
| llm: LLMClient | None = None | |
| async def lifespan(app: FastAPI): | |
| global rag, llm | |
| logger.info("Initialising RAG pipeline and LLM client...") | |
| rag = RAGPipeline() | |
| llm = LLMClient() | |
| logger.info("Ready.") | |
| yield | |
| logger.info("Shutting down.") | |
| app = FastAPI( | |
| title="Parsa Rouhi — Chatbot API", | |
| description="Ask anything about Parsa's skills, projects, and experience.", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=ALLOWED_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| class Message(BaseModel): | |
| role: str = Field(..., pattern="^(user|assistant)$") | |
| content: str | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=1000) | |
| history: list[Message] = Field(default_factory=list, max_length=20) | |
| class ChatResponse(BaseModel): | |
| response: str | |
| sources_retrieved: int | |
| async def health(): | |
| return {"status": "ok", "rag_ready": rag is not None, "llm_ready": llm is not None} | |
| async def chat(req: ChatRequest): | |
| if rag is None or llm is None: | |
| raise HTTPException(status_code=503, detail="Service is still initialising. Please retry.") | |
| try: | |
| # Retrieve relevant context | |
| context = rag.retrieve(req.message) | |
| chunks_count = context.count("---") + 1 if context else 0 | |
| # Generate response | |
| history = [m.model_dump() for m in req.history] | |
| answer = llm.generate( | |
| user_message=req.message, | |
| context=context, | |
| history=history, | |
| ) | |
| return ChatResponse(response=answer, sources_retrieved=chunks_count) | |
| except Exception as e: | |
| logger.exception("Error during chat generation") | |
| raise HTTPException(status_code=500, detail=str(e)) |