File size: 6,306 Bytes
84f4fa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60a0d93
f34b0f8
 
84f4fa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# file: retrieval.py

import time
import asyncio
import numpy as np
import torch
from groq import AsyncGroq
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.nn.functional import cosine_similarity
from typing import List, Dict, Tuple
from langchain.storage import InMemoryStore

from embedding import EmbeddingClient
from langchain_core.documents import Document

# --- Configuration ---
HYDE_MODEL = "llama-3.1-8b-instant"
RERANKER_MODEL = 'BAAI/bge-reranker-base'
#RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
INITIAL_K_CANDIDATES = 20
TOP_K_CHUNKS = 10 

async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
    # ... (this function remains unchanged) ...
    if not groq_api_key:
        print("Groq API key not set. Skipping HyDE generation.")
        return ""

    print(f"Starting HyDE generation for query: '{query}'...")
    client = AsyncGroq(api_key=groq_api_key)
    prompt = (
        f"Write a brief, formal passage that answers the following question. "
        f"Use specific terminology as if it were from a larger document. "
        f"Do not include the question or conversational text.\n\n"
        f"Question: {query}\n\n"
        f"Hypothetical Passage:"
    )

    try:
        chat_completion = await client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=HYDE_MODEL,
            temperature=0.7,
            max_tokens=500,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        print(f"An error occurred during HyDE generation: {e}")
        return ""


class Retriever:
    """Manages hybrid search with parent-child retrieval."""

    def __init__(self, embedding_client: EmbeddingClient):
        self.embedding_client = embedding_client
        self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
        self.bm25 = None
        self.document_chunks = []
        self.chunk_embeddings = None
        self.docstore = InMemoryStore() # <-- ADD THIS
        print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")

    def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS
        """Builds the search index from child documents and stores parent documents."""
        self.document_chunks = child_documents # Store child docs for mapping
        self.docstore = docstore # Store the parent documents
        
        corpus = [doc.page_content for doc in child_documents]
        if not corpus:
            print("No documents to index.")
            return

        print("Indexing child documents for retrieval...")
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        self.bm25 = BM25Okapi(tokenized_corpus)
        self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
        print("Indexing complete.")

    def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
        # ... (this function remains unchanged) ...
        if self.bm25 is None or self.chunk_embeddings is None:
            raise ValueError("Retriever has not been indexed. Call index() first.")

        enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
        tokenized_query = query.split(" ")
        bm25_scores = self.bm25.get_scores(tokenized_query)
        query_embedding = self.embedding_client.create_embeddings([enhanced_query])
        dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()

        scaler = MinMaxScaler()
        norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
        norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
        combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
        
        top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
        return [(idx, combined_scores[idx]) for idx in top_indices]


    async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
        # ... (this function remains unchanged) ...
        if not candidates:
            return []

        print(f"Reranking {len(candidates)} candidates...")
        rerank_input = [[query, chunk["content"]] for chunk in candidates]
        
        rerank_scores = await asyncio.to_thread(
            self.reranker.predict, rerank_input, show_progress_bar=False
        )

        for candidate, score in zip(candidates, rerank_scores):
            candidate['rerank_score'] = score
        
        candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
        return candidates[:TOP_K_CHUNKS]


    async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS
        """Executes the full retrieval pipeline and returns parent documents."""
        print(f"Retrieving documents for query: '{query}'")
        
        # 1. Hybrid search returns indices of the best CHILD documents
        initial_candidates_info = self._hybrid_search(query, hyde_doc)
        
        retrieved_child_docs = [{
            "content": self.document_chunks[idx].page_content,
            "metadata": self.document_chunks[idx].metadata,
        } for idx, score in initial_candidates_info]

        # 2. Rerank the CHILD documents
        reranked_child_docs = await self._rerank(query, retrieved_child_docs)

        # 3. Get the unique parent IDs from the reranked child documents
        parent_ids = []
        for doc in reranked_child_docs:
            parent_id = doc["metadata"]["parent_id"]
            if parent_id not in parent_ids:
                parent_ids.append(parent_id)

        # 4. Retrieve the full PARENT documents from the docstore
        retrieved_parents = self.docstore.mget(parent_ids)
        
        # Filter out any None results in case of a miss
        final_parent_docs = [doc for doc in retrieved_parents if doc is not None]

        # 5. Format for the generation step
        final_context = [{
            "content": doc.page_content,
            "metadata": doc.metadata
        } for doc in final_parent_docs]

        print(f"Retrieved {len(final_context)} final parent chunks for context.")
        return final_context