File size: 8,283 Bytes
5ccaf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83529f
5ccaf15
b83529f
5ccaf15
f44abf4
5ccaf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f44abf4
5ccaf15
 
 
 
 
 
 
 
 
 
 
f44abf4
5ccaf15
 
 
 
 
 
 
69b1de6
 
5ccaf15
 
 
f44abf4
5ccaf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83529f
5ccaf15
b83529f
5ccaf15
 
 
 
 
 
 
 
 
 
 
 
 
 
b83529f
5ccaf15
 
 
 
 
 
 
b83529f
 
5ccaf15
b83529f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f44abf4
b83529f
 
 
 
 
 
 
 
 
f44abf4
b83529f
 
 
 
5ccaf15
b83529f
 
5ccaf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b83529f
 
 
5ccaf15
 
 
f44abf4
5ccaf15
 
 
 
f44abf4
69b1de6
 
5ccaf15
 
 
 
 
 
 
 
 
 
 
 
f44abf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import json
import tempfile
import requests
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import List, Dict, Union, Any, Optional
from dotenv import load_dotenv
import asyncio
import httpx
import time
from urllib.parse import urlparse, unquote
import uuid
import re

# Import LangChain Document and text splitter
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

from processing_utility import (
    extract_schema_from_file,
    process_document,
    download_and_parse_document_using_llama_index,
)

# Import the new classes and functions from rag_utils
from rag_utils import (
    process_markdown_with_recursive_chunking,
    generate_answer_with_groq,
    generate_hypothetical_document,
    HybridSearchManager,
    EmbeddingClient,
    CHUNK_SIZE,
    CHUNK_OVERLAP,
    TOP_K_CHUNKS,
    GROQ_MODEL_NAME,
)

load_dotenv()

# --- FastAPI App Initialization ---
app = FastAPI(
    title="HackRX RAG API",
    description="API for Retrieval-Augmented Generation from PDF documents.",
    version="1.0.0",
)

# --- Global instance for the HybridSearchManager ---
hybrid_search_manager: Optional[HybridSearchManager] = None

@app.on_event("startup")
async def startup_event():
    global hybrid_search_manager
    hybrid_search_manager = HybridSearchManager()
    #initialize_llama_extract_agent()
    print("Application startup complete. HybridSearchManager is ready.")

# --- Groq API Key Setup ---
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "NOT_FOUND")
if GROQ_API_KEY == "NOT_FOUND":
    print(
        "WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production."
    )

# --- Pydantic Models for Request and Response ---
class RunRequest(BaseModel):
    documents: str
    questions: List[str]

class Answer(BaseModel):
    answer: str

class RunResponse(BaseModel):
    answers: List[str]
    #step_timings: Dict[str, float]
    #hypothetical_documents: List[str]

