rakesh-dvg commited on
Commit
209de17
·
verified ·
1 Parent(s): 4c446cf

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +39 -83
agent.py CHANGED
@@ -1,120 +1,85 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
- """Multiply two integers and return the product."""
24
  return a * b
25
 
26
  @tool
27
  def add(a: int, b: int) -> int:
28
- """Add two integers and return the sum."""
29
  return a + b
30
 
31
  @tool
32
  def subtract(a: int, b: int) -> int:
33
- """Subtract second integer from first and return the difference."""
34
  return a - b
35
 
36
  @tool
37
  def divide(a: int, b: int) -> float:
38
- """Divide first integer by second and return the quotient. Raises error if divisor is zero."""
39
  if b == 0:
40
  raise ValueError("Cannot divide by zero.")
41
  return a / b
42
 
43
  @tool
44
  def modulus(a: int, b: int) -> int:
45
- """Return the modulus (remainder) of first integer divided by second."""
46
  return a % b
47
 
48
  @tool
49
  def wiki_search(query: str) -> dict:
50
- """Search Wikipedia for a query and return formatted top 2 results."""
51
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
- formatted_search_docs = "\n\n---\n\n".join(
53
- [
54
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
- for doc in search_docs
56
- ])
57
- return {"wiki_results": formatted_search_docs}
58
 
59
  @tool
60
  def web_search(query: str) -> dict:
61
- """Search the web via Tavily and return formatted top 3 results."""
62
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
63
- formatted_search_docs = "\n\n---\n\n".join(
64
- [
65
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
66
- for doc in search_docs
67
- ])
68
- return {"web_results": formatted_search_docs}
69
 
70
  @tool
71
  def arvix_search(query: str) -> dict:
72
- """Search Arxiv for a query and return formatted top 3 results (truncated content)."""
73
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
74
- formatted_search_docs = "\n\n---\n\n".join(
75
- [
76
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
77
- for doc in search_docs
78
- ])
79
- return {"arvix_results": formatted_search_docs}
80
-
81
- # Load the system prompt from file
82
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
83
  system_prompt = f.read()
84
 
85
  sys_msg = SystemMessage(content=system_prompt)
86
 
87
- # Build retriever
88
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
89
- supabase: Client = create_client(
90
- os.environ.get("SUPABASE_URL"),
91
- os.environ.get("SUPABASE_SERVICE_KEY"),
92
- )
93
- vector_store = SupabaseVectorStore(
94
- client=supabase,
95
- embedding=embeddings,
96
- table_name="documents",
97
- query_name="match_documents_langchain",
98
- )
99
- create_retriever_tool = create_retriever_tool(
100
- retriever=vector_store.as_retriever(),
101
- name="Question Search",
102
- description="A tool to retrieve similar questions from a vector store.",
103
- )
104
-
105
  tools = [
106
- multiply,
107
- add,
108
- subtract,
109
- divide,
110
- modulus,
111
- wiki_search,
112
- web_search,
113
- arvix_search,
114
  ]
115
 
116
  def build_graph(provider: str = "groq"):
117
- """Build the LangGraph agent graph with the specified provider."""
118
  if provider == "google":
119
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
120
  elif provider == "groq":
@@ -124,39 +89,30 @@ def build_graph(provider: str = "groq"):
124
  llm=HuggingFaceEndpoint(
125
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
126
  temperature=0,
127
- ),
128
  )
129
  else:
130
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
131
  llm_with_tools = llm.bind_tools(tools)
132
 
133
  def assistant(state: MessagesState):
134
- """Assistant node to process messages with LLM and tools."""
135
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
136
 
137
- def retriever(state: MessagesState):
138
- """Retriever node to find similar questions from vector store."""
139
- similar_question = vector_store.similarity_search(state["messages"][0].content)
140
- example_msg = HumanMessage(
141
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
142
- )
143
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
144
-
145
  builder = StateGraph(MessagesState)
146
- builder.add_node("retriever", retriever)
147
  builder.add_node("assistant", assistant)
148
  builder.add_node("tools", ToolNode(tools))
149
- builder.add_edge(START, "retriever")
150
- builder.add_edge("retriever", "assistant")
151
  builder.add_conditional_edges("assistant", tools_condition)
