Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- chunking_parent.py +79 -0
- document_processor.py +88 -0
- embedding.py +40 -0
- generation.py +57 -0
- main3.py +123 -0
- requirements.txt +217 -0
- retrieval_parent.py +155 -0
chunking_parent.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: chunking.py
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import List, Tuple, Dict, Any
|
| 4 |
+
from langchain_core.documents import Document
|
| 5 |
+
from langchain.storage import InMemoryStore
|
| 6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 7 |
+
|
| 8 |
+
# --- Configuration for Parent-Child Splitting ---
|
| 9 |
+
# Parent chunks are the larger documents passed to the LLM for context.
|
| 10 |
+
PARENT_CHUNK_SIZE = 2000
|
| 11 |
+
PARENT_CHUNK_OVERLAP = 200
|
| 12 |
+
|
| 13 |
+
# Child chunks are the smaller, more granular documents used for retrieval.
|
| 14 |
+
CHILD_CHUNK_SIZE = 400
|
| 15 |
+
CHILD_CHUNK_OVERLAP = 100
|
| 16 |
+
|
| 17 |
+
def create_parent_child_chunks(
|
| 18 |
+
full_text: str
|
| 19 |
+
) -> Tuple[List[Document], InMemoryStore, Dict[str, str]]:
|
| 20 |
+
"""
|
| 21 |
+
Implements the Parent Document strategy for chunking.
|
| 22 |
+
|
| 23 |
+
1. Splits the document into larger "parent" chunks.
|
| 24 |
+
2. Splits the parent chunks into smaller "child" chunks.
|
| 25 |
+
3. The child chunks are used for retrieval, while the parent chunks
|
| 26 |
+
are used to provide context to the LLM.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
full_text: The entire text content of the document.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
A tuple containing:
|
| 33 |
+
- A list of the small "child" documents for the vector store.
|
| 34 |
+
- An in-memory store mapping parent document IDs to the parent documents.
|
| 35 |
+
- A dictionary mapping child document IDs to their parent's ID.
|
| 36 |
+
"""
|
| 37 |
+
if not full_text:
|
| 38 |
+
print("Warning: Input text for chunking is empty.")
|
| 39 |
+
return [], InMemoryStore(), {}
|
| 40 |
+
|
| 41 |
+
print("Creating parent and child chunks...")
|
| 42 |
+
|
| 43 |
+
# This splitter creates the large documents that will be stored.
|
| 44 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
| 45 |
+
chunk_size=PARENT_CHUNK_SIZE,
|
| 46 |
+
chunk_overlap=PARENT_CHUNK_OVERLAP,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# This splitter creates the small, granular chunks for retrieval.
|
| 50 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
| 51 |
+
chunk_size=CHILD_CHUNK_SIZE,
|
| 52 |
+
chunk_overlap=CHILD_CHUNK_OVERLAP,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
parent_documents = parent_splitter.create_documents([full_text])
|
| 56 |
+
|
| 57 |
+
docstore = InMemoryStore()
|
| 58 |
+
child_documents = []
|
| 59 |
+
child_to_parent_id_map = {}
|
| 60 |
+
|
| 61 |
+
# Generate unique IDs for each parent document and add them to the store
|
| 62 |
+
parent_ids = [str(uuid.uuid4()) for _ in parent_documents]
|
| 63 |
+
docstore.mset(list(zip(parent_ids, parent_documents)))
|
| 64 |
+
|
| 65 |
+
# Split each parent document into smaller child documents
|
| 66 |
+
for i, p_doc in enumerate(parent_documents):
|
| 67 |
+
parent_id = parent_ids[i]
|
| 68 |
+
_child_docs = child_splitter.split_documents([p_doc])
|
| 69 |
+
|
| 70 |
+
for _child_doc in _child_docs:
|
| 71 |
+
child_id = str(uuid.uuid4())
|
| 72 |
+
_child_doc.metadata["parent_id"] = parent_id
|
| 73 |
+
_child_doc.metadata["child_id"] = child_id
|
| 74 |
+
child_to_parent_id_map[child_id] = parent_id
|
| 75 |
+
|
| 76 |
+
child_documents.extend(_child_docs)
|
| 77 |
+
|
| 78 |
+
print(f"Created {len(parent_documents)} parent chunks and {len(child_documents)} child chunks.")
|
| 79 |
+
return child_documents, docstore, child_to_parent_id_map
|
document_processor.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: document_processing.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import httpx
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from urllib.parse import urlparse, unquote
|
| 8 |
+
from llama_index.readers.file import PyMuPDFReader
|
| 9 |
+
from llama_index.core import Document as LlamaDocument
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 11 |
+
from pydantic import HttpUrl
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
# Define the batch size for parallel processing
|
| 15 |
+
BATCH_SIZE = 25
|
| 16 |
+
|
| 17 |
+
def _process_page_batch(documents_batch: List[LlamaDocument]) -> str:
|
| 18 |
+
"""
|
| 19 |
+
Helper function to extract content from a batch of LlamaIndex Document objects
|
| 20 |
+
and join them into a single string.
|
| 21 |
+
"""
|
| 22 |
+
return "\n\n".join([d.get_content() for d in documents_batch])
|
| 23 |
+
|
| 24 |
+
async def ingest_and_parse_document(doc_url: HttpUrl) -> str:
|
| 25 |
+
"""
|
| 26 |
+
Asynchronously downloads a document, saves it locally, and parses it to
|
| 27 |
+
Markdown text using PyMuPDFReader with parallel processing.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
doc_url: The Pydantic-validated URL of the document to process.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A single string containing the document's extracted text.
|
| 34 |
+
"""
|
| 35 |
+
print(f"Initiating download from: {doc_url}")
|
| 36 |
+
LOCAL_STORAGE_DIR = "data/"
|
| 37 |
+
os.makedirs(LOCAL_STORAGE_DIR, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Asynchronously download the document
|
| 41 |
+
async with httpx.AsyncClient() as client:
|
| 42 |
+
response = await client.get(str(doc_url), timeout=30.0, follow_redirects=True)
|
| 43 |
+
response.raise_for_status()
|
| 44 |
+
doc_bytes = response.content
|
| 45 |
+
print("Download successful.")
|
| 46 |
+
|
| 47 |
+
# Determine a valid local filename
|
| 48 |
+
parsed_path = urlparse(str(doc_url)).path
|
| 49 |
+
filename = unquote(os.path.basename(parsed_path)) or "downloaded_document.pdf"
|
| 50 |
+
local_file_path = Path(os.path.join(LOCAL_STORAGE_DIR, filename))
|
| 51 |
+
|
| 52 |
+
# Save the document locally
|
| 53 |
+
with open(local_file_path, "wb") as f:
|
| 54 |
+
f.write(doc_bytes)
|
| 55 |
+
print(f"Document saved locally at: {local_file_path}")
|
| 56 |
+
|
| 57 |
+
# Parse the document using LlamaIndex's PyMuPDFReader
|
| 58 |
+
print("Parsing document with PyMuPDFReader...")
|
| 59 |
+
loader = PyMuPDFReader()
|
| 60 |
+
docs_from_loader = loader.load_data(file_path=local_file_path)
|
| 61 |
+
|
| 62 |
+
# Parallelize the extraction of text from loaded pages
|
| 63 |
+
start_time = time.perf_counter()
|
| 64 |
+
all_extracted_texts = []
|
| 65 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
|
| 66 |
+
futures = [
|
| 67 |
+
executor.submit(_process_page_batch, docs_from_loader[i:i + BATCH_SIZE])
|
| 68 |
+
for i in range(0, len(docs_from_loader), BATCH_SIZE)
|
| 69 |
+
]
|
| 70 |
+
for future in as_completed(futures):
|
| 71 |
+
all_extracted_texts.append(future.result())
|
| 72 |
+
|
| 73 |
+
doc_text = "\n\n".join(all_extracted_texts)
|
| 74 |
+
elapsed_time = time.perf_counter() - start_time
|
| 75 |
+
print(f"Time taken for parallel text extraction: {elapsed_time:.4f} seconds.")
|
| 76 |
+
|
| 77 |
+
if not doc_text:
|
| 78 |
+
raise ValueError("Document parsing yielded no content.")
|
| 79 |
+
|
| 80 |
+
print(f"Parsing complete. Extracted {len(doc_text)} characters.")
|
| 81 |
+
return doc_text
|
| 82 |
+
|
| 83 |
+
except httpx.HTTPStatusError as e:
|
| 84 |
+
print(f"Error downloading document: {e}")
|
| 85 |
+
raise
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"An unexpected error occurred during document processing: {e}")
|
| 88 |
+
raise
|
embedding.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: embedding.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
# --- Configuration ---
|
| 8 |
+
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
| 9 |
+
|
| 10 |
+
class EmbeddingClient:
|
| 11 |
+
"""A client for generating text embeddings using a local sentence transformer model."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
|
| 14 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
| 16 |
+
print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")
|
| 17 |
+
|
| 18 |
+
def create_embeddings(self, texts: List[str]) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Generates embeddings for a list of text chunks.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
texts: A list of strings to be embedded.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
A torch.Tensor containing the generated embeddings.
|
| 27 |
+
"""
|
| 28 |
+
if not texts:
|
| 29 |
+
return torch.tensor([])
|
| 30 |
+
|
| 31 |
+
print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
|
| 32 |
+
try:
|
| 33 |
+
embeddings = self.model.encode(
|
| 34 |
+
texts, convert_to_tensor=True, show_progress_bar=False
|
| 35 |
+
)
|
| 36 |
+
print("Embeddings generated successfully.")
|
| 37 |
+
return embeddings
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"An error occurred during embedding generation: {e}")
|
| 40 |
+
raise
|
generation.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: generation.py
|
| 2 |
+
from groq import AsyncGroq
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
|
| 5 |
+
# --- Configuration ---
|
| 6 |
+
GROQ_MODEL_NAME = "llama3-8b-8192"
|
| 7 |
+
|
| 8 |
+
async def generate_answer(query: str, context_chunks: List[Dict], groq_api_key: str) -> str:
|
| 9 |
+
"""
|
| 10 |
+
Generates a final answer using the Groq API based on the query and retrieved context.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
query: The user's original question.
|
| 14 |
+
context_chunks: A list of the most relevant, reranked document chunks.
|
| 15 |
+
groq_api_key: The API key for the Groq service.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
A string containing the generated answer.
|
| 19 |
+
"""
|
| 20 |
+
if not groq_api_key:
|
| 21 |
+
return "Error: Groq API key is not set."
|
| 22 |
+
if not context_chunks:
|
| 23 |
+
return "I do not have enough information to answer this question based on the provided document."
|
| 24 |
+
|
| 25 |
+
print("Generating final answer with Groq...")
|
| 26 |
+
client = AsyncGroq(api_key=groq_api_key)
|
| 27 |
+
|
| 28 |
+
# Format the context for the prompt
|
| 29 |
+
context_str = "\n\n---\n\n".join(
|
| 30 |
+
[f"Context Chunk:\n{chunk['content']}" for chunk in context_chunks]
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
prompt = (
|
| 34 |
+
"You are an expert Q&A system. Your task is to extract information with 100% accuracy from the provided text. Provide a brief and direct answer."
|
| 35 |
+
"Do not mention the context in your response. Answer *only* using the information from the provided document."
|
| 36 |
+
"Do not infer, add, or assume any information that is not explicitly written in the source text. If the answer is not in the document, state that the information is not available."
|
| 37 |
+
"When the question involves numbers, percentages, or monetary values, extract the exact figures from the text."
|
| 38 |
+
"Double-check that the value corresponds to the correct plan or condition mentioned in the question."
|
| 39 |
+
"\n\n"
|
| 40 |
+
f"CONTEXT:\n{context_str}\n\n"
|
| 41 |
+
f"QUESTION:\n{query}\n\n"
|
| 42 |
+
"ANSWER:"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
chat_completion = await client.chat.completions.create(
|
| 47 |
+
messages=[{"role": "user", "content": prompt}],
|
| 48 |
+
model=GROQ_MODEL_NAME,
|
| 49 |
+
temperature=0.2, # Lower temperature for more factual answers
|
| 50 |
+
max_tokens=500,
|
| 51 |
+
)
|
| 52 |
+
answer = chat_completion.choices[0].message.content
|
| 53 |
+
print("Answer generated successfully.")
|
| 54 |
+
return answer
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"An error occurred during Groq API call: {e}")
|
| 57 |
+
return "Could not generate an answer due to an API error."
|
main3.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: main.py
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
import asyncio
|
| 5 |
+
from fastapi import FastAPI, HTTPException
|
| 6 |
+
from pydantic import BaseModel, HttpUrl
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
from document_processor import ingest_and_parse_document
|
| 11 |
+
from chunking_parent import create_parent_child_chunks # <-- Using the parent-child function
|
| 12 |
+
from embedding import EmbeddingClient
|
| 13 |
+
from retrieval_parent import Retriever, generate_hypothetical_document
|
| 14 |
+
from generation import generate_answer
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
app = FastAPI(
|
| 19 |
+
title="Modular RAG API",
|
| 20 |
+
description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
|
| 21 |
+
version="2.2.1", # Updated version
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
| 25 |
+
embedding_client = EmbeddingClient()
|
| 26 |
+
retriever = Retriever(embedding_client=embedding_client)
|
| 27 |
+
|
| 28 |
+
# --- Pydantic Models ---
|
| 29 |
+
class RunRequest(BaseModel):
|
| 30 |
+
document_url: HttpUrl
|
| 31 |
+
questions: List[str]
|
| 32 |
+
|
| 33 |
+
class RunResponse(BaseModel):
|
| 34 |
+
answers: List[str]
|
| 35 |
+
|
| 36 |
+
class TestRequest(BaseModel): # <-- Model for test endpoints
|
| 37 |
+
document_url: HttpUrl
|
| 38 |
+
|
| 39 |
+
# --- NEW: Test Endpoint for Parent-Child Chunking ---
|
| 40 |
+
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
|
| 41 |
+
async def test_chunking_endpoint(request: TestRequest):
|
| 42 |
+
"""
|
| 43 |
+
Tests the parent-child chunking strategy.
|
| 44 |
+
Returns parent chunks, child chunks, and the time taken.
|
| 45 |
+
"""
|
| 46 |
+
print("--- Running Parent-Child Chunking Test ---")
|
| 47 |
+
start_time = time.perf_counter()
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
# Step 1: Parse the document to get raw text
|
| 51 |
+
markdown_content = await ingest_and_parse_document(request.document_url)
|
| 52 |
+
|
| 53 |
+
# Step 2: Create parent and child chunks
|
| 54 |
+
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
|
| 55 |
+
|
| 56 |
+
end_time = time.perf_counter()
|
| 57 |
+
duration = end_time - start_time
|
| 58 |
+
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
|
| 59 |
+
|
| 60 |
+
# Convert Document objects to a JSON-serializable list for the response
|
| 61 |
+
child_chunk_results = [
|
| 62 |
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
| 63 |
+
for doc in child_documents
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
# Retrieve parent documents from the in-memory store
|
| 67 |
+
parent_docs = docstore.mget(list(docstore.store.keys()))
|
| 68 |
+
parent_chunk_results = [
|
| 69 |
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
| 70 |
+
for doc in parent_docs if doc
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"total_time_seconds": duration,
|
| 75 |
+
"parent_chunk_count": len(parent_chunk_results),
|
| 76 |
+
"child_chunk_count": len(child_chunk_results),
|
| 77 |
+
"parent_chunks": parent_chunk_results,
|
| 78 |
+
"child_chunks": child_chunk_results,
|
| 79 |
+
}
|
| 80 |
+
except Exception as e:
|
| 81 |
+
raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@app.post("/hackrx/run", response_model=RunResponse)
|
| 85 |
+
async def run_rag_pipeline(request: RunRequest):
|
| 86 |
+
try:
|
| 87 |
+
print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
|
| 88 |
+
|
| 89 |
+
# --- STAGE 1: DOCUMENT INGESTION ---
|
| 90 |
+
markdown_content = await ingest_and_parse_document(request.document_url)
|
| 91 |
+
|
| 92 |
+
# --- STAGE 2: PARENT-CHILD CHUNKING ---
|
| 93 |
+
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
|
| 94 |
+
|
| 95 |
+
if not child_documents:
|
| 96 |
+
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
|
| 97 |
+
|
| 98 |
+
# --- STAGE 3: INDEXING ---
|
| 99 |
+
retriever.index(child_documents, docstore)
|
| 100 |
+
|
| 101 |
+
# --- CONCURRENT WORKFLOW ---
|
| 102 |
+
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
|
| 103 |
+
all_hyde_docs = await asyncio.gather(*hyde_tasks)
|
| 104 |
+
|
| 105 |
+
retrieval_tasks = [
|
| 106 |
+
retriever.retrieve(q, hyde_doc)
|
| 107 |
+
for q, hyde_doc in zip(request.questions, all_hyde_docs)
|
| 108 |
+
]
|
| 109 |
+
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
|
| 110 |
+
|
| 111 |
+
answer_tasks = [
|
| 112 |
+
generate_answer(q, chunks, GROQ_API_KEY)
|
| 113 |
+
for q, chunks in zip(request.questions, all_retrieved_chunks)
|
| 114 |
+
]
|
| 115 |
+
final_answers = await asyncio.gather(*answer_tasks)
|
| 116 |
+
|
| 117 |
+
print("--- RAG Pipeline Completed Successfully ---")
|
| 118 |
+
return RunResponse(answers=final_answers)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
raise HTTPException(
|
| 122 |
+
status_code=500, detail=f"An internal server error occurred: {str(e)}"
|
| 123 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.10.0
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
aiohappyeyeballs==2.6.1
|
| 4 |
+
aiohttp==3.12.15
|
| 5 |
+
aiosignal==1.4.0
|
| 6 |
+
aiosqlite==0.21.0
|
| 7 |
+
annotated-types==0.7.0
|
| 8 |
+
antlr4-python3-runtime==4.9.3
|
| 9 |
+
anyio==4.9.0
|
| 10 |
+
asyncio==3.4.3
|
| 11 |
+
attrs==25.3.0
|
| 12 |
+
backoff==2.2.1
|
| 13 |
+
banks==2.2.0
|
| 14 |
+
beautifulsoup4==4.13.4
|
| 15 |
+
bitarray==3.6.0
|
| 16 |
+
blinker==1.9.0
|
| 17 |
+
cachetools==5.5.2
|
| 18 |
+
catalogue==2.0.10
|
| 19 |
+
certifi==2025.8.3
|
| 20 |
+
cffi==1.17.1
|
| 21 |
+
charset-normalizer==3.4.2
|
| 22 |
+
click==8.2.1
|
| 23 |
+
colbert-ai==0.2.21
|
| 24 |
+
colorama==0.4.6
|
| 25 |
+
coloredlogs==15.0.1
|
| 26 |
+
contourpy==1.3.3
|
| 27 |
+
cryptography==45.0.6
|
| 28 |
+
cycler==0.12.1
|
| 29 |
+
dataclasses-json==0.6.7
|
| 30 |
+
datasets==4.0.0
|
| 31 |
+
defusedxml==0.7.1
|
| 32 |
+
deprecated==1.2.18
|
| 33 |
+
dill==0.3.8
|
| 34 |
+
dirtyjson==1.0.8
|
| 35 |
+
distro==1.9.0
|
| 36 |
+
dotenv==0.9.9
|
| 37 |
+
effdet==0.4.1
|
| 38 |
+
emoji==2.14.1
|
| 39 |
+
et-xmlfile==2.0.0
|
| 40 |
+
faiss-cpu==1.11.0.post1
|
| 41 |
+
fast-pytorch-kmeans==0.2.2
|
| 42 |
+
fastapi==0.116.1
|
| 43 |
+
filelock==3.18.0
|
| 44 |
+
filetype==1.2.0
|
| 45 |
+
flashrank==0.2.10
|
| 46 |
+
flask==3.1.1
|
| 47 |
+
flatbuffers==25.2.10
|
| 48 |
+
fonttools==4.59.0
|
| 49 |
+
frozenlist==1.7.0
|
| 50 |
+
fsspec==2025.3.0
|
| 51 |
+
git-python==1.0.3
|
| 52 |
+
gitdb==4.0.12
|
| 53 |
+
gitpython==3.1.45
|
| 54 |
+
google-api-core==2.25.1
|
| 55 |
+
google-auth==2.40.3
|
| 56 |
+
google-cloud-vision==3.10.2
|
| 57 |
+
googleapis-common-protos==1.70.0
|
| 58 |
+
greenlet==3.2.3
|
| 59 |
+
griffe==1.9.0
|
| 60 |
+
groq==0.30.0
|
| 61 |
+
grpcio==1.74.0
|
| 62 |
+
grpcio-status==1.74.0
|
| 63 |
+
h11==0.16.0
|
| 64 |
+
hf-xet==1.1.5
|
| 65 |
+
html5lib==1.1
|
| 66 |
+
httpcore==1.0.9
|
| 67 |
+
httptools==0.6.4
|
| 68 |
+
httpx==0.28.1
|
| 69 |
+
httpx-sse==0.4.1
|
| 70 |
+
huggingface-hub==0.34.3
|
| 71 |
+
humanfriendly==10.0
|
| 72 |
+
idna==3.10
|
| 73 |
+
itsdangerous==2.2.0
|
| 74 |
+
jinja2==3.1.6
|
| 75 |
+
jiter==0.10.0
|
| 76 |
+
joblib==1.5.1
|
| 77 |
+
jsonpatch==1.33
|
| 78 |
+
jsonpointer==3.0.0
|
| 79 |
+
kiwisolver==1.4.8
|
| 80 |
+
langchain==0.3.27
|
| 81 |
+
langchain-community==0.3.27
|
| 82 |
+
langchain-core==0.3.72
|
| 83 |
+
langchain-pymupdf4llm==0.4.1
|
| 84 |
+
langchain-text-splitters==0.3.9
|
| 85 |
+
langdetect==1.0.9
|
| 86 |
+
langsmith==0.4.10
|
| 87 |
+
llama-cloud==0.1.35
|
| 88 |
+
llama-cloud-services==0.6.53
|
| 89 |
+
llama-index==0.13.0
|
| 90 |
+
llama-index-cli==0.5.0
|
| 91 |
+
llama-index-core==0.13.0
|
| 92 |
+
llama-index-embeddings-openai==0.5.0
|
| 93 |
+
llama-index-indices-managed-llama-cloud==0.9.0
|
| 94 |
+
llama-index-instrumentation==0.4.0
|
| 95 |
+
llama-index-llms-openai==0.5.1
|
| 96 |
+
llama-index-readers-file==0.5.0
|
| 97 |
+
llama-index-readers-llama-parse==0.5.0
|
| 98 |
+
llama-index-workflows==1.2.0
|
| 99 |
+
llama-parse==0.6.53
|
| 100 |
+
lxml==6.0.0
|
| 101 |
+
markdown==3.8.2
|
| 102 |
+
markupsafe==3.0.2
|
| 103 |
+
marshmallow==3.26.1
|
| 104 |
+
matplotlib==3.10.5
|
| 105 |
+
mpmath==1.3.0
|
| 106 |
+
msoffcrypto-tool==5.4.2
|
| 107 |
+
multidict==6.6.3
|
| 108 |
+
multiprocess==0.70.16
|
| 109 |
+
mypy-extensions==1.1.0
|
| 110 |
+
nest-asyncio==1.6.0
|
| 111 |
+
networkx==3.5
|
| 112 |
+
ninja==1.11.1.4
|
| 113 |
+
nltk==3.9.1
|
| 114 |
+
numpy==2.3.2
|
| 115 |
+
olefile==0.47
|
| 116 |
+
omegaconf==2.3.0
|
| 117 |
+
onnx==1.18.0
|
| 118 |
+
onnxruntime==1.22.1
|
| 119 |
+
openai==1.99.3
|
| 120 |
+
opencv-python==4.11.0.86
|
| 121 |
+
openpyxl==3.1.5
|
| 122 |
+
orjson==3.11.1
|
| 123 |
+
packaging==25.0
|
| 124 |
+
pandas==2.2.3
|
| 125 |
+
pdf2image==1.17.0
|
| 126 |
+
pdfminer-six==20250506
|
| 127 |
+
pi-heif==1.1.0
|
| 128 |
+
pikepdf==9.10.2
|
| 129 |
+
pillow==11.3.0
|
| 130 |
+
pinecone-client==6.0.0
|
| 131 |
+
pinecone-plugin-interface==0.0.7
|
| 132 |
+
platformdirs==4.3.8
|
| 133 |
+
propcache==0.3.2
|
| 134 |
+
proto-plus==1.26.1
|
| 135 |
+
protobuf==6.31.1
|
| 136 |
+
psutil==7.0.0
|
| 137 |
+
pyarrow==21.0.0
|
| 138 |
+
pyasn1==0.6.1
|
| 139 |
+
pyasn1-modules==0.4.2
|
| 140 |
+
pycocotools==2.0.10
|
| 141 |
+
pycparser==2.22
|
| 142 |
+
pydantic==2.11.7
|
| 143 |
+
pydantic-core==2.33.2
|
| 144 |
+
pydantic-settings==2.10.1
|
| 145 |
+
pymupdf==1.26.3
|
| 146 |
+
pymupdf4llm==0.0.27
|
| 147 |
+
pypandoc==1.15
|
| 148 |
+
pyparsing==3.2.3
|
| 149 |
+
pypdf==5.9.0
|
| 150 |
+
pypdf2==3.0.1
|
| 151 |
+
pypdfium2==4.30.0
|
| 152 |
+
pytesseract==0.3.13
|
| 153 |
+
python-dateutil==2.9.0.post0
|
| 154 |
+
python-docx==1.2.0
|
| 155 |
+
python-dotenv==1.1.1
|
| 156 |
+
python-iso639==2025.2.18
|
| 157 |
+
python-magic==0.4.27
|
| 158 |
+
python-multipart==0.0.20
|
| 159 |
+
python-oxmsg==0.0.2
|
| 160 |
+
python-pptx==1.0.2
|
| 161 |
+
pytz==2025.2
|
| 162 |
+
pyyaml==6.0.2
|
| 163 |
+
ragatouille==0.0.9.post2
|
| 164 |
+
rank-bm25==0.2.2
|
| 165 |
+
rapidfuzz==3.13.0
|
| 166 |
+
regex==2025.7.34
|
| 167 |
+
requests==2.32.4
|
| 168 |
+
requests-toolbelt==1.0.0
|
| 169 |
+
rsa==4.9.1
|
| 170 |
+
safetensors==0.5.3
|
| 171 |
+
scikit-learn==1.7.1
|
| 172 |
+
scipy==1.16.1
|
| 173 |
+
sentence-transformers==5.0.0
|
| 174 |
+
setuptools==80.9.0
|
| 175 |
+
six==1.17.0
|
| 176 |
+
smmap==5.0.2
|
| 177 |
+
sniffio==1.3.1
|
| 178 |
+
soupsieve==2.7
|
| 179 |
+
sqlalchemy==2.0.42
|
| 180 |
+
srsly==2.5.1
|
| 181 |
+
starlette==0.47.2
|
| 182 |
+
striprtf==0.0.26
|
| 183 |
+
sympy==1.14.0
|
| 184 |
+
tenacity==9.1.2
|
| 185 |
+
tesseract==0.1.3
|
| 186 |
+
threadpoolctl==3.6.0
|
| 187 |
+
tiktoken==0.9.0
|
| 188 |
+
timm==1.0.19
|
| 189 |
+
tokenizers==0.21.4
|
| 190 |
+
torch==2.7.1
|
| 191 |
+
torchvision==0.22.1
|
| 192 |
+
tqdm==4.67.1
|
| 193 |
+
transformers==4.49.0
|
| 194 |
+
typing==3.7.4.3
|
| 195 |
+
typing-extensions==4.14.1
|
| 196 |
+
typing-inspect==0.9.0
|
| 197 |
+
typing-inspection==0.4.1
|
| 198 |
+
tzdata==2025.2
|
| 199 |
+
ujson==5.10.0
|
| 200 |
+
unstructured==0.18.11
|
| 201 |
+
unstructured-client==0.42.2
|
| 202 |
+
unstructured-inference==1.0.5
|
| 203 |
+
unstructured-pytesseract==0.3.15
|
| 204 |
+
urllib3==2.5.0
|
| 205 |
+
uvicorn==0.35.0
|
| 206 |
+
uvloop==0.21.0
|
| 207 |
+
voyager==2.1.0
|
| 208 |
+
watchfiles==1.1.0
|
| 209 |
+
webencodings==0.5.1
|
| 210 |
+
websockets==15.0.1
|
| 211 |
+
werkzeug==3.1.3
|
| 212 |
+
wrapt==1.17.2
|
| 213 |
+
xlrd==2.0.2
|
| 214 |
+
xlsxwriter==3.2.5
|
| 215 |
+
xxhash==3.5.0
|
| 216 |
+
yarl==1.20.1
|
| 217 |
+
zstandard==0.23.0
|
retrieval_parent.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file: retrieval.py
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from groq import AsyncGroq
|
| 8 |
+
from rank_bm25 import BM25Okapi
|
| 9 |
+
from sentence_transformers import CrossEncoder
|
| 10 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 11 |
+
from torch.nn.functional import cosine_similarity
|
| 12 |
+
from typing import List, Dict, Tuple
|
| 13 |
+
from langchain.storage import InMemoryStore
|
| 14 |
+
|
| 15 |
+
from embedding import EmbeddingClient
|
| 16 |
+
from langchain_core.documents import Document
|
| 17 |
+
|
| 18 |
+
# --- Configuration ---
|
| 19 |
+
HYDE_MODEL = "llama3-8b-8192"
|
| 20 |
+
RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
|
| 21 |
+
INITIAL_K_CANDIDATES = 20
|
| 22 |
+
TOP_K_CHUNKS = 10
|
| 23 |
+
|
| 24 |
+
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
|
| 25 |
+
# ... (this function remains unchanged) ...
|
| 26 |
+
if not groq_api_key:
|
| 27 |
+
print("Groq API key not set. Skipping HyDE generation.")
|
| 28 |
+
return ""
|
| 29 |
+
|
| 30 |
+
print(f"Starting HyDE generation for query: '{query}'...")
|
| 31 |
+
client = AsyncGroq(api_key=groq_api_key)
|
| 32 |
+
prompt = (
|
| 33 |
+
f"Write a brief, formal passage that answers the following question. "
|
| 34 |
+
f"Use specific terminology as if it were from a larger document. "
|
| 35 |
+
f"Do not include the question or conversational text.\n\n"
|
| 36 |
+
f"Question: {query}\n\n"
|
| 37 |
+
f"Hypothetical Passage:"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
chat_completion = await client.chat.completions.create(
|
| 42 |
+
messages=[{"role": "user", "content": prompt}],
|
| 43 |
+
model=HYDE_MODEL,
|
| 44 |
+
temperature=0.7,
|
| 45 |
+
max_tokens=500,
|
| 46 |
+
)
|
| 47 |
+
return chat_completion.choices[0].message.content
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"An error occurred during HyDE generation: {e}")
|
| 50 |
+
return ""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Retriever:
|
| 54 |
+
"""Manages hybrid search with parent-child retrieval."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, embedding_client: EmbeddingClient):
|
| 57 |
+
self.embedding_client = embedding_client
|
| 58 |
+
self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
|
| 59 |
+
self.bm25 = None
|
| 60 |
+
self.document_chunks = []
|
| 61 |
+
self.chunk_embeddings = None
|
| 62 |
+
self.docstore = InMemoryStore() # <-- ADD THIS
|
| 63 |
+
print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
|
| 64 |
+
|
| 65 |
+
def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS
|
| 66 |
+
"""Builds the search index from child documents and stores parent documents."""
|
| 67 |
+
self.document_chunks = child_documents # Store child docs for mapping
|
| 68 |
+
self.docstore = docstore # Store the parent documents
|
| 69 |
+
|
| 70 |
+
corpus = [doc.page_content for doc in child_documents]
|
| 71 |
+
if not corpus:
|
| 72 |
+
print("No documents to index.")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
print("Indexing child documents for retrieval...")
|
| 76 |
+
tokenized_corpus = [doc.split(" ") for doc in corpus]
|
| 77 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
| 78 |
+
self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
|
| 79 |
+
print("Indexing complete.")
|
| 80 |
+
|
| 81 |
+
def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
|
| 82 |
+
# ... (this function remains unchanged) ...
|
| 83 |
+
if self.bm25 is None or self.chunk_embeddings is None:
|
| 84 |
+
raise ValueError("Retriever has not been indexed. Call index() first.")
|
| 85 |
+
|
| 86 |
+
enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
|
| 87 |
+
tokenized_query = query.split(" ")
|
| 88 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 89 |
+
query_embedding = self.embedding_client.create_embeddings([enhanced_query])
|
| 90 |
+
dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
|
| 91 |
+
|
| 92 |
+
scaler = MinMaxScaler()
|
| 93 |
+
norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
|
| 94 |
+
norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
|
| 95 |
+
combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
|
| 96 |
+
|
| 97 |
+
top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
|
| 98 |
+
return [(idx, combined_scores[idx]) for idx in top_indices]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
|
| 102 |
+
# ... (this function remains unchanged) ...
|
| 103 |
+
if not candidates:
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
print(f"Reranking {len(candidates)} candidates...")
|
| 107 |
+
rerank_input = [[query, chunk["content"]] for chunk in candidates]
|
| 108 |
+
|
| 109 |
+
rerank_scores = await asyncio.to_thread(
|
| 110 |
+
self.reranker.predict, rerank_input, show_progress_bar=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
for candidate, score in zip(candidates, rerank_scores):
|
| 114 |
+
candidate['rerank_score'] = score
|
| 115 |
+
|
| 116 |
+
candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
|
| 117 |
+
return candidates[:TOP_K_CHUNKS]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS
|
| 121 |
+
"""Executes the full retrieval pipeline and returns parent documents."""
|
| 122 |
+
print(f"Retrieving documents for query: '{query}'")
|
| 123 |
+
|
| 124 |
+
# 1. Hybrid search returns indices of the best CHILD documents
|
| 125 |
+
initial_candidates_info = self._hybrid_search(query, hyde_doc)
|
| 126 |
+
|
| 127 |
+
retrieved_child_docs = [{
|
| 128 |
+
"content": self.document_chunks[idx].page_content,
|
| 129 |
+
"metadata": self.document_chunks[idx].metadata,
|
| 130 |
+
} for idx, score in initial_candidates_info]
|
| 131 |
+
|
| 132 |
+
# 2. Rerank the CHILD documents
|
| 133 |
+
reranked_child_docs = await self._rerank(query, retrieved_child_docs)
|
| 134 |
+
|
| 135 |
+
# 3. Get the unique parent IDs from the reranked child documents
|
| 136 |
+
parent_ids = []
|
| 137 |
+
for doc in reranked_child_docs:
|
| 138 |
+
parent_id = doc["metadata"]["parent_id"]
|
| 139 |
+
if parent_id not in parent_ids:
|
| 140 |
+
parent_ids.append(parent_id)
|
| 141 |
+
|
| 142 |
+
# 4. Retrieve the full PARENT documents from the docstore
|
| 143 |
+
retrieved_parents = self.docstore.mget(parent_ids)
|
| 144 |
+
|
| 145 |
+
# Filter out any None results in case of a miss
|
| 146 |
+
final_parent_docs = [doc for doc in retrieved_parents if doc is not None]
|
| 147 |
+
|
| 148 |
+
# 5. Format for the generation step
|
| 149 |
+
final_context = [{
|
| 150 |
+
"content": doc.page_content,
|
| 151 |
+
"metadata": doc.metadata
|
| 152 |
+
} for doc in final_parent_docs]
|
| 153 |
+
|
| 154 |
+
print(f"Retrieved {len(final_context)} final parent chunks for context.")
|
| 155 |
+
return final_context
|