File size: 2,501 Bytes
548482b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3504e78
548482b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field

from app.rag import RAGPipeline
from app.llm import LLMClient

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


ALLOWED_ORIGINS = [
    origin.strip()
    for origin in os.getenv(
        "ALLOWED_ORIGINS",
        "http://localhost:3000,https://aprouhi.com"
    ).split(",")
    if origin.strip()
]

rag: RAGPipeline | None = None
llm: LLMClient | None = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global rag, llm
    logger.info("Initialising RAG pipeline and LLM client...")
    rag = RAGPipeline()
    llm = LLMClient()
    logger.info("Ready.")
    yield
    logger.info("Shutting down.")


app = FastAPI(
    title="Parsa Rouhi — Chatbot API",
    description="Ask anything about Parsa's skills, projects, and experience.",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["*"],
)

class Message(BaseModel):
    role: str = Field(..., pattern="^(user|assistant)$")
    content: str

class ChatRequest(BaseModel):
    message: str = Field(..., min_length=1, max_length=1000)
    history: list[Message] = Field(default_factory=list, max_length=20)

class ChatResponse(BaseModel):
    response: str
    sources_retrieved: int


@app.get("/health")
async def health():
    return {"status": "ok", "rag_ready": rag is not None, "llm_ready": llm is not None}



@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
    if rag is None or llm is None:
        raise HTTPException(status_code=503, detail="Service is still initialising. Please retry.")

    try:
        # Retrieve relevant context
        context = rag.retrieve(req.message)
        chunks_count = context.count("---") + 1 if context else 0

        # Generate response
        history = [m.model_dump() for m in req.history]
        answer = llm.generate(
            user_message=req.message,
            context=context,
            history=history,
        )

        return ChatResponse(response=answer, sources_retrieved=chunks_count)

    except Exception as e:
        logger.exception("Error during chat generation")
        raise HTTPException(status_code=500, detail=str(e))