PercivalFletcher commited on
Commit
5ccaf15
·
verified ·
1 Parent(s): 147fdf0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +262 -0
main.py CHANGED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import requests
5
+ from fastapi import FastAPI, HTTPException, Depends, status
6
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
+ from pydantic import BaseModel
8
+ from typing import List, Dict, Union, Any, Optional
9
+ from dotenv import load_dotenv
10
+ import asyncio
11
+ import httpx
12
+ import time
13
+ from urllib.parse import urlparse, unquote
14
+ import uuid
15
+ import re
16
+
17
+ # Import LangChain Document and text splitter
18
+ from langchain_core.documents import Document
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+
21
+ from processing_utility import (
22
+ extract_schema_from_file,
23
+ #initialize_llama_extract_agent,
24
+ process_document,
25
+ download_and_parse_document_using_llama_index,
26
+ )
27
+
28
+ # Import the new classes and functions from rag_utils
29
+ from rag_utils import (
30
+ process_markdown_with_manual_sections,
31
+ generate_answer_with_groq,
32
+ HybridSearchManager,
33
+ EmbeddingClient, # This might not be needed directly in main.py, but good to have
34
+ CHUNK_SIZE,
35
+ CHUNK_OVERLAP,
36
+ TOP_K_CHUNKS,
37
+ GROQ_MODEL_NAME,
38
+ )
39
+
40
+ load_dotenv()
41
+
42
+ # --- FastAPI App Initialization ---
43
+ app = FastAPI(
44
+ title="HackRX RAG API",
45
+ description="API for Retrieval-Augmented Generation from PDF documents.",
46
+ version="1.0.0",
47
+ )
48
+
49
+ # --- Global instance for the HybridSearchManager ---
50
+ # This will be initialized on startup
51
+ hybrid_search_manager: Optional[HybridSearchManager] = None
52
+
53
+ @app.on_event("startup")
54
+ async def startup_event():
55
+ global hybrid_search_manager
56
+ # Initialize the HybridSearchManager at startup
57
+ hybrid_search_manager = HybridSearchManager()
58
+ #initialize_llama_extract_agent() # From processing_utility
59
+ print("Application startup complete. HybridSearchManager is ready.")
60
+
61
+ # --- Groq API Key Setup ---
62
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "NOT_FOUND")
63
+ if GROQ_API_KEY == "NOT_FOUND":
64
+ print(
65
+ "WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production."
66
+ )
67
+
68
+ # --- Authorization Token Setup ---
69
+ # EXPECTED_AUTH_TOKEN = os.getenv("AUTHORIZATION_TOKEN")
70
+ # if not EXPECTED_AUTH_TOKEN:
71
+ # print(
72
+ # "WARNING: AUTHORIZATION_TOKEN environment variable is not set. Authorization will not work as expected."
73
+ # )
74
+
75
+ # --- Pydantic Models for Request and Response ---
76
+ class RunRequest(BaseModel):
77
+ documents: str # URL to the PDF document
78
+ questions: List[str]
79
+
80
+ class Answer(BaseModel):
81
+ answer: str
82
+
83
+ class RunResponse(BaseModel):
84
+ answers: List[str]
85
+ #processing_time: float
86
+ #step_timings: dict # New field for detailed timings
87
+
88
+ # --- Security Dependency ---
89
+ security = HTTPBearer()
90
+
91
+ # async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
92
+ # """
93
+ # Verifies the Bearer token in the Authorization header.
94
+ # """
95
+ # if not EXPECTED_AUTH_TOKEN:
96
+ # raise HTTPException(
97
+ # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
98
+ # detail="Authorization token not configured on the server.",
99
+ # )
100
+ # if credentials.scheme != "Bearer" or credentials.credentials != EXPECTED_AUTH_TOKEN:
101
+ # raise HTTPException(
102
+ # status_code=status.HTTP_401_UNAUTHORIZED,
103
+ # detail="Invalid or missing authentication token",
104
+ # headers={"WWW-Authenticate": "Bearer"},
105
+ # )
106
+ # return True
107
+
108
+ @app.post("/hackrx/run", response_model=RunResponse)
109
+ async def run_rag_pipeline(
110
+ request: RunRequest,
111
+ # authorized: bool = Depends(verify_token)
112
+ ):
113
+ """
114
+ Runs the RAG pipeline for a given PDF document (converted to Markdown internally)
115
+ and a list of questions.
116
+ """
117
+ pdf_url = request.documents
118
+ questions = request.questions
119
+ local_markdown_path = None
120
+ step_timings = {}
121
+
122
+ start_time_total = time.perf_counter()
123
+
124
+ try:
125
+ # Ensure the HybridSearchManager is initialized
126
+ if hybrid_search_manager is None:
127
+ raise HTTPException(
128
+ status_code=500, detail="HybridSearchManager not initialized."
129
+ )
130
+
131
+ # 1. Parsing: Download PDF and parse to Markdown
132
+ start_time = time.perf_counter()
133
+ markdown_content = await download_and_parse_document_using_llama_index(pdf_url)
134
+ with tempfile.NamedTemporaryFile(
135
+ mode="w", delete=False, encoding="utf-8", suffix=".md"
136
+ ) as temp_md_file:
137
+ temp_md_file.write(markdown_content)
138
+ local_markdown_path = temp_md_file.name
139
+ end_time = time.perf_counter()
140
+ step_timings["parsing_to_markdown"] = end_time - start_time
141
+ print(
142
+ f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds."
143
+ )
144
+
145
+ # 2. Headings Generation: Extract headings JSON
146
+ '''start_time = time.perf_counter()
147
+ headings_json = extract_schema_from_file(local_markdown_path)
148
+ if not headings_json or not headings_json.get("headings"):
149
+ raise HTTPException(
150
+ status_code=400,
151
+ detail="Could not retrieve valid headings from the provided document.",
152
+ )
153
+ end_time = time.perf_counter()
154
+ step_timings["headings_generation"] = end_time - start_time
155
+ print(
156
+ f"Headings Generation took {step_timings['headings_generation']:.2f} seconds."
157
+ )'''
158
+
159
+ headings_json = {"headings":["p"]}
160
+
161
+ # 3. Chunk Generation: Process Markdown into chunks
162
+ start_time = time.perf_counter()
163
+ processed_documents = process_markdown_with_manual_sections(
164
+ local_markdown_path,
165
+ headings_json,
166
+ CHUNK_SIZE,
167
+ CHUNK_OVERLAP,
168
+ )
169
+ if not processed_documents:
170
+ raise HTTPException(
171
+ status_code=500, detail="Failed to process document into chunks."
172
+ )
173
+ end_time = time.perf_counter()
174
+ step_timings["chunk_generation"] = end_time - start_time
175
+ print(
176
+ f"Chunk Generation took {step_timings['chunk_generation']:.2f} seconds."
177
+ )
178
+
179
+ # 4. Model Initialization and Embeddings Pre-computation
180
+ start_time = time.perf_counter()
181
+ # --- FIX: Await the async function call ---
182
+ await hybrid_search_manager.initialize_models(processed_documents)
183
+ end_time = time.perf_counter()
184
+ step_timings["model_initialization"] = end_time - start_time
185
+ print(
186
+ f"Model initialization took {step_timings['model_initialization']:.2f} seconds."
187
+ )
188
+
189
+ # 5. Concurrent Query Processing (Search and Generation)
190
+ start_time_query_processing = time.perf_counter()
191
+
192
+ # Search Phase
193
+ batch_size = 3
194
+ all_retrieved_results = []
195
+ print(f"Starting concurrent search in batches of {batch_size}...")
196
+
197
+ for i in range(0, len(questions), batch_size):
198
+ current_batch_questions = questions[i : i + batch_size]
199
+ print(
200
+ f"Processing batch {i // batch_size + 1} with {len(current_batch_questions)} queries."
201
+ )
202
+
203
+ # --- FIX: Directly create a list of coroutines, no asyncio.to_thread needed here ---
204
+ search_tasks = [
205
+ hybrid_search_manager.perform_hybrid_search(
206
+ question, TOP_K_CHUNKS
207
+ )
208
+ for question in current_batch_questions
209
+ ]
210
+ batch_results = await asyncio.gather(*search_tasks)
211
+ all_retrieved_results.extend(batch_results)
212
+
213
+ print("Search phase completed for all queries.")
214
+
215
+ # Generation Phase
216
+ print(f"Starting concurrent answer generation for {len(questions)} questions...")
217
+ generation_tasks = []
218
+ for question, retrieved_results in zip(questions, all_retrieved_results):
219
+ if retrieved_results:
220
+ generation_tasks.append(
221
+ generate_answer_with_groq(
222
+ question, retrieved_results, GROQ_API_KEY
223
+ )
224
+ )
225
+ else:
226
+ no_info_future = asyncio.Future()
227
+ no_info_future.set_result(
228
+ "No relevant information found in the document to answer this question."
229
+ )
230
+ generation_tasks.append(no_info_future)
231
+
232
+ all_answer_texts = await asyncio.gather(*generation_tasks)
233
+
234
+ end_time_query_processing = time.perf_counter()
235
+ step_timings["query_processing"] = (
236
+ end_time_query_processing - start_time_query_processing
237
+ )
238
+ print(
239
+ f"Total query processing took {step_timings['query_processing']:.2f} seconds."
240
+ )
241
+
242
+ end_time_total = time.perf_counter()
243
+ total_processing_time = end_time_total - start_time_total
244
+ print("All questions processed.")
245
+
246
+ all_answers = [answer_text for answer_text in all_answer_texts]
247
+
248
+ return RunResponse(
249
+ answers=all_answers
250
+ )
251
+
252
+ except HTTPException as e:
253
+ raise e
254
+ except Exception as e:
255
+ print(f"An unhandled error occurred: {e}")
256
+ raise HTTPException(
257
+ status_code=500, detail=f"An internal server error occurred: {e}"
258
+ )
259
+ finally:
260
+ if local_markdown_path and os.path.exists(local_markdown_path):
261
+ os.unlink(local_markdown_path)
262
+ print(f"Cleaned up temporary markdown file: {local_markdown_path}")