File size: 2,829 Bytes
04ab625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import time
import psutil
import os
from typing import Optional, List
from datetime import datetime

from app.rag_naive import NaiveRAG
from app.rag_optimized import OptimizedRAG
from app.metrics import MetricsTracker

app = FastAPI(title="RAG Latency Demo", 
              description="CPU-Only Low-Latency RAG System")

# Initialize components
metrics_tracker = MetricsTracker()
naive_rag = NaiveRAG(metrics_tracker)
optimized_rag = OptimizedRAG(metrics_tracker)

class QueryRequest(BaseModel):
    question: str
    use_optimized: bool = True
    top_k: Optional[int] = None

class QueryResponse(BaseModel):
    answer: str
    latency_ms: float
    memory_mb: float
    chunks_used: int
    model: str

@app.get("/")
async def root():
    return {
        "message": "RAG Latency Optimization System",
        "status": "running",
        "endpoints": {
            "query": "POST /query",
            "metrics": "GET /metrics",
            "reset_metrics": "POST /reset_metrics"
        }
    }

@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
    start_time = time.perf_counter()
    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss / 1024 / 1024  # MB
    
    try:
        if request.use_optimized:
            answer, chunks_used = optimized_rag.query(request.question, request.top_k)
            model = "optimized"
        else:
            answer, chunks_used = naive_rag.query(request.question, request.top_k)
            model = "naive"
            
        end_time = time.perf_counter()
        final_memory = process.memory_info().rss / 1024 / 1024
        
        latency_ms = (end_time - start_time) * 1000
        memory_mb = final_memory - initial_memory
        
        # Store metrics
        metrics_tracker.record_query(
            model=model,
            latency_ms=latency_ms,
            memory_mb=memory_mb,
            chunks_used=chunks_used,
            question_length=len(request.question)
        )
        
        return QueryResponse(
            answer=answer,
            latency_ms=round(latency_ms, 2),
            memory_mb=round(memory_mb, 2),
            chunks_used=chunks_used,
            model=model
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/metrics")
async def get_metrics():
    metrics = metrics_tracker.get_summary()
    return JSONResponse(content=metrics)

@app.post("/reset_metrics")
async def reset_metrics():
    metrics_tracker.reset()
    return {"message": "Metrics reset successfully"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)