Spaces:
Sleeping
Sleeping
| # file: retrieval.py | |
| import time | |
| import asyncio | |
| import numpy as np | |
| import torch | |
| from groq import AsyncGroq | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import CrossEncoder | |
| from sklearn.preprocessing import MinMaxScaler | |
| from torch.nn.functional import cosine_similarity | |
| from typing import List, Dict, Tuple | |
| from langchain.storage import InMemoryStore | |
| from embedding import EmbeddingClient | |
| from langchain_core.documents import Document | |
| # --- Configuration --- | |
| HYDE_MODEL = "llama-3.1-8b-instant" | |
| RERANKER_MODEL = 'BAAI/bge-reranker-base' | |
| #RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2' | |
| INITIAL_K_CANDIDATES = 20 | |
| TOP_K_CHUNKS = 10 | |
| async def generate_hypothetical_document(query: str, groq_api_key: str) -> str: | |
| # ... (this function remains unchanged) ... | |
| if not groq_api_key: | |
| print("Groq API key not set. Skipping HyDE generation.") | |
| return "" | |
| print(f"Starting HyDE generation for query: '{query}'...") | |
| client = AsyncGroq(api_key=groq_api_key) | |
| prompt = ( | |
| f"Write a brief, formal passage that answers the following question. " | |
| f"Use specific terminology as if it were from a larger document. " | |
| f"Do not include the question or conversational text.\n\n" | |
| f"Question: {query}\n\n" | |
| f"Hypothetical Passage:" | |
| ) | |
| try: | |
| chat_completion = await client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model=HYDE_MODEL, | |
| temperature=0.7, | |
| max_tokens=500, | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"An error occurred during HyDE generation: {e}") | |
| return "" | |
| class Retriever: | |
| """Manages hybrid search with parent-child retrieval.""" | |
| def __init__(self, embedding_client: EmbeddingClient): | |
| self.embedding_client = embedding_client | |
| self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device) | |
| self.bm25 = None | |
| self.document_chunks = [] | |
| self.chunk_embeddings = None | |
| self.docstore = InMemoryStore() # <-- ADD THIS | |
| print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.") | |
| def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS | |
| """Builds the search index from child documents and stores parent documents.""" | |
| self.document_chunks = child_documents # Store child docs for mapping | |
| self.docstore = docstore # Store the parent documents | |
| corpus = [doc.page_content for doc in child_documents] | |
| if not corpus: | |
| print("No documents to index.") | |
| return | |
| print("Indexing child documents for retrieval...") | |
| tokenized_corpus = [doc.split(" ") for doc in corpus] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| self.chunk_embeddings = self.embedding_client.create_embeddings(corpus) | |
| print("Indexing complete.") | |
| def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]: | |
| # ... (this function remains unchanged) ... | |
| if self.bm25 is None or self.chunk_embeddings is None: | |
| raise ValueError("Retriever has not been indexed. Call index() first.") | |
| enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query | |
| tokenized_query = query.split(" ") | |
| bm25_scores = self.bm25.get_scores(tokenized_query) | |
| query_embedding = self.embedding_client.create_embeddings([enhanced_query]) | |
| dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten() | |
| scaler = MinMaxScaler() | |
| norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten() | |
| norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten() | |
| combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense | |
| top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES] | |
| return [(idx, combined_scores[idx]) for idx in top_indices] | |
| async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]: | |
| # ... (this function remains unchanged) ... | |
| if not candidates: | |
| return [] | |
| print(f"Reranking {len(candidates)} candidates...") | |
| rerank_input = [[query, chunk["content"]] for chunk in candidates] | |
| rerank_scores = await asyncio.to_thread( | |
| self.reranker.predict, rerank_input, show_progress_bar=False | |
| ) | |
| for candidate, score in zip(candidates, rerank_scores): | |
| candidate['rerank_score'] = score | |
| candidates.sort(key=lambda x: x['rerank_score'], reverse=True) | |
| return candidates[:TOP_K_CHUNKS] | |
| async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS | |
| """Executes the full retrieval pipeline and returns parent documents.""" | |
| print(f"Retrieving documents for query: '{query}'") | |
| # 1. Hybrid search returns indices of the best CHILD documents | |
| initial_candidates_info = self._hybrid_search(query, hyde_doc) | |
| retrieved_child_docs = [{ | |
| "content": self.document_chunks[idx].page_content, | |
| "metadata": self.document_chunks[idx].metadata, | |
| } for idx, score in initial_candidates_info] | |
| # 2. Rerank the CHILD documents | |
| reranked_child_docs = await self._rerank(query, retrieved_child_docs) | |
| # 3. Get the unique parent IDs from the reranked child documents | |
| parent_ids = [] | |
| for doc in reranked_child_docs: | |
| parent_id = doc["metadata"]["parent_id"] | |
| if parent_id not in parent_ids: | |
| parent_ids.append(parent_id) | |
| # 4. Retrieve the full PARENT documents from the docstore | |
| retrieved_parents = self.docstore.mget(parent_ids) | |
| # Filter out any None results in case of a miss | |
| final_parent_docs = [doc for doc in retrieved_parents if doc is not None] | |
| # 5. Format for the generation step | |
| final_context = [{ | |
| "content": doc.page_content, | |
| "metadata": doc.metadata | |
| } for doc in final_parent_docs] | |
| print(f"Retrieved {len(final_context)} final parent chunks for context.") | |
| return final_context |