File size: 9,373 Bytes
4523329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#!/usr/bin/env python3
"""
虫群聚合协议 — 节点发现模块

功能:
- 节点注册与注销
- 心跳检测(判断节点在线/离线)
- 能力查询(找到满足条件的节点)
- 广播发现(局域网/已知节点广播)
"""

import hashlib
import threading
import time
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable, Dict, List, Optional

from .types import (
    NodeInfo, NodeCapability, NodeRole, NodeStatus,
    PermissionLevel, ProtocolMessage,
)


class NodeRegistry:
    """
    节点注册中心 — 维护所有已知节点的信息
    
    类似GPU集群的服务器列表,但这里是分布式的:
    - 每个节点维护自己看到的节点列表
    - 心跳机制检测节点存活
    - 按能力匹配找到合适的节点组建临时服务器
    """

    def __init__(self, self_node_id: str, heartbeat_timeout: float = 30.0):
        self.self_node_id = self_node_id
        self.heartbeat_timeout = heartbeat_timeout
        
        # 已知节点 node_id -> NodeInfo
        self._nodes: Dict[str, NodeInfo] = {}
        self._lock = threading.RLock()
        
        # 能力索引: capability_name -> set of node_ids
        self._cap_index: Dict[str, set] = defaultdict(set)
        
        # 角色索引: role -> set of node_ids
        self._role_index: Dict[NodeRole, set] = defaultdict(set)
        
        # 事件回调
        self._callbacks: Dict[str, List[Callable]] = defaultdict(list)

    # ============================================================
    # 节点管理
    # ============================================================

    def register(self, node: NodeInfo) -> bool:
        """注册节点"""
        with self._lock:
            self._nodes[node.node_id] = node
            node.last_heartbeat = datetime.now()
            
            # 更新索引
            for cap in node.capability.specializations:
                self._cap_index[cap].add(node.node_id)
            self._role_index[node.role].add(node.node_id)
        
        self._fire("node_joined", node)
        return True

    def unregister(self, node_id: str) -> bool:
        """注销节点"""
        with self._lock:
            node = self._nodes.pop(node_id, None)
            if not node:
                return False
            
            # 清理索引
            for cap in node.capability.specializations:
                self._cap_index[cap].discard(node_id)
            self._role_index[node.role].discard(node_id)
        
        self._fire("node_left", node)
        return True

    def update_heartbeat(self, node_id: str) -> bool:
        """更新心跳"""
        with self._lock:
            node = self._nodes.get(node_id)
            if node:
                node.last_heartbeat = datetime.now()
                node.status = NodeStatus.ONLINE
                return True
        return False

    def get_node(self, node_id: str) -> Optional[NodeInfo]:
        """获取节点信息"""
        with self._lock:
            return self._nodes.get(node_id)

    def get_all_nodes(self) -> List[NodeInfo]:
        """获取所有已知节点"""
        with self._lock:
            return list(self._nodes.values())

    # ============================================================
    # 节点发现 — 找到满足条件的节点
    # ============================================================

    def discover_by_capability(self, capability: str, 
                                min_compute: float = 0.0,
                                exclude: List[str] = None) -> List[NodeInfo]:
        """按能力发现节点"""
        with self._lock:
            candidates = self._cap_index.get(capability, set())
            exclude = exclude or []
            
            results = []
            for nid in candidates:
                if nid in exclude or nid == self.self_node_id:
                    continue
                node = self._nodes.get(nid)
                if node and node.is_available() and node.capability.compute_score >= min_compute:
                    results.append(node)
            
            # 按计算能力排序
            results.sort(key=lambda n: n.capability.compute_score, reverse=True)
            return results

    def discover_by_role(self, role: NodeRole,
                         exclude: List[str] = None) -> List[NodeInfo]:
        """按角色发现节点"""
        with self._lock:
            candidates = self._role_index.get(role, set())
            exclude = exclude or []
            
            results = []
            for nid in candidates:
                if nid in exclude or nid == self.self_node_id:
                    continue
                node = self._nodes.get(nid)
                if node and node.is_available():
                    results.append(node)
            return results

    def discover_for_task(self, required_caps: List[str],
                          min_nodes: int = 2,
                          max_nodes: int = 5,
                          exclude: List[str] = None) -> List[NodeInfo]:
        """
        为任务发现合适的节点
        
        核心逻辑:找到同时满足所有所需能力的节点集合
        类似于GPU集群中找有空闲显存的显卡
        """
        with self._lock:
            exclude = exclude or []
            
            # 计算每个可用节点的匹配度
            scored_nodes = []
            for node in self._nodes.values():
                if not node.can_accept_task():
                    continue
                if node.node_id in exclude or node.node_id == self.self_node_id:
                    continue
                
                # 计算能力匹配度
                matched = sum(1 for cap in required_caps 
                             if cap in node.capability.specializations)
                score = matched / max(len(required_caps), 1)
                
                if matched > 0 or not required_caps:  # 至少匹配一个或无特殊要求
                    scored_nodes.append((node, score, matched))
            
            # 先按匹配度,再按计算能力排序
            scored_nodes.sort(key=lambda x: (x[1], x[0].capability.compute_score), reverse=True)
            
            # 取top max_nodes
            selected = [n for n, s, m in scored_nodes[:max_nodes]]
            
            # 确保最少节点数
            if len(selected) < min_nodes:
                # 放宽条件,加入任何可用节点
                for node in self._nodes.values():
                    if (node.can_accept_task() and 
                        node.node_id not in exclude and
                        node.node_id != self.self_node_id and
                        node.node_id not in [n.node_id for n in selected]):
                        selected.append(node)
                        if len(selected) >= min_nodes:
                            break
            
            return selected[:max_nodes]

    # ============================================================
    # 心跳检测
    # ============================================================

    def check_heartbeats(self) -> List[str]:
        """检查心跳,标记超时节点"""
        now = datetime.now()
        timeout_ids = []
        
        with self._lock:
            for node_id, node in list(self._nodes.items()):
                if node_id == self.self_node_id:
                    continue
                if node.last_heartbeat:
                    elapsed = (now - node.last_heartbeat).total_seconds()
                    if elapsed > self.heartbeat_timeout:
                        node.status = NodeStatus.OFFLINE
                        timeout_ids.append(node_id)
                else:
                    # 没有心跳记录
                    node.status = NodeStatus.OFFLINE
        
        return timeout_ids

    def get_online_count(self) -> int:
        """获取在线节点数"""
        with self._lock:
            return sum(1 for n in self._nodes.values() if n.is_available())

    # ============================================================
    # 事件回调
    # ============================================================

    def on(self, event: str, callback: Callable):
        """注册事件回调"""
        self._callbacks[event].append(callback)

    def _fire(self, event: str, data=None):
        """触发事件"""
        for cb in self._callbacks.get(event, []):
            try:
                cb(data)
            except Exception:
                pass

    # ============================================================
    # 状态信息
    # ============================================================

    def get_status(self) -> Dict:
        """获取注册中心状态"""
        with self._lock:
            online = sum(1 for n in self._nodes.values() if n.is_available())
            total = len(self._nodes)
            return {
                "self_node_id": self.self_node_id,
                "total_nodes": total,
                "online_nodes": online,
                "offline_nodes": total - online,
                "capabilities_indexed": len(self._cap_index),
                "roles": {r.value: len(ids) for r, ids in self._role_index.items()},
            }