manohargottam commited on
Commit
d14dccf
·
verified ·
1 Parent(s): 1551d5f

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +185 -145
agent.py CHANGED
@@ -1,180 +1,156 @@
1
- """LangGraph Agent"""
2
  import os
3
  import json
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_community.document_loaders import WikipediaLoader
13
  from langchain_community.document_loaders import ArxivLoader
14
- from langchain_community.vectorstores import SupabaseVectorStore
15
- from langchain_core.messages import SystemMessage, HumanMessage
16
  from langchain_core.tools import tool
17
- from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
 
20
  load_dotenv()
21
 
22
- def safe_get_metadata(doc, key, default=""):
23
- """Safely extract metadata from document, handling string and dict formats"""
24
- try:
25
- if isinstance(doc.metadata, str):
26
- # Try to parse as JSON if it's a string
27
- metadata = json.loads(doc.metadata)
28
- elif isinstance(doc.metadata, dict):
29
- metadata = doc.metadata
30
- else:
31
- return default
32
- return metadata.get(key, default)
33
- except (json.JSONDecodeError, AttributeError):
34
- return default
35
-
36
  @tool
37
  def multiply(a: int, b: int) -> int:
38
- """Multiply two numbers.
39
- Args:
40
- a: first int
41
- b: second int
42
- """
43
  return a * b
44
 
45
  @tool
46
  def add(a: int, b: int) -> int:
47
- """Add two numbers.
48
-
49
- Args:
50
- a: first int
51
- b: second int
52
- """
53
  return a + b
54
 
55
  @tool
56
  def subtract(a: int, b: int) -> int:
57
- """Subtract two numbers.
58
-
59
- Args:
60
- a: first int
61
- b: second int
62
- """
63
  return a - b
64
 
65
  @tool
66
  def divide(a: int, b: int) -> int:
67
- """Divide two numbers.
68
-
69
- Args:
70
- a: first int
71
- b: second int
72
- """
73
  if b == 0:
74
  raise ValueError("Cannot divide by zero.")
75
  return a / b
76
 
77
  @tool
78
  def modulus(a: int, b: int) -> int:
79
- """Get the modulus of two numbers.
80
-
81
- Args:
82
- a: first int
83
- b: second int
84
- """
85
  return a % b
86
 
87
  @tool
88
  def wiki_search(query: str) -> str:
89
- """Search Wikipedia for a query and return maximum 2 results.
90
-
91
- Args:
92
- query: The search query."""
93
  try:
94
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
95
- formatted_search_docs = "\n\n---\n\n".join(
96
- [
97
- f'<Document source="{safe_get_metadata(doc, "source")}" page="{safe_get_metadata(doc, "page")}"/>\n{doc.page_content}\n</Document>'
98
- for doc in search_docs
99
- ])
100
- return {"wiki_results": formatted_search_docs}
 
 
101
  except Exception as e:
102
- return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}
103
 
104
  @tool
105
  def web_search(query: str) -> str:
106
- """Search Tavily for a query and return maximum 3 results.
107
-
108
- Args:
109
- query: The search query."""
110
  try:
111
  search_tool = TavilySearchResults(max_results=3)
112
- search_results = search_tool.invoke(query)
113
 
114
- # Handle the case where search_results might be a list of dicts or Document objects
115
- if isinstance(search_results, list):
116
- formatted_search_docs = "\n\n---\n\n".join(
117
- [
118
- f'<Document source="{result.get("url", "")}" />\n{result.get("content", "")}\n</Document>'
119
- if isinstance(result, dict) else
120
- f'<Document source="{safe_get_metadata(result, "source")}" page="{safe_get_metadata(result, "page")}"/>\n{result.page_content}\n</Document>'
121
- for result in search_results
122
- ])
123
- else:
124
- formatted_search_docs = str(search_results)
125
-
126
- return {"web_results": formatted_search_docs}
127
  except Exception as e:
128
- return {"web_results": f"Error searching web: {str(e)}"}
129
 
130
  @tool
131
- def arvix_search(query: str) -> str:
132
- """Search Arxiv for a query and return maximum 3 result.
133
-
134
- Args:
135
- query: The search query."""
136
  try:
137
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
138
- formatted_search_docs = "\n\n---\n\n".join(
139
- [
140
- f'<Document source="{safe_get_metadata(doc, "source")}" page="{safe_get_metadata(doc, "page")}"/>\n{doc.page_content[:1000]}\n</Document>'
141
- for doc in search_docs
142
- ])
143
- return {"arvix_results": formatted_search_docs}
 
 
144
  except Exception as e:
