PercivalFletcher commited on
Commit
f44abf4
·
verified ·
1 Parent(s): 9bce7bd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -64
main.py CHANGED
@@ -30,7 +30,7 @@ from rag_utils import (
30
  process_markdown_with_manual_sections,
31
  generate_answer_with_groq,
32
  HybridSearchManager,
33
- EmbeddingClient, # This might not be needed directly in main.py, but good to have
34
  CHUNK_SIZE,
35
  CHUNK_OVERLAP,
36
  TOP_K_CHUNKS,
@@ -47,15 +47,13 @@ app = FastAPI(
47
  )
48
 
49
  # --- Global instance for the HybridSearchManager ---
50
- # This will be initialized on startup
51
  hybrid_search_manager: Optional[HybridSearchManager] = None
52
 
53
  @app.on_event("startup")
54
  async def startup_event():
55
  global hybrid_search_manager
56
- # Initialize the HybridSearchManager at startup
57
  hybrid_search_manager = HybridSearchManager()
58
- #initialize_llama_extract_agent() # From processing_utility
59
  print("Application startup complete. HybridSearchManager is ready.")
60
 
61
  # --- Groq API Key Setup ---
@@ -65,16 +63,9 @@ if GROQ_API_KEY == "NOT_FOUND":
65
  "WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production."
66
  )
67
 
68
- # --- Authorization Token Setup ---
69
- # EXPECTED_AUTH_TOKEN = os.getenv("AUTHORIZATION_TOKEN")
70
- # if not EXPECTED_AUTH_TOKEN:
71
- # print(
72
- # "WARNING: AUTHORIZATION_TOKEN environment variable is not set. Authorization will not work as expected."
73
- # )
74
-
75
  # --- Pydantic Models for Request and Response ---
76
  class RunRequest(BaseModel):
77
- documents: str # URL to the PDF document
78
  questions: List[str]
79
 
80
  class Answer(BaseModel):
@@ -82,33 +73,12 @@ class Answer(BaseModel):
82
 
83
  class RunResponse(BaseModel):
84
  answers: List[str]
85
- #processing_time: float
86
- #step_timings: dict # New field for detailed timings
87
-
88
- # --- Security Dependency ---
89
- security = HTTPBearer()
90
-
91
- # async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
92
- # """
93
- # Verifies the Bearer token in the Authorization header.
94
- # """
95
- # if not EXPECTED_AUTH_TOKEN:
96
- # raise HTTPException(
97
- # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
98
- # detail="Authorization token not configured on the server.",
99
- # )
100
- # if credentials.scheme != "Bearer" or credentials.credentials != EXPECTED_AUTH_TOKEN:
101
- # raise HTTPException(
102
- # status_code=status.HTTP_401_UNAUTHORIZED,
103
- # detail="Invalid or missing authentication token",
104
- # headers={"WWW-Authenticate": "Bearer"},
105
- # )
106
- # return True
107
 
108
  @app.post("/hackrx/run", response_model=RunResponse)
109
  async def run_rag_pipeline(
110
- request: RunRequest,
111
- # authorized: bool = Depends(verify_token)
112
  ):
