PercivalFletcher commited on
Commit
f69ffb2
·
verified ·
1 Parent(s): 6ba91ab

Update rag_utils.py

Browse files
Files changed (1) hide show
  1. rag_utils.py +61 -53
rag_utils.py CHANGED
@@ -8,46 +8,34 @@ from rank_bm25 import BM25Okapi
8
  from sentence_transformers import SentenceTransformer
9
  from sklearn.preprocessing import MinMaxScaler
10
  import numpy as np
11
- from typing import Any, List
12
  import asyncio
13
- import torch # Import torch for GPU operations
 
 
14
 
15
  # --- Configuration (can be overridden by the calling app) ---
16
  CHUNK_SIZE = 1000
17
  CHUNK_OVERLAP = 200
18
- TOP_K_CHUNKS = 5
 
 
19
  GROQ_MODEL_NAME = "llama3-8b-8192"
20
- EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # A good general-purpose embedding model
21
 
22
  # --- Class for managing the Sentence Transformer model ---
23
  class EmbeddingClient:
24
  """A client for generating text embeddings using a local, open-source model."""
25
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
26
- """
27
- Initializes the SentenceTransformer model and moves it to the GPU if available.
28
- """
29
- # Determine if a GPU is available and set the device accordingly
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  print(f"Using device: {self.device}")
32
-
33
- # Load the model and move it to the determined device (GPU or CPU)
34
  self.model = SentenceTransformer(model_name, device=self.device)
35
  print(f"Sentence Transformer embedding client initialized ({model_name}) on {self.device}.")
36
 
37
  def get_embeddings(self, texts: List[str]) -> torch.Tensor:
38
- """
39
- Generates embeddings for a list of text chunks on the GPU.
40
- Args:
41
- texts: A list of strings (our document chunks) to be embedded.
42
- Returns:
43
- A tensor of embedding vectors on the GPU.
44
- """
45
  if not texts:
46
  return torch.tensor([])
47
-
48
  print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
49
- # The .encode() method efficiently converts a list of texts into embeddings.
50
- # It handles moving the data to the correct device internally.
51
  embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
52
  print("Embeddings generated successfully.")
53
  return embeddings
@@ -56,88 +44,112 @@ class EmbeddingClient:
56
  class HybridSearchManager:
57
  """
58
  Manages the initialization and execution of a hybrid search system
59
- combining BM25 and dense vector search, with GPU acceleration.
60
  """
61
  def __init__(self, embedding_model_name: str = EMBEDDING_MODEL_NAME):
62
  self.bm25_model = None
63
  self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
64
  self.document_chunks = []
65
  self.document_embeddings = None
 
 
66
 
67
  async def initialize_models(self, documents: list[Document]):
68
- """
69
- Initializes BM25 and computes document embeddings on the GPU.
70
- """
71
  self.document_chunks = documents
72
  corpus = [doc.page_content for doc in documents]
73
  if not corpus:
74
  print("No documents to initialize. Skipping model setup.")
75
  return
76
-
77
- # Initialize BM25 model (CPU-bound)
78
  print("Initializing BM25 model...")
79
  tokenized_corpus = [doc.split(" ") for doc in corpus]
80
  self.bm25_model = BM25Okapi(tokenized_corpus)
81
  print("BM25 model initialized.")
82
-
83
- # Compute and store document embeddings on the GPU
84
  print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
85
  self.document_embeddings = self.embedding_client.get_embeddings(corpus)
86
  print("Document embeddings computed.")
87
-
88
- async def perform_hybrid_search(self, query: str, top_k: int) -> list[dict]:
89
  """
90
- Performs a hybrid search using BM25 and dense vectors, with GPU acceleration for dense search.
 
91
  """
92
  if self.bm25_model is None or self.document_embeddings is None:
93
  raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
94
  print(f"Performing hybrid search for query: '{query}' (top_k={top_k})...")
95
 
96
- # BM25 search (CPU-bound)
97
  tokenized_query = query.split(" ")
98
  bm25_scores = self.bm25_model.get_scores(tokenized_query)
99
 
100
- # Dense search (GPU-bound)
101
- # Get query embedding on the GPU
102
  query_embedding = self.embedding_client.get_embeddings([query])
103
-
104
- # Perform cosine similarity on the GPU
105
  from torch.nn.functional import cosine_similarity
106
  dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
107
-
108
- # Move dense scores back to CPU for subsequent processing
109
  dense_scores = dense_scores.cpu().numpy()
110
 
111
  if len(bm25_scores) == 0 or len(dense_scores) == 0:
112
- return []
113
 
114
  scaler = MinMaxScaler()
115
  normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
116
  normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
117
  combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
 
 
118
  ranked_indices = np.argsort(combined_scores)[::-1]
