Spaces:
Running
Running
| import json | |
| import sys | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| sys.path.insert(0, str(Path(__file__).parent / "src")) | |
| from generate import Generator | |
| from retrieve import USE_CHROMA_CLOUD | |
| _state: dict = {} | |
| limiter = Limiter(key_func=get_remote_address, default_limits=["5/minute"]) | |
| def get_client_ip(request: Request) -> str: | |
| forwarded = request.headers.get("x-forwarded-for") | |
| if forwarded: | |
| # the header can be a comma-separated chain (client, proxy1, proxy2, ...); | |
| # the first entry is the original client | |
| return forwarded.split(",")[0].strip() | |
| if request.client and request.client.host: | |
| return request.client.host | |
| return "unknown" | |
| limiter = Limiter( | |
| key_func=get_client_ip, | |
| default_limits=["5/minute"], | |
| ) | |
| async def lifespan(app: FastAPI): | |
| print("Loading RAG pipeline (embeddings, vectorstore, reranker, LLM client)...") | |
| _state["generator"] = Generator() | |
| print("✅ RAG pipeline ready.") | |
| yield | |
| _state.clear() | |
| app = FastAPI(title="AskTheHandbook", lifespan=lifespan) | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_generator() -> Generator: | |
| gen = _state.get("generator") | |
| if gen is None: | |
| raise HTTPException(status_code=503, detail="RAG pipeline is still starting up.") | |
| return gen | |
| class ChatRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=2000) | |
| top_n: int = Field(4, ge=1, le=10) | |
| class Source(BaseModel): | |
| source: str | |
| section: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| sources: list[Source] | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "chroma_backend": "cloud" if USE_CHROMA_CLOUD else "local", | |
| "pipeline_ready": "generator" in _state, | |
| } | |
| def chat(req: ChatRequest, request: Request): | |
| gen = get_generator() | |
| tokens = list(gen.answer(req.question, top_n=req.top_n)) | |
| return ChatResponse(answer="".join(tokens), sources=gen.get_sources()) | |
| def chat_stream(req: ChatRequest, request: Request): | |
| gen = get_generator() | |
| def event_stream(): | |
| count = 0 | |
| for token in gen.answer(req.question, top_n=req.top_n): | |
| count += 1 | |
| yield f"data: {json.dumps({'token': token})}\n\n" | |
| print(f"[STREAM] loop finished, total tokens yielded: {count}", flush=True) | |
| yield f"data: {json.dumps({'done': True, 'sources': gen.get_sources()})}\n\n" | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| "Connection": "keep-alive", | |
| }, | |
| ) | |
| static_dir = Path(__file__).parent / "static" | |
| if static_dir.exists(): | |
| app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static") | |