surya07 commited on
Commit
398ce8d
·
verified ·
1 Parent(s): cb4bb20

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +188 -255
agent.py CHANGED
@@ -1,279 +1,212 @@
 
1
  import os
2
- import certifi
3
- os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
 
 
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
- from langchain_community.vectorstores import Chroma
 
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
  from langchain_core.tools import tool
14
  from langchain.tools.retriever import create_retriever_tool
15
- from langchain_core.documents import Document
16
- from langchain_community.embeddings import HuggingFaceEmbeddings
17
- from langchain.text_splitter import RecursiveCharacterTextSplitter
18
- from langchain_groq import ChatGroq
19
 
20
  load_dotenv()
21
 
22
-
23
- # ---------------- CONFIGURATION ----------------
24
- # Change this to any valid Hugging Face model endpoint (e.g., meta-llama/Llama-3-8b-chat-hf)
25
- HF_MODEL_NAME = os.getenv("LLAMA_MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
26
- HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
27
- HF_MODEL_URL = f"https://api-inference.huggingface.co/models/{HF_MODEL_NAME}"
28
- # Use the OpenAI-compatible inference endpoint
29
- HF_OPENAI_URL = "https://api-inference.huggingface.co/openai"
30
- # ---------------- UTILITY TOOLS ----------------
31
- @tool
32
- def multiply_numbers(x: int, y: int) -> int:
33
- """Multiply two integers and return the result."""
34
- return x * y
35
-
36
- @tool
37
- def add_numbers(x: int, y: int) -> int:
38
- """Add two integers and return the sum."""
39
- return x + y
40
-
41
- @tool
42
- def subtract_numbers(x: int, y: int) -> int:
43
- """Subtract the second integer from the first and return the result."""
44
- return x - y
45
-
46
- @tool
47
- def divide_numbers(x: int, y: int) -> float:
48
- """Divide the first number by the second and return the result. Raises an error on division by zero."""
49
- if y == 0:
50
- raise ValueError("Division by zero is not allowed.")
51
- return x / y
52
-
53
- @tool
54
- def modulus_numbers(x: int, y: int) -> int:
55
- """Return the remainder when the first number is divided by the second."""
56
- return x % y
57
-
58
- @tool
59
- def power_numbers(base: float, exponent: float) -> float:
60
- """Raise the base to the power of exponent and return the result."""
61
- return base ** exponent
62
-
63
- @tool
64
- def root_number(value: float, n: float) -> float:
65
- """Compute the nth root of a value and return the result."""
66
- return value ** (1 / n)
67
-
68
- @tool
69
- def wiki_lookup(query: str) -> str:
70
- """Search Wikipedia for the query and return up to 2 summarized documents."""
71
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
72
- return "\n\n---\n\n".join(
73
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
74
- )
75
-
76
- @tool
77
- def web_lookup(query: str) -> str:
78
- """Search the web using Tavily and return up to 3 summarized results."""
79
- docs = TavilySearchResults(max_results=3).invoke(query=query)
80
- return "\n\n---\n\n".join(
81
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
82
- )
83
-
84
- @tool
85
- def arxiv_lookup(query: str) -> str:
86
- """Search arXiv for the query and return summaries of up to 3 papers."""
87
- docs = ArxivLoader(query=query, load_max_docs=3).load()
88
- return "\n\n---\n\n".join(
89
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content[:800]}</Document>' for d in docs
90
- )
91
- @tool
92
- def add_numbers(x: int, y: int) -> int:
93
- """Add two integers and return the sum."""
94
- return x + y
95
-
96
- @tool
97
- def subtract_numbers(x: int, y: int) -> int:
98
- """Subtract the second integer from the first and return the result."""
99
- return x - y
100
-
101
- @tool
102
- def divide_numbers(x: int, y: int) -> float:
103
- """Divide the first number by the second and return the result. Raises an error on division by zero."""
104
- if y == 0:
105
- raise ValueError("Division by zero is not allowed.")
106
- return x / y
107
-
108
  @tool
