import torch, numpy as np, faiss from sentence_transformers import SentenceTransformer, CrossEncoder class ApexRetriever: def __init__(self, model_dir="."): self.bi = SentenceTransformer(f"{model_dir}/bi_encoder", device="cuda" if torch.cuda.is_available() else "cpu") self.reranker = CrossEncoder(f"{model_dir}/reranker", device="cuda" if torch.cuda.is_available() else "cpu") self._index, self._documents = None, [] def index_documents(self, documents): self._documents = documents emb = self.bi.encode(documents, normalize_embeddings=True, show_progress_bar=False) self._index = faiss.IndexFlatIP(emb.shape[1]) self._index.add(emb.astype("float32")) def retrieve(self, query, top_k=5, recall_k=100): if self._index is None: raise ValueError("Index documents first.") q_emb = self.bi.encode(query, normalize_embeddings=True).astype("float32") _, indices = self._index.search(np.expand_dims(q_emb, 0), min(recall_k, len(self._documents))) candidates = [self._documents[i] for i in indices[0]] pairs = [(query, d) for d in candidates] scores = self.reranker.predict(pairs) return [d for _, d in sorted(zip(scores, candidates), reverse=True)[:top_k]]