AskTheHandbook / app.py
rishitpant's picture
Update limit
3cbcbf3
Raw
History Blame Contribute Delete
3.53 kB
import json
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
sys.path.insert(0, str(Path(__file__).parent / "src"))
from generate import Generator
from retrieve import USE_CHROMA_CLOUD
_state: dict = {}
limiter = Limiter(key_func=get_remote_address, default_limits=["5/minute"])
def get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
# the header can be a comma-separated chain (client, proxy1, proxy2, ...);
# the first entry is the original client
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
limiter = Limiter(
key_func=get_client_ip,
default_limits=["5/minute"],
)
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading RAG pipeline (embeddings, vectorstore, reranker, LLM client)...")
_state["generator"] = Generator()
print("✅ RAG pipeline ready.")
yield
_state.clear()
app = FastAPI(title="AskTheHandbook", lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def get_generator() -> Generator:
gen = _state.get("generator")
if gen is None:
raise HTTPException(status_code=503, detail="RAG pipeline is still starting up.")
return gen
class ChatRequest(BaseModel):
question: str = Field(..., min_length=1, max_length=2000)
top_n: int = Field(4, ge=1, le=10)
class Source(BaseModel):
source: str
section: str
class ChatResponse(BaseModel):
answer: str
sources: list[Source]
@app.get("/api/health")
def health():
return {
"status": "ok",
"chroma_backend": "cloud" if USE_CHROMA_CLOUD else "local",
"pipeline_ready": "generator" in _state,
}
@app.post("/api/chat", response_model=ChatResponse)
@limiter.limit("5/minute")
def chat(req: ChatRequest, request: Request):
gen = get_generator()
tokens = list(gen.answer(req.question, top_n=req.top_n))
return ChatResponse(answer="".join(tokens), sources=gen.get_sources())
@app.post("/api/chat/stream")
@limiter.limit("5/minute")
def chat_stream(req: ChatRequest, request: Request):
gen = get_generator()
def event_stream():
count = 0
for token in gen.answer(req.question, top_n=req.top_n):
count += 1
yield f"data: {json.dumps({'token': token})}\n\n"
print(f"[STREAM] loop finished, total tokens yielded: {count}", flush=True)
yield f"data: {json.dumps({'done': True, 'sources': gen.get_sources()})}\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
static_dir = Path(__file__).parent / "static"
if static_dir.exists():
app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")