3rushi commited on
Commit
bd26388
·
verified ·
1 Parent(s): 297492b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +71 -162
agent.py CHANGED
@@ -1,187 +1,96 @@
1
- """LangGraph Agent"""
2
- import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- # from langchain_community.tools import create_retriever_tool
17
- from supabase.client import Client, create_client
18
 
19
- load_dotenv()
20
 
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
- return a * b
29
 
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
- Args:
35
- a: first int
36
- b: second int
37
- """
38
- return a + b
39
 
40
- @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
43
-
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
- return a - b
49
 
50
  @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
- """
58
- if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
- return a / b
61
 
62
- @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
- return a % b
71
 
72
  @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
-
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
 
85
 
86
- @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
-
90
- Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
99
-
100
-
101
- # load the system prompt from the file
102
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
103
- system_prompt = f.read()
104
-
105
- # System message
106
- sys_msg = SystemMessage(content=system_prompt)
107
-
108
- # build a retriever
109
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
110
- supabase: Client = create_client(
111
- os.environ.get("SUPABASE_URL"),
112
- os.environ.get("SUPABASE_SERVICE_KEY"))
113
- vector_store = SupabaseVectorStore(
114
- client=supabase,
115
- embedding= embeddings,
116
- table_name="documents",
117
- query_name="match_documents_langchain",
118
- )
119
- retriever_tool = create_retriever_tool(
120
- retriever=vector_store.as_retriever(),
121
- name="Question Search",
122
- description="A tool to retrieve similar questions from a vector store.",
123
- )
124
-
125
-
126
-
127
- tools = [
128
- multiply,
129
- add,
130
- subtract,
131
- divide,
132
- modulus,
133
- wiki_search,
134
- web_search,
135
- retriever_tool,
136
- ]
137
-
138
- # Build graph function
139
  def build_graph(provider: str = "google"):
140
- """Build the graph"""
141
- # Load environment variables from .env file
 
 
 
 
 
 
142
  if provider == "google":
143
- # Google Gemini
144
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
 
 
145
  elif provider == "groq":
146
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
147
- elif provider == "huggingface":
148
- # TODO: Add huggingface endpoint
149
- llm = ChatHuggingFace(
150
- llm=HuggingFaceEndpoint(
151
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
152
- temperature=0,
153
- ),
154
  )
155
  else:
156
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
157
- # Bind tools to LLM
158
- llm_with_tools = llm.bind_tools(tools)
159
 
160
- # Node
161
- def assistant(state: MessagesState):
162
- """Assistant node"""
163
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
164
-
165
- from langchain_core.messages import AIMessage
166
 
167
- def retriever(state: MessagesState):
168
- query = state["messages"][-1].content
169
- similar_doc = vector_store.similarity_search(query, k=1)[0]
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- content = similar_doc.page_content
172
- if "Final answer :" in content:
173
- answer = content.split("Final answer :")[-1].strip()
174
- else:
175
- answer = content.strip()
176
 
177
- return {"messages": [AIMessage(content=answer)]}
178
 
179
- builder = StateGraph(MessagesState)
180
- builder.add_node("retriever", retriever)
 
 
181
 
182
- # Retriever ist Start und Endpunkt
183
- builder.set_entry_point("retriever")
184
- builder.set_finish_point("retriever")
185
 
186
- # Compile graph
187
  return builder.compile()
 
1
+ from typing import List
2
+ from langgraph.graph import StateGraph, MessagesState, END
3
+ from langgraph.prebuilt import ToolNode, tools_condition
4
+
 
 
 
 
 
 
 
 
 
5
  from langchain_core.messages import SystemMessage, HumanMessage
6
  from langchain_core.tools import tool
 
 
7
 
8
+ from duckduckgo_search import DDGS
9
 
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain_groq import ChatGroq
 
 
 
 
 
 
12
 
13
+ import os
 
 
 
 
 
 
 
 
14
 
15
+ # ---------------------------------------------------------
16
+ # TOOLS
17
+ # ---------------------------------------------------------
 
 
 
 
 
 
18
 
19
  @tool
20
+ def web_search(query: str) -> str:
21
+ """Search the web using DuckDuckGo."""
22
+ with DDGS() as ddgs:
23
+ results = list(ddgs.text(query, max_results=3))
24
+ if not results:
25
+ return "No results found."
26
+ return "\n\n".join(r["body"] for r in results)
 
 
 
27
 
 
 
 
 
 
 
 
 
 
28
 
29
  @tool
30
+ def calculator(expression: str) -> str:
31
+ """Evaluate a math expression."""
32
+ try:
33
+ return str(eval(expression))
34
+ except Exception:
35
+ return "Error evaluating expression."
36
+
37
+
38
+ TOOLS = [web_search, calculator]
39
+
40
+ # ---------------------------------------------------------
41
+ # BUILD GRAPH
42
+ # ---------------------------------------------------------
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def build_graph(provider: str = "google"):
45
+ """
46
+ Build and return the LangGraph agent.
47
+ Supported providers: 'google', 'groq'
48
+ """
49
+
50
+ # -------------------------------
51
+ # LLM SELECTION
52
+ # -------------------------------
53
  if provider == "google":
54
+ llm = ChatGoogleGenerativeAI(
55
+ model="gemini-2.0-flash",
56
+ temperature=0
57
+ )
58
  elif provider == "groq":
59
+ llm = ChatGroq(
60
+ model="qwen-qwq-32b",
61
+ temperature=0
 
 
 
 
 
62
  )
63
  else:
64
+ raise ValueError("Invalid provider. Use 'google' or 'groq'.")
 
 
65
 
66
+ llm_with_tools = llm.bind_tools(TOOLS)
 
 
 
 
 
67
 
68
+ # -------------------------------
69
+ # ASSISTANT NODE
70
+ # -------------------------------
71
+ def assistant(state: MessagesState):
72
+ return {
73
+ "messages": [
74
+ llm_with_tools.invoke(state["messages"])
75
+ ]
76
+ }
77
+
78
+ # -------------------------------
79
+ # GRAPH
80
+ # -------------------------------
81
+ builder = StateGraph(MessagesState)
82
 
83
+ builder.add_node("assistant", assistant)
84
+ builder.add_node("tools", ToolNode(TOOLS))
 
 
 
85
 
86
+ builder.set_entry_point("assistant")
87
 
88
+ builder.add_conditional_edges(
89
+ "assistant",
90
+ tools_condition
91
+ )
92
 
93
+ builder.add_edge("tools", "assistant")
94
+ builder.add_edge("assistant", END)
 
95
 
 
96
  return builder.compile()