PercivalFletcher commited on
Commit
84f4fa5
·
verified ·
1 Parent(s): cb1c1dd

Upload 7 files

Browse files
Files changed (7) hide show
  1. chunking_parent.py +79 -0
  2. document_processor.py +88 -0
  3. embedding.py +40 -0
  4. generation.py +57 -0
  5. main3.py +123 -0
  6. requirements.txt +217 -0
  7. retrieval_parent.py +155 -0
chunking_parent.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: chunking.py
2
+ import uuid
3
+ from typing import List, Tuple, Dict, Any
4
+ from langchain_core.documents import Document
5
+ from langchain.storage import InMemoryStore
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+
8
+ # --- Configuration for Parent-Child Splitting ---
9
+ # Parent chunks are the larger documents passed to the LLM for context.
10
+ PARENT_CHUNK_SIZE = 2000
11
+ PARENT_CHUNK_OVERLAP = 200
12
+
13
+ # Child chunks are the smaller, more granular documents used for retrieval.
14
+ CHILD_CHUNK_SIZE = 400
15
+ CHILD_CHUNK_OVERLAP = 100
16
+
17
+ def create_parent_child_chunks(
18
+ full_text: str
19
+ ) -> Tuple[List[Document], InMemoryStore, Dict[str, str]]:
20
+ """
21
+ Implements the Parent Document strategy for chunking.
22
+
23
+ 1. Splits the document into larger "parent" chunks.
24
+ 2. Splits the parent chunks into smaller "child" chunks.
25
+ 3. The child chunks are used for retrieval, while the parent chunks
26
+ are used to provide context to the LLM.
27
+
28
+ Args:
29
+ full_text: The entire text content of the document.
30
+
31
+ Returns:
32
+ A tuple containing:
33
+ - A list of the small "child" documents for the vector store.
34
+ - An in-memory store mapping parent document IDs to the parent documents.
35
+ - A dictionary mapping child document IDs to their parent's ID.
36
+ """
37
+ if not full_text:
38
+ print("Warning: Input text for chunking is empty.")
39
+ return [], InMemoryStore(), {}
40
+
41
+ print("Creating parent and child chunks...")
42
+
43
+ # This splitter creates the large documents that will be stored.
44
+ parent_splitter = RecursiveCharacterTextSplitter(
45
+ chunk_size=PARENT_CHUNK_SIZE,
46
+ chunk_overlap=PARENT_CHUNK_OVERLAP,
47
+ )
48
+
49
+ # This splitter creates the small, granular chunks for retrieval.
50
+ child_splitter = RecursiveCharacterTextSplitter(
51
+ chunk_size=CHILD_CHUNK_SIZE,
52
+ chunk_overlap=CHILD_CHUNK_OVERLAP,
53
+ )
54
+
55
+ parent_documents = parent_splitter.create_documents([full_text])
56
+
57
+ docstore = InMemoryStore()
58
+ child_documents = []
59
+ child_to_parent_id_map = {}
60
+
61
+ # Generate unique IDs for each parent document and add them to the store
62
+ parent_ids = [str(uuid.uuid4()) for _ in parent_documents]
63
+ docstore.mset(list(zip(parent_ids, parent_documents)))
64
+
65
+ # Split each parent document into smaller child documents
66
+ for i, p_doc in enumerate(parent_documents):
67
+ parent_id = parent_ids[i]
68
+ _child_docs = child_splitter.split_documents([p_doc])
69
+
70
+ for _child_doc in _child_docs:
71
+ child_id = str(uuid.uuid4())
72
+ _child_doc.metadata["parent_id"] = parent_id
73
+ _child_doc.metadata["child_id"] = child_id
74
+ child_to_parent_id_map[child_id] = parent_id
75
+
76
+ child_documents.extend(_child_docs)
77
+
78
+ print(f"Created {len(parent_documents)} parent chunks and {len(child_documents)} child chunks.")
79
+ return child_documents, docstore, child_to_parent_id_map
document_processor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: document_processing.py
2
+
3
+ import os
4
+ import time
5
+ import httpx
6
+ from pathlib import Path
7
+ from urllib.parse import urlparse, unquote
8
+ from llama_index.readers.file import PyMuPDFReader
9
+ from llama_index.core import Document as LlamaDocument
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from pydantic import HttpUrl
12
+ from typing import List
13
+
14
+ # Define the batch size for parallel processing
15
+ BATCH_SIZE = 25
16
+
17
+ def _process_page_batch(documents_batch: List[LlamaDocument]) -> str:
18
+ """
19
+ Helper function to extract content from a batch of LlamaIndex Document objects
20
+ and join them into a single string.
21
+ """
22
+ return "\n\n".join([d.get_content() for d in documents_batch])
23
+
24
+ async def ingest_and_parse_document(doc_url: HttpUrl) -> str:
25
+ """
26
+ Asynchronously downloads a document, saves it locally, and parses it to
27
+ Markdown text using PyMuPDFReader with parallel processing.
28
+
29
+ Args:
30
+ doc_url: The Pydantic-validated URL of the document to process.
31
+
32
+ Returns:
33
+ A single string containing the document's extracted text.
34
+ """
35
+ print(f"Initiating download from: {doc_url}")
36
+ LOCAL_STORAGE_DIR = "data/"
37
+ os.makedirs(LOCAL_STORAGE_DIR, exist_ok=True)
38
+
39
+ try:
40
+ # Asynchronously download the document
41
+ async with httpx.AsyncClient() as client:
42
+ response = await client.get(str(doc_url), timeout=30.0, follow_redirects=True)
43
+ response.raise_for_status()
44
+ doc_bytes = response.content
45
+ print("Download successful.")
46
+
47
+ # Determine a valid local filename
48
+ parsed_path = urlparse(str(doc_url)).path
49
+ filename = unquote(os.path.basename(parsed_path)) or "downloaded_document.pdf"
50
+ local_file_path = Path(os.path.join(LOCAL_STORAGE_DIR, filename))
51
+
52
+ # Save the document locally
53
+ with open(local_file_path, "wb") as f:
54
+ f.write(doc_bytes)
55
+ print(f"Document saved locally at: {local_file_path}")
56
+
57
+ # Parse the document using LlamaIndex's PyMuPDFReader
58
+ print("Parsing document with PyMuPDFReader...")
59
+ loader = PyMuPDFReader()
60
+ docs_from_loader = loader.load_data(file_path=local_file_path)
61
+
62
+ # Parallelize the extraction of text from loaded pages
63
+ start_time = time.perf_counter()
64
+ all_extracted_texts = []
65
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
66
+ futures = [
67
+ executor.submit(_process_page_batch, docs_from_loader[i:i + BATCH_SIZE])
68
+ for i in range(0, len(docs_from_loader), BATCH_SIZE)
69
+ ]
70
+ for future in as_completed(futures):
71
+ all_extracted_texts.append(future.result())
72
+
73
+ doc_text = "\n\n".join(all_extracted_texts)
74
+ elapsed_time = time.perf_counter() - start_time
75
+ print(f"Time taken for parallel text extraction: {elapsed_time:.4f} seconds.")
76
+
77
+ if not doc_text:
78
+ raise ValueError("Document parsing yielded no content.")
79
+
80
+ print(f"Parsing complete. Extracted {len(doc_text)} characters.")
81
+ return doc_text
82
+
83
+ except httpx.HTTPStatusError as e:
84
+ print(f"Error downloading document: {e}")
85
+ raise
86
+ except Exception as e:
87
+ print(f"An unexpected error occurred during document processing: {e}")
88
+ raise
embedding.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: embedding.py
2
+
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from typing import List
6
+
7
+ # --- Configuration ---
8
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
9
+
10
+ class EmbeddingClient:
11
+ """A client for generating text embeddings using a local sentence transformer model."""
12
+
13
+ def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.model = SentenceTransformer(model_name, device=self.device)
16
+ print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")
17
+
18
+ def create_embeddings(self, texts: List[str]) -> torch.Tensor:
19
+ """
20
+ Generates embeddings for a list of text chunks.
21
+
22
+ Args:
23
+ texts: A list of strings to be embedded.
24
+
25
+ Returns:
26
+ A torch.Tensor containing the generated embeddings.
27
+ """
28
+ if not texts:
29
+ return torch.tensor([])
30
+
31
+ print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
32
+ try:
33
+ embeddings = self.model.encode(
34
+ texts, convert_to_tensor=True, show_progress_bar=False
35
+ )
36
+ print("Embeddings generated successfully.")
37
+ return embeddings
38
+ except Exception as e:
39
+ print(f"An error occurred during embedding generation: {e}")
40
+ raise
generation.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: generation.py
2
+ from groq import AsyncGroq
3
+ from typing import List, Dict
4
+
5
+ # --- Configuration ---
6
+ GROQ_MODEL_NAME = "llama3-8b-8192"
7
+
8
+ async def generate_answer(query: str, context_chunks: List[Dict], groq_api_key: str) -> str:
9
+ """
10
+ Generates a final answer using the Groq API based on the query and retrieved context.
11
+
12
+ Args:
13
+ query: The user's original question.
14
+ context_chunks: A list of the most relevant, reranked document chunks.
15
+ groq_api_key: The API key for the Groq service.
16
+
17
+ Returns:
18
+ A string containing the generated answer.
19
+ """
20
+ if not groq_api_key:
21
+ return "Error: Groq API key is not set."
22
+ if not context_chunks:
23
+ return "I do not have enough information to answer this question based on the provided document."
24
+
25
+ print("Generating final answer with Groq...")
26
+ client = AsyncGroq(api_key=groq_api_key)
27
+
28
+ # Format the context for the prompt
29
+ context_str = "\n\n---\n\n".join(
30
+ [f"Context Chunk:\n{chunk['content']}" for chunk in context_chunks]
31
+ )
32
+
33
+ prompt = (
34
+ "You are an expert Q&A system. Your task is to extract information with 100% accuracy from the provided text. Provide a brief and direct answer."
35
+ "Do not mention the context in your response. Answer *only* using the information from the provided document."
36
+ "Do not infer, add, or assume any information that is not explicitly written in the source text. If the answer is not in the document, state that the information is not available."
37
+ "When the question involves numbers, percentages, or monetary values, extract the exact figures from the text."
38
+ "Double-check that the value corresponds to the correct plan or condition mentioned in the question."
39
+ "\n\n"
40
+ f"CONTEXT:\n{context_str}\n\n"
41
+ f"QUESTION:\n{query}\n\n"
42
+ "ANSWER:"
43
+ )
44
+
45
+ try:
46
+ chat_completion = await client.chat.completions.create(
47
+ messages=[{"role": "user", "content": prompt}],
48
+ model=GROQ_MODEL_NAME,
49
+ temperature=0.2, # Lower temperature for more factual answers
50
+ max_tokens=500,
51
+ )
52
+ answer = chat_completion.choices[0].message.content
53
+ print("Answer generated successfully.")
54
+ return answer
55
+ except Exception as e:
56
+ print(f"An error occurred during Groq API call: {e}")
57
+ return "Could not generate an answer due to an API error."
main3.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: main.py
2
+ import time
3
+ import os
4
+ import asyncio
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, HttpUrl
7
+ from typing import List, Dict, Any
8
+ from dotenv import load_dotenv
9
+
10
+ from document_processor import ingest_and_parse_document
11
+ from chunking_parent import create_parent_child_chunks # <-- Using the parent-child function
12
+ from embedding import EmbeddingClient
13
+ from retrieval_parent import Retriever, generate_hypothetical_document
14
+ from generation import generate_answer
15
+
16
+ load_dotenv()
17
+
18
+ app = FastAPI(
19
+ title="Modular RAG API",
20
+ description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
21
+ version="2.2.1", # Updated version
22
+ )
23
+
24
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
25
+ embedding_client = EmbeddingClient()
26
+ retriever = Retriever(embedding_client=embedding_client)
27
+
28
+ # --- Pydantic Models ---
29
+ class RunRequest(BaseModel):
30
+ document_url: HttpUrl
31
+ questions: List[str]
32
+
33
+ class RunResponse(BaseModel):
34
+ answers: List[str]
35
+
36
+ class TestRequest(BaseModel): # <-- Model for test endpoints
37
+ document_url: HttpUrl
38
+
39
+ # --- NEW: Test Endpoint for Parent-Child Chunking ---
40
+ @app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
41
+ async def test_chunking_endpoint(request: TestRequest):
42
+ """
43
+ Tests the parent-child chunking strategy.
44
+ Returns parent chunks, child chunks, and the time taken.
45
+ """
46
+ print("--- Running Parent-Child Chunking Test ---")
47
+ start_time = time.perf_counter()
48
+
49
+ try:
50
+ # Step 1: Parse the document to get raw text
51
+ markdown_content = await ingest_and_parse_document(request.document_url)
52
+
53
+ # Step 2: Create parent and child chunks
54
+ child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
55
+
56
+ end_time = time.perf_counter()
57
+ duration = end_time - start_time
58
+ print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
59
+
60
+ # Convert Document objects to a JSON-serializable list for the response
61
+ child_chunk_results = [
62
+ {"page_content": doc.page_content, "metadata": doc.metadata}
63
+ for doc in child_documents
64
+ ]
65
+
66
+ # Retrieve parent documents from the in-memory store
67
+ parent_docs = docstore.mget(list(docstore.store.keys()))
68
+ parent_chunk_results = [
69
+ {"page_content": doc.page_content, "metadata": doc.metadata}
70
+ for doc in parent_docs if doc
71
+ ]
72
+
73
+ return {
74
+ "total_time_seconds": duration,
75
+ "parent_chunk_count": len(parent_chunk_results),
76
+ "child_chunk_count": len(child_chunk_results),
77
+ "parent_chunks": parent_chunk_results,
78
+ "child_chunks": child_chunk_results,
79
+ }
80
+ except Exception as e:
81
+ raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")
82
+
83
+
84
+ @app.post("/hackrx/run", response_model=RunResponse)
85
+ async def run_rag_pipeline(request: RunRequest):
86
+ try:
87
+ print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
88
+
89
+ # --- STAGE 1: DOCUMENT INGESTION ---
90
+ markdown_content = await ingest_and_parse_document(request.document_url)
91
+
92
+ # --- STAGE 2: PARENT-CHILD CHUNKING ---
93
+ child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
94
+
95
+ if not child_documents:
96
+ raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
97
+
98
+ # --- STAGE 3: INDEXING ---
99
+ retriever.index(child_documents, docstore)
100
+
101
+ # --- CONCURRENT WORKFLOW ---
102
+ hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
103
+ all_hyde_docs = await asyncio.gather(*hyde_tasks)
104
+
105
+ retrieval_tasks = [
106
+ retriever.retrieve(q, hyde_doc)
107
+ for q, hyde_doc in zip(request.questions, all_hyde_docs)
108
+ ]
109
+ all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
110
+
111
+ answer_tasks = [
112
+ generate_answer(q, chunks, GROQ_API_KEY)
113
+ for q, chunks in zip(request.questions, all_retrieved_chunks)
114
+ ]
115
+ final_answers = await asyncio.gather(*answer_tasks)
116
+
117
+ print("--- RAG Pipeline Completed Successfully ---")
118
+ return RunResponse(answers=final_answers)
119
+
120
+ except Exception as e:
121
+ raise HTTPException(
122
+ status_code=500, detail=f"An internal server error occurred: {str(e)}"
123
+ )
requirements.txt ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.10.0
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.12.15
5
+ aiosignal==1.4.0
6
+ aiosqlite==0.21.0
7
+ annotated-types==0.7.0
8
+ antlr4-python3-runtime==4.9.3
9
+ anyio==4.9.0
10
+ asyncio==3.4.3
11
+ attrs==25.3.0
12
+ backoff==2.2.1
13
+ banks==2.2.0
14
+ beautifulsoup4==4.13.4
15
+ bitarray==3.6.0
16
+ blinker==1.9.0
17
+ cachetools==5.5.2
18
+ catalogue==2.0.10
19
+ certifi==2025.8.3
20
+ cffi==1.17.1
21
+ charset-normalizer==3.4.2
22
+ click==8.2.1
23
+ colbert-ai==0.2.21
24
+ colorama==0.4.6
25
+ coloredlogs==15.0.1
26
+ contourpy==1.3.3
27
+ cryptography==45.0.6
28
+ cycler==0.12.1
29
+ dataclasses-json==0.6.7
30
+ datasets==4.0.0
31
+ defusedxml==0.7.1
32
+ deprecated==1.2.18
33
+ dill==0.3.8
34
+ dirtyjson==1.0.8
35
+ distro==1.9.0
36
+ dotenv==0.9.9
37
+ effdet==0.4.1
38
+ emoji==2.14.1
39
+ et-xmlfile==2.0.0
40
+ faiss-cpu==1.11.0.post1
41
+ fast-pytorch-kmeans==0.2.2
42
+ fastapi==0.116.1
43
+ filelock==3.18.0
44
+ filetype==1.2.0
45
+ flashrank==0.2.10
46
+ flask==3.1.1
47
+ flatbuffers==25.2.10
48
+ fonttools==4.59.0
49
+ frozenlist==1.7.0
50
+ fsspec==2025.3.0
51
+ git-python==1.0.3
52
+ gitdb==4.0.12
53
+ gitpython==3.1.45
54
+ google-api-core==2.25.1
55
+ google-auth==2.40.3
56
+ google-cloud-vision==3.10.2
57
+ googleapis-common-protos==1.70.0
58
+ greenlet==3.2.3
59
+ griffe==1.9.0
60
+ groq==0.30.0
61
+ grpcio==1.74.0
62
+ grpcio-status==1.74.0
63
+ h11==0.16.0
64
+ hf-xet==1.1.5
65
+ html5lib==1.1
66
+ httpcore==1.0.9
67
+ httptools==0.6.4
68
+ httpx==0.28.1
69
+ httpx-sse==0.4.1
70
+ huggingface-hub==0.34.3
71
+ humanfriendly==10.0
72
+ idna==3.10
73
+ itsdangerous==2.2.0
74
+ jinja2==3.1.6
75
+ jiter==0.10.0
76
+ joblib==1.5.1
77
+ jsonpatch==1.33
78
+ jsonpointer==3.0.0
79
+ kiwisolver==1.4.8
80
+ langchain==0.3.27
81
+ langchain-community==0.3.27
82
+ langchain-core==0.3.72
83
+ langchain-pymupdf4llm==0.4.1
84
+ langchain-text-splitters==0.3.9
85
+ langdetect==1.0.9
86
+ langsmith==0.4.10
87
+ llama-cloud==0.1.35
88
+ llama-cloud-services==0.6.53
89
+ llama-index==0.13.0
90
+ llama-index-cli==0.5.0
91
+ llama-index-core==0.13.0
92
+ llama-index-embeddings-openai==0.5.0
93
+ llama-index-indices-managed-llama-cloud==0.9.0
94
+ llama-index-instrumentation==0.4.0
95
+ llama-index-llms-openai==0.5.1
96
+ llama-index-readers-file==0.5.0
97
+ llama-index-readers-llama-parse==0.5.0
98
+ llama-index-workflows==1.2.0
99
+ llama-parse==0.6.53
100
+ lxml==6.0.0
101
+ markdown==3.8.2
102
+ markupsafe==3.0.2
103
+ marshmallow==3.26.1
104
+ matplotlib==3.10.5
105
+ mpmath==1.3.0
106
+ msoffcrypto-tool==5.4.2
107
+ multidict==6.6.3
108
+ multiprocess==0.70.16
109
+ mypy-extensions==1.1.0
110
+ nest-asyncio==1.6.0
111
+ networkx==3.5
112
+ ninja==1.11.1.4
113
+ nltk==3.9.1
114
+ numpy==2.3.2
115
+ olefile==0.47
116
+ omegaconf==2.3.0
117
+ onnx==1.18.0
118
+ onnxruntime==1.22.1
119
+ openai==1.99.3
120
+ opencv-python==4.11.0.86
121
+ openpyxl==3.1.5
122
+ orjson==3.11.1
123
+ packaging==25.0
124
+ pandas==2.2.3
125
+ pdf2image==1.17.0
126
+ pdfminer-six==20250506
127
+ pi-heif==1.1.0
128
+ pikepdf==9.10.2
129
+ pillow==11.3.0
130
+ pinecone-client==6.0.0
131
+ pinecone-plugin-interface==0.0.7
132
+ platformdirs==4.3.8
133
+ propcache==0.3.2
134
+ proto-plus==1.26.1
135
+ protobuf==6.31.1
136
+ psutil==7.0.0
137
+ pyarrow==21.0.0
138
+ pyasn1==0.6.1
139
+ pyasn1-modules==0.4.2
140
+ pycocotools==2.0.10
141
+ pycparser==2.22
142
+ pydantic==2.11.7
143
+ pydantic-core==2.33.2
144
+ pydantic-settings==2.10.1
145
+ pymupdf==1.26.3
146
+ pymupdf4llm==0.0.27
147
+ pypandoc==1.15
148
+ pyparsing==3.2.3
149
+ pypdf==5.9.0
150
+ pypdf2==3.0.1
151
+ pypdfium2==4.30.0
152
+ pytesseract==0.3.13
153
+ python-dateutil==2.9.0.post0
154
+ python-docx==1.2.0
155
+ python-dotenv==1.1.1
156
+ python-iso639==2025.2.18
157
+ python-magic==0.4.27
158
+ python-multipart==0.0.20
159
+ python-oxmsg==0.0.2
160
+ python-pptx==1.0.2
161
+ pytz==2025.2
162
+ pyyaml==6.0.2
163
+ ragatouille==0.0.9.post2
164
+ rank-bm25==0.2.2
165
+ rapidfuzz==3.13.0
166
+ regex==2025.7.34
167
+ requests==2.32.4
168
+ requests-toolbelt==1.0.0
169
+ rsa==4.9.1
170
+ safetensors==0.5.3
171
+ scikit-learn==1.7.1
172
+ scipy==1.16.1
173
+ sentence-transformers==5.0.0
174
+ setuptools==80.9.0
175
+ six==1.17.0
176
+ smmap==5.0.2
177
+ sniffio==1.3.1
178
+ soupsieve==2.7
179
+ sqlalchemy==2.0.42
180
+ srsly==2.5.1
181
+ starlette==0.47.2
182
+ striprtf==0.0.26
183
+ sympy==1.14.0
184
+ tenacity==9.1.2
185
+ tesseract==0.1.3
186
+ threadpoolctl==3.6.0
187
+ tiktoken==0.9.0
188
+ timm==1.0.19
189
+ tokenizers==0.21.4
190
+ torch==2.7.1
191
+ torchvision==0.22.1
192
+ tqdm==4.67.1
193
+ transformers==4.49.0
194
+ typing==3.7.4.3
195
+ typing-extensions==4.14.1
196
+ typing-inspect==0.9.0
197
+ typing-inspection==0.4.1
198
+ tzdata==2025.2
199
+ ujson==5.10.0
200
+ unstructured==0.18.11
201
+ unstructured-client==0.42.2
202
+ unstructured-inference==1.0.5
203
+ unstructured-pytesseract==0.3.15
204
+ urllib3==2.5.0
205
+ uvicorn==0.35.0
206
+ uvloop==0.21.0
207
+ voyager==2.1.0
208
+ watchfiles==1.1.0
209
+ webencodings==0.5.1
210
+ websockets==15.0.1
211
+ werkzeug==3.1.3
212
+ wrapt==1.17.2
213
+ xlrd==2.0.2
214
+ xlsxwriter==3.2.5
215
+ xxhash==3.5.0
216
+ yarl==1.20.1
217
+ zstandard==0.23.0
retrieval_parent.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: retrieval.py
2
+
3
+ import time
4
+ import asyncio
5
+ import numpy as np
6
+ import torch
7
+ from groq import AsyncGroq
8
+ from rank_bm25 import BM25Okapi
9
+ from sentence_transformers import CrossEncoder
10
+ from sklearn.preprocessing import MinMaxScaler
11
+ from torch.nn.functional import cosine_similarity
12
+ from typing import List, Dict, Tuple
13
+ from langchain.storage import InMemoryStore
14
+
15
+ from embedding import EmbeddingClient
16
+ from langchain_core.documents import Document
17
+
18
+ # --- Configuration ---
19
+ HYDE_MODEL = "llama3-8b-8192"
20
+ RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
21
+ INITIAL_K_CANDIDATES = 20
22
+ TOP_K_CHUNKS = 10
23
+
24
+ async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
25
+ # ... (this function remains unchanged) ...
26
+ if not groq_api_key:
27
+ print("Groq API key not set. Skipping HyDE generation.")
28
+ return ""
29
+
30
+ print(f"Starting HyDE generation for query: '{query}'...")
31
+ client = AsyncGroq(api_key=groq_api_key)
32
+ prompt = (
33
+ f"Write a brief, formal passage that answers the following question. "
34
+ f"Use specific terminology as if it were from a larger document. "
35
+ f"Do not include the question or conversational text.\n\n"
36
+ f"Question: {query}\n\n"
37
+ f"Hypothetical Passage:"
38
+ )
39
+
40
+ try:
41
+ chat_completion = await client.chat.completions.create(
42
+ messages=[{"role": "user", "content": prompt}],
43
+ model=HYDE_MODEL,
44
+ temperature=0.7,
45
+ max_tokens=500,
46
+ )
47
+ return chat_completion.choices[0].message.content
48
+ except Exception as e:
49
+ print(f"An error occurred during HyDE generation: {e}")
50
+ return ""
51
+
52
+
53
+ class Retriever:
54
+ """Manages hybrid search with parent-child retrieval."""
55
+
56
+ def __init__(self, embedding_client: EmbeddingClient):
57
+ self.embedding_client = embedding_client
58
+ self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
59
+ self.bm25 = None
60
+ self.document_chunks = []
61
+ self.chunk_embeddings = None
62
+ self.docstore = InMemoryStore() # <-- ADD THIS
63
+ print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
64
+
65
+ def index(self, child_documents: List[Document], docstore: InMemoryStore): # <-- MODIFY THIS
66
+ """Builds the search index from child documents and stores parent documents."""
67
+ self.document_chunks = child_documents # Store child docs for mapping
68
+ self.docstore = docstore # Store the parent documents
69
+
70
+ corpus = [doc.page_content for doc in child_documents]
71
+ if not corpus:
72
+ print("No documents to index.")
73
+ return
74
+
75
+ print("Indexing child documents for retrieval...")
76
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
77
+ self.bm25 = BM25Okapi(tokenized_corpus)
78
+ self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
79
+ print("Indexing complete.")
80
+
81
+ def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
82
+ # ... (this function remains unchanged) ...
83
+ if self.bm25 is None or self.chunk_embeddings is None:
84
+ raise ValueError("Retriever has not been indexed. Call index() first.")
85
+
86
+ enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
87
+ tokenized_query = query.split(" ")
88
+ bm25_scores = self.bm25.get_scores(tokenized_query)
89
+ query_embedding = self.embedding_client.create_embeddings([enhanced_query])
90
+ dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
91
+
92
+ scaler = MinMaxScaler()
93
+ norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
94
+ norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
95
+ combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
96
+
97
+ top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
98
+ return [(idx, combined_scores[idx]) for idx in top_indices]
99
+
100
+
101
+ async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
102
+ # ... (this function remains unchanged) ...
103
+ if not candidates:
104
+ return []
105
+
106
+ print(f"Reranking {len(candidates)} candidates...")
107
+ rerank_input = [[query, chunk["content"]] for chunk in candidates]
108
+
109
+ rerank_scores = await asyncio.to_thread(
110
+ self.reranker.predict, rerank_input, show_progress_bar=False
111
+ )
112
+
113
+ for candidate, score in zip(candidates, rerank_scores):
114
+ candidate['rerank_score'] = score
115
+
116
+ candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
117
+ return candidates[:TOP_K_CHUNKS]
118
+
119
+
120
+ async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: # <-- MODIFY THIS
121
+ """Executes the full retrieval pipeline and returns parent documents."""
122
+ print(f"Retrieving documents for query: '{query}'")
123
+
124
+ # 1. Hybrid search returns indices of the best CHILD documents
125
+ initial_candidates_info = self._hybrid_search(query, hyde_doc)
126
+
127
+ retrieved_child_docs = [{
128
+ "content": self.document_chunks[idx].page_content,
129
+ "metadata": self.document_chunks[idx].metadata,
130
+ } for idx, score in initial_candidates_info]
131
+
132
+ # 2. Rerank the CHILD documents
133
+ reranked_child_docs = await self._rerank(query, retrieved_child_docs)
134
+
135
+ # 3. Get the unique parent IDs from the reranked child documents
136
+ parent_ids = []
137
+ for doc in reranked_child_docs:
138
+ parent_id = doc["metadata"]["parent_id"]
139
+ if parent_id not in parent_ids:
140
+ parent_ids.append(parent_id)
141
+
142
+ # 4. Retrieve the full PARENT documents from the docstore
143
+ retrieved_parents = self.docstore.mget(parent_ids)
144
+
145
+ # Filter out any None results in case of a miss
146
+ final_parent_docs = [doc for doc in retrieved_parents if doc is not None]
147
+
148
+ # 5. Format for the generation step
149
+ final_context = [{
150
+ "content": doc.page_content,
151
+ "metadata": doc.metadata
152
+ } for doc in final_parent_docs]
153
+
154
+ print(f"Retrieved {len(final_context)} final parent chunks for context.")
155
+ return final_context