145
- return {"arvix_results": f"Error searching Arxiv: {str(e)}"}
146
 
147
- # load the system prompt from the file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  try:
149
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
150
  system_prompt = f.read()
151
  except FileNotFoundError:
152
  system_prompt = "You are a helpful AI assistant."
153
 
154
- # System message
155
  sys_msg = SystemMessage(content=system_prompt)
156
 
157
- # build a retriever
158
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
159
  supabase_url = "https://ajnakgegqblhwltzkzbz.supabase.co"
160
  supabase_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo"
161
 
162
  try:
163
- supabase: Client = create_client(supabase_url, supabase_key)
164
- vector_store = SupabaseVectorStore(
165
- client=supabase,
166
- embedding= embeddings,
167
- table_name="documents",
168
- query_name="match_documents_langchain",
169
- )
170
- create_retriever_tool = create_retriever_tool(
171
- retriever=vector_store.as_retriever(),
172
- name="Question Search",
173
- description="A tool to retrieve similar questions from a vector store.",
174
- )
175
  except Exception as e:
176
- print(f"Warning: Could not initialize vector store: {e}")
177
- vector_store = None
178
 
179
  tools = [
180
  multiply,
@@ -184,59 +160,123 @@ tools = [
184
  modulus,
185
  wiki_search,
186
  web_search,
187
- arvix_search,
188
  ]
189
 
190
- # Build graph function
191
  def build_graph(provider: str = "groq"):
192
- """Build the graph"""
193
- # Load environment variables from .env file
194
  if provider == "groq":
195
- # Groq https://console.groq.com/docs/models
196
- llm = ChatGroq(model="qwen-qwq-32b",api_key="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", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
197
-
 
 
198
  else:
199
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
200
- # Bind tools to LLM
201
- llm_with_tools = llm.bind_tools(tools)
202
-
203
- # Node
204
- def assistant(state: MessagesState):
205
- """Assistant node"""
206
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
207
 
208
- from langchain_core.messages import AIMessage
209
-
210
  def retriever(state: MessagesState):
211
- """Retriever node with error handling"""
212
  try:
213
- if vector_store is None:
214
- return {"messages": [AIMessage(content="Vector store not available.")]}
215
-
216
  query = state["messages"][-1].content
217
- similar_docs = vector_store.similarity_search(query, k=1)
218
 
219
- if not similar_docs:
220
- return {"messages": [AIMessage(content="No similar documents found.")]}
 
 
 
221
 
222
- similar_doc = similar_docs[0]
223
- content = similar_doc.page_content
 
 
 
 
 
224
 
225
- if "Final answer :" in content:
226
- answer = content.split("Final answer :")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  else:
228
- answer = content.strip()
229
-
230
- return {"messages": [AIMessage(content=answer)]}
231
  except Exception as e:
232
- return {"messages": [AIMessage(content=f"Error in retriever: {str(e)}")]}
233
 
 
234
  builder = StateGraph(MessagesState)
235
  builder.add_node("retriever", retriever)
236
-
237
- # Retriever ist Start und Endpunkt
238
  builder.set_entry_point("retriever")
239
  builder.set_finish_point("retriever")
 
 
240
 
241
- # Compile graph
242
- return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent - Complete bypass of problematic vector store"""
2
  import os
3
  import json
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
 
8
  from langchain_groq import ChatGroq
 
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
  from langchain_community.document_loaders import WikipediaLoader
11
  from langchain_community.document_loaders import ArxivLoader
12
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
 
13
  from langchain_core.tools import tool
 
14
  from supabase.client import Client, create_client
15
 
16
  load_dotenv()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @tool
19
  def multiply(a: int, b: int) -> int:
20
+ """Multiply two numbers."""
 
 
 
 
21
  return a * b
22
 
23
  @tool
24
  def add(a: int, b: int) -> int:
25
+ """Add two numbers."""
 
 
 
 
 
26
  return a + b
27
 
28
  @tool
29
  def subtract(a: int, b: int) -> int:
30
+ """Subtract two numbers."""
 
 
 
 
 
31
  return a - b
32
 
33
  @tool
34
  def divide(a: int, b: int) -> int:
35
+ """Divide two numbers."""
 
 
 
 
 
36
  if b == 0:
37
  raise ValueError("Cannot divide by zero.")
38
  return a / b
39
 
40
  @tool
41
  def modulus(a: int, b: int) -> int:
42
+ """Get the modulus of two numbers."""
 
 
 
 
 
43
  return a % b
44
 
45
  @tool
46
  def wiki_search(query: str) -> str:
47
+ """Search Wikipedia for a query and return maximum 2 results."""
 
 
 
48
  try:
49
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
50
+ formatted_docs = []
51
+ for doc in search_docs:
52
+ source = "Wikipedia"
53
+ if hasattr(doc, 'metadata') and isinstance(doc.metadata, dict):
54
+ source = doc.metadata.get('source', 'Wikipedia')
55
+ formatted_docs.append(f"Source: {source}\n{doc.page_content[:1000]}...")
56
+
57
+ return "\n\n---\n\n".join(formatted_docs)
58
  except Exception as e:
59
+ return f"Error searching Wikipedia: {str(e)}"
60
 
61
  @tool
62
  def web_search(query: str) -> str:
63
+ """Search the web using Tavily."""
 
 
 
64
  try:
65
  search_tool = TavilySearchResults(max_results=3)
66
+ results = search_tool.invoke(query)
67
 
68
+ if isinstance(results, list):
69
+ formatted_results = []
70
+ for result in results:
71
+ if isinstance(result, dict):
72
+ url = result.get('url', 'Unknown')
73
+ content = result.get('content', '')[:1000]
74
+ formatted_results.append(f"Source: {url}\n{content}...")
75
+ return "\n\n---\n\n".join(formatted_results)
76
+ return str(results)
 
 
 
 
77
  except Exception as e:
78
+ return f"Error searching web: {str(e)}"
79
 
80
  @tool
81
+ def arxiv_search(query: str) -> str:
82
+ """Search Arxiv for academic papers."""
 
 
 
83
  try:
84
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
85
+ formatted_docs = []
86
+ for doc in search_docs:
87
+ source = "ArXiv"
88
+ if hasattr(doc, 'metadata') and isinstance(doc.metadata, dict):
89
+ source = doc.metadata.get('source', 'ArXiv')
90
+ formatted_docs.append(f"Source: {source}\n{doc.page_content[:1000]}...")
91
+
92
+ return "\n\n---\n\n".join(formatted_docs)
93
  except Exception as e:
94
+ return f"Error searching ArXiv: {str(e)}"
95
 
96
+ # Raw Supabase search function that bypasses LangChain entirely
97
+ def raw_supabase_search(query: str, supabase_client):
98
+ """Direct Supabase search without any LangChain components"""
99
+ try:
100
+ # Simple text-based search using Supabase's built-in functions
101
+ # This assumes you have a simple text search function in your database
102
+ result = supabase_client.table('documents').select('content').text_search('content', query).limit(1).execute()
103
+
104
+ if result.data:
105
+ return result.data[0]['content']
106
+ else:
107
+ # Fallback: get any document (for testing)
108
+ result = supabase_client.table('documents').select('content').limit(1).execute()
109
+ if result.data:
110
+ return result.data[0]['content']
111
+ return "No documents found in database"
112
+
113
+ except Exception as e:
114
+ return f"Database search error: {str(e)}"
115
+
116
+ # Alternative: Use simple SQL query
117
+ def simple_sql_search(query: str, supabase_client):
118
+ """Simple SQL-based search"""
119
+ try:
120
+ # Use a simple SQL query to avoid metadata issues
121
+ sql_query = f"""
122
+ SELECT content
123
+ FROM documents
124
+ WHERE content ILIKE '%{query}%'
125
+ LIMIT 1
126
+ """
127
+ result = supabase_client.rpc('execute_sql', {'query': sql_query}).execute()
128
+
129
+ if result.data:
130
+ return result.data[0]['content']
131
+ return "No matching documents found"
132
+
133
+ except Exception as e:
134
+ return f"SQL search error: {str(e)}"
135
+
136
+ # Load system prompt
137
  try:
138
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
139
  system_prompt = f.read()
140
  except FileNotFoundError:
141
  system_prompt = "You are a helpful AI assistant."
142
 
 
143
  sys_msg = SystemMessage(content=system_prompt)
144
 
145
+ # Initialize Supabase without vector store
 
146
  supabase_url = "https://ajnakgegqblhwltzkzbz.supabase.co"
147
  supabase_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo"
148
 
149
  try:
150
+ supabase_client = create_client(supabase_url, supabase_key)
 
 
 
 
 
 
 
 
 
 
 
151
  except Exception as e:
152
+ print(f"Warning: Could not initialize Supabase client: {e}")
153
+ supabase_client = None
154
 
155
  tools = [
156
  multiply,
 
160
  modulus,
161
  wiki_search,
162
  web_search,
163
+ arxiv_search,
164
  ]
165
 
 
166
  def build_graph(provider: str = "groq"):
167
+ """Build the graph without problematic vector store operations"""
 
168
  if provider == "groq":
169
+ llm = ChatGroq(
170
+ model="qwen-qwq-32b",
171
+ api_key="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",
172
+ temperature=0
173
+ )
174
  else:
175
+ raise ValueError("Invalid provider. Choose 'groq'.")
 
 
 
 
 
 
 
176
 
 
 
177
  def retriever(state: MessagesState):
178
+ """Simple retriever that avoids all Document/metadata validation"""
179
  try:
 
 
 
180
  query = state["messages"][-1].content
 
181
 
182
+ if supabase_client is None:
183
+ return {"messages": [AIMessage(content="Database not available. Try using the web search tools instead.")]}
184
+
185
+ # Try different approaches in order of preference
186
+ content = None
187
 
188
+ # Approach 1: Simple table query
189
+ try:
190
+ result = supabase_client.table('documents').select('content').limit(1).execute()
191
+ if result.data and len(result.data) > 0:
192
+ content = result.data[0].get('content', '')
193
+ except Exception as e:
194
+ print(f"Table query failed: {e}")
195
 
196
+ # Approach 2: Raw supabase search
197
+ if not content:
198
+ content = raw_supabase_search(query, supabase_client)
199
+
200
+ # Process the content
201
+ if content and content.strip():
202
+ # Look for final answer pattern
203
+ if "Final answer :" in content:
204
+ answer = content.split("Final answer :")[-1].strip()
205
+ else:
206
+ # Take first 500 characters as answer
207
+ answer = content.strip()[:500]
208
+ if len(content) > 500:
209
+ answer += "..."
210
+
211
+ return {"messages": [AIMessage(content=answer)]}
212
  else:
213
+ return {"messages": [AIMessage(content="No relevant information found. Please try using the search tools.")]}
214
+
 
215
  except Exception as e:
216
+ return {"messages": [AIMessage(content=f"Search unavailable: {str(e)}. Please try using the web search tools.")]}
217
 
218
+ # Build simple graph
219
  builder = StateGraph(MessagesState)
220
  builder.add_node("retriever", retriever)
 
 
221
  builder.set_entry_point("retriever")
222
  builder.set_finish_point("retriever")
223
+
224
+ return builder.compile()
225
 
226
+ # Alternative: Build graph without retriever at all
227
+ def build_assistant_graph(provider: str = "groq"):
228
+ """Build a graph with just assistant and tools (no problematic retriever)"""
229
+ if provider == "groq":
230
+ llm = ChatGroq(
231
+ model="qwen-qwq-32b",
232
+ api_key="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",
233
+ temperature=0
234
+ )
235
+ else:
236
+ raise ValueError("Invalid provider.")
237
+
238
+ llm_with_tools = llm.bind_tools(tools)
239
+
240
+ def assistant(state: MessagesState):
241
+ """Assistant node that can use tools"""
242
+ messages = [sys_msg] + state["messages"]
243
+ return {"messages": [llm_with_tools.invoke(messages)]}
244
+
245
+ builder = StateGraph(MessagesState)
246
+ builder.add_node("assistant", assistant)
247
+ builder.add_node("tools", ToolNode(tools))
248
+
249
+ builder.set_entry_point("assistant")
250
+ builder.add_conditional_edges("assistant", tools_condition)
251
+ builder.add_edge("tools", "assistant")
252
+
253
+ return builder.compile()
254
+
255
+ # Test function
256
+ def test_graph():
257
+ """Test the graph builds successfully"""
258
+ try:
259
+ print("Testing retriever-based graph...")
260
+ graph1 = build_graph()
261
+ print("✓ Retriever graph built successfully!")
262
+ return graph1
263
+ except Exception as e:
264
+ print(f"✗ Retriever graph failed: {e}")
265
+ print("Testing assistant-only graph...")
266
+ try:
267
+ graph2 = build_assistant_graph()
268
+ print("✓ Assistant graph built successfully!")
269
+ return graph2
270
+ except Exception as e2:
271
+ print(f"✗ Assistant graph also failed: {e2}")
272
+ return None
273
+
274
+ if __name__ == "__main__":
275
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
276
+
277
+ graph = build_graph(provider="groq")
278
+
279
+ messages = [HumanMessage(content=question)]
280
+ messages = graph.invoke({"messages": messages})
281
+ for m in messages["messages"]:
282
+ m.pretty_print()