Spaces:
Sleeping
Sleeping
File size: 6,473 Bytes
d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 53d71b0 84f4fa5 d4b4587 5c24a84 84f4fa5 5c24a84 84f4fa5 d4b4587 84f4fa5 d4b4587 5c24a84 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 d4b4587 84f4fa5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | # file: main3.py
import time
import os
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Any
from dotenv import load_dotenv
from document_processor import ingest_and_parse_document
from chunking_parent import create_parent_child_chunks
from embedding import EmbeddingClient
from retrieval_parent import Retriever, generate_hypothetical_document
from generation import generate_answer
load_dotenv()
app = FastAPI(
title="Modular RAG API",
description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
version="2.2.2", # Version updated
)
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
embedding_client = EmbeddingClient()
retriever = Retriever(embedding_client=embedding_client)
# --- Pydantic Models ---
class RunRequest(BaseModel):
documents: HttpUrl
questions: List[str]
class RunResponse(BaseModel):
answers: List[str]
class TestRequest(BaseModel):
documents: HttpUrl
# --- NEW: Test Endpoint for Parent-Child Chunking ---
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
async def test_chunking_endpoint(request: TestRequest):
"""
Tests the parent-child chunking strategy.
Returns parent chunks, child chunks, and the time taken.
"""
print("--- Running Parent-Child Chunking Test ---")
start_time = time.perf_counter()
try:
# Step 1: Parse the document to get raw text
markdown_content = await ingest_and_parse_document(request.documents)
# Step 2: Create parent and child chunks
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
end_time = time.perf_counter()
duration = end_time - start_time
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
# Convert Document objects to a JSON-serializable list for the response
child_chunk_results = [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in child_documents
]
# Retrieve parent documents from the in-memory store
parent_docs = docstore.mget(list(docstore.store.keys()))
parent_chunk_results = [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in parent_docs if doc
]
return {
"total_time_seconds": duration,
"parent_chunk_count": len(parent_chunk_results),
"child_chunk_count": len(child_chunk_results),
"parent_chunks": parent_chunk_results,
"child_chunks": child_chunk_results,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")
@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(request: RunRequest):
try:
total_pipeline_start_time = time.perf_counter()
timings = {}
print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
# --- STAGE 1: DOCUMENT INGESTION (Parsing) ---
parse_start = time.perf_counter()
markdown_content = await ingest_and_parse_document(request.documents)
timings["1_parsing"] = time.perf_counter() - parse_start
print(f"Time taken for Parsing: {timings['1_parsing']:.4f} seconds.")
# --- STAGE 2: PARENT-CHILD CHUNKING ---
chunk_start = time.perf_counter()
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
timings["2_chunking"] = time.perf_counter() - chunk_start
print(f"Time taken for Parent-Child Chunking: {timings['2_chunking']:.4f} seconds.")
if not child_documents:
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
# --- STAGE 3: INDEXING (Embedding) ---
index_start = time.perf_counter()
retriever.index(child_documents, docstore)
timings["3_indexing_and_embedding"] = time.perf_counter() - index_start
print(f"Time taken for Indexing (incl. Embeddings): {timings['3_indexing_and_embedding']:.4f} seconds.")
# --- CONCURRENT WORKFLOW ---
# Step A: Concurrently generate hypothetical documents (HyDE)
print("Generating hypothetical documents...")
hyde_start = time.perf_counter()
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
all_hyde_docs = await asyncio.gather(*hyde_tasks)
timings["4_hyde_generation_total"] = time.perf_counter() - hyde_start
print(f"Time taken for HyDE Generation (total): {timings['4_hyde_generation_total']:.4f} seconds.")
# Step B: Concurrently retrieve relevant chunks
print("Retrieving chunks...")
retrieval_start = time.perf_counter()
retrieval_tasks = [
retriever.retrieve(q, hyde_doc)
for q, hyde_doc in zip(request.questions, all_hyde_docs)
]
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
timings["5_retrieval_total"] = time.perf_counter() - retrieval_start
print(f"Time taken for Retrieval (total): {timings['5_retrieval_total']:.4f} seconds.")
# Step C: Concurrently generate final answers
print("Generating final answers...")
generation_start = time.perf_counter()
answer_tasks = [
generate_answer(q, chunks, GROQ_API_KEY)
for q, chunks in zip(request.questions, all_retrieved_chunks)
]
final_answers = await asyncio.gather(*answer_tasks)
timings["6_answer_generation_total"] = time.perf_counter() - generation_start
print(f"Time taken for Answer Generation (total): {timings['6_answer_generation_total']:.4f} seconds.")
timings["total_pipeline_time"] = time.perf_counter() - total_pipeline_start_time
print("\n--- RAG Pipeline Completed Successfully ---")
print(f"--- Total Pipeline Time: {timings['total_pipeline_time']:.4f} seconds ---")
print("--- Timing Breakdown ---")
for stage, duration in timings.items():
print(f"- {stage}: {duration:.4f} seconds")
return RunResponse(answers=final_answers)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"An internal server error occurred: {str(e)}"
) |