119
- top_k_indices = ranked_indices[:top_k]
 
120
  retrieved_results = []
121
- for idx in top_k_indices:
122
  doc = self.document_chunks[idx]
123
  retrieved_results.append({
124
  "content": doc.page_content,
125
  "document_metadata": doc.metadata
126
  })
127
- print(f"Retrieved {len(retrieved_results)} top chunks using hybrid search.")
128
- return retrieved_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- # --- Helper Functions (remain unchanged as they are not GPU-intensive) ---
131
  def process_markdown_with_manual_sections(
132
  md_file_path: str,
133
  headings_json: dict,
134
  chunk_size: int,
135
  chunk_overlap: int):
136
- """
137
- Processes a markdown document from a file path by segmenting it based on
138
- provided section headings, and then recursively chunking each segment.
139
- Each chunk receives the corresponding section heading as metadata.
140
- """
141
  all_chunks_with_metadata = []
142
  full_text = ""
143
  if not os.path.exists(md_file_path):
@@ -168,7 +180,6 @@ def process_markdown_with_manual_sections(
168
  heading_positions = []
169
  for heading in heading_texts:
170
  pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split()), re.IGNORECASE)
171
-
172
  match = pattern.search(full_text)
173
  if match:
174
  heading_positions.append({"heading_text": heading, "start_index": match.start()})
