Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import List, Dict | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.documents import Document | |
| from rank_bm25 import BM25Okapi | |
| class RetrievedPassage: | |
| rank: int | |
| score: float | |
| qid: str | |
| text: str | |
| source_question: str | |
| source_answer: str | |
| authors: str = "" | |
| year: str = "" | |
| journal: str = "" | |
| title: str = "" | |
| class BioRetriever: | |
| def __init__(self, vectorstore: FAISS, top_k: int = 10) -> None: | |
| self.vectorstore = vectorstore | |
| self.top_k = top_k | |
| # Build BM25 index on initialization and store the mapping of documents | |
| self._docs = list(self.vectorstore.docstore._dict.values()) | |
| corpus = [doc.page_content.lower().split() for doc in self._docs] | |
| self.bm25 = BM25Okapi(corpus) | |
| def retrieve(self, query_or_queries: str | List[str]) -> list[RetrievedPassage]: | |
| # Handle both single query string or multiple expanded variants | |
| queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries | |
| # Store ranks for RRF. Key: doc_id (using index in self._docs or text as fallback) | |
| rrf_scores: Dict[str, float] = {} | |
| doc_store: Dict[str, Document] = {} | |
| for query in queries: | |
| # 1. Sparse Retrieval (BM25) | |
| tokenized_query = query.lower().split() | |
| bm25_scores = self.bm25.get_scores(tokenized_query) | |
| # Get top_k from BM25 | |
| bm25_top_indices = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:self.top_k] | |
| for rank, idx in enumerate(bm25_top_indices, start=1): | |
| doc = self._docs[idx] | |
| # Combine qid and part of text to create unique id | |
| doc_id = doc.metadata.get("qid", "") + "_" + doc.page_content[:50] | |
| doc_store[doc_id] = doc | |
| rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + (1.0 / (rank + 60)) | |
| # 2. Dense Retrieval (FAISS) | |
| dense_docs_scores = self.vectorstore.similarity_search_with_score(query, k=self.top_k) | |
| # Filter out irrelevant results — L2 distance > 250.0 means too dissimilar | |
| dense_docs_scores = [(doc, score) for doc, score in dense_docs_scores if score < 250.0] | |
| for rank, (doc, _score) in enumerate(dense_docs_scores, start=1): | |
| doc_id = doc.metadata.get("qid", "") + "_" + doc.page_content[:50] | |
| doc_store[doc_id] = doc | |
| rrf_scores[doc_id] = rrf_scores.get(doc_id, 0.0) + (1.0 / (rank + 60)) | |
| # Sort documents by their RRF score | |
| ranked_docs = sorted(rrf_scores.items(), key=lambda item: item[1], reverse=True) | |
| passages: list[RetrievedPassage] = [] | |
| for i, (doc_id, score) in enumerate(ranked_docs[:self.top_k], start=1): | |
| doc = doc_store[doc_id] | |
| passages.append( | |
| RetrievedPassage( | |
| rank=i, | |
| score=float(score), | |
| qid=str(doc.metadata.get("qid", "")), | |
| text=doc.page_content, | |
| source_question=str(doc.metadata.get("question", "")), | |
| source_answer=str(doc.metadata.get("answer", "")), | |
| ) | |
| ) | |
| return passages | |