ParentHackRx / main3.py
PercivalFletcher's picture
Update main3.py
5c24a84 verified
# file: main3.py
import time
import os
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Any
from dotenv import load_dotenv
from document_processor import ingest_and_parse_document
from chunking_parent import create_parent_child_chunks
from embedding import EmbeddingClient
from retrieval_parent import Retriever, generate_hypothetical_document
from generation import generate_answer
load_dotenv()
app = FastAPI(
title="Modular RAG API",
description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
version="2.2.2", # Version updated
)
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
embedding_client = EmbeddingClient()
retriever = Retriever(embedding_client=embedding_client)
# --- Pydantic Models ---
class RunRequest(BaseModel):
documents: HttpUrl
questions: List[str]
class RunResponse(BaseModel):
answers: List[str]
class TestRequest(BaseModel):
documents: HttpUrl
# --- NEW: Test Endpoint for Parent-Child Chunking ---
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
async def test_chunking_endpoint(request: TestRequest):
"""
Tests the parent-child chunking strategy.
Returns parent chunks, child chunks, and the time taken.
"""
print("--- Running Parent-Child Chunking Test ---")
start_time = time.perf_counter()
try:
# Step 1: Parse the document to get raw text
markdown_content = await ingest_and_parse_document(request.documents)
# Step 2: Create parent and child chunks
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
end_time = time.perf_counter()
duration = end_time - start_time
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
# Convert Document objects to a JSON-serializable list for the response
child_chunk_results = [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in child_documents
]
# Retrieve parent documents from the in-memory store
parent_docs = docstore.mget(list(docstore.store.keys()))
parent_chunk_results = [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in parent_docs if doc
]
return {
"total_time_seconds": duration,
"parent_chunk_count": len(parent_chunk_results),
"child_chunk_count": len(child_chunk_results),
"parent_chunks": parent_chunk_results,
"child_chunks": child_chunk_results,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")
@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(request: RunRequest):
try:
total_pipeline_start_time = time.perf_counter()
timings = {}
print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
# --- STAGE 1: DOCUMENT INGESTION (Parsing) ---
parse_start = time.perf_counter()
markdown_content = await ingest_and_parse_document(request.documents)
timings["1_parsing"] = time.perf_counter() - parse_start
print(f"Time taken for Parsing: {timings['1_parsing']:.4f} seconds.")
# --- STAGE 2: PARENT-CHILD CHUNKING ---
chunk_start = time.perf_counter()
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
timings["2_chunking"] = time.perf_counter() - chunk_start
print(f"Time taken for Parent-Child Chunking: {timings['2_chunking']:.4f} seconds.")
if not child_documents:
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
# --- STAGE 3: INDEXING (Embedding) ---
index_start = time.perf_counter()
retriever.index(child_documents, docstore)
timings["3_indexing_and_embedding"] = time.perf_counter() - index_start
print(f"Time taken for Indexing (incl. Embeddings): {timings['3_indexing_and_embedding']:.4f} seconds.")
# --- CONCURRENT WORKFLOW ---
# Step A: Concurrently generate hypothetical documents (HyDE)
print("Generating hypothetical documents...")
hyde_start = time.perf_counter()
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
all_hyde_docs = await asyncio.gather(*hyde_tasks)
timings["4_hyde_generation_total"] = time.perf_counter() - hyde_start
print(f"Time taken for HyDE Generation (total): {timings['4_hyde_generation_total']:.4f} seconds.")
# Step B: Concurrently retrieve relevant chunks
print("Retrieving chunks...")
retrieval_start = time.perf_counter()
retrieval_tasks = [
retriever.retrieve(q, hyde_doc)
for q, hyde_doc in zip(request.questions, all_hyde_docs)
]
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
timings["5_retrieval_total"] = time.perf_counter() - retrieval_start
print(f"Time taken for Retrieval (total): {timings['5_retrieval_total']:.4f} seconds.")
# Step C: Concurrently generate final answers
print("Generating final answers...")
generation_start = time.perf_counter()
answer_tasks = [
generate_answer(q, chunks, GROQ_API_KEY)
for q, chunks in zip(request.questions, all_retrieved_chunks)
]
final_answers = await asyncio.gather(*answer_tasks)
timings["6_answer_generation_total"] = time.perf_counter() - generation_start
print(f"Time taken for Answer Generation (total): {timings['6_answer_generation_total']:.4f} seconds.")
timings["total_pipeline_time"] = time.perf_counter() - total_pipeline_start_time
print("\n--- RAG Pipeline Completed Successfully ---")
print(f"--- Total Pipeline Time: {timings['total_pipeline_time']:.4f} seconds ---")
print("--- Timing Breakdown ---")
for stage, duration in timings.items():
print(f"- {stage}: {duration:.4f} seconds")
return RunResponse(answers=final_answers)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"An internal server error occurred: {str(e)}"
)