Pro-RAG-Level1 / src /cache.py
alihaiderscholar's picture
Upload 19 files
aabd1d8 verified
import os
import time
from qdrant_client import QdrantClient
from qdrant_client.http import models
from src.embeddings import get_embedding_model
from dotenv import load_dotenv
# Load secrets
load_dotenv()
class SemanticCache:
def __init__(self, collection_name: str = "pro_rag_cache"):
self.collection_name = collection_name
# --- CONNECTION LOGIC ---
qdrant_url = os.getenv("QDRANT_URL")
qdrant_key = os.getenv("QDRANT_API_KEY")
if qdrant_url and qdrant_key:
print(f"☁️ [Cache] Connecting to Qdrant Cloud...")
self.client = QdrantClient(url=qdrant_url, api_key=qdrant_key)
else:
print(f"🏠 [Cache] Connecting to Local Docker...")
self.client = QdrantClient(url="http://localhost:6333")
self.embedding_model = get_embedding_model()
self.threshold = 0.92
# Initialize Cache Collection
try:
if not self.client.collection_exists(collection_name):
print(f"⚙️ Initializing Semantic Cache: '{collection_name}'...")
self.client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=3072,
distance=models.Distance.COSINE
)
)
except Exception as e:
print(f"⚠️ Cache Initialization Warning: {e}")
def search_cache(self, query: str):
"""
Checks if a similar question has already been answered.
Uses the modern 'query_points' method.
"""
try:
# 1. Embed the query
vector = self.embedding_model.embed_query(query)
# 2. Search Qdrant Cache Collection (UPDATED METHOD)
search_result = self.client.query_points(
collection_name=self.collection_name,
query=vector,
limit=1,
with_payload=True
).points
# 3. Check Threshold
if search_result:
best_match = search_result[0]
if best_match.score >= self.threshold:
print(f"⚡ CACHE HIT! (Similarity: {best_match.score:.4f})")
return best_match.payload["answer"]
score = search_result[0].score if search_result else 0
print(f"🐢 CACHE MISS (Best match: {score:.4f})")
return None
except Exception as e:
# Print error but don't crash the app
print(f"⚠️ Cache Search Error: {e}")
return None
def add_to_cache(self, query: str, answer: str):
"""
Saves the Query + Answer pair.
"""
try:
vector = self.embedding_model.embed_query(query)
point_id = int(time.time() * 1000)
self.client.upsert(
collection_name=self.collection_name,
points=[
models.PointStruct(
id=point_id,
vector=vector,
payload={
"question": query,
"answer": answer,
"timestamp": time.time()
}
)
]
)
except Exception as e:
print(f"⚠️ Failed to save to cache: {e}")