PercivalFletcher commited on
Commit
bfca738
·
verified ·
1 Parent(s): 1484160

Update rag_utils.py

Browse files
Files changed (1) hide show
  1. rag_utils.py +105 -108
rag_utils.py CHANGED
@@ -5,25 +5,60 @@ from groq import AsyncGroq
5
  import json
6
  import re
7
  from rank_bm25 import BM25Okapi
8
- from sentence_transformers import SentenceTransformer
9
  from sklearn.preprocessing import MinMaxScaler
10
  import numpy as np
11
  from typing import Any, List, Tuple
12
  import asyncio
13
  import torch
14
  import time
15
- from flashrank import Ranker, RerankRequest # Import the FlashRank library
16
 
17
  # --- Configuration (can be overridden by the calling app) ---
18
  CHUNK_SIZE = 1000
19
  CHUNK_OVERLAP = 200
20
- TOP_K_CHUNKS = 5 # The final number of chunks to send to the LLM
21
  # A larger number of initial candidates for reranking
22
- INITIAL_K_CANDIDATES = 20
23
  GROQ_MODEL_NAME = "llama3-8b-8192"
24
- EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # A good general-purpose embedding model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # --- Class for managing the Sentence Transformer model ---
27
  class EmbeddingClient:
28
  """A client for generating text embeddings using a local, open-source model."""
29
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
@@ -51,8 +86,9 @@ class HybridSearchManager:
51
  self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
52
  self.document_chunks = []
53
  self.document_embeddings = None
54
- self.reranker = Ranker() # Initialize the FlashRank reranker
55
- print("FlashRank reranker initialized.")
 
56
 
57
  async def initialize_models(self, documents: list[Document]):
58
  self.document_chunks = documents
@@ -67,90 +103,85 @@ class HybridSearchManager:
67
  print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
68
  self.document_embeddings = self.embedding_client.get_embeddings(corpus)
69
  print("Document embeddings computed.")
70
-
71
- async def perform_hybrid_search(self, query: str, top_k: int) -> Tuple[List[dict], float]:
72
  """
73
- Performs a hybrid search, then reranks the results, and returns the top chunks
74
- along with the time taken for reranking.
75
  """
76
  if self.bm25_model is None or self.document_embeddings is None:
77
  raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
78
- print(f"Performing hybrid search for query: '{query}' (top_k={top_k})...")
79
 
80
- # Get a larger number of chunks for a better reranking pool
81
  tokenized_query = query.split(" ")
82
  bm25_scores = self.bm25_model.get_scores(tokenized_query)
83
-
84
- query_embedding = self.embedding_client.get_embeddings([query])
85
  from torch.nn.functional import cosine_similarity
86
  dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
87
  dense_scores = dense_scores.cpu().numpy()
88
 
89
  if len(bm25_scores) == 0 or len(dense_scores) == 0:
90
- return [], 0.0
91
 
92
  scaler = MinMaxScaler()
93
  normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
94
  normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
95
  combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
96
-
97
- # We now get `INITIAL_K_CANDIDATES` documents from the combined search
98
  ranked_indices = np.argsort(combined_scores)[::-1]
99
  top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]
100
-
101
  retrieved_results = []
102
  for idx in top_initial_indices:
103
  doc = self.document_chunks[idx]
104
  retrieved_results.append({
105
  "content": doc.page_content,
106
- "document_metadata": doc.metadata
 
107
  })
108
-
109
- print(f"Retrieved {len(retrieved_results)} initial chunks. Starting reranking...")
110
-
111
- # --- Reranking Step with Timing ---
112
- start_time_rerank = time.perf_counter()
 
 
 
113
  if not retrieved_results:
114
- return [], 0.0
115
 
116
- # FlashRank expects a list of dictionaries with a "text" key
117
- passages = [{"text": chunk["content"]} for chunk in retrieved_results]
118
-
119
- # The reranker takes a query and a list of passages and returns a reranked list
120
- reranked_results = await asyncio.to_thread(
121
- self.reranker.rerank, RerankRequest(query=query, passages=passages)
122
  )
123
 
124
  end_time_rerank = time.perf_counter()
125
  rerank_time = end_time_rerank - start_time_rerank
126
-
127
- # Re-map the reranked results back to our original document format
 
 
128
  final_chunks = []
