Spaces:
Sleeping
Sleeping
File size: 8,283 Bytes
5ccaf15 b83529f 5ccaf15 b83529f 5ccaf15 f44abf4 5ccaf15 f44abf4 5ccaf15 f44abf4 5ccaf15 69b1de6 5ccaf15 f44abf4 5ccaf15 b83529f 5ccaf15 b83529f 5ccaf15 b83529f 5ccaf15 b83529f 5ccaf15 b83529f f44abf4 b83529f f44abf4 b83529f 5ccaf15 b83529f 5ccaf15 b83529f 5ccaf15 f44abf4 5ccaf15 f44abf4 69b1de6 5ccaf15 f44abf4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import os
import json
import tempfile
import requests
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import List, Dict, Union, Any, Optional
from dotenv import load_dotenv
import asyncio
import httpx
import time
from urllib.parse import urlparse, unquote
import uuid
import re
# Import LangChain Document and text splitter
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from processing_utility import (
extract_schema_from_file,
process_document,
download_and_parse_document_using_llama_index,
)
# Import the new classes and functions from rag_utils
from rag_utils import (
process_markdown_with_recursive_chunking,
generate_answer_with_groq,
generate_hypothetical_document,
HybridSearchManager,
EmbeddingClient,
CHUNK_SIZE,
CHUNK_OVERLAP,
TOP_K_CHUNKS,
GROQ_MODEL_NAME,
)
load_dotenv()
# --- FastAPI App Initialization ---
app = FastAPI(
title="HackRX RAG API",
description="API for Retrieval-Augmented Generation from PDF documents.",
version="1.0.0",
)
# --- Global instance for the HybridSearchManager ---
hybrid_search_manager: Optional[HybridSearchManager] = None
@app.on_event("startup")
async def startup_event():
global hybrid_search_manager
hybrid_search_manager = HybridSearchManager()
#initialize_llama_extract_agent()
print("Application startup complete. HybridSearchManager is ready.")
# --- Groq API Key Setup ---
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "NOT_FOUND")
if GROQ_API_KEY == "NOT_FOUND":
print(
"WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production."
)
# --- Pydantic Models for Request and Response ---
class RunRequest(BaseModel):
documents: str
questions: List[str]
class Answer(BaseModel):
answer: str
class RunResponse(BaseModel):
answers: List[str]
#step_timings: Dict[str, float]
#hypothetical_documents: List[str]
@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(
request: RunRequest
):
"""
Runs the RAG pipeline for a given PDF document (converted to Markdown internally)
and a list of questions.
"""
pdf_url = request.documents
questions = request.questions
local_markdown_path = None
step_timings = {}
start_time_total = time.perf_counter()
try:
if hybrid_search_manager is None:
raise HTTPException(
status_code=500, detail="HybridSearchManager not initialized."
)
# 1. Parsing: Download PDF and parse to Markdown
start_time = time.perf_counter()
markdown_content = await download_and_parse_document_using_llama_index(pdf_url)
with tempfile.NamedTemporaryFile(
mode="w", delete=False, encoding="utf-8", suffix=".md"
) as temp_md_file:
temp_md_file.write(markdown_content)
local_markdown_path = temp_md_file.name
end_time = time.perf_counter()
step_timings["parsing_to_markdown"] = end_time - start_time
print(
f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds."
)
# 2. Chunk Generation: Process Markdown into chunks
start_time = time.perf_counter()
processed_documents = process_markdown_with_recursive_chunking(
local_markdown_path,
CHUNK_SIZE,
CHUNK_OVERLAP,
)
if not processed_documents:
raise HTTPException(
status_code=500, detail="Failed to process document into chunks."
)
end_time = time.perf_counter()
step_timings["chunk_generation"] = end_time - start_time
print(
f"Chunk Generation took {step_timings['chunk_generation']:.2f} seconds."
)
# 3. Model Initialization and Embeddings Pre-computation
start_time = time.perf_counter()
await hybrid_search_manager.initialize_models(processed_documents)
end_time = time.perf_counter()
step_timings["model_initialization"] = end_time - start_time
print(
f"Model initialization took {step_timings['model_initialization']:.2f} seconds."
)
# --- NEW CONCURRENT WORKFLOW ---
# 4. Concurrently generate all hypothetical documents
start_time_hyde = time.perf_counter()
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in questions]
all_hyde_docs = await asyncio.gather(*hyde_tasks)
end_time_hyde = time.perf_counter()
step_timings["hyde_generation_total_time"] = end_time_hyde - start_time_hyde
step_timings["hyde_generation_avg_time_per_query"] = (end_time_hyde - start_time_hyde) / len(questions)
# 5. Concurrently perform initial hybrid search to get candidates for ALL queries
start_time_search = time.perf_counter()
candidate_retrieval_tasks = [
hybrid_search_manager.retrieve_candidates(q, hyde_doc)
for q, hyde_doc in zip(questions, all_hyde_docs)
]
all_candidates = await asyncio.gather(*candidate_retrieval_tasks)
end_time_search = time.perf_counter()
step_timings["candidate_retrieval_total_time"] = end_time_search - start_time_search
# 6. Concurrently rerank the candidates for ALL queries
start_time_rerank = time.perf_counter()
rerank_tasks = [
hybrid_search_manager.rerank_results(q, candidates, TOP_K_CHUNKS)
for q, candidates in zip(questions, all_candidates)
]
reranked_results_and_times = await asyncio.gather(*rerank_tasks)
end_time_rerank = time.perf_counter()
step_timings["reranking_total_time"] = end_time_rerank - start_time_rerank
# Unpack reranked results and timings
all_retrieved_results = [item[0] for item in reranked_results_and_times]
all_rerank_times = [item[1] for item in reranked_results_and_times]
step_timings["reranking_avg_time_per_query"] = (end_time_rerank - start_time_rerank) / len(questions)
# 7. Concurrently generate final answers
start_time_generation = time.perf_counter()
generation_tasks = []
for question, retrieved_results in zip(questions, all_retrieved_results):
if retrieved_results:
generation_tasks.append(
generate_answer_with_groq(
question, retrieved_results, GROQ_API_KEY
)
)
else:
no_info_future = asyncio.Future()
no_info_future.set_result(
"No relevant information found in the document to answer this question."
)
generation_tasks.append(no_info_future)
all_answer_texts = await asyncio.gather(*generation_tasks)
end_time_generation = time.perf_counter()
step_timings["generation_total_time"] = end_time_generation - start_time_generation
step_timings["generation_avg_time_per_query"] = (end_time_generation - start_time_generation) / len(questions)
end_time_total = time.perf_counter()
total_processing_time = end_time_total - start_time_total
step_timings["total_processing_time"] = total_processing_time
print("All questions processed.")
all_answers = [answer_text for answer_text in all_answer_texts]
return RunResponse(
answers=all_answers,
#step_timings=step_timings,
#hypothetical_documents=all_hyde_docs
)
except HTTPException as e:
raise e
except Exception as e:
print(f"An unhandled error occurred: {e}")
raise HTTPException(
status_code=500, detail=f"An internal server error occurred: {e}"
)
finally:
if local_markdown_path and os.path.exists(local_markdown_path):
os.unlink(local_markdown_path)
print(f"Cleaned up temporary markdown file: {local_markdown_path}")
|