File size: 3,381 Bytes
aaa9e08
 
a2dd353
aaa9e08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2dd353
aaa9e08
 
10b1f68
 
aaa9e08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# chatbot_backend.py
from langgraph.graph import StateGraph, START
from typing import TypedDict, Annotated, Optional
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools import DuckDuckGoSearchRun
from dotenv import load_dotenv
from model_selection import load_llm
import sqlite3
import os

load_dotenv()

# -------------------
# 1. Tools
# -------------------
search_tool = DuckDuckGoSearchRun(region="us-en")
tools = [search_tool]

# -------------------
# 2. State
# -------------------
class ChatState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    doc_content: Optional[str]      # User's uploaded document content
    search_enabled: bool         # Whether to use external search
    model_choice: str            # Model selected ("OpenAI (Paid)" or "LLaMA (Open Source)")
    doc_source: Optional[str]       # Document source ("Upload Your Documents" or "Use Sample Documents")
    system_prompt: Optional[str]     # ✅ New field for system prompt
    
# -------------------
# 3. Chat Node (Main LLM Node)
# -------------------
def chat_node(state: ChatState):
    """
    Handles chat flow dynamically:
    - Loads model (OpenAI or LLaMA)
    - Adds document context if provided
    - Enables search if required
    """
    messages = state["messages"]
    doc_content = state.get("doc_content")
    search_enabled = state.get("search_enabled", False)
    model_choice = state.get("model_choice", "OpenAI (Paid)")
    system_prompt = state.get("system_prompt")
    
    if system_prompt:
        messages = [SystemMessage(content=system_prompt)] + messages
    
    # 1️⃣ Get appropriate LLM
    llm = load_llm(model_choice)
    llm_with_tools = llm.bind_tools(tools)

    # 2️⃣ Inject document context (if available)
    if doc_content:
        messages = [
            HumanMessage(
                content=f"Use the following document context for your answers if relevant:\n\n{doc_content}"
            )
        ] + messages

    # 3️⃣ Run LLM (with or without tools)
    if search_enabled:
        response = llm_with_tools.invoke(messages)
    else:
        response = llm.invoke(messages)

    return {"messages": [response]}

# -------------------
# 4. Tool Node
# -------------------
tool_node = ToolNode([search_tool])

# -------------------
# 5. Checkpointer
# -------------------
conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)

# -------------------
# 6. Graph Definition
# -------------------
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "chat_node")
graph.add_conditional_edges("chat_node", tools_condition)
graph.add_edge("tools", "chat_node")

# Compile chatbot graph
chatbot = graph.compile(checkpointer=checkpointer)

# -------------------
# 7. Helper Functions
# -------------------
def retrieve_all_threads():
    """
    Retrieve all unique thread IDs from SQLite checkpointer.
    """
    all_threads = set()
    for checkpoint in checkpointer.list(None):
        all_threads.add(checkpoint.config["configurable"]["thread_id"])
    return list(all_threads)