129
- for res in reranked_results:
130
- # Find the original chunk based on the text
131
- original_chunk_data = next(
132
- (c for c in retrieved_results if c["content"] == res["text"]),
133
- None
134
- )
135
- if original_chunk_data:
136
- final_chunks.append({
137
- "content": original_chunk_data["content"],
138
- "document_metadata": original_chunk_data["document_metadata"],
139
- "rerank_score": res["score"]
140
- })
141
-
142
- # Return the top_k reranked chunks and the timing information
143
- print(f"Reranking completed in {rerank_time:.4f} seconds. Retrieved {len(final_chunks[:top_k])} top chunks.")
144
- return final_chunks[:top_k], rerank_time
145
 
 
 
146
 
147
- # --- Helper Functions (remain unchanged) ---
148
- def process_markdown_with_manual_sections(
149
  md_file_path: str,
150
- headings_json: dict,
151
  chunk_size: int,
152
- chunk_overlap: int):
153
- all_chunks_with_metadata = []
154
  full_text = ""
155
  if not os.path.exists(md_file_path):
156
  print(f"Error: File not found at '{md_file_path}'")
@@ -169,57 +200,22 @@ def process_markdown_with_manual_sections(
169
  if not full_text:
170
  print("Input markdown file is empty.")
171
  return []
 
172
  text_splitter = RecursiveCharacterTextSplitter(
173
  chunk_size=chunk_size,
174
  chunk_overlap=chunk_overlap,
175
  length_function=len,
176
  is_separator_regex=False,
177
  )
178
- heading_texts = headings_json.get("headings", [])
179
- print(f"Identified headings for segmentation: {heading_texts}")
180
- heading_positions = []
181
- for heading in heading_texts:
182
- pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split()), re.IGNORECASE)
183
- match = pattern.search(full_text)
184
- if match:
185
- heading_positions.append({"heading_text": heading, "start_index": match.start()})
186
- else:
187
- print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")
188
- heading_positions.sort(key=lambda x: x["start_index"])
189
- segments_with_headings = []
190
- if heading_positions and heading_positions[0]["start_index"] > 0:
191
- preface_text = full_text[:heading_positions[0]["start_index"]].strip()
192
- if preface_text:
193
- segments_with_headings.append({
194
- "section_heading": "Document Start/Preface",
195
- "section_text": preface_text
196
- })
197
- for i, current_heading_info in enumerate(heading_positions):
198
- start_index = current_heading_info["start_index"]
199
- heading_text = current_heading_info["heading_text"]
200
- end_index = len(full_text)
201
- if i + 1 < len(heading_positions):
202
- end_index = heading_positions[i+1]["start_index"]
203
- section_content = full_text[start_index:end_index].strip()
204
- if section_content:
205
- segments_with_headings.append({
206
- "section_heading": heading_text,
207
- "section_text": section_content
208
- })
209
- print(f"Created {len(segments_with_headings)} segments based on provided headings.")
210
- for segment in segments_with_headings:
211
- section_heading = segment["section_heading"]
212
- section_text = segment["section_text"]
213
- if section_text:
214
- chunks = text_splitter.split_text(section_text)
215
- for chunk in chunks:
216
- metadata = {
217
- "document_part": "Section",
218
- "section_heading": section_heading,
219
- }
220
- all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata))
221
- print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.")
222
- return all_chunks_with_metadata
223
 
224
  async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
