rakesh-dvg commited on
Commit
4c446cf
·
verified ·
1 Parent(s): f0d1562

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +26 -17
agent.py CHANGED
@@ -20,68 +20,76 @@ load_dotenv()
20
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
 
23
  return a * b
24
 
25
  @tool
26
  def add(a: int, b: int) -> int:
 
27
  return a + b
28
 
29
  @tool
30
  def subtract(a: int, b: int) -> int:
 
31
  return a - b
32
 
33
  @tool
34
  def divide(a: int, b: int) -> float:
 
35
  if b == 0:
36
  raise ValueError("Cannot divide by zero.")
37
  return a / b
38
 
39
  @tool
40
  def modulus(a: int, b: int) -> int:
 
41
  return a % b
42
 
43
  @tool
44
- def wiki_search(query: str) -> str:
 
45
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
46
- formatted = "\n\n---\n\n".join(
47
  [
48
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
49
  for doc in search_docs
50
  ])
51
- return {"wiki_results": formatted}
52
 
53
  @tool
54
- def web_search(query: str) -> str:
 
55
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
56
- formatted = "\n\n---\n\n".join(
57
  [
58
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
59
  for doc in search_docs
60
  ])
61
- return {"web_results": formatted}
62
 
63
  @tool
64
- def arvix_search(query: str) -> str:
 
65
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
66
- formatted = "\n\n---\n\n".join(
67
  [
68
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
69
  for doc in search_docs
70
  ])
71
- return {"arvix_results": formatted}
72
 
73
-
74
- # Load system prompt
75
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
76
  system_prompt = f.read()
77
 
78
  sys_msg = SystemMessage(content=system_prompt)
79
 
80
- # Setup vector store retriever
81
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
82
  supabase: Client = create_client(
83
- os.environ.get("SUPABASE_URL"),
84
- os.environ.get("SUPABASE_SERVICE_KEY"))
 
85
  vector_store = SupabaseVectorStore(
86
  client=supabase,
87
  embedding=embeddings,
@@ -106,6 +114,7 @@ tools = [
106
  ]
107
 
108
  def build_graph(provider: str = "groq"):
 
109
  if provider == "google":
110
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
111
  elif provider == "groq":
@@ -119,13 +128,14 @@ def build_graph(provider: str = "groq"):
119
  )
120
  else:
121
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
122
-
123
  llm_with_tools = llm.bind_tools(tools)
124
 
125
  def assistant(state: MessagesState):
 
126
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
127
 
128
  def retriever(state: MessagesState):
 
129
  similar_question = vector_store.similarity_search(state["messages"][0].content)
130
  example_msg = HumanMessage(
131
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
@@ -143,7 +153,6 @@ def build_graph(provider: str = "groq"):
143
 
144
  return builder.compile()
145
 
146
-
147
  if __name__ == "__main__":
148
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
149
  graph = build_graph(provider="groq")
 
20
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
+ """Multiply two integers and return the product."""
24
  return a * b
25
 
26
  @tool
27
  def add(a: int, b: int) -> int:
28
+ """Add two integers and return the sum."""
29
  return a + b
30
 
31
  @tool
32
  def subtract(a: int, b: int) -> int:
33
+ """Subtract second integer from first and return the difference."""
34
  return a - b
35
 
36
  @tool
37
  def divide(a: int, b: int) -> float:
38
+ """Divide first integer by second and return the quotient. Raises error if divisor is zero."""
39
  if b == 0:
40
  raise ValueError("Cannot divide by zero.")
41
  return a / b
42
 
43
  @tool
44
  def modulus(a: int, b: int) -> int:
45
+ """Return the modulus (remainder) of first integer divided by second."""
46
  return a % b
47
 
48
  @tool
49
+ def wiki_search(query: str) -> dict:
50
+ """Search Wikipedia for a query and return formatted top 2 results."""
51
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
+ formatted_search_docs = "\n\n---\n\n".join(
53
  [
54
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
  for doc in search_docs
56
  ])
57
+ return {"wiki_results": formatted_search_docs}
58
 
59
  @tool
60
+ def web_search(query: str) -> dict:
61
+ """Search the web via Tavily and return formatted top 3 results."""
62
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
63
+ formatted_search_docs = "\n\n---\n\n".join(
64
  [
65
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
66
  for doc in search_docs
67
  ])
68
+ return {"web_results": formatted_search_docs}
69
 
70
  @tool
71
+ def arvix_search(query: str) -> dict:
72
+ """Search Arxiv for a query and return formatted top 3 results (truncated content)."""
73
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
74
+ formatted_search_docs = "\n\n---\n\n".join(
75
  [
76
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
77
  for doc in search_docs
78
  ])
79
+ return {"arvix_results": formatted_search_docs}
80
 
81
+ # Load the system prompt from file
 
82
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
83
  system_prompt = f.read()
84
 
85
  sys_msg = SystemMessage(content=system_prompt)
86
 
87
+ # Build retriever
88
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
89
  supabase: Client = create_client(
90
+ os.environ.get("SUPABASE_URL"),
91
+ os.environ.get("SUPABASE_SERVICE_KEY"),
92
+ )
93
  vector_store = SupabaseVectorStore(
94
  client=supabase,
95
  embedding=embeddings,
 
114
  ]
115
 
116
  def build_graph(provider: str = "groq"):
117
+ """Build the LangGraph agent graph with the specified provider."""
118
  if provider == "google":
119
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
120
  elif provider == "groq":
 
128
  )
129
  else:
130
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
131
  llm_with_tools = llm.bind_tools(tools)
132
 
133
  def assistant(state: MessagesState):
134
+ """Assistant node to process messages with LLM and tools."""
135
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
136
 
137
  def retriever(state: MessagesState):
138
+ """Retriever node to find similar questions from vector store."""
139
  similar_question = vector_store.similarity_search(state["messages"][0].content)
140
  example_msg = HumanMessage(
141
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
 
153
 
154
  return builder.compile()
155
 
 
156
  if __name__ == "__main__":
157
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
158
  graph = build_graph(provider="groq")