113
  """
114
  Runs the RAG pipeline for a given PDF document (converted to Markdown internally)
@@ -118,11 +88,8 @@ async def run_rag_pipeline(
118
  questions = request.questions
119
  local_markdown_path = None
120
  step_timings = {}
121
-
122
  start_time_total = time.perf_counter()
123
-
124
  try:
125
- # Ensure the HybridSearchManager is initialized
126
  if hybrid_search_manager is None:
127
  raise HTTPException(
128
  status_code=500, detail="HybridSearchManager not initialized."
@@ -142,20 +109,6 @@ async def run_rag_pipeline(
142
  f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds."
143
  )
144
 
145
- # 2. Headings Generation: Extract headings JSON
146
- '''start_time = time.perf_counter()
147
- headings_json = extract_schema_from_file(local_markdown_path)
148
- if not headings_json or not headings_json.get("headings"):
149
- raise HTTPException(
150
- status_code=400,
151
- detail="Could not retrieve valid headings from the provided document.",
152
- )
153
- end_time = time.perf_counter()
154
- step_timings["headings_generation"] = end_time - start_time
155
- print(
156
- f"Headings Generation took {step_timings['headings_generation']:.2f} seconds."
157
- )'''
158
-
159
  headings_json = {"headings":["p"]}
160
 
161
  # 3. Chunk Generation: Process Markdown into chunks
@@ -178,7 +131,6 @@ async def run_rag_pipeline(
178
 
179
  # 4. Model Initialization and Embeddings Pre-computation
180
  start_time = time.perf_counter()
181
- # --- FIX: Await the async function call ---
182
  await hybrid_search_manager.initialize_models(processed_documents)
183
  end_time = time.perf_counter()
184
  step_timings["model_initialization"] = end_time - start_time
@@ -188,29 +140,36 @@ async def run_rag_pipeline(
188
 
189
  # 5. Concurrent Query Processing (Search and Generation)
190
  start_time_query_processing = time.perf_counter()
191
-
192
  # Search Phase
193
  batch_size = 3
194
  all_retrieved_results = []
 
195
  print(f"Starting concurrent search in batches of {batch_size}...")
196
-
197
  for i in range(0, len(questions), batch_size):
198
  current_batch_questions = questions[i : i + batch_size]
199
  print(
200
  f"Processing batch {i // batch_size + 1} with {len(current_batch_questions)} queries."
201
  )
202
-
203
- # --- FIX: Directly create a list of coroutines, no asyncio.to_thread needed here ---
204
  search_tasks = [
205
  hybrid_search_manager.perform_hybrid_search(
206
  question, TOP_K_CHUNKS
207
  )
208
  for question in current_batch_questions
209
  ]
210
- batch_results = await asyncio.gather(*search_tasks)
211
- all_retrieved_results.extend(batch_results)
 
 
 
 
212
 
213
  print("Search phase completed for all queries.")
 
 
 
 
214
 
215
  # Generation Phase
216
  print(f"Starting concurrent answer generation for {len(questions)} questions...")
@@ -230,7 +189,6 @@ async def run_rag_pipeline(
230
  generation_tasks.append(no_info_future)
231
 
232
  all_answer_texts = await asyncio.gather(*generation_tasks)
233
-
234
  end_time_query_processing = time.perf_counter()
235
  step_timings["query_processing"] = (
236
  end_time_query_processing - start_time_query_processing
@@ -241,12 +199,13 @@ async def run_rag_pipeline(
241
 
242
  end_time_total = time.perf_counter()
243
  total_processing_time = end_time_total - start_time_total
 
244
  print("All questions processed.")
245
-
246
  all_answers = [answer_text for answer_text in all_answer_texts]
247
 
248
  return RunResponse(
249
- answers=all_answers
 
250
  )
251
 
252
  except HTTPException as e:
@@ -259,4 +218,5 @@ async def run_rag_pipeline(
259
  finally:
260
  if local_markdown_path and os.path.exists(local_markdown_path):
261
  os.unlink(local_markdown_path)
262
- print(f"Cleaned up temporary markdown file: {local_markdown_path}")
 
 
30
  process_markdown_with_manual_sections,
31
  generate_answer_with_groq,
32
  HybridSearchManager,
33
+ EmbeddingClient,
34
  CHUNK_SIZE,
35
  CHUNK_OVERLAP,
36
  TOP_K_CHUNKS,
 
47
  )
48
 
49
  # --- Global instance for the HybridSearchManager ---
 
50
  hybrid_search_manager: Optional[HybridSearchManager] = None
51
 
52
  @app.on_event("startup")
53
  async def startup_event():
54
  global hybrid_search_manager
 
55
  hybrid_search_manager = HybridSearchManager()
56
+ #initialize_llama_extract_agent()
57
  print("Application startup complete. HybridSearchManager is ready.")
58
 
59
  # --- Groq API Key Setup ---
 
63
  "WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production."
64
  )
65
 
 
 
 
 
 
 
 
66
  # --- Pydantic Models for Request and Response ---
67
  class RunRequest(BaseModel):
68
+ documents: str
69
  questions: List[str]
70
 
71
  class Answer(BaseModel):
 
73
 
74
  class RunResponse(BaseModel):
75
  answers: List[str]
76
+ step_timings: Dict[str, float] # Added field for timing information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @app.post("/hackrx/run", response_model=RunResponse)
79
  async def run_rag_pipeline(
80
+ request: RunRequest
81
+ # authorized: bool = Depends(verify_token)):
82
  ):
83
  """
