# file: main3.py import time import os import asyncio from fastapi import FastAPI, HTTPException from pydantic import BaseModel, HttpUrl from typing import List, Dict, Any from dotenv import load_dotenv from document_processor import ingest_and_parse_document from chunking_parent import create_parent_child_chunks from embedding import EmbeddingClient from retrieval_parent import Retriever, generate_hypothetical_document from generation import generate_answer load_dotenv() app = FastAPI( title="Modular RAG API", description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.", version="2.2.2", # Version updated ) GROQ_API_KEY = os.environ.get("GROQ_API_KEY") embedding_client = EmbeddingClient() retriever = Retriever(embedding_client=embedding_client) # --- Pydantic Models --- class RunRequest(BaseModel): documents: HttpUrl questions: List[str] class RunResponse(BaseModel): answers: List[str] class TestRequest(BaseModel): documents: HttpUrl # --- NEW: Test Endpoint for Parent-Child Chunking --- @app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"]) async def test_chunking_endpoint(request: TestRequest): """ Tests the parent-child chunking strategy. Returns parent chunks, child chunks, and the time taken. """ print("--- Running Parent-Child Chunking Test ---") start_time = time.perf_counter() try: # Step 1: Parse the document to get raw text markdown_content = await ingest_and_parse_document(request.documents) # Step 2: Create parent and child chunks child_documents, docstore, _ = create_parent_child_chunks(markdown_content) end_time = time.perf_counter() duration = end_time - start_time print(f"--- Parsing and Chunking took {duration:.2f} seconds ---") # Convert Document objects to a JSON-serializable list for the response child_chunk_results = [ {"page_content": doc.page_content, "metadata": doc.metadata} for doc in child_documents ] # Retrieve parent documents from the in-memory store parent_docs = docstore.mget(list(docstore.store.keys())) parent_chunk_results = [ {"page_content": doc.page_content, "metadata": doc.metadata} for doc in parent_docs if doc ] return { "total_time_seconds": duration, "parent_chunk_count": len(parent_chunk_results), "child_chunk_count": len(child_chunk_results), "parent_chunks": parent_chunk_results, "child_chunks": child_chunk_results, } except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}") @app.post("/hackrx/run", response_model=RunResponse) async def run_rag_pipeline(request: RunRequest): try: total_pipeline_start_time = time.perf_counter() timings = {} print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---") # --- STAGE 1: DOCUMENT INGESTION (Parsing) --- parse_start = time.perf_counter() markdown_content = await ingest_and_parse_document(request.documents) timings["1_parsing"] = time.perf_counter() - parse_start print(f"Time taken for Parsing: {timings['1_parsing']:.4f} seconds.") # --- STAGE 2: PARENT-CHILD CHUNKING --- chunk_start = time.perf_counter() child_documents, docstore, _ = create_parent_child_chunks(markdown_content) timings["2_chunking"] = time.perf_counter() - chunk_start print(f"Time taken for Parent-Child Chunking: {timings['2_chunking']:.4f} seconds.") if not child_documents: raise HTTPException(status_code=400, detail="Document could not be processed into chunks.") # --- STAGE 3: INDEXING (Embedding) --- index_start = time.perf_counter() retriever.index(child_documents, docstore) timings["3_indexing_and_embedding"] = time.perf_counter() - index_start print(f"Time taken for Indexing (incl. Embeddings): {timings['3_indexing_and_embedding']:.4f} seconds.") # --- CONCURRENT WORKFLOW --- # Step A: Concurrently generate hypothetical documents (HyDE) print("Generating hypothetical documents...") hyde_start = time.perf_counter() hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions] all_hyde_docs = await asyncio.gather(*hyde_tasks) timings["4_hyde_generation_total"] = time.perf_counter() - hyde_start print(f"Time taken for HyDE Generation (total): {timings['4_hyde_generation_total']:.4f} seconds.") # Step B: Concurrently retrieve relevant chunks print("Retrieving chunks...") retrieval_start = time.perf_counter() retrieval_tasks = [ retriever.retrieve(q, hyde_doc) for q, hyde_doc in zip(request.questions, all_hyde_docs) ] all_retrieved_chunks = await asyncio.gather(*retrieval_tasks) timings["5_retrieval_total"] = time.perf_counter() - retrieval_start print(f"Time taken for Retrieval (total): {timings['5_retrieval_total']:.4f} seconds.") # Step C: Concurrently generate final answers print("Generating final answers...") generation_start = time.perf_counter() answer_tasks = [ generate_answer(q, chunks, GROQ_API_KEY) for q, chunks in zip(request.questions, all_retrieved_chunks) ] final_answers = await asyncio.gather(*answer_tasks) timings["6_answer_generation_total"] = time.perf_counter() - generation_start print(f"Time taken for Answer Generation (total): {timings['6_answer_generation_total']:.4f} seconds.") timings["total_pipeline_time"] = time.perf_counter() - total_pipeline_start_time print("\n--- RAG Pipeline Completed Successfully ---") print(f"--- Total Pipeline Time: {timings['total_pipeline_time']:.4f} seconds ---") print("--- Timing Breakdown ---") for stage, duration in timings.items(): print(f"- {stage}: {duration:.4f} seconds") return RunResponse(answers=final_answers) except Exception as e: raise HTTPException( status_code=500, detail=f"An internal server error occurred: {str(e)}" )