Shekarss commited on
Commit
0a0a03c
·
verified ·
1 Parent(s): 668ca23

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +112 -186
agent.py CHANGED
@@ -1,226 +1,152 @@
1
- """LangGraph Agent"""
2
- import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
 
 
 
 
18
 
19
- load_dotenv()
 
 
 
 
 
 
 
 
 
20
 
21
  @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
 
24
  Args:
25
- a: first int
26
- b: second int
27
  """
28
- return a * b
 
29
 
30
  @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
  Args:
35
- a: first int
36
- b: second int
 
37
  """
38
- return a + b
39
 
40
  @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
44
  Args:
45
- a: first int
46
- b: second int
 
47
  """
48
  return a - b
49
 
50
  @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
  Args:
55
- a: first int
56
- b: second int
 
57
  """
58
- if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
- return a / b
61
 
62
  @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
  Args:
67
- a: first int
68
- b: second int
 
69
  """
70
  return a % b
71
 
72
  @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
-
76
  Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
 
86
  @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
-
90
  Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
99
 
100
  @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
-
104
  Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
- sys_msg = SystemMessage(content=system_prompt)
122
-
123
- supabase_url = os.environ["SUPABASE_URL"]
124
- supabase_key = os.environ["SUPABASE_SERVICE_KEY"]
125
-
126
- # build a retriever
127
- embeddings = HuggingFaceEmbeddings(
128
- model_name="sentence-transformers/all-mpnet-base-v2"
129
- ) # dim=768
130
- supabase: Client = create_client(supabase_url, supabase_key)
131
- vector_store = SupabaseVectorStore(
132
- client=supabase,
133
- embedding=embeddings,
134
- table_name="documents",
135
- query_name="match_documents_langchain",
136
- )
137
- create_retriever_tool = create_retriever_tool(
138
- retriever=vector_store.as_retriever(),
139
- name="Question Search",
140
- description="A tool to retrieve similar questions from a vector store.",
141
- )
142
-
143
 
 
144
 
145
- tools = [
146
- multiply,
147
- add,
148
- subtract,
149
- divide,
150
- modulus,
151
- wiki_search,
152
- web_search,
153
- arvix_search,
154
- ]
155
 
156
- # Build graph function
157
- def build_graph(provider: str = "google"):
158
- """Build the graph"""
159
- # Load environment variables from .env file
160
- if provider == "google":
161
- # Google Gemini
162
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
163
- elif provider == "groq":
164
- # Groq https://console.groq.com/docs/models
165
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
166
- elif provider == "huggingface":
167
- # TODO: Add huggingface endpoint
168
- llm = ChatHuggingFace(
169
- llm=HuggingFaceEndpoint(
170
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
171
- temperature=0,
172
- ),
173
- )
174
- else:
175
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
176
- # Bind tools to LLM
177
  llm_with_tools = llm.bind_tools(tools)
178
-
179
- # Node
180
- def assistant(state: MessagesState):
181
- """Assistant node"""
182
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
183
-
184
- # def retriever(state: MessagesState):
185
- # """Retriever node"""
186
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
187
- #example_msg = HumanMessage(
188
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
189
- # )
190
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
191
-
192
- from langchain_core.messages import AIMessage
193
-
194
- def retriever(state: MessagesState):
195
- query = state["messages"][-1].content
196
- similar_doc = vector_store.similarity_search(query, k=1)[0]
197
-
198
- content = similar_doc.page_content
199
- if "Final answer :" in content:
200
- answer = content.split("Final answer :")[-1].strip()
201
- else:
202
- answer = content.strip()
203
-
204
- return {"messages": [AIMessage(content=answer)]}
205
-
206
- # builder = StateGraph(MessagesState)
207
- #builder.add_node("retriever", retriever)
208
- #builder.add_node("assistant", assistant)
209
- #builder.add_node("tools", ToolNode(tools))
210
- #builder.add_edge(START, "retriever")
211
- #builder.add_edge("retriever", "assistant")
212
- #builder.add_conditional_edges(
213
- # "assistant",
214
- # tools_condition,
215
- #)
216
- #builder.add_edge("tools", "assistant")
217
-
218
- builder = StateGraph(MessagesState)
219
- builder.add_node("retriever", retriever)
220
-
221
- # Retriever ist Start und Endpunkt
222
- builder.set_entry_point("retriever")
223
- builder.set_finish_point("retriever")
224
-
225
- # Compile graph
226
- return builder.compile()
 
1
+ from llama_index.core import SimpleDirectoryReader
 
 
 
 
2
  from langgraph.prebuilt import ToolNode
 
 
 
 
3
  from langchain_community.document_loaders import WikipediaLoader
4
+ from langchain_community.tools import DuckDuckGoSearchRun
 
 
5
  from langchain_core.tools import tool
6
+ from typing import TypedDict, Annotated, List, Any, Dict
7
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, AnyMessage, SystemMessage
8
+ from langgraph.graph.message import add_messages
9
+ from langchain_groq import ChatGroq
10
+ from langgraph.graph import StateGraph, START, END
11
+ from langgraph.prebuilt import ToolNode, tools_condition
12
 
13
+ @tool
14
+ def read_file(directory:str)->str:
15
+ """
16
+ Reads text from files in a directory
17
+ Args:
18
+ directory (string): Takes in a directory name.
19
+ output_type: text
20
+ """
21
+ docs = SimpleDirectoryReader(directory).load_data()
22
+ return ' '.join([doc.text for doc in docs])
23
 
24
  @tool
25
+ def web_search(query:str)->str:
26
+ """
27
+ Inputs a query and use DuckDuckGoSearchRun to search and fetch information from web.
28
  Args:
