aravsaxena884 commited on
Commit
a5ec459
·
1 Parent(s): 54693e5
Files changed (3) hide show
  1. Dockerfile +30 -0
  2. app.py +486 -0
  3. req.txt +18 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ gcc \
8
+ g++ \
9
+ curl \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements and install Python dependencies
13
+ COPY req.txt .
14
+ RUN pip install --no-cache-dir -r req.txt
15
+
16
+ # Copy application code
17
+ COPY app.py .
18
+
19
+ # Create tmp directory for temporary files
20
+ RUN mkdir -p /tmp && chmod 777 /tmp
21
+
22
+ # Expose port
23
+ EXPOSE 7860
24
+
25
+ # Set environment variables
26
+ ENV PORT=7860
27
+ ENV PYTHONUNBUFFERED=1
28
+
29
+ # Run the application
30
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ from typing import Annotated, Literal, Sequence, TypedDict, Optional, List
5
+ import asyncio
6
+ from contextlib import asynccontextmanager
7
+
8
+ import requests
9
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, HttpUrl
12
+ import uvicorn
13
+
14
+ # LangChain imports
15
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
16
+ from langchain_core.prompts import PromptTemplate
17
+ from langchain_core.pydantic_v1 import Field
18
+ from langchain_core.tools import tool
19
+ from langchain_groq import ChatGroq
20
+ from langchain_community.embeddings import HuggingFaceEmbeddings
21
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
22
+ from langchain_core.documents import Document
23
+
24
+ # LangGraph imports
25
+ from langgraph.graph import END, StateGraph, START
26
+ from langgraph.graph.message import add_messages
27
+ from langgraph.prebuilt import tools_condition, ToolNode
28
+
29
+ # Docling imports
30
+ from docling.document_converter import DocumentConverter
31
+ from docling.datamodel.base_models import InputFormat
32
+
33
+ # Qdrant imports
34
+ from qdrant_client import QdrantClient
35
+ from qdrant_client.http import models
36
+ from qdrant_client.http.models import Distance, VectorParams, PointStruct
37
+
38
+ # Configure logging
39
+ logging.basicConfig(level=logging.INFO)
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # Environment variables
43
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
44
+ QDRANT_URL = os.getenv("QDRANT_URL")
45
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
46
+
47
+ if not GROQ_API_KEY:
48
+ raise ValueError("GROQ_API_KEY environment variable is required")
49
+
50
+ # Global variables for clients and models
51
+ qdrant_client = None
52
+ embeddings_model = None
53
+ llm = None
54
+
55
+ @asynccontextmanager
56
+ async def lifespan(app: FastAPI):
57
+ """Initialize global resources on startup"""
58
+ global qdrant_client, embeddings_model, llm
59
+
60
+ # Initialize Qdrant client
61
+ qdrant_client = QdrantClient(
62
+ url=QDRANT_URL,
63
+ api_key=QDRANT_API_KEY,
64
+ timeout=60
65
+ )
66
+
67
+ # Initialize embeddings model
68
+ embeddings_model = HuggingFaceEmbeddings(
69
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
70
+ model_kwargs={'device': 'cpu'}
71
+ )
72
+
73
+ # Initialize LLM
74
+ llm = ChatGroq(
75
+ groq_api_key=GROQ_API_KEY,
76
+ model_name="mixtral-8x7b-32768",
77
+ temperature=0
78
+ )
79
+
80
+ logger.info("Application initialized successfully")
81
+ yield
82
+
83
+ # Cleanup
84
+ logger.info("Application shutting down")
85
+
86
+ app = FastAPI(
87
+ title="Agentic RAG with PDF Processing",
88
+ description="Production-ready RAG system with agentic workflow for PDF Q&A",
89
+ version="1.0.0",
90
+ lifespan=lifespan
91
+ )
92
+
93
+ # CORS middleware
94
+ app.add_middleware(
95
+ CORSMiddleware,
96
+ allow_origins=["*"],
97
+ allow_credentials=True,
98
+ allow_methods=["*"],
99
+ allow_headers=["*"],
100
+ )
101
+
102
+ # Pydantic models
103
+ class PDFUploadRequest(BaseModel):
104
+ pdf_url: HttpUrl
105
+ collection_name: Optional[str] = None
106
+
107
+ class QuestionRequest(BaseModel):
108
+ question: str
109
+ collection_name: str
110
+
111
+ class ChatResponse(BaseModel):
112
+ answer: str
113
+ sources: List[str] = []
114
+ metadata: dict = {}
115
+
116
+ # Agent State
117
+ class AgentState(TypedDict):
118
+ messages: Annotated[Sequence[BaseMessage], add_messages]
119
+ collection_name: str
120
+
121
+ # Document processing functions
122
+ async def download_pdf(url: str) -> bytes:
123
+ """Download PDF from URL"""
124
+ try:
125
+ response = requests.get(str(url), timeout=30)
126
+ response.raise_for_status()
127
+ return response.content
128
+ except Exception as e:
129
+ logger.error(f"Failed to download PDF: {e}")
130
+ raise HTTPException(status_code=400, detail=f"Failed to download PDF: {e}")
131
+
132
+ async def extract_pdf_content(pdf_content: bytes) -> List[Document]:
133
+ """Extract content from PDF using Docling"""
134
+ try:
135
+ # Initialize document converter
136
+ converter = DocumentConverter()
137
+
138
+ # Save PDF content to temporary file
139
+ temp_file = f"/tmp/{uuid.uuid4()}.pdf"
140
+ with open(temp_file, "wb") as f:
141
+ f.write(pdf_content)
142
+
143
+ # Convert document
144
+ result = converter.convert(temp_file)
145
+
146
+ # Extract text and create documents
147
+ documents = []
148
+ full_text = result.document.export_to_markdown()
149
+
150
+ # Split text into chunks
151
+ text_splitter = RecursiveCharacterTextSplitter(
152
+ chunk_size=1000,
153
+ chunk_overlap=200,
154
+ separators=["\n\n", "\n", " ", ""]
155
+ )
156
+
157
+ chunks = text_splitter.split_text(full_text)
158
+
159
+ for i, chunk in enumerate(chunks):
160
+ doc = Document(
161
+ page_content=chunk,
162
+ metadata={
163
+ "source": "pdf",
164
+ "chunk_id": i,
165
+ "total_chunks": len(chunks)
166
+ }
167
+ )
168
+ documents.append(doc)
169
+
170
+ # Clean up temporary file
171
+ os.remove(temp_file)
172
+
173
+ logger.info(f"Extracted {len(documents)} document chunks")
174
+ return documents
175
+
176
+ except Exception as e:
177
+ logger.error(f"Failed to extract PDF content: {e}")
178
+ raise HTTPException(status_code=500, detail=f"Failed to extract PDF content: {e}")
179
+
180
+ async def store_in_qdrant(documents: List[Document], collection_name: str):
181
+ """Store documents in Qdrant vector database"""
182
+ try:
183
+ # Create collection if it doesn't exist
184
+ try:
185
+ qdrant_client.get_collection(collection_name)
186
+ except Exception:
187
+ qdrant_client.create_collection(
188
+ collection_name=collection_name,
189
+ vectors_config=VectorParams(size=384, distance=Distance.COSINE)
190
+ )
191
+
192
+ # Generate embeddings and store documents
193
+ points = []
194
+ for i, doc in enumerate(documents):
195
+ embedding = embeddings_model.embed_query(doc.page_content)
196
+
197
+ point = PointStruct(
198
+ id=i,
199
+ vector=embedding,
200
+ payload={
201
+ "text": doc.page_content,
202
+ "metadata": doc.metadata
203
+ }
204
+ )
205
+ points.append(point)
206
+
207
+ # Upload points in batches
208
+ batch_size = 100
209
+ for i in range(0, len(points), batch_size):
210
+ batch = points[i:i + batch_size]
211
+ qdrant_client.upsert(
212
+ collection_name=collection_name,
213
+ points=batch
214
+ )
215
+
216
+ logger.info(f"Stored {len(documents)} documents in Qdrant collection: {collection_name}")
217
+
218
+ except Exception as e:
219
+ logger.error(f"Failed to store documents in Qdrant: {e}")
220
+ raise HTTPException(status_code=500, detail=f"Failed to store documents: {e}")
221
+
222
+ # RAG Tools
223
+ @tool
224
+ def retriever_tool(query: str, collection_name: str) -> str:
225
+ """Retrieve relevant documents from Qdrant based on the query."""
226
+ try:
227
+ # Generate query embedding
228
+ query_embedding = embeddings_model.embed_query(query)
229
+
230
+ # Search in Qdrant
231
+ search_result = qdrant_client.search(
232
+ collection_name=collection_name,
233
+ query_vector=query_embedding,
234
+ limit=5
235
+ )
236
+
237
+ # Format results
238
+ documents = []
239
+ for result in search_result:
240
+ documents.append(result.payload["text"])
241
+
242
+ return "\n\n".join(documents)
243
+
244
+ except Exception as e:
245
+ logger.error(f"Retrieval failed: {e}")
246
+ return "No relevant documents found."
247
+
248
+ # Agent workflow functions
249
+ def grade_documents(state) -> Literal["generate", "rewrite"]:
250
+ """Determines whether the retrieved documents are relevant to the question."""
251
+ logger.info("---CHECK RELEVANCE---")
252
+
253
+ messages = state["messages"]
254
+ last_message = messages[-1]
255
+ question = messages[0].content
256
+ docs = last_message.content
257
+
258
+ # Create a simple relevance check prompt
259
+ prompt = f"""
260
+ You are assessing the relevance of retrieved documents to a user question.
261
+
262
+ Question: {question}
263
+ Documents: {docs[:500]}...
264
+
265
+ Are these documents relevant to answer the question? Respond with only 'yes' or 'no'.
266
+ """
267
+
268
+ try:
269
+ response = llm.invoke([HumanMessage(content=prompt)])
270
+ decision = response.content.strip().lower()
271
+
272
+ if "yes" in decision:
273
+ logger.info("---DECISION: DOCS RELEVANT---")
274
+ return "generate"
275
+ else:
276
+ logger.info("---DECISION: DOCS NOT RELEVANT---")
277
+ return "rewrite"
278
+ except Exception:
279
+ # Default to generate if assessment fails
280
+ return "generate"
281
+
282
+ def agent(state):
283
+ """Agent that decides whether to retrieve documents or end."""
284
+ logger.info("---CALL AGENT---")
285
+ messages = state["messages"]
286
+ collection_name = state["collection_name"]
287
+
288
+ # Bind the retriever tool to the model
289
+ tools = [retriever_tool]
290
+ model_with_tools = llm.bind_tools(tools)
291
+
292
+ # Add system message about using retrieval
293
+ system_prompt = HumanMessage(
294
+ content=f"""You are an AI assistant with access to a document retrieval tool.
295
+ Use the retriever_tool to find relevant information from the collection '{collection_name}'
296
+ to answer user questions. Always use the tool first before providing an answer."""
297
+ )
298
+
299
+ messages_with_system = [system_prompt] + messages
300
+ response = model_with_tools.invoke(messages_with_system)
301
+
302
+ return {"messages": [response]}
303
+
304
+ def rewrite(state):
305
+ """Transform the query to produce a better question."""
306
+ logger.info("---TRANSFORM QUERY---")
307
+ messages = state["messages"]
308
+ question = messages[0].content
309
+
310
+ rewrite_prompt = f"""
311
+ Look at the input and try to reason about the underlying semantic intent/meaning.
312
+
313
+ Original question: {question}
314
+
315
+ Formulate an improved, more specific question that would help retrieve better documents:
316
+ """
317
+
318
+ try:
319
+ response = llm.invoke([HumanMessage(content=rewrite_prompt)])
320
+ return {"messages": [response]}
321
+ except Exception as e:
322
+ logger.error(f"Rewrite failed: {e}")
323
+ return {"messages": [HumanMessage(content=question)]}
324
+
325
+ def generate(state):
326
+ """Generate final answer based on retrieved documents."""
327
+ logger.info("---GENERATE---")
328
+ messages = state["messages"]
329
+ question = messages[0].content
330
+ last_message = messages[-1]
331
+
332
+ docs = last_message.content
333
+
334
+ # RAG prompt
335
+ rag_prompt = f"""
336
+ Use the following pieces of context to answer the question at the end.
337
+ If you don't know the answer based on the context, just say that you don't know,
338
+ don't try to make up an answer.
339
+
340
+ Context:
341
+ {docs}
342
+
343
+ Question: {question}
344
+
345
+ Answer:
346
+ """
347
+
348
+ try:
349
+ response = llm.invoke([HumanMessage(content=rag_prompt)])
350
+ return {"messages": [response]}
351
+ except Exception as e:
352
+ logger.error(f"Generation failed: {e}")
353
+ return {"messages": [AIMessage(content="I apologize, but I encountered an error generating the response.")]}
354
+
355
+ # Create workflow
356
+ def create_workflow():
357
+ """Create the agent workflow graph."""
358
+ workflow = StateGraph(AgentState)
359
+
360
+ # Add nodes
361
+ workflow.add_node("agent", agent)
362
+ retrieve = ToolNode([retriever_tool])
363
+ workflow.add_node("retrieve", retrieve)
364
+ workflow.add_node("rewrite", rewrite)
365
+ workflow.add_node("generate", generate)
366
+
367
+ # Add edges
368
+ workflow.add_edge(START, "agent")
369
+
370
+ workflow.add_conditional_edges(
371
+ "agent",
372
+ tools_condition,
373
+ {
374
+ "tools": "retrieve",
375
+ END: END,
376
+ },
377
+ )
378
+
379
+ workflow.add_conditional_edges(
380
+ "retrieve",
381
+ grade_documents,
382
+ {
383
+ "generate": "generate",
384
+ "rewrite": "rewrite"
385
+ }
386
+ )
387
+
388
+ workflow.add_edge("generate", END)
389
+ workflow.add_edge("rewrite", "agent")
390
+
391
+ return workflow.compile()
392
+
393
+ # API Endpoints
394
+ @app.post("/upload-pdf", response_model=dict)
395
+ async def upload_pdf(request: PDFUploadRequest, background_tasks: BackgroundTasks):
396
+ """Upload and process PDF from URL"""
397
+ try:
398
+ # Generate collection name if not provided
399
+ collection_name = request.collection_name or f"pdf_{uuid.uuid4().hex[:8]}"
400
+
401
+ # Download PDF
402
+ pdf_content = await download_pdf(request.pdf_url)
403
+
404
+ # Extract content
405
+ documents = await extract_pdf_content(pdf_content)
406
+
407
+ # Store in vector database
408
+ await store_in_qdrant(documents, collection_name)
409
+
410
+ return {
411
+ "status": "success",
412
+ "message": f"PDF processed successfully",
413
+ "collection_name": collection_name,
414
+ "document_count": len(documents)
415
+ }
416
+
417
+ except Exception as e:
418
+ logger.error(f"PDF upload failed: {e}")
419
+ raise HTTPException(status_code=500, detail=str(e))
420
+
421
+ @app.post("/chat", response_model=ChatResponse)
422
+ async def chat(request: QuestionRequest):
423
+ """Chat with the documents using agentic RAG"""
424
+ try:
425
+ # Check if collection exists
426
+ try:
427
+ qdrant_client.get_collection(request.collection_name)
428
+ except Exception:
429
+ raise HTTPException(
430
+ status_code=404,
431
+ detail=f"Collection '{request.collection_name}' not found. Please upload a PDF first."
432
+ )
433
+
434
+ # Create workflow
435
+ workflow = create_workflow()
436
+
437
+ # Initial state
438
+ initial_state = {
439
+ "messages": [HumanMessage(content=request.question)],
440
+ "collection_name": request.collection_name
441
+ }
442
+
443
+ # Run the workflow
444
+ result = workflow.invoke(initial_state)
445
+
446
+ # Extract final answer
447
+ final_message = result["messages"][-1]
448
+ answer = final_message.content if hasattr(final_message, 'content') else str(final_message)
449
+
450
+ return ChatResponse(
451
+ answer=answer,
452
+ sources=[request.collection_name],
453
+ metadata={
454
+ "collection_name": request.collection_name,
455
+ "message_count": len(result["messages"])
456
+ }
457
+ )
458
+
459
+ except HTTPException:
460
+ raise
461
+ except Exception as e:
462
+ logger.error(f"Chat failed: {e}")
463
+ raise HTTPException(status_code=500, detail=f"Chat failed: {e}")
464
+
465
+ @app.get("/collections", response_model=List[str])
466
+ async def list_collections():
467
+ """List all available collections"""
468
+ try:
469
+ collections = qdrant_client.get_collections()
470
+ return [collection.name for collection in collections.collections]
471
+ except Exception as e:
472
+ logger.error(f"Failed to list collections: {e}")
473
+ return []
474
+
475
+ @app.get("/health")
476
+ async def health_check():
477
+ """Health check endpoint"""
478
+ return {"status": "healthy", "message": "Agentic RAG service is running"}
479
+
480
+ if __name__ == "__main__":
481
+ uvicorn.run(
482
+ "app:app",
483
+ host="0.0.0.0",
484
+ port=int(os.getenv("PORT", 7860)),
485
+ reload=False
486
+ )
req.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ langchain
4
+ langchain-core
5
+ langchain-groq
6
+ langchain-community
7
+ langgraph
8
+ docling
9
+ qdrant-client
10
+ sentence-transformers
11
+ transformers
12
+ torch
13
+ requests
14
+ pydantic
15
+ python-multipart
16
+ numpy
17
+ pandas
18
+ Pillow