File size: 2,029 Bytes
8e579f9
d0812dd
bd740be
8e579f9
fb8e216
d0812dd
 
bd740be
d0812dd
 
bd740be
 
 
 
 
 
8e579f9
bd740be
d0812dd
bd740be
 
 
d0812dd
 
bd740be
 
 
 
 
d0812dd
 
 
bd740be
 
 
d0812dd
 
 
 
bd740be
 
 
d0812dd
8e579f9
bd740be
fb8e216
 
 
bd740be
 
 
8e579f9
 
bd740be
 
 
 
 
 
 
8e579f9
 
d0812dd
 
bd740be
 
 
8e579f9
 
 
 
 
 
bd740be
 
 
d0812dd
8e579f9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

from typing import TypedDict, Annotated
from langchain_core.messages import BaseMessage, SystemMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv

from tools import (
    create_rag_tool,
    arxiv_search,
    wikipedia_search,
    tavily_search,
)

load_dotenv()

# ===============================
# SYSTEM PROMPT
# ===============================
SYSTEM_PROMPT = SystemMessage(
    content="""
You are an AI assistant using Retrieval-Augmented Generation.

If a document is uploaded, you MUST answer using it.
If no relevant info exists, clearly say so.
Never hallucinate document content.
"""
)

# ===============================
# STATE
# ===============================
class ChatState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]


# ===============================
# LLM
# ===============================
llm = ChatOpenAI(
    model="gpt-4.1-nano",
    temperature=0.3,
    streaming=True
)

# ===============================
# TOOLS
# ===============================
rag_tool = create_rag_tool()

tools = [
    rag_tool,
    wikipedia_search,
    arxiv_search,
    tavily_search,
]

llm = llm.bind_tools(tools)
tool_node = ToolNode(tools)


# ===============================
# CHAT NODE
# ===============================
def chatbot(state: ChatState):
    messages = [SYSTEM_PROMPT] + state["messages"]
    response = llm.invoke(messages)
    return {"messages": [response]}


# ===============================
# GRAPH
# ===============================
memory = MemorySaver()
graph = StateGraph(ChatState)

graph.add_node("chat", chatbot)
graph.add_node("tools", tool_node)

graph.add_edge(START, "chat")
graph.add_conditional_edges("chat", tools_condition)
graph.add_edge("tools", "chat")
graph.add_edge("chat", END)

app = graph.compile(checkpointer=memory)