225
  """
@@ -244,12 +240,13 @@ async def generate_answer_with_groq(query: str, retrieved_results: list[dict], g
244
  )
245
  context = "\n\n".join(context_parts)
246
  prompt = (
247
- f"You are a specialized document analyzer assistant. Your task is to answer the user's question "
248
- f"solely based on the provided context. Pay close attention to the section heading and document part "
249
- f"for each context chunk. Ensure your answer incorporates all relevant details, including any legal nuances "
250
- f"and conditions found in the context, and is concise, limited to one or two sentences. "
251
- f"Do not explicitly mention the retrieved chunks. If the answer cannot be found in the provided context, "
252
- f"clearly state that you do not have enough information.\n\n"
 
253
  f"Context:\n{context}\n\n"
254
  f"Question: {query}\n\n"
255
  f"Answer:"
@@ -271,4 +268,4 @@ async def generate_answer_with_groq(query: str, retrieved_results: list[dict], g
271
  return answer
272
  except Exception as e:
273
  print(f"An error occurred during Groq API call: {e}")
274
- return "Could not generate an answer due to an API error."
 
5
  import json
6
  import re
7
  from rank_bm25 import BM25Okapi
8
+ from sentence_transformers import SentenceTransformer, CrossEncoder # Added CrossEncoder
9
  from sklearn.preprocessing import MinMaxScaler
10
  import numpy as np
11
  from typing import Any, List, Tuple
12
  import asyncio
13
  import torch
14
  import time
 
15
 
16
  # --- Configuration (can be overridden by the calling app) ---
17
  CHUNK_SIZE = 1000
18
  CHUNK_OVERLAP = 200
19
+ TOP_K_CHUNKS = 10 # The final number of chunks to send to the LLM
20
  # A larger number of initial candidates for reranking
21
+ INITIAL_K_CANDIDATES = 20
22
  GROQ_MODEL_NAME = "llama3-8b-8192"
23
+ HYDE_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
24
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
25
+
26
+ # --- Hypothetical Document Generation and EmbeddingClient remain unchanged ---
27
+ async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
28
+ """
29
+ Generates a hypothetical document using the Groq API.
30
+ This prompt is generic and does not require prior knowledge of the document style.
31
+ """
32
+ if not groq_api_key:
33
+ print("Groq API key not set. Skipping hypothetical document generation.")
34
+ return ""
35
+
36
+ print(f"Starting HyDE generation for query: '{query}'...")
37
+ client = AsyncGroq(api_key=groq_api_key)
38
+ prompt = (
39
+ f"You are a document writer. Your task is to write a brief passage as a section of a document "
40
+ f"that could answer the following question. The passage should use specific terminology and "
41
+ f"a formal tone, as if it were an excerpt from a larger document. Do not include the question, "
42
+ f"and do not add any conversational text. The goal is to create a concise, semantically rich text "
43
+ f"to guide a search engine to find similarly styled and detailed content.\n\n"
44
+ f"Question: {query}\n\n"
45
+ f"Hypothetical Section:"
46
+ )
47
+
48
+ try:
49
+ chat_completion = await client.chat.completions.create(
50
+ messages=[{"role": "user", "content": prompt}],
51
+ model=HYDE_MODEL,
52
+ temperature=0.7,
53
+ max_tokens=500,
54
+ )
55
+ hyde_doc = chat_completion.choices[0].message.content
56
+ print("Hypothetical document generated.")
57
+ return hyde_doc
58
+ except Exception as e:
59
+ print(f"An error occurred during HyDE generation: {e}")
60
+ return ""
61
 
 
62
  class EmbeddingClient:
63
  """A client for generating text embeddings using a local, open-source model."""
64
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
 
86
  self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
87
  self.document_chunks = []
88
  self.document_embeddings = None
89
+ # Initialize BGE reranker model
90
+ self.reranker = CrossEncoder('BAAI/bge-reranker-base', device='cuda' if torch.cuda.is_available() else 'cpu')
91
+ print("BGE Reranker initialized.")
92
 
93
  async def initialize_models(self, documents: list[Document]):
94
  self.document_chunks = documents
 
103
  print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
104
  self.document_embeddings = self.embedding_client.get_embeddings(corpus)
105
  print("Document embeddings computed.")
106
+
107
+ async def retrieve_candidates(self, query: str, hyde_doc: str) -> List[dict]:
108
  """
