philincloud commited on
Commit
3d2ef38
Β·
verified Β·
1 Parent(s): 966c093

Update langgraph_agent.py

Browse files
Files changed (1) hide show
  1. langgraph_agent.py +12 -7
langgraph_agent.py CHANGED
@@ -8,31 +8,37 @@ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
  from langchain_core.messages import SystemMessage, HumanMessage
9
  from langchain_core.tools import tool
10
 
11
- #tools
12
  @tool
13
  def multiply(a: int, b: int) -> int:
 
14
  return a * b
15
 
16
  @tool
17
  def add(a: int, b: int) -> int:
 
18
  return a + b
19
 
20
  @tool
21
  def subtract(a: int, b: int) -> int:
 
22
  return a - b
23
 
24
  @tool
25
  def divide(a: int, b: int) -> float:
 
26
  if b == 0:
27
  raise ValueError("Cannot divide by zero.")
28
  return a / b
29
 
30
  @tool
31
  def modulus(a: int, b: int) -> int:
 
32
  return a % b
33
 
 
34
  @tool
35
  def wiki_search(query: str) -> dict:
 
36
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
37
  formatted = "\n\n---\n\n".join(
38
  f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
@@ -42,6 +48,7 @@ def wiki_search(query: str) -> dict:
42
 
43
  @tool
44
  def web_search(query: str) -> dict:
 
45
  docs = TavilySearchResults(max_results=3).invoke(query=query)
46
  formatted = "\n\n---\n\n".join(
47
  f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
@@ -51,6 +58,7 @@ def web_search(query: str) -> dict:
51
 
52
  @tool
53
  def arvix_search(query: str) -> dict:
 
54
  docs = ArxivLoader(query=query, load_max_docs=3).load()
55
  formatted = "\n\n---\n\n".join(
56
  f'<Document source="{d.metadata["source"]}"/>\n{d.page_content[:1000]}'
@@ -58,24 +66,21 @@ def arvix_search(query: str) -> dict:
58
  )
59
  return {"arvix_results": formatted}
60
 
61
-
62
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
63
  HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN")
64
 
65
- # ───────────────────────────────────────────────────────────────────────────────
66
- # 4) Assemble tool list
67
  tools = [
68
  multiply, add, subtract, divide, modulus,
69
  wiki_search, web_search, arvix_search,
70
  ]
71
 
72
- # ───────────────────────────────────────────────────────────────────────────────
73
- # 5) Load your system prompt
74
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
75
  system_prompt = f.read()
76
  sys_msg = SystemMessage(content=system_prompt)
77
 
78
- # ───────────────────────────────────────────────────────────────────────────────
79
  def build_graph(provider: str = "openai"):
80
  """Build the LangGraph agent with chosen LLM (default: OpenAI)."""
81
  if provider == "openai":
 
8
  from langchain_core.messages import SystemMessage, HumanMessage
9
  from langchain_core.tools import tool
10
 
 
11
  @tool
12
  def multiply(a: int, b: int) -> int:
13
+ """Multiply two integers."""
14
  return a * b
15
 
16
  @tool
17
  def add(a: int, b: int) -> int:
18
+ """Add two integers."""
19
  return a + b
20
 
21
  @tool
22
  def subtract(a: int, b: int) -> int:
23
+ """Subtract the second integer from the first."""
24
  return a - b
25
 
26
  @tool
27
  def divide(a: int, b: int) -> float:
28
+ """Divide first integer by second; error if divisor is zero."""
29
  if b == 0:
30
  raise ValueError("Cannot divide by zero.")
31
  return a / b
32
 
33
  @tool
34
  def modulus(a: int, b: int) -> int:
35
+ """Return the remainder of dividing first integer by second."""
36
  return a % b
37
 
38
+
39
  @tool
40
  def wiki_search(query: str) -> dict:
41
+ """Search Wikipedia for a query and return up to 2 documents."""
42
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
43
  formatted = "\n\n---\n\n".join(
44
  f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
 
48
 
49
  @tool
50
  def web_search(query: str) -> dict:
51
+ """Perform a web search (via Tavily) and return up to 3 results."""
52
  docs = TavilySearchResults(max_results=3).invoke(query=query)
53
  formatted = "\n\n---\n\n".join(
54
  f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
 
58
 
59
  @tool
60
  def arvix_search(query: str) -> dict:
61
+ """Search arXiv for a query and return up to 3 paper excerpts."""
62
  docs = ArxivLoader(query=query, load_max_docs=3).load()
63
  formatted = "\n\n---\n\n".join(
64
  f'<Document source="{d.metadata["source"]}"/>\n{d.page_content[:1000]}'
 
66
  )
67
  return {"arvix_results": formatted}
68
 
 
69
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
70
  HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN")
71
 
72
+
 
73
  tools = [
74
  multiply, add, subtract, divide, modulus,
75
  wiki_search, web_search, arvix_search,
76
  ]
77
 
78
+
 
79
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
80
  system_prompt = f.read()
81
  sys_msg = SystemMessage(content=system_prompt)
82
 
83
+
84
  def build_graph(provider: str = "openai"):
85
  """Build the LangGraph agent with chosen LLM (default: OpenAI)."""
86
  if provider == "openai":