#!/usr/bin/env python3 """ 虫群聚合协议 — 任务调度器 核心功能: - 接收任务请求 - 组建临时服务器(TaskForce) - 分配任务到各节点 - 聚合各节点结果 - 任务完成后解散临时服务器 类比GPU集群: - 任务调度器 = SLURM/调度系统 - 临时服务器 = 临时分配的GPU组 - 节点 = 单块GPU """ import hashlib import logging import threading import time from collections import deque from datetime import datetime from typing import Callable, Dict, List, Optional from .types import ( AggregationStrategy, AggregationTask, TaskForce, TaskForceStatus, NodeInfo, NodeStatus, ProtocolMessage, ) from .discovery import NodeRegistry from .transport import MessageBus logger = logging.getLogger(__name__) class TaskForceManager: """ 临时服务器管理器 生命周期:创建 → 组建 → 运行 → 完成 → 解散 """ def __init__(self, node_registry: NodeRegistry, message_bus: MessageBus): self.registry = node_registry self.bus = message_bus # 活跃的临时服务器 taskforce_id -> TaskForce self._taskforces: Dict[str, TaskForce] = {} self._lock = threading.RLock() # 任务队列 self._task_queue = deque() # 统计 self._stats = { "taskforces_created": 0, "taskforces_completed": 0, "taskforces_failed": 0, "tasks_processed": 0, } # ============================================================ # 临时服务器生命周期 # ============================================================ def create_taskforce(self, task: AggregationTask) -> Optional[TaskForce]: """ 为任务创建临时服务器 类似GPU集群分配资源: 1. 分析任务需要的能力 2. 从可用节点中选择 3. 组建临时服务器 4. 通知成员节点 """ # 1. 发现合适的节点 candidates = self.registry.discover_for_task( required_caps=task.required_capabilities, min_nodes=task.min_nodes, max_nodes=task.max_nodes, exclude=[task.requester], ) if len(candidates) < task.min_nodes: logger.warning(f"节点不足: 需要{task.min_nodes},找到{len(candidates)}") return None # 2. 创建临时服务器 tf_id = self._gen_id("tf") tf = TaskForce( taskforce_id=tf_id, name=f"TaskForce-{tf_id}", coordinator=task.requester, strategy=task.strategy, task_description=task.query, ) # 3. 添加成员 for node in candidates: tf.add_member(node.node_id) # 更新节点的当前任务列表 node.current_taskforces.append(tf_id) # 4. 记录 with self._lock: self._taskforces[tf_id] = tf task.taskforce_id = tf_id task.status = "assigned" self._stats["taskforces_created"] += 1 # 5. 通知成员(通过消息总线) self.bus.broadcast("join_taskforce", { "taskforce_id": tf_id, "coordinator": task.requester, "task": task.query, "members": tf.members, "strategy": tf.strategy.value, }) logger.info(f"临时服务器创建: {tf_id}, 成员: {tf.members}") return tf def complete_taskforce(self, tf_id: str, result: Dict): """完成任务,解散临时服务器""" with self._lock: tf = self._taskforces.get(tf_id) if not tf: return tf.status = TaskForceStatus.COMPLETED tf.completed_at = datetime.now() tf.results = result # 清理成员节点的任务列表 for member_id in tf.members: node = self.registry.get_node(member_id) if node and tf_id in node.current_taskforces: node.current_taskforces.remove(tf_id) self._stats["taskforces_completed"] += 1 # 通知解散 self.bus.broadcast("leave_taskforce", { "taskforce_id": tf_id, "status": "completed", }) logger.info(f"临时服务器解散: {tf_id}") def fail_taskforce(self, tf_id: str, reason: str = ""): """临时服务器失败""" with self._lock: tf = self._taskforces.get(tf_id) if not tf: return tf.status = TaskForceStatus.FAILED tf.completed_at = datetime.now() for member_id in tf.members: node = self.registry.get_node(member_id) if node and tf_id in node.current_taskforces: node.current_taskforces.remove(tf_id) self._stats["taskforces_failed"] += 1 # ============================================================ # 任务调度 # ============================================================ def submit_task(self, task: AggregationTask) -> Optional[str]: """提交任务""" tf = self.create_taskforce(task) if tf: return tf.taskforce_id return None def get_taskforce(self, tf_id: str) -> Optional[TaskForce]: with self._lock: return self._taskforces.get(tf_id) def get_active_taskforces(self) -> List[TaskForce]: with self._lock: return [tf for tf in self._taskforces.values() if tf.status == TaskForceStatus.ACTIVE] # ============================================================ # 结果聚合 # ============================================================ def aggregate_results(self, tf_id: str, node_results: Dict[str, Dict]) -> Dict: """ 聚合各节点结果 策略: - PARAMETER_AVERAGE: 参数平均(联邦学习风格) - ENSEMBLE_VOTE: 投票法(多数同意) - SEQUENTIAL_REFINE: 顺序精炼(每个节点改进上一个的结果) - ADAPTIVE_MIX: 自适应混合(按置信度加权) """ tf = self.get_taskforce(tf_id) if not tf: return {"error": "临时服务器不存在"} strategy = tf.strategy results = {k: v for k, v in node_results.items() if v} if not results: return {"error": "无有效结果"} if strategy == AggregationStrategy.ENSEMBLE_VOTE: return self._vote_aggregate(results) elif strategy == AggregationStrategy.SEQUENTIAL_REFINE: return self._sequential_aggregate(results) elif strategy == AggregationStrategy.PARAMETER_AVERAGE: return self._parameter_average(results) else: # ADAPTIVE_MIX return self._adaptive_mix(results) def _vote_aggregate(self, results: Dict[str, Dict]) -> Dict: """投票聚合 — 选择出现最多的回答""" from collections import Counter responses = [] for node_id, result in results.items(): resp = result.get("response", "") if resp: responses.append(resp) if not responses: return {"response": "", "confidence": 0.0} # 简单投票:选最长的回答(通常是信息最丰富的) counter = Counter(responses) if counter: best = counter.most_common(1)[0][0] confidence = counter.most_common(1)[0][1] / len(responses) return {"response": best, "confidence": confidence, "method": "vote"} return {"response": responses[0], "confidence": 0.5, "method": "vote"} def _sequential_aggregate(self, results: Dict[str, Dict]) -> Dict: """顺序精炼 — 每个节点改进上一个结果""" refined = "" confidence = 0.0 for node_id, result in results.items(): if refined: # 将前一个结果作为上下文传入 refined = result.get("response", refined) else: refined = result.get("response", "") confidence = max(confidence, result.get("confidence", 0.0)) return {"response": refined, "confidence": confidence, "method": "sequential"} def _parameter_average(self, results: Dict[str, Dict]) -> Dict: """参数平均 — 联邦学习风格""" # 对置信度做加权平均 total_weight = 0.0 weighted_confidence = 0.0 best_response = "" best_conf = 0.0 for node_id, result in results.items(): conf = result.get("confidence", 0.5) total_weight += conf weighted_confidence += conf * conf if conf > best_conf: best_conf = conf best_response = result.get("response", "") avg_confidence = weighted_confidence / max(total_weight, 0.01) return { "response": best_response, "confidence": avg_confidence, "method": "parameter_average", "contributing_nodes": len(results), } def _adaptive_mix(self, results: Dict[str, Dict]) -> Dict: """自适应混合 — 按置信度和专长加权""" total_score = 0.0 best_response = "" best_score = 0.0 all_responses = [] for node_id, result in results.items(): conf = result.get("confidence", 0.5) # 考虑节点的专长匹配度 node = self.registry.get_node(node_id) expertise_bonus = 0.0 if node: expertise_bonus = node.capability.compute_score * 0.1 score = conf + expertise_bonus total_score += score all_responses.append(result.get("response", "")) if score > best_score: best_score = score best_response = result.get("response", "") avg_confidence = total_score / max(len(results), 1) return { "response": best_response, "confidence": min(avg_confidence, 1.0), "method": "adaptive_mix", "contributing_nodes": len(results), "all_responses": all_responses[:3], # 保留前3个备选 } # ============================================================ # 工具方法 # ============================================================ def _gen_id(self, prefix: str) -> str: return f"{prefix}_{hashlib.md5(f'{time.time()}{prefix}'.encode()).hexdigest()[:8]}" def get_stats(self) -> Dict: with self._lock: return { **self._stats, "active_taskforces": len(self.get_active_taskforces()), "total_taskforces": len(self._taskforces), }