NavyDevilDoc commited on
Commit
ee438ef
·
verified ·
1 Parent(s): f2d535c

Update src/search.py

Browse files
Files changed (1) hide show
  1. src/search.py +85 -0
src/search.py CHANGED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import pickle
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
+
6
+ INDEX_FILE = "navy_index.faiss"
7
+ META_FILE = "navy_metadata.pkl" # We still use this for fast mapping, or we could query SQL.
8
+
9
+ class SearchEngine:
10
+ def __init__(self):
11
+ # Force CPU
12
+ self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2', device="cpu")
13
+ self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device="cpu")
14
+ self.index = None
15
+ self.metadata = [] # List of dicts: {'doc_id':..., 'text':...}
16
+ self.load_index()
17
+
18
+ def load_index(self):
19
+ if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE):
20
+ try:
21
+ self.index = faiss.read_index(INDEX_FILE)
22
+ with open(META_FILE, "rb") as f: self.metadata = pickle.load(f)
23
+ except:
24
+ self.reset_index()
25
+ else:
26
+ self.reset_index()
27
+
28
+ def reset_index(self):
29
+ self.index = faiss.IndexIDMap(faiss.IndexFlatIP(384))
30
+ self.metadata = []
31
+
32
+ def add_features(self, chunks):
33
+ """
34
+ Embeds chunks and adds to FAISS.
35
+ chunks = [{'text':..., 'doc_id':...}]
36
+ """
37
+ texts = [c["text"] for c in chunks]
38
+ embeddings = self.bi_encoder.encode(texts)
39
+ faiss.normalize_L2(embeddings)
40
+
41
+ start_id = len(self.metadata)
42
+ ids = np.arange(start_id, start_id + len(chunks)).astype('int64')
43
+
44
+ self.index.add_with_ids(embeddings, ids)
45
+ self.metadata.extend(chunks)
46
+ self.save()
47
+
48
+ def save(self):
49
+ faiss.write_index(self.index, INDEX_FILE)
50
+ with open(META_FILE, "wb") as f: pickle.dump(self.metadata, f)
51
+
52
+ def search(self, query, top_k=5):
53
+ if not self.index or self.index.ntotal == 0: return []
54
+
55
+ q_vec = self.bi_encoder.encode([query])
56
+ faiss.normalize_L2(q_vec)
57
+
58
+ # 1. Retrieve Candidate Vectors
59
+ scores, indices = self.index.search(q_vec, min(self.index.ntotal, top_k * 10))
60
+
61
+ candidates = []
62
+ for i, idx in enumerate(indices[0]):
63
+ if idx != -1:
64
+ item = self.metadata[idx]
65
+ candidates.append([query, item['text']])
66
+
67
+ if not candidates: return []
68
+
69
+ # 2. Re-Rank with Cross Encoder
70
+ cross_scores = self.cross_encoder.predict(candidates)
71
+
72
+ # 3. Format Results
73
+ results = []
74
+ for i, idx in enumerate(indices[0]):
75
+ if idx != -1:
76
+ meta = self.metadata[idx]
77
+ results.append({
78
+ "score": cross_scores[i],
79
+ "doc_id": meta["doc_id"],
80
+ "source": meta["source"],
81
+ "snippet": meta["text"]
82
+ })
83
+
84
+ # Sort by Cross-Encoder Score
85
+ return sorted(results, key=lambda x: x['score'], reverse=True)[:top_k]