lk080424 commited on
Commit
3c72bc3
·
verified ·
1 Parent(s): 466030f

Upload core/aggregation_protocol/scheduler.py with huggingface_hub

Browse files
core/aggregation_protocol/scheduler.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 虫群聚合协议 — 任务调度器
4
+
5
+ 核心功能:
6
+ - 接收任务请求
7
+ - 组建临时服务器(TaskForce)
8
+ - 分配任务到各节点
9
+ - 聚合各节点结果
10
+ - 任务完成后解散临时服务器
11
+
12
+ 类比GPU集群:
13
+ - 任务调度器 = SLURM/调度系统
14
+ - 临时服务器 = 临时分配的GPU组
15
+ - 节点 = 单块GPU
16
+ """
17
+
18
+ import hashlib
19
+ import logging
20
+ import threading
21
+ import time
22
+ from collections import deque
23
+ from datetime import datetime
24
+ from typing import Callable, Dict, List, Optional
25
+
26
+ from .types import (
27
+ AggregationStrategy, AggregationTask, TaskForce, TaskForceStatus,
28
+ NodeInfo, NodeStatus, ProtocolMessage,
29
+ )
30
+ from .discovery import NodeRegistry
31
+ from .transport import MessageBus
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class TaskForceManager:
37
+ """
38
+ 临时服务器管理器
39
+
40
+ 生命周期:创建 → 组建 → 运行 → 完成 → 解散
41
+ """
42
+
43
+ def __init__(self, node_registry: NodeRegistry, message_bus: MessageBus):
44
+ self.registry = node_registry
45
+ self.bus = message_bus
46
+
47
+ # 活跃的临时服务器 taskforce_id -> TaskForce
48
+ self._taskforces: Dict[str, TaskForce] = {}
49
+ self._lock = threading.RLock()
50
+
51
+ # 任务队列
52
+ self._task_queue = deque()
53
+
54
+ # 统计
55
+ self._stats = {
56
+ "taskforces_created": 0,
57
+ "taskforces_completed": 0,
58
+ "taskforces_failed": 0,
59
+ "tasks_processed": 0,
60
+ }
61
+
62
+ # ============================================================
63
+ # 临时服务器生命周期
64
+ # ============================================================
65
+
66
+ def create_taskforce(self, task: AggregationTask) -> Optional[TaskForce]:
67
+ """
68
+ 为任务创建临时服务器
69
+
70
+ 类似GPU集群分配资源:
71
+ 1. 分析任务需要的能力
72
+ 2. 从可用节点中选择
73
+ 3. 组建临时服务器
74
+ 4. 通知成员节点
75
+ """
76
+ # 1. 发现合适的节点
77
+ candidates = self.registry.discover_for_task(
78
+ required_caps=task.required_capabilities,
79
+ min_nodes=task.min_nodes,
80
+ max_nodes=task.max_nodes,
81
+ exclude=[task.requester],
82
+ )
83
+
84
+ if len(candidates) < task.min_nodes:
85
+ logger.warning(f"节点不足: 需要{task.min_nodes},找到{len(candidates)}")
86
+ return None
87
+
88
+ # 2. 创建临时服务器
89
+ tf_id = self._gen_id("tf")
90
+ tf = TaskForce(
91
+ taskforce_id=tf_id,
92
+ name=f"TaskForce-{tf_id}",
93
+ coordinator=task.requester,
94
+ strategy=task.strategy,
95
+ task_description=task.query,
96
+ )
97
+
98
+ # 3. 添加成员
99
+ for node in candidates:
100
+ tf.add_member(node.node_id)
101
+ # 更新节点的当前任务列表
102
+ node.current_taskforces.append(tf_id)
103
+
104
+ # 4. 记录
105
+ with self._lock:
106
+ self._taskforces[tf_id] = tf
107
+ task.taskforce_id = tf_id
108
+ task.status = "assigned"
109
+ self._stats["taskforces_created"] += 1
110
+
111
+ # 5. 通知成员(通过消息总线)
112
+ self.bus.broadcast("join_taskforce", {
113
+ "taskforce_id": tf_id,
114
+ "coordinator": task.requester,
115
+ "task": task.query,
116
+ "members": tf.members,
117
+ "strategy": tf.strategy.value,
118
+ })
119
+
120
+ logger.info(f"临时服务器创建: {tf_id}, 成员: {tf.members}")
121
+ return tf
122
+
123
+ def complete_taskforce(self, tf_id: str, result: Dict):
124
+ """完成任务,解散临时服务器"""
125
+ with self._lock:
126
+ tf = self._taskforces.get(tf_id)
127
+ if not tf:
128
+ return
129
+
130
+ tf.status = TaskForceStatus.COMPLETED
131
+ tf.completed_at = datetime.now()
132
+ tf.results = result
133
+
134
+ # 清理成员节点的任务列表
135
+ for member_id in tf.members:
136
+ node = self.registry.get_node(member_id)
137
+ if node and tf_id in node.current_taskforces:
138
+ node.current_taskforces.remove(tf_id)
139
+
140
+ self._stats["taskforces_completed"] += 1
141
+
142
+ # 通知解散
143
+ self.bus.broadcast("leave_taskforce", {
144
+ "taskforce_id": tf_id,
145
+ "status": "completed",
146
+ })
147
+
148
+ logger.info(f"临时服务器解散: {tf_id}")
149
+
150
+ def fail_taskforce(self, tf_id: str, reason: str = ""):
151
+ """临时服务器失败"""
152
+ with self._lock:
153
+ tf = self._taskforces.get(tf_id)
154
+ if not tf:
155
+ return
156
+
157
+ tf.status = TaskForceStatus.FAILED
158
+ tf.completed_at = datetime.now()
159
+
160
+ for member_id in tf.members:
161
+ node = self.registry.get_node(member_id)
162
+ if node and tf_id in node.current_taskforces:
163
+ node.current_taskforces.remove(tf_id)
164
+
165
+ self._stats["taskforces_failed"] += 1
166
+
167
+ # ============================================================
168
+ # 任务调度
169
+ # ============================================================
170
+
171
+ def submit_task(self, task: AggregationTask) -> Optional[str]:
172
+ """提交任务"""
173
+ tf = self.create_taskforce(task)
174
+ if tf:
175
+ return tf.taskforce_id
176
+ return None
177
+
178
+ def get_taskforce(self, tf_id: str) -> Optional[TaskForce]:
179
+ with self._lock:
180
+ return self._taskforces.get(tf_id)
181
+
182
+ def get_active_taskforces(self) -> List[TaskForce]:
183
+ with self._lock:
184
+ return [tf for tf in self._taskforces.values()
185
+ if tf.status == TaskForceStatus.ACTIVE]
186
+
187
+ # ============================================================
188
+ # 结果聚合
189
+ # ============================================================
190
+
191
+ def aggregate_results(self, tf_id: str,
192
+ node_results: Dict[str, Dict]) -> Dict:
193
+ """
194
+ 聚合各节点结果
195
+
196
+ 策略:
197
+ - PARAMETER_AVERAGE: 参数平均(联邦学习风格)
198
+ - ENSEMBLE_VOTE: 投票法(多数同意)
199
+ - SEQUENTIAL_REFINE: 顺序精炼(每个节点改进上一个的结果)
200
+ - ADAPTIVE_MIX: 自适应混合(按置信度加权)
201
+ """
202
+ tf = self.get_taskforce(tf_id)
203
+ if not tf:
204
+ return {"error": "临时服务器不存在"}
205
+
206
+ strategy = tf.strategy
207
+ results = {k: v for k, v in node_results.items() if v}
208
+
209
+ if not results:
210
+ return {"error": "无有效结果"}
211
+
212
+ if strategy == AggregationStrategy.ENSEMBLE_VOTE:
213
+ return self._vote_aggregate(results)
214
+ elif strategy == AggregationStrategy.SEQUENTIAL_REFINE:
215
+ return self._sequential_aggregate(results)
216
+ elif strategy == AggregationStrategy.PARAMETER_AVERAGE:
217
+ return self._parameter_average(results)
218
+ else: # ADAPTIVE_MIX
219
+ return self._adaptive_mix(results)
220
+
221
+ def _vote_aggregate(self, results: Dict[str, Dict]) -> Dict:
222
+ """投票聚合 — 选择出现最多的回答"""
223
+ from collections import Counter
224
+
225
+ responses = []
226
+ for node_id, result in results.items():
227
+ resp = result.get("response", "")
228
+ if resp:
229
+ responses.append(resp)
230
+
231
+ if not responses:
232
+ return {"response": "", "confidence": 0.0}
233
+
234
+ # 简单投票:选最长的回答(通常是信息最丰富的)
235
+ counter = Counter(responses)
236
+ if counter:
237
+ best = counter.most_common(1)[0][0]
238
+ confidence = counter.most_common(1)[0][1] / len(responses)
239
+ return {"response": best, "confidence": confidence, "method": "vote"}
240
+
241
+ return {"response": responses[0], "confidence": 0.5, "method": "vote"}
242
+
243
+ def _sequential_aggregate(self, results: Dict[str, Dict]) -> Dict:
244
+ """顺序精炼 — 每个节点改进上一个结果"""
245
+ refined = ""
246
+ confidence = 0.0
247
+
248
+ for node_id, result in results.items():
249
+ if refined:
250
+ # 将前一个结果作为上下文传入
251
+ refined = result.get("response", refined)
252
+ else:
253
+ refined = result.get("response", "")
254
+ confidence = max(confidence, result.get("confidence", 0.0))
255
+
256
+ return {"response": refined, "confidence": confidence, "method": "sequential"}
257
+
258
+ def _parameter_average(self, results: Dict[str, Dict]) -> Dict:
259
+ """参数平均 — 联邦学习风格"""
260
+ # 对置信度做加权平均
261
+ total_weight = 0.0
262
+ weighted_confidence = 0.0
263
+ best_response = ""
264
+ best_conf = 0.0
265
+
266
+ for node_id, result in results.items():
267
+ conf = result.get("confidence", 0.5)
268
+ total_weight += conf
269
+ weighted_confidence += conf * conf
270
+
271
+ if conf > best_conf:
272
+ best_conf = conf
273
+ best_response = result.get("response", "")
274
+
275
+ avg_confidence = weighted_confidence / max(total_weight, 0.01)
276
+ return {
277
+ "response": best_response,
278
+ "confidence": avg_confidence,
279
+ "method": "parameter_average",
280
+ "contributing_nodes": len(results),
281
+ }
282
+
283
+ def _adaptive_mix(self, results: Dict[str, Dict]) -> Dict:
284
+ """自适应混合 — 按置信度和专长加权"""
285
+ total_score = 0.0
286
+ best_response = ""
287
+ best_score = 0.0
288
+ all_responses = []
289
+
290
+ for node_id, result in results.items():
291
+ conf = result.get("confidence", 0.5)
292
+ # 考虑节点的专长匹配度
293
+ node = self.registry.get_node(node_id)
294
+ expertise_bonus = 0.0
295
+ if node:
296
+ expertise_bonus = node.capability.compute_score * 0.1
297
+
298
+ score = conf + expertise_bonus
299
+ total_score += score
300
+ all_responses.append(result.get("response", ""))
301
+
302
+ if score > best_score:
303
+ best_score = score
304
+ best_response = result.get("response", "")
305
+
306
+ avg_confidence = total_score / max(len(results), 1)
307
+ return {
308
+ "response": best_response,
309
+ "confidence": min(avg_confidence, 1.0),
310
+ "method": "adaptive_mix",
311
+ "contributing_nodes": len(results),
312
+ "all_responses": all_responses[:3], # 保留前3个备选
313
+ }
314
+
315
+ # ============================================================
316
+ # 工具方法
317
+ # ============================================================
318
+
319
+ def _gen_id(self, prefix: str) -> str:
320
+ return f"{prefix}_{hashlib.md5(f'{time.time()}{prefix}'.encode()).hexdigest()[:8]}"
321
+
322
+ def get_stats(self) -> Dict:
323
+ with self._lock:
324
+ return {
325
+ **self._stats,
326
+ "active_taskforces": len(self.get_active_taskforces()),
327
+ "total_taskforces": len(self._taskforces),
328
+ }