Spaces:
Sleeping
Sleeping
| # file: main3.py | |
| import time | |
| import os | |
| import asyncio | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, HttpUrl | |
| from typing import List, Dict, Any | |
| from dotenv import load_dotenv | |
| from document_processor import ingest_and_parse_document | |
| from chunking_parent import create_parent_child_chunks | |
| from embedding import EmbeddingClient | |
| from retrieval_parent import Retriever, generate_hypothetical_document | |
| from generation import generate_answer | |
| load_dotenv() | |
| app = FastAPI( | |
| title="Modular RAG API", | |
| description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.", | |
| version="2.2.2", # Version updated | |
| ) | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| embedding_client = EmbeddingClient() | |
| retriever = Retriever(embedding_client=embedding_client) | |
| # --- Pydantic Models --- | |
| class RunRequest(BaseModel): | |
| documents: HttpUrl | |
| questions: List[str] | |
| class RunResponse(BaseModel): | |
| answers: List[str] | |
| class TestRequest(BaseModel): | |
| documents: HttpUrl | |
| # --- NEW: Test Endpoint for Parent-Child Chunking --- | |
| async def test_chunking_endpoint(request: TestRequest): | |
| """ | |
| Tests the parent-child chunking strategy. | |
| Returns parent chunks, child chunks, and the time taken. | |
| """ | |
| print("--- Running Parent-Child Chunking Test ---") | |
| start_time = time.perf_counter() | |
| try: | |
| # Step 1: Parse the document to get raw text | |
| markdown_content = await ingest_and_parse_document(request.documents) | |
| # Step 2: Create parent and child chunks | |
| child_documents, docstore, _ = create_parent_child_chunks(markdown_content) | |
| end_time = time.perf_counter() | |
| duration = end_time - start_time | |
| print(f"--- Parsing and Chunking took {duration:.2f} seconds ---") | |
| # Convert Document objects to a JSON-serializable list for the response | |
| child_chunk_results = [ | |
| {"page_content": doc.page_content, "metadata": doc.metadata} | |
| for doc in child_documents | |
| ] | |
| # Retrieve parent documents from the in-memory store | |
| parent_docs = docstore.mget(list(docstore.store.keys())) | |
| parent_chunk_results = [ | |
| {"page_content": doc.page_content, "metadata": doc.metadata} | |
| for doc in parent_docs if doc | |
| ] | |
| return { | |
| "total_time_seconds": duration, | |
| "parent_chunk_count": len(parent_chunk_results), | |
| "child_chunk_count": len(child_chunk_results), | |
| "parent_chunks": parent_chunk_results, | |
| "child_chunks": child_chunk_results, | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}") | |
| async def run_rag_pipeline(request: RunRequest): | |
| try: | |
| total_pipeline_start_time = time.perf_counter() | |
| timings = {} | |
| print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---") | |
| # --- STAGE 1: DOCUMENT INGESTION (Parsing) --- | |
| parse_start = time.perf_counter() | |
| markdown_content = await ingest_and_parse_document(request.documents) | |
| timings["1_parsing"] = time.perf_counter() - parse_start | |
| print(f"Time taken for Parsing: {timings['1_parsing']:.4f} seconds.") | |
| # --- STAGE 2: PARENT-CHILD CHUNKING --- | |
| chunk_start = time.perf_counter() | |
| child_documents, docstore, _ = create_parent_child_chunks(markdown_content) | |
| timings["2_chunking"] = time.perf_counter() - chunk_start | |
| print(f"Time taken for Parent-Child Chunking: {timings['2_chunking']:.4f} seconds.") | |
| if not child_documents: | |
| raise HTTPException(status_code=400, detail="Document could not be processed into chunks.") | |
| # --- STAGE 3: INDEXING (Embedding) --- | |
| index_start = time.perf_counter() | |
| retriever.index(child_documents, docstore) | |
| timings["3_indexing_and_embedding"] = time.perf_counter() - index_start | |
| print(f"Time taken for Indexing (incl. Embeddings): {timings['3_indexing_and_embedding']:.4f} seconds.") | |
| # --- CONCURRENT WORKFLOW --- | |
| # Step A: Concurrently generate hypothetical documents (HyDE) | |
| print("Generating hypothetical documents...") | |
| hyde_start = time.perf_counter() | |
| hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions] | |
| all_hyde_docs = await asyncio.gather(*hyde_tasks) | |
| timings["4_hyde_generation_total"] = time.perf_counter() - hyde_start | |
| print(f"Time taken for HyDE Generation (total): {timings['4_hyde_generation_total']:.4f} seconds.") | |
| # Step B: Concurrently retrieve relevant chunks | |
| print("Retrieving chunks...") | |
| retrieval_start = time.perf_counter() | |
| retrieval_tasks = [ | |
| retriever.retrieve(q, hyde_doc) | |
| for q, hyde_doc in zip(request.questions, all_hyde_docs) | |
| ] | |
| all_retrieved_chunks = await asyncio.gather(*retrieval_tasks) | |
| timings["5_retrieval_total"] = time.perf_counter() - retrieval_start | |
| print(f"Time taken for Retrieval (total): {timings['5_retrieval_total']:.4f} seconds.") | |
| # Step C: Concurrently generate final answers | |
| print("Generating final answers...") | |
| generation_start = time.perf_counter() | |
| answer_tasks = [ | |
| generate_answer(q, chunks, GROQ_API_KEY) | |
| for q, chunks in zip(request.questions, all_retrieved_chunks) | |
| ] | |
| final_answers = await asyncio.gather(*answer_tasks) | |
| timings["6_answer_generation_total"] = time.perf_counter() - generation_start | |
| print(f"Time taken for Answer Generation (total): {timings['6_answer_generation_total']:.4f} seconds.") | |
| timings["total_pipeline_time"] = time.perf_counter() - total_pipeline_start_time | |
| print("\n--- RAG Pipeline Completed Successfully ---") | |
| print(f"--- Total Pipeline Time: {timings['total_pipeline_time']:.4f} seconds ---") | |
| print("--- Timing Breakdown ---") | |
| for stage, duration in timings.items(): | |
| print(f"- {stage}: {duration:.4f} seconds") | |
| return RunResponse(answers=final_answers) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, detail=f"An internal server error occurred: {str(e)}" | |
| ) |