Spaces:
Sleeping
Sleeping
File size: 6,306 Bytes
84f4fa5 60a0d93 f34b0f8 84f4fa5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | # file: retrieval.py
import time
import asyncio
import numpy as np
import torch
from groq import AsyncGroq
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.nn.functional import cosine_similarity
from typing import List, Dict, Tuple
from langchain.storage import InMemoryStore
from embedding import EmbeddingClient
from langchain_core.documents import Document
# --- Configuration ---
HYDE_MODEL = "llama-3.1-8b-instant"
RERANKER_MODEL = 'BAAI/bge-reranker-base'
#RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
INITIAL_K_CANDIDATES = 20
TOP_K_CHUNKS = 10
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
# ... (this function remains unchanged) ...
if not groq_api_key:
print("Groq API key not set. Skipping HyDE generation.")
return ""
print(f"Starting HyDE generation for query: '{query}'...")
client = AsyncGroq(api_key=groq_api_key)
prompt = (
f"Write a brief, formal passage that answers the following question. "
f"Use specific terminology as if it were from a larger document. "
f"Do not include the question or conversational text.\n\n"
f"Question: {query}\n\n"
f"Hypothetical Passage:"
)
try:
chat_completion = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=HYDE_MODEL,
temperature=0.7,
max_tokens=500,
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"An error occurred during HyDE generation: {e}")
return ""
class Retriever:
"""Manages hybrid search with parent-child retrieval."""
def __init__(self, embedding_client: EmbeddingClient):
self.embedding_client = embedding_client
self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
self.bm25 = None
self.document_chunks = []
self.chunk_embeddings = None
self.docstore = InMemoryStore() # <-- ADD THIS
print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS
"""Builds the search index from child documents and stores parent documents."""
self.document_chunks = child_documents # Store child docs for mapping
self.docstore = docstore # Store the parent documents
corpus = [doc.page_content for doc in child_documents]
if not corpus:
print("No documents to index.")
return
print("Indexing child documents for retrieval...")
tokenized_corpus = [doc.split(" ") for doc in corpus]
self.bm25 = BM25Okapi(tokenized_corpus)
self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
print("Indexing complete.")
def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
# ... (this function remains unchanged) ...
if self.bm25 is None or self.chunk_embeddings is None:
raise ValueError("Retriever has not been indexed. Call index() first.")
enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
tokenized_query = query.split(" ")
bm25_scores = self.bm25.get_scores(tokenized_query)
query_embedding = self.embedding_client.create_embeddings([enhanced_query])
dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
scaler = MinMaxScaler()
norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
return [(idx, combined_scores[idx]) for idx in top_indices]
async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
# ... (this function remains unchanged) ...
if not candidates:
return []
print(f"Reranking {len(candidates)} candidates...")
rerank_input = [[query, chunk["content"]] for chunk in candidates]
rerank_scores = await asyncio.to_thread(
self.reranker.predict, rerank_input, show_progress_bar=False
)
for candidate, score in zip(candidates, rerank_scores):
candidate['rerank_score'] = score
candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
return candidates[:TOP_K_CHUNKS]
async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS
"""Executes the full retrieval pipeline and returns parent documents."""
print(f"Retrieving documents for query: '{query}'")
# 1. Hybrid search returns indices of the best CHILD documents
initial_candidates_info = self._hybrid_search(query, hyde_doc)
retrieved_child_docs = [{
"content": self.document_chunks[idx].page_content,
"metadata": self.document_chunks[idx].metadata,
} for idx, score in initial_candidates_info]
# 2. Rerank the CHILD documents
reranked_child_docs = await self._rerank(query, retrieved_child_docs)
# 3. Get the unique parent IDs from the reranked child documents
parent_ids = []
for doc in reranked_child_docs:
parent_id = doc["metadata"]["parent_id"]
if parent_id not in parent_ids:
parent_ids.append(parent_id)
# 4. Retrieve the full PARENT documents from the docstore
retrieved_parents = self.docstore.mget(parent_ids)
# Filter out any None results in case of a miss
final_parent_docs = [doc for doc in retrieved_parents if doc is not None]
# 5. Format for the generation step
final_context = [{
"content": doc.page_content,
"metadata": doc.metadata
} for doc in final_parent_docs]
print(f"Retrieved {len(final_context)} final parent chunks for context.")
return final_context |