#!/usr/bin/env python3 """ 虫群聚合协议 — 主入口 AggregationProtocol 是整个聚合协议的统一入口,整合: - 节点发现(discovery) - 通信层(transport) - 任务调度(scheduler) - 权限管理(permission) 使用方式: protocol = AggregationProtocol( node_id="queen_zhangming", role=NodeRole.QUEEN, address="http://localhost:8080", ) protocol.start() # 提交聚合任务 result = protocol.submit_task("分析这段代码的性能瓶颈", ...) # 关闭 protocol.stop() """ import hashlib import logging import threading import time from datetime import datetime from typing import Callable, Dict, List, Optional from .types import ( NodeInfo, NodeCapability, NodeRole, NodeStatus, PermissionLevel, AggregationStrategy, AggregationTask, TaskForce, ProtocolMessage, ) from .discovery import NodeRegistry from .transport import MessageBus, LocalTransport from .scheduler import TaskForceManager from .permission import PermissionManager logger = logging.getLogger(__name__) class AggregationProtocol: """ 虫群聚合协议 — 核心入口 设计理念: - 像GPU集群一样,将分散的设备临时组合成服务器 - 每个设备可同时参与多个临时服务器 - 权限控制调用范围和数量 - 任务完成自动解散,设备回归空闲池 """ def __init__( self, node_id: str, role: NodeRole = NodeRole.QUEEN, name: str = "", address: str = "", permission_level: PermissionLevel = PermissionLevel.QUEEN, capabilities: List[str] = None, compute_score: float = 1.0, max_concurrent: int = 3, heartbeat_timeout: float = 30.0, ): # 自身节点信息 self.node_info = NodeInfo( node_id=node_id, name=name or node_id, role=role, status=NodeStatus.ONLINE, address=address, capability=NodeCapability( specializations=capabilities or ["general"], compute_score=compute_score, max_concurrent=max_concurrent, ), permission_level=permission_level, ) # 核心组件 self.registry = NodeRegistry(node_id, heartbeat_timeout) self.bus = MessageBus(self.node_info) self.scheduler = TaskForceManager(self.registry, self.bus) self.permission = PermissionManager(permission_level) # 本地传输(同一进程内通信) self._local_transport = LocalTransport() # 运行状态 self._running = False self._heartbeat_thread = None # 消息处理 self._setup_message_handlers() # 统计 self._stats = { "start_time": None, "tasks_submitted": 0, "tasks_completed": 0, "messages_sent": 0, } # ============================================================ # 生命周期 # ============================================================ def start(self): """启动聚合协议""" if self._running: return self._running = True self.node_info.status = NodeStatus.ONLINE # 注册自身 self.registry.register(self.node_info) self.bus.register_node_address(self.node_info.node_id, self.node_info.address) self._local_transport.register_bus(self.node_info.node_id, self.bus) self.permission.set_node_permission(self.node_info.node_id, self.node_info.permission_level) # 启动消息总线 self.bus.start() # 启动心跳线程 self._heartbeat_thread = threading.Thread( target=self._heartbeat_loop, daemon=True ) self._heartbeat_thread.start() self._stats["start_time"] = datetime.now().isoformat() logger.info(f"聚合协议启动: {self.node_info.node_id} ({self.node_info.role.value})") def stop(self): """停止聚合协议""" self._running = False self.node_info.status = NodeStatus.OFFLINE self.bus.stop() logger.info(f"聚合协议停止: {self.node_info.node_id}") # ============================================================ # 节点管理 # ============================================================ def add_remote_node(self, node_id: str, name: str = "", role: NodeRole = NodeRole.QUEEN, address: str = "", capabilities: List[str] = None, compute_score: float = 1.0, permission_level: PermissionLevel = PermissionLevel.QUEEN): """添加远程节点""" node = NodeInfo( node_id=node_id, name=name or node_id, role=role, address=address, capability=NodeCapability( specializations=capabilities or ["general"], compute_score=compute_score, ), permission_level=permission_level, ) self.registry.register(node) if address: self.bus.register_node_address(node_id, address) self.permission.set_node_permission(node_id, permission_level) def add_local_node(self, protocol: 'AggregationProtocol'): """添加本地节点(同一进程)""" self.registry.register(protocol.node_info) self._local_transport.register_bus(protocol.node_info.node_id, protocol.bus) self.permission.set_node_permission( protocol.node_info.node_id, protocol.node_info.permission_level ) def remove_node(self, node_id: str): """移除节点""" self.registry.unregister(node_id) self.bus.unregister_node_address(node_id) # ============================================================ # 聚合任务 # ============================================================ def submit_task( self, query: str, required_capabilities: List[str] = None, min_nodes: int = 2, max_nodes: int = 5, strategy: AggregationStrategy = AggregationStrategy.ADAPTIVE_MIX, priority: int = 5, timeout_sec: float = 60.0, ) -> Dict: """ 提交聚合任务 流程: 1. 创建任务 2. 权限检查(可调用多少节点) 3. 组建临时服务器 4. 分发任务到各节点 5. 聚合结果 6. 解散临时服务器 Returns: {"taskforce_id": str, "status": str, "result": Dict} """ self._stats["tasks_submitted"] += 1 # 1. 创建任务 task_id = self._gen_id("task") task = AggregationTask( task_id=task_id, query=query, requester=self.node_info.node_id, required_capabilities=required_capabilities or [], min_nodes=min_nodes, max_nodes=max_nodes, strategy=strategy, priority=priority, timeout_sec=timeout_sec, ) # 权限检查 max_allowed = self.permission.get_max_callable_count(self.node_info.node_id) task.max_nodes = min(task.max_nodes, max_allowed) # 虫皇权限:至少可以调用min_nodes个 if self.node_info.permission_level == PermissionLevel.OVERMIND: task.max_nodes = max(task.max_nodes, max(min_nodes, 1)) if task.max_nodes < 1: return { "task_id": task_id, "status": "denied", "error": "权限不足,无法调用其他节点", } # 3. 组建临时服务器 tf = self.scheduler.create_taskforce(task) if not tf: return { "task_id": task_id, "status": "failed", "error": "无法组建临时服务器,可用节点不足", } # 4. 分发任务到各节点 node_results = {} for member_id in tf.members: # 权限检查(虫皇自动通过) if (self.node_info.permission_level != PermissionLevel.OVERMIND and not self.permission.can_call_node(self.node_info.node_id, member_id)): continue # 本地模拟:生成占位结果 # 实际场景中通过消息总线发送任务并等待回复 node = self.registry.get_node(member_id) node_results[member_id] = { "response": f"[来自{node.name if node else member_id}针对\"{query[:20]}\"的回答]", "confidence": 0.7 + (node.capability.compute_score * 0.05 if node else 0), } # 5. 聚合结果 final_result = self.scheduler.aggregate_results(tf.taskforce_id, node_results) # 6. 解散临时服务器 self.scheduler.complete_taskforce(tf.taskforce_id, final_result) self._stats["tasks_completed"] += 1 return { "task_id": task_id, "taskforce_id": tf.taskforce_id, "status": "completed", "result": final_result, "members": tf.members, "strategy": strategy.value, } # ============================================================ # 消息处理 # ============================================================ def _setup_message_handlers(self): """设置消息处理器""" self.bus.register_handler("discover", self._handle_discover) self.bus.register_handler("join_taskforce", self._handle_join) self.bus.register_handler("leave_taskforce", self._handle_leave) self.bus.register_handler("task_assign", self._handle_task_assign) self.bus.register_handler("task_result", self._handle_task_result) self.bus.register_handler("heartbeat", self._handle_heartbeat) def _handle_discover(self, msg: ProtocolMessage): """处理发现消息""" node_data = msg.payload.get("node_info", {}) if node_data: node = NodeInfo( node_id=node_data.get("node_id", msg.sender), name=node_data.get("name", ""), role=NodeRole(node_data.get("role", "queen")), address=node_data.get("address", ""), capability=NodeCapability( specializations=node_data.get("capabilities", ["general"]), compute_score=node_data.get("compute_score", 1.0), ), ) self.registry.register(node) def _handle_join(self, msg: ProtocolMessage): """处理加入临时服务器""" tf_id = msg.payload.get("taskforce_id", "") if tf_id and self.node_info.node_id in msg.payload.get("members", []): self.node_info.current_taskforces.append(tf_id) def _handle_leave(self, msg: ProtocolMessage): """处理离开临时服务器""" tf_id = msg.payload.get("taskforce_id", "") if tf_id in self.node_info.current_taskforces: self.node_info.current_taskforces.remove(tf_id) def _handle_task_assign(self, msg: ProtocolMessage): """处理任务分配""" # 子类可覆写此方法实现具体的任务执行逻辑 pass def _handle_task_result(self, msg: ProtocolMessage): """处理任务结果""" pass def _handle_heartbeat(self, msg: ProtocolMessage): """处理心跳""" self.registry.update_heartbeat(msg.sender) # ============================================================ # 心跳 # ============================================================ def _heartbeat_loop(self): """心跳循环""" while self._running: try: # 广播心跳 self.bus.broadcast("heartbeat", { "node_id": self.node_info.node_id, "status": self.node_info.status.value, "active_taskforces": len(self.node_info.current_taskforces), }) # 检查其他节点心跳 timeout_nodes = self.registry.check_heartbeats() if timeout_nodes: logger.warning(f"心跳超时节点: {timeout_nodes}") except Exception as e: logger.error(f"心跳异常: {e}") time.sleep(10) # 每10秒一次 # ============================================================ # 工具方法 # ============================================================ def _gen_id(self, prefix: str) -> str: return f"{prefix}_{hashlib.md5(f'{time.time()}{prefix}'.encode()).hexdigest()[:8]}" def get_status(self) -> Dict: """获取协议状态""" return { "node": self.node_info.to_dict(), "registry": self.registry.get_status(), "scheduler": self.scheduler.get_stats(), "permission": self.permission.get_stats(), "transport": self.bus.get_stats(), "protocol_stats": self._stats, "running": self._running, } def discover_nodes(self) -> List[Dict]: """发现所有可用节点""" nodes = self.registry.get_all_nodes() return [n.to_dict() for n in nodes if n.is_available()]