File size: 12,157 Bytes
6e103a7
 
 
 
 
 
 
bfca738
6e103a7
 
f69ffb2
6e103a7
f69ffb2
 
6e103a7
 
 
 
bfca738
f69ffb2
bfca738
d32b359
8a06370
bfca738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e103a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f69ffb2
6e103a7
 
 
 
 
 
bfca738
4300923
 
6e103a7
 
 
 
 
 
 
 
 
 
 
 
 
 
bfca738
 
6e103a7
bfca738
 
6e103a7
 
 
bfca738
6e103a7
bfca738
6e103a7
 
bfca738
6e103a7
 
 
 
 
bfca738
6e103a7
 
 
 
 
bfca738
6e103a7
f69ffb2
bfca738
6e103a7
f69ffb2
6e103a7
 
 
bfca738
 
6e103a7
bfca738
 
 
 
 
 
 
 
f69ffb2
bfca738
f69ffb2
bfca738
 
 
 
 
 
f69ffb2
 
 
 
bfca738
 
 
 
f69ffb2
bfca738
 
 
 
 
 
f69ffb2
bfca738
 
f69ffb2
bfca738
 
6e103a7
 
bfca738
 
6e103a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfca738
6e103a7
 
 
 
 
 
bfca738
 
 
 
 
 
 
 
 
6e103a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfca738
 
 
 
 
 
 
6e103a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfca738
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from groq import AsyncGroq
import json
import re
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder # Added CrossEncoder
from sklearn.preprocessing import MinMaxScaler
import numpy as np
from typing import Any, List, Tuple
import asyncio
import torch
import time

# --- Configuration (can be overridden by the calling app) ---
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TOP_K_CHUNKS = 10  # The final number of chunks to send to the LLM
# A larger number of initial candidates for reranking
INITIAL_K_CANDIDATES = 20
GROQ_MODEL_NAME = "openai/gpt-oss-20b"
HYDE_MODEL = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"

# --- Hypothetical Document Generation and EmbeddingClient remain unchanged ---
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
    """
    Generates a hypothetical document using the Groq API.
    This prompt is generic and does not require prior knowledge of the document style.
    """
    if not groq_api_key:
        print("Groq API key not set. Skipping hypothetical document generation.")
        return ""

    print(f"Starting HyDE generation for query: '{query}'...")
    client = AsyncGroq(api_key=groq_api_key)
    prompt = (
        f"You are a document writer. Your task is to write a brief passage as a section of a document "
        f"that could answer the following question. The passage should use specific terminology and "
        f"a formal tone, as if it were an excerpt from a larger document. Do not include the question, "
        f"and do not add any conversational text. The goal is to create a concise, semantically rich text "
        f"to guide a search engine to find similarly styled and detailed content.\n\n"
        f"Question: {query}\n\n"
        f"Hypothetical Section:"
    )

    try:
        chat_completion = await client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=HYDE_MODEL,
            temperature=0.7,
            max_tokens=500,
        )
        hyde_doc = chat_completion.choices[0].message.content
        print("Hypothetical document generated.")
        return hyde_doc
    except Exception as e:
        print(f"An error occurred during HyDE generation: {e}")
        return ""

class EmbeddingClient:
    """A client for generating text embeddings using a local, open-source model."""
    def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        self.model = SentenceTransformer(model_name, device=self.device)
        print(f"Sentence Transformer embedding client initialized ({model_name}) on {self.device}.")

    def get_embeddings(self, texts: List[str]) -> torch.Tensor:
        if not texts:
            return torch.tensor([])
        print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
        embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
        print("Embeddings generated successfully.")
        return embeddings

