# chatbot_backend.py from langgraph.graph import StateGraph, START from typing import TypedDict, Annotated, Optional from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition from langchain_community.tools import DuckDuckGoSearchRun from dotenv import load_dotenv from model_selection import load_llm import sqlite3 import os load_dotenv() # ------------------- # 1. Tools # ------------------- search_tool = DuckDuckGoSearchRun(region="us-en") tools = [search_tool] # ------------------- # 2. State # ------------------- class ChatState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] doc_content: Optional[str] # User's uploaded document content search_enabled: bool # Whether to use external search model_choice: str # Model selected ("OpenAI (Paid)" or "LLaMA (Open Source)") doc_source: Optional[str] # Document source ("Upload Your Documents" or "Use Sample Documents") system_prompt: Optional[str] # ✅ New field for system prompt # ------------------- # 3. Chat Node (Main LLM Node) # ------------------- def chat_node(state: ChatState): """ Handles chat flow dynamically: - Loads model (OpenAI or LLaMA) - Adds document context if provided - Enables search if required """ messages = state["messages"] doc_content = state.get("doc_content") search_enabled = state.get("search_enabled", False) model_choice = state.get("model_choice", "OpenAI (Paid)") system_prompt = state.get("system_prompt") if system_prompt: messages = [SystemMessage(content=system_prompt)] + messages # 1️⃣ Get appropriate LLM llm = load_llm(model_choice) llm_with_tools = llm.bind_tools(tools) # 2️⃣ Inject document context (if available) if doc_content: messages = [ HumanMessage( content=f"Use the following document context for your answers if relevant:\n\n{doc_content}" ) ] + messages # 3️⃣ Run LLM (with or without tools) if search_enabled: response = llm_with_tools.invoke(messages) else: response = llm.invoke(messages) return {"messages": [response]} # ------------------- # 4. Tool Node # ------------------- tool_node = ToolNode([search_tool]) # ------------------- # 5. Checkpointer # ------------------- conn = sqlite3.connect(database="chatbot.db", check_same_thread=False) checkpointer = SqliteSaver(conn=conn) # ------------------- # 6. Graph Definition # ------------------- graph = StateGraph(ChatState) graph.add_node("chat_node", chat_node) graph.add_node("tools", tool_node) graph.add_edge(START, "chat_node") graph.add_conditional_edges("chat_node", tools_condition) graph.add_edge("tools", "chat_node") # Compile chatbot graph chatbot = graph.compile(checkpointer=checkpointer) # ------------------- # 7. Helper Functions # ------------------- def retrieve_all_threads(): """ Retrieve all unique thread IDs from SQLite checkpointer. """ all_threads = set() for checkpoint in checkpointer.list(None): all_threads.add(checkpoint.config["configurable"]["thread_id"]) return list(all_threads)