AnyRAG-WebSearch / src /chatbot_backend.py
Rashid Ali
chatstate fix
10b1f68
# chatbot_backend.py
from langgraph.graph import StateGraph, START
from typing import TypedDict, Annotated, Optional
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools import DuckDuckGoSearchRun
from dotenv import load_dotenv
from model_selection import load_llm
import sqlite3
import os
load_dotenv()
# -------------------
# 1. Tools
# -------------------
search_tool = DuckDuckGoSearchRun(region="us-en")
tools = [search_tool]
# -------------------
# 2. State
# -------------------
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
doc_content: Optional[str] # User's uploaded document content
search_enabled: bool # Whether to use external search
model_choice: str # Model selected ("OpenAI (Paid)" or "LLaMA (Open Source)")
doc_source: Optional[str] # Document source ("Upload Your Documents" or "Use Sample Documents")
system_prompt: Optional[str] # ✅ New field for system prompt
# -------------------
# 3. Chat Node (Main LLM Node)
# -------------------
def chat_node(state: ChatState):
"""
Handles chat flow dynamically:
- Loads model (OpenAI or LLaMA)
- Adds document context if provided
- Enables search if required
"""
messages = state["messages"]
doc_content = state.get("doc_content")
search_enabled = state.get("search_enabled", False)
model_choice = state.get("model_choice", "OpenAI (Paid)")
system_prompt = state.get("system_prompt")
if system_prompt:
messages = [SystemMessage(content=system_prompt)] + messages
# 1️⃣ Get appropriate LLM
llm = load_llm(model_choice)
llm_with_tools = llm.bind_tools(tools)
# 2️⃣ Inject document context (if available)
if doc_content:
messages = [
HumanMessage(
content=f"Use the following document context for your answers if relevant:\n\n{doc_content}"
)
] + messages
# 3️⃣ Run LLM (with or without tools)
if search_enabled:
response = llm_with_tools.invoke(messages)
else:
response = llm.invoke(messages)
return {"messages": [response]}
# -------------------
# 4. Tool Node
# -------------------
tool_node = ToolNode([search_tool])
# -------------------
# 5. Checkpointer
# -------------------
conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)
# -------------------
# 6. Graph Definition
# -------------------
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "chat_node")
graph.add_conditional_edges("chat_node", tools_condition)
graph.add_edge("tools", "chat_node")
# Compile chatbot graph
chatbot = graph.compile(checkpointer=checkpointer)
# -------------------
# 7. Helper Functions
# -------------------
def retrieve_all_threads():
"""
Retrieve all unique thread IDs from SQLite checkpointer.
"""
all_threads = set()
for checkpoint in checkpointer.list(None):
all_threads.add(checkpoint.config["configurable"]["thread_id"])
return list(all_threads)