Spaces:
Sleeping
Sleeping
File size: 2,953 Bytes
ee438ef 34f2da9 ee438ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import faiss
import numpy as np
import os
import pickle
from sentence_transformers import SentenceTransformer, CrossEncoder
INDEX_FILE = "navy_index.faiss"
META_FILE = "navy_metadata.pkl" # We still use this for fast mapping, or we could query SQL.
class SearchEngine:
def __init__(self):
# Force CPU
self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2', device="cpu")
self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device="cpu")
self.index = None
self.metadata = [] # List of dicts: {'doc_id':..., 'text':...}
self.load_index()
def load_index(self):
if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE):
try:
self.index = faiss.read_index(INDEX_FILE)
with open(META_FILE, "rb") as f: self.metadata = pickle.load(f)
except:
self.reset_index()
else:
self.reset_index()
def reset_index(self):
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(384))
self.metadata = []
def add_features(self, chunks):
"""
Embeds chunks and adds to FAISS.
chunks = [{'text':..., 'doc_id':...}]
"""
texts = [c["text"] for c in chunks]
embeddings = self.bi_encoder.encode(texts)
faiss.normalize_L2(embeddings)
start_id = len(self.metadata)
ids = np.arange(start_id, start_id + len(chunks)).astype('int64')
self.index.add_with_ids(embeddings, ids)
self.metadata.extend(chunks)
self.save()
def save(self):
faiss.write_index(self.index, INDEX_FILE)
with open(META_FILE, "wb") as f: pickle.dump(self.metadata, f)
def search(self, query, top_k=5):
if not self.index or self.index.ntotal == 0: return []
q_vec = self.bi_encoder.encode([query])
faiss.normalize_L2(q_vec)
# 1. Retrieve Candidate Vectors
scores, indices = self.index.search(q_vec, min(self.index.ntotal, top_k * 10))
candidates = []
for i, idx in enumerate(indices[0]):
if idx != -1:
item = self.metadata[idx]
candidates.append([query, item['text']])
if not candidates: return []
# 2. Re-Rank with Cross Encoder
cross_scores = self.cross_encoder.predict(candidates)
# 3. Format Results
results = []
for i, idx in enumerate(indices[0]):
if idx != -1:
meta = self.metadata[idx]
results.append({
"score": cross_scores[i],
"doc_id": meta["doc_id"],
"source": meta["source"],
"snippet": meta["text"]
})
# Sort by Cross-Encoder Score
return sorted(results, key=lambda x: x['score'], reverse=True)[:top_k] |