Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,9 +34,9 @@
|
|
| 34 |
|
| 35 |
|
| 36 |
import gradio as gr
|
| 37 |
-
import
|
|
|
|
| 38 |
from langgraph.graph import StateGraph, END
|
| 39 |
-
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
|
| 40 |
from langchain_core.prompts import ChatPromptTemplate
|
| 41 |
from langchain_openai import ChatOpenAI
|
| 42 |
|
|
@@ -47,15 +47,12 @@ from retriever import load_guest_dataset
|
|
| 47 |
# 定义状态对象
|
| 48 |
class AgentState(TypedDict):
|
| 49 |
messages: List[BaseMessage] # 消息历史
|
| 50 |
-
plan: Optional[str] # 当前计划
|
| 51 |
-
tool_results: List[
|
| 52 |
-
step_count: int # 当前步骤计数
|
|
|
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
|
| 56 |
-
# model = ChatOpenAI(temperature=0)
|
| 57 |
-
|
| 58 |
-
# Initialize the tools
|
| 59 |
search_tool = DuckDuckGoSearchTool()
|
| 60 |
weather_info_tool = WeatherInfoTool()
|
| 61 |
hub_stats_tool = HubStatsTool()
|
|
@@ -67,9 +64,16 @@ TOOLS = {
|
|
| 67 |
"weather_info": weather_info_tool,
|
| 68 |
"hub_stats": hub_stats_tool,
|
| 69 |
"guest_info": guest_info_tool,
|
| 70 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
}
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
# ========================
|
| 74 |
# LangGraph 节点函数
|
| 75 |
# ========================
|
|
@@ -79,17 +83,283 @@ def plan_node(state: AgentState):
|
|
| 79 |
messages = state["messages"]
|
| 80 |
step_count = state["step_count"]
|
| 81 |
|
| 82 |
-
# 每3步进行规划
|
| 83 |
-
if step_count % 3 == 0:
|
| 84 |
# 创建规划提示
|
| 85 |
prompt = ChatPromptTemplate.from_messages([
|
| 86 |
-
("system", "你是一个智能助手Alfred。根据对话历史和当前状态,规划下一步行动。"
|
| 87 |
-
|
|
|
|
| 88 |
])
|
| 89 |
|
|
|
|
|
|
|
|
|
|
| 90 |
chain = prompt | model
|
| 91 |
-
response = chain.invoke({"
|
| 92 |
|
| 93 |
return {
|
| 94 |
"plan": response.content,
|
| 95 |
-
"messages": messages}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
import gradio as gr
|
| 37 |
+
from typing import TypedDict, List, Dict, Optional, Annotated
|
| 38 |
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
| 39 |
from langgraph.graph import StateGraph, END
|
|
|
|
| 40 |
from langchain_core.prompts import ChatPromptTemplate
|
| 41 |
from langchain_openai import ChatOpenAI
|
| 42 |
|
|
|
|
| 47 |
# 定义状态对象
|
| 48 |
class AgentState(TypedDict):
|
| 49 |
messages: List[BaseMessage] # 消息历史
|
| 50 |
+
plan: Annotated[Optional[str], "当前计划"] # 当前计划
|
| 51 |
+
tool_results: Annotated[List[Dict], "工具执行结果"] # 工具执行结果
|
| 52 |
+
step_count: Annotated[int, "步骤计数器"] # 当前步骤计数
|
| 53 |
+
final_response: Annotated[Optional[str], "最终响应"] # 最终响应
|
| 54 |
|
| 55 |
+
# 初始化工具
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
search_tool = DuckDuckGoSearchTool()
|
| 57 |
weather_info_tool = WeatherInfoTool()
|
| 58 |
hub_stats_tool = HubStatsTool()
|
|
|
|
| 64 |
"weather_info": weather_info_tool,
|
| 65 |
"hub_stats": hub_stats_tool,
|
| 66 |
"guest_info": guest_info_tool,
|
| 67 |
+
# 基础工具(模拟smolagents的add_base_tools)
|
| 68 |
+
"search_web": search_tool,
|
| 69 |
+
"get_weather": weather_info_tool,
|
| 70 |
+
"get_hub_stats": hub_stats_tool,
|
| 71 |
+
"get_guest_info": guest_info_tool
|
| 72 |
}
|
| 73 |
|
| 74 |
+
# 初始化模型 - 使用ChatOpenAI替代HfApiModel
|
| 75 |
+
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7, max_tokens=512)
|
| 76 |
+
|
| 77 |
# ========================
|
| 78 |
# LangGraph 节点函数
|
| 79 |
# ========================
|
|
|
|
| 83 |
messages = state["messages"]
|
| 84 |
step_count = state["step_count"]
|
| 85 |
|
| 86 |
+
# 每3步进行规划(模拟planning_interval=3)
|
| 87 |
+
if step_count % 3 == 0 or step_count == 0:
|
| 88 |
# 创建规划提示
|
| 89 |
prompt = ChatPromptTemplate.from_messages([
|
| 90 |
+
("system", "你是一个智能助手Alfred。根据对话历史和当前状态,规划下一步行动。"
|
| 91 |
+
"提供清晰、具体的计划,包括要使用的工具(如果适用)。"),
|
| 92 |
+
("human", "当前对话历史:\n{history}\n\n请规划下一步行动。")
|
| 93 |
])
|
| 94 |
|
| 95 |
+
# 提取对话历史
|
| 96 |
+
history = "\n".join([f"{m.type}: {m.content}" for m in messages])
|
| 97 |
+
|
| 98 |
chain = prompt | model
|
| 99 |
+
response = chain.invoke({"history": history})
|
| 100 |
|
| 101 |
return {
|
| 102 |
"plan": response.content,
|
| 103 |
+
"messages": messages + [AIMessage(content=f"计划: {response.content}")],
|
| 104 |
+
"step_count": step_count + 1
|
| 105 |
+
}
|
| 106 |
+
return state
|
| 107 |
+
|
| 108 |
+
def tool_selection_node(state: AgentState):
|
| 109 |
+
"""工具选择节点 - 决定使用哪个工具"""
|
| 110 |
+
messages = state["messages"]
|
| 111 |
+
plan = state.get("plan", "")
|
| 112 |
+
|
| 113 |
+
# 创建工具选择提示
|
| 114 |
+
tool_list = ", ".join(TOOLS.keys())
|
| 115 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 116 |
+
("system", "你是一个智能助手Alfred。根据用户请求和当前计划,选择最合适的工具。"),
|
| 117 |
+
("system", f"可用工具: {tool_list}"),
|
| 118 |
+
("system", f"当前计划: {plan}"),
|
| 119 |
+
("human", "用户请求: {last_message}\n\n请选择要使用的工具(只返回工具名称)。")
|
| 120 |
+
])
|
| 121 |
+
|
| 122 |
+
# 获取最后一条用户消息
|
| 123 |
+
last_user_message = next((m.content for m in reversed(messages) if isinstance(m, HumanMessage)), "")
|
| 124 |
+
|
| 125 |
+
chain = prompt | model
|
| 126 |
+
response = chain.invoke({"last_message": last_user_message})
|
| 127 |
+
|
| 128 |
+
# 提取选择的工具
|
| 129 |
+
tool_name = response.content.strip().lower()
|
| 130 |
+
return {
|
| 131 |
+
"selected_tool": tool_name,
|
| 132 |
+
"messages": messages + [AIMessage(content=f"选择工具: {tool_name}")]
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
def tool_execution_node(state: AgentState):
|
| 136 |
+
"""工具执行节点 - 执行选定的工具"""
|
| 137 |
+
tool_name = state["selected_tool"]
|
| 138 |
+
messages = state["messages"]
|
| 139 |
+
|
| 140 |
+
# 获取工具实例
|
| 141 |
+
tool = TOOLS.get(tool_name)
|
| 142 |
+
if not tool:
|
| 143 |
+
error_msg = f"未知工具: {tool_name}"
|
| 144 |
+
return {
|
| 145 |
+
"tool_results": [{"tool": tool_name, "error": error_msg}],
|
| 146 |
+
"messages": messages + [AIMessage(content=error_msg)]
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# 提取工具参数(最后一条用户消息)
|
| 150 |
+
last_user_message = next((m.content for m in reversed(messages) if isinstance(m, HumanMessage)), "")
|
| 151 |
+
|
| 152 |
+
# 执行工具
|
| 153 |
+
try:
|
| 154 |
+
result = tool.run(last_user_message)
|
| 155 |
+
tool_result = {"tool": tool_name, "result": result}
|
| 156 |
+
except Exception as e:
|
| 157 |
+
tool_result = {"tool": tool_name, "error": str(e)}
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
"tool_results": [tool_result],
|
| 161 |
+
"messages": messages + [AIMessage(content=f"工具执行结果: {str(tool_result)[:200]}...")]
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def response_generation_node(state: AgentState):
|
| 165 |
+
"""响应生成节点 - 基于工具结果生成最终响应"""
|
| 166 |
+
tool_results = state["tool_results"]
|
| 167 |
+
messages = state["messages"]
|
| 168 |
+
|
| 169 |
+
# 创建响应生成提示
|
| 170 |
+
results_str = "\n".join([f"{res['tool']}: {res.get('result', res.get('error', '无结果'))}"
|
| 171 |
+
for res in tool_results])
|
| 172 |
+
|
| 173 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 174 |
+
("system", "你是一个智能助手Alfred。基于工具执行结果生成对用户的响应。"),
|
| 175 |
+
("system", f"工具执行结果:\n{results_str}"),
|
| 176 |
+
("system", "确保响应完整、友好且直接回答用户问题。"),
|
| 177 |
+
("human", "用户原始请求: {last_message}")
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
# 获取最后一条用户消息
|
| 181 |
+
last_user_message = next((m.content for m in reversed(messages) if isinstance(m, HumanMessage)), "")
|
| 182 |
+
|
| 183 |
+
chain = prompt | model
|
| 184 |
+
response = chain.invoke({"last_message": last_user_message})
|
| 185 |
+
|
| 186 |
+
return {
|
| 187 |
+
"messages": messages + [AIMessage(content=response.content)],
|
| 188 |
+
"final_response": response.content
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def should_continue(state: AgentState):
|
| 192 |
+
"""决定是否继续的条件函数"""
|
| 193 |
+
messages = state["messages"]
|
| 194 |
+
step_count = state["step_count"]
|
| 195 |
+
|
| 196 |
+
# 检查是否有最终响应
|
| 197 |
+
if "final_response" in state:
|
| 198 |
+
last_ai_message = state["final_response"]
|
| 199 |
+
|
| 200 |
+
# 检查AI是否表示需要更多操作
|
| 201 |
+
if "需要更多信息" in last_ai_message or "下一步" in last_ai_message:
|
| 202 |
+
return "continue"
|
| 203 |
+
|
| 204 |
+
# 检查是否有未完成的工具调用
|
| 205 |
+
if any("error" in res for res in state.get("tool_results", [])):
|
| 206 |
+
return "continue"
|
| 207 |
+
|
| 208 |
+
# 默认情况下结束
|
| 209 |
+
return "end"
|
| 210 |
+
|
| 211 |
+
# 默认情况下继续
|
| 212 |
+
return "continue"
|
| 213 |
+
|
| 214 |
+
# ========================
|
| 215 |
+
# 构建LangGraph工作流
|
| 216 |
+
# ========================
|
| 217 |
+
|
| 218 |
+
# 创建状态图
|
| 219 |
+
workflow = StateGraph(AgentState)
|
| 220 |
+
|
| 221 |
+
# 添加节点
|
| 222 |
+
workflow.add_node("plan", plan_node)
|
| 223 |
+
workflow.add_node("select_tool", tool_selection_node)
|
| 224 |
+
workflow.add_node("execute_tool", tool_execution_node)
|
| 225 |
+
workflow.add_node("generate_response", response_generation_node)
|
| 226 |
+
|
| 227 |
+
# 设置入口点
|
| 228 |
+
workflow.set_entry_point("plan")
|
| 229 |
+
|
| 230 |
+
# 添加边
|
| 231 |
+
workflow.add_edge("plan", "select_tool")
|
| 232 |
+
workflow.add_edge("select_tool", "execute_tool")
|
| 233 |
+
workflow.add_edge("execute_tool", "generate_response")
|
| 234 |
+
|
| 235 |
+
# 添加条件边
|
| 236 |
+
workflow.add_conditional_edges(
|
| 237 |
+
"generate_response",
|
| 238 |
+
should_continue,
|
| 239 |
+
{
|
| 240 |
+
"continue": "plan", # 继续规划
|
| 241 |
+
"end": END # 结束
|
| 242 |
+
}
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# 编译工作流
|
| 246 |
+
agent_workflow = workflow.compile()
|
| 247 |
+
|
| 248 |
+
# ========================
|
| 249 |
+
# Gradio界面 (模拟GradioUI)
|
| 250 |
+
# ========================
|
| 251 |
+
|
| 252 |
+
class GradioUI:
|
| 253 |
+
def __init__(self, agent_workflow):
|
| 254 |
+
self.agent_workflow = agent_workflow
|
| 255 |
+
|
| 256 |
+
def run_agent(self, message):
|
| 257 |
+
"""运行代理工作流"""
|
| 258 |
+
# 初始化状态
|
| 259 |
+
initial_state = {
|
| 260 |
+
"messages": [HumanMessage(content=message)],
|
| 261 |
+
"plan": None,
|
| 262 |
+
"tool_results": [],
|
| 263 |
+
"step_count": 0,
|
| 264 |
+
"final_response": None
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
# 执行工作流
|
| 268 |
+
for step in self.agent_workflow.stream(initial_state):
|
| 269 |
+
for node_name, node_state in step.items():
|
| 270 |
+
# 当有最终响应时返回
|
| 271 |
+
if "final_response" in node_state:
|
| 272 |
+
return self.format_history(node_state["messages"])
|
| 273 |
+
|
| 274 |
+
# 如果未返回,则使用最终状态
|
| 275 |
+
return self.format_history(initial_state["messages"])
|
| 276 |
+
|
| 277 |
+
def format_history(self, messages):
|
| 278 |
+
"""格式化消息历史用于显示"""
|
| 279 |
+
formatted = []
|
| 280 |
+
for msg in messages:
|
| 281 |
+
if isinstance(msg, HumanMessage):
|
| 282 |
+
formatted.append(("用户", msg.content))
|
| 283 |
+
elif isinstance(msg, AIMessage):
|
| 284 |
+
# 区分不同类型的AI消息
|
| 285 |
+
if "计划:" in msg.content:
|
| 286 |
+
formatted.append(("Alfred (计划)", msg.content))
|
| 287 |
+
elif "选择工具:" in msg.content:
|
| 288 |
+
formatted.append(("Alfred (工具选择)", msg.content))
|
| 289 |
+
elif "工具执行结果:" in msg.content:
|
| 290 |
+
formatted.append(("Alfred (工具结果)", msg.content))
|
| 291 |
+
else:
|
| 292 |
+
formatted.append(("Alfred", msg.content))
|
| 293 |
+
return formatted
|
| 294 |
+
|
| 295 |
+
def chat_interface(self, message, history):
|
| 296 |
+
"""Gradio聊天界面函数"""
|
| 297 |
+
# 添加当前消息到历史
|
| 298 |
+
history = history or []
|
| 299 |
+
history.append(("用户", message))
|
| 300 |
+
|
| 301 |
+
# 运行代理
|
| 302 |
+
agent_history = self.run_agent(message)
|
| 303 |
+
|
| 304 |
+
# 更新Gradio历史记录
|
| 305 |
+
for role, content in agent_history:
|
| 306 |
+
# 只添加新的消息
|
| 307 |
+
if (role, content) not in history:
|
| 308 |
+
history.append((role, content))
|
| 309 |
+
|
| 310 |
+
# 提取最终响应
|
| 311 |
+
final_response = next((content for role, content in reversed(history) if role == "Alfred"), "")
|
| 312 |
+
|
| 313 |
+
# 转换为Gradio格式
|
| 314 |
+
gradio_history = []
|
| 315 |
+
for role, content in history:
|
| 316 |
+
if role.startswith("用户"):
|
| 317 |
+
gradio_history.append((content, None))
|
| 318 |
+
else:
|
| 319 |
+
if gradio_history and gradio_history[-1][1] is None:
|
| 320 |
+
gradio_history[-1] = (gradio_history[-1][0], content)
|
| 321 |
+
else:
|
| 322 |
+
gradio_history.append((None, f"[{role}]: {content}"))
|
| 323 |
+
|
| 324 |
+
return "", gradio_history
|
| 325 |
+
|
| 326 |
+
def launch(self):
|
| 327 |
+
"""启动Gradio界面"""
|
| 328 |
+
with gr.Blocks(title="Alfred Agent with LangGraph") as demo:
|
| 329 |
+
gr.Markdown("# 🤖 Alfred 智能代理 (LangGraph 实现)")
|
| 330 |
+
gr.Markdown("使用LangGraph框架实现的智能代理系统,支持多种工具调用")
|
| 331 |
+
|
| 332 |
+
with gr.Row():
|
| 333 |
+
with gr.Column(scale=3):
|
| 334 |
+
chatbot = gr.Chatbot(label="对话历史", height=500)
|
| 335 |
+
msg = gr.Textbox(label="输入消息", placeholder="在此输入您的问题...")
|
| 336 |
+
clear_btn = gr.ClearButton([msg, chatbot])
|
| 337 |
+
|
| 338 |
+
# 示例问题
|
| 339 |
+
gr.Examples(
|
| 340 |
+
examples=[
|
| 341 |
+
"今天的天气怎么样?",
|
| 342 |
+
"搜索LangGraph的最新信息",
|
| 343 |
+
"Hugging Face上有多少模型?",
|
| 344 |
+
"昨天的嘉宾是谁?"
|
| 345 |
+
],
|
| 346 |
+
inputs=msg
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
with gr.Column(scale=1):
|
| 350 |
+
gr.Markdown("### 可用工具")
|
| 351 |
+
for tool_name in TOOLS.keys():
|
| 352 |
+
gr.Markdown(f"- {tool_name}")
|
| 353 |
+
|
| 354 |
+
# 事件处理
|
| 355 |
+
msg.submit(
|
| 356 |
+
self.chat_interface,
|
| 357 |
+
[msg, chatbot],
|
| 358 |
+
[msg, chatbot]
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
demo.launch()
|
| 362 |
+
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
# 创建并启动Gradio界面
|
| 365 |
+
GradioUI(agent_workflow).launch()
|