109
+ Performs a HyDE-enhanced hybrid search to retrieve initial candidates
110
+ without reranking.
111
  """
112
  if self.bm25_model is None or self.document_embeddings is None:
113
  raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
114
+ print(f"Performing hybrid search for candidate retrieval for query: '{query}'...")
115
 
116
+ hyde_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
117
  tokenized_query = query.split(" ")
118
  bm25_scores = self.bm25_model.get_scores(tokenized_query)
119
+ query_embedding = self.embedding_client.get_embeddings([hyde_query])
 
120
  from torch.nn.functional import cosine_similarity
121
  dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
122
  dense_scores = dense_scores.cpu().numpy()
123
 
124
  if len(bm25_scores) == 0 or len(dense_scores) == 0:
125
+ return []
126
 
127
  scaler = MinMaxScaler()
128
  normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
129
  normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
130
  combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
131
+
 
132
  ranked_indices = np.argsort(combined_scores)[::-1]
133
  top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]
134
+
135
  retrieved_results = []
136
  for idx in top_initial_indices:
137
  doc = self.document_chunks[idx]
138
  retrieved_results.append({
139
  "content": doc.page_content,
140
+ "document_metadata": doc.metadata,
141
+ "initial_score": combined_scores[idx] # Optionally store the initial score
142
  })
143
+
144
+ print(f"Retrieved {len(retrieved_results)} initial candidates for reranking.")
145
+ return retrieved_results
146
+
147
+ async def rerank_results(self, query: str, retrieved_results: List[dict], top_k: int) -> List[dict]:
148
+ """
149
+ Performs reranking on a list of retrieved candidate documents.
150
+ """
151
  if not retrieved_results:
152
+ return []
153
 
154
+ print(f"Reranking {len(retrieved_results)} candidates for query: '{query}'...")
155
+ start_time_rerank = time.perf_counter()
156
+
157
+ rerank_input = [[query, chunk["content"]] for chunk in retrieved_results]
158
+ rerank_scores = await asyncio.to_thread(
159
+ self.reranker.predict, rerank_input, show_progress_bar=False
160
  )
161
 
162
  end_time_rerank = time.perf_counter()
163
  rerank_time = end_time_rerank - start_time_rerank
164
+
165
+ scored_results = list(zip(retrieved_results, rerank_scores))
166
+ scored_results.sort(key=lambda x: x[1], reverse=True)
167
+
168
  final_chunks = []
169
+ for res, score in scored_results[:top_k]:
170
+ final_chunks.append({
171
+ "content": res["content"],
172
+ "document_metadata": res["document_metadata"],
173
+ "rerank_score": score
174
+ })
 
 
 
 
 
 
 
 
 
 
175
 
176
+ print(f"Reranking completed in {rerank_time:.4f} seconds. Returning top {len(final_chunks)} chunks.")
177
+ return final_chunks, rerank_time
178
 
179
+ # --- Other helper functions (process_markdown_with_recursive_chunking, generate_answer_with_groq) remain unchanged ---
180
+ def process_markdown_with_recursive_chunking(
181
  md_file_path: str,
 
182
  chunk_size: int,
183
+ chunk_overlap: int) -> List[Document]:
184
+ all_chunks = []
185
  full_text = ""
186
  if not os.path.exists(md_file_path):
187
  print(f"Error: File not found at '{md_file_path}'")
 
200
  if not full_text:
201
  print("Input markdown file is empty.")
202
  return []
203
+
204
  text_splitter = RecursiveCharacterTextSplitter(
205
  chunk_size=chunk_size,
206
  chunk_overlap=chunk_overlap,
207
  length_function=len,
208
  is_separator_regex=False,
209
  )
210
+
211
+ chunks = text_splitter.split_text(full_text)
212
+
213
+ for chunk in chunks:
214
+ all_chunks.append(Document(page_content=chunk, metadata={"document_part": "Whole Document"}))
215
+
216
+ print(f"Created {len(all_chunks)} chunks from the entire document.")
217
+ return all_chunks
218
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
221
  """
 
240
  )
241
  context = "\n\n".join(context_parts)
242
  prompt = (
243
+ f"You are an expert on the provided document. Your task is to answer the user's question "
244
+ f"based only on the information given. Your answers should be brief, concise, and in a similar style to these examples:\n"
245
+ f"- Yes, outpatient consultations and diagnostic tests are covered, provided they are medically necessary and fall within the specified sub-limits under the plan.\n"
246
+ f"- The policy does not cover any expenses incurred during the first 30 days from the inception of the policy, except in the case of accidents.\n"
247
+ f"- Room rent is covered up to a single private AC room per day unless otherwise specified in the policy schedule.\n"
248
+ f"- Yes, the policy allows for mid-term inclusion of newly married spouses and newborn children, subject to notification and payment of additional premium within the stipulated time frame.\n"
249
+ f"Do not mention or refer to the document or the context in your final answer. If the information required to answer the question is not available in the provided context, state that you do not have enough information.\n\n"
250
  f"Context:\n{context}\n\n"
251
  f"Question: {query}\n\n"
252
  f"Answer:"
 
268
  return answer
269
  except Exception as e:
270
  print(f"An error occurred during Groq API call: {e}")
271
+ return "Could not generate an answer due to an API error."