109
- def modulus_numbers(x: int, y: int) -> int:
110
- """Return the remainder when the first number is divided by the second."""
111
- return x % y
112
-
113
- @tool
114
- def power_numbers(base: float, exponent: float) -> float:
115
- """Raise the base to the power of exponent and return the result."""
116
- return base ** exponent
117
-
118
- @tool
119
- def root_number(value: float, n: float) -> float:
120
- """Compute the nth root of a value and return the result."""
121
- return value ** (1 / n)
122
-
123
- @tool
124
- def wiki_lookup(query: str) -> str:
125
- """Search Wikipedia for the query and return up to 2 summarized documents."""
126
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
127
- return "\n\n---\n\n".join(
128
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
129
- )
130
-
131
- @tool
132
- def web_lookup(query: str) -> str:
133
- """Search the web using Tavily and return up to 3 summarized results."""
134
- docs = TavilySearchResults(max_results=3).invoke(query=query)
135
- return "\n\n---\n\n".join(
136
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content}</Document>' for d in docs
137
- )
138
-
139
- @tool
140
- def arxiv_lookup(query: str) -> str:
141
- """Search arXiv for the query and return summaries of up to 3 papers."""
142
- docs = ArxivLoader(query=query, load_max_docs=3).load()
143
- return "\n\n---\n\n".join(
144
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page", "")}"/>{d.page_content[:800]}</Document>' for d in docs
145
- )
146
-
147
- # # ---------------- SETUP LOCAL VECTORSTORE ----------------
148
- # embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
149
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
150
- # sample_docs = [Document(page_content="St. Thomas Aquinas was a medieval Catholic priest and philosopher.", metadata={"source": "wiki", "page": "St. Thomas Aquinas"})]
151
- # split_docs = text_splitter.split_documents(sample_docs)
152
- # vector_db = Chroma.from_documents(documents=split_docs, embedding=embedding_model)
153
- # retriever_tool = create_retriever_tool(
154
- # retriever=vector_db.as_retriever(),
155
- # name="SimilarQuestionFinder",
156
- # description="Retrieve similar questions and examples from vector DB."
157
- # )
158
-
159
- # # ---------------- SYSTEM PROMPT ----------------
160
- # with open("system_prompt.txt", "r", encoding="utf-8") as f:
161
- # system_content = f.read()
162
- # system_message = SystemMessage(content=system_content)
163
-
164
- # # ---------------- BUILD STATE GRAPH ----------------
165
- # def construct_agent_graph():
166
- # llama_llm = ChatHuggingFace(
167
- # llm=HuggingFaceEndpoint(
168
- # endpoint_url=HF_OPENAI_URL,
169
- # temperature=0
170
- # )
171
- # ).bind_tools([
172
- # multiply_numbers,
173
- # add_numbers,
174
- # subtract_numbers,
175
- # divide_numbers,
176
- # modulus_numbers,
177
- # power_numbers,
178
- # root_number,
179
- # wiki_lookup,
180
- # web_lookup,
181
- # arxiv_lookup,
182
- # retriever_tool,
183
- # ])
184
-
185
- # def retrieve_node(state: MessagesState):
186
- # similar = vector_db.similarity_search(state["messages"][0].content)
187
- # hint = HumanMessage(content=f"Reference example:\n{similar[0].page_content}" if similar else "")
188
- # return {"messages": [system_message] + state["messages"] + [hint]}
189
-
190
- # def respond_node(state: MessagesState):
191
- # return {"messages": [llama_llm.invoke(state["messages"]) ]}
192
-
193
- # graph_builder = StateGraph(MessagesState)
194
- # graph_builder.add_node("find_similar", retrieve_node)
195
- # graph_builder.add_node("generate_answer", respond_node)
196
- # graph_builder.add_node("tool_executor", ToolNode([]))
197
-
198
- # graph_builder.add_edge(START, "find_similar")
199
- # graph_builder.add_edge("find_similar", "generate_answer")
200
- # graph_builder.add_conditional_edges(
201
- # "generate_answer",
202
- # tools_condition,
203
- # {"tools": "tool_executor", "default": "generate_answer"}
204
- # )
205
- # graph_builder.add_edge("tool_executor", "generate_answer")
206
-
207
- # return graph_builder.compile()
208
-
209
- # # ---------------- RUN EXAMPLE ----------------
210
- # if __name__ == "__main__":
211
- # sample_q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
212
- # agent = construct_agent_graph()
213
- # msgs = [HumanMessage(content=sample_q)]
214
- # out = agent.invoke({"messages": msgs})
215
- # for m in out["messages"]:
216
- # m.pretty_print()
217
-
218
- # ---------------- EMBEDDINGS & VECTOR DB ----------------
219
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
220
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
221
- sample_docs = [Document(page_content="Sample doc.", metadata={"source":"wiki"})]
222
- split_docs = text_splitter.split_documents(sample_docs)
223
- vector_db = Chroma.from_documents(documents=split_docs, embedding=embedding_model)
224
- retriever_tool = create_retriever_tool(
225
- retriever=vector_db.as_retriever(),
226
- name="SimilarQuestionFinder",
227
- description="Retrieve similar questions and examples from vector DB."
228
- )
229
-
230
- all_tools = [multiply_numbers, add_numbers, subtract_numbers, divide_numbers,
231
- modulus_numbers, power_numbers, root_number,
232
- wiki_lookup, web_lookup, arxiv_lookup, retriever_tool]
233
-
234
- # ---------------- SYSTEM PROMPT ----------------
235
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
236
- system_content = f.read()
237
- system_message = SystemMessage(content=system_content)
238
- # ---------------- BUILD GRAPH ----------------
239
- def construct_agent_graph():
240
- llama_llm = ChatGroq(
241
- model="qwen-qwq-32b",
242
- api_key=os.environ["GROQ_API_KEY"],
243
- temperature=0,
244
- )
245
-
246
- def retrieve_node(state: MessagesState):
247
- msgs = [system_message] + state["messages"]
248
- similar = vector_db.similarity_search(state["messages"][0].content)
249
- if similar:
250
- msgs.append(HumanMessage(content=f"Reference example:\n{similar[0].page_content}"))
251
- return {"messages": msgs}
 
 
 
 
 
