Spaces:
Sleeping
Sleeping
| # chatbot_backend.py | |
| from langgraph.graph import StateGraph, START | |
| from typing import TypedDict, Annotated, Optional | |
| from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from dotenv import load_dotenv | |
| from model_selection import load_llm | |
| import sqlite3 | |
| import os | |
| load_dotenv() | |
| # ------------------- | |
| # 1. Tools | |
| # ------------------- | |
| search_tool = DuckDuckGoSearchRun(region="us-en") | |
| tools = [search_tool] | |
| # ------------------- | |
| # 2. State | |
| # ------------------- | |
| class ChatState(TypedDict): | |
| messages: Annotated[list[BaseMessage], add_messages] | |
| doc_content: Optional[str] # User's uploaded document content | |
| search_enabled: bool # Whether to use external search | |
| model_choice: str # Model selected ("OpenAI (Paid)" or "LLaMA (Open Source)") | |
| doc_source: Optional[str] # Document source ("Upload Your Documents" or "Use Sample Documents") | |
| system_prompt: Optional[str] # ✅ New field for system prompt | |
| # ------------------- | |
| # 3. Chat Node (Main LLM Node) | |
| # ------------------- | |
| def chat_node(state: ChatState): | |
| """ | |
| Handles chat flow dynamically: | |
| - Loads model (OpenAI or LLaMA) | |
| - Adds document context if provided | |
| - Enables search if required | |
| """ | |
| messages = state["messages"] | |
| doc_content = state.get("doc_content") | |
| search_enabled = state.get("search_enabled", False) | |
| model_choice = state.get("model_choice", "OpenAI (Paid)") | |
| system_prompt = state.get("system_prompt") | |
| if system_prompt: | |
| messages = [SystemMessage(content=system_prompt)] + messages | |
| # 1️⃣ Get appropriate LLM | |
| llm = load_llm(model_choice) | |
| llm_with_tools = llm.bind_tools(tools) | |
| # 2️⃣ Inject document context (if available) | |
| if doc_content: | |
| messages = [ | |
| HumanMessage( | |
| content=f"Use the following document context for your answers if relevant:\n\n{doc_content}" | |
| ) | |
| ] + messages | |
| # 3️⃣ Run LLM (with or without tools) | |
| if search_enabled: | |
| response = llm_with_tools.invoke(messages) | |
| else: | |
| response = llm.invoke(messages) | |
| return {"messages": [response]} | |
| # ------------------- | |
| # 4. Tool Node | |
| # ------------------- | |
| tool_node = ToolNode([search_tool]) | |
| # ------------------- | |
| # 5. Checkpointer | |
| # ------------------- | |
| conn = sqlite3.connect(database="chatbot.db", check_same_thread=False) | |
| checkpointer = SqliteSaver(conn=conn) | |
| # ------------------- | |
| # 6. Graph Definition | |
| # ------------------- | |
| graph = StateGraph(ChatState) | |
| graph.add_node("chat_node", chat_node) | |
| graph.add_node("tools", tool_node) | |
| graph.add_edge(START, "chat_node") | |
| graph.add_conditional_edges("chat_node", tools_condition) | |
| graph.add_edge("tools", "chat_node") | |
| # Compile chatbot graph | |
| chatbot = graph.compile(checkpointer=checkpointer) | |
| # ------------------- | |
| # 7. Helper Functions | |
| # ------------------- | |
| def retrieve_all_threads(): | |
| """ | |
| Retrieve all unique thread IDs from SQLite checkpointer. | |
| """ | |
| all_threads = set() | |
| for checkpoint in checkpointer.list(None): | |
| all_threads.add(checkpoint.config["configurable"]["thread_id"]) | |
| return list(all_threads) | |