rohittayde commited on
Commit
ade0954
·
verified ·
1 Parent(s): ea76d69

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +136 -81
agent.py CHANGED
@@ -1,5 +1,6 @@
1
- """LangGraph Agent"""
2
  import os
 
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
@@ -13,14 +14,11 @@ from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
-
17
  from supabase.client import Client, create_client
18
- # --- langchain create_retriever_tool fallback (paste near other imports) ---
19
- # NOTE: removed the unconditional import that caused ModuleNotFoundError.
20
- import traceback
21
 
 
22
  try:
23
- # Prefer the real helper if available
24
  from langchain.tools.retriever import create_retriever_tool # type: ignore
25
  HAS_CREATE_RETRIEVER_TOOL = True
26
  except Exception:
@@ -29,6 +27,10 @@ except Exception:
29
  print(traceback.format_exc())
30
 
31
  class _SimpleRetrieverTool:
 
 
 
 
32
  def __init__(self, retriever, name="retriever", description=""):
33
  self.name = name
34
  self.description = description
@@ -69,6 +71,7 @@ except Exception:
69
  """
70
  return _SimpleRetrieverTool(retriever, name=name, description=description)
71
 
 
72
  load_dotenv()
73
 
74
  @tool
@@ -128,13 +131,16 @@ def wiki_search(query: str) -> str:
128
 
129
  Args:
130
  query: The search query."""
131
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
132
- formatted_search_docs = "\n\n---\n\n".join(
133
- [
134
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
135
- for doc in search_docs
136
- ])
137
- return {"wiki_results": formatted_search_docs}
 
 
 
138
 
139
  @tool
140
  def web_search(query: str) -> str:
@@ -142,13 +148,16 @@ def web_search(query: str) -> str:
142
 
143
  Args:
144
  query: The search query."""
145
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
146
- formatted_search_docs = "\n\n---\n\n".join(
147
- [
148
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
149
- for doc in search_docs
150
- ])
151
- return {"web_results": formatted_search_docs}
 
 
 
152
 
153
  @tool
154
  def arvix_search(query: str) -> str:
@@ -156,14 +165,16 @@ def arvix_search(query: str) -> str:
156
 
157
  Args:
158
  query: The search query."""
159
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
160
- formatted_search_docs = "\n\n---\n\n".join(
161
- [
162
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
163
- for doc in search_docs
164
- ])
165
- return {"arvix_results": formatted_search_docs}
166
-
 
 
167
 
168
 
169
  # load the system prompt from the file
@@ -173,24 +184,53 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
173
  # System message
174
  sys_msg = SystemMessage(content=system_prompt)
175
 
176
- # build a retriever
177
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
178
- supabase: Client = create_client(
179
- os.environ.get("SUPABASE_URL"),
180
- os.environ.get("SUPABASE_SERVICE_KEY"))
181
- vector_store = SupabaseVectorStore(
182
- client=supabase,
183
- embedding= embeddings,
184
- table_name="documents",
185
- query_name="match_documents_langchain",
186
- )
187
- retriever_tool = create_retriever_tool(
188
- retriever=vector_store.as_retriever(),
189
- name="Question Search",
190
- description="A tool to retrieve similar questions from a vector store.",
191
- )
192
-
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
 
