# file: retrieval.py import time import asyncio import numpy as np import torch from groq import AsyncGroq from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder from sklearn.preprocessing import MinMaxScaler from torch.nn.functional import cosine_similarity from typing import List, Dict, Tuple from langchain.storage import InMemoryStore from embedding import EmbeddingClient from langchain_core.documents import Document # --- Configuration --- HYDE_MODEL = "llama-3.1-8b-instant" RERANKER_MODEL = 'BAAI/bge-reranker-base' #RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2' INITIAL_K_CANDIDATES = 20 TOP_K_CHUNKS = 10 async def generate_hypothetical_document(query: str, groq_api_key: str) -> str: # ... (this function remains unchanged) ... if not groq_api_key: print("Groq API key not set. Skipping HyDE generation.") return "" print(f"Starting HyDE generation for query: '{query}'...") client = AsyncGroq(api_key=groq_api_key) prompt = ( f"Write a brief, formal passage that answers the following question. " f"Use specific terminology as if it were from a larger document. " f"Do not include the question or conversational text.\n\n" f"Question: {query}\n\n" f"Hypothetical Passage:" ) try: chat_completion = await client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=HYDE_MODEL, temperature=0.7, max_tokens=500, ) return chat_completion.choices[0].message.content except Exception as e: print(f"An error occurred during HyDE generation: {e}") return "" class Retriever: """Manages hybrid search with parent-child retrieval.""" def __init__(self, embedding_client: EmbeddingClient): self.embedding_client = embedding_client self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device) self.bm25 = None self.document_chunks = [] self.chunk_embeddings = None self.docstore = InMemoryStore() # <-- ADD THIS print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.") def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS """Builds the search index from child documents and stores parent documents.""" self.document_chunks = child_documents # Store child docs for mapping self.docstore = docstore # Store the parent documents corpus = [doc.page_content for doc in child_documents] if not corpus: print("No documents to index.") return print("Indexing child documents for retrieval...") tokenized_corpus = [doc.split(" ") for doc in corpus] self.bm25 = BM25Okapi(tokenized_corpus) self.chunk_embeddings = self.embedding_client.create_embeddings(corpus) print("Indexing complete.") def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]: # ... (this function remains unchanged) ... if self.bm25 is None or self.chunk_embeddings is None: raise ValueError("Retriever has not been indexed. Call index() first.") enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query tokenized_query = query.split(" ") bm25_scores = self.bm25.get_scores(tokenized_query) query_embedding = self.embedding_client.create_embeddings([enhanced_query]) dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten() scaler = MinMaxScaler() norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten() norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten() combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES] return [(idx, combined_scores[idx]) for idx in top_indices] async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]: # ... (this function remains unchanged) ... if not candidates: return [] print(f"Reranking {len(candidates)} candidates...") rerank_input = [[query, chunk["content"]] for chunk in candidates] rerank_scores = await asyncio.to_thread( self.reranker.predict, rerank_input, show_progress_bar=False ) for candidate, score in zip(candidates, rerank_scores): candidate['rerank_score'] = score candidates.sort(key=lambda x: x['rerank_score'], reverse=True) return candidates[:TOP_K_CHUNKS] async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS """Executes the full retrieval pipeline and returns parent documents.""" print(f"Retrieving documents for query: '{query}'") # 1. Hybrid search returns indices of the best CHILD documents initial_candidates_info = self._hybrid_search(query, hyde_doc) retrieved_child_docs = [{ "content": self.document_chunks[idx].page_content, "metadata": self.document_chunks[idx].metadata, } for idx, score in initial_candidates_info] # 2. Rerank the CHILD documents reranked_child_docs = await self._rerank(query, retrieved_child_docs) # 3. Get the unique parent IDs from the reranked child documents parent_ids = [] for doc in reranked_child_docs: parent_id = doc["metadata"]["parent_id"] if parent_id not in parent_ids: parent_ids.append(parent_id) # 4. Retrieve the full PARENT documents from the docstore retrieved_parents = self.docstore.mget(parent_ids) # Filter out any None results in case of a miss final_parent_docs = [doc for doc in retrieved_parents if doc is not None] # 5. Format for the generation step final_context = [{ "content": doc.page_content, "metadata": doc.metadata } for doc in final_parent_docs] print(f"Retrieved {len(final_context)} final parent chunks for context.") return final_context