Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import time | |
| import psutil | |
| import os | |
| from typing import Optional, List | |
| from datetime import datetime | |
| from app.rag_naive import NaiveRAG | |
| from app.rag_optimized import OptimizedRAG | |
| from app.metrics import MetricsTracker | |
| app = FastAPI(title="RAG Latency Demo", | |
| description="CPU-Only Low-Latency RAG System") | |
| # Initialize components | |
| metrics_tracker = MetricsTracker() | |
| naive_rag = NaiveRAG(metrics_tracker) | |
| optimized_rag = OptimizedRAG(metrics_tracker) | |
| class QueryRequest(BaseModel): | |
| question: str | |
| use_optimized: bool = True | |
| top_k: Optional[int] = None | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| latency_ms: float | |
| memory_mb: float | |
| chunks_used: int | |
| model: str | |
| async def root(): | |
| return { | |
| "message": "RAG Latency Optimization System", | |
| "status": "running", | |
| "endpoints": { | |
| "query": "POST /query", | |
| "metrics": "GET /metrics", | |
| "reset_metrics": "POST /reset_metrics" | |
| } | |
| } | |
| async def process_query(request: QueryRequest): | |
| start_time = time.perf_counter() | |
| process = psutil.Process(os.getpid()) | |
| initial_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| try: | |
| if request.use_optimized: | |
| answer, chunks_used = optimized_rag.query(request.question, request.top_k) | |
| model = "optimized" | |
| else: | |
| answer, chunks_used = naive_rag.query(request.question, request.top_k) | |
| model = "naive" | |
| end_time = time.perf_counter() | |
| final_memory = process.memory_info().rss / 1024 / 1024 | |
| latency_ms = (end_time - start_time) * 1000 | |
| memory_mb = final_memory - initial_memory | |
| # Store metrics | |
| metrics_tracker.record_query( | |
| model=model, | |
| latency_ms=latency_ms, | |
| memory_mb=memory_mb, | |
| chunks_used=chunks_used, | |
| question_length=len(request.question) | |
| ) | |
| return QueryResponse( | |
| answer=answer, | |
| latency_ms=round(latency_ms, 2), | |
| memory_mb=round(memory_mb, 2), | |
| chunks_used=chunks_used, | |
| model=model | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_metrics(): | |
| metrics = metrics_tracker.get_summary() | |
| return JSONResponse(content=metrics) | |
| async def reset_metrics(): | |
| metrics_tracker.reset() | |
| return {"message": "Metrics reset successfully"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |