File size: 2,953 Bytes
ee438ef
 
34f2da9
ee438ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import faiss
import numpy as np
import os
import pickle
from sentence_transformers import SentenceTransformer, CrossEncoder

INDEX_FILE = "navy_index.faiss"
META_FILE = "navy_metadata.pkl" # We still use this for fast mapping, or we could query SQL.

class SearchEngine:
    def __init__(self):
        # Force CPU
        self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2', device="cpu")
        self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device="cpu")
        self.index = None
        self.metadata = [] # List of dicts: {'doc_id':..., 'text':...}
        self.load_index()

    def load_index(self):
        if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE):
            try:
                self.index = faiss.read_index(INDEX_FILE)
                with open(META_FILE, "rb") as f: self.metadata = pickle.load(f)
            except:
                self.reset_index()
        else:
            self.reset_index()

    def reset_index(self):
        self.index = faiss.IndexIDMap(faiss.IndexFlatIP(384))
        self.metadata = []

    def add_features(self, chunks):
        """
        Embeds chunks and adds to FAISS.
        chunks = [{'text':..., 'doc_id':...}]
        """
        texts = [c["text"] for c in chunks]
        embeddings = self.bi_encoder.encode(texts)
        faiss.normalize_L2(embeddings)
        
        start_id = len(self.metadata)
        ids = np.arange(start_id, start_id + len(chunks)).astype('int64')
        
        self.index.add_with_ids(embeddings, ids)
        self.metadata.extend(chunks)
        self.save()

    def save(self):
        faiss.write_index(self.index, INDEX_FILE)
        with open(META_FILE, "wb") as f: pickle.dump(self.metadata, f)

    def search(self, query, top_k=5):
        if not self.index or self.index.ntotal == 0: return []
        
        q_vec = self.bi_encoder.encode([query])
        faiss.normalize_L2(q_vec)
        
        # 1. Retrieve Candidate Vectors
        scores, indices = self.index.search(q_vec, min(self.index.ntotal, top_k * 10))
        
        candidates = []
        for i, idx in enumerate(indices[0]):
            if idx != -1:
                item = self.metadata[idx]
                candidates.append([query, item['text']])
        
        if not candidates: return []

        # 2. Re-Rank with Cross Encoder
        cross_scores = self.cross_encoder.predict(candidates)
        
        # 3. Format Results
        results = []
        for i, idx in enumerate(indices[0]):
            if idx != -1:
                meta = self.metadata[idx]
                results.append({
                    "score": cross_scores[i],
                    "doc_id": meta["doc_id"],
                    "source": meta["source"],
                    "snippet": meta["text"]
                })
        
        # Sort by Cross-Encoder Score
        return sorted(results, key=lambda x: x['score'], reverse=True)[:top_k]