ParentHackRx / retrieval_parent.py
PercivalFletcher's picture
Update retrieval_parent.py
f34b0f8 verified
# file: retrieval.py
import time
import asyncio
import numpy as np
import torch
from groq import AsyncGroq
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.nn.functional import cosine_similarity
from typing import List, Dict, Tuple
from langchain.storage import InMemoryStore
from embedding import EmbeddingClient
from langchain_core.documents import Document
# --- Configuration ---
HYDE_MODEL = "llama-3.1-8b-instant"
RERANKER_MODEL = 'BAAI/bge-reranker-base'
#RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
INITIAL_K_CANDIDATES = 20
TOP_K_CHUNKS = 10
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
# ... (this function remains unchanged) ...
if not groq_api_key:
print("Groq API key not set. Skipping HyDE generation.")
return ""
print(f"Starting HyDE generation for query: '{query}'...")
client = AsyncGroq(api_key=groq_api_key)
prompt = (
f"Write a brief, formal passage that answers the following question. "
f"Use specific terminology as if it were from a larger document. "
f"Do not include the question or conversational text.\n\n"
f"Question: {query}\n\n"
f"Hypothetical Passage:"
)
try:
chat_completion = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=HYDE_MODEL,
temperature=0.7,
max_tokens=500,
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"An error occurred during HyDE generation: {e}")
return ""
class Retriever:
"""Manages hybrid search with parent-child retrieval."""
def __init__(self, embedding_client: EmbeddingClient):
self.embedding_client = embedding_client
self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
self.bm25 = None
self.document_chunks = []
self.chunk_embeddings = None
self.docstore = InMemoryStore() # <-- ADD THIS
print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS
"""Builds the search index from child documents and stores parent documents."""
self.document_chunks = child_documents # Store child docs for mapping
self.docstore = docstore # Store the parent documents
corpus = [doc.page_content for doc in child_documents]
if not corpus:
print("No documents to index.")
return
print("Indexing child documents for retrieval...")
tokenized_corpus = [doc.split(" ") for doc in corpus]
self.bm25 = BM25Okapi(tokenized_corpus)
self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
print("Indexing complete.")
def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
# ... (this function remains unchanged) ...
if self.bm25 is None or self.chunk_embeddings is None:
raise ValueError("Retriever has not been indexed. Call index() first.")
enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
tokenized_query = query.split(" ")
bm25_scores = self.bm25.get_scores(tokenized_query)
query_embedding = self.embedding_client.create_embeddings([enhanced_query])
dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
scaler = MinMaxScaler()
norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
return [(idx, combined_scores[idx]) for idx in top_indices]
async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
# ... (this function remains unchanged) ...
if not candidates:
return []
print(f"Reranking {len(candidates)} candidates...")
rerank_input = [[query, chunk["content"]] for chunk in candidates]
rerank_scores = await asyncio.to_thread(
self.reranker.predict, rerank_input, show_progress_bar=False
)
for candidate, score in zip(candidates, rerank_scores):
candidate['rerank_score'] = score
candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
return candidates[:TOP_K_CHUNKS]
async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS
"""Executes the full retrieval pipeline and returns parent documents."""
print(f"Retrieving documents for query: '{query}'")
# 1. Hybrid search returns indices of the best CHILD documents
initial_candidates_info = self._hybrid_search(query, hyde_doc)
retrieved_child_docs = [{
"content": self.document_chunks[idx].page_content,
"metadata": self.document_chunks[idx].metadata,
} for idx, score in initial_candidates_info]
# 2. Rerank the CHILD documents
reranked_child_docs = await self._rerank(query, retrieved_child_docs)
# 3. Get the unique parent IDs from the reranked child documents
parent_ids = []
for doc in reranked_child_docs:
parent_id = doc["metadata"]["parent_id"]
if parent_id not in parent_ids:
parent_ids.append(parent_id)
# 4. Retrieve the full PARENT documents from the docstore
retrieved_parents = self.docstore.mget(parent_ids)
# Filter out any None results in case of a miss
final_parent_docs = [doc for doc in retrieved_parents if doc is not None]
# 5. Format for the generation step
final_context = [{
"content": doc.page_content,
"metadata": doc.metadata
} for doc in final_parent_docs]
print(f"Retrieved {len(final_context)} final parent chunks for context.")
return final_context