252
 
253
- def respond_node(state: MessagesState):
254
- return {"messages": [llama_llm.invoke(state["messages"])]}
255
 
256
- graph = StateGraph(MessagesState)
257
- graph.add_node("find_similar", retrieve_node)
258
- graph.add_node("generate_answer", respond_node)
259
- graph.add_node("tool_executor", ToolNode(tools=all_tools))
260
 
261
- graph.add_edge(START, "find_similar")
262
- graph.add_edge("find_similar", "generate_answer")
263
- graph.add_conditional_edges(
264
- "generate_answer",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  tools_condition,
266
- {"tools": "tool_executor", "__end__": "__end__"}
267
  )
268
- graph.add_edge("tool_executor", "generate_answer")
269
 
270
- return graph.compile()
 
271
 
272
- # ---------------- RUN EXAMPLE ----------------
273
  if __name__ == "__main__":
274
- agent = construct_agent_graph()
275
- sample_q = "When was St. Thomas Aquinas added to that page?"
276
- out = agent.invoke({"messages": [HumanMessage(content=sample_q)]})
277
- for m in out["messages"]:
 
 
 
278
  m.pretty_print()
279
-
 
1
+ """LangGraph Agent"""
2
  import os
 
 
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
 
 
 
18
 
19
  load_dotenv()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+
