Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群v7 — 任务模型树(Task Tree) | |
| 知识树结构的按需加载任务模型群 | |
| 设计思路(类比高中物理课程体系): | |
| - 高中物理 = 力学 + 电学 + 热学 + 光学 ... | |
| - 力学 = 运动学 + 动力学 + 动量 + 能量 ... | |
| - 每个知识点 = 一个任务模型(Drone) | |
| - 上课时只加载当前章节的模型,用完卸载 | |
| - 这样内存中永远只有需要的模型在运行 | |
| """ | |
| import json | |
| import logging | |
| import os | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # 任务节点 — 知识树的叶子/分支 | |
| # ============================================================ | |
| class TaskNode: | |
| """ | |
| 任务节点 — 知识树的一个节点 | |
| 可以是分支(如"力学")或叶子(如"牛顿第二定律") | |
| 叶子节点关联一个具体的任务模型 | |
| """ | |
| def __init__(self, node_id: str, name: str, | |
| domain: str = "", parent_id: str = "", | |
| is_leaf: bool = False, model_id: str = ""): | |
| self.node_id = node_id | |
| self.name = name | |
| self.domain = domain | |
| self.parent_id = parent_id | |
| self.is_leaf = is_leaf | |
| self.model_id = model_id # 叶子节点关联的模型ID | |
| self.children: List[str] = [] # 子节点ID列表 | |
| # 运行时状态 | |
| self.loaded = False # 是否已加载到内存 | |
| self.use_count = 0 # 使用次数 | |
| self.last_used: Optional[float] = None | |
| def to_dict(self) -> Dict: | |
| return { | |
| "node_id": self.node_id, | |
| "name": self.name, | |
| "domain": self.domain, | |
| "parent_id": self.parent_id, | |
| "is_leaf": self.is_leaf, | |
| "model_id": self.model_id, | |
| "children": self.children, | |
| "use_count": self.use_count, | |
| "loaded": self.loaded, | |
| } | |
| # ============================================================ | |
| # 任务模型树核心 | |
| # ============================================================ | |
| class TaskTree: | |
| """ | |
| 任务模型树 — 按需加载的知识树 | |
| 类比课程体系: | |
| - 根节点 = 用户的全部知识领域 | |
| - 分支节点 = 领域分类(编程/物理/写作/日常...) | |
| - 叶子节点 = 具体任务模型(python基础/力学/邮件写作...) | |
| 核心操作: | |
| 1. execute(): 按TaskAnalysis路由到对应叶子节点执行 | |
| 2. record_usage(): 记录使用情况,用于后续优化 | |
| 3. load/unload: 按需加载/卸载模型(内存管理) | |
| """ | |
| DATA_DIR = "/home/admin/swarm/data/task_tree" | |
| MAX_LOADED = 3 # 同时最多加载的模型数 | |
| def __init__(self, user_id: str = "default"): | |
| self.user_id = user_id | |
| self._nodes: Dict[str, TaskNode] = {} | |
| self._root_id = "root" | |
| self._loaded_models: Dict[str, float] = {} # model_id -> load_time | |
| # 统计 | |
| self._execute_count = 0 | |
| self._hit_count = 0 | |
| self._load_count = 0 | |
| self._unload_count = 0 | |
| # 初始化默认知识树 | |
| self._init_default_tree() | |
| # 加载用户自定义树 | |
| self._load() | |
| # ============================================================ | |
| # 默认知识树 | |
| # ============================================================ | |
| def _init_default_tree(self): | |
| """初始化默认的任务模型知识树""" | |
| # 根节点 | |
| root = TaskNode("root", "全部领域") | |
| self._nodes["root"] = root | |
| # 一级领域 | |
| domains = { | |
| "coding": ("编程", ["python", "web", "algorithm", "database"]), | |
| "science": ("科学", ["physics", "math", "chemistry"]), | |
| "writing": ("写作", ["email", "report", "creative"]), | |
| "daily": ("日常", ["chat", "translate", "search"]), | |
| "work": ("工作", ["office", "project", "meeting"]), | |
| } | |
| for domain_id, (domain_name, leaf_names) in domains.items(): | |
| # 创建领域分支 | |
| domain_node = TaskNode( | |
| domain_id, domain_name, | |
| domain=domain_name, parent_id="root" | |
| ) | |
| self._nodes[domain_id] = domain_node | |
| root.children.append(domain_id) | |
| # 创建叶子节点 | |
| for leaf_name in leaf_names: | |
| leaf_id = f"{domain_id}.{leaf_name}" | |
| leaf_node = TaskNode( | |
| leaf_id, leaf_name, | |
| domain=domain_name, parent_id=domain_id, | |
| is_leaf=True, model_id=f"drone_{domain_id}_{leaf_name}" | |
| ) | |
| self._nodes[leaf_id] = leaf_node | |
| domain_node.children.append(leaf_id) | |
| # ============================================================ | |
| # 核心接口 | |
| # ============================================================ | |
| def execute(self, analysis, query: str) -> Optional[str]: | |
| """ | |
| 根据TaskAnalysis找到对应叶子节点执行 | |
| 当前实现: 查找匹配的叶子节点,返回模型ID | |
| 未来: 实际加载并执行对应的小模型 | |
| """ | |
| self._execute_count += 1 | |
| # 根据知识领域查找节点 | |
| domains = getattr(analysis, 'knowledge_domains', []) | |
| intent = getattr(analysis, 'intent', 'chat') | |
| # 映射: intent → 可能的领域节点 | |
| intent_domain_map = { | |
| "code": "coding", "reasoning": "science", "compute": "science", | |
| "translate": "daily", "query": "daily", "chat": "daily", | |
| "memory": "daily", "write": "writing", | |
| } | |
| target_domain = intent_domain_map.get(intent, "daily") | |
| # 如果有精确领域匹配,优先用 | |
| domain_name_map = { | |
| "物理": "science", "编程": "coding", "数学": "science", | |
| "写作": "writing", "日常": "daily", | |
| } | |
| for d in domains: | |
| if d in domain_name_map: | |
| target_domain = domain_name_map[d] | |
| break | |
| # 查找该领域下的叶子节点 | |
| if target_domain in self._nodes: | |
| domain_node = self._nodes[target_domain] | |
| if domain_node.children: | |
| # 取第一个叶子作为默认(后续可以做更精细的匹配) | |
| leaf_id = domain_node.children[0] | |
| leaf = self._nodes.get(leaf_id) | |
| if leaf and leaf.is_leaf: | |
| # 标记加载和使用 | |
| self._ensure_loaded(leaf.model_id) | |
| leaf.use_count += 1 | |
| leaf.last_used = time.time() | |
| self._hit_count += 1 | |
| # 当前阶段: 返回None让QueenAgent走降级路径 | |
| # 未来: 这里实际执行叶子节点关联的小模型 | |
| return None | |
| return None | |
| def record_usage(self, analysis): | |
| """记录使用情况,用于知识树优化""" | |
| intent = getattr(analysis, 'intent', 'chat') | |
| domains = getattr(analysis, 'knowledge_domains', []) | |
| route = getattr(analysis, 'route', '') | |
| # 记录到对应节点 | |
| for domain in domains: | |
| domain_name_map = { | |
| "物理": "science", "编程": "coding", "数学": "science", | |
| "写作": "writing", "日常": "daily", | |
| } | |
| domain_id = domain_name_map.get(domain) | |
| if domain_id and domain_id in self._nodes: | |
| self._nodes[domain_id].use_count += 1 | |
| self._nodes[domain_id].last_used = time.time() | |
| # ============================================================ | |
| # 动态加载管理 | |
| # ============================================================ | |
| def _ensure_loaded(self, model_id: str): | |
| """确保模型已加载,超限时淘汰最久未用的""" | |
| if model_id in self._loaded_models: | |
| self._loaded_models[model_id] = time.time() | |
| return | |
| # 检查是否超限 | |
| if len(self._loaded_models) >= self.MAX_LOADED: | |
| self._evict_one() | |
| # 加载 | |
| self._loaded_models[model_id] = time.time() | |
| self._load_count += 1 | |
| logger.debug(f"加载任务模型: {model_id}") | |
| def _evict_one(self): | |
| """淘汰最久未使用的模型""" | |
| if not self._loaded_models: | |
| return | |
| # LRU: 淘汰最早加载的 | |
| oldest = min(self._loaded_models, key=self._loaded_models.get) | |
| del self._loaded_models[oldest] | |
| self._unload_count += 1 | |
| logger.debug(f"卸载任务模型: {oldest}") | |
| def unload_all(self): | |
| """卸载所有模型""" | |
| self._loaded_models.clear() | |
| for node in self._nodes.values(): | |
| node.loaded = False | |
| # ============================================================ | |
| # 知识树操作 | |
| # ============================================================ | |
| def add_node(self, parent_id: str, node_id: str, name: str, | |
| is_leaf: bool = False, model_id: str = "") -> bool: | |
| """添加节点到知识树""" | |
| if parent_id not in self._nodes: | |
| logger.warning(f"父节点 {parent_id} 不存在") | |
| return False | |
| if node_id in self._nodes: | |
| logger.warning(f"节点 {node_id} 已存在") | |
| return False | |
| parent = self._nodes[parent_id] | |
| node = TaskNode( | |
| node_id, name, | |
| domain=parent.domain, | |
| parent_id=parent_id, | |
| is_leaf=is_leaf, | |
| model_id=model_id, | |
| ) | |
| self._nodes[node_id] = node | |
| parent.children.append(node_id) | |
| return True | |
| def remove_node(self, node_id: str) -> bool: | |
| """移除节点(级联删除子节点)""" | |
| if node_id not in self._nodes or node_id == "root": | |
| return False | |
| node = self._nodes[node_id] | |
| # 递归删除子节点 | |
| for child_id in list(node.children): | |
| self.remove_node(child_id) | |
| # 从父节点移除引用 | |
| if node.parent_id in self._nodes: | |
| parent = self._nodes[node.parent_id] | |
| parent.children = [c for c in parent.children if c != node_id] | |
| del self._nodes[node_id] | |
| return True | |
| def get_tree(self, node_id: str = "root", depth: int = 0) -> Dict: | |
| """获取知识树结构(递归)""" | |
| if node_id not in self._nodes: | |
| return {} | |
| node = self._nodes[node_id] | |
| result = { | |
| "id": node.node_id, | |
| "name": node.name, | |
| "domain": node.domain, | |
| "use_count": node.use_count, | |
| } | |
| if node.children: | |
| result["children"] = [ | |
| self.get_tree(c, depth + 1) for c in node.children | |
| ] | |
| return result | |
| # ============================================================ | |
| # 持久化 | |
| # ============================================================ | |
| def _load(self): | |
| """加载用户自定义的知识树""" | |
| filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json") | |
| if not os.path.exists(filepath): | |
| return | |
| try: | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for item in data.get("custom_nodes", []): | |
| node = TaskNode( | |
| node_id=item["node_id"], | |
| name=item["name"], | |
| domain=item.get("domain", ""), | |
| parent_id=item.get("parent_id", "root"), | |
| is_leaf=item.get("is_leaf", False), | |
| model_id=item.get("model_id", ""), | |
| ) | |
| node.use_count = item.get("use_count", 0) | |
| self._nodes[node.node_id] = node | |
| if node.parent_id in self._nodes: | |
| if node.node_id not in self._nodes[node.parent_id].children: | |
| self._nodes[node.parent_id].children.append(node.node_id) | |
| logger.info(f"任务树加载: +{len(data.get('custom_nodes', []))}自定义节点") | |
| except Exception as e: | |
| logger.warning(f"任务树加载失败: {e}") | |
| def save(self): | |
| """保存知识树""" | |
| os.makedirs(self.DATA_DIR, exist_ok=True) | |
| filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json") | |
| custom_nodes = [] | |
| for node in self._nodes.values(): | |
| if node.node_id == "root" or "." not in node.node_id: | |
| continue # 跳过默认节点 | |
| custom_nodes.append(node.to_dict()) | |
| try: | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| json.dump({"custom_nodes": custom_nodes}, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| logger.warning(f"任务树保存失败: {e}") | |
| # ============================================================ | |
| # 状态查询 | |
| # ============================================================ | |
| def get_status(self) -> Dict: | |
| """获取任务树状态""" | |
| leaf_count = sum(1 for n in self._nodes.values() if n.is_leaf) | |
| return { | |
| "user_id": self.user_id, | |
| "total_nodes": len(self._nodes), | |
| "leaf_nodes": leaf_count, | |
| "loaded_models": len(self._loaded_models), | |
| "max_loaded": self.MAX_LOADED, | |
| "execute_count": self._execute_count, | |
| "hit_count": self._hit_count, | |
| "load_count": self._load_count, | |
| "unload_count": self._unload_count, | |
| } | |