akrstova commited on
Commit
2a6057d
·
1 Parent(s): 4c241f1

Initial version of agent with math and search tools

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
agent.py CHANGED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_core.messages import SystemMessage, HumanMessage
8
+ from langchain_core.tools import tool
9
+ from tools.math_tools import add, subtract, multiply, divide, modulus, power, sqrt
10
+ from tools.search_tools import search_wikipedia, web_search
11
+
12
+ load_dotenv()
13
+
14
+ def build_graph():
15
+
16
+ llm = ChatGoogleGenerativeAI(
17
+ model="gemini-2.0-flash-001",
18
+ temperature=0.8,
19
+ max_tokens=None,
20
+ timeout=None,
21
+ max_retries=2,
22
+ )
23
+ tools = [add, subtract, multiply, divide, modulus, power, sqrt, web_search, search_wikipedia]
24
+
25
+ llm_with_tools = llm.bind_tools(tools)
26
+
27
+ def assistant(state: MessagesState):
28
+ """Assistant node for invoking the LLM."""
29
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
30
+
31
+ builder = StateGraph(MessagesState)
32
+ builder.add_node("assistant", assistant)
33
+ builder.add_node("tools", ToolNode(tools))
34
+ builder.add_edge(START, "assistant")
35
+ builder.add_conditional_edges(
36
+ "assistant",
37
+ tools_condition,
38
+ )
39
+ builder.add_edge("tools", "assistant")
40
+
41
+ # Compile graph
42
+ return builder.compile()
43
+
44
+ if __name__ == "__main__":
45
+ question = "Give me a summary of the wikipedia page for photosynthesis"
46
+ # Build the graph
47
+ graph = build_graph()
48
+ # Run the graph
49
+ messages = [HumanMessage(content=question)]
50
+ messages = graph.invoke({"messages": messages})
51
+ for m in messages["messages"]:
52
+ m.pretty_print()
app.py CHANGED
@@ -3,6 +3,8 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -12,12 +14,29 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
 
15
  print("BasicAgent initialized.")
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from langchain_core.messages import HumanMessage
7
+ from agent import build_graph
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
14
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
15
  class BasicAgent:
16
  def __init__(self):
17
+ self.graph = build_graph()
18
  print("BasicAgent initialized.")
19
  def __call__(self, question: str) -> str:
20
  print(f"Agent received question (first 50 chars): {question[:50]}...")
21
+ messages = [HumanMessage(content=question)]
22
+ try:
23
+ result = self.graph.invoke({"messages": messages})
24
+
25
+ if not result.get("messages") or len(result["messages"]) == 0:
26
+ raise ValueError("No messages returned from the graph")
27
+
28
+ last_message = result["messages"][-1].content
29
+
30
+ if "final answer" in last_message.lower():
31
+ split_final_answer = last_message.lower().split("final answer")
32
+ final_answer = split_final_answer[1].strip()
33
+ else:
34
+ final_answer = last_message
35
+ return final_answer
36
+ except Exception as e:
37
+ print("Agent processing failed")
38
+ return
39
+
40
 
41
  def run_and_submit_all( profile: gr.OAuthProfile | None):
42
  """
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from final-assignment-template!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "final-assignment-template"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "dotenv>=0.9.9",
9
+ "gradio>=5.29.0",
10
+ "langchain-community>=0.3.23",
11
+ "langchain-core>=0.3.59",
12
+ "langchain-google-genai>=2.1.4",
13
+ "langgraph>=0.4.3",
14
+ "pandas>=2.2.3",
15
+ "requests>=2.32.3",
16
+ "wikipedia>=1.4.0",
17
+ ]
tools/__init__.py ADDED
File without changes
tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (172 Bytes). View file
 
tools/__pycache__/math_tools.cpython-312.pyc ADDED
Binary file (2.68 kB). View file
 
tools/__pycache__/search_tools.cpython-312.pyc ADDED
Binary file (1.91 kB). View file
 
tools/math_tools.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains tools for performing mathematical operations."""
2
+ from langchain_core.tools import tool
3
+
4
+ @tool
5
+ def add(x: float, y: float) -> float:
6
+ """
7
+ Adds two numbers.
8
+
9
+ Args:
10
+ x (float): the first number
11
+ y (float): the second number
12
+ """
13
+ return x + y
14
+
15
+ @tool
16
+ def subtract(x: float, y: float) -> float:
17
+ """
18
+ Subtracts two numbers.
19
+
20
+ Args:
21
+ x (float): the first number
22
+ y (float): the second number
23
+ """
24
+ return x - y
25
+
26
+ @tool
27
+ def multiply(x: float, y: float) -> float:
28
+ """
29
+ Multiplies two numbers.
30
+
31
+ Args:
32
+ x (float): the first number
33
+ y (float): the second number
34
+ """
35
+ return x * y
36
+
37
+ @tool
38
+ def divide(x: float, y: float) -> float:
39
+ """
40
+ Divides two numbers.
41
+
42
+ Args:
43
+ x (float): the first number
44
+ y (float): the second number to divide by
45
+ """
46
+ if y == 0:
47
+ raise ValueError("Cannot divide by zero")
48
+ return x / y
49
+
50
+ @tool
51
+ def modulus(x: float, y: float) -> float:
52
+ """
53
+ Returns the remainder when x is divided by y.
54
+
55
+ Args:
56
+ x (float): the dividend
57
+ y (float): the divisor
58
+ """
59
+ if y == 0:
60
+ raise ValueError("Cannot compute modulus with zero divisor")
61
+ return x % y
62
+
63
+ @tool
64
+ def power(x: float, y: float) -> float:
65
+ """
66
+ Raises x to the power of y.
67
+
68
+ Args:
69
+ x (float): the base number
70
+ y (float): the exponent
71
+ """
72
+ return x ** y
73
+
74
+ @tool
75
+ def sqrt(x: float) -> float:
76
+ """
77
+ Returns the square root of x.
78
+
79
+ Args:
80
+ x (float): the number to find the square root of
81
+ """
82
+ if x < 0:
83
+ raise ValueError("Cannot compute square root of negative number")
84
+ return x ** 0.5
85
+
tools/search_tools.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains tools for searching external sources."""
2
+ from langchain_core.tools import tool
3
+ from langchain_community.tools.tavily_search import TavilySearchResults
4
+ from langchain_community.document_loaders import WikipediaLoader
5
+
6
+
7
+ @tool
8
+ def web_search(query: str) -> str:
9
+ """Search the web via Tavily for a given query and return maximum 3 results.
10
+
11
+ Args:
12
+ query (str): The search query
13
+
14
+ Returns:
15
+ str: Content from the Tavily search
16
+ """
17
+ search_results = TavilySearchResults(max_results=3).invoke(query=query)
18
+ return {"web_results": search_results}
19
+
20
+ @tool
21
+ def search_wikipedia(query: str) -> str:
22
+ """
23
+ Searches Wikipedia and returns content from the topic.
24
+
25
+ Args:
26
+ query (str): The search query/topic to look up on Wikipedia
27
+
28
+ Returns:
29
+ str: Content from the Wikipedia article for the query
30
+
31
+ Raises:
32
+ ValueError: If no matching page is found
33
+ """
34
+ try:
35
+ loader = WikipediaLoader(query=query, load_max_docs=1)
36
+ docs = loader.load()
37
+ if not docs:
38
+ return f"No Wikipedia page found for '{query}'"
39
+ return docs[0].page_content
40
+ except Exception as e:
41
+ return f"Error searching Wikipedia: {str(e)}"
42
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff