BioRAG / src /bio_rag /retriever.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Dict
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from rank_bm25 import BM25Okapi
@dataclass
class RetrievedPassage:
rank: int
score: float
qid: str
text: str
source_question: str
source_answer: str
authors: str = ""
year: str = ""
journal: str = ""
title: str = ""
class BioRetriever:
def __init__(self, vectorstore: FAISS, top_k: int = 10) -> None:
self.vectorstore = vectorstore
self.top_k = top_k
# Build BM25 index on initialization and store the mapping of documents
self._docs = list(self.vectorstore.docstore._dict.values())
corpus = [doc.page_content.lower().split() for doc in self._docs]
self.bm25 = BM25Okapi(corpus)
def retrieve(self, query_or_queries: str | List[str]) -> list[RetrievedPassage]:
# Handle both single query string or multiple expanded variants
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
# Store ranks for RRF. Key: doc_id (using index in self._docs or text as fallback)
rrf_scores: Dict[str, float] = {}
doc_store: Dict[str, Document] = {}
for query in queries:
# 1. Sparse Retrieval (BM25)
tokenized_query = query.lower().split()
bm25_scores = self.bm25.get_scores(tokenized_query)
# Get top_k from BM25
bm25_top_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:self.top_k]
for rank, idx in enumerate(bm25_top_indices, start=1):
doc = self._docs[idx]
# Combine qid and part of text to create unique id
doc_id = doc.metadata.get("qid", "") + "_" + doc.page_content[:50]
doc_store[doc_id] = doc
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + (1.0 / (rank + 60))
# 2. Dense Retrieval (FAISS)
dense_docs_scores = self.vectorstore.similarity_search_with_score(query, k=self.top_k)
# Filter out irrelevant results — L2 distance > 250.0 means too dissimilar
dense_docs_scores = [(doc, score) for doc, score in dense_docs_scores if score < 250.0]
for rank, (doc, _score) in enumerate(dense_docs_scores, start=1):
doc_id = doc.metadata.get("qid", "") + "_" + doc.page_content[:50]
doc_store[doc_id] = doc
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + (1.0 / (rank + 60))
# Sort documents by their RRF score
ranked_docs = sorted(rrf_scores.items(), key=lambda item: item[1], reverse=True)
passages: list[RetrievedPassage] = []
for i, (doc_id, score) in enumerate(ranked_docs[:self.top_k], start=1):
doc = doc_store[doc_id]
passages.append(
RetrievedPassage(
rank=i,
score=float(score),
qid=str(doc.metadata.get("qid", "")),
text=doc.page_content,
source_question=str(doc.metadata.get("question", "")),
source_answer=str(doc.metadata.get("answer", "")),
)
)
return passages