Spaces:
Sleeping
Sleeping
File size: 12,157 Bytes
6e103a7 bfca738 6e103a7 f69ffb2 6e103a7 f69ffb2 6e103a7 bfca738 f69ffb2 bfca738 d32b359 8a06370 bfca738 6e103a7 f69ffb2 6e103a7 bfca738 4300923 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 f69ffb2 bfca738 6e103a7 f69ffb2 6e103a7 bfca738 6e103a7 bfca738 f69ffb2 bfca738 f69ffb2 bfca738 f69ffb2 bfca738 f69ffb2 bfca738 f69ffb2 bfca738 f69ffb2 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 6e103a7 bfca738 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from groq import AsyncGroq
import json
import re
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder # Added CrossEncoder
from sklearn.preprocessing import MinMaxScaler
import numpy as np
from typing import Any, List, Tuple
import asyncio
import torch
import time
# --- Configuration (can be overridden by the calling app) ---
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TOP_K_CHUNKS = 10 # The final number of chunks to send to the LLM
# A larger number of initial candidates for reranking
INITIAL_K_CANDIDATES = 20
GROQ_MODEL_NAME = "openai/gpt-oss-20b"
HYDE_MODEL = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
# --- Hypothetical Document Generation and EmbeddingClient remain unchanged ---
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
"""
Generates a hypothetical document using the Groq API.
This prompt is generic and does not require prior knowledge of the document style.
"""
if not groq_api_key:
print("Groq API key not set. Skipping hypothetical document generation.")
return ""
print(f"Starting HyDE generation for query: '{query}'...")
client = AsyncGroq(api_key=groq_api_key)
prompt = (
f"You are a document writer. Your task is to write a brief passage as a section of a document "
f"that could answer the following question. The passage should use specific terminology and "
f"a formal tone, as if it were an excerpt from a larger document. Do not include the question, "
f"and do not add any conversational text. The goal is to create a concise, semantically rich text "
f"to guide a search engine to find similarly styled and detailed content.\n\n"
f"Question: {query}\n\n"
f"Hypothetical Section:"
)
try:
chat_completion = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=HYDE_MODEL,
temperature=0.7,
max_tokens=500,
)
hyde_doc = chat_completion.choices[0].message.content
print("Hypothetical document generated.")
return hyde_doc
except Exception as e:
print(f"An error occurred during HyDE generation: {e}")
return ""
class EmbeddingClient:
"""A client for generating text embeddings using a local, open-source model."""
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
self.model = SentenceTransformer(model_name, device=self.device)
print(f"Sentence Transformer embedding client initialized ({model_name}) on {self.device}.")
def get_embeddings(self, texts: List[str]) -> torch.Tensor:
if not texts:
return torch.tensor([])
print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
print("Embeddings generated successfully.")
return embeddings
# --- Hybrid Search Manager Class ---
class HybridSearchManager:
"""
Manages the initialization and execution of a hybrid search system
combining BM25, dense vector search, and a fast reranker.
"""
def __init__(self, embedding_model_name: str = EMBEDDING_MODEL_NAME):
self.bm25_model = None
self.embedding_client = EmbeddingClient(model_name=embedding_model_name)
self.document_chunks = []
self.document_embeddings = None
# Initialize BGE reranker model
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu')
print("ms-marco-MiniLM-L6-v2 Reranker initialized.")
async def initialize_models(self, documents: list[Document]):
self.document_chunks = documents
corpus = [doc.page_content for doc in documents]
if not corpus:
print("No documents to initialize. Skipping model setup.")
return
print("Initializing BM25 model...")
tokenized_corpus = [doc.split(" ") for doc in corpus]
self.bm25_model = BM25Okapi(tokenized_corpus)
print("BM25 model initialized.")
print(f"Computing and storing document embeddings on {self.embedding_client.device}...")
self.document_embeddings = self.embedding_client.get_embeddings(corpus)
print("Document embeddings computed.")
async def retrieve_candidates(self, query: str, hyde_doc: str) -> List[dict]:
"""
Performs a HyDE-enhanced hybrid search to retrieve initial candidates
without reranking.
"""
if self.bm25_model is None or self.document_embeddings is None:
raise ValueError("Hybrid search models are not initialized. Call initialize_models first.")
print(f"Performing hybrid search for candidate retrieval for query: '{query}'...")
hyde_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
tokenized_query = query.split(" ")
bm25_scores = self.bm25_model.get_scores(tokenized_query)
query_embedding = self.embedding_client.get_embeddings([hyde_query])
from torch.nn.functional import cosine_similarity
dense_scores = cosine_similarity(query_embedding, self.document_embeddings)
dense_scores = dense_scores.cpu().numpy()
if len(bm25_scores) == 0 or len(dense_scores) == 0:
return []
scaler = MinMaxScaler()
normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores
ranked_indices = np.argsort(combined_scores)[::-1]
top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES]
retrieved_results = []
for idx in top_initial_indices:
doc = self.document_chunks[idx]
retrieved_results.append({
"content": doc.page_content,
"document_metadata": doc.metadata,
"initial_score": combined_scores[idx] # Optionally store the initial score
})
print(f"Retrieved {len(retrieved_results)} initial candidates for reranking.")
return retrieved_results
async def rerank_results(self, query: str, retrieved_results: List[dict], top_k: int) -> List[dict]:
"""
Performs reranking on a list of retrieved candidate documents.
"""
if not retrieved_results:
return []
print(f"Reranking {len(retrieved_results)} candidates for query: '{query}'...")
start_time_rerank = time.perf_counter()
rerank_input = [[query, chunk["content"]] for chunk in retrieved_results]
rerank_scores = await asyncio.to_thread(
self.reranker.predict, rerank_input, show_progress_bar=False
)
end_time_rerank = time.perf_counter()
rerank_time = end_time_rerank - start_time_rerank
scored_results = list(zip(retrieved_results, rerank_scores))
scored_results.sort(key=lambda x: x[1], reverse=True)
final_chunks = []
for res, score in scored_results[:top_k]:
final_chunks.append({
"content": res["content"],
"document_metadata": res["document_metadata"],
"rerank_score": score
})
print(f"Reranking completed in {rerank_time:.4f} seconds. Returning top {len(final_chunks)} chunks.")
return final_chunks, rerank_time
# --- Other helper functions (process_markdown_with_recursive_chunking, generate_answer_with_groq) remain unchanged ---
def process_markdown_with_recursive_chunking(
md_file_path: str,
chunk_size: int,
chunk_overlap: int) -> List[Document]:
all_chunks = []
full_text = ""
if not os.path.exists(md_file_path):
print(f"Error: File not found at '{md_file_path}'")
return []
if not os.path.isfile(md_file_path):
print(f"Error: Path '{md_file_path}' is not a file.")
return []
if not md_file_path.lower().endswith(".md"):
print(f"Warning: File '{md_file_path}' does not have a .md extension.")
try:
with open(md_file_path, 'r', encoding='utf-8') as f:
full_text = f.read()
except Exception as e:
print(f"Error reading file '{md_file_path}': {e}")
return []
if not full_text:
print("Input markdown file is empty.")
return []
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
chunks = text_splitter.split_text(full_text)
for chunk in chunks:
all_chunks.append(Document(page_content=chunk, metadata={"document_part": "Whole Document"}))
print(f"Created {len(all_chunks)} chunks from the entire document.")
return all_chunks
async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
"""
Generates an answer using the Groq API based on the query and retrieved chunks' content.
"""
if not groq_api_key:
return "Error: Groq API key is not set. Cannot generate answer."
print("Generating answer with Groq API...")
client = AsyncGroq(api_key= groq_api_key)
context_parts = []
for i, res in enumerate(retrieved_results):
content = res.get("content", "")
metadata = res.get("document_metadata", {})
section_heading = metadata.get("section_heading", "N/A")
document_part = metadata.get("document_part", "N/A")
context_parts.append(
f"--- Context Chunk {i+1} ---\n"
f"Document Part: {document_part}\n"
f"Section Heading: {section_heading}\n"
f"Content: {content}\n"
f"-------------------------"
)
context = "\n\n".join(context_parts)
prompt = (
f"You are an expert on the provided document. Your task is to answer the user's question "
f"based only on the information given. Your answers should be brief, concise, and in a similar style to these examples:\n"
f"- Yes, outpatient consultations and diagnostic tests are covered, provided they are medically necessary and fall within the specified sub-limits under the plan.\n"
f"- The policy does not cover any expenses incurred during the first 30 days from the inception of the policy, except in the case of accidents.\n"
f"- Room rent is covered up to a single private AC room per day unless otherwise specified in the policy schedule.\n"
f"- Yes, the policy allows for mid-term inclusion of newly married spouses and newborn children, subject to notification and payment of additional premium within the stipulated time frame.\n"
f"Do not mention or refer to the document or the context in your final answer. If the information required to answer the question is not available in the provided context, state that you do not have enough information.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
try:
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=GROQ_MODEL_NAME,
temperature=0.7,
max_tokens=500,
)
answer = chat_completion.choices[0].message.content
print("Answer generated successfully.")
return answer
except Exception as e:
print(f"An error occurred during Groq API call: {e}")
return "Could not generate an answer due to an API error." |