rakesh-dvg commited on
Commit
0c1ed4c
·
verified ·
1 Parent(s): b419bcf

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +36 -26
agent.py CHANGED
@@ -1,4 +1,4 @@
1
- """LangGraph Agent (No Supabase)"""
2
  import os
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition, ToolNode
@@ -12,57 +12,63 @@ from langchain_core.tools import tool
12
 
13
  @tool
14
  def multiply(a: int, b: int) -> int:
15
- """Multiply two integers and return the result."""
16
  return a * b
17
 
18
  @tool
19
  def add(a: int, b: int) -> int:
20
- """Add two integers and return the result."""
21
  return a + b
22
 
23
  @tool
24
  def subtract(a: int, b: int) -> int:
25
- """Subtract b from a and return the result."""
26
  return a - b
27
 
28
  @tool
29
  def divide(a: int, b: int) -> float:
30
- """Divide a by b and return the result. Raises an error if b is zero."""
31
  if b == 0:
32
  raise ValueError("Cannot divide by zero.")
33
  return a / b
34
 
35
  @tool
36
  def modulus(a: int, b: int) -> int:
37
- """Return the modulus (remainder) of a divided by b."""
38
  return a % b
39
 
40
  @tool
41
  def wiki_search(query: str) -> dict:
42
  """Search Wikipedia for a query and return up to 2 results."""
43
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
- results = "\n\n---\n\n".join(
45
- f"<Document>\n{doc.page_content}\n</Document>" for doc in search_docs
46
- )
47
- return {"wiki_results": results}
 
 
48
 
49
  @tool
50
  def web_search(query: str) -> dict:
51
- """Search the web via Tavily and return up to 3 results."""
52
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
53
- results = "\n\n---\n\n".join(
54
- f"<Document>\n{doc.page_content}\n</Document>" for doc in search_docs
55
- )
56
- return {"web_results": results}
 
 
57
 
58
  @tool
59
  def arvix_search(query: str) -> dict:
60
  """Search Arxiv and return up to 3 truncated results."""
61
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
62
- results = "\n\n---\n\n".join(
63
- f"<Document>\n{doc.page_content[:500]}\n</Document>" for doc in search_docs
64
- )
65
- return {"arvix_results": results}
 
 
66
 
67
  # Load system prompt
68
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
@@ -76,14 +82,17 @@ tools = [
76
  ]
77
 
78
  def build_graph(provider: str = "groq"):
79
- """Build the LangGraph agent with selected LLM provider."""
 
80
  if provider == "google":
81
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
 
82
  elif provider == "groq":
83
- groq_api_key = os.environ.get("GROQ_API_KEY")
84
  if not groq_api_key:
85
- raise ValueError("GROQ_API_KEY is not set in the environment.")
86
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_api_key)
 
87
  elif provider == "huggingface":
88
  llm = ChatHuggingFace(
89
  llm=HuggingFaceEndpoint(
@@ -91,8 +100,9 @@ def build_graph(provider: str = "groq"):
91
  temperature=0,
92
  )
93
  )
 
94
  else:
95
- raise ValueError("Invalid provider: choose 'google', 'groq' or 'huggingface'.")
96
 
97
  llm_with_tools = llm.bind_tools(tools)
98
 
@@ -108,10 +118,10 @@ def build_graph(provider: str = "groq"):
108
 
109
  return builder.compile()
110
 
 
111
  if __name__ == "__main__":
112
- from langchain_core.messages import HumanMessage
113
- question = "What is the capital of France and its population?"
114
- graph = build_graph()
115
  messages = [HumanMessage(content=question)]
116
  result = graph.invoke({"messages": messages})
117
  for msg in result["messages"]:
 
1
+ """LangGraph Agent (GROQ version without Supabase)"""
2
  import os
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition, ToolNode
 
12
 
13
  @tool
14
  def multiply(a: int, b: int) -> int:
15
+ """Multiply two numbers."""
16
  return a * b
17
 
18
  @tool
19
  def add(a: int, b: int) -> int:
20
+ """Add two numbers."""
21
  return a + b
22
 
23
  @tool
24
  def subtract(a: int, b: int) -> int:
25
+ """Subtract second number from the first."""
26
  return a - b
27
 
28
  @tool
29
  def divide(a: int, b: int) -> float:
30
+ """Divide two numbers."""
31
  if b == 0:
32
  raise ValueError("Cannot divide by zero.")
33
  return a / b
34
 
35
  @tool
36
  def modulus(a: int, b: int) -> int:
37
+ """Get the modulus (remainder) of two numbers."""
38
  return a % b
39
 
40
  @tool
41
  def wiki_search(query: str) -> dict:
42
  """Search Wikipedia for a query and return up to 2 results."""
43
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
+ formatted_search_docs = "\n\n---\n\n".join(
45
+ [
46
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
47
+ for doc in search_docs
48
+ ])
49
+ return {"wiki_results": formatted_search_docs}
50
 
51
  @tool
52
  def web_search(query: str) -> dict:
53
+ """Search Tavily for a query and return up to 3 results."""
54
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
55
+ formatted_search_docs = "\n\n---\n\n".join(
56
+ [
57
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
58
+ for doc in search_docs
59
+ ])
60
+ return {"web_results": formatted_search_docs}
61
 
62
  @tool
63
  def arvix_search(query: str) -> dict:
64
  """Search Arxiv and return up to 3 truncated results."""
65
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
66
+ formatted_search_docs = "\n\n---\n\n".join(
67
+ [
68
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
69
+ for doc in search_docs
70
+ ])
71
+ return {"arvix_results": formatted_search_docs}
72
 
73
  # Load system prompt
74
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
 
82
  ]
83
 
84
  def build_graph(provider: str = "groq"):
85
+ """Build the LangGraph agent using specified LLM provider."""
86
+
87
  if provider == "google":
88
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
89
+
90
  elif provider == "groq":
91
+ groq_api_key = os.getenv("GROQ_API_KEY")
92
  if not groq_api_key:
93
+ raise ValueError("GROQ_API_KEY environment variable not set.")
94
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=groq_api_key)
95
+
96
  elif provider == "huggingface":
97
  llm = ChatHuggingFace(
98
  llm=HuggingFaceEndpoint(
 
100
  temperature=0,
101
  )
102
  )
103
+
104
  else:
105
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
106
 
107
  llm_with_tools = llm.bind_tools(tools)
108
 
 
118
 
119
  return builder.compile()
120
 
121
+ # For testing purposes
122
  if __name__ == "__main__":
123
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
124
+ graph = build_graph(provider="groq")
 
125
  messages = [HumanMessage(content=question)]
126
  result = graph.invoke({"messages": messages})
127
  for msg in result["messages"]: