ktluege commited on
Commit
6440130
Β·
verified Β·
1 Parent(s): 426a424

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +85 -33
agent.py CHANGED
@@ -1,71 +1,123 @@
1
  # agent.py
2
  import os
3
  from dotenv import load_dotenv
4
- from langgraph.graph import StateGraph, MessagesState
 
 
 
5
  from langchain_community.tools.tavily_search import TavilySearchResults
6
- from langchain_community.document_loaders import WikipediaLoader
 
7
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
8
  from langchain_core.tools import tool
 
 
9
 
10
- # Load environment variables (API keys, etc.)
11
  load_dotenv()
12
 
13
- # === TOOLS ===
 
 
 
14
 
15
  @tool
16
  def add(a: int, b: int) -> int:
17
- """Add two numbers."""
18
  return a + b
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @tool
21
  def wiki_search(query: str) -> str:
22
- """Search Wikipedia and return up to 2 results."""
23
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
24
- formatted = "\n\n---\n\n".join(
25
- [f"{doc.metadata}\n{doc.page_content[:500]}" for doc in search_docs]
26
- )
27
- return {"wiki_results": formatted}
 
 
28
 
29
  @tool
30
  def web_search(query: str) -> str:
31
- """Search Tavily for a query and return max 3 results."""
32
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
33
- formatted = "\n\n---\n\n".join(
34
- [f"{doc.metadata}\n{doc.page_content[:500]}" for doc in search_docs]
35
- )
36
- return {"web_results": formatted}
 
 
37
 
38
- tools = [add, wiki_search, web_search]
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # === SYSTEM PROMPT ===
41
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
42
  system_prompt = f.read()
43
  sys_msg = SystemMessage(content=system_prompt)
44
 
45
- # === GRAPH BUILD ===
46
- def build_graph():
47
- # For demo, use OpenAI API if key provided; else default to Gemini
48
- from langchain_openai import ChatOpenAI
49
- openai_api_key = os.environ.get("OPENAI_API_KEY")
50
- if openai_api_key:
51
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, openai_api_key=openai_api_key)
52
- else:
53
- # Gemini fallback (or any other, e.g., groq, etc.)
54
- from langchain_google_genai import ChatGoogleGenerativeAI
55
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
 
56
 
57
- # Bind tools to LLM
 
 
 
 
 
58
  llm_with_tools = llm.bind_tools(tools)
59
 
60
- # Simple assistant node
 
 
 
 
 
 
 
 
 
 
61
  def assistant(state: MessagesState):
62
- # Always include system prompt!
63
  messages = [sys_msg] + state["messages"]
64
  return {"messages": [llm_with_tools.invoke(messages)]}
65
 
66
  builder = StateGraph(MessagesState)
 
67
  builder.add_node("assistant", assistant)
68
- builder.set_entry_point("assistant")
 
69
  builder.set_finish_point("assistant")
70
-
71
  return builder.compile()
 
1
  # agent.py
2
  import os
3
  from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition, ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_community.vectorstores import SupabaseVectorStore
11
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
12
  from langchain_core.tools import tool
13
+ from langchain.tools.retriever import create_retriever_tool
14
+ from supabase.client import create_client
15
 
 
16
  load_dotenv()
17
 
18
+ # ========== TOOLS ==========
19
+ @tool
20
+ def multiply(a: int, b: int) -> int:
21
+ return a * b
22
 
23
  @tool
24
  def add(a: int, b: int) -> int:
 
25
  return a + b
26
 
27
+ @tool
28
+ def subtract(a: int, b: int) -> int:
29
+ return a - b
30
+
31
+ @tool
32
+ def divide(a: int, b: int) -> float:
33
+ if b == 0:
34
+ raise ValueError("Cannot divide by zero.")
35
+ return a / b
36
+
37
+ @tool
38
+ def modulus(a: int, b: int) -> int:
39
+ return a % b
40
+
41
  @tool
42
  def wiki_search(query: str) -> str:
 
43
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
+ formatted_search_docs = "\n\n---\n\n".join(
45
+ [
46
+ f'{doc.metadata}\n{doc.page_content[:500]}'
47
+ for doc in search_docs
48
+ ])
49
+ return {"wiki_results": formatted_search_docs}
50
 
51
  @tool
52
  def web_search(query: str) -> str:
 
53
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
54
+ formatted_search_docs = "\n\n---\n\n".join(
55
+ [
56
+ f'{doc.metadata}\n{doc.page_content[:500]}'
57
+ for doc in search_docs
58
+ ])
59
+ return {"web_results": formatted_search_docs}
60
 
61
+ @tool
62
+ def arvix_search(query: str) -> str:
63
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
64
+ formatted_search_docs = "\n\n---\n\n".join(
65
+ [
66
+ f'{doc.metadata}\n{doc.page_content[:1000]}'
67
+ for doc in search_docs
68
+ ])
69
+ return {"arvix_results": formatted_search_docs}
70
+
71
+ tools = [
72
+ multiply, add, subtract, divide, modulus,
73
+ wiki_search, web_search, arvix_search,
74
+ ]
75
 
76
+ # ========== SYSTEM PROMPT ==========
77
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
78
  system_prompt = f.read()
79
  sys_msg = SystemMessage(content=system_prompt)
80
 
81
+ # ========== SUPABASE VECTORSTORE ==========
82
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
83
+ supabase = create_client(
84
+ os.environ.get("SUPABASE_URL"),
85
+ os.environ.get("SUPABASE_SERVICE_KEY")
86
+ )
87
+ vector_store = SupabaseVectorStore(
88
+ client=supabase,
89
+ embedding=embeddings,
90
+ table_name="documents",
91
+ query_name="match_documents_langchain",
92
+ )
93
 
94
+ # ========== GRAPH BUILD ==========
95
+ def build_graph(provider: str = "google"):
96
+ if provider == "google":
97
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
98
+ else:
99
+ raise ValueError("Invalid provider. (Expand for openai/groq if needed)")
100
  llm_with_tools = llm.bind_tools(tools)
101
 
102
+ def retriever(state: MessagesState):
103
+ query = state["messages"][-1].content
104
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
105
+ content = similar_doc.page_content
106
+ # Try to extract final answer format
107
+ if "FINAL ANSWER:" in content:
108
+ answer = content.split("FINAL ANSWER:")[-1].strip()
109
+ return {"messages": [AIMessage(content=f"FINAL ANSWER: {answer}")]}
110
+ else:
111
+ return {"messages": [AIMessage(content=content.strip())]}
112
+
113
  def assistant(state: MessagesState):
 
114
  messages = [sys_msg] + state["messages"]
115
  return {"messages": [llm_with_tools.invoke(messages)]}
116
 
117
  builder = StateGraph(MessagesState)
118
+ builder.add_node("retriever", retriever)
119
  builder.add_node("assistant", assistant)
120
+ builder.add_edge(START, "retriever")
121
+ builder.add_edge("retriever", "assistant")
122
  builder.set_finish_point("assistant")
 
123
  return builder.compile()