196
  tools = [
@@ -204,6 +244,20 @@ tools = [
204
  arvix_search,
205
  ]
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  # Build graph function
208
  def build_graph(provider: str = "google"):
209
  """Build the graph"""
@@ -213,7 +267,7 @@ def build_graph(provider: str = "google"):
213
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
214
  elif provider == "groq":
215
  # Groq https://console.groq.com/docs/models
216
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
217
  elif provider == "huggingface":
218
  # TODO: Add huggingface endpoint
219
  llm = ChatHuggingFace(
@@ -224,52 +278,53 @@ def build_graph(provider: str = "google"):
224
  )
225
  else:
226
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
227
- # Bind tools to LLM
228
- llm_with_tools = llm.bind_tools(tools)
229
 
230
- # Node
 
 
 
 
 
 
 
 
231
  def assistant(state: MessagesState):
232
  """Assistant node"""
233
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
234
-
235
- # def retriever(state: MessagesState):
236
- # """Retriever node"""
237
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
238
- #example_msg = HumanMessage(
239
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
240
- # )
241
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
242
 
243
  from langchain_core.messages import AIMessage
244
 
245
  def retriever(state: MessagesState):
246
  query = state["messages"][-1].content
247
- similar_doc = vector_store.similarity_search(query, k=1)[0]
248
-
249
- content = similar_doc.page_content
250
- if "Final answer :" in content:
251
- answer = content.split("Final answer :")[-1].strip()
252
- else:
253
- answer = content.strip()
254
-
255
- return {"messages": [AIMessage(content=answer)]}
256
-
257
- # builder = StateGraph(MessagesState)
258
- #builder.add_node("retriever", retriever)
259
- #builder.add_node("assistant", assistant)
260
- #builder.add_node("tools", ToolNode(tools))
261
- #builder.add_edge(START, "retriever")
262
- #builder.add_edge("retriever", "assistant")
263
- #builder.add_conditional_edges(
264
- # "assistant",
265
- # tools_condition,
266
- #)
267
- #builder.add_edge("tools", "assistant")
268
-
269
  builder = StateGraph(MessagesState)
270
  builder.add_node("retriever", retriever)
271
 
272
- # Retriever ist Start und Endpunkt
273
  builder.set_entry_point("retriever")
274
  builder.set_finish_point("retriever")
275
 
 
1
+ """LangGraph Agent (patched for robustness)"""
2
  import os
3
+ import traceback
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
 
14
  from langchain_community.vectorstores import SupabaseVectorStore
15
  from langchain_core.messages import SystemMessage, HumanMessage
16
  from langchain_core.tools import tool
 
17
  from supabase.client import Client, create_client
 
 
 
18
 
19
+ # --- Safe import + fallback for langchain.tools.retriever.create_retriever_tool ---
20
  try:
21
+ # Try to import the real helper (if the installed langchain provides it)
22
  from langchain.tools.retriever import create_retriever_tool # type: ignore
23
  HAS_CREATE_RETRIEVER_TOOL = True
24
  except Exception:
 
27
  print(traceback.format_exc())
28
 
29
  class _SimpleRetrieverTool:
30
+ """
31
+ Minimal tool-like wrapper providing a `.run(query)` method.
32
+ Most templates call tool.run(query) — adapt if your code uses a different interface.
33
+ """
34
  def __init__(self, retriever, name="retriever", description=""):
35
  self.name = name
36
  self.description = description
 
71
  """
72
  return _SimpleRetrieverTool(retriever, name=name, description=description)
73
 
74
+
75
  load_dotenv()
76
 
77
  @tool
 
131
 
132
  Args:
133
  query: The search query."""
134
+ try:
135
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
136
+ formatted_search_docs = "\n\n---\n\n".join(
137
+ [
138
+ f'<Document source="{doc.metadata.get("source","")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
139
+ for doc in search_docs
140
+ ])
141
+ return {"wiki_results": formatted_search_docs}
142
+ except Exception as e:
143
+ return {"wiki_results_error": str(e)}
144
 
145
  @tool
146
  def web_search(query: str) -> str:
 
148
 
149
  Args:
150
  query: The search query."""
151
+ try:
152
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
153
+ formatted_search_docs = "\n\n---\n\n".join(
154
+ [
155
+ f'<Document source="{doc.metadata.get("source","")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
156
+ for doc in search_docs
157
+ ])
158
+ return {"web_results": formatted_search_docs}
159
+ except Exception as e:
160
+ return {"web_results_error": str(e)}
161
 
162
  @tool
163
  def arvix_search(query: str) -> str:
 
165
 
166
  Args:
167
  query: The search query."""
168
+ try:
169
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
170
+ formatted_search_docs = "\n\n---\n\n".join(
171
+ [
172
+ f'<Document source="{doc.metadata.get("source","")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
173
+ for doc in search_docs
174
+ ])
175
+ return {"arvix_results": formatted_search_docs}
176
+ except Exception as e:
177
+ return {"arvix_results_error": str(e)}
178
 
179
 
180
  # load the system prompt from the file
 
184
  # System message
185
  sys_msg = SystemMessage(content=system_prompt)
186
 
187
+ # --- Build a retriever (defensive: don't crash if heavy deps or credentials missing) ---
188
+ retriever_tool = None
189
+ vector_store = None
190
+ embeddings = None
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ # Try to create HuggingFaceEmbeddings and SupabaseVectorStore if dependencies and env are present.
193
+ try:
194
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
195
+ except Exception as e:
196
+ print(f"⚠️ Could not initialize HuggingFaceEmbeddings: {e}")
197
+ embeddings = None
198
+
199
+ SUPABASE_URL = os.environ.get("SUPABASE_URL")
200
+ SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_KEY")
201
+
202
+ if SUPABASE_URL and SUPABASE_SERVICE_KEY and embeddings is not None:
203
+ try:
204
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_SERVICE_KEY)
205
+ vector_store = SupabaseVectorStore(
206
+ client=supabase,
207
+ embedding=embeddings,
208
+ table_name="documents",
209
+ query_name="match_documents_langchain",
210
+ )
211
+ except Exception as e:
212
+ print(f"⚠️ Could not initialize SupabaseVectorStore: {e}")
213
+ vector_store = None
214
+ else:
215
+ if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
216
+ print("⚠️ SUPABASE_URL or SUPABASE_SERVICE_KEY not set — skipping vector store initialization.")
217
+ elif embeddings is None:
218
+ print("⚠️ Embeddings not available — skipping vector store initialization.")
219
+ vector_store = None
220
+
221
+ # Create a retriever tool only if vector_store exists
222
+ if vector_store is not None:
223
+ try:
224
+ retriever_tool = create_retriever_tool(
225
+ retriever=vector_store.as_retriever(),
226
+ name="Question Search",
227
+ description="A tool to retrieve similar questions from a vector store.",
228
+ )
229
+ except Exception as e:
230
+ print(f"⚠️ Failed to create retriever tool from vector store: {e}")
231
+ retriever_tool = None
232
+ else:
233
+ retriever_tool = None
234
 
