import os import time from qdrant_client import QdrantClient from qdrant_client.http import models from src.embeddings import get_embedding_model from dotenv import load_dotenv # Load secrets load_dotenv() class SemanticCache: def __init__(self, collection_name: str = "pro_rag_cache"): self.collection_name = collection_name # --- CONNECTION LOGIC --- qdrant_url = os.getenv("QDRANT_URL") qdrant_key = os.getenv("QDRANT_API_KEY") if qdrant_url and qdrant_key: print(f"☁️ [Cache] Connecting to Qdrant Cloud...") self.client = QdrantClient(url=qdrant_url, api_key=qdrant_key) else: print(f"🏠 [Cache] Connecting to Local Docker...") self.client = QdrantClient(url="http://localhost:6333") self.embedding_model = get_embedding_model() self.threshold = 0.92 # Initialize Cache Collection try: if not self.client.collection_exists(collection_name): print(f"⚙️ Initializing Semantic Cache: '{collection_name}'...") self.client.create_collection( collection_name=collection_name, vectors_config=models.VectorParams( size=3072, distance=models.Distance.COSINE ) ) except Exception as e: print(f"⚠️ Cache Initialization Warning: {e}") def search_cache(self, query: str): """ Checks if a similar question has already been answered. Uses the modern 'query_points' method. """ try: # 1. Embed the query vector = self.embedding_model.embed_query(query) # 2. Search Qdrant Cache Collection (UPDATED METHOD) search_result = self.client.query_points( collection_name=self.collection_name, query=vector, limit=1, with_payload=True ).points # 3. Check Threshold if search_result: best_match = search_result[0] if best_match.score >= self.threshold: print(f"⚡ CACHE HIT! (Similarity: {best_match.score:.4f})") return best_match.payload["answer"] score = search_result[0].score if search_result else 0 print(f"🐢 CACHE MISS (Best match: {score:.4f})") return None except Exception as e: # Print error but don't crash the app print(f"⚠️ Cache Search Error: {e}") return None def add_to_cache(self, query: str, answer: str): """ Saves the Query + Answer pair. """ try: vector = self.embedding_model.embed_query(query) point_id = int(time.time() * 1000) self.client.upsert( collection_name=self.collection_name, points=[ models.PointStruct( id=point_id, vector=vector, payload={ "question": query, "answer": answer, "timestamp": time.time() } ) ] ) except Exception as e: print(f"⚠️ Failed to save to cache: {e}")