Spaces:
Sleeping
Sleeping
File size: 2,829 Bytes
04ab625 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import time
import psutil
import os
from typing import Optional, List
from datetime import datetime
from app.rag_naive import NaiveRAG
from app.rag_optimized import OptimizedRAG
from app.metrics import MetricsTracker
app = FastAPI(title="RAG Latency Demo",
description="CPU-Only Low-Latency RAG System")
# Initialize components
metrics_tracker = MetricsTracker()
naive_rag = NaiveRAG(metrics_tracker)
optimized_rag = OptimizedRAG(metrics_tracker)
class QueryRequest(BaseModel):
question: str
use_optimized: bool = True
top_k: Optional[int] = None
class QueryResponse(BaseModel):
answer: str
latency_ms: float
memory_mb: float
chunks_used: int
model: str
@app.get("/")
async def root():
return {
"message": "RAG Latency Optimization System",
"status": "running",
"endpoints": {
"query": "POST /query",
"metrics": "GET /metrics",
"reset_metrics": "POST /reset_metrics"
}
}
@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
start_time = time.perf_counter()
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
try:
if request.use_optimized:
answer, chunks_used = optimized_rag.query(request.question, request.top_k)
model = "optimized"
else:
answer, chunks_used = naive_rag.query(request.question, request.top_k)
model = "naive"
end_time = time.perf_counter()
final_memory = process.memory_info().rss / 1024 / 1024
latency_ms = (end_time - start_time) * 1000
memory_mb = final_memory - initial_memory
# Store metrics
metrics_tracker.record_query(
model=model,
latency_ms=latency_ms,
memory_mb=memory_mb,
chunks_used=chunks_used,
question_length=len(request.question)
)
return QueryResponse(
answer=answer,
latency_ms=round(latency_ms, 2),
memory_mb=round(memory_mb, 2),
chunks_used=chunks_used,
model=model
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/metrics")
async def get_metrics():
metrics = metrics_tracker.get_summary()
return JSONResponse(content=metrics)
@app.post("/reset_metrics")
async def reset_metrics():
metrics_tracker.reset()
return {"message": "Metrics reset successfully"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|