File size: 1,900 Bytes
45fe8b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import faiss
import pickle
from sentence_transformers import SentenceTransformer
from semantic_cache import SemanticCache
from fuzzy_cluster import load_gmm_model


class QueryEngine:

    def __init__(self):

        # Load embedding model
        self.model = SentenceTransformer("all-MiniLM-L6-v2")

        # Load FAISS index
        self.index = faiss.read_index("models/faiss_index.index")

        # Load documents
        with open("models/documents.pkl", "rb") as f:
            self.documents = pickle.load(f)

        # Load clustering model
        self.gmm = load_gmm_model()

        # Initialize semantic cache
        self.cache = SemanticCache(similarity_threshold=0.75)

    def search(self, query, top_k=5):

        # Step 1 β€” embed query
        query_embedding = self.model.encode([query])

        # Step 2 β€” check semantic cache
        cached = self.cache.search(query_embedding)

        if cached:
            return {
                "query": query,
                "cache_hit": True,
                "matched_query": cached["matched_query"],
                "similarity_score": cached["similarity_score"],
                "results": cached["result"]
            }

        # Step 3 β€” cluster detection
        cluster_probs = self.gmm.predict_proba(query_embedding)
        dominant_cluster = cluster_probs.argmax()

        # Step 4 β€” vector search
        distances, indices = self.index.search(query_embedding, top_k)

        results = [self.documents[i] for i in indices[0]]

        # Step 5 β€” store in cache
        self.cache.add(query, query_embedding, results)

        return {
            "query": query,
            "cache_hit": False,
            "dominant_cluster": int(dominant_cluster),
            "results": results
        }

    def cache_stats(self):
        return self.cache.stats()

    def clear_cache(self):
        self.cache.clear()