Parsa2025AI commited on
Commit
548482b
·
verified ·
1 Parent(s): 3dc8ae4
Files changed (1) hide show
  1. main.py +94 -0
main.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, Field
8
+
9
+ from app.rag import RAGPipeline
10
+ from app.llm import LLMClient
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ ALLOWED_ORIGINS = [
17
+ origin.strip()
18
+ for origin in os.getenv(
19
+ "ALLOWED_ORIGINS",
20
+ "http://localhost:3000,https://aprouhi.com"
21
+ ).split(",")
22
+ if origin.strip()
23
+ ]
24
+
25
+ rag: RAGPipeline | None = None
26
+ llm: LLMClient | None = None
27
+
28
+ @asynccontextmanager
29
+ async def lifespan(app: FastAPI):
30
+ global rag, llm
31
+ logger.info("Initialising RAG pipeline and LLM client...")
32
+ rag = RAGPipeline()
33
+ llm = LLMClient()
34
+ logger.info("Ready.")
35
+ yield
36
+ logger.info("Shutting down.")
37
+
38
+
39
+ app = FastAPI(
40
+ title="Parsa Rouhi — Chatbot API",
41
+ description="Ask anything about Parsa's skills, projects, and experience.",
42
+ version="1.0.0",
43
+ lifespan=lifespan,
44
+ )
45
+
46
+ app.add_middleware(
47
+ CORSMiddleware,
48
+ allow_origins=ALLOWED_ORIGINS,
49
+ allow_credentials=True,
50
+ allow_methods=["GET", "POST", "OPTIONS"],
51
+ allow_headers=["*"],
52
+ )
53
+
54
+ class Message(BaseModel):
55
+ role: str = Field(..., pattern="^(user|assistant)$")
56
+ content: str
57
+
58
+ class ChatRequest(BaseModel):
59
+ message: str = Field(..., min_length=1, max_length=1000)
60
+ history: list[Message] = Field(default_factory=list, max_length=20)
61
+
62
+ class ChatResponse(BaseModel):
63
+ response: str
64
+ sources_retrieved: int
65
+
66
+
67
+ @app.get("/health")
68
+ async def health():
69
+ return {"status": "ok", "rag_ready": rag is not None, "llm_ready": llm is not None}
70
+
71
+
72
+ @app.post("/chat", response_model=ChatResponse)
73
+ async def chat(req: ChatRequest):
74
+ if rag is None or llm is None:
75
+ raise HTTPException(status_code=503, detail="Service is still initialising. Please retry.")
76
+
77
+ try:
78
+ # Retrieve relevant context
79
+ context = rag.retrieve(req.message)
80
+ chunks_count = context.count("---") + 1 if context else 0
81
+
82
+ # Generate response
83
+ history = [m.model_dump() for m in req.history]
84
+ answer = llm.generate(
85
+ user_message=req.message,
86
+ context=context,
87
+ history=history,
88
+ )
89
+
90
+ return ChatResponse(response=answer, sources_retrieved=chunks_count)
91
+
92
+ except Exception as e:
93
+ logger.exception("Error during chat generation")
94
+ raise HTTPException(status_code=500, detail=str(e))