Spaces:
Sleeping
Sleeping
File size: 3,381 Bytes
aaa9e08 a2dd353 aaa9e08 a2dd353 aaa9e08 10b1f68 aaa9e08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# chatbot_backend.py
from langgraph.graph import StateGraph, START
from typing import TypedDict, Annotated, Optional
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools import DuckDuckGoSearchRun
from dotenv import load_dotenv
from model_selection import load_llm
import sqlite3
import os
load_dotenv()
# -------------------
# 1. Tools
# -------------------
search_tool = DuckDuckGoSearchRun(region="us-en")
tools = [search_tool]
# -------------------
# 2. State
# -------------------
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
doc_content: Optional[str] # User's uploaded document content
search_enabled: bool # Whether to use external search
model_choice: str # Model selected ("OpenAI (Paid)" or "LLaMA (Open Source)")
doc_source: Optional[str] # Document source ("Upload Your Documents" or "Use Sample Documents")
system_prompt: Optional[str] # ✅ New field for system prompt
# -------------------
# 3. Chat Node (Main LLM Node)
# -------------------
def chat_node(state: ChatState):
"""
Handles chat flow dynamically:
- Loads model (OpenAI or LLaMA)
- Adds document context if provided
- Enables search if required
"""
messages = state["messages"]
doc_content = state.get("doc_content")
search_enabled = state.get("search_enabled", False)
model_choice = state.get("model_choice", "OpenAI (Paid)")
system_prompt = state.get("system_prompt")
if system_prompt:
messages = [SystemMessage(content=system_prompt)] + messages
# 1️⃣ Get appropriate LLM
llm = load_llm(model_choice)
llm_with_tools = llm.bind_tools(tools)
# 2️⃣ Inject document context (if available)
if doc_content:
messages = [
HumanMessage(
content=f"Use the following document context for your answers if relevant:\n\n{doc_content}"
)
] + messages
# 3️⃣ Run LLM (with or without tools)
if search_enabled:
response = llm_with_tools.invoke(messages)
else:
response = llm.invoke(messages)
return {"messages": [response]}
# -------------------
# 4. Tool Node
# -------------------
tool_node = ToolNode([search_tool])
# -------------------
# 5. Checkpointer
# -------------------
conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)
# -------------------
# 6. Graph Definition
# -------------------
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "chat_node")
graph.add_conditional_edges("chat_node", tools_condition)
graph.add_edge("tools", "chat_node")
# Compile chatbot graph
chatbot = graph.compile(checkpointer=checkpointer)
# -------------------
# 7. Helper Functions
# -------------------
def retrieve_all_threads():
"""
Retrieve all unique thread IDs from SQLite checkpointer.
"""
all_threads = set()
for checkpoint in checkpointer.list(None):
all_threads.add(checkpoint.config["configurable"]["thread_id"])
return list(all_threads)
|