ktluege commited on
Commit
a023ff4
Β·
verified Β·
1 Parent(s): 0e5949f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +28 -146
agent.py CHANGED
@@ -1,183 +1,65 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langchain_openai import ChatOpenAI # <-- ADD THIS IMPORT
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
- from langchain_community.tools.tavily_search import TavilySearchResults
8
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
  from langchain_community.vectorstores import SupabaseVectorStore
10
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
11
  from langchain_core.tools import tool
12
- from supabase.client import Client, create_client
13
 
14
  load_dotenv()
15
 
16
- @tool
17
- def multiply(a: int, b: int) -> int:
18
- """Multiply two numbers.
19
- Args:
20
- a: first int
21
- b: second int
22
- """
23
- return a * b
24
 
25
- @tool
26
- def add(a: int, b: int) -> int:
27
- """Add two numbers.
28
- Args:
29
- a: first int
30
- b: second int
31
- """
32
- return a + b
33
 
34
- @tool
35
- def subtract(a: int, b: int) -> int:
36
- """Subtract two numbers.
37
- Args:
38
- a: first int
39
- b: second int
40
- """
41
- return a - b
42
-
43
- @tool
44
- def divide(a: int, b: int) -> float:
45
- """Divide two numbers.
46
- Args:
47
- a: first int
48
- b: second int
49
- """
50
- if b == 0:
51
- raise ValueError("Cannot divide by zero.")
52
- return a / b
53
-
54
- @tool
55
- def modulus(a: int, b: int) -> int:
56
- """Get the modulus of two numbers.
57
- Args:
58
- a: first int
59
- b: second int
60
- """
61
- return a % b
62
-
63
- @tool
64
- def wiki_search(query: str) -> str:
65
- """Search Wikipedia for a query and return maximum 2 results.
66
- Args:
67
- query: The search query."""
68
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
69
- formatted_search_docs = "\n\n---\n\n".join(
70
- [
71
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
72
- for doc in search_docs
73
- ])
74
- return {"wiki_results": formatted_search_docs}
75
-
76
- @tool
77
- def web_search(query: str) -> str:
78
- """Search Tavily for a query and return maximum 3 results.
79
- Args:
80
- query: The search query."""
81
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
82
- formatted_search_docs = "\n\n---\n\n".join(
83
- [
84
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
85
- for doc in search_docs
86
- ])
87
- return {"web_results": formatted_search_docs}
88
-
89
- @tool
90
- def arvix_search(query: str) -> str:
91
- """Search Arxiv for a query and return maximum 3 result.
92
- Args:
93
- query: The search query."""
94
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
95
- formatted_search_docs = "\n\n---\n\n".join(
96
- [
97
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
98
- for doc in search_docs
99
- ])
100
- return {"arvix_results": formatted_search_docs}
101
-
102
- tools = [
103
- multiply,
104
- add,
105
- subtract,
106
- divide,
107
- modulus,
108
- wiki_search,
109
- web_search,
110
- arvix_search,
111
- ]
112
-
113
- # load the system prompt from the file
114
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
115
  system_prompt = f.read()
116
 
117
- # System message
118
  sys_msg = SystemMessage(content=system_prompt)
119
 
120
- # build a retriever
121
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
122
- supabase: Client = create_client(
123
  os.environ.get("SUPABASE_URL"),
124
  os.environ.get("SUPABASE_SERVICE_KEY"))
125
  vector_store = SupabaseVectorStore(
126
  client=supabase,
127
- embedding= embeddings,
128
  table_name="documents",
129
- query_name="match_documents_langchain",
130
  )
131
 
132
  def build_graph(provider: str = "openai"):
133
- """Build the graph with OpenAI, Gemini, or HuggingFace backend."""
134
  if provider == "openai":
135
  llm = ChatOpenAI(
136
- model="gpt-3.5-turbo", # or "gpt-4o" if available and you want to use it
137
  temperature=0,
138
  openai_api_key=os.environ.get("OPENAI_API_KEY"),
139
  )