235
 
236
  tools = [
 
244
  arvix_search,
245
  ]
246
 
247
+ # Add retriever_tool to tools if available and matches the callable interface
248
+ if retriever_tool is not None:
249
+ try:
250
+ if hasattr(retriever_tool, "run"):
251
+ @tool
252
+ def retriever_wrapper(query: str) -> str:
253
+ return retriever_tool.run(query)
254
+ tools.append(retriever_wrapper)
255
+ else:
256
+ tools.append(retriever_tool)
257
+ except Exception as e:
258
+ print(f"⚠️ Could not append retriever tool to tools list: {e}")
259
+
260
+
261
  # Build graph function
262
  def build_graph(provider: str = "google"):
263
  """Build the graph"""
 
267
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
268
  elif provider == "groq":
269
  # Groq https://console.groq.com/docs/models
270
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
271
  elif provider == "huggingface":
272
  # TODO: Add huggingface endpoint
273
  llm = ChatHuggingFace(
 
278
  )
279
  else:
280
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
 
281
 
282
+ # Bind tools to LLM
283
+ try:
284
+ llm_with_tools = llm.bind_tools(tools)
285
+ except Exception as e:
286
+ print(f"⚠️ Could not bind tools to LLM: {e}")
287
+ # fallback: keep LLM without tools
288
+ llm_with_tools = llm
289
+
290
+ # Node: assistant
291
  def assistant(state: MessagesState):
292
  """Assistant node"""
293
+ try:
294
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
295
+ except Exception as e:
296
+ print(f"⚠️ assistant node failed: {e}")
297
+ # return empty message so graph can continue
298
+ return {"messages": [HumanMessage(content="")]}
 
 
 
299
 
300
  from langchain_core.messages import AIMessage
301
 
302
  def retriever(state: MessagesState):
303
  query = state["messages"][-1].content
304
+ # If vector_store not available, return empty message so assistant proceeds normally
305
+ if vector_store is None:
306
+ return {"messages": [AIMessage(content="")]}
307
+
308
+ try:
309
+ similar_docs = vector_store.similarity_search(query, k=1)
310
+ if not similar_docs:
311
+ return {"messages": [AIMessage(content="")]}
312
+ similar_doc = similar_docs[0]
313
+ content = similar_doc.page_content
314
+ if "Final answer :" in content:
315
+ answer = content.split("Final answer :")[-1].strip()
316
+ else:
317
+ answer = content.strip()
318
+ return {"messages": [AIMessage(content=answer)]}
319
+ except Exception as e:
320
+ print(f"⚠️ retriever node failed: {e}")
321
+ return {"messages": [AIMessage(content="")]}
322
+
323
+ # Build the state graph: a simple retriever-only entry point (defensive)
 
 
324
  builder = StateGraph(MessagesState)
325
  builder.add_node("retriever", retriever)
326
 
327
+ # Retriever is both the entry and finish point in this design
328
  builder.set_entry_point("retriever")
329
  builder.set_finish_point("retriever")
330