chatbot / main.py
Keshabwi66's picture
Huggingface
4184ffc
# rag_agent_app/backend/main.py
import os
import time
from typing import List, Dict, Any
import tempfile
from fastapi import FastAPI, HTTPException, status, UploadFile, File
from pydantic import BaseModel, Field
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_community.document_loaders import PyPDFLoader
from agent import rag_agent
from vectorstore import add_document_to_vectorstore
# Initialize FastAPI app
app = FastAPI(
title="LangGraph RAG Agent API",
description="API for the LangGraph-powered RAG agent with Pinecone and Groq.",
version="1.0.0",
)
# In-memory session manager for LangGraph checkpoints (for demonstration)
memory = MemorySaver()
# --- Pydantic Models for API ---
class TraceEvent(BaseModel):
step: int
node_name: str
description: str
details: Dict[str, Any] = Field(default_factory=dict)
event_type: str
class QueryRequest(BaseModel):
session_id: str
query: str
enable_web_search: bool = True # NEW: Add web search toggle state
class AgentResponse(BaseModel):
response: str
trace_events: List[TraceEvent] = Field(default_factory=list)
class DocumentUploadResponse(BaseModel):
message: str
filename: str
processed_chunks: int
# --- Document Upload Endpoint ---
@app.post("/upload-document/", response_model=DocumentUploadResponse, status_code=status.HTTP_200_OK)
async def upload_document(file: UploadFile = File(...)):
"""
Uploads a PDF document, extracts text, and adds it to the RAG knowledge base.
"""
if not file.filename.endswith(".pdf"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only PDF files are supported."
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
file_content = await file.read()
tmp_file.write(file_content)
temp_file_path = tmp_file.name
print(f"Received PDF for upload: {file.filename}. Saved temporarily to {temp_file_path}")
try:
loader = PyPDFLoader(temp_file_path)
documents = loader.load()
total_chunks_added = 0
if documents:
full_text_content = "\n\n".join([doc.page_content for doc in documents])
add_document_to_vectorstore(full_text_content)
total_chunks_added = len(documents)
return DocumentUploadResponse(
message=f"PDF '{file.filename}' successfully uploaded and indexed.",
filename=file.filename,
processed_chunks=total_chunks_added
)
except Exception as e:
print(f"Error processing PDF document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process PDF: {e}"
)
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f"Cleaned up temporary file: {temp_file_path}")
# --- Chat Endpoint ---
@app.post("/chat/", response_model=AgentResponse)
async def chat_with_agent(request: QueryRequest):
trace_events_for_frontend: List[TraceEvent] = []
try:
# Pass enable_web_search into the config for the agent to access
config = {
"configurable": {
"thread_id": request.session_id,
"web_search_enabled": request.enable_web_search
}
}
inputs = {"messages": [HumanMessage(content=request.query)]}
final_message = ""
print(f"--- Starting Agent Stream for session {request.session_id} ---")
print(f"Web Search Enabled: {request.enable_web_search}") # For server-side debugging
for i, s in enumerate(rag_agent.stream(inputs, config=config)):
current_node_name = None
node_output_state = None
if '__end__' in s:
current_node_name = '__end__'
node_output_state = s['__end__']
else:
current_node_name = list(s.keys())[0]
node_output_state = s[current_node_name]
event_description = f"Executing node: {current_node_name}"
event_details = {}
event_type = "generic_node_execution"
if current_node_name == "router":
route_decision = node_output_state.get('route')
# Check for overridden route if web search was disabled
initial_decision = node_output_state.get('initial_router_decision', route_decision)
override_reason = node_output_state.get('router_override_reason', None)
if override_reason:
event_description = f"Router initially decided: '{initial_decision}'. Overridden to: '{route_decision}' because {override_reason}."
event_details = {"initial_decision": initial_decision, "final_decision": route_decision, "override_reason": override_reason}
else:
event_description = f"Router decided: '{route_decision}'"
event_details = {"decision": route_decision, "reason": "Based on initial query analysis."}
event_type = "router_decision"
elif current_node_name == "rag_lookup":
rag_content_summary = node_output_state.get("rag", "")[:200] + "..."
rag_sufficient = node_output_state.get("route") == "answer"
if rag_sufficient:
event_description = f"RAG Lookup performed. Content found and deemed sufficient. Proceeding to answer."
event_details = {"retrieved_content_summary": rag_content_summary, "sufficiency_verdict": "Sufficient"}
else:
event_description = f"RAG Lookup performed. Content NOT sufficient. Diverting to web search."
event_details = {"retrieved_content_summary": rag_content_summary, "sufficiency_verdict": "Not Sufficient"}
event_type = "rag_action"
elif current_node_name == "web_search":
web_content_summary = node_output_state.get("web", "")[:200] + "..."
event_description = f"Web Search performed. Results retrieved. Proceeding to answer."
event_details = {"retrieved_content_summary": web_content_summary}
event_type = "web_action"
elif current_node_name == "answer":
event_description = "Generating final answer using gathered context."
event_type = "answer_generation"
elif current_node_name == "__end__":
event_description = "Agent process completed."
event_type = "process_end"
trace_events_for_frontend.append(
TraceEvent(
step=i + 1,
node_name=current_node_name,
description=event_description,
details=event_details,
event_type=event_type
)
)
print(f"Streamed Event: Step {i+1} - Node: {current_node_name} - Desc: {event_description}")
# Get the final state from the last yielded item in the stream
final_actual_state_dict = None
if s:
if '__end__' in s:
final_actual_state_dict = s['__end__']
else:
if list(s.keys()):
final_actual_state_dict = s[list(s.keys())[0]]
if final_actual_state_dict and "messages" in final_actual_state_dict:
for msg in reversed(final_actual_state_dict["messages"]):
if isinstance(msg, AIMessage):
final_message = msg.content
break
if not final_message:
print("Agent finished, but no final AIMessage found in the final state after stream completion.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Agent did not return a valid response (final AI message not found).")
print(f"--- Agent Stream Ended. Final Response: {final_message[:200]}... ---")
return AgentResponse(response=final_message, trace_events=trace_events_for_frontend)
except Exception as e:
import traceback
traceback.print_exc()
error_details = f"Error during agent invocation: {e}"
print(error_details)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal Server Error: {e}")
@app.get("/health")
async def health_check():
return {"status": "ok"}