lk080424's picture
Upload core/aggregation_protocol/protocol.py with huggingface_hub
cb60518 verified
#!/usr/bin/env python3
"""
虫群聚合协议 — 主入口
AggregationProtocol 是整个聚合协议的统一入口,整合:
- 节点发现(discovery)
- 通信层(transport)
- 任务调度(scheduler)
- 权限管理(permission)
使用方式:
protocol = AggregationProtocol(
node_id="queen_zhangming",
role=NodeRole.QUEEN,
address="http://localhost:8080",
)
protocol.start()
# 提交聚合任务
result = protocol.submit_task("分析这段代码的性能瓶颈", ...)
# 关闭
protocol.stop()
"""
import hashlib
import logging
import threading
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional
from .types import (
NodeInfo, NodeCapability, NodeRole, NodeStatus,
PermissionLevel, AggregationStrategy, AggregationTask,
TaskForce, ProtocolMessage,
)
from .discovery import NodeRegistry
from .transport import MessageBus, LocalTransport
from .scheduler import TaskForceManager
from .permission import PermissionManager
logger = logging.getLogger(__name__)
class AggregationProtocol:
"""
虫群聚合协议 — 核心入口
设计理念:
- 像GPU集群一样,将分散的设备临时组合成服务器
- 每个设备可同时参与多个临时服务器
- 权限控制调用范围和数量
- 任务完成自动解散,设备回归空闲池
"""
def __init__(
self,
node_id: str,
role: NodeRole = NodeRole.QUEEN,
name: str = "",
address: str = "",
permission_level: PermissionLevel = PermissionLevel.QUEEN,
capabilities: List[str] = None,
compute_score: float = 1.0,
max_concurrent: int = 3,
heartbeat_timeout: float = 30.0,
):
# 自身节点信息
self.node_info = NodeInfo(
node_id=node_id,
name=name or node_id,
role=role,
status=NodeStatus.ONLINE,
address=address,
capability=NodeCapability(
specializations=capabilities or ["general"],
compute_score=compute_score,
max_concurrent=max_concurrent,
),
permission_level=permission_level,
)
# 核心组件
self.registry = NodeRegistry(node_id, heartbeat_timeout)
self.bus = MessageBus(self.node_info)
self.scheduler = TaskForceManager(self.registry, self.bus)
self.permission = PermissionManager(permission_level)
# 本地传输(同一进程内通信)
self._local_transport = LocalTransport()
# 运行状态
self._running = False
self._heartbeat_thread = None
# 消息处理
self._setup_message_handlers()
# 统计
self._stats = {
"start_time": None,
"tasks_submitted": 0,
"tasks_completed": 0,
"messages_sent": 0,
}
# ============================================================
# 生命周期
# ============================================================
def start(self):
"""启动聚合协议"""
if self._running:
return
self._running = True
self.node_info.status = NodeStatus.ONLINE
# 注册自身
self.registry.register(self.node_info)
self.bus.register_node_address(self.node_info.node_id, self.node_info.address)
self._local_transport.register_bus(self.node_info.node_id, self.bus)
self.permission.set_node_permission(self.node_info.node_id, self.node_info.permission_level)
# 启动消息总线
self.bus.start()
# 启动心跳线程
self._heartbeat_thread = threading.Thread(
target=self._heartbeat_loop, daemon=True
)
self._heartbeat_thread.start()
self._stats["start_time"] = datetime.now().isoformat()
logger.info(f"聚合协议启动: {self.node_info.node_id} ({self.node_info.role.value})")
def stop(self):
"""停止聚合协议"""
self._running = False
self.node_info.status = NodeStatus.OFFLINE
self.bus.stop()
logger.info(f"聚合协议停止: {self.node_info.node_id}")
# ============================================================
# 节点管理
# ============================================================
def add_remote_node(self, node_id: str, name: str = "",
role: NodeRole = NodeRole.QUEEN,
address: str = "",
capabilities: List[str] = None,
compute_score: float = 1.0,
permission_level: PermissionLevel = PermissionLevel.QUEEN):
"""添加远程节点"""
node = NodeInfo(
node_id=node_id,
name=name or node_id,
role=role,
address=address,
capability=NodeCapability(
specializations=capabilities or ["general"],
compute_score=compute_score,
),
permission_level=permission_level,
)
self.registry.register(node)
if address:
self.bus.register_node_address(node_id, address)
self.permission.set_node_permission(node_id, permission_level)
def add_local_node(self, protocol: 'AggregationProtocol'):
"""添加本地节点(同一进程)"""
self.registry.register(protocol.node_info)
self._local_transport.register_bus(protocol.node_info.node_id, protocol.bus)
self.permission.set_node_permission(
protocol.node_info.node_id,
protocol.node_info.permission_level
)
def remove_node(self, node_id: str):
"""移除节点"""
self.registry.unregister(node_id)
self.bus.unregister_node_address(node_id)
# ============================================================
# 聚合任务
# ============================================================
def submit_task(
self,
query: str,
required_capabilities: List[str] = None,
min_nodes: int = 2,
max_nodes: int = 5,
strategy: AggregationStrategy = AggregationStrategy.ADAPTIVE_MIX,
priority: int = 5,
timeout_sec: float = 60.0,
) -> Dict:
"""
提交聚合任务
流程:
1. 创建任务
2. 权限检查(可调用多少节点)
3. 组建临时服务器
4. 分发任务到各节点
5. 聚合结果
6. 解散临时服务器
Returns:
{"taskforce_id": str, "status": str, "result": Dict}
"""
self._stats["tasks_submitted"] += 1
# 1. 创建任务
task_id = self._gen_id("task")
task = AggregationTask(
task_id=task_id,
query=query,
requester=self.node_info.node_id,
required_capabilities=required_capabilities or [],
min_nodes=min_nodes,
max_nodes=max_nodes,
strategy=strategy,
priority=priority,
timeout_sec=timeout_sec,
)
# 权限检查
max_allowed = self.permission.get_max_callable_count(self.node_info.node_id)
task.max_nodes = min(task.max_nodes, max_allowed)
# 虫皇权限:至少可以调用min_nodes个
if self.node_info.permission_level == PermissionLevel.OVERMIND:
task.max_nodes = max(task.max_nodes, max(min_nodes, 1))
if task.max_nodes < 1:
return {
"task_id": task_id,
"status": "denied",
"error": "权限不足,无法调用其他节点",
}
# 3. 组建临时服务器
tf = self.scheduler.create_taskforce(task)
if not tf:
return {
"task_id": task_id,
"status": "failed",
"error": "无法组建临时服务器,可用节点不足",
}
# 4. 分发任务到各节点
node_results = {}
for member_id in tf.members:
# 权限检查(虫皇自动通过)
if (self.node_info.permission_level != PermissionLevel.OVERMIND and
not self.permission.can_call_node(self.node_info.node_id, member_id)):
continue
# 本地模拟:生成占位结果
# 实际场景中通过消息总线发送任务并等待回复
node = self.registry.get_node(member_id)
node_results[member_id] = {
"response": f"[来自{node.name if node else member_id}针对\"{query[:20]}\"的回答]",
"confidence": 0.7 + (node.capability.compute_score * 0.05 if node else 0),
}
# 5. 聚合结果
final_result = self.scheduler.aggregate_results(tf.taskforce_id, node_results)
# 6. 解散临时服务器
self.scheduler.complete_taskforce(tf.taskforce_id, final_result)
self._stats["tasks_completed"] += 1
return {
"task_id": task_id,
"taskforce_id": tf.taskforce_id,
"status": "completed",
"result": final_result,
"members": tf.members,
"strategy": strategy.value,
}
# ============================================================
# 消息处理
# ============================================================
def _setup_message_handlers(self):
"""设置消息处理器"""
self.bus.register_handler("discover", self._handle_discover)
self.bus.register_handler("join_taskforce", self._handle_join)
self.bus.register_handler("leave_taskforce", self._handle_leave)
self.bus.register_handler("task_assign", self._handle_task_assign)
self.bus.register_handler("task_result", self._handle_task_result)
self.bus.register_handler("heartbeat", self._handle_heartbeat)
def _handle_discover(self, msg: ProtocolMessage):
"""处理发现消息"""
node_data = msg.payload.get("node_info", {})
if node_data:
node = NodeInfo(
node_id=node_data.get("node_id", msg.sender),
name=node_data.get("name", ""),
role=NodeRole(node_data.get("role", "queen")),
address=node_data.get("address", ""),
capability=NodeCapability(
specializations=node_data.get("capabilities", ["general"]),
compute_score=node_data.get("compute_score", 1.0),
),
)
self.registry.register(node)
def _handle_join(self, msg: ProtocolMessage):
"""处理加入临时服务器"""
tf_id = msg.payload.get("taskforce_id", "")
if tf_id and self.node_info.node_id in msg.payload.get("members", []):
self.node_info.current_taskforces.append(tf_id)
def _handle_leave(self, msg: ProtocolMessage):
"""处理离开临时服务器"""
tf_id = msg.payload.get("taskforce_id", "")
if tf_id in self.node_info.current_taskforces:
self.node_info.current_taskforces.remove(tf_id)
def _handle_task_assign(self, msg: ProtocolMessage):
"""处理任务分配"""
# 子类可覆写此方法实现具体的任务执行逻辑
pass
def _handle_task_result(self, msg: ProtocolMessage):
"""处理任务结果"""
pass
def _handle_heartbeat(self, msg: ProtocolMessage):
"""处理心跳"""
self.registry.update_heartbeat(msg.sender)
# ============================================================
# 心跳
# ============================================================
def _heartbeat_loop(self):
"""心跳循环"""
while self._running:
try:
# 广播心跳
self.bus.broadcast("heartbeat", {
"node_id": self.node_info.node_id,
"status": self.node_info.status.value,
"active_taskforces": len(self.node_info.current_taskforces),
})
# 检查其他节点心跳
timeout_nodes = self.registry.check_heartbeats()
if timeout_nodes:
logger.warning(f"心跳超时节点: {timeout_nodes}")
except Exception as e:
logger.error(f"心跳异常: {e}")
time.sleep(10) # 每10秒一次
# ============================================================
# 工具方法
# ============================================================
def _gen_id(self, prefix: str) -> str:
return f"{prefix}_{hashlib.md5(f'{time.time()}{prefix}'.encode()).hexdigest()[:8]}"
def get_status(self) -> Dict:
"""获取协议状态"""
return {
"node": self.node_info.to_dict(),
"registry": self.registry.get_status(),
"scheduler": self.scheduler.get_stats(),
"permission": self.permission.get_stats(),
"transport": self.bus.get_stats(),
"protocol_stats": self._stats,
"running": self._running,
}
def discover_nodes(self) -> List[Dict]:
"""发现所有可用节点"""
nodes = self.registry.get_all_nodes()
return [n.to_dict() for n in nodes if n.is_available()]