Spaces:
Sleeping
Sleeping
Update rag_utils.py
Browse files- rag_utils.py +105 -108
rag_utils.py
CHANGED
|
@@ -5,25 +5,60 @@ from groq import AsyncGroq
|
|
| 5 |
import json
|
| 6 |
import re
|
| 7 |
from rank_bm25 import BM25Okapi
|
| 8 |
-
from sentence_transformers import SentenceTransformer
|
| 9 |
from sklearn.preprocessing import MinMaxScaler
|
| 10 |
import numpy as np
|
| 11 |
from typing import Any, List, Tuple
|
| 12 |
import asyncio
|
| 13 |
import torch
|
| 14 |
import time
|
| 15 |
-
from flashrank import Ranker, RerankRequest # Import the FlashRank library
|
| 16 |
|
| 17 |
# --- Configuration (can be overridden by the calling app) ---
|
| 18 |
CHUNK_SIZE = 1000
|
| 19 |
CHUNK_OVERLAP = 200
|
| 20 |
-
TOP_K_CHUNKS =
|
| 21 |
# A larger number of initial candidates for reranking
|
| 22 |
-
INITIAL_K_CANDIDATES = 20
|
| 23 |
GROQ_MODEL_NAME = "llama3-8b-8192"
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
# --- Class for managing the Sentence Transformer model ---
|
| 27 |
class EmbeddingClient:
|
| 28 |
"""A client for generating text embeddings using a local, open-source model."""
|
| 29 |
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
|
|
@@ -51,8 +86,9 @@ class HybridSearchManager:
|
|
| 51 |
self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
|
| 52 |
self.document_chunks = []
|
| 53 |
self.document_embeddings = None
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
async def initialize_models(self, documents: list[Document]):
|
| 58 |
self.document_chunks = documents
|
|
@@ -67,90 +103,85 @@ class HybridSearchManager:
|
|
| 67 |
print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
|
| 68 |
self.document_embeddings = self.embedding_client.get_embeddings(corpus)
|
| 69 |
print("Document embeddings computed.")
|
| 70 |
-
|
| 71 |
-
async def
|
| 72 |
"""
|
| 73 |
-
Performs a hybrid search
|
| 74 |
-
|
| 75 |
"""
|
| 76 |
if self.bm25_model is None or self.document_embeddings is None:
|
| 77 |
raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
|
| 78 |
-
print(f"Performing hybrid search for query: '{query}'
|
| 79 |
|
| 80 |
-
|
| 81 |
tokenized_query = query.split(" ")
|
| 82 |
bm25_scores = self.bm25_model.get_scores(tokenized_query)
|
| 83 |
-
|
| 84 |
-
query_embedding = self.embedding_client.get_embeddings([query])
|
| 85 |
from torch.nn.functional import cosine_similarity
|
| 86 |
dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
|
| 87 |
dense_scores = dense_scores.cpu().numpy()
|
| 88 |
|
| 89 |
if len(bm25_scores) == 0 or len(dense_scores) == 0:
|
| 90 |
-
return []
|
| 91 |
|
| 92 |
scaler = MinMaxScaler()
|
| 93 |
normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
|
| 94 |
normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
|
| 95 |
combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
|
| 96 |
-
|
| 97 |
-
# We now get `INITIAL_K_CANDIDATES` documents from the combined search
|
| 98 |
ranked_indices = np.argsort(combined_scores)[::-1]
|
| 99 |
top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]
|
| 100 |
-
|
| 101 |
retrieved_results = []
|
| 102 |
for idx in top_initial_indices:
|
| 103 |
doc = self.document_chunks[idx]
|
| 104 |
retrieved_results.append({
|
| 105 |
"content": doc.page_content,
|
| 106 |
-
"document_metadata": doc.metadata
|
|
|
|
| 107 |
})
|
| 108 |
-
|
| 109 |
-
print(f"Retrieved {len(retrieved_results)} initial
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
if not retrieved_results:
|
| 114 |
-
return []
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
self.reranker.
|
| 122 |
)
|
| 123 |
|
| 124 |
end_time_rerank = time.perf_counter()
|
| 125 |
rerank_time = end_time_rerank - start_time_rerank
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
final_chunks = []
|
| 129 |
-
for res in
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
)
|
| 135 |
-
if original_chunk_data:
|
| 136 |
-
final_chunks.append({
|
| 137 |
-
"content": original_chunk_data["content"],
|
| 138 |
-
"document_metadata": original_chunk_data["document_metadata"],
|
| 139 |
-
"rerank_score": res["score"]
|
| 140 |
-
})
|
| 141 |
-
|
| 142 |
-
# Return the top_k reranked chunks and the timing information
|
| 143 |
-
print(f"Reranking completed in {rerank_time:.4f} seconds. Retrieved {len(final_chunks[:top_k])} top chunks.")
|
| 144 |
-
return final_chunks[:top_k], rerank_time
|
| 145 |
|
|
|
|
|
|
|
| 146 |
|
| 147 |
-
# ---
|
| 148 |
-
def
|
| 149 |
md_file_path: str,
|
| 150 |
-
headings_json: dict,
|
| 151 |
chunk_size: int,
|
| 152 |
-
chunk_overlap: int):
|
| 153 |
-
|
| 154 |
full_text = ""
|
| 155 |
if not os.path.exists(md_file_path):
|
| 156 |
print(f"Error: File not found at '{md_file_path}'")
|
|
@@ -169,57 +200,22 @@ def process_markdown_with_manual_sections(
|
|
| 169 |
if not full_text:
|
| 170 |
print("Input markdown file is empty.")
|
| 171 |
return []
|
|
|
|
| 172 |
text_splitter = RecursiveCharacterTextSplitter(
|
| 173 |
chunk_size=chunk_size,
|
| 174 |
chunk_overlap=chunk_overlap,
|
| 175 |
length_function=len,
|
| 176 |
is_separator_regex=False,
|
| 177 |
)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
for
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")
|
| 188 |
-
heading_positions.sort(key=lambda x: x["start_index"])
|
| 189 |
-
segments_with_headings = []
|
| 190 |
-
if heading_positions and heading_positions[0]["start_index"] > 0:
|
| 191 |
-
preface_text = full_text[:heading_positions[0]["start_index"]].strip()
|
| 192 |
-
if preface_text:
|
| 193 |
-
segments_with_headings.append({
|
| 194 |
-
"section_heading": "Document Start/Preface",
|
| 195 |
-
"section_text": preface_text
|
| 196 |
-
})
|
| 197 |
-
for i, current_heading_info in enumerate(heading_positions):
|
| 198 |
-
start_index = current_heading_info["start_index"]
|
| 199 |
-
heading_text = current_heading_info["heading_text"]
|
| 200 |
-
end_index = len(full_text)
|
| 201 |
-
if i + 1 < len(heading_positions):
|
| 202 |
-
end_index = heading_positions[i+1]["start_index"]
|
| 203 |
-
section_content = full_text[start_index:end_index].strip()
|
| 204 |
-
if section_content:
|
| 205 |
-
segments_with_headings.append({
|
| 206 |
-
"section_heading": heading_text,
|
| 207 |
-
"section_text": section_content
|
| 208 |
-
})
|
| 209 |
-
print(f"Created {len(segments_with_headings)} segments based on provided headings.")
|
| 210 |
-
for segment in segments_with_headings:
|
| 211 |
-
section_heading = segment["section_heading"]
|
| 212 |
-
section_text = segment["section_text"]
|
| 213 |
-
if section_text:
|
| 214 |
-
chunks = text_splitter.split_text(section_text)
|
| 215 |
-
for chunk in chunks:
|
| 216 |
-
metadata = {
|
| 217 |
-
"document_part": "Section",
|
| 218 |
-
"section_heading": section_heading,
|
| 219 |
-
}
|
| 220 |
-
all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata))
|
| 221 |
-
print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.")
|
| 222 |
-
return all_chunks_with_metadata
|
| 223 |
|
| 224 |
async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
|
| 225 |
"""
|
|
@@ -244,12 +240,13 @@ async def generate_answer_with_groq(query: str, retrieved_results: list[dict], g
|
|
| 244 |
)
|
| 245 |
context = "\n\n".join(context_parts)
|
| 246 |
prompt = (
|
| 247 |
-
f"You are
|
| 248 |
-
f"
|
| 249 |
-
f"
|
| 250 |
-
f"
|
| 251 |
-
f"
|
| 252 |
-
f"
|
|
|
|
| 253 |
f"Context:\n{context}\n\n"
|
| 254 |
f"Question: {query}\n\n"
|
| 255 |
f"Answer:"
|
|
@@ -271,4 +268,4 @@ async def generate_answer_with_groq(query: str, retrieved_results: list[dict], g
|
|
| 271 |
return answer
|
| 272 |
except Exception as e:
|
| 273 |
print(f"An error occurred during Groq API call: {e}")
|
| 274 |
-
return "Could not generate an answer due to an API error."
|
|
|
|
| 5 |
import json
|
| 6 |
import re
|
| 7 |
from rank_bm25 import BM25Okapi
|
| 8 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder # Added CrossEncoder
|
| 9 |
from sklearn.preprocessing import MinMaxScaler
|
| 10 |
import numpy as np
|
| 11 |
from typing import Any, List, Tuple
|
| 12 |
import asyncio
|
| 13 |
import torch
|
| 14 |
import time
|
|
|
|
| 15 |
|
| 16 |
# --- Configuration (can be overridden by the calling app) ---
|
| 17 |
CHUNK_SIZE = 1000
|
| 18 |
CHUNK_OVERLAP = 200
|
| 19 |
+
TOP_K_CHUNKS = 10 # The final number of chunks to send to the LLM
|
| 20 |
# A larger number of initial candidates for reranking
|
| 21 |
+
INITIAL_K_CANDIDATES = 20
|
| 22 |
GROQ_MODEL_NAME = "llama3-8b-8192"
|
| 23 |
+
HYDE_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
|
| 24 |
+
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
| 25 |
+
|
| 26 |
+
# --- Hypothetical Document Generation and EmbeddingClient remain unchanged ---
|
| 27 |
+
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
|
| 28 |
+
"""
|
| 29 |
+
Generates a hypothetical document using the Groq API.
|
| 30 |
+
This prompt is generic and does not require prior knowledge of the document style.
|
| 31 |
+
"""
|
| 32 |
+
if not groq_api_key:
|
| 33 |
+
print("Groq API key not set. Skipping hypothetical document generation.")
|
| 34 |
+
return ""
|
| 35 |
+
|
| 36 |
+
print(f"Starting HyDE generation for query: '{query}'...")
|
| 37 |
+
client = AsyncGroq(api_key=groq_api_key)
|
| 38 |
+
prompt = (
|
| 39 |
+
f"You are a document writer. Your task is to write a brief passage as a section of a document "
|
| 40 |
+
f"that could answer the following question. The passage should use specific terminology and "
|
| 41 |
+
f"a formal tone, as if it were an excerpt from a larger document. Do not include the question, "
|
| 42 |
+
f"and do not add any conversational text. The goal is to create a concise, semantically rich text "
|
| 43 |
+
f"to guide a search engine to find similarly styled and detailed content.\n\n"
|
| 44 |
+
f"Question: {query}\n\n"
|
| 45 |
+
f"Hypothetical Section:"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
chat_completion = await client.chat.completions.create(
|
| 50 |
+
messages=[{"role": "user", "content": prompt}],
|
| 51 |
+
model=HYDE_MODEL,
|
| 52 |
+
temperature=0.7,
|
| 53 |
+
max_tokens=500,
|
| 54 |
+
)
|
| 55 |
+
hyde_doc = chat_completion.choices[0].message.content
|
| 56 |
+
print("Hypothetical document generated.")
|
| 57 |
+
return hyde_doc
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"An error occurred during HyDE generation: {e}")
|
| 60 |
+
return ""
|
| 61 |
|
|
|
|
| 62 |
class EmbeddingClient:
|
| 63 |
"""A client for generating text embeddings using a local, open-source model."""
|
| 64 |
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
|
|
|
|
| 86 |
self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
|
| 87 |
self.document_chunks = []
|
| 88 |
self.document_embeddings = None
|
| 89 |
+
# Initialize BGE reranker model
|
| 90 |
+
self.reranker = CrossEncoder('BAAI/bge-reranker-base', device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
+
print("BGE Reranker initialized.")
|
| 92 |
|
| 93 |
async def initialize_models(self, documents: list[Document]):
|
| 94 |
self.document_chunks = documents
|
|
|
|
| 103 |
print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
|
| 104 |
self.document_embeddings = self.embedding_client.get_embeddings(corpus)
|
| 105 |
print("Document embeddings computed.")
|
| 106 |
+
|
| 107 |
+
async def retrieve_candidates(self, query: str, hyde_doc: str) -> List[dict]:
|
| 108 |
"""
|
| 109 |
+
Performs a HyDE-enhanced hybrid search to retrieve initial candidates
|
| 110 |
+
without reranking.
|
| 111 |
"""
|
| 112 |
if self.bm25_model is None or self.document_embeddings is None:
|
| 113 |
raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
|
| 114 |
+
print(f"Performing hybrid search for candidate retrieval for query: '{query}'...")
|
| 115 |
|
| 116 |
+
hyde_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
|
| 117 |
tokenized_query = query.split(" ")
|
| 118 |
bm25_scores = self.bm25_model.get_scores(tokenized_query)
|
| 119 |
+
query_embedding = self.embedding_client.get_embeddings([hyde_query])
|
|
|
|
| 120 |
from torch.nn.functional import cosine_similarity
|
| 121 |
dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
|
| 122 |
dense_scores = dense_scores.cpu().numpy()
|
| 123 |
|
| 124 |
if len(bm25_scores) == 0 or len(dense_scores) == 0:
|
| 125 |
+
return []
|
| 126 |
|
| 127 |
scaler = MinMaxScaler()
|
| 128 |
normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
|
| 129 |
normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
|
| 130 |
combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
|
| 131 |
+
|
|
|
|
| 132 |
ranked_indices = np.argsort(combined_scores)[::-1]
|
| 133 |
top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]
|
| 134 |
+
|
| 135 |
retrieved_results = []
|
| 136 |
for idx in top_initial_indices:
|
| 137 |
doc = self.document_chunks[idx]
|
| 138 |
retrieved_results.append({
|
| 139 |
"content": doc.page_content,
|
| 140 |
+
"document_metadata": doc.metadata,
|
| 141 |
+
"initial_score": combined_scores[idx] # Optionally store the initial score
|
| 142 |
})
|
| 143 |
+
|
| 144 |
+
print(f"Retrieved {len(retrieved_results)} initial candidates for reranking.")
|
| 145 |
+
return retrieved_results
|
| 146 |
+
|
| 147 |
+
async def rerank_results(self, query: str, retrieved_results: List[dict], top_k: int) -> List[dict]:
|
| 148 |
+
"""
|
| 149 |
+
Performs reranking on a list of retrieved candidate documents.
|
| 150 |
+
"""
|
| 151 |
if not retrieved_results:
|
| 152 |
+
return []
|
| 153 |
|
| 154 |
+
print(f"Reranking {len(retrieved_results)} candidates for query: '{query}'...")
|
| 155 |
+
start_time_rerank = time.perf_counter()
|
| 156 |
+
|
| 157 |
+
rerank_input = [[query, chunk["content"]] for chunk in retrieved_results]
|
| 158 |
+
rerank_scores = await asyncio.to_thread(
|
| 159 |
+
self.reranker.predict, rerank_input, show_progress_bar=False
|
| 160 |
)
|
| 161 |
|
| 162 |
end_time_rerank = time.perf_counter()
|
| 163 |
rerank_time = end_time_rerank - start_time_rerank
|
| 164 |
+
|
| 165 |
+
scored_results = list(zip(retrieved_results, rerank_scores))
|
| 166 |
+
scored_results.sort(key=lambda x: x[1], reverse=True)
|
| 167 |
+
|
| 168 |
final_chunks = []
|
| 169 |
+
for res, score in scored_results[:top_k]:
|
| 170 |
+
final_chunks.append({
|
| 171 |
+
"content": res["content"],
|
| 172 |
+
"document_metadata": res["document_metadata"],
|
| 173 |
+
"rerank_score": score
|
| 174 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
+
print(f"Reranking completed in {rerank_time:.4f} seconds. Returning top {len(final_chunks)} chunks.")
|
| 177 |
+
return final_chunks, rerank_time
|
| 178 |
|
| 179 |
+
# --- Other helper functions (process_markdown_with_recursive_chunking, generate_answer_with_groq) remain unchanged ---
|
| 180 |
+
def process_markdown_with_recursive_chunking(
|
| 181 |
md_file_path: str,
|
|
|
|
| 182 |
chunk_size: int,
|
| 183 |
+
chunk_overlap: int) -> List[Document]:
|
| 184 |
+
all_chunks = []
|
| 185 |
full_text = ""
|
| 186 |
if not os.path.exists(md_file_path):
|
| 187 |
print(f"Error: File not found at '{md_file_path}'")
|
|
|
|
| 200 |
if not full_text:
|
| 201 |
print("Input markdown file is empty.")
|
| 202 |
return []
|
| 203 |
+
|
| 204 |
text_splitter = RecursiveCharacterTextSplitter(
|
| 205 |
chunk_size=chunk_size,
|
| 206 |
chunk_overlap=chunk_overlap,
|
| 207 |
length_function=len,
|
| 208 |
is_separator_regex=False,
|
| 209 |
)
|
| 210 |
+
|
| 211 |
+
chunks = text_splitter.split_text(full_text)
|
| 212 |
+
|
| 213 |
+
for chunk in chunks:
|
| 214 |
+
all_chunks.append(Document(page_content=chunk, metadata={"document_part": "Whole Document"}))
|
| 215 |
+
|
| 216 |
+
print(f"Created {len(all_chunks)} chunks from the entire document.")
|
| 217 |
+
return all_chunks
|
| 218 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
|
| 221 |
"""
|
|
|
|
| 240 |
)
|
| 241 |
context = "\n\n".join(context_parts)
|
| 242 |
prompt = (
|
| 243 |
+
f"You are an expert on the provided document. Your task is to answer the user's question "
|
| 244 |
+
f"based only on the information given. Your answers should be brief, concise, and in a similar style to these examples:\n"
|
| 245 |
+
f"- Yes, outpatient consultations and diagnostic tests are covered, provided they are medically necessary and fall within the specified sub-limits under the plan.\n"
|
| 246 |
+
f"- The policy does not cover any expenses incurred during the first 30 days from the inception of the policy, except in the case of accidents.\n"
|
| 247 |
+
f"- Room rent is covered up to a single private AC room per day unless otherwise specified in the policy schedule.\n"
|
| 248 |
+
f"- Yes, the policy allows for mid-term inclusion of newly married spouses and newborn children, subject to notification and payment of additional premium within the stipulated time frame.\n"
|
| 249 |
+
f"Do not mention or refer to the document or the context in your final answer. If the information required to answer the question is not available in the provided context, state that you do not have enough information.\n\n"
|
| 250 |
f"Context:\n{context}\n\n"
|
| 251 |
f"Question: {query}\n\n"
|
| 252 |
f"Answer:"
|
|
|
|
| 268 |
return answer
|
| 269 |
except Exception as e:
|
| 270 |
print(f"An error occurred during Groq API call: {e}")
|
| 271 |
+
return "Could not generate an answer due to an API error."
|