policy-analysis / utils /retrieve_n_rerank.py
kaburia's picture
utils path
4a2e120
# load the encoded text and vectorstore
from utils.encoding_input import encode_text
from utils.loading_embeddings import get_vectorstore
from sentence_transformers import CrossEncoder
import numpy as np
import faiss
def search_vectorstore(encoded_text, vectorstore, k=5, with_score=False):
"""
Vector similarity search with optional distance/score return.
Args:
encoded_text (np.ndarray | list): 1-D vector.
vectorstore (langchain.vectorstores.faiss.FAISS): your store.
k (int): # of neighbors.
with_score (bool): toggle score output.
Returns:
list: docs or (doc, score) tuples.
"""
q = np.asarray(encoded_text, dtype="float32").reshape(1, -1)
# ---- Use raw FAISS for full control and consistent behavior-------
index = vectorstore.index # faiss.Index
distances, idxs = index.search(q, k) # (1, k) each
docs = [vectorstore.docstore.search(
vectorstore.index_to_docstore_id[i]) for i in idxs[0]]
# Return with or without scores
return list(zip(docs, distances[0])) if with_score else docs
def rerank_cross_encoder(query_text, docs, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", top_m=20, min_score=None):
"""
Returns top_m (doc, score) sorted by score desc. If min_score is set, filters below it.
docs: A list of Document objects.
"""
ce = CrossEncoder(model_name)
# Create pairs of (query_text, document_content)
pairs = [(query_text, doc.page_content) for doc in docs] # Use doc.page_content for the text
scores = ce.predict(pairs) # higher is better
# Pair original documents with their scores and sort
scored_documents = sorted(zip(docs, scores.tolist()), key=lambda x: x[1], reverse=True)
# Apply minimum score filter if specified
if min_score is not None:
scored_documents = [r for r in scored_documents if r[1] >= min_score]
# Return the top_m reranked (Document, score) tuples
return scored_documents[:top_m]
# retrieval and reranking function
def retrieve_and_rerank(query_text, vectorstore, k=50,
rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
top_m=20, min_score=None,
only_docs=True):
# Step 1: Encode the query text
encoded_query = encode_text(query_text)
# Step 2: Retrieve relevant documents from the vectorstore
retrieved_docs = search_vectorstore(encoded_query, vectorstore, k=k)
# get only the documents
retrieved_docs = [doc for doc, _ in retrieved_docs] if isinstance(retrieved_docs[0], tuple) else retrieved_docs
# If no documents are retrieved, return an empty list
if not retrieved_docs:
return []
# Step 3: Rerank the retrieved documents
reranked_docs = rerank_cross_encoder(query_text, retrieved_docs, model_name=rerank_model, top_m=top_m, min_score=min_score)
# If only_docs is True, return just the documents
if only_docs:
return [doc for doc, _ in reranked_docs]
# Otherwise, return the reranked documents with their scores
return reranked_docs