#!/usr/bin/env python3 """ 虫群v7 — 任务模型树(Task Tree) 知识树结构的按需加载任务模型群 设计思路(类比高中物理课程体系): - 高中物理 = 力学 + 电学 + 热学 + 光学 ... - 力学 = 运动学 + 动力学 + 动量 + 能量 ... - 每个知识点 = 一个任务模型(Drone) - 上课时只加载当前章节的模型,用完卸载 - 这样内存中永远只有需要的模型在运行 """ import json import logging import os import time from datetime import datetime from typing import Dict, List, Optional logger = logging.getLogger(__name__) # ============================================================ # 任务节点 — 知识树的叶子/分支 # ============================================================ class TaskNode: """ 任务节点 — 知识树的一个节点 可以是分支(如"力学")或叶子(如"牛顿第二定律") 叶子节点关联一个具体的任务模型 """ def __init__(self, node_id: str, name: str, domain: str = "", parent_id: str = "", is_leaf: bool = False, model_id: str = ""): self.node_id = node_id self.name = name self.domain = domain self.parent_id = parent_id self.is_leaf = is_leaf self.model_id = model_id # 叶子节点关联的模型ID self.children: List[str] = [] # 子节点ID列表 # 运行时状态 self.loaded = False # 是否已加载到内存 self.use_count = 0 # 使用次数 self.last_used: Optional[float] = None def to_dict(self) -> Dict: return { "node_id": self.node_id, "name": self.name, "domain": self.domain, "parent_id": self.parent_id, "is_leaf": self.is_leaf, "model_id": self.model_id, "children": self.children, "use_count": self.use_count, "loaded": self.loaded, } # ============================================================ # 任务模型树核心 # ============================================================ class TaskTree: """ 任务模型树 — 按需加载的知识树 类比课程体系: - 根节点 = 用户的全部知识领域 - 分支节点 = 领域分类(编程/物理/写作/日常...) - 叶子节点 = 具体任务模型(python基础/力学/邮件写作...) 核心操作: 1. execute(): 按TaskAnalysis路由到对应叶子节点执行 2. record_usage(): 记录使用情况,用于后续优化 3. load/unload: 按需加载/卸载模型(内存管理) """ DATA_DIR = "/home/admin/swarm/data/task_tree" MAX_LOADED = 3 # 同时最多加载的模型数 def __init__(self, user_id: str = "default"): self.user_id = user_id self._nodes: Dict[str, TaskNode] = {} self._root_id = "root" self._loaded_models: Dict[str, float] = {} # model_id -> load_time # 统计 self._execute_count = 0 self._hit_count = 0 self._load_count = 0 self._unload_count = 0 # 初始化默认知识树 self._init_default_tree() # 加载用户自定义树 self._load() # ============================================================ # 默认知识树 # ============================================================ def _init_default_tree(self): """初始化默认的任务模型知识树""" # 根节点 root = TaskNode("root", "全部领域") self._nodes["root"] = root # 一级领域 domains = { "coding": ("编程", ["python", "web", "algorithm", "database"]), "science": ("科学", ["physics", "math", "chemistry"]), "writing": ("写作", ["email", "report", "creative"]), "daily": ("日常", ["chat", "translate", "search"]), "work": ("工作", ["office", "project", "meeting"]), } for domain_id, (domain_name, leaf_names) in domains.items(): # 创建领域分支 domain_node = TaskNode( domain_id, domain_name, domain=domain_name, parent_id="root" ) self._nodes[domain_id] = domain_node root.children.append(domain_id) # 创建叶子节点 for leaf_name in leaf_names: leaf_id = f"{domain_id}.{leaf_name}" leaf_node = TaskNode( leaf_id, leaf_name, domain=domain_name, parent_id=domain_id, is_leaf=True, model_id=f"drone_{domain_id}_{leaf_name}" ) self._nodes[leaf_id] = leaf_node domain_node.children.append(leaf_id) # ============================================================ # 核心接口 # ============================================================ def execute(self, analysis, query: str) -> Optional[str]: """ 根据TaskAnalysis找到对应叶子节点执行 当前实现: 查找匹配的叶子节点,返回模型ID 未来: 实际加载并执行对应的小模型 """ self._execute_count += 1 # 根据知识领域查找节点 domains = getattr(analysis, 'knowledge_domains', []) intent = getattr(analysis, 'intent', 'chat') # 映射: intent → 可能的领域节点 intent_domain_map = { "code": "coding", "reasoning": "science", "compute": "science", "translate": "daily", "query": "daily", "chat": "daily", "memory": "daily", "write": "writing", } target_domain = intent_domain_map.get(intent, "daily") # 如果有精确领域匹配,优先用 domain_name_map = { "物理": "science", "编程": "coding", "数学": "science", "写作": "writing", "日常": "daily", } for d in domains: if d in domain_name_map: target_domain = domain_name_map[d] break # 查找该领域下的叶子节点 if target_domain in self._nodes: domain_node = self._nodes[target_domain] if domain_node.children: # 取第一个叶子作为默认(后续可以做更精细的匹配) leaf_id = domain_node.children[0] leaf = self._nodes.get(leaf_id) if leaf and leaf.is_leaf: # 标记加载和使用 self._ensure_loaded(leaf.model_id) leaf.use_count += 1 leaf.last_used = time.time() self._hit_count += 1 # 当前阶段: 返回None让QueenAgent走降级路径 # 未来: 这里实际执行叶子节点关联的小模型 return None return None def record_usage(self, analysis): """记录使用情况,用于知识树优化""" intent = getattr(analysis, 'intent', 'chat') domains = getattr(analysis, 'knowledge_domains', []) route = getattr(analysis, 'route', '') # 记录到对应节点 for domain in domains: domain_name_map = { "物理": "science", "编程": "coding", "数学": "science", "写作": "writing", "日常": "daily", } domain_id = domain_name_map.get(domain) if domain_id and domain_id in self._nodes: self._nodes[domain_id].use_count += 1 self._nodes[domain_id].last_used = time.time() # ============================================================ # 动态加载管理 # ============================================================ def _ensure_loaded(self, model_id: str): """确保模型已加载,超限时淘汰最久未用的""" if model_id in self._loaded_models: self._loaded_models[model_id] = time.time() return # 检查是否超限 if len(self._loaded_models) >= self.MAX_LOADED: self._evict_one() # 加载 self._loaded_models[model_id] = time.time() self._load_count += 1 logger.debug(f"加载任务模型: {model_id}") def _evict_one(self): """淘汰最久未使用的模型""" if not self._loaded_models: return # LRU: 淘汰最早加载的 oldest = min(self._loaded_models, key=self._loaded_models.get) del self._loaded_models[oldest] self._unload_count += 1 logger.debug(f"卸载任务模型: {oldest}") def unload_all(self): """卸载所有模型""" self._loaded_models.clear() for node in self._nodes.values(): node.loaded = False # ============================================================ # 知识树操作 # ============================================================ def add_node(self, parent_id: str, node_id: str, name: str, is_leaf: bool = False, model_id: str = "") -> bool: """添加节点到知识树""" if parent_id not in self._nodes: logger.warning(f"父节点 {parent_id} 不存在") return False if node_id in self._nodes: logger.warning(f"节点 {node_id} 已存在") return False parent = self._nodes[parent_id] node = TaskNode( node_id, name, domain=parent.domain, parent_id=parent_id, is_leaf=is_leaf, model_id=model_id, ) self._nodes[node_id] = node parent.children.append(node_id) return True def remove_node(self, node_id: str) -> bool: """移除节点(级联删除子节点)""" if node_id not in self._nodes or node_id == "root": return False node = self._nodes[node_id] # 递归删除子节点 for child_id in list(node.children): self.remove_node(child_id) # 从父节点移除引用 if node.parent_id in self._nodes: parent = self._nodes[node.parent_id] parent.children = [c for c in parent.children if c != node_id] del self._nodes[node_id] return True def get_tree(self, node_id: str = "root", depth: int = 0) -> Dict: """获取知识树结构(递归)""" if node_id not in self._nodes: return {} node = self._nodes[node_id] result = { "id": node.node_id, "name": node.name, "domain": node.domain, "use_count": node.use_count, } if node.children: result["children"] = [ self.get_tree(c, depth + 1) for c in node.children ] return result # ============================================================ # 持久化 # ============================================================ def _load(self): """加载用户自定义的知识树""" filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json") if not os.path.exists(filepath): return try: with open(filepath, "r", encoding="utf-8") as f: data = json.load(f) for item in data.get("custom_nodes", []): node = TaskNode( node_id=item["node_id"], name=item["name"], domain=item.get("domain", ""), parent_id=item.get("parent_id", "root"), is_leaf=item.get("is_leaf", False), model_id=item.get("model_id", ""), ) node.use_count = item.get("use_count", 0) self._nodes[node.node_id] = node if node.parent_id in self._nodes: if node.node_id not in self._nodes[node.parent_id].children: self._nodes[node.parent_id].children.append(node.node_id) logger.info(f"任务树加载: +{len(data.get('custom_nodes', []))}自定义节点") except Exception as e: logger.warning(f"任务树加载失败: {e}") def save(self): """保存知识树""" os.makedirs(self.DATA_DIR, exist_ok=True) filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json") custom_nodes = [] for node in self._nodes.values(): if node.node_id == "root" or "." not in node.node_id: continue # 跳过默认节点 custom_nodes.append(node.to_dict()) try: with open(filepath, "w", encoding="utf-8") as f: json.dump({"custom_nodes": custom_nodes}, f, ensure_ascii=False, indent=2) except Exception as e: logger.warning(f"任务树保存失败: {e}") # ============================================================ # 状态查询 # ============================================================ def get_status(self) -> Dict: """获取任务树状态""" leaf_count = sum(1 for n in self._nodes.values() if n.is_leaf) return { "user_id": self.user_id, "total_nodes": len(self._nodes), "leaf_nodes": leaf_count, "loaded_models": len(self._loaded_models), "max_loaded": self.MAX_LOADED, "execute_count": self._execute_count, "hit_count": self._hit_count, "load_count": self._load_count, "unload_count": self._unload_count, }