140
- elif provider == "google":
141
- from langchain_google_genai import ChatGoogleGenerativeAI
142
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
143
- elif provider == "huggingface":
144
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
145
- llm = ChatHuggingFace(
146
- llm=HuggingFaceEndpoint(
147
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
148
- temperature=0,
149
- ),
150
- )
151
  else:
152
- raise ValueError("Invalid provider. Choose 'openai', 'google', or 'huggingface'.")
153
 
154
  llm_with_tools = llm.bind_tools(tools)
155
 
156
- from langchain_core.messages import AIMessage
157
-
158
  def retriever(state: MessagesState):
159
- query = state["messages"][-1].content
160
- results = vector_store.similarity_search(query, k=1)
161
- if not results:
162
- return {"messages": [AIMessage(content="FINAL ANSWER: No relevant answer found.")]}
163
- similar_doc = results[0]
164
- content = similar_doc.page_content
165
- if "FINAL ANSWER:" in content:
166
- answer = content.split("FINAL ANSWER:")[-1].strip()
167
- return {"messages": [AIMessage(content=f"FINAL ANSWER: {answer}")]}
168
- else:
169
- return {"messages": [AIMessage(content=content.strip())]}
170
-
171
-
172
- builder = StateGraph(MessagesState)
173
- builder.add_node("retriever", retriever)
174
- builder.add_node("assistant", assistant)
175
- builder.add_edge(START, "retriever")
176
- builder.add_edge("retriever", "assistant")
177
- builder.set_finish_point("assistant")
178
-
179
- return builder.compile()
180
-
181
 
182
  builder = StateGraph(MessagesState)
183
  builder.add_node("retriever", retriever)
 
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from langchain.graph import START, StateGraph, MessagesState
4
+ from langchain_openai import ChatOpenAI
5
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
6
  from langchain_community.vectorstores import SupabaseVectorStore
7
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
8
  from langchain_core.tools import tool
 
9
 
10
  load_dotenv()
11
 
12
+ # ... [Your tool definitions here] ...
 
 
 
 
 
 
 
13
 
14
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
17
  system_prompt = f.read()
18
 
 
19
  sys_msg = SystemMessage(content=system_prompt)
20
 
21
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
22
+ from supabase.client import create_client
23
+ supabase = create_client(
24
  os.environ.get("SUPABASE_URL"),
25
  os.environ.get("SUPABASE_SERVICE_KEY"))
26
  vector_store = SupabaseVectorStore(
27
  client=supabase,
28
+ embedding=embeddings,
29
  table_name="documents",
30
+ query_name="match_documents_langchain"
31
  )
32
 
33
  def build_graph(provider: str = "openai"):
 
34
  if provider == "openai":
35
  llm = ChatOpenAI(
36
+ model="gpt-3.5-turbo", # or "gpt-4o"
37
  temperature=0,
38
  openai_api_key=os.environ.get("OPENAI_API_KEY"),
39
  )
 
 
 
 
 
 
 
 
 
 
 
40
  else:
41
+ raise ValueError("Invalid provider.")
42
 
43
  llm_with_tools = llm.bind_tools(tools)
44
 
 
 
45
  def retriever(state: MessagesState):
46
+ query = state["messages"][-1].content
47
+ results = vector_store.similarity_search(query, k=1)
48
+ if not results:
49
+ return {"messages": [AIMessage(content="FINAL ANSWER: No relevant answer found.")]}
50
+ similar_doc = results[0]
51
+ content = similar_doc.page_content
52
+ if "FINAL ANSWER:" in content:
53
+ answer = content.split("FINAL ANSWER:")[-1].strip()
54
+ return {"messages": [AIMessage(content=f"FINAL ANSWER: {answer}")]}
55
+ else:
56
+ return {"messages": [AIMessage(content=content.strip())]}
57
+
58
+ def assistant(state: MessagesState):
59
+ user_message = state["messages"][-1]
60
+ # Make sure you send both system and user message
61
+ result = llm_with_tools.invoke([sys_msg, user_message])
62
+ return {"messages": [result]}
 
 
 
 
 
63
 
64
  builder = StateGraph(MessagesState)
65
  builder.add_node("retriever", retriever)