25
+ Args:
26
+ a: first int
27
+ b: second int
28
+ """
29
+ return a * b
30
+
31
+ @tool
32
+ def add(a: int, b: int) -> int:
33
+ """Add two numbers.
34
+
35
+ Args:
36
+ a: first int
37
+ b: second int
38
+ """
39
+ return a + b
40
+
41
+ @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """Subtract two numbers.
44
+
45
+ Args:
46
+ a: first int
47
+ b: second int
48
+ """
49
+ return a - b
50
+
51
+ @tool
52
+ def divide(a: int, b: int) -> int:
53
+ """Divide two numbers.
54
+
55
+ Args:
56
+ a: first int
57
+ b: second int
58
+ """
59
+ if b == 0:
60
+ raise ValueError("Cannot divide by zero.")
61
+ return a / b
62
+
63
+ @tool
64
+ def modulus(a: int, b: int) -> int:
65
+ """Get the modulus of two numbers.
66
+
67
+ Args:
68
+ a: first int
69
+ b: second int
70
+ """
71
+ return a % b
72
+
73
+ @tool
74
+ def wiki_search(query: str) -> str:
75
+ """Search Wikipedia for a query and return maximum 2 results.
76
+
77
+ Args:
78
+ query: The search query."""
79
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
+ formatted_search_docs = "\n\n---\n\n".join(
81
+ [
82
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
+ for doc in search_docs
84
+ ])
85
+ return {"wiki_results": formatted_search_docs}
86
+
87
+ @tool
88
+ def web_search(query: str) -> str:
89
+ """Search Tavily for a query and return maximum 3 results.
90
+
91
+ Args:
92
+ query: The search query."""
93
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
+ formatted_search_docs = "\n\n---\n\n".join(
95
+ [
96
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
+ for doc in search_docs
98
+ ])
99
+ return {"web_results": formatted_search_docs}
100
+
101
+ @tool
102
+ def arvix_search(query: str) -> str:
103
+ """Search Arxiv for a query and return maximum 3 result.
104
+
105
+ Args:
106
+ query: The search query."""
107
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
+ formatted_search_docs = "\n\n---\n\n".join(
109
+ [
110
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
+ for doc in search_docs
112
+ ])
113
+ return {"arvix_results": formatted_search_docs}
114
+
115
+
116
+
117
+ # load the system prompt from the file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
+ system_prompt = f.read()
120
+
121
+ # System message
122
+ sys_msg = SystemMessage(content=system_prompt)
123
+
124
+ # build a retriever
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
+ supabase: Client = create_client(
127
+ os.environ.get("SUPABASE_URL"),
128
+ os.environ.get("SUPABASE_SERVICE_KEY"))
129
+ vector_store = SupabaseVectorStore(
130
+ client=supabase,
131
+ embedding= embeddings,
132
+ table_name="documents",
133
+ query_name="match_documents_langchain",
134
+ )
135
+ create_retriever_tool = create_retriever_tool(
136
+ retriever=vector_store.as_retriever(),
137
+ name="Question Search",
138
+ description="A tool to retrieve similar questions from a vector store.",
139
+ )
140
 
 
 
141
 
 
 
 
 
142
 
143
+ tools = [
144
+ multiply,
145
+ add,
146
+ subtract,
147
+ divide,
148
+ modulus,
149
+ wiki_search,
150
+ web_search,
151
+ arvix_search,
152
+ ]
153
+
154
+ # Build graph function
155
+ def build_graph(provider: str = "huggingface"):
156
+ """Build the graph"""
157
+ # Load environment variables from .env file
158
+ if provider == "google":
159
+ # Google Gemini
160
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
+ elif provider == "groq":
162
+ # Groq https://console.groq.com/docs/models
163
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
+ elif provider == "huggingface":
165
+ llm = ChatHuggingFace(
166
+ llm=HuggingFaceEndpoint(
167
+ repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
168
+ ),
169
+ )
170
+ else:
171
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
172
+ # Bind tools to LLM
173
+ llm_with_tools = llm.bind_tools(tools)
174
+
175
+ # Node
176
+ def assistant(state: MessagesState):
177
+ """Assistant node"""
178
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
179
+
180
+ def retriever(state: MessagesState):
181
+ """Retriever node"""
182
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
183
+ example_msg = HumanMessage(
184
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
185
+ )
186
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
187
+
188
+ builder = StateGraph(MessagesState)
189
+ builder.add_node("retriever", retriever)
190
+ builder.add_node("assistant", assistant)
191
+ builder.add_node("tools", ToolNode(tools))
192
+ builder.add_edge(START, "retriever")
193
+ builder.add_edge("retriever", "assistant")
194
+ builder.add_conditional_edges(
195
+ "assistant",
196
  tools_condition,
 
197
  )
198
+ builder.add_edge("tools", "assistant")
199
 
200
+ # Compile graph
201
+ return builder.compile()
202
 
203
+ # test
204
  if __name__ == "__main__":
205
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
206
+ # Build the graph
207
+ graph = build_graph(provider="groq")
208
+ # Run the graph
209
+ messages = [HumanMessage(content=question)]
210
+ messages = graph.invoke({"messages": messages})
211
+ for m in messages["messages"]:
212
  m.pretty_print()