Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群聚合协议 — 主入口 | |
| AggregationProtocol 是整个聚合协议的统一入口,整合: | |
| - 节点发现(discovery) | |
| - 通信层(transport) | |
| - 任务调度(scheduler) | |
| - 权限管理(permission) | |
| 使用方式: | |
| protocol = AggregationProtocol( | |
| node_id="queen_zhangming", | |
| role=NodeRole.QUEEN, | |
| address="http://localhost:8080", | |
| ) | |
| protocol.start() | |
| # 提交聚合任务 | |
| result = protocol.submit_task("分析这段代码的性能瓶颈", ...) | |
| # 关闭 | |
| protocol.stop() | |
| """ | |
| import hashlib | |
| import logging | |
| import threading | |
| import time | |
| from datetime import datetime | |
| from typing import Callable, Dict, List, Optional | |
| from .types import ( | |
| NodeInfo, NodeCapability, NodeRole, NodeStatus, | |
| PermissionLevel, AggregationStrategy, AggregationTask, | |
| TaskForce, ProtocolMessage, | |
| ) | |
| from .discovery import NodeRegistry | |
| from .transport import MessageBus, LocalTransport | |
| from .scheduler import TaskForceManager | |
| from .permission import PermissionManager | |
| logger = logging.getLogger(__name__) | |
| class AggregationProtocol: | |
| """ | |
| 虫群聚合协议 — 核心入口 | |
| 设计理念: | |
| - 像GPU集群一样,将分散的设备临时组合成服务器 | |
| - 每个设备可同时参与多个临时服务器 | |
| - 权限控制调用范围和数量 | |
| - 任务完成自动解散,设备回归空闲池 | |
| """ | |
| def __init__( | |
| self, | |
| node_id: str, | |
| role: NodeRole = NodeRole.QUEEN, | |
| name: str = "", | |
| address: str = "", | |
| permission_level: PermissionLevel = PermissionLevel.QUEEN, | |
| capabilities: List[str] = None, | |
| compute_score: float = 1.0, | |
| max_concurrent: int = 3, | |
| heartbeat_timeout: float = 30.0, | |
| ): | |
| # 自身节点信息 | |
| self.node_info = NodeInfo( | |
| node_id=node_id, | |
| name=name or node_id, | |
| role=role, | |
| status=NodeStatus.ONLINE, | |
| address=address, | |
| capability=NodeCapability( | |
| specializations=capabilities or ["general"], | |
| compute_score=compute_score, | |
| max_concurrent=max_concurrent, | |
| ), | |
| permission_level=permission_level, | |
| ) | |
| # 核心组件 | |
| self.registry = NodeRegistry(node_id, heartbeat_timeout) | |
| self.bus = MessageBus(self.node_info) | |
| self.scheduler = TaskForceManager(self.registry, self.bus) | |
| self.permission = PermissionManager(permission_level) | |
| # 本地传输(同一进程内通信) | |
| self._local_transport = LocalTransport() | |
| # 运行状态 | |
| self._running = False | |
| self._heartbeat_thread = None | |
| # 消息处理 | |
| self._setup_message_handlers() | |
| # 统计 | |
| self._stats = { | |
| "start_time": None, | |
| "tasks_submitted": 0, | |
| "tasks_completed": 0, | |
| "messages_sent": 0, | |
| } | |
| # ============================================================ | |
| # 生命周期 | |
| # ============================================================ | |
| def start(self): | |
| """启动聚合协议""" | |
| if self._running: | |
| return | |
| self._running = True | |
| self.node_info.status = NodeStatus.ONLINE | |
| # 注册自身 | |
| self.registry.register(self.node_info) | |
| self.bus.register_node_address(self.node_info.node_id, self.node_info.address) | |
| self._local_transport.register_bus(self.node_info.node_id, self.bus) | |
| self.permission.set_node_permission(self.node_info.node_id, self.node_info.permission_level) | |
| # 启动消息总线 | |
| self.bus.start() | |
| # 启动心跳线程 | |
| self._heartbeat_thread = threading.Thread( | |
| target=self._heartbeat_loop, daemon=True | |
| ) | |
| self._heartbeat_thread.start() | |
| self._stats["start_time"] = datetime.now().isoformat() | |
| logger.info(f"聚合协议启动: {self.node_info.node_id} ({self.node_info.role.value})") | |
| def stop(self): | |
| """停止聚合协议""" | |
| self._running = False | |
| self.node_info.status = NodeStatus.OFFLINE | |
| self.bus.stop() | |
| logger.info(f"聚合协议停止: {self.node_info.node_id}") | |
| # ============================================================ | |
| # 节点管理 | |
| # ============================================================ | |
| def add_remote_node(self, node_id: str, name: str = "", | |
| role: NodeRole = NodeRole.QUEEN, | |
| address: str = "", | |
| capabilities: List[str] = None, | |
| compute_score: float = 1.0, | |
| permission_level: PermissionLevel = PermissionLevel.QUEEN): | |
| """添加远程节点""" | |
| node = NodeInfo( | |
| node_id=node_id, | |
| name=name or node_id, | |
| role=role, | |
| address=address, | |
| capability=NodeCapability( | |
| specializations=capabilities or ["general"], | |
| compute_score=compute_score, | |
| ), | |
| permission_level=permission_level, | |
| ) | |
| self.registry.register(node) | |
| if address: | |
| self.bus.register_node_address(node_id, address) | |
| self.permission.set_node_permission(node_id, permission_level) | |
| def add_local_node(self, protocol: 'AggregationProtocol'): | |
| """添加本地节点(同一进程)""" | |
| self.registry.register(protocol.node_info) | |
| self._local_transport.register_bus(protocol.node_info.node_id, protocol.bus) | |
| self.permission.set_node_permission( | |
| protocol.node_info.node_id, | |
| protocol.node_info.permission_level | |
| ) | |
| def remove_node(self, node_id: str): | |
| """移除节点""" | |
| self.registry.unregister(node_id) | |
| self.bus.unregister_node_address(node_id) | |
| # ============================================================ | |
| # 聚合任务 | |
| # ============================================================ | |
| def submit_task( | |
| self, | |
| query: str, | |
| required_capabilities: List[str] = None, | |
| min_nodes: int = 2, | |
| max_nodes: int = 5, | |
| strategy: AggregationStrategy = AggregationStrategy.ADAPTIVE_MIX, | |
| priority: int = 5, | |
| timeout_sec: float = 60.0, | |
| ) -> Dict: | |
| """ | |
| 提交聚合任务 | |
| 流程: | |
| 1. 创建任务 | |
| 2. 权限检查(可调用多少节点) | |
| 3. 组建临时服务器 | |
| 4. 分发任务到各节点 | |
| 5. 聚合结果 | |
| 6. 解散临时服务器 | |
| Returns: | |
| {"taskforce_id": str, "status": str, "result": Dict} | |
| """ | |
| self._stats["tasks_submitted"] += 1 | |
| # 1. 创建任务 | |
| task_id = self._gen_id("task") | |
| task = AggregationTask( | |
| task_id=task_id, | |
| query=query, | |
| requester=self.node_info.node_id, | |
| required_capabilities=required_capabilities or [], | |
| min_nodes=min_nodes, | |
| max_nodes=max_nodes, | |
| strategy=strategy, | |
| priority=priority, | |
| timeout_sec=timeout_sec, | |
| ) | |
| # 权限检查 | |
| max_allowed = self.permission.get_max_callable_count(self.node_info.node_id) | |
| task.max_nodes = min(task.max_nodes, max_allowed) | |
| # 虫皇权限:至少可以调用min_nodes个 | |
| if self.node_info.permission_level == PermissionLevel.OVERMIND: | |
| task.max_nodes = max(task.max_nodes, max(min_nodes, 1)) | |
| if task.max_nodes < 1: | |
| return { | |
| "task_id": task_id, | |
| "status": "denied", | |
| "error": "权限不足,无法调用其他节点", | |
| } | |
| # 3. 组建临时服务器 | |
| tf = self.scheduler.create_taskforce(task) | |
| if not tf: | |
| return { | |
| "task_id": task_id, | |
| "status": "failed", | |
| "error": "无法组建临时服务器,可用节点不足", | |
| } | |
| # 4. 分发任务到各节点 | |
| node_results = {} | |
| for member_id in tf.members: | |
| # 权限检查(虫皇自动通过) | |
| if (self.node_info.permission_level != PermissionLevel.OVERMIND and | |
| not self.permission.can_call_node(self.node_info.node_id, member_id)): | |
| continue | |
| # 本地模拟:生成占位结果 | |
| # 实际场景中通过消息总线发送任务并等待回复 | |
| node = self.registry.get_node(member_id) | |
| node_results[member_id] = { | |
| "response": f"[来自{node.name if node else member_id}针对\"{query[:20]}\"的回答]", | |
| "confidence": 0.7 + (node.capability.compute_score * 0.05 if node else 0), | |
| } | |
| # 5. 聚合结果 | |
| final_result = self.scheduler.aggregate_results(tf.taskforce_id, node_results) | |
| # 6. 解散临时服务器 | |
| self.scheduler.complete_taskforce(tf.taskforce_id, final_result) | |
| self._stats["tasks_completed"] += 1 | |
| return { | |
| "task_id": task_id, | |
| "taskforce_id": tf.taskforce_id, | |
| "status": "completed", | |
| "result": final_result, | |
| "members": tf.members, | |
| "strategy": strategy.value, | |
| } | |
| # ============================================================ | |
| # 消息处理 | |
| # ============================================================ | |
| def _setup_message_handlers(self): | |
| """设置消息处理器""" | |
| self.bus.register_handler("discover", self._handle_discover) | |
| self.bus.register_handler("join_taskforce", self._handle_join) | |
| self.bus.register_handler("leave_taskforce", self._handle_leave) | |
| self.bus.register_handler("task_assign", self._handle_task_assign) | |
| self.bus.register_handler("task_result", self._handle_task_result) | |
| self.bus.register_handler("heartbeat", self._handle_heartbeat) | |
| def _handle_discover(self, msg: ProtocolMessage): | |
| """处理发现消息""" | |
| node_data = msg.payload.get("node_info", {}) | |
| if node_data: | |
| node = NodeInfo( | |
| node_id=node_data.get("node_id", msg.sender), | |
| name=node_data.get("name", ""), | |
| role=NodeRole(node_data.get("role", "queen")), | |
| address=node_data.get("address", ""), | |
| capability=NodeCapability( | |
| specializations=node_data.get("capabilities", ["general"]), | |
| compute_score=node_data.get("compute_score", 1.0), | |
| ), | |
| ) | |
| self.registry.register(node) | |
| def _handle_join(self, msg: ProtocolMessage): | |
| """处理加入临时服务器""" | |
| tf_id = msg.payload.get("taskforce_id", "") | |
| if tf_id and self.node_info.node_id in msg.payload.get("members", []): | |
| self.node_info.current_taskforces.append(tf_id) | |
| def _handle_leave(self, msg: ProtocolMessage): | |
| """处理离开临时服务器""" | |
| tf_id = msg.payload.get("taskforce_id", "") | |
| if tf_id in self.node_info.current_taskforces: | |
| self.node_info.current_taskforces.remove(tf_id) | |
| def _handle_task_assign(self, msg: ProtocolMessage): | |
| """处理任务分配""" | |
| # 子类可覆写此方法实现具体的任务执行逻辑 | |
| pass | |
| def _handle_task_result(self, msg: ProtocolMessage): | |
| """处理任务结果""" | |
| pass | |
| def _handle_heartbeat(self, msg: ProtocolMessage): | |
| """处理心跳""" | |
| self.registry.update_heartbeat(msg.sender) | |
| # ============================================================ | |
| # 心跳 | |
| # ============================================================ | |
| def _heartbeat_loop(self): | |
| """心跳循环""" | |
| while self._running: | |
| try: | |
| # 广播心跳 | |
| self.bus.broadcast("heartbeat", { | |
| "node_id": self.node_info.node_id, | |
| "status": self.node_info.status.value, | |
| "active_taskforces": len(self.node_info.current_taskforces), | |
| }) | |
| # 检查其他节点心跳 | |
| timeout_nodes = self.registry.check_heartbeats() | |
| if timeout_nodes: | |
| logger.warning(f"心跳超时节点: {timeout_nodes}") | |
| except Exception as e: | |
| logger.error(f"心跳异常: {e}") | |
| time.sleep(10) # 每10秒一次 | |
| # ============================================================ | |
| # 工具方法 | |
| # ============================================================ | |
| def _gen_id(self, prefix: str) -> str: | |
| return f"{prefix}_{hashlib.md5(f'{time.time()}{prefix}'.encode()).hexdigest()[:8]}" | |
| def get_status(self) -> Dict: | |
| """获取协议状态""" | |
| return { | |
| "node": self.node_info.to_dict(), | |
| "registry": self.registry.get_status(), | |
| "scheduler": self.scheduler.get_stats(), | |
| "permission": self.permission.get_stats(), | |
| "transport": self.bus.get_stats(), | |
| "protocol_stats": self._stats, | |
| "running": self._running, | |
| } | |
| def discover_nodes(self) -> List[Dict]: | |
| """发现所有可用节点""" | |
| nodes = self.registry.get_all_nodes() | |
| return [n.to_dict() for n in nodes if n.is_available()] | |