Spaces:
Sleeping
Sleeping
File size: 8,732 Bytes
4184ffc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | # rag_agent_app/backend/main.py
import os
import time
from typing import List, Dict, Any
import tempfile
from fastapi import FastAPI, HTTPException, status, UploadFile, File
from pydantic import BaseModel, Field
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_community.document_loaders import PyPDFLoader
from agent import rag_agent
from vectorstore import add_document_to_vectorstore
# Initialize FastAPI app
app = FastAPI(
title="LangGraph RAG Agent API",
description="API for the LangGraph-powered RAG agent with Pinecone and Groq.",
version="1.0.0",
)
# In-memory session manager for LangGraph checkpoints (for demonstration)
memory = MemorySaver()
# --- Pydantic Models for API ---
class TraceEvent(BaseModel):
step: int
node_name: str
description: str
details: Dict[str, Any] = Field(default_factory=dict)
event_type: str
class QueryRequest(BaseModel):
session_id: str
query: str
enable_web_search: bool = True # NEW: Add web search toggle state
class AgentResponse(BaseModel):
response: str
trace_events: List[TraceEvent] = Field(default_factory=list)
class DocumentUploadResponse(BaseModel):
message: str
filename: str
processed_chunks: int
# --- Document Upload Endpoint ---
@app.post("/upload-document/", response_model=DocumentUploadResponse, status_code=status.HTTP_200_OK)
async def upload_document(file: UploadFile = File(...)):
"""
Uploads a PDF document, extracts text, and adds it to the RAG knowledge base.
"""
if not file.filename.endswith(".pdf"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only PDF files are supported."
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
file_content = await file.read()
tmp_file.write(file_content)
temp_file_path = tmp_file.name
print(f"Received PDF for upload: {file.filename}. Saved temporarily to {temp_file_path}")
try:
loader = PyPDFLoader(temp_file_path)
documents = loader.load()
total_chunks_added = 0
if documents:
full_text_content = "\n\n".join([doc.page_content for doc in documents])
add_document_to_vectorstore(full_text_content)
total_chunks_added = len(documents)
return DocumentUploadResponse(
message=f"PDF '{file.filename}' successfully uploaded and indexed.",
filename=file.filename,
processed_chunks=total_chunks_added
)
except Exception as e:
print(f"Error processing PDF document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process PDF: {e}"
)
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f"Cleaned up temporary file: {temp_file_path}")
# --- Chat Endpoint ---
@app.post("/chat/", response_model=AgentResponse)
async def chat_with_agent(request: QueryRequest):
trace_events_for_frontend: List[TraceEvent] = []
try:
# Pass enable_web_search into the config for the agent to access
config = {
"configurable": {
"thread_id": request.session_id,
"web_search_enabled": request.enable_web_search
}
}
inputs = {"messages": [HumanMessage(content=request.query)]}
final_message = ""
print(f"--- Starting Agent Stream for session {request.session_id} ---")
print(f"Web Search Enabled: {request.enable_web_search}") # For server-side debugging
for i, s in enumerate(rag_agent.stream(inputs, config=config)):
current_node_name = None
node_output_state = None
if '__end__' in s:
current_node_name = '__end__'
node_output_state = s['__end__']
else:
current_node_name = list(s.keys())[0]
node_output_state = s[current_node_name]
event_description = f"Executing node: {current_node_name}"
event_details = {}
event_type = "generic_node_execution"
if current_node_name == "router":
route_decision = node_output_state.get('route')
# Check for overridden route if web search was disabled
initial_decision = node_output_state.get('initial_router_decision', route_decision)
override_reason = node_output_state.get('router_override_reason', None)
if override_reason:
event_description = f"Router initially decided: '{initial_decision}'. Overridden to: '{route_decision}' because {override_reason}."
event_details = {"initial_decision": initial_decision, "final_decision": route_decision, "override_reason": override_reason}
else:
event_description = f"Router decided: '{route_decision}'"
event_details = {"decision": route_decision, "reason": "Based on initial query analysis."}
event_type = "router_decision"
elif current_node_name == "rag_lookup":
rag_content_summary = node_output_state.get("rag", "")[:200] + "..."
rag_sufficient = node_output_state.get("route") == "answer"
if rag_sufficient:
event_description = f"RAG Lookup performed. Content found and deemed sufficient. Proceeding to answer."
event_details = {"retrieved_content_summary": rag_content_summary, "sufficiency_verdict": "Sufficient"}
else:
event_description = f"RAG Lookup performed. Content NOT sufficient. Diverting to web search."
event_details = {"retrieved_content_summary": rag_content_summary, "sufficiency_verdict": "Not Sufficient"}
event_type = "rag_action"
elif current_node_name == "web_search":
web_content_summary = node_output_state.get("web", "")[:200] + "..."
event_description = f"Web Search performed. Results retrieved. Proceeding to answer."
event_details = {"retrieved_content_summary": web_content_summary}
event_type = "web_action"
elif current_node_name == "answer":
event_description = "Generating final answer using gathered context."
event_type = "answer_generation"
elif current_node_name == "__end__":
event_description = "Agent process completed."
event_type = "process_end"
trace_events_for_frontend.append(
TraceEvent(
step=i + 1,
node_name=current_node_name,
description=event_description,
details=event_details,
event_type=event_type
)
)
print(f"Streamed Event: Step {i+1} - Node: {current_node_name} - Desc: {event_description}")
# Get the final state from the last yielded item in the stream
final_actual_state_dict = None
if s:
if '__end__' in s:
final_actual_state_dict = s['__end__']
else:
if list(s.keys()):
final_actual_state_dict = s[list(s.keys())[0]]
if final_actual_state_dict and "messages" in final_actual_state_dict:
for msg in reversed(final_actual_state_dict["messages"]):
if isinstance(msg, AIMessage):
final_message = msg.content
break
if not final_message:
print("Agent finished, but no final AIMessage found in the final state after stream completion.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Agent did not return a valid response (final AI message not found).")
print(f"--- Agent Stream Ended. Final Response: {final_message[:200]}... ---")
return AgentResponse(response=final_message, trace_events=trace_events_for_frontend)
except Exception as e:
import traceback
traceback.print_exc()
error_details = f"Error during agent invocation: {e}"
print(error_details)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal Server Error: {e}")
@app.get("/health")
async def health_check():
return {"status": "ok"} |