Spaces:
Sleeping
Sleeping
| import faiss | |
| import numpy as np | |
| import os | |
| import pickle | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| INDEX_FILE = "navy_index.faiss" | |
| META_FILE = "navy_metadata.pkl" # We still use this for fast mapping, or we could query SQL. | |
| class SearchEngine: | |
| def __init__(self): | |
| # Force CPU | |
| self.bi_encoder = SentenceTransformer('all-MiniLM-L6-v2', device="cpu") | |
| self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device="cpu") | |
| self.index = None | |
| self.metadata = [] # List of dicts: {'doc_id':..., 'text':...} | |
| self.load_index() | |
| def load_index(self): | |
| if os.path.exists(INDEX_FILE) and os.path.exists(META_FILE): | |
| try: | |
| self.index = faiss.read_index(INDEX_FILE) | |
| with open(META_FILE, "rb") as f: self.metadata = pickle.load(f) | |
| except: | |
| self.reset_index() | |
| else: | |
| self.reset_index() | |
| def reset_index(self): | |
| self.index = faiss.IndexIDMap(faiss.IndexFlatIP(384)) | |
| self.metadata = [] | |
| def add_features(self, chunks): | |
| """ | |
| Embeds chunks and adds to FAISS. | |
| chunks = [{'text':..., 'doc_id':...}] | |
| """ | |
| texts = [c["text"] for c in chunks] | |
| embeddings = self.bi_encoder.encode(texts) | |
| faiss.normalize_L2(embeddings) | |
| start_id = len(self.metadata) | |
| ids = np.arange(start_id, start_id + len(chunks)).astype('int64') | |
| self.index.add_with_ids(embeddings, ids) | |
| self.metadata.extend(chunks) | |
| self.save() | |
| def save(self): | |
| faiss.write_index(self.index, INDEX_FILE) | |
| with open(META_FILE, "wb") as f: pickle.dump(self.metadata, f) | |
| def search(self, query, top_k=5): | |
| if not self.index or self.index.ntotal == 0: return [] | |
| q_vec = self.bi_encoder.encode([query]) | |
| faiss.normalize_L2(q_vec) | |
| # 1. Retrieve Candidate Vectors | |
| scores, indices = self.index.search(q_vec, min(self.index.ntotal, top_k * 10)) | |
| candidates = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx != -1: | |
| item = self.metadata[idx] | |
| candidates.append([query, item['text']]) | |
| if not candidates: return [] | |
| # 2. Re-Rank with Cross Encoder | |
| cross_scores = self.cross_encoder.predict(candidates) | |
| # 3. Format Results | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx != -1: | |
| meta = self.metadata[idx] | |
| results.append({ | |
| "score": cross_scores[i], | |
| "doc_id": meta["doc_id"], | |
| "source": meta["source"], | |
| "snippet": meta["text"] | |
| }) | |
| # Sort by Cross-Encoder Score | |
| return sorted(results, key=lambda x: x['score'], reverse=True)[:top_k] |