File size: 3,622 Bytes
aabd1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import time
from qdrant_client import QdrantClient
from qdrant_client.http import models
from src.embeddings import get_embedding_model
from dotenv import load_dotenv

# Load secrets
load_dotenv()

class SemanticCache:
    def __init__(self, collection_name: str = "pro_rag_cache"):
        self.collection_name = collection_name
        
        # --- CONNECTION LOGIC ---
        qdrant_url = os.getenv("QDRANT_URL")
        qdrant_key = os.getenv("QDRANT_API_KEY")

        if qdrant_url and qdrant_key:
            print(f"☁️ [Cache] Connecting to Qdrant Cloud...")
            self.client = QdrantClient(url=qdrant_url, api_key=qdrant_key)
        else:
            print(f"🏠 [Cache] Connecting to Local Docker...")
            self.client = QdrantClient(url="http://localhost:6333")
            
        self.embedding_model = get_embedding_model()
        self.threshold = 0.92

        # Initialize Cache Collection
        try:
            if not self.client.collection_exists(collection_name):
                print(f"⚙️ Initializing Semantic Cache: '{collection_name}'...")
                self.client.create_collection(
                    collection_name=collection_name,
                    vectors_config=models.VectorParams(
                        size=3072,
                        distance=models.Distance.COSINE
                    )
                )
        except Exception as e:
            print(f"⚠️ Cache Initialization Warning: {e}")

    def search_cache(self, query: str):
        """

        Checks if a similar question has already been answered.

        Uses the modern 'query_points' method.

        """
        try:
            # 1. Embed the query
            vector = self.embedding_model.embed_query(query)

            # 2. Search Qdrant Cache Collection (UPDATED METHOD)
            search_result = self.client.query_points(
                collection_name=self.collection_name,
                query=vector,
                limit=1,
                with_payload=True
            ).points

            # 3. Check Threshold
            if search_result:
                best_match = search_result[0]
                if best_match.score >= self.threshold:
                    print(f"⚡ CACHE HIT! (Similarity: {best_match.score:.4f})")
                    return best_match.payload["answer"]
            
            score = search_result[0].score if search_result else 0
            print(f"🐢 CACHE MISS (Best match: {score:.4f})")
            return None
            
        except Exception as e:
            # Print error but don't crash the app
            print(f"⚠️ Cache Search Error: {e}")
            return None

    def add_to_cache(self, query: str, answer: str):
        """

        Saves the Query + Answer pair.

        """
        try:
            vector = self.embedding_model.embed_query(query)
            point_id = int(time.time() * 1000) 
            
            self.client.upsert(
                collection_name=self.collection_name,
                points=[
                    models.PointStruct(
                        id=point_id,
                        vector=vector,
                        payload={
                            "question": query, 
                            "answer": answer,
                            "timestamp": time.time()
                        }
                    )
                ]
            )
        except Exception as e:
            print(f"⚠️ Failed to save to cache: {e}")