Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群聚合协议 — 节点发现模块 | |
| 功能: | |
| - 节点注册与注销 | |
| - 心跳检测(判断节点在线/离线) | |
| - 能力查询(找到满足条件的节点) | |
| - 广播发现(局域网/已知节点广播) | |
| """ | |
| import hashlib | |
| import threading | |
| import time | |
| from collections import defaultdict | |
| from datetime import datetime, timedelta | |
| from typing import Callable, Dict, List, Optional | |
| from .types import ( | |
| NodeInfo, NodeCapability, NodeRole, NodeStatus, | |
| PermissionLevel, ProtocolMessage, | |
| ) | |
| class NodeRegistry: | |
| """ | |
| 节点注册中心 — 维护所有已知节点的信息 | |
| 类似GPU集群的服务器列表,但这里是分布式的: | |
| - 每个节点维护自己看到的节点列表 | |
| - 心跳机制检测节点存活 | |
| - 按能力匹配找到合适的节点组建临时服务器 | |
| """ | |
| def __init__(self, self_node_id: str, heartbeat_timeout: float = 30.0): | |
| self.self_node_id = self_node_id | |
| self.heartbeat_timeout = heartbeat_timeout | |
| # 已知节点 node_id -> NodeInfo | |
| self._nodes: Dict[str, NodeInfo] = {} | |
| self._lock = threading.RLock() | |
| # 能力索引: capability_name -> set of node_ids | |
| self._cap_index: Dict[str, set] = defaultdict(set) | |
| # 角色索引: role -> set of node_ids | |
| self._role_index: Dict[NodeRole, set] = defaultdict(set) | |
| # 事件回调 | |
| self._callbacks: Dict[str, List[Callable]] = defaultdict(list) | |
| # ============================================================ | |
| # 节点管理 | |
| # ============================================================ | |
| def register(self, node: NodeInfo) -> bool: | |
| """注册节点""" | |
| with self._lock: | |
| self._nodes[node.node_id] = node | |
| node.last_heartbeat = datetime.now() | |
| # 更新索引 | |
| for cap in node.capability.specializations: | |
| self._cap_index[cap].add(node.node_id) | |
| self._role_index[node.role].add(node.node_id) | |
| self._fire("node_joined", node) | |
| return True | |
| def unregister(self, node_id: str) -> bool: | |
| """注销节点""" | |
| with self._lock: | |
| node = self._nodes.pop(node_id, None) | |
| if not node: | |
| return False | |
| # 清理索引 | |
| for cap in node.capability.specializations: | |
| self._cap_index[cap].discard(node_id) | |
| self._role_index[node.role].discard(node_id) | |
| self._fire("node_left", node) | |
| return True | |
| def update_heartbeat(self, node_id: str) -> bool: | |
| """更新心跳""" | |
| with self._lock: | |
| node = self._nodes.get(node_id) | |
| if node: | |
| node.last_heartbeat = datetime.now() | |
| node.status = NodeStatus.ONLINE | |
| return True | |
| return False | |
| def get_node(self, node_id: str) -> Optional[NodeInfo]: | |
| """获取节点信息""" | |
| with self._lock: | |
| return self._nodes.get(node_id) | |
| def get_all_nodes(self) -> List[NodeInfo]: | |
| """获取所有已知节点""" | |
| with self._lock: | |
| return list(self._nodes.values()) | |
| # ============================================================ | |
| # 节点发现 — 找到满足条件的节点 | |
| # ============================================================ | |
| def discover_by_capability(self, capability: str, | |
| min_compute: float = 0.0, | |
| exclude: List[str] = None) -> List[NodeInfo]: | |
| """按能力发现节点""" | |
| with self._lock: | |
| candidates = self._cap_index.get(capability, set()) | |
| exclude = exclude or [] | |
| results = [] | |
| for nid in candidates: | |
| if nid in exclude or nid == self.self_node_id: | |
| continue | |
| node = self._nodes.get(nid) | |
| if node and node.is_available() and node.capability.compute_score >= min_compute: | |
| results.append(node) | |
| # 按计算能力排序 | |
| results.sort(key=lambda n: n.capability.compute_score, reverse=True) | |
| return results | |
| def discover_by_role(self, role: NodeRole, | |
| exclude: List[str] = None) -> List[NodeInfo]: | |
| """按角色发现节点""" | |
| with self._lock: | |
| candidates = self._role_index.get(role, set()) | |
| exclude = exclude or [] | |
| results = [] | |
| for nid in candidates: | |
| if nid in exclude or nid == self.self_node_id: | |
| continue | |
| node = self._nodes.get(nid) | |
| if node and node.is_available(): | |
| results.append(node) | |
| return results | |
| def discover_for_task(self, required_caps: List[str], | |
| min_nodes: int = 2, | |
| max_nodes: int = 5, | |
| exclude: List[str] = None) -> List[NodeInfo]: | |
| """ | |
| 为任务发现合适的节点 | |
| 核心逻辑:找到同时满足所有所需能力的节点集合 | |
| 类似于GPU集群中找有空闲显存的显卡 | |
| """ | |
| with self._lock: | |
| exclude = exclude or [] | |
| # 计算每个可用节点的匹配度 | |
| scored_nodes = [] | |
| for node in self._nodes.values(): | |
| if not node.can_accept_task(): | |
| continue | |
| if node.node_id in exclude or node.node_id == self.self_node_id: | |
| continue | |
| # 计算能力匹配度 | |
| matched = sum(1 for cap in required_caps | |
| if cap in node.capability.specializations) | |
| score = matched / max(len(required_caps), 1) | |
| if matched > 0 or not required_caps: # 至少匹配一个或无特殊要求 | |
| scored_nodes.append((node, score, matched)) | |
| # 先按匹配度,再按计算能力排序 | |
| scored_nodes.sort(key=lambda x: (x[1], x[0].capability.compute_score), reverse=True) | |
| # 取top max_nodes | |
| selected = [n for n, s, m in scored_nodes[:max_nodes]] | |
| # 确保最少节点数 | |
| if len(selected) < min_nodes: | |
| # 放宽条件,加入任何可用节点 | |
| for node in self._nodes.values(): | |
| if (node.can_accept_task() and | |
| node.node_id not in exclude and | |
| node.node_id != self.self_node_id and | |
| node.node_id not in [n.node_id for n in selected]): | |
| selected.append(node) | |
| if len(selected) >= min_nodes: | |
| break | |
| return selected[:max_nodes] | |
| # ============================================================ | |
| # 心跳检测 | |
| # ============================================================ | |
| def check_heartbeats(self) -> List[str]: | |
| """检查心跳,标记超时节点""" | |
| now = datetime.now() | |
| timeout_ids = [] | |
| with self._lock: | |
| for node_id, node in list(self._nodes.items()): | |
| if node_id == self.self_node_id: | |
| continue | |
| if node.last_heartbeat: | |
| elapsed = (now - node.last_heartbeat).total_seconds() | |
| if elapsed > self.heartbeat_timeout: | |
| node.status = NodeStatus.OFFLINE | |
| timeout_ids.append(node_id) | |
| else: | |
| # 没有心跳记录 | |
| node.status = NodeStatus.OFFLINE | |
| return timeout_ids | |
| def get_online_count(self) -> int: | |
| """获取在线节点数""" | |
| with self._lock: | |
| return sum(1 for n in self._nodes.values() if n.is_available()) | |
| # ============================================================ | |
| # 事件回调 | |
| # ============================================================ | |
| def on(self, event: str, callback: Callable): | |
| """注册事件回调""" | |
| self._callbacks[event].append(callback) | |
| def _fire(self, event: str, data=None): | |
| """触发事件""" | |
| for cb in self._callbacks.get(event, []): | |
| try: | |
| cb(data) | |
| except Exception: | |
| pass | |
| # ============================================================ | |
| # 状态信息 | |
| # ============================================================ | |
| def get_status(self) -> Dict: | |
| """获取注册中心状态""" | |
| with self._lock: | |
| online = sum(1 for n in self._nodes.values() if n.is_available()) | |
| total = len(self._nodes) | |
| return { | |
| "self_node_id": self.self_node_id, | |
| "total_nodes": total, | |
| "online_nodes": online, | |
| "offline_nodes": total - online, | |
| "capabilities_indexed": len(self._cap_index), | |
| "roles": {r.value: len(ids) for r, ids in self._role_index.items()}, | |
| } | |