84
  Runs the RAG pipeline for a given PDF document (converted to Markdown internally)
 
88
  questions = request.questions
89
  local_markdown_path = None
90
  step_timings = {}
 
91
  start_time_total = time.perf_counter()
 
92
  try:
 
93
  if hybrid_search_manager is None:
94
  raise HTTPException(
95
  status_code=500, detail="HybridSearchManager not initialized."
 
109
  f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds."
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  headings_json = {"headings":["p"]}
113
 
114
  # 3. Chunk Generation: Process Markdown into chunks
 
131
 
132
  # 4. Model Initialization and Embeddings Pre-computation
133
  start_time = time.perf_counter()
 
134
  await hybrid_search_manager.initialize_models(processed_documents)
135
  end_time = time.perf_counter()
136
  step_timings["model_initialization"] = end_time - start_time
 
140
 
141
  # 5. Concurrent Query Processing (Search and Generation)
142
  start_time_query_processing = time.perf_counter()
143
+
144
  # Search Phase
145
  batch_size = 3
146
  all_retrieved_results = []
147
+ all_rerank_times = []
148
  print(f"Starting concurrent search in batches of {batch_size}...")
 
149
  for i in range(0, len(questions), batch_size):
150
  current_batch_questions = questions[i : i + batch_size]
151
  print(
152
  f"Processing batch {i // batch_size + 1} with {len(current_batch_questions)} queries."
153
  )
154
+ # The search method now returns a tuple of results and rerank time
 
155
  search_tasks = [
156
  hybrid_search_manager.perform_hybrid_search(
157
  question, TOP_K_CHUNKS
158
  )
159
  for question in current_batch_questions
160
  ]
161
+ batch_results_and_times = await asyncio.gather(*search_tasks)
162
+
163
+ # Unpack results and timings
164
+ for results, rerank_time in batch_results_and_times:
165
+ all_retrieved_results.append(results)
166
+ all_rerank_times.append(rerank_time)
167
 
168
  print("Search phase completed for all queries.")
169
+
170
+ # Add the total reranking time to the step timings
171
+ step_timings["reranking_total_time"] = sum(all_rerank_times)
172
+ step_timings["reranking_avg_time_per_query"] = sum(all_rerank_times) / len(all_rerank_times)
173
 
174
  # Generation Phase
175
  print(f"Starting concurrent answer generation for {len(questions)} questions...")
 
189
  generation_tasks.append(no_info_future)
190
 
191
  all_answer_texts = await asyncio.gather(*generation_tasks)
 
192
  end_time_query_processing = time.perf_counter()
193
  step_timings["query_processing"] = (
194
  end_time_query_processing - start_time_query_processing
 
199
 
200
  end_time_total = time.perf_counter()
201
  total_processing_time = end_time_total - start_time_total
202
+ step_timings["total_processing_time"] = total_processing_time
203
  print("All questions processed.")
 
204
  all_answers = [answer_text for answer_text in all_answer_texts]
205
 
206
  return RunResponse(
207
+ answers=all_answers,
208
+ step_timings=step_timings
209
  )
210
 
211
  except HTTPException as e:
 
218
  finally:
219
  if local_markdown_path and os.path.exists(local_markdown_path):
220
  os.unlink(local_markdown_path)
221
+ print(f"Cleaned up temporary markdown file: {local_markdown_path}")
222
+