PercivalFletcher commited on
Commit
0266bc6
·
verified ·
1 Parent(s): 5f89477

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +243 -0
main.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: main.py
2
+ # (Modified to load embedding model at startup and await async pipeline run)
3
+
4
+ import os
5
+ import tempfile
6
+ import asyncio
7
+ import time
8
+ from typing import List, Dict, Any
9
+ from urllib.parse import urlparse, unquote
10
+ from pathlib import Path
11
+
12
+ import httpx
13
+ from fastapi import FastAPI, HTTPException
14
+ from pydantic import BaseModel, HttpUrl
15
+ from groq import AsyncGroq
16
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
17
+ import torch # Import torch to check for CUDA availability
18
+
19
+ from dotenv import load_dotenv
20
+
21
+ load_dotenv()
22
+
23
+ # Import the Pipeline class from the previous file
24
+ from pipeline import Pipeline
25
+
26
+ # FastAPI application setup
27
+ app = FastAPI(
28
+ title="Llama-Index RAG with Groq",
29
+ description="An API to process a PDF from a URL and answer a list of questions using a Llama-Index RAG pipeline.",
30
+ )
31
+
32
+ # --- Pydantic Models for API Request and Response ---
33
+ class RunRequest(BaseModel):
34
+ documents: HttpUrl
35
+ questions: List[str]
36
+
37
+ class Answer(BaseModel):
38
+ question: str
39
+ answer: str
40
+
41
+ class RunResponse(BaseModel):
42
+ answers: List[Answer]
43
+ processing_time: float
44
+ step_timings: Dict[str, float]
45
+
46
+ # --- Global Configurations ---
47
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "gsk_...")
48
+ GROQ_MODEL_NAME = "llama3-70b-8192"
49
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
50
+
51
+ # Global variable to hold the initialized embedding model
52
+ embed_model_instance: HuggingFaceEmbedding | None = None
53
+
54
+ if GROQ_API_KEY == "gsk_...":
55
+ print("WARNING: GROQ_API_KEY is not set. Please set it in your environment or main.py.")
56
+
57
+ @app.on_event("startup")
58
+ async def startup_event():
59
+ """
60
+ Loads the embedding model once when the application starts.
61
+ This prevents re-loading it on every API call.
62
+ """
63
+ global embed_model_instance
64
+ print(f"Loading embedding model '{EMBEDDING_MODEL_NAME}' at startup...")
65
+ # Check for GPU availability and use it if possible
66
+ # Assuming 16GB VRAM, a standard device check is sufficient
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ print(f"Using device: {device}")
69
+ embed_model_instance = await asyncio.to_thread(HuggingFaceEmbedding, model_name=EMBEDDING_MODEL_NAME, device=device)
70
+ print("Embedding model loaded successfully.")
71
+
72
+ # --- Async Groq Generation Function ---
73
+ async def generate_answer_with_groq(query: str, retrieved_results: List[dict], groq_api_key: str) -> str:
74
+ """
75
+ Generates an answer using the Groq API based on the query and retrieved chunks' content.
76
+ """
77
+ if not groq_api_key:
78
+ return "Error: Groq API key is not set. Cannot generate answer."
79
+
80
+ client = AsyncGroq(api_key=groq_api_key)
81
+
82
+ context_parts = []
83
+ for i, res in enumerate(retrieved_results):
84
+ content = res.get("content", "")
85
+ metadata = res.get("document_metadata", {})
86
+
87
+ section_heading = metadata.get("section_heading", metadata.get("file_name", "N/A"))
88
+
89
+ context_parts.append(
90
+ f"--- Context Chunk {i+1} ---\n"
91
+ f"Document Part: {section_heading}\n"
92
+ f"Content: {content}\n"
93
+ f"-------------------------"
94
+ )
95
+ context = "\n\n".join(context_parts)
96
+
97
+ prompt = (
98
+ f"You are a specialized document analyzer assistant. Your task is to answer the user's question "
99
+ f"solely based on the provided context. If the answer cannot be found in the provided context, "
100
+ f"clearly state that you do not have enough information.\n\n"
101
+ f"Context:\n{context}\n\n"
102
+ f"Question: {query}\n\n"
103
+ f"Answer:"
104
+ )
105
+
106
+ try:
107
+ chat_completion = await client.chat.completions.create(
108
+ messages=[
109
+ {
110
+ "role": "user",
111
+ "content": prompt,
112
+ }
113
+ ],
114
+ model=GROQ_MODEL_NAME,
115
+ temperature=0.7,
116
+ max_tokens=500,
117
+ )
118
+ answer = chat_completion.choices[0].message.content
119
+ return answer
120
+ except Exception as e:
121
+ print(f"An error occurred during Groq API call: {e}")
122
+ return "Could not generate an answer due to an API error."
123
+
124
+
125
+ # --- FastAPI Endpoint ---
126
+ @app.get("/health", tags=["Monitoring"])
127
+ async def health_check():
128
+ return {"status": "ok"}
129
+
130
+ @app.post("/hackrx/run", response_model=RunResponse)
131
+ async def run_rag_pipeline(request: RunRequest):
132
+ """
133
+ Runs the RAG pipeline for a given PDF document URL and a list of questions.
134
+ The PDF is downloaded, processed, and then the questions are answered.
135
+ """
136
+ pdf_url = request.documents
137
+ questions = request.questions
138
+ local_pdf_path = None
139
+ step_timings = {}
140
+
141
+ start_time_total = time.perf_counter()
142
+
143
+ if not embed_model_instance:
144
+ raise HTTPException(
145
+ status_code=500,
146
+ detail="Embedding model not loaded. Application startup failed."
147
+ )
148
+
149
+ if not GROQ_API_KEY or GROQ_API_KEY == "gsk_...":
150
+ raise HTTPException(
151
+ status_code=500,
152
+ detail="Groq API key is not configured. Please set the GROQ_API_KEY environment variable."
153
+ )
154
+
155
+ try:
156
+ # 1. Download PDF
157
+ start_time = time.perf_counter()
158
+ async with httpx.AsyncClient() as client:
159
+ try:
160
+ response = await client.get(str(pdf_url), timeout=30.0, follow_redirects=True)
161
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
162
+ doc_bytes = response.content
163
+ print("Download successful.")
164
+ except httpx.HTTPStatusError as e:
165
+ raise HTTPException(status_code=e.response.status_code, detail=f"HTTP error downloading PDF: {e.response.status_code} - {e.response.text}")
166
+ except httpx.RequestError as e:
167
+ raise HTTPException(status_code=400, detail=f"Network error downloading PDF: {e}")
168
+ except Exception as e:
169
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred during download: {e}")
170
+
171
+ # Determine a temporary local filename
172
+ parsed_path = urlparse(str(pdf_url)).path
173
+ filename = unquote(os.path.basename(parsed_path))
174
+ if not filename or not filename.lower().endswith(".pdf"):
175
+ # If the URL doesn't provide a valid PDF filename, create a generic one.
176
+ filename = "downloaded_document.pdf"
177
+
178
+ # Use tempfile to create a secure temporary file
179
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf_file:
180
+ temp_pdf_file.write(doc_bytes)
181
+ local_pdf_path = temp_pdf_file.name
182
+
183
+ end_time = time.perf_counter()
184
+ step_timings["download_pdf"] = end_time - start_time
185
+ print(f"PDF download took {step_timings['download_pdf']:.2f} seconds.")
186
+
187
+ # 2. Initialize and Run the Pipeline (Parsing, Node Creation, Embeddings)
188
+ start_time = time.perf_counter()
189
+ # The Pipeline's run() method is now async, so await it directly
190
+ pipeline = Pipeline(groq_api_key=GROQ_API_KEY, pdf_path=local_pdf_path, embed_model=embed_model_instance)
191
+ await pipeline.run() # Changed from asyncio.to_thread(pipeline.run)
192
+ end_time = time.perf_counter()
193
+ step_timings["pipeline_setup"] = end_time - start_time
194
+ print(f"Pipeline setup took {step_timings['pipeline_setup']:.2f} seconds.")
195
+
196
+ # 3. Concurrent Retrieval Phase
197
+ start_time_retrieval = time.perf_counter()
198
+ print(f"\nStarting concurrent retrieval for {len(questions)} questions...")
199
+
200
+ retrieval_tasks = [asyncio.to_thread(pipeline.retrieve_nodes, q) for q in questions]
201
+ all_retrieved_results = await asyncio.gather(*retrieval_tasks)
202
+
203
+ end_time_retrieval = time.perf_counter()
204
+ step_timings["retrieval"] = end_time_retrieval - start_time_retrieval
205
+ print(f"Retrieval phase completed in {step_timings['retrieval']:.2f} seconds.")
206
+
207
+ # 4. Concurrent Generation Phase
208
+ start_time_generation = time.perf_counter()
209
+ print(f"\nStarting concurrent answer generation for {len(questions)} questions...")
210
+
211
+ generation_tasks = [
212
+ generate_answer_with_groq(q, retrieved_results, GROQ_API_KEY)
213
+ for q, retrieved_results in zip(questions, all_retrieved_results)
214
+ ]
215
+
216
+ all_answer_texts = await asyncio.gather(*generation_tasks)
217
+
218
+ end_time_generation = time.perf_counter()
219
+ step_timings["generation"] = end_time_generation - start_time_generation
220
+ print(f"Generation phase completed in {step_timings['generation']:.2f} seconds.")
221
+
222
+ end_time_total = time.perf_counter()
223
+ total_processing_time = end_time_total - start_time_total
224
+
225
+ answers = [Answer(question=q, answer=a) for q, a in zip(questions, all_answer_texts)]
226
+
227
+ return RunResponse(
228
+ answers=answers,
229
+ processing_time=total_processing_time,
230
+ step_timings=step_timings,
231
+ )
232
+
233
+ except HTTPException as e:
234
+ raise e
235
+ except Exception as e:
236
+ print(f"An unhandled error occurred: {e}")
237
+ raise HTTPException(
238
+ status_code=500, detail=f"An internal server error occurred: {e}"
239
+ )
240
+ finally:
241
+ if local_pdf_path and os.path.exists(local_pdf_path):
242
+ os.unlink(local_pdf_path)
243
+ print(f"Cleaned up temporary PDF file: {local_pdf_path}")