@@ -176,7 +187,6 @@ def process_markdown_with_manual_sections(
176
  print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")
177
  heading_positions.sort(key=lambda x: x["start_index"])
178
  segments_with_headings = []
179
-
180
  if heading_positions and heading_positions[0]["start_index"] > 0:
181
  preface_text = full_text[:heading_positions[0]["start_index"]].strip()
182
  if preface_text:
@@ -187,12 +197,10 @@ def process_markdown_with_manual_sections(
187
  for i, current_heading_info in enumerate(heading_positions):
188
  start_index = current_heading_info["start_index"]
189
  heading_text = current_heading_info["heading_text"]
190
-
191
  end_index = len(full_text)
192
  if i + 1 < len(heading_positions):
193
  end_index = heading_positions[i+1]["start_index"]
194
  section_content = full_text[start_index:end_index].strip()
195
-
196
  if section_content:
197
  segments_with_headings.append({
198
  "section_heading": heading_text,
 
8
  from sentence_transformers import SentenceTransformer
9
  from sklearn.preprocessing import MinMaxScaler
10
  import numpy as np
11
+ from typing import Any, List, Tuple
12
  import asyncio
13
+ import torch
14
+ import time
15
+ from flashrank import Ranker, RerankRequest # Import the FlashRank library
16
 
17
  # --- Configuration (can be overridden by the calling app) ---
18
  CHUNK_SIZE = 1000
19
  CHUNK_OVERLAP = 200
20
+ TOP_K_CHUNKS = 5 # The final number of chunks to send to the LLM
21
+ # A larger number of initial candidates for reranking
22
+ INITIAL_K_CANDIDATES = 20
23
  GROQ_MODEL_NAME = "llama3-8b-8192"
24
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # A good general-purpose embedding model
25
 
26
  # --- Class for managing the Sentence Transformer model ---
27
  class EmbeddingClient:
28
  """A client for generating text embeddings using a local, open-source model."""
29
  def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
 
 
 
 
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  print(f"Using device: {self.device}")
 
 
32
  self.model = SentenceTransformer(model_name, device=self.device)
33
  print(f"Sentence Transformer embedding client initialized ({model_name}) on {self.device}.")
34
 
35
  def get_embeddings(self, texts: List[str]) -> torch.Tensor:
 
 
 
 
 
 
 
36
  if not texts:
37
  return torch.tensor([])
 
38
  print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
 
 
39
  embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
40
  print("Embeddings generated successfully.")
41
  return embeddings
 
44
  class HybridSearchManager:
45
  """
46
  Manages the initialization and execution of a hybrid search system
47
+ combining BM25, dense vector search, and a fast reranker.
48
  """
49
  def __init__(self, embedding_model_name: str = EMBEDDING_MODEL_NAME):
50
  self.bm25_model = None
51
  self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
52
  self.document_chunks = []
53
  self.document_embeddings = None
54
+ self.reranker = Ranker() # Initialize the FlashRank reranker
55
+ print("FlashRank reranker initialized.")
56
 
57
  async def initialize_models(self, documents: list[Document]):
 
 
 
58
  self.document_chunks = documents
59
  corpus = [doc.page_content for doc in documents]
60
  if not corpus:
61
  print("No documents to initialize. Skipping model setup.")
62
  return
 
 
63
  print("Initializing BM25 model...")
64
  tokenized_corpus = [doc.split(" ") for doc in corpus]
65
  self.bm25_model = BM25Okapi(tokenized_corpus)
66
  print("BM25 model initialized.")
 
 
67
  print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
68
  self.document_embeddings = self.embedding_client.get_embeddings(corpus)
69
  print("Document embeddings computed.")
70
+
71
+ async def perform_hybrid_search(self, query: str, top_k: int) -> Tuple[List[dict], float]:
72
  """
73
+ Performs a hybrid search, then reranks the results, and returns the top chunks
74
+ along with the time taken for reranking.
75
  """
76
  if self.bm25_model is None or self.document_embeddings is None:
77
  raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
78
  print(f"Performing hybrid search for query: '{query}' (top_k={top_k})...")
79
 
80
+ # Get a larger number of chunks for a better reranking pool
81
  tokenized_query = query.split(" ")
82
  bm25_scores = self.bm25_model.get_scores(tokenized_query)
83
 
 
 
84
  query_embedding = self.embedding_client.get_embeddings([query])
 
 
85
  from torch.nn.functional import cosine_similarity
86
  dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
 
 
87
  dense_scores = dense_scores.cpu().numpy()
88
 
89
  if len(bm25_scores) == 0 or len(dense_scores) == 0:
90
+ return [], 0.0
91
 
92
  scaler = MinMaxScaler()
93
  normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
94
  normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
95
  combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
96
+
97
+ # We now get `INITIAL_K_CANDIDATES` documents from the combined search
98
  ranked_indices = np.argsort(combined_scores)[::-1]
99
+ top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]
100
+
101
  retrieved_results = []
102
+ for idx in top_initial_indices:
103
  doc = self.document_chunks[idx]
104
  retrieved_results.append({
105
  "content": doc.page_content,
106
  "document_metadata": doc.metadata
107
  })
108
+
109
+ print(f"Retrieved {len(retrieved_results)} initial chunks. Starting reranking...")
110
+
111
+ # --- Reranking Step with Timing ---
112
+ start_time_rerank = time.perf_counter()
113
+ if not retrieved_results:
114
+ return [], 0.0
115
+
116
+ # FlashRank expects a list of dictionaries with a "text" key
117
+ passages = [{"text": chunk["content"]} for chunk in retrieved_results]
118
+
119
+ # The reranker takes a query and a list of passages and returns a reranked list
120
+ reranked_results = await asyncio.to_thread(
121
+ self.reranker.rerank, RerankRequest(query=query, passages=passages)
122
+ )
123
+
124
+ end_time_rerank = time.perf_counter()
125
+ rerank_time = end_time_rerank - start_time_rerank
126
+
127
+ # Re-map the reranked results back to our original document format
128
+ final_chunks = []
129
+ for res in reranked_results:
130
+ # Find the original chunk based on the text
131
+ original_chunk_data = next(
132
+ (c for c in retrieved_results if c["content"] == res["text"]),
133
+ None
134
+ )
135
+ if original_chunk_data:
136
+ final_chunks.append({
137
+ "content": original_chunk_data["content"],
138
+ "document_metadata": original_chunk_data["document_metadata"],
139
+ "rerank_score": res["score"]
140
+ })
141
+
142
+ # Return the top_k reranked chunks and the timing information
143
+ print(f"Reranking completed in {rerank_time:.4f} seconds. Retrieved {len(final_chunks[:top_k])} top chunks.")
144
+ return final_chunks[:top_k], rerank_time
145
+
146
 
147
+ # --- Helper Functions (remain unchanged) ---
148
  def process_markdown_with_manual_sections(
149
  md_file_path: str,
150
  headings_json: dict,
151
  chunk_size: int,
152
  chunk_overlap: int):
 
 
 
 
 
153
  all_chunks_with_metadata = []
154
  full_text = ""
155
  if not os.path.exists(md_file_path):
 
180
  heading_positions = []
181
  for heading in heading_texts:
182
  pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split()), re.IGNORECASE)
 
183
  match = pattern.search(full_text)
184
  if match:
185
  heading_positions.append({"heading_text": heading, "start_index": match.start()})
 
187
  print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")
188
  heading_positions.sort(key=lambda x: x["start_index"])
189
  segments_with_headings = []
 
190
  if heading_positions and heading_positions[0]["start_index"] > 0:
191
  preface_text = full_text[:heading_positions[0]["start_index"]].strip()
192
  if preface_text:
 
197
  for i, current_heading_info in enumerate(heading_positions):
198
  start_index = current_heading_info["start_index"]
199
  heading_text = current_heading_info["heading_text"]
 
200
  end_index = len(full_text)
201
  if i + 1 < len(heading_positions):
202
  end_index = heading_positions[i+1]["start_index"]
203
  section_content = full_text[start_index:end_index].strip()
 
204
  if section_content:
205
  segments_with_headings.append({
206
  "section_heading": heading_text,