lk080424's picture
Upload core/aggregation_protocol/discovery.py with huggingface_hub
4523329 verified
#!/usr/bin/env python3
"""
虫群聚合协议 — 节点发现模块
功能:
- 节点注册与注销
- 心跳检测(判断节点在线/离线)
- 能力查询(找到满足条件的节点)
- 广播发现(局域网/已知节点广播)
"""
import hashlib
import threading
import time
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable, Dict, List, Optional
from .types import (
NodeInfo, NodeCapability, NodeRole, NodeStatus,
PermissionLevel, ProtocolMessage,
)
class NodeRegistry:
"""
节点注册中心 — 维护所有已知节点的信息
类似GPU集群的服务器列表,但这里是分布式的:
- 每个节点维护自己看到的节点列表
- 心跳机制检测节点存活
- 按能力匹配找到合适的节点组建临时服务器
"""
def __init__(self, self_node_id: str, heartbeat_timeout: float = 30.0):
self.self_node_id = self_node_id
self.heartbeat_timeout = heartbeat_timeout
# 已知节点 node_id -> NodeInfo
self._nodes: Dict[str, NodeInfo] = {}
self._lock = threading.RLock()
# 能力索引: capability_name -> set of node_ids
self._cap_index: Dict[str, set] = defaultdict(set)
# 角色索引: role -> set of node_ids
self._role_index: Dict[NodeRole, set] = defaultdict(set)
# 事件回调
self._callbacks: Dict[str, List[Callable]] = defaultdict(list)
# ============================================================
# 节点管理
# ============================================================
def register(self, node: NodeInfo) -> bool:
"""注册节点"""
with self._lock:
self._nodes[node.node_id] = node
node.last_heartbeat = datetime.now()
# 更新索引
for cap in node.capability.specializations:
self._cap_index[cap].add(node.node_id)
self._role_index[node.role].add(node.node_id)
self._fire("node_joined", node)
return True
def unregister(self, node_id: str) -> bool:
"""注销节点"""
with self._lock:
node = self._nodes.pop(node_id, None)
if not node:
return False
# 清理索引
for cap in node.capability.specializations:
self._cap_index[cap].discard(node_id)
self._role_index[node.role].discard(node_id)
self._fire("node_left", node)
return True
def update_heartbeat(self, node_id: str) -> bool:
"""更新心跳"""
with self._lock:
node = self._nodes.get(node_id)
if node:
node.last_heartbeat = datetime.now()
node.status = NodeStatus.ONLINE
return True
return False
def get_node(self, node_id: str) -> Optional[NodeInfo]:
"""获取节点信息"""
with self._lock:
return self._nodes.get(node_id)
def get_all_nodes(self) -> List[NodeInfo]:
"""获取所有已知节点"""
with self._lock:
return list(self._nodes.values())
# ============================================================
# 节点发现 — 找到满足条件的节点
# ============================================================
def discover_by_capability(self, capability: str,
min_compute: float = 0.0,
exclude: List[str] = None) -> List[NodeInfo]:
"""按能力发现节点"""
with self._lock:
candidates = self._cap_index.get(capability, set())
exclude = exclude or []
results = []
for nid in candidates:
if nid in exclude or nid == self.self_node_id:
continue
node = self._nodes.get(nid)
if node and node.is_available() and node.capability.compute_score >= min_compute:
results.append(node)
# 按计算能力排序
results.sort(key=lambda n: n.capability.compute_score, reverse=True)
return results
def discover_by_role(self, role: NodeRole,
exclude: List[str] = None) -> List[NodeInfo]:
"""按角色发现节点"""
with self._lock:
candidates = self._role_index.get(role, set())
exclude = exclude or []
results = []
for nid in candidates:
if nid in exclude or nid == self.self_node_id:
continue
node = self._nodes.get(nid)
if node and node.is_available():
results.append(node)
return results
def discover_for_task(self, required_caps: List[str],
min_nodes: int = 2,
max_nodes: int = 5,
exclude: List[str] = None) -> List[NodeInfo]:
"""
为任务发现合适的节点
核心逻辑:找到同时满足所有所需能力的节点集合
类似于GPU集群中找有空闲显存的显卡
"""
with self._lock:
exclude = exclude or []
# 计算每个可用节点的匹配度
scored_nodes = []
for node in self._nodes.values():
if not node.can_accept_task():
continue
if node.node_id in exclude or node.node_id == self.self_node_id:
continue
# 计算能力匹配度
matched = sum(1 for cap in required_caps
if cap in node.capability.specializations)
score = matched / max(len(required_caps), 1)
if matched > 0 or not required_caps: # 至少匹配一个或无特殊要求
scored_nodes.append((node, score, matched))
# 先按匹配度,再按计算能力排序
scored_nodes.sort(key=lambda x: (x[1], x[0].capability.compute_score), reverse=True)
# 取top max_nodes
selected = [n for n, s, m in scored_nodes[:max_nodes]]
# 确保最少节点数
if len(selected) < min_nodes:
# 放宽条件,加入任何可用节点
for node in self._nodes.values():
if (node.can_accept_task() and
node.node_id not in exclude and
node.node_id != self.self_node_id and
node.node_id not in [n.node_id for n in selected]):
selected.append(node)
if len(selected) >= min_nodes:
break
return selected[:max_nodes]
# ============================================================
# 心跳检测
# ============================================================
def check_heartbeats(self) -> List[str]:
"""检查心跳,标记超时节点"""
now = datetime.now()
timeout_ids = []
with self._lock:
for node_id, node in list(self._nodes.items()):
if node_id == self.self_node_id:
continue
if node.last_heartbeat:
elapsed = (now - node.last_heartbeat).total_seconds()
if elapsed > self.heartbeat_timeout:
node.status = NodeStatus.OFFLINE
timeout_ids.append(node_id)
else:
# 没有心跳记录
node.status = NodeStatus.OFFLINE
return timeout_ids
def get_online_count(self) -> int:
"""获取在线节点数"""
with self._lock:
return sum(1 for n in self._nodes.values() if n.is_available())
# ============================================================
# 事件回调
# ============================================================
def on(self, event: str, callback: Callable):
"""注册事件回调"""
self._callbacks[event].append(callback)
def _fire(self, event: str, data=None):
"""触发事件"""
for cb in self._callbacks.get(event, []):
try:
cb(data)
except Exception:
pass
# ============================================================
# 状态信息
# ============================================================
def get_status(self) -> Dict:
"""获取注册中心状态"""
with self._lock:
online = sum(1 for n in self._nodes.values() if n.is_available())
total = len(self._nodes)
return {
"self_node_id": self.self_node_id,
"total_nodes": total,
"online_nodes": online,
"offline_nodes": total - online,
"capabilities_indexed": len(self._cap_index),
"roles": {r.value: len(ids) for r, ids in self._role_index.items()},
}