File size: 6,473 Bytes
d4b4587
84f4fa5
 
 
 
 
 
 
 
 
d4b4587
84f4fa5
 
 
 
 
 
 
 
 
d4b4587
84f4fa5
 
 
 
 
 
 
 
53d71b0
84f4fa5
 
 
 
 
d4b4587
5c24a84
84f4fa5
 
 
 
 
 
 
 
 
 
 
 
 
5c24a84
84f4fa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b4587
 
84f4fa5
 
d4b4587
 
5c24a84
d4b4587
 
84f4fa5
 
d4b4587
84f4fa5
d4b4587
 
84f4fa5
 
 
 
d4b4587
 
84f4fa5
d4b4587
 
84f4fa5
 
d4b4587
 
 
84f4fa5
 
d4b4587
 
84f4fa5
d4b4587
 
 
84f4fa5
 
 
 
 
d4b4587
 
84f4fa5
d4b4587
 
 
84f4fa5
 
 
 
 
d4b4587
 
 
 
 
 
 
 
 
84f4fa5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# file: main3.py
import time
import os
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Any
from dotenv import load_dotenv

from document_processor import ingest_and_parse_document
from chunking_parent import create_parent_child_chunks
from embedding import EmbeddingClient
from retrieval_parent import Retriever, generate_hypothetical_document
from generation import generate_answer

load_dotenv()

app = FastAPI(
    title="Modular RAG API",
    description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
    version="2.2.2", # Version updated
)

GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
embedding_client = EmbeddingClient()
retriever = Retriever(embedding_client=embedding_client)

# --- Pydantic Models ---
class RunRequest(BaseModel):
    documents: HttpUrl
    questions: List[str]

class RunResponse(BaseModel):
    answers: List[str]

class TestRequest(BaseModel):
    documents: HttpUrl

# --- NEW: Test Endpoint for Parent-Child Chunking ---
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
async def test_chunking_endpoint(request: TestRequest):
    """
    Tests the parent-child chunking strategy.
    Returns parent chunks, child chunks, and the time taken.
    """
    print("--- Running Parent-Child Chunking Test ---")
    start_time = time.perf_counter()

    try:
        # Step 1: Parse the document to get raw text
        markdown_content = await ingest_and_parse_document(request.documents)
        
        # Step 2: Create parent and child chunks
        child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
        
        end_time = time.perf_counter()
        duration = end_time - start_time
        print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")

        # Convert Document objects to a JSON-serializable list for the response
        child_chunk_results = [
            {"page_content": doc.page_content, "metadata": doc.metadata}
            for doc in child_documents
        ]
        
        # Retrieve parent documents from the in-memory store
        parent_docs = docstore.mget(list(docstore.store.keys()))
        parent_chunk_results = [
            {"page_content": doc.page_content, "metadata": doc.metadata}
            for doc in parent_docs if doc
        ]
        
        return {
            "total_time_seconds": duration,
            "parent_chunk_count": len(parent_chunk_results),
            "child_chunk_count": len(child_chunk_results),
            "parent_chunks": parent_chunk_results,
            "child_chunks": child_chunk_results,
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")


@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(request: RunRequest):
    try:
        total_pipeline_start_time = time.perf_counter()
        timings = {}
        print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
        
        # --- STAGE 1: DOCUMENT INGESTION (Parsing) ---
        parse_start = time.perf_counter()
        markdown_content = await ingest_and_parse_document(request.documents)
        timings["1_parsing"] = time.perf_counter() - parse_start
        print(f"Time taken for Parsing: {timings['1_parsing']:.4f} seconds.")
        
        # --- STAGE 2: PARENT-CHILD CHUNKING ---
        chunk_start = time.perf_counter()
        child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
        timings["2_chunking"] = time.perf_counter() - chunk_start
        print(f"Time taken for Parent-Child Chunking: {timings['2_chunking']:.4f} seconds.")

        if not child_documents:
            raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")

        # --- STAGE 3: INDEXING (Embedding) ---
        index_start = time.perf_counter()
        retriever.index(child_documents, docstore)
        timings["3_indexing_and_embedding"] = time.perf_counter() - index_start
        print(f"Time taken for Indexing (incl. Embeddings): {timings['3_indexing_and_embedding']:.4f} seconds.")

        # --- CONCURRENT WORKFLOW ---
        # Step A: Concurrently generate hypothetical documents (HyDE)
        print("Generating hypothetical documents...")
        hyde_start = time.perf_counter()
        hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
        all_hyde_docs = await asyncio.gather(*hyde_tasks)
        timings["4_hyde_generation_total"] = time.perf_counter() - hyde_start
        print(f"Time taken for HyDE Generation (total): {timings['4_hyde_generation_total']:.4f} seconds.")

        # Step B: Concurrently retrieve relevant chunks
        print("Retrieving chunks...")
        retrieval_start = time.perf_counter()
        retrieval_tasks = [
            retriever.retrieve(q, hyde_doc)
            for q, hyde_doc in zip(request.questions, all_hyde_docs)
        ]
        all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
        timings["5_retrieval_total"] = time.perf_counter() - retrieval_start
        print(f"Time taken for Retrieval (total): {timings['5_retrieval_total']:.4f} seconds.")

        # Step C: Concurrently generate final answers
        print("Generating final answers...")
        generation_start = time.perf_counter()
        answer_tasks = [
            generate_answer(q, chunks, GROQ_API_KEY)
            for q, chunks in zip(request.questions, all_retrieved_chunks)
        ]
        final_answers = await asyncio.gather(*answer_tasks)
        timings["6_answer_generation_total"] = time.perf_counter() - generation_start
        print(f"Time taken for Answer Generation (total): {timings['6_answer_generation_total']:.4f} seconds.")

        timings["total_pipeline_time"] = time.perf_counter() - total_pipeline_start_time
        print("\n--- RAG Pipeline Completed Successfully ---")
        print(f"--- Total Pipeline Time: {timings['total_pipeline_time']:.4f} seconds ---")
        print("--- Timing Breakdown ---")
        for stage, duration in timings.items():
            print(f"- {stage}: {duration:.4f} seconds")

        return RunResponse(answers=final_answers)

    except Exception as e:
        raise HTTPException(
            status_code=500, detail=f"An internal server error occurred: {str(e)}"
        )