Ariyan-Pro's picture
Deploy RAG Latency Optimization v1.0
04ab625
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import time
import psutil
import os
from typing import Optional, List
from datetime import datetime
from app.rag_naive import NaiveRAG
from app.rag_optimized import OptimizedRAG
from app.metrics import MetricsTracker
app = FastAPI(title="RAG Latency Demo",
description="CPU-Only Low-Latency RAG System")
# Initialize components
metrics_tracker = MetricsTracker()
naive_rag = NaiveRAG(metrics_tracker)
optimized_rag = OptimizedRAG(metrics_tracker)
class QueryRequest(BaseModel):
question: str
use_optimized: bool = True
top_k: Optional[int] = None
class QueryResponse(BaseModel):
answer: str
latency_ms: float
memory_mb: float
chunks_used: int
model: str
@app.get("/")
async def root():
return {
"message": "RAG Latency Optimization System",
"status": "running",
"endpoints": {
"query": "POST /query",
"metrics": "GET /metrics",
"reset_metrics": "POST /reset_metrics"
}
}
@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
start_time = time.perf_counter()
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
try:
if request.use_optimized:
answer, chunks_used = optimized_rag.query(request.question, request.top_k)
model = "optimized"
else:
answer, chunks_used = naive_rag.query(request.question, request.top_k)
model = "naive"
end_time = time.perf_counter()
final_memory = process.memory_info().rss / 1024 / 1024
latency_ms = (end_time - start_time) * 1000
memory_mb = final_memory - initial_memory
# Store metrics
metrics_tracker.record_query(
model=model,
latency_ms=latency_ms,
memory_mb=memory_mb,
chunks_used=chunks_used,
question_length=len(request.question)
)
return QueryResponse(
answer=answer,
latency_ms=round(latency_ms, 2),
memory_mb=round(memory_mb, 2),
chunks_used=chunks_used,
model=model
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/metrics")
async def get_metrics():
metrics = metrics_tracker.get_summary()
return JSONResponse(content=metrics)
@app.post("/reset_metrics")
async def reset_metrics():
metrics_tracker.reset()
return {"message": "Metrics reset successfully"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)