# --- Hybrid Search Manager Class ---
class HybridSearchManager:
    """
    Manages the initialization and execution of a hybrid search system
    combining BM25, dense vector search, and a fast reranker.
    """
    def __init__(self, embedding_model_name: str = EMBEDDING_MODEL_NAME):
        self.bm25_model = None
        self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
        self.document_chunks = []
        self.document_embeddings = None
        # Initialize BGE reranker model
        self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu')
        print("ms-marco-MiniLM-L6-v2 Reranker initialized.")

    async def initialize_models(self, documents: list[Document]):
        self.document_chunks = documents
        corpus = [doc.page_content for doc in documents]
        if not corpus:
            print("No documents to initialize. Skipping model setup.")
            return
        print("Initializing BM25 model...")
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        self.bm25_model = BM25Okapi(tokenized_corpus)
        print("BM25 model initialized.")
        print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
        self.document_embeddings = self.embedding_client.get_embeddings(corpus)
        print("Document embeddings computed.")

    async def retrieve_candidates(self, query: str, hyde_doc: str) -> List[dict]:
        """
        Performs a HyDE-enhanced hybrid search to retrieve initial candidates
        without reranking.
        """
        if self.bm25_model is None or self.document_embeddings is None:
            raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
        print(f"Performing hybrid search for candidate retrieval for query: '{query}'...")

        hyde_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
        tokenized_query = query.split(" ")
        bm25_scores = self.bm25_model.get_scores(tokenized_query)
        query_embedding = self.embedding_client.get_embeddings([hyde_query])
        from torch.nn.functional import cosine_similarity
        dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
        dense_scores = dense_scores.cpu().numpy()

        if len(bm25_scores) == 0 or len(dense_scores) == 0:
            return []

        scaler = MinMaxScaler()
        normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
        normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
        combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores

        ranked_indices = np.argsort(combined_scores)[::-1]
        top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]

        retrieved_results = []
        for idx in top_initial_indices:
            doc = self.document_chunks[idx]
            retrieved_results.append({
                "content": doc.page_content,
                "document_metadata": doc.metadata,
                "initial_score": combined_scores[idx] # Optionally store the initial score
            })

        print(f"Retrieved {len(retrieved_results)} initial candidates for reranking.")
        return retrieved_results

    async def rerank_results(self, query: str, retrieved_results: List[dict], top_k: int) -> List[dict]:
        """
        Performs reranking on a list of retrieved candidate documents.
        """
        if not retrieved_results:
            return []

        print(f"Reranking {len(retrieved_results)} candidates for query: '{query}'...")
        start_time_rerank = time.perf_counter()

        rerank_input = [[query, chunk["content"]] for chunk in retrieved_results]
        rerank_scores = await asyncio.to_thread(
            self.reranker.predict, rerank_input, show_progress_bar=False
        )
        
        end_time_rerank = time.perf_counter()
        rerank_time = end_time_rerank - start_time_rerank

        scored_results = list(zip(retrieved_results, rerank_scores))
        scored_results.sort(key=lambda x: x[1], reverse=True)

        final_chunks = []
        for res, score in scored_results[:top_k]:
            final_chunks.append({
                "content": res["content"],
                "document_metadata": res["document_metadata"],
                "rerank_score": score
            })

        print(f"Reranking completed in {rerank_time:.4f} seconds. Returning top {len(final_chunks)} chunks.")
        return final_chunks, rerank_time

# --- Other helper functions (process_markdown_with_recursive_chunking, generate_answer_with_groq) remain unchanged ---
def process_markdown_with_recursive_chunking(
    md_file_path: str,
    chunk_size: int,
    chunk_overlap: int) -> List[Document]:
    all_chunks = []
    full_text = ""
    if not os.path.exists(md_file_path):
        print(f"Error: File not found at '{md_file_path}'")
        return []
    if not os.path.isfile(md_file_path):
        print(f"Error: Path '{md_file_path}' is not a file.")
        return []
    if not md_file_path.lower().endswith(".md"):
        print(f"Warning: File '{md_file_path}' does not have a .md extension.")
    try:
        with open(md_file_path, 'r', encoding='utf-8') as f:
            full_text = f.read()
    except Exception as e:
        print(f"Error reading file '{md_file_path}': {e}")
        return []
    if not full_text:
        print("Input markdown file is empty.")
        return []

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False,
    )

    chunks = text_splitter.split_text(full_text)

    for chunk in chunks:
        all_chunks.append(Document(page_content=chunk, metadata={"document_part": "Whole Document"}))

    print(f"Created {len(all_chunks)} chunks from the entire document.")
    return all_chunks


async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
    """
    Generates an answer using the Groq API based on the query and retrieved chunks' content.
    """
    if not groq_api_key:
        return "Error: Groq API key is not set. Cannot generate answer."
    print("Generating answer with Groq API...")
    client = AsyncGroq(api_key= groq_api_key)
    context_parts = []
    for i, res in enumerate(retrieved_results):
        content = res.get("content", "")
        metadata = res.get("document_metadata", {})
        section_heading = metadata.get("section_heading", "N/A")
        document_part = metadata.get("document_part", "N/A")
        context_parts.append(
            f"--- Context Chunk {i+1} ---\n"
            f"Document Part: {document_part}\n"
            f"Section Heading: {section_heading}\n"
            f"Content: {content}\n"
            f"-------------------------"
        )
    context = "\n\n".join(context_parts)
    prompt = (
        f"You are an expert on the provided document. Your task is to answer the user's question "
        f"based only on the information given. Your answers should be brief, concise, and in a similar style to these examples:\n"
        f"- Yes, outpatient consultations and diagnostic tests are covered, provided they are medically necessary and fall within the specified sub-limits under the plan.\n"
        f"- The policy does not cover any expenses incurred during the first 30 days from the inception of the policy, except in the case of accidents.\n"
        f"- Room rent is covered up to a single private AC room per day unless otherwise specified in the policy schedule.\n"
        f"- Yes, the policy allows for mid-term inclusion of newly married spouses and newborn children, subject to notification and payment of additional premium within the stipulated time frame.\n"
        f"Do not mention or refer to the document or the context in your final answer. If the information required to answer the question is not available in the provided context, state that you do not have enough information.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n\n"
        f"Answer:"
    )
    try:
        chat_completion = await client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=GROQ_MODEL_NAME,
            temperature=0.7,
            max_tokens=500,
        )
        answer = chat_completion.choices[0].message.content
        print("Answer generated successfully.")
        return answer
    except Exception as e:
        print(f"An error occurred during Groq API call: {e}")
        return "Could not generate an answer due to an API error."