ApexRetriever / pipeline.py
QuantaSparkLabs's picture
Upload folder using huggingface_hub
2d9ebee verified
import torch, numpy as np, faiss
from sentence_transformers import SentenceTransformer, CrossEncoder
class ApexRetriever:
def __init__(self, model_dir="."):
self.bi = SentenceTransformer(f"{model_dir}/bi_encoder", device="cuda" if torch.cuda.is_available() else "cpu")
self.reranker = CrossEncoder(f"{model_dir}/reranker", device="cuda" if torch.cuda.is_available() else "cpu")
self._index, self._documents = None, []
def index_documents(self, documents):
self._documents = documents
emb = self.bi.encode(documents, normalize_embeddings=True, show_progress_bar=False)
self._index = faiss.IndexFlatIP(emb.shape[1])
self._index.add(emb.astype("float32"))
def retrieve(self, query, top_k=5, recall_k=100):
if self._index is None: raise ValueError("Index documents first.")
q_emb = self.bi.encode(query, normalize_embeddings=True).astype("float32")
_, indices = self._index.search(np.expand_dims(q_emb, 0), min(recall_k, len(self._documents)))
candidates = [self._documents[i] for i in indices[0]]
pairs = [(query, d) for d in candidates]
scores = self.reranker.predict(pairs)
return [d for _, d in sorted(zip(scores, candidates), reverse=True)[:top_k]]