ktluege commited on
Commit
442e89c
Β·
verified Β·
1 Parent(s): 96c81f4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +36 -129
agent.py CHANGED
@@ -2,78 +2,43 @@
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
  from langchain_openai import ChatOpenAI
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_community.document_loaders import WikipediaLoader
11
  from langchain_community.document_loaders import ArxivLoader
12
  from langchain_community.vectorstores import SupabaseVectorStore
13
- from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
- from langchain.tools.retriever import create_retriever_tool
16
- from supabase.client import Client, create_client
17
 
18
  load_dotenv()
19
 
 
20
  @tool
21
  def multiply(a: int, b: int) -> int:
22
- """Multiply two numbers.
23
- Args:
24
- a: first int
25
- b: second int
26
- """
27
  return a * b
28
 
29
  @tool
30
  def add(a: int, b: int) -> int:
31
- """Add two numbers.
32
-
33
- Args:
34
- a: first int
35
- b: second int
36
- """
37
  return a + b
38
 
39
  @tool
40
  def subtract(a: int, b: int) -> int:
41
- """Subtract two numbers.
42
-
43
- Args:
44
- a: first int
45
- b: second int
46
- """
47
  return a - b
48
 
49
  @tool
50
- def divide(a: int, b: int) -> int:
51
- """Divide two numbers.
52
-
53
- Args:
54
- a: first int
55
- b: second int
56
- """
57
  if b == 0:
58
  raise ValueError("Cannot divide by zero.")
59
  return a / b
60
 
61
  @tool
62
  def modulus(a: int, b: int) -> int:
63
- """Get the modulus of two numbers.
64
-
65
- Args:
66
- a: first int
67
- b: second int
68
- """
69
  return a % b
70
 
71
  @tool
72
  def wiki_search(query: str) -> str:
73
- """Search Wikipedia for a query and return maximum 2 results.
74
-
75
- Args:
76
- query: The search query."""
77
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
78
  formatted_search_docs = "\n\n---\n\n".join(
79
  [
@@ -84,10 +49,6 @@ def wiki_search(query: str) -> str:
84
 
85
  @tool
86
  def web_search(query: str) -> str:
87
- """Search Tavily for a query and return maximum 3 results.
88
-
89
- Args:
90
- query: The search query."""
91
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
92
  formatted_search_docs = "\n\n---\n\n".join(
93
  [
@@ -98,10 +59,6 @@ def web_search(query: str) -> str:
98
 
99
  @tool
100
  def arvix_search(query: str) -> str:
101
- """Search Arxiv for a query and return maximum 3 result.
102
-
103
- Args:
104
- query: The search query."""
105
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
106
  formatted_search_docs = "\n\n---\n\n".join(
107
  [
@@ -110,52 +67,37 @@ def arvix_search(query: str) -> str:
110
  ])
111
  return {"arvix_results": formatted_search_docs}
112
 
 
 
 
 
113
 
114
-
115
- # load the system prompt from the file
116
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
117
  system_prompt = f.read()
118
-
119
- # System message
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
- # build a retriever
123
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
- supabase: Client = create_client(
125
  os.environ.get("SUPABASE_URL"),
126
- os.environ.get("SUPABASE_SERVICE_KEY"))
 
127
  vector_store = SupabaseVectorStore(
128
  client=supabase,
129
- embedding= embeddings,
130
  table_name="documents",
131
  query_name="match_documents_langchain",
132
  )
133
- create_retriever_tool = create_retriever_tool(
134
- retriever=vector_store.as_retriever(),
135
- name="Question Search",
136
- description="A tool to retrieve similar questions from a vector store.",
137
- )
138
 
139
-
140
-
141
- tools = [
142
- multiply,
143
- add,
144
- subtract,
145
- divide,
146
- modulus,
147
- wiki_search,
148
- web_search,
149
- arvix_search,
150
- ]
151
-
152
- # Build graph function
153
  def build_graph(provider: str = "openai"):
154
- """Build the graph"""
155
- if provider == "google":
156
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
157
- elif provider == "groq":
158
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
159
  elif provider == "huggingface":
160
  llm = ChatHuggingFace(
161
  llm=HuggingFaceEndpoint(
@@ -163,67 +105,32 @@ def build_graph(provider: str = "openai"):
163
  temperature=0,
164
  ),
165
  )
166
- elif provider == "openai":
167
- llm = ChatOpenAI(
168
- model="gpt-3.5-turbo", # or "gpt-4o"
169
- temperature=0,
170
- openai_api_key=os.environ.get("OPENAI_API_KEY"),
171
- )
172
  else:
173
- raise ValueError("Invalid provider. Choose 'google', 'groq', 'huggingface', or 'openai'.")
174
 
175
- # Bind tools to LLM
176
  llm_with_tools = llm.bind_tools(tools)
177
 
178
- # Node
179
- def assistant(state: MessagesState):
180
- """Assistant node"""
181
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
182
-
183
- # def retriever(state: MessagesState):
184
- # """Retriever node"""
185
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
186
- #example_msg = HumanMessage(
187
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
188
- # )
189
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
190
-
191
  from langchain_core.messages import AIMessage
192
 
193
  def retriever(state: MessagesState):
194
  query = state["messages"][-1].content
195
  results = vector_store.similarity_search(query, k=1)
196
- if not results:
197
- from langchain_core.messages import AIMessage
198
- return {"messages": [AIMessage(content="FINAL ANSWER: No relevant answer found.")]}
199
- similar_doc = results[0]
200
- content = similar_doc.page_content
201
- if "Final answer :" in content:
202
- answer = content.split("Final answer :")[-1].strip()
 
 
203
  else:
204
- answer = content.strip()
205
- from langchain_core.messages import AIMessage
206
- return {"messages": [AIMessage(content=answer)]}
207
-
208
-
209
- # builder = StateGraph(MessagesState)
210
- #builder.add_node("retriever", retriever)
211
- #builder.add_node("assistant", assistant)
212
- #builder.add_node("tools", ToolNode(tools))
213
- #builder.add_edge(START, "retriever")
214
- #builder.add_edge("retriever", "assistant")
215
- #builder.add_conditional_edges(
216
- # "assistant",
217
- # tools_condition,
218
- #)
219
- #builder.add_edge("tools", "assistant")
220
 
221
  builder = StateGraph(MessagesState)
222
  builder.add_node("retriever", retriever)
223
-
224
- # Retriever ist Start und Endpunkt
225
  builder.set_entry_point("retriever")
226
  builder.set_finish_point("retriever")
227
-
228
- # Compile graph
229
- return builder.compile()
 
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
 
 
5
  from langchain_openai import ChatOpenAI
6
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
  from langchain_community.document_loaders import ArxivLoader
10
  from langchain_community.vectorstores import SupabaseVectorStore
11
+ from langchain_core.messages import SystemMessage
12
  from langchain_core.tools import tool
13
+ from supabase.client import create_client
 
14
 
15
  load_dotenv()
16
 
17
+ # ----- TOOLS -----
18
  @tool
19
  def multiply(a: int, b: int) -> int:
 
 
 
 
 
20
  return a * b
21
 
22
  @tool
23
  def add(a: int, b: int) -> int:
 
 
 
 
 
 
24
  return a + b
25
 
26
  @tool
27
  def subtract(a: int, b: int) -> int:
 
 
 
 
 
 
28
  return a - b
29
 
30
  @tool
31
+ def divide(a: int, b: int) -> float:
 
 
 
 
 
 
32
  if b == 0:
33
  raise ValueError("Cannot divide by zero.")
34
  return a / b
35
 
36
  @tool
37
  def modulus(a: int, b: int) -> int:
 
 
 
 
 
 
38
  return a % b
39
 
40
  @tool
41
  def wiki_search(query: str) -> str:
 
 
 
 
42
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
43
  formatted_search_docs = "\n\n---\n\n".join(
44
  [
 
49
 
50
  @tool
51
  def web_search(query: str) -> str:
 
 
 
 
52
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
53
  formatted_search_docs = "\n\n---\n\n".join(
54
  [
 
59
 
60
  @tool
61
  def arvix_search(query: str) -> str:
 
 
 
 
62
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
63
  formatted_search_docs = "\n\n---\n\n".join(
64
  [
 
67
  ])
68
  return {"arvix_results": formatted_search_docs}
69
 
70
+ tools = [
71
+ multiply, add, subtract, divide, modulus,
72
+ wiki_search, web_search, arvix_search,
73
+ ]
74
 
75
+ # ----- SYSTEM PROMPT -----
 
76
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
77
  system_prompt = f.read()
 
 
78
  sys_msg = SystemMessage(content=system_prompt)
79
 
80
+ # ----- VECTOR STORE -----
81
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
82
+ supabase = create_client(
83
  os.environ.get("SUPABASE_URL"),
84
+ os.environ.get("SUPABASE_SERVICE_KEY")
85
+ )
86
  vector_store = SupabaseVectorStore(
87
  client=supabase,
88
+ embedding=embeddings,
89
  table_name="documents",
90
  query_name="match_documents_langchain",
91
  )
 
 
 
 
 
92
 
93
+ # ----- GRAPH -----
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def build_graph(provider: str = "openai"):
95
+ if provider == "openai":
96
+ llm = ChatOpenAI(
97
+ model="gpt-3.5-turbo", # or "gpt-4o"
98
+ temperature=0,
99
+ openai_api_key=os.environ.get("OPENAI_API_KEY"),
100
+ )
101
  elif provider == "huggingface":
102
  llm = ChatHuggingFace(
103
  llm=HuggingFaceEndpoint(
 
105
  temperature=0,
106
  ),
107
  )
 
 
 
 
 
 
108
  else:
109
+ raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
110
 
 
111
  llm_with_tools = llm.bind_tools(tools)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  from langchain_core.messages import AIMessage
114
 
115
  def retriever(state: MessagesState):
116
  query = state["messages"][-1].content
117
  results = vector_store.similarity_search(query, k=1)
118
+ if results:
119
+ similar_doc = results[0]
120
+ content = similar_doc.page_content.strip()
121
+ # Remove "FINAL ANSWER:" if present
122
+ if "FINAL ANSWER:" in content:
123
+ answer = content.split("FINAL ANSWER:")[-1].strip()
124
+ else:
125
+ answer = content
126
+ return {"messages": [AIMessage(content=answer)]}
127
  else:
128
+ # Fallback to LLM + tools, only the answer (no prefix)
129
+ answer_msg = llm_with_tools.invoke([sys_msg, state["messages"][-1]])
130
+ return {"messages": [AIMessage(content=answer_msg.content.strip())]}
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  builder = StateGraph(MessagesState)
133
  builder.add_node("retriever", retriever)
 
 
134
  builder.set_entry_point("retriever")
135
  builder.set_finish_point("retriever")
136
+ return builder.compile()