manohargottam commited on
Commit
1551d5f
·
verified ·
1 Parent(s): 1b2b2d6

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +98 -70
agent.py CHANGED
@@ -1,5 +1,6 @@
1
  """LangGraph Agent"""
2
  import os
 
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
@@ -18,6 +19,20 @@ from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
@@ -75,13 +90,16 @@ def wiki_search(query: str) -> str:
75
 
76
  Args:
77
  query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
 
 
 
85
 
86
  @tool
87
  def web_search(query: str) -> str:
@@ -89,13 +107,25 @@ def web_search(query: str) -> str:
89
 
90
  Args:
91
  query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
- ])
98
- return {"web_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  @tool
101
  def arvix_search(query: str) -> str:
@@ -103,19 +133,23 @@ def arvix_search(query: str) -> str:
103
 
104
  Args:
105
  query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
 
115
 
116
  # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
 
 
 
119
 
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
@@ -125,20 +159,22 @@ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-b
125
  supabase_url = "https://ajnakgegqblhwltzkzbz.supabase.co"
126
  supabase_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo"
127
 
128
- supabase: Client = create_client(supabase_url, supabase_key)
129
- vector_store = SupabaseVectorStore(
130
- client=supabase,
131
- embedding= embeddings,
132
- table_name="documents",
133
- query_name="match_documents_langchain",
134
- )
135
- create_retriever_tool = create_retriever_tool(
136
- retriever=vector_store.as_retriever(),
137
- name="Question Search",
138
- description="A tool to retrieve similar questions from a vector store.",
139
- )
140
-
141
-
 
 
142
 
143
  tools = [
144
  multiply,
@@ -169,39 +205,31 @@ def build_graph(provider: str = "groq"):
169
  """Assistant node"""
170
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
171
 
172
- # def retriever(state: MessagesState):
173
- # """Retriever node"""
174
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
175
- #example_msg = HumanMessage(
176
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
177
- # )
178
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
179
-
180
  from langchain_core.messages import AIMessage
181
 
182
  def retriever(state: MessagesState):
183
- query = state["messages"][-1].content
184
- similar_doc = vector_store.similarity_search(query, k=1)[0]
185
-
186
- content = similar_doc.page_content
187
- if "Final answer :" in content:
188
- answer = content.split("Final answer :")[-1].strip()
189
- else:
190
- answer = content.strip()
191
-
192
- return {"messages": [AIMessage(content=answer)]}
193
-
194
- # builder = StateGraph(MessagesState)
195
- #builder.add_node("retriever", retriever)
196
- #builder.add_node("assistant", assistant)
197
- #builder.add_node("tools", ToolNode(tools))
198
- #builder.add_edge(START, "retriever")
199
- #builder.add_edge("retriever", "assistant")
200
- #builder.add_conditional_edges(
201
- # "assistant",
202
- # tools_condition,
203
- #)
204
- #builder.add_edge("tools", "assistant")
205
 
206
  builder = StateGraph(MessagesState)
207
  builder.add_node("retriever", retriever)
@@ -211,4 +239,4 @@ def build_graph(provider: str = "groq"):
211
  builder.set_finish_point("retriever")
212
 
213
  # Compile graph
214
- return builder.compile()
 
1
  """LangGraph Agent"""
2
  import os
3
+ import json
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
 
19
 
20
  load_dotenv()
21
 
22
+ def safe_get_metadata(doc, key, default=""):
23
+ """Safely extract metadata from document, handling string and dict formats"""
24
+ try:
25
+ if isinstance(doc.metadata, str):
26
+ # Try to parse as JSON if it's a string
27
+ metadata = json.loads(doc.metadata)
28
+ elif isinstance(doc.metadata, dict):
29
+ metadata = doc.metadata
30
+ else:
31
+ return default
32
+ return metadata.get(key, default)
33
+ except (json.JSONDecodeError, AttributeError):
34
+ return default
35
+
36
  @tool
37
  def multiply(a: int, b: int) -> int:
38
  """Multiply two numbers.
 
90
 
91
  Args:
92
  query: The search query."""
93
+ try:
94
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
95
+ formatted_search_docs = "\n\n---\n\n".join(
96
+ [
97
+ f'<Document source="{safe_get_metadata(doc, "source")}" page="{safe_get_metadata(doc, "page")}"/>\n{doc.page_content}\n</Document>'
98
+ for doc in search_docs
99
+ ])
100
+ return {"wiki_results": formatted_search_docs}
101
+ except Exception as e:
102
+ return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}
103
 
104
  @tool
105
  def web_search(query: str) -> str:
 
107
 
108
  Args:
109
  query: The search query."""
110
+ try:
111
+ search_tool = TavilySearchResults(max_results=3)
112
+ search_results = search_tool.invoke(query)
113
+
114
+ # Handle the case where search_results might be a list of dicts or Document objects
115
+ if isinstance(search_results, list):
116
+ formatted_search_docs = "\n\n---\n\n".join(
117
+ [
118
+ f'<Document source="{result.get("url", "")}" />\n{result.get("content", "")}\n</Document>'
119
+ if isinstance(result, dict) else
120
+ f'<Document source="{safe_get_metadata(result, "source")}" page="{safe_get_metadata(result, "page")}"/>\n{result.page_content}\n</Document>'
121
+ for result in search_results
122
+ ])
123
+ else:
124
+ formatted_search_docs = str(search_results)
125
+
126
+ return {"web_results": formatted_search_docs}
127
+ except Exception as e:
128
+ return {"web_results": f"Error searching web: {str(e)}"}
129
 
130
  @tool
131
  def arvix_search(query: str) -> str:
 
133
 
134
  Args:
135
  query: The search query."""
136
+ try:
137
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
138
+ formatted_search_docs = "\n\n---\n\n".join(
139
+ [
140
+ f'<Document source="{safe_get_metadata(doc, "source")}" page="{safe_get_metadata(doc, "page")}"/>\n{doc.page_content[:1000]}\n</Document>'
141
+ for doc in search_docs
142
+ ])
143
+ return {"arvix_results": formatted_search_docs}
144
+ except Exception as e:
145
+ return {"arvix_results": f"Error searching Arxiv: {str(e)}"}
146
 
147
  # load the system prompt from the file
148
+ try:
149
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
150
+ system_prompt = f.read()
151
+ except FileNotFoundError:
152
+ system_prompt = "You are a helpful AI assistant."
153
 
154
  # System message
155
  sys_msg = SystemMessage(content=system_prompt)
 
159
  supabase_url = "https://ajnakgegqblhwltzkzbz.supabase.co"
160
  supabase_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo"
161
 
162
+ try:
163
+ supabase: Client = create_client(supabase_url, supabase_key)
164
+ vector_store = SupabaseVectorStore(
165
+ client=supabase,
166
+ embedding= embeddings,
167
+ table_name="documents",
168
+ query_name="match_documents_langchain",
169
+ )
170
+ create_retriever_tool = create_retriever_tool(
171
+ retriever=vector_store.as_retriever(),
172
+ name="Question Search",
173
+ description="A tool to retrieve similar questions from a vector store.",
174
+ )
175
+ except Exception as e:
176
+ print(f"Warning: Could not initialize vector store: {e}")
177
+ vector_store = None
178
 
179
  tools = [
180
  multiply,
 
205
  """Assistant node"""
206
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
207
 
 
 
 
 
 
 
 
 
208
  from langchain_core.messages import AIMessage
209
 
210
  def retriever(state: MessagesState):
211
+ """Retriever node with error handling"""
212
+ try:
213
+ if vector_store is None:
214
+ return {"messages": [AIMessage(content="Vector store not available.")]}
215
+
216
+ query = state["messages"][-1].content
217
+ similar_docs = vector_store.similarity_search(query, k=1)
218
+
219
+ if not similar_docs:
220
+ return {"messages": [AIMessage(content="No similar documents found.")]}
221
+
222
+ similar_doc = similar_docs[0]
223
+ content = similar_doc.page_content
224
+
225
+ if "Final answer :" in content:
226
+ answer = content.split("Final answer :")[-1].strip()
227
+ else:
228
+ answer = content.strip()
229
+
230
+ return {"messages": [AIMessage(content=answer)]}
231
+ except Exception as e:
232
+ return {"messages": [AIMessage(content=f"Error in retriever: {str(e)}")]}
233
 
234
  builder = StateGraph(MessagesState)
235
  builder.add_node("retriever", retriever)
 
239
  builder.set_finish_point("retriever")
240
 
241
  # Compile graph
242
+ return builder.compile()