ApexRetriever-Pro / pipeline.py
QuantaSparkLabs's picture
Upload folder using huggingface_hub
029ba3a verified
import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
from transformers import pipeline
# =========================================================
# EMBEDDING HELPERS
# =========================================================
def embed_texts(model, texts):
return model.encode(
[
f"Represent this sentence for searching relevant passages: {t}"
for t in texts
],
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=False
)
def embed_query(model, query):
return model.encode(
f"Represent this sentence for searching relevant passages: {query}",
normalize_embeddings=True,
convert_to_numpy=True
)
# =========================================================
# MMR DIVERSITY FILTER
# =========================================================
def mmr(
query_embedding,
candidate_embeddings,
candidate_docs,
top_k=10,
lambda_param=0.7
):
selected = []
selected_indices = []
similarity_to_query = np.dot(
candidate_embeddings,
query_embedding
)
first_idx = np.argmax(similarity_to_query)
selected.append(candidate_docs[first_idx])
selected_indices.append(first_idx)
while len(selected) < min(top_k, len(candidate_docs)):
remaining = list(
set(range(len(candidate_docs))) - set(selected_indices)
)
mmr_scores = []
for idx in remaining:
relevance = similarity_to_query[idx]
diversity = max([
np.dot(
candidate_embeddings[idx],
candidate_embeddings[s]
)
for s in selected_indices
])
score = (
lambda_param * relevance
- (1 - lambda_param) * diversity
)
mmr_scores.append((score, idx))
_, best_idx = max(mmr_scores)
selected.append(candidate_docs[best_idx])
selected_indices.append(best_idx)
return selected
# =========================================================
# APEX RETRIEVER PRO
# =========================================================
class ApexRetrieverPro:
def __init__(self, model_dir="."):
self.device = (
"cuda"
if torch.cuda.is_available()
else "cpu"
)
self.bi = SentenceTransformer(
f"{model_dir}/bi_encoder",
device=self.device
)
self.reranker = CrossEncoder(
f"{model_dir}/reranker",
device=self.device
)
# =====================================================
# LOAD FLAN-T5
# =====================================================
self.generator = pipeline(
"text2text-generation",
model=f"{model_dir}/flan_t5",
device=0 if self.device == "cuda" else -1
)
print("βœ… FLAN-T5 loaded")
self.documents = []
self.bm25 = None
self.index = None
self.embeddings = None
print("βœ… ApexRetrieverPro loaded")
# =====================================================
# INDEX DOCUMENTS
# =====================================================
def index_documents(self, documents):
self.documents = documents
tokenized = [
doc.lower().split()
for doc in documents
]
self.bm25 = BM25Okapi(tokenized)
self.embeddings = embed_texts(
self.bi,
documents
).astype("float32")
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension)
self.index.add(self.embeddings)
print(f"πŸ“š Indexed {len(documents)} documents")
# =====================================================
# RETRIEVE
# =====================================================
def retrieve(
self,
query,
top_k=5,
bm25_k=25,
dense_k=25
):
if self.bm25 is None:
raise ValueError(
"Call index_documents() first."
)
# -------------------------------------------------
# STAGE 1 β€” BM25
# -------------------------------------------------
bm25_scores = self.bm25.get_scores(
query.lower().split()
)
bm25_ids = np.argsort(
bm25_scores
)[::-1][:bm25_k]
# -------------------------------------------------
# STAGE 2 β€” DENSE RETRIEVAL
# -------------------------------------------------
q_emb = embed_query(
self.bi,
query
).astype("float32").reshape(1, -1)
_, dense_ids = self.index.search(
q_emb,
min(dense_k, len(self.documents))
)
dense_ids = dense_ids[0]
# -------------------------------------------------
# HYBRID MERGE
# -------------------------------------------------
seen = set()
candidate_ids = []
for idx in list(bm25_ids) + list(dense_ids):
idx = int(idx)
if idx not in seen:
seen.add(idx)
candidate_ids.append(idx)
candidates = [
self.documents[idx]
for idx in candidate_ids
]
candidate_embeddings = np.array([
self.embeddings[idx]
for idx in candidate_ids
])
# -------------------------------------------------
# STAGE 3 β€” MMR DIVERSITY
# -------------------------------------------------
candidates = mmr(
q_emb.flatten(),
candidate_embeddings,
candidates,
top_k=15
)
# -------------------------------------------------
# STAGE 4 β€” CROSS ENCODER
# -------------------------------------------------
pairs = [
(query, doc)
for doc in candidates
]
scores = self.reranker.predict(
pairs,
show_progress_bar=False
)
ranked = sorted(
zip(scores, candidates),
reverse=True
)
reranked_docs = [
doc
for _, doc in ranked[:top_k]
]
# -------------------------------------------------
# STAGE 5 β€” FINAL RETURN
# -------------------------------------------------
return reranked_docs