Spaces:
Sleeping
Sleeping
Upload agent.py
Browse files
agent.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""LangGraph Agent"""
|
| 2 |
import os
|
|
|
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
from langgraph.graph import START, StateGraph, MessagesState
|
| 5 |
from langgraph.prebuilt import tools_condition
|
|
@@ -18,6 +19,20 @@ from supabase.client import Client, create_client
|
|
| 18 |
|
| 19 |
load_dotenv()
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
@tool
|
| 22 |
def multiply(a: int, b: int) -> int:
|
| 23 |
"""Multiply two numbers.
|
|
@@ -75,13 +90,16 @@ def wiki_search(query: str) -> str:
|
|
| 75 |
|
| 76 |
Args:
|
| 77 |
query: The search query."""
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
@tool
|
| 87 |
def web_search(query: str) -> str:
|
|
@@ -89,13 +107,25 @@ def web_search(query: str) -> str:
|
|
| 89 |
|
| 90 |
Args:
|
| 91 |
query: The search query."""
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
@tool
|
| 101 |
def arvix_search(query: str) -> str:
|
|
@@ -103,19 +133,23 @@ def arvix_search(query: str) -> str:
|
|
| 103 |
|
| 104 |
Args:
|
| 105 |
query: The search query."""
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
# load the system prompt from the file
|
| 117 |
-
|
| 118 |
-
system_prompt = f
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
# System message
|
| 121 |
sys_msg = SystemMessage(content=system_prompt)
|
|
@@ -125,20 +159,22 @@ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-b
|
|
| 125 |
supabase_url = "https://ajnakgegqblhwltzkzbz.supabase.co"
|
| 126 |
supabase_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo"
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
|
| 143 |
tools = [
|
| 144 |
multiply,
|
|
@@ -169,39 +205,31 @@ def build_graph(provider: str = "groq"):
|
|
| 169 |
"""Assistant node"""
|
| 170 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
| 171 |
|
| 172 |
-
# def retriever(state: MessagesState):
|
| 173 |
-
# """Retriever node"""
|
| 174 |
-
# similar_question = vector_store.similarity_search(state["messages"][0].content)
|
| 175 |
-
#example_msg = HumanMessage(
|
| 176 |
-
# content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
|
| 177 |
-
# )
|
| 178 |
-
# return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
| 179 |
-
|
| 180 |
from langchain_core.messages import AIMessage
|
| 181 |
|
| 182 |
def retriever(state: MessagesState):
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
|
| 206 |
builder = StateGraph(MessagesState)
|
| 207 |
builder.add_node("retriever", retriever)
|
|
@@ -211,4 +239,4 @@ def build_graph(provider: str = "groq"):
|
|
| 211 |
builder.set_finish_point("retriever")
|
| 212 |
|
| 213 |
# Compile graph
|
| 214 |
-
return builder.compile()
|
|
|
|
| 1 |
"""LangGraph Agent"""
|
| 2 |
import os
|
| 3 |
+
import json
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
from langgraph.graph import START, StateGraph, MessagesState
|
| 6 |
from langgraph.prebuilt import tools_condition
|
|
|
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
|
| 22 |
+
def safe_get_metadata(doc, key, default=""):
|
| 23 |
+
"""Safely extract metadata from document, handling string and dict formats"""
|
| 24 |
+
try:
|
| 25 |
+
if isinstance(doc.metadata, str):
|
| 26 |
+
# Try to parse as JSON if it's a string
|
| 27 |
+
metadata = json.loads(doc.metadata)
|
| 28 |
+
elif isinstance(doc.metadata, dict):
|
| 29 |
+
metadata = doc.metadata
|
| 30 |
+
else:
|
| 31 |
+
return default
|
| 32 |
+
return metadata.get(key, default)
|
| 33 |
+
except (json.JSONDecodeError, AttributeError):
|
| 34 |
+
return default
|
| 35 |
+
|
| 36 |
@tool
|
| 37 |
def multiply(a: int, b: int) -> int:
|
| 38 |
"""Multiply two numbers.
|
|
|
|
| 90 |
|
| 91 |
Args:
|
| 92 |
query: The search query."""
|
| 93 |
+
try:
|
| 94 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
| 95 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 96 |
+
[
|
| 97 |
+
f'<Document source="{safe_get_metadata(doc, "source")}" page="{safe_get_metadata(doc, "page")}"/>\n{doc.page_content}\n</Document>'
|
| 98 |
+
for doc in search_docs
|
| 99 |
+
])
|
| 100 |
+
return {"wiki_results": formatted_search_docs}
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}
|
| 103 |
|
| 104 |
@tool
|
| 105 |
def web_search(query: str) -> str:
|
|
|
|
| 107 |
|
| 108 |
Args:
|
| 109 |
query: The search query."""
|
| 110 |
+
try:
|
| 111 |
+
search_tool = TavilySearchResults(max_results=3)
|
| 112 |
+
search_results = search_tool.invoke(query)
|
| 113 |
+
|
| 114 |
+
# Handle the case where search_results might be a list of dicts or Document objects
|
| 115 |
+
if isinstance(search_results, list):
|
| 116 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 117 |
+
[
|
| 118 |
+
f'<Document source="{result.get("url", "")}" />\n{result.get("content", "")}\n</Document>'
|
| 119 |
+
if isinstance(result, dict) else
|
| 120 |
+
f'<Document source="{safe_get_metadata(result, "source")}" page="{safe_get_metadata(result, "page")}"/>\n{result.page_content}\n</Document>'
|
| 121 |
+
for result in search_results
|
| 122 |
+
])
|
| 123 |
+
else:
|
| 124 |
+
formatted_search_docs = str(search_results)
|
| 125 |
+
|
| 126 |
+
return {"web_results": formatted_search_docs}
|
| 127 |
+
except Exception as e:
|
| 128 |
+
return {"web_results": f"Error searching web: {str(e)}"}
|
| 129 |
|
| 130 |
@tool
|
| 131 |
def arvix_search(query: str) -> str:
|
|
|
|
| 133 |
|
| 134 |
Args:
|
| 135 |
query: The search query."""
|
| 136 |
+
try:
|
| 137 |
+
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
|
| 138 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 139 |
+
[
|
| 140 |
+
f'<Document source="{safe_get_metadata(doc, "source")}" page="{safe_get_metadata(doc, "page")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
| 141 |
+
for doc in search_docs
|
| 142 |
+
])
|
| 143 |
+
return {"arvix_results": formatted_search_docs}
|
| 144 |
+
except Exception as e:
|
| 145 |
+
return {"arvix_results": f"Error searching Arxiv: {str(e)}"}
|
| 146 |
|
| 147 |
# load the system prompt from the file
|
| 148 |
+
try:
|
| 149 |
+
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
| 150 |
+
system_prompt = f.read()
|
| 151 |
+
except FileNotFoundError:
|
| 152 |
+
system_prompt = "You are a helpful AI assistant."
|
| 153 |
|
| 154 |
# System message
|
| 155 |
sys_msg = SystemMessage(content=system_prompt)
|
|
|
|
| 159 |
supabase_url = "https://ajnakgegqblhwltzkzbz.supabase.co"
|
| 160 |
supabase_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImFqbmFrZ2VncWJsaHdsdHpremJ6Iiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDkyMDgxODgsImV4cCI6MjA2NDc4NDE4OH0.b9RPF-5otedg4yiaQu_uhOgYpXVXd9D_0oR-9cluUjo"
|
| 161 |
|
| 162 |
+
try:
|
| 163 |
+
supabase: Client = create_client(supabase_url, supabase_key)
|
| 164 |
+
vector_store = SupabaseVectorStore(
|
| 165 |
+
client=supabase,
|
| 166 |
+
embedding= embeddings,
|
| 167 |
+
table_name="documents",
|
| 168 |
+
query_name="match_documents_langchain",
|
| 169 |
+
)
|
| 170 |
+
create_retriever_tool = create_retriever_tool(
|
| 171 |
+
retriever=vector_store.as_retriever(),
|
| 172 |
+
name="Question Search",
|
| 173 |
+
description="A tool to retrieve similar questions from a vector store.",
|
| 174 |
+
)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Warning: Could not initialize vector store: {e}")
|
| 177 |
+
vector_store = None
|
| 178 |
|
| 179 |
tools = [
|
| 180 |
multiply,
|
|
|
|
| 205 |
"""Assistant node"""
|
| 206 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
from langchain_core.messages import AIMessage
|
| 209 |
|
| 210 |
def retriever(state: MessagesState):
|
| 211 |
+
"""Retriever node with error handling"""
|
| 212 |
+
try:
|
| 213 |
+
if vector_store is None:
|
| 214 |
+
return {"messages": [AIMessage(content="Vector store not available.")]}
|
| 215 |
+
|
| 216 |
+
query = state["messages"][-1].content
|
| 217 |
+
similar_docs = vector_store.similarity_search(query, k=1)
|
| 218 |
+
|
| 219 |
+
if not similar_docs:
|
| 220 |
+
return {"messages": [AIMessage(content="No similar documents found.")]}
|
| 221 |
+
|
| 222 |
+
similar_doc = similar_docs[0]
|
| 223 |
+
content = similar_doc.page_content
|
| 224 |
+
|
| 225 |
+
if "Final answer :" in content:
|
| 226 |
+
answer = content.split("Final answer :")[-1].strip()
|
| 227 |
+
else:
|
| 228 |
+
answer = content.strip()
|
| 229 |
+
|
| 230 |
+
return {"messages": [AIMessage(content=answer)]}
|
| 231 |
+
except Exception as e:
|
| 232 |
+
return {"messages": [AIMessage(content=f"Error in retriever: {str(e)}")]}
|
| 233 |
|
| 234 |
builder = StateGraph(MessagesState)
|
| 235 |
builder.add_node("retriever", retriever)
|
|
|
|
| 239 |
builder.set_finish_point("retriever")
|
| 240 |
|
| 241 |
# Compile graph
|
| 242 |
+
return builder.compile()
|