RobotPai / agent.py
hua101's picture
Update agent.py
0756a62 verified
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
load_dotenv()
print("GROQ_API_KEY:", os.getenv("GROQ_API_KEY"))
print("SUPABASE_URL:", os.getenv("SUPABASE_URL"))
# === 原有的数学工具 ===
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: first int
b: second int
"""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: first int
b: second int
"""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers.
Args:
a: first int
b: second int
"""
return a - b
@tool
def divide(a: int, b: int) -> int:
"""Divide two numbers.
Args:
a: first int
b: second int
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
return a % b
# === 原有的搜索工具 ===
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"wiki_results": formatted_search_docs}
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query."""
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"web_results": formatted_search_docs}
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return {"arvix_results": formatted_search_docs}
# === 新增:Supabase 工具 ===
@tool
def supabase_vector_search(query: str, max_results: int = 3) -> str:
"""Search the Supabase knowledge base using vector similarity.
Args:
query: The search query
max_results: Maximum number of results to return (default: 3)
"""
try:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
supabase: Client = create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_SERVICE_KEY")
)
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="supabase_docs", # 使用您的实际表名
query_name="match_documents", # 使用我们创建的函数
)
results = vector_store.similarity_search(query, k=max_results)
if results:
formatted_results = "\n\n---\n\n".join([
f'<Document similarity="high"/>\n{doc.page_content[:800]}...\n</Document>'
for doc in results
])
return {"supabase_vector_results": formatted_results}
else:
return {"message": "No relevant documents found in knowledge base"}
except Exception as e:
return {"error": f"Supabase vector search failed: {str(e)}"}
@tool
def supabase_text_search(query: str, max_results: int = 3) -> str:
"""Search the Supabase knowledge base using text search.
Args:
query: The search query
max_results: Maximum number of results to return (default: 3)
"""
try:
supabase: Client = create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_SERVICE_KEY")
)
# 使用我们创建的混合搜索函数,只用文本搜索
result = supabase.rpc('hybrid_search', {
'search_query': query,
'search_type': 'text',
'max_results': max_results
}).execute()
if result.data:
formatted_results = "\n\n---\n\n".join([
f'<Document similarity="{item.get("similarity", 0):.3f}"/>\n{item["content"][:800]}...\n</Document>'
for item in result.data
])
return {"supabase_text_results": formatted_results}
else:
return {"message": "No relevant documents found in knowledge base"}
except Exception as e:
return {"error": f"Supabase text search failed: {str(e)}"}
@tool
def get_knowledge_context(query: str) -> str:
"""Get contextual information from the knowledge base for better understanding.
Args:
query: The user's question
"""
try:
supabase: Client = create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_SERVICE_KEY")
)
result = supabase.rpc('get_agent_context', {
'user_query': query,
'context_limit': 2
}).execute()
if result.data and len(result.data) > 0:
context_data = result.data[0]
context_text = context_data.get("context_text", "")
confidence = context_data.get("confidence_score", 0)
source_count = context_data.get("source_count", 0)
if context_text and source_count > 0:
return {
"context": context_text[:1000], # 限制长度
"confidence": f"{confidence:.2f}",
"sources": source_count
}
else:
return {"message": "No relevant context found"}
else:
return {"message": "No context available"}
except Exception as e:
return {"error": f"Context retrieval failed: {str(e)}"}
# load the system prompt from the file
try:
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
except FileNotFoundError:
# 如果文件不存在,使用默认系统提示
system_prompt = """你是一个智能助手,可以使用多种工具来回答用户的问题。
可用工具包括:
1. 数学计算工具(加减乘除等)
2. 网络搜索工具(Wikipedia, Arxiv, Web搜索)
3. Supabase 知识库工具(向量搜索、文本搜索、上下文获取)
请根据用户的问题选择最合适的工具,并提供准确、有用的答案。对于知识库中的信息,优先使用 Supabase 工具。"""
# System message
sys_msg = SystemMessage(content=system_prompt)
# === 更新 retriever 设置 ===
def setup_vector_store():
"""设置向量存储"""
try:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
supabase: Client = create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_SERVICE_KEY")
)
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="supabase_docs", # 修改为正确的表名
query_name="match_documents", # 使用我们创建的函数
)
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
name="Knowledge Base Search",
description="Search the knowledge base for similar questions and answers.",
)
return vector_store, retriever_tool
except Exception as e:
print(f"❌ Vector store setup failed: {e}")
return None, None
# 设置向量存储
vector_store, retriever_tool = setup_vector_store()
# === 更新工具列表 ===
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arvix_search,
supabase_vector_search, # 新增
supabase_text_search, # 新增
get_knowledge_context, # 新增
]
# 如果 retriever 设置成功,添加到工具列表
if retriever_tool:
tools.append(retriever_tool)
print("✅ Knowledge base retriever tool added")
else:
print("⚠️ Knowledge base retriever tool not available")
# Build graph function
def build_graph(provider: str = "groq"):
"""Build the graph"""
if provider == "google":
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0,
),
)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Node
def assistant(state: MessagesState):
"""Assistant node"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def retriever(state: MessagesState):
"""Enhanced retriever node with Supabase integration"""
try:
if vector_store and len(state["messages"]) > 0:
user_query = state["messages"][-1].content
similar_questions = vector_store.similarity_search(user_query, k=2)
if similar_questions:
example_content = "\n\n".join([
f"Similar Q&A {i+1}: {doc.page_content[:400]}..."
for i, doc in enumerate(similar_questions)
])
example_msg = HumanMessage(
content=f"Here are similar questions and answers from the knowledge base for reference:\n\n{example_content}",
)
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
# 如果没有向量存储或搜索失败,返回原始消息
return {"messages": [sys_msg] + state["messages"]}
except Exception as e:
print(f"Retriever error: {e}")
return {"messages": [sys_msg] + state["messages"]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
# test
if __name__ == "__main__":
# 测试多种类型的问题
test_questions = [
"When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?",
"What is the area of the green polygon?", # 测试知识库搜索
"Calculate 25 times 17", # 测试数学工具
]
print("🚀 开始测试 Agent...")
# Build the graph
graph = build_graph(provider="groq")
for i, question in enumerate(test_questions, 1):
print(f"\n{'='*60}")
print(f"测试 {i}/3: {question}")
print(f"{'='*60}")
try:
messages = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
print("\n📋 对话历史:")
for m in result["messages"]:
m.pretty_print()
except Exception as e:
print(f"❌ 处理问题时出错: {e}")
print(f"\n{'-'*60}")
print("\n🎉 测试完成!")