from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel import time import psutil import os from typing import Optional, List from datetime import datetime from app.rag_naive import NaiveRAG from app.rag_optimized import OptimizedRAG from app.metrics import MetricsTracker app = FastAPI(title="RAG Latency Demo", description="CPU-Only Low-Latency RAG System") # Initialize components metrics_tracker = MetricsTracker() naive_rag = NaiveRAG(metrics_tracker) optimized_rag = OptimizedRAG(metrics_tracker) class QueryRequest(BaseModel): question: str use_optimized: bool = True top_k: Optional[int] = None class QueryResponse(BaseModel): answer: str latency_ms: float memory_mb: float chunks_used: int model: str @app.get("/") async def root(): return { "message": "RAG Latency Optimization System", "status": "running", "endpoints": { "query": "POST /query", "metrics": "GET /metrics", "reset_metrics": "POST /reset_metrics" } } @app.post("/query", response_model=QueryResponse) async def process_query(request: QueryRequest): start_time = time.perf_counter() process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss / 1024 / 1024 # MB try: if request.use_optimized: answer, chunks_used = optimized_rag.query(request.question, request.top_k) model = "optimized" else: answer, chunks_used = naive_rag.query(request.question, request.top_k) model = "naive" end_time = time.perf_counter() final_memory = process.memory_info().rss / 1024 / 1024 latency_ms = (end_time - start_time) * 1000 memory_mb = final_memory - initial_memory # Store metrics metrics_tracker.record_query( model=model, latency_ms=latency_ms, memory_mb=memory_mb, chunks_used=chunks_used, question_length=len(request.question) ) return QueryResponse( answer=answer, latency_ms=round(latency_ms, 2), memory_mb=round(memory_mb, 2), chunks_used=chunks_used, model=model ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/metrics") async def get_metrics(): metrics = metrics_tracker.get_summary() return JSONResponse(content=metrics) @app.post("/reset_metrics") async def reset_metrics(): metrics_tracker.reset() return {"message": "Metrics reset successfully"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)