29
+ query (string): Takes a query
30
+ output_type: text
31
  """
32
+ search = DuckDuckGoSearchRun()
33
+ return search.invoke(query)
34
 
35
  @tool
36
+ def multiply(a:float, b:float)->float:
37
+ """
38
+ Multiplies two numbers.
39
  Args:
40
+ a (float): the first number
41
+ b (float): the second number
42
+ output_type: float
43
  """
44
+ return a*b
45
 
46
  @tool
47
+ def divide(a:float, b:float)->float:
48
+ """
49
+ Divides two numbers.
50
+ Args:
51
+ a (float): the first number
52
+ b (float): the second number
53
+ output_type: float
54
+ """
55
+ if b != 0:
56
+ return a / b
57
+ else:
58
+ raise ValueError('Cannot divide a number by zero')
59
 
60
+ @tool
61
+ def subtract(a:float, b:float)->float:
62
+ """
63
+ Subtracts two numbers.
64
  Args:
65
+ a (float): the first number
66
+ b (float): the second number
67
+ output_type: float
68
  """
69
  return a - b
70
 
71
  @tool
72
+ def add(a:float, b:float)->float:
73
+ """
74
+ Adds two numbers.
75
  Args:
76
+ a (float): the first number
77
+ b (float): the second number
78
+ output_type: float
79
  """
80
+ return a + b
 
 
81
 
82
  @tool
83
+ def modulus(a:int, b:int)->int:
84
+ """
85
+ Get the modulus of two numbers.
86
  Args:
87
+ a (int): the first number
88
+ b (int): the second number
89
+ output_type: int
90
  """
91
  return a % b
92
 
93
  @tool
94
+ def power(a:float, b:float)->float:
95
+ """
96
+ Get the power of two numbers.
97
  Args:
98
+ a (float): the first number
99
+ b (float): the second number
100
+ output_type: float
101
+ """
102
+ return a**b
 
 
 
103
 
104
  @tool
105
+ def square_root(a:float)->float | complex:
106
+ """
107
+ Get the square root of a numbers.
108
  Args:
109
+ a (float): the first number
110
+ output_type: float
111
+ """
112
+ if a >= 0:
113
+ return a**0.5
114
+ return cmath.sqrt(a)
 
 
115
 
116
  @tool
117
+ def wikipedia_fetcher(query:str)->str:
118
+ """
119
+ Inputs a query and use WikipediaLoader to fetch query realted information.
120
  Args:
121
+ query (string): Takes a query
122
+ output_type: text
123
+ """
124
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
125
+ doc = docs[0].page_content
126
+ return doc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ tools = [read_file, wikipedia_fetcher, web_search, add, subtract, divide, multiply, modulus, power, square_root]
129
 
130
+ class AgentState(TypedDict):
131
+ messages: Annotated[List[AnyMessage], add_messages]
 
 
 
 
 
 
 
 
132
 
133
+ def call_model(state: AgentState):
134
+ llm = ChatGroq(model="qwen/qwen3-32b", temperature=0, api_key='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')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  llm_with_tools = llm.bind_tools(tools)
136
+ response = llm_with_tools.invoke(state['messages'])
137
+ return {'messages': [response]}
138
+
139
+ def build_graph():
140
+ workflow = StateGraph(AgentState)
141
+ workflow.add_node('llm', call_model)
142
+ workflow.add_node("call_tool", ToolNode(tools))
143
+ workflow.add_edge(START, 'llm')
144
+ workflow.add_conditional_edges(
145
+ 'llm',
146
+ tools_condition,
147
+ {'tools':'call_tool', '__end__':END}
148
+ )
149
+ workflow.add_edge('call_tool', 'llm')
150
+
151
+ app = workflow.compile()
152
+ return app