@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(
    request: RunRequest
):
    """
    Runs the RAG pipeline for a given PDF document (converted to Markdown internally)
    and a list of questions.
    """
    pdf_url = request.documents
    questions = request.questions
    local_markdown_path = None
    step_timings = {}
    start_time_total = time.perf_counter()
    try:
        if hybrid_search_manager is None:
            raise HTTPException(
                status_code=500, detail="HybridSearchManager not initialized."
            )

        # 1. Parsing: Download PDF and parse to Markdown
        start_time = time.perf_counter()
        markdown_content = await download_and_parse_document_using_llama_index(pdf_url)
        with tempfile.NamedTemporaryFile(
            mode="w", delete=False, encoding="utf-8", suffix=".md"
        ) as temp_md_file:
            temp_md_file.write(markdown_content)
            local_markdown_path = temp_md_file.name
        end_time = time.perf_counter()
        step_timings["parsing_to_markdown"] = end_time - start_time
        print(
            f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds."
        )

        # 2. Chunk Generation: Process Markdown into chunks
        start_time = time.perf_counter()
        processed_documents = process_markdown_with_recursive_chunking(
            local_markdown_path,
            CHUNK_SIZE,
            CHUNK_OVERLAP,
        )
        if not processed_documents:
            raise HTTPException(
                status_code=500, detail="Failed to process document into chunks."
            )
        end_time = time.perf_counter()
        step_timings["chunk_generation"] = end_time - start_time
        print(
            f"Chunk Generation took {step_timings['chunk_generation']:.2f} seconds."
        )

        # 3. Model Initialization and Embeddings Pre-computation
        start_time = time.perf_counter()
        await hybrid_search_manager.initialize_models(processed_documents)
        end_time = time.perf_counter()
        step_timings["model_initialization"] = end_time - start_time
        print(
            f"Model initialization took {step_timings['model_initialization']:.2f} seconds."
        )
        
        # --- NEW CONCURRENT WORKFLOW ---

        # 4. Concurrently generate all hypothetical documents
        start_time_hyde = time.perf_counter()
        hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in questions]
        all_hyde_docs = await asyncio.gather(*hyde_tasks)
        end_time_hyde = time.perf_counter()
        step_timings["hyde_generation_total_time"] = end_time_hyde - start_time_hyde
        step_timings["hyde_generation_avg_time_per_query"] = (end_time_hyde - start_time_hyde) / len(questions)

        # 5. Concurrently perform initial hybrid search to get candidates for ALL queries
        start_time_search = time.perf_counter()
        candidate_retrieval_tasks = [
            hybrid_search_manager.retrieve_candidates(q, hyde_doc)
            for q, hyde_doc in zip(questions, all_hyde_docs)
        ]
        all_candidates = await asyncio.gather(*candidate_retrieval_tasks)
        end_time_search = time.perf_counter()
        step_timings["candidate_retrieval_total_time"] = end_time_search - start_time_search
        
        # 6. Concurrently rerank the candidates for ALL queries
        start_time_rerank = time.perf_counter()
        rerank_tasks = [
            hybrid_search_manager.rerank_results(q, candidates, TOP_K_CHUNKS)
            for q, candidates in zip(questions, all_candidates)
        ]
        reranked_results_and_times = await asyncio.gather(*rerank_tasks)
        end_time_rerank = time.perf_counter()
        step_timings["reranking_total_time"] = end_time_rerank - start_time_rerank
        
        # Unpack reranked results and timings
        all_retrieved_results = [item[0] for item in reranked_results_and_times]
        all_rerank_times = [item[1] for item in reranked_results_and_times]
        step_timings["reranking_avg_time_per_query"] = (end_time_rerank - start_time_rerank) / len(questions)

        # 7. Concurrently generate final answers
        start_time_generation = time.perf_counter()
        generation_tasks = []
        for question, retrieved_results in zip(questions, all_retrieved_results):
            if retrieved_results:
                generation_tasks.append(
                    generate_answer_with_groq(
                        question, retrieved_results, GROQ_API_KEY
                    )
                )
            else:
                no_info_future = asyncio.Future()
                no_info_future.set_result(
                    "No relevant information found in the document to answer this question."
                )
                generation_tasks.append(no_info_future)

        all_answer_texts = await asyncio.gather(*generation_tasks)
        end_time_generation = time.perf_counter()
        step_timings["generation_total_time"] = end_time_generation - start_time_generation
        step_timings["generation_avg_time_per_query"] = (end_time_generation - start_time_generation) / len(questions)

        end_time_total = time.perf_counter()
        total_processing_time = end_time_total - start_time_total
        step_timings["total_processing_time"] = total_processing_time
        print("All questions processed.")
        all_answers = [answer_text for answer_text in all_answer_texts]

        return RunResponse(
            answers=all_answers,
            #step_timings=step_timings,
            #hypothetical_documents=all_hyde_docs
        )

    except HTTPException as e:
        raise e
    except Exception as e:
        print(f"An unhandled error occurred: {e}")
        raise HTTPException(
            status_code=500, detail=f"An internal server error occurred: {e}"
        )
    finally:
        if local_markdown_path and os.path.exists(local_markdown_path):
            os.unlink(local_markdown_path)
            print(f"Cleaned up temporary markdown file: {local_markdown_path}")