PradeepBodhi commited on
Commit
d8aae99
·
verified ·
1 Parent(s): 9eeedaf

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +32 -83
agent.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
@@ -5,24 +6,18 @@ from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
- from langchain_community.chat_models import ChatHuggingFace
9
- from langchain_community.llms import HuggingFaceEndpoint
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
- from langchain_core.documents import Document
 
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
- from langchain_community.embeddings import HuggingFaceEmbeddings
17
- from langchain_community.vectorstores import SupabaseVectorStore
18
- from langchain.tools.retriever import create_retriever_tool
19
  from supabase.client import Client, create_client
20
- # from langchain_community.vectorstores import FAISS
21
 
22
  load_dotenv()
23
 
24
- # -------------------- Tools --------------------
25
-
26
  @tool
27
  def multiply(a: int, b: int) -> int:
28
  """Multiply two numbers.
@@ -118,26 +113,19 @@ def arvix_search(query: str) -> str:
118
  return {"arvix_results": formatted_search_docs}
119
 
120
 
121
- # -------------------- System Prompt --------------------
122
 
 
123
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
124
  system_prompt = f.read()
125
 
 
126
  sys_msg = SystemMessage(content=system_prompt)
127
 
128
- # -------------------- FAISS Retriever Setup --------------------
129
-
130
- # Use FAISS with HuggingFace embeddings
131
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
132
-
133
- supabase_url = os.environ.get("SUPABASE_URL")
134
- supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
135
-
136
- if not supabase_url or not supabase_key:
137
- raise ValueError("Supabase URL or Service Key not found in environment variables.")
138
-
139
- supabase: Client = create_client(supabase_url, supabase_key)
140
-
141
  vector_store = SupabaseVectorStore(
142
  client=supabase,
143
  embedding= embeddings,
@@ -150,24 +138,7 @@ create_retriever_tool = create_retriever_tool(
150
  description="A tool to retrieve similar questions from a vector store.",
151
  )
152
 
153
- # # FAISS must be initialized with data; here we use placeholder/example docs for illustration
154
- # # Replace with real documents if available
155
- # documents = [
156
- # Document(page_content="What is LangChain?", metadata={"source": "faq"}),
157
- # Document(page_content="How to use vector stores in LangChain?", metadata={"source": "guide"}),
158
- # ]
159
 
160
- # vector_store = FAISS.from_documents(documents, embeddings)
161
-
162
- # # Optional: save/load index to persist
163
- # # vector_store.save_local("faiss_index")
164
- # # vector_store = FAISS.load_local("faiss_index", embeddings)
165
-
166
- # retriever_tool = create_retriever_tool(
167
- # retriever=vector_store.as_retriever(),
168
- # name="Question Search",
169
- # description="A tool to retrieve similar questions from FAISS vector store.",
170
- # )
171
 
172
  tools = [
173
  multiply,
@@ -180,8 +151,6 @@ tools = [
180
  arvix_search,
181
  ]
182
 
183
- # -------------------- Graph --------------------
184
-
185
  # Build graph function
186
  def build_graph(provider: str = "groq"):
187
  """Build the graph"""
@@ -190,7 +159,6 @@ def build_graph(provider: str = "groq"):
190
  # Google Gemini
191
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
192
  elif provider == "groq":
193
- os.environ.get("GROQ_API_KEY")
194
  # Groq https://console.groq.com/docs/models
195
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
196
  elif provider == "huggingface":
@@ -207,37 +175,17 @@ def build_graph(provider: str = "groq"):
207
  llm_with_tools = llm.bind_tools(tools)
208
 
209
  # Node
210
- def retriever(state: MessagesState):
211
- try:
212
- if not state["messages"] or not hasattr(state["messages"][0], "content"):
213
- return {"messages": [sys_msg]}
214
-
215
- query = state["messages"][0].content
216
- print(f"Retriever query: {query}")
217
-
218
- similar_question = vector_store.similarity_search(query)
219
- print(f"Found {len(similar_question)} similar questions")
220
-
221
- if similar_question:
222
- example_msg = HumanMessage(
223
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
224
- )
225
- else:
226
- example_msg = HumanMessage(content="No similar questions found in the vector store.")
227
-
228
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
229
- except Exception as e:
230
- print(f"Retriever error: {e}")
231
- return {"messages": [sys_msg, HumanMessage(content=f"Retriever error: {e}")]}
232
-
233
  def assistant(state: MessagesState):
234
- try:
235
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
236
- except Exception as e:
237
- print(f"Assistant error: {e}")
238
- return {"messages": [sys_msg, HumanMessage(content=f"Assistant error: {e}")]}
239
-
240
-
 
 
 
241
 
242
  builder = StateGraph(MessagesState)
243
  builder.add_node("retriever", retriever)
@@ -254,12 +202,13 @@ def build_graph(provider: str = "groq"):
254
  # Compile graph
255
  return builder.compile()
256
 
257
- # -------------------- Optional Test --------------------
258
-
259
- # if __name__ == "__main__":
260
- # question = "What is LangChain used for?"
261
- # graph = build_graph(provider="groq")
262
- # messages = [HumanMessage(content=question)]
263
- # messages = graph.invoke({"messages": messages})
264
- # for m in messages["messages"]:
265
- # m.pretty_print()
 
 
1
+ """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
 
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
 
 
 
17
  from supabase.client import Client, create_client
 
18
 
19
  load_dotenv()
20
 
 
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
 
113
  return {"arvix_results": formatted_search_docs}
114
 
115
 
 
116
 
117
+ # load the system prompt from the file
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  system_prompt = f.read()
120
 
121
+ # System message
122
  sys_msg = SystemMessage(content=system_prompt)
123
 
124
+ # build a retriever
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
+ supabase: Client = create_client(
127
+ os.environ.get("SUPABASE_URL"),
128
+ os.environ.get("SUPABASE_SERVICE_KEY"))
 
 
 
 
 
 
 
 
129
  vector_store = SupabaseVectorStore(
130
  client=supabase,
131
  embedding= embeddings,
 
138
  description="A tool to retrieve similar questions from a vector store.",
139
  )
140
 
 
 
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  tools = [
144
  multiply,
 
151
  arvix_search,
152
  ]
153
 
 
 
154
  # Build graph function
155
  def build_graph(provider: str = "groq"):
156
  """Build the graph"""
 
159
  # Google Gemini
160
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
  elif provider == "groq":
 
162
  # Groq https://console.groq.com/docs/models
163
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
  elif provider == "huggingface":
 
175
  llm_with_tools = llm.bind_tools(tools)
176
 
177
  # Node
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def assistant(state: MessagesState):
179
+ """Assistant node"""
180
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
+
182
+ def retriever(state: MessagesState):
183
+ """Retriever node"""
184
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
185
+ example_msg = HumanMessage(
186
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
+ )
188
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
189
 
190
  builder = StateGraph(MessagesState)
191
  builder.add_node("retriever", retriever)
 
202
  # Compile graph
203
  return builder.compile()
204
 
205
+ # test
206
+ if __name__ == "__main__":
207
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
+ # Build the graph
209
+ graph = build_graph(provider="groq")
210
+ # Run the graph
211
+ messages = [HumanMessage(content=question)]
212
+ messages = graph.invoke({"messages": messages})
213
+ for m in messages["messages"]:
214
+ m.pretty_print()