base-teams / team_framework /load_agents.py
airsltd's picture
update
44a03f6
from google.adk.agents import Agent
import os
def load_agents_from_db_records(agent_records: list, tool_map: dict, before_model_callback):
"""
从数据库记录中加载 agent 配置,创建 Agent 对象,并建立关联。
"""
agents_by_name = {}
loaded_agents_list = [] # Keep a list for the final output if needed
if agent_records:
for agent_data in agent_records:
try:
# 假设数据库记录包含 name, model, instruction, description, sub_agents, sort_id, is_root, global_instruction 字段
agent_name = agent_data.get("name", f"agent_{agent_data.get('id', 'unknown')}")
global_instruction=agent_data.get("global_instruction", "")
if (global_instruction == None):
global_instruction = ""
agent = Agent(
name=agent_name,
model=agent_data.get("model", "gemini-2.0-flash-exp"),
instruction=agent_data.get("instruction", ""),
description=agent_data.get("description", ""),
global_instruction=global_instruction,
# sub_agents will be added in the next step
# callbacks etc. can be added here if present in data
)
# Map tool names from database to actual tool functions
tools_list = [tool_map.get(tool_name) for tool_name in agent_data.get("tools") or [] if tool_name in tool_map]
agent.tools = tools_list # Use the mapped tool functions
agents_by_name[agent_name] = agent
loaded_agents_list.append(agent) # Add to list if needed later
print(f"成功创建 agent 对象: {agent.name}")
except Exception as e:
print(f"创建 agent 对象时发生错误 (数据: {agent_data}): {e}")
print(f"总共加载了 {len(agents_by_name)} 个 agents。")
# 处理 sub_agents 引用
if agent_records:
for agent_data in agent_records:
agent_name = agent_data.get("name", f"agent_{agent_data.get('id', 'unknown')}")
current_agent = agents_by_name.get(agent_name)
if current_agent and "sub_agents" in agent_data and isinstance(agent_data["sub_agents"], list):
current_agent.sub_agents = []
for sub_agent_name in agent_data["sub_agents"]:
sub_agent = agents_by_name.get(sub_agent_name)
if sub_agent:
current_agent.sub_agents.append(sub_agent)
print(f"为 agent '{current_agent.name}' 添加 sub-agent: '{sub_agent.name}'")
else:
print(f"警告: 未找到 sub-agent '{sub_agent_name}' (被 agent '{current_agent.name}' 引用)")
print(f"总共加载并关联了 {len(agents_by_name)} 个 agents。")
# 根据找到的索引设置 root_agent
root_agent = None
root_ageint_idx = -1 # Need to find the root index within this function now
if agent_records:
for index, agent_data in enumerate(agent_records):
if agent_data.get("is_root", False): # 假设数据库记录包含 is_root 字段
root_ageint_idx = index
print(f"找到 root agent 的索引: {root_ageint_idx}")
break # 假设只有一个 root agent
if 0 <= root_ageint_idx < len(loaded_agents_list):
root_agent = loaded_agents_list[root_ageint_idx]
else:
print("警告: 未找到 root agent 或索引无效,使用默认 root_agent。")
# 如果没有找到 root agent,可以考虑使用一个默认的 root_agent 或者抛出错误
root_agent = Agent(
name="default_root_agent",
model="gemini-2.0-flash-exp",
instruction="You are a helpful assistant.",
description="You are a helpful assistant.",
tools=[]
)
# 轮询 loaded_agents_list
for agent in loaded_agents_list:
agent.before_model_callback=before_model_callback
if agent.sub_agents:
for sub_agent in agent.sub_agents:
sub_agent.parent_agent = agent
return root_agent, agents_by_name
def load_agent(item, before_model_callback):
print(f"\nload_agent item: {item['name']}", item)
_agent = Agent(
name = item['name'],
model = item['model'],
instruction = item['instruction'],
description = item['description'],
before_model_callback=before_model_callback
)
return _agent
def load_agents(agents_dict:dict, before_model_callback):
print("\n\n\n######## load_agents\nagents_dict: ", agents_dict)
ret_agents = {}
for item in agents_dict:
ret_agents[item['name']] = load_agent(item, before_model_callback)
# ret_agents.append(load_agent(item))
print("\n\n\n@@@@@@@@@@@@@@@@@@@@@\nload_agents\nret_agents.greeting_agent: ", ret_agents['greeting_agent'])
print("\n\n\n@@@@@@@@@@@@@@@@@@@@@\nload_agents\nret_agents[farewell_agent]: ", ret_agents['farewell_agent'])
return ret_agents
def load_root_agent(_root_agent, before_model_callback):
ret_agent = load_agent(_root_agent, before_model_callback)
print("\n\n\n######## load_root_agent\nret_agent: ", ret_agent)
return ret_agent