152
  builder.add_edge("tools", "assistant")
153
 
154
  return builder.compile()
155
 
156
  if __name__ == "__main__":
157
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
158
- graph = build_graph(provider="groq")
 
159
  messages = [HumanMessage(content=question)]
160
- messages = graph.invoke({"messages": messages})
161
- for m in messages["messages"]:
162
- m.pretty_print()
 
1
+ """LangGraph Agent (No Supabase)"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition, ToolNode
 
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
 
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
  from langchain_core.tools import tool
 
 
13
 
14
  load_dotenv()
15
 
16
  @tool
17
  def multiply(a: int, b: int) -> int:
18
+ """Multiply two integers and return the result."""
19
  return a * b
20
 
21
  @tool
22
  def add(a: int, b: int) -> int:
23
+ """Add two integers and return the result."""
24
  return a + b
25
 
26
  @tool
27
  def subtract(a: int, b: int) -> int:
28
+ """Subtract b from a and return the result."""
29
  return a - b
30
 
31
  @tool
32
  def divide(a: int, b: int) -> float:
33
+ """Divide a by b and return the result."""
34
  if b == 0:
35
  raise ValueError("Cannot divide by zero.")
36
  return a / b
37
 
38
  @tool
39
  def modulus(a: int, b: int) -> int:
40
+ """Return the modulus of a and b."""
41
  return a % b
42
 
43
  @tool
44
  def wiki_search(query: str) -> dict:
45
+ """Search Wikipedia for a query and return up to 2 results."""
46
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
47
+ results = "\n\n---\n\n".join(
48
+ f"<Document>\n{doc.page_content}\n</Document>" for doc in search_docs
49
+ )
50
+ return {"wiki_results": results}
 
 
51
 
52
  @tool
53
  def web_search(query: str) -> dict:
54
+ """Search the web via Tavily and return up to 3 results."""
55
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
56
+ results = "\n\n---\n\n".join(
57
+ f"<Document>\n{doc.page_content}\n</Document>" for doc in search_docs
58
+ )
59
+ return {"web_results": results}
 
 
60
 
61
  @tool
62
  def arvix_search(query: str) -> dict:
63
+ """Search Arxiv and return up to 3 truncated results."""
64
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
65
+ results = "\n\n---\n\n".join(
66
+ f"<Document>\n{doc.page_content[:500]}\n</Document>" for doc in search_docs
67
+ )
68
+ return {"arvix_results": results}
69
+
70
+ # Load system prompt
 
 
71
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
72
  system_prompt = f.read()
73
 
74
  sys_msg = SystemMessage(content=system_prompt)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  tools = [
77
+ multiply, add, subtract, divide, modulus,
78
+ wiki_search, web_search, arvix_search
 
 
 
 
 
 
79
  ]
80
 
81
  def build_graph(provider: str = "groq"):
82
+ """Build the LangGraph agent with selected LLM provider."""
83
  if provider == "google":
84
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
85
  elif provider == "groq":
 
89
  llm=HuggingFaceEndpoint(
90
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
91
  temperature=0,
92
+ )
93
  )
94
  else:
95
+ raise ValueError("Invalid provider: choose 'google', 'groq' or 'huggingface'.")
96
+
97
  llm_with_tools = llm.bind_tools(tools)
98
 
99
  def assistant(state: MessagesState):
 
100
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
101
 
 
 
 
 
 
 
 
 
102
  builder = StateGraph(MessagesState)
 
103
  builder.add_node("assistant", assistant)
104
  builder.add_node("tools", ToolNode(tools))
105
+ builder.add_edge(START, "assistant")
 
106
  builder.add_conditional_edges("assistant", tools_condition)
107
  builder.add_edge("tools", "assistant")
108
 
109
  return builder.compile()
110
 
111
  if __name__ == "__main__":
112
+ from langchain_core.messages import HumanMessage
113
+ question = "What is the capital of France and its population?"
114
+ graph = build_graph()
115
  messages = [HumanMessage(content=question)]
116
+ result = graph.invoke({"messages": messages})
117
+ for msg in result["messages"]:
118
+ print(msg.content)