#!/usr/bin/env python3 """ 虫群聚合协议 — 节点发现模块 功能: - 节点注册与注销 - 心跳检测(判断节点在线/离线) - 能力查询(找到满足条件的节点) - 广播发现(局域网/已知节点广播) """ import hashlib import threading import time from collections import defaultdict from datetime import datetime, timedelta from typing import Callable, Dict, List, Optional from .types import ( NodeInfo, NodeCapability, NodeRole, NodeStatus, PermissionLevel, ProtocolMessage, ) class NodeRegistry: """ 节点注册中心 — 维护所有已知节点的信息 类似GPU集群的服务器列表,但这里是分布式的: - 每个节点维护自己看到的节点列表 - 心跳机制检测节点存活 - 按能力匹配找到合适的节点组建临时服务器 """ def __init__(self, self_node_id: str, heartbeat_timeout: float = 30.0): self.self_node_id = self_node_id self.heartbeat_timeout = heartbeat_timeout # 已知节点 node_id -> NodeInfo self._nodes: Dict[str, NodeInfo] = {} self._lock = threading.RLock() # 能力索引: capability_name -> set of node_ids self._cap_index: Dict[str, set] = defaultdict(set) # 角色索引: role -> set of node_ids self._role_index: Dict[NodeRole, set] = defaultdict(set) # 事件回调 self._callbacks: Dict[str, List[Callable]] = defaultdict(list) # ============================================================ # 节点管理 # ============================================================ def register(self, node: NodeInfo) -> bool: """注册节点""" with self._lock: self._nodes[node.node_id] = node node.last_heartbeat = datetime.now() # 更新索引 for cap in node.capability.specializations: self._cap_index[cap].add(node.node_id) self._role_index[node.role].add(node.node_id) self._fire("node_joined", node) return True def unregister(self, node_id: str) -> bool: """注销节点""" with self._lock: node = self._nodes.pop(node_id, None) if not node: return False # 清理索引 for cap in node.capability.specializations: self._cap_index[cap].discard(node_id) self._role_index[node.role].discard(node_id) self._fire("node_left", node) return True def update_heartbeat(self, node_id: str) -> bool: """更新心跳""" with self._lock: node = self._nodes.get(node_id) if node: node.last_heartbeat = datetime.now() node.status = NodeStatus.ONLINE return True return False def get_node(self, node_id: str) -> Optional[NodeInfo]: """获取节点信息""" with self._lock: return self._nodes.get(node_id) def get_all_nodes(self) -> List[NodeInfo]: """获取所有已知节点""" with self._lock: return list(self._nodes.values()) # ============================================================ # 节点发现 — 找到满足条件的节点 # ============================================================ def discover_by_capability(self, capability: str, min_compute: float = 0.0, exclude: List[str] = None) -> List[NodeInfo]: """按能力发现节点""" with self._lock: candidates = self._cap_index.get(capability, set()) exclude = exclude or [] results = [] for nid in candidates: if nid in exclude or nid == self.self_node_id: continue node = self._nodes.get(nid) if node and node.is_available() and node.capability.compute_score >= min_compute: results.append(node) # 按计算能力排序 results.sort(key=lambda n: n.capability.compute_score, reverse=True) return results def discover_by_role(self, role: NodeRole, exclude: List[str] = None) -> List[NodeInfo]: """按角色发现节点""" with self._lock: candidates = self._role_index.get(role, set()) exclude = exclude or [] results = [] for nid in candidates: if nid in exclude or nid == self.self_node_id: continue node = self._nodes.get(nid) if node and node.is_available(): results.append(node) return results def discover_for_task(self, required_caps: List[str], min_nodes: int = 2, max_nodes: int = 5, exclude: List[str] = None) -> List[NodeInfo]: """ 为任务发现合适的节点 核心逻辑:找到同时满足所有所需能力的节点集合 类似于GPU集群中找有空闲显存的显卡 """ with self._lock: exclude = exclude or [] # 计算每个可用节点的匹配度 scored_nodes = [] for node in self._nodes.values(): if not node.can_accept_task(): continue if node.node_id in exclude or node.node_id == self.self_node_id: continue # 计算能力匹配度 matched = sum(1 for cap in required_caps if cap in node.capability.specializations) score = matched / max(len(required_caps), 1) if matched > 0 or not required_caps: # 至少匹配一个或无特殊要求 scored_nodes.append((node, score, matched)) # 先按匹配度,再按计算能力排序 scored_nodes.sort(key=lambda x: (x[1], x[0].capability.compute_score), reverse=True) # 取top max_nodes selected = [n for n, s, m in scored_nodes[:max_nodes]] # 确保最少节点数 if len(selected) < min_nodes: # 放宽条件,加入任何可用节点 for node in self._nodes.values(): if (node.can_accept_task() and node.node_id not in exclude and node.node_id != self.self_node_id and node.node_id not in [n.node_id for n in selected]): selected.append(node) if len(selected) >= min_nodes: break return selected[:max_nodes] # ============================================================ # 心跳检测 # ============================================================ def check_heartbeats(self) -> List[str]: """检查心跳,标记超时节点""" now = datetime.now() timeout_ids = [] with self._lock: for node_id, node in list(self._nodes.items()): if node_id == self.self_node_id: continue if node.last_heartbeat: elapsed = (now - node.last_heartbeat).total_seconds() if elapsed > self.heartbeat_timeout: node.status = NodeStatus.OFFLINE timeout_ids.append(node_id) else: # 没有心跳记录 node.status = NodeStatus.OFFLINE return timeout_ids def get_online_count(self) -> int: """获取在线节点数""" with self._lock: return sum(1 for n in self._nodes.values() if n.is_available()) # ============================================================ # 事件回调 # ============================================================ def on(self, event: str, callback: Callable): """注册事件回调""" self._callbacks[event].append(callback) def _fire(self, event: str, data=None): """触发事件""" for cb in self._callbacks.get(event, []): try: cb(data) except Exception: pass # ============================================================ # 状态信息 # ============================================================ def get_status(self) -> Dict: """获取注册中心状态""" with self._lock: online = sum(1 for n in self._nodes.values() if n.is_available()) total = len(self._nodes) return { "self_node_id": self.self_node_id, "total_nodes": total, "online_nodes": online, "offline_nodes": total - online, "capabilities_indexed": len(self._cap_index), "roles": {r.value: len(ids) for r, ids in self._role_index.items()}, }