Ariyan-Pro's picture
Deploy RAG Latency Optimization v1.0
04ab625
"""
Naive RAG Implementation - Baseline for comparison.
No optimizations, no caching, brute-force everything.
"""
import time
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import sqlite3
from typing import List, Tuple, Optional
import hashlib
from pathlib import Path
import psutil
import os
from config import (
EMBEDDING_MODEL, DATA_DIR, FAISS_INDEX_PATH, DOCSTORE_PATH,
CHUNK_SIZE, TOP_K, MAX_TOKENS
)
class NaiveRAG:
"""Baseline naive RAG implementation with no optimizations."""
def __init__(self, metrics_tracker=None):
self.metrics_tracker = metrics_tracker
self.embedder = None
self.faiss_index = None
self.docstore_conn = None
self._initialized = False
self.process = psutil.Process(os.getpid())
def initialize(self):
"""Lazy initialization of components."""
if self._initialized:
return
print("Initializing Naive RAG...")
start_time = time.perf_counter()
# Load embedding model
self.embedder = SentenceTransformer(EMBEDDING_MODEL)
# Load FAISS index
if FAISS_INDEX_PATH.exists():
self.faiss_index = faiss.read_index(str(FAISS_INDEX_PATH))
# Connect to document store
self.docstore_conn = sqlite3.connect(DOCSTORE_PATH)
init_time = (time.perf_counter() - start_time) * 1000
memory_mb = self.process.memory_info().rss / 1024 / 1024
print(f"Naive RAG initialized in {init_time:.2f}ms, Memory: {memory_mb:.2f}MB")
self._initialized = True
def _get_chunks_by_ids(self, chunk_ids: List[int]) -> List[str]:
"""Retrieve chunks from document store by IDs."""
cursor = self.docstore_conn.cursor()
placeholders = ','.join('?' for _ in chunk_ids)
query = f"SELECT chunk_text FROM chunks WHERE id IN ({placeholders})"
cursor.execute(query, chunk_ids)
results = cursor.fetchall()
return [r[0] for r in results]
def _search_faiss(self, query_embedding: np.ndarray, top_k: int = TOP_K) -> List[int]:
"""Brute-force FAISS search."""
if self.faiss_index is None:
raise ValueError("FAISS index not loaded")
# Convert to float32 for FAISS
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
# Search
distances, indices = self.faiss_index.search(query_embedding, top_k)
# Convert to Python list and add 1 (FAISS returns 0-based, DB uses 1-based)
return [int(idx + 1) for idx in indices[0] if idx >= 0]
def _generate_response_naive(self, question: str, chunks: List[str]) -> str:
"""Naive response generation - just concatenate chunks."""
# In a real implementation, this would call an LLM
# For now, we'll simulate a simple response
context = "\n\n".join(chunks[:3]) # Use only first 3 chunks
response = f"Based on the documents:\n\n{context[:300]}..."
# Simulate LLM processing time (100-300ms)
time.sleep(0.2)
return response
def query(self, question: str, top_k: Optional[int] = None) -> Tuple[str, int]:
"""
Process a query using naive RAG.
Args:
question: The user's question
top_k: Number of chunks to retrieve (overrides default)
Returns:
Tuple of (answer, number of chunks used)
"""
if not self._initialized:
self.initialize()
start_time = time.perf_counter()
initial_memory = self.process.memory_info().rss / 1024 / 1024
embedding_time = 0
retrieval_time = 0
generation_time = 0
# Step 1: Embed query (no caching)
embedding_start = time.perf_counter()
query_embedding = self.embedder.encode([question])[0]
embedding_time = (time.perf_counter() - embedding_start) * 1000
# Step 2: Search FAISS (brute force)
retrieval_start = time.perf_counter()
k = top_k or TOP_K
chunk_ids = self._search_faiss(query_embedding, k)
retrieval_time = (time.perf_counter() - retrieval_start) * 1000
# Step 3: Retrieve chunks
chunks = self._get_chunks_by_ids(chunk_ids) if chunk_ids else []
# Step 4: Generate response (naive)
generation_start = time.perf_counter()
answer = self._generate_response_naive(question, chunks)
generation_time = (time.perf_counter() - generation_start) * 1000
total_time = (time.perf_counter() - start_time) * 1000
final_memory = self.process.memory_info().rss / 1024 / 1024
memory_used = final_memory - initial_memory
# Log metrics if tracker is available
if self.metrics_tracker:
self.metrics_tracker.record_query(
model="naive",
latency_ms=total_time,
memory_mb=memory_used,
chunks_used=len(chunks),
question_length=len(question),
embedding_time=embedding_time,
retrieval_time=retrieval_time,
generation_time=generation_time
)
print(f"[Naive RAG] Query: '{question[:50]}...'")
print(f" - Embedding: {embedding_time:.2f}ms")
print(f" - Retrieval: {retrieval_time:.2f}ms")
print(f" - Generation: {generation_time:.2f}ms")
print(f" - Total: {total_time:.2f}ms")
print(f" - Memory used: {memory_used:.2f}MB")
print(f" - Chunks used: {len(chunks)}")
return answer, len(chunks)
def close(self):
"""Clean up resources."""
if self.docstore_conn:
self.docstore_conn.close()
self._initialized = False