Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| 虫群聚合协议 — 任务调度器 | |
| 核心功能: | |
| - 接收任务请求 | |
| - 组建临时服务器(TaskForce) | |
| - 分配任务到各节点 | |
| - 聚合各节点结果 | |
| - 任务完成后解散临时服务器 | |
| 类比GPU集群: | |
| - 任务调度器 = SLURM/调度系统 | |
| - 临时服务器 = 临时分配的GPU组 | |
| - 节点 = 单块GPU | |
| """ | |
| import hashlib | |
| import logging | |
| import threading | |
| import time | |
| from collections import deque | |
| from datetime import datetime | |
| from typing import Callable, Dict, List, Optional | |
| from .types import ( | |
| AggregationStrategy, AggregationTask, TaskForce, TaskForceStatus, | |
| NodeInfo, NodeStatus, ProtocolMessage, | |
| ) | |
| from .discovery import NodeRegistry | |
| from .transport import MessageBus | |
| logger = logging.getLogger(__name__) | |
| class TaskForceManager: | |
| """ | |
| 临时服务器管理器 | |
| 生命周期:创建 → 组建 → 运行 → 完成 → 解散 | |
| """ | |
| def __init__(self, node_registry: NodeRegistry, message_bus: MessageBus): | |
| self.registry = node_registry | |
| self.bus = message_bus | |
| # 活跃的临时服务器 taskforce_id -> TaskForce | |
| self._taskforces: Dict[str, TaskForce] = {} | |
| self._lock = threading.RLock() | |
| # 任务队列 | |
| self._task_queue = deque() | |
| # 统计 | |
| self._stats = { | |
| "taskforces_created": 0, | |
| "taskforces_completed": 0, | |
| "taskforces_failed": 0, | |
| "tasks_processed": 0, | |
| } | |
| # ============================================================ | |
| # 临时服务器生命周期 | |
| # ============================================================ | |
| def create_taskforce(self, task: AggregationTask) -> Optional[TaskForce]: | |
| """ | |
| 为任务创建临时服务器 | |
| 类似GPU集群分配资源: | |
| 1. 分析任务需要的能力 | |
| 2. 从可用节点中选择 | |
| 3. 组建临时服务器 | |
| 4. 通知成员节点 | |
| """ | |
| # 1. 发现合适的节点 | |
| candidates = self.registry.discover_for_task( | |
| required_caps=task.required_capabilities, | |
| min_nodes=task.min_nodes, | |
| max_nodes=task.max_nodes, | |
| exclude=[task.requester], | |
| ) | |
| if len(candidates) < task.min_nodes: | |
| logger.warning(f"节点不足: 需要{task.min_nodes},找到{len(candidates)}") | |
| return None | |
| # 2. 创建临时服务器 | |
| tf_id = self._gen_id("tf") | |
| tf = TaskForce( | |
| taskforce_id=tf_id, | |
| name=f"TaskForce-{tf_id}", | |
| coordinator=task.requester, | |
| strategy=task.strategy, | |
| task_description=task.query, | |
| ) | |
| # 3. 添加成员 | |
| for node in candidates: | |
| tf.add_member(node.node_id) | |
| # 更新节点的当前任务列表 | |
| node.current_taskforces.append(tf_id) | |
| # 4. 记录 | |
| with self._lock: | |
| self._taskforces[tf_id] = tf | |
| task.taskforce_id = tf_id | |
| task.status = "assigned" | |
| self._stats["taskforces_created"] += 1 | |
| # 5. 通知成员(通过消息总线) | |
| self.bus.broadcast("join_taskforce", { | |
| "taskforce_id": tf_id, | |
| "coordinator": task.requester, | |
| "task": task.query, | |
| "members": tf.members, | |
| "strategy": tf.strategy.value, | |
| }) | |
| logger.info(f"临时服务器创建: {tf_id}, 成员: {tf.members}") | |
| return tf | |
| def complete_taskforce(self, tf_id: str, result: Dict): | |
| """完成任务,解散临时服务器""" | |
| with self._lock: | |
| tf = self._taskforces.get(tf_id) | |
| if not tf: | |
| return | |
| tf.status = TaskForceStatus.COMPLETED | |
| tf.completed_at = datetime.now() | |
| tf.results = result | |
| # 清理成员节点的任务列表 | |
| for member_id in tf.members: | |
| node = self.registry.get_node(member_id) | |
| if node and tf_id in node.current_taskforces: | |
| node.current_taskforces.remove(tf_id) | |
| self._stats["taskforces_completed"] += 1 | |
| # 通知解散 | |
| self.bus.broadcast("leave_taskforce", { | |
| "taskforce_id": tf_id, | |
| "status": "completed", | |
| }) | |
| logger.info(f"临时服务器解散: {tf_id}") | |
| def fail_taskforce(self, tf_id: str, reason: str = ""): | |
| """临时服务器失败""" | |
| with self._lock: | |
| tf = self._taskforces.get(tf_id) | |
| if not tf: | |
| return | |
| tf.status = TaskForceStatus.FAILED | |
| tf.completed_at = datetime.now() | |
| for member_id in tf.members: | |
| node = self.registry.get_node(member_id) | |
| if node and tf_id in node.current_taskforces: | |
| node.current_taskforces.remove(tf_id) | |
| self._stats["taskforces_failed"] += 1 | |
| # ============================================================ | |
| # 任务调度 | |
| # ============================================================ | |
| def submit_task(self, task: AggregationTask) -> Optional[str]: | |
| """提交任务""" | |
| tf = self.create_taskforce(task) | |
| if tf: | |
| return tf.taskforce_id | |
| return None | |
| def get_taskforce(self, tf_id: str) -> Optional[TaskForce]: | |
| with self._lock: | |
| return self._taskforces.get(tf_id) | |
| def get_active_taskforces(self) -> List[TaskForce]: | |
| with self._lock: | |
| return [tf for tf in self._taskforces.values() | |
| if tf.status == TaskForceStatus.ACTIVE] | |
| # ============================================================ | |
| # 结果聚合 | |
| # ============================================================ | |
| def aggregate_results(self, tf_id: str, | |
| node_results: Dict[str, Dict]) -> Dict: | |
| """ | |
| 聚合各节点结果 | |
| 策略: | |
| - PARAMETER_AVERAGE: 参数平均(联邦学习风格) | |
| - ENSEMBLE_VOTE: 投票法(多数同意) | |
| - SEQUENTIAL_REFINE: 顺序精炼(每个节点改进上一个的结果) | |
| - ADAPTIVE_MIX: 自适应混合(按置信度加权) | |
| """ | |
| tf = self.get_taskforce(tf_id) | |
| if not tf: | |
| return {"error": "临时服务器不存在"} | |
| strategy = tf.strategy | |
| results = {k: v for k, v in node_results.items() if v} | |
| if not results: | |
| return {"error": "无有效结果"} | |
| if strategy == AggregationStrategy.ENSEMBLE_VOTE: | |
| return self._vote_aggregate(results) | |
| elif strategy == AggregationStrategy.SEQUENTIAL_REFINE: | |
| return self._sequential_aggregate(results) | |
| elif strategy == AggregationStrategy.PARAMETER_AVERAGE: | |
| return self._parameter_average(results) | |
| else: # ADAPTIVE_MIX | |
| return self._adaptive_mix(results) | |
| def _vote_aggregate(self, results: Dict[str, Dict]) -> Dict: | |
| """投票聚合 — 选择出现最多的回答""" | |
| from collections import Counter | |
| responses = [] | |
| for node_id, result in results.items(): | |
| resp = result.get("response", "") | |
| if resp: | |
| responses.append(resp) | |
| if not responses: | |
| return {"response": "", "confidence": 0.0} | |
| # 简单投票:选最长的回答(通常是信息最丰富的) | |
| counter = Counter(responses) | |
| if counter: | |
| best = counter.most_common(1)[0][0] | |
| confidence = counter.most_common(1)[0][1] / len(responses) | |
| return {"response": best, "confidence": confidence, "method": "vote"} | |
| return {"response": responses[0], "confidence": 0.5, "method": "vote"} | |
| def _sequential_aggregate(self, results: Dict[str, Dict]) -> Dict: | |
| """顺序精炼 — 每个节点改进上一个结果""" | |
| refined = "" | |
| confidence = 0.0 | |
| for node_id, result in results.items(): | |
| if refined: | |
| # 将前一个结果作为上下文传入 | |
| refined = result.get("response", refined) | |
| else: | |
| refined = result.get("response", "") | |
| confidence = max(confidence, result.get("confidence", 0.0)) | |
| return {"response": refined, "confidence": confidence, "method": "sequential"} | |
| def _parameter_average(self, results: Dict[str, Dict]) -> Dict: | |
| """参数平均 — 联邦学习风格""" | |
| # 对置信度做加权平均 | |
| total_weight = 0.0 | |
| weighted_confidence = 0.0 | |
| best_response = "" | |
| best_conf = 0.0 | |
| for node_id, result in results.items(): | |
| conf = result.get("confidence", 0.5) | |
| total_weight += conf | |
| weighted_confidence += conf * conf | |
| if conf > best_conf: | |
| best_conf = conf | |
| best_response = result.get("response", "") | |
| avg_confidence = weighted_confidence / max(total_weight, 0.01) | |
| return { | |
| "response": best_response, | |
| "confidence": avg_confidence, | |
| "method": "parameter_average", | |
| "contributing_nodes": len(results), | |
| } | |
| def _adaptive_mix(self, results: Dict[str, Dict]) -> Dict: | |
| """自适应混合 — 按置信度和专长加权""" | |
| total_score = 0.0 | |
| best_response = "" | |
| best_score = 0.0 | |
| all_responses = [] | |
| for node_id, result in results.items(): | |
| conf = result.get("confidence", 0.5) | |
| # 考虑节点的专长匹配度 | |
| node = self.registry.get_node(node_id) | |
| expertise_bonus = 0.0 | |
| if node: | |
| expertise_bonus = node.capability.compute_score * 0.1 | |
| score = conf + expertise_bonus | |
| total_score += score | |
| all_responses.append(result.get("response", "")) | |
| if score > best_score: | |
| best_score = score | |
| best_response = result.get("response", "") | |
| avg_confidence = total_score / max(len(results), 1) | |
| return { | |
| "response": best_response, | |
| "confidence": min(avg_confidence, 1.0), | |
| "method": "adaptive_mix", | |
| "contributing_nodes": len(results), | |
| "all_responses": all_responses[:3], # 保留前3个备选 | |
| } | |
| # ============================================================ | |
| # 工具方法 | |
| # ============================================================ | |
| def _gen_id(self, prefix: str) -> str: | |
| return f"{prefix}_{hashlib.md5(f'{time.time()}{prefix}'.encode()).hexdigest()[:8]}" | |
| def get_stats(self) -> Dict: | |
| with self._lock: | |
| return { | |
| **self._stats, | |
| "active_taskforces": len(self.get_active_taskforces()), | |
| "total_taskforces": len(self._taskforces), | |
| } | |