swarm-backend / core /task_tree.py
lk080424's picture
Upload folder using huggingface_hub
17fba62 verified
#!/usr/bin/env python3
"""
虫群v7 — 任务模型树(Task Tree)
知识树结构的按需加载任务模型群
设计思路(类比高中物理课程体系):
- 高中物理 = 力学 + 电学 + 热学 + 光学 ...
- 力学 = 运动学 + 动力学 + 动量 + 能量 ...
- 每个知识点 = 一个任务模型(Drone)
- 上课时只加载当前章节的模型,用完卸载
- 这样内存中永远只有需要的模型在运行
"""
import json
import logging
import os
import time
from datetime import datetime
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# ============================================================
# 任务节点 — 知识树的叶子/分支
# ============================================================
class TaskNode:
"""
任务节点 — 知识树的一个节点
可以是分支(如"力学")或叶子(如"牛顿第二定律")
叶子节点关联一个具体的任务模型
"""
def __init__(self, node_id: str, name: str,
domain: str = "", parent_id: str = "",
is_leaf: bool = False, model_id: str = ""):
self.node_id = node_id
self.name = name
self.domain = domain
self.parent_id = parent_id
self.is_leaf = is_leaf
self.model_id = model_id # 叶子节点关联的模型ID
self.children: List[str] = [] # 子节点ID列表
# 运行时状态
self.loaded = False # 是否已加载到内存
self.use_count = 0 # 使用次数
self.last_used: Optional[float] = None
def to_dict(self) -> Dict:
return {
"node_id": self.node_id,
"name": self.name,
"domain": self.domain,
"parent_id": self.parent_id,
"is_leaf": self.is_leaf,
"model_id": self.model_id,
"children": self.children,
"use_count": self.use_count,
"loaded": self.loaded,
}
# ============================================================
# 任务模型树核心
# ============================================================
class TaskTree:
"""
任务模型树 — 按需加载的知识树
类比课程体系:
- 根节点 = 用户的全部知识领域
- 分支节点 = 领域分类(编程/物理/写作/日常...)
- 叶子节点 = 具体任务模型(python基础/力学/邮件写作...)
核心操作:
1. execute(): 按TaskAnalysis路由到对应叶子节点执行
2. record_usage(): 记录使用情况,用于后续优化
3. load/unload: 按需加载/卸载模型(内存管理)
"""
DATA_DIR = "/home/admin/swarm/data/task_tree"
MAX_LOADED = 3 # 同时最多加载的模型数
def __init__(self, user_id: str = "default"):
self.user_id = user_id
self._nodes: Dict[str, TaskNode] = {}
self._root_id = "root"
self._loaded_models: Dict[str, float] = {} # model_id -> load_time
# 统计
self._execute_count = 0
self._hit_count = 0
self._load_count = 0
self._unload_count = 0
# 初始化默认知识树
self._init_default_tree()
# 加载用户自定义树
self._load()
# ============================================================
# 默认知识树
# ============================================================
def _init_default_tree(self):
"""初始化默认的任务模型知识树"""
# 根节点
root = TaskNode("root", "全部领域")
self._nodes["root"] = root
# 一级领域
domains = {
"coding": ("编程", ["python", "web", "algorithm", "database"]),
"science": ("科学", ["physics", "math", "chemistry"]),
"writing": ("写作", ["email", "report", "creative"]),
"daily": ("日常", ["chat", "translate", "search"]),
"work": ("工作", ["office", "project", "meeting"]),
}
for domain_id, (domain_name, leaf_names) in domains.items():
# 创建领域分支
domain_node = TaskNode(
domain_id, domain_name,
domain=domain_name, parent_id="root"
)
self._nodes[domain_id] = domain_node
root.children.append(domain_id)
# 创建叶子节点
for leaf_name in leaf_names:
leaf_id = f"{domain_id}.{leaf_name}"
leaf_node = TaskNode(
leaf_id, leaf_name,
domain=domain_name, parent_id=domain_id,
is_leaf=True, model_id=f"drone_{domain_id}_{leaf_name}"
)
self._nodes[leaf_id] = leaf_node
domain_node.children.append(leaf_id)
# ============================================================
# 核心接口
# ============================================================
def execute(self, analysis, query: str) -> Optional[str]:
"""
根据TaskAnalysis找到对应叶子节点执行
当前实现: 查找匹配的叶子节点,返回模型ID
未来: 实际加载并执行对应的小模型
"""
self._execute_count += 1
# 根据知识领域查找节点
domains = getattr(analysis, 'knowledge_domains', [])
intent = getattr(analysis, 'intent', 'chat')
# 映射: intent → 可能的领域节点
intent_domain_map = {
"code": "coding", "reasoning": "science", "compute": "science",
"translate": "daily", "query": "daily", "chat": "daily",
"memory": "daily", "write": "writing",
}
target_domain = intent_domain_map.get(intent, "daily")
# 如果有精确领域匹配,优先用
domain_name_map = {
"物理": "science", "编程": "coding", "数学": "science",
"写作": "writing", "日常": "daily",
}
for d in domains:
if d in domain_name_map:
target_domain = domain_name_map[d]
break
# 查找该领域下的叶子节点
if target_domain in self._nodes:
domain_node = self._nodes[target_domain]
if domain_node.children:
# 取第一个叶子作为默认(后续可以做更精细的匹配)
leaf_id = domain_node.children[0]
leaf = self._nodes.get(leaf_id)
if leaf and leaf.is_leaf:
# 标记加载和使用
self._ensure_loaded(leaf.model_id)
leaf.use_count += 1
leaf.last_used = time.time()
self._hit_count += 1
# 当前阶段: 返回None让QueenAgent走降级路径
# 未来: 这里实际执行叶子节点关联的小模型
return None
return None
def record_usage(self, analysis):
"""记录使用情况,用于知识树优化"""
intent = getattr(analysis, 'intent', 'chat')
domains = getattr(analysis, 'knowledge_domains', [])
route = getattr(analysis, 'route', '')
# 记录到对应节点
for domain in domains:
domain_name_map = {
"物理": "science", "编程": "coding", "数学": "science",
"写作": "writing", "日常": "daily",
}
domain_id = domain_name_map.get(domain)
if domain_id and domain_id in self._nodes:
self._nodes[domain_id].use_count += 1
self._nodes[domain_id].last_used = time.time()
# ============================================================
# 动态加载管理
# ============================================================
def _ensure_loaded(self, model_id: str):
"""确保模型已加载,超限时淘汰最久未用的"""
if model_id in self._loaded_models:
self._loaded_models[model_id] = time.time()
return
# 检查是否超限
if len(self._loaded_models) >= self.MAX_LOADED:
self._evict_one()
# 加载
self._loaded_models[model_id] = time.time()
self._load_count += 1
logger.debug(f"加载任务模型: {model_id}")
def _evict_one(self):
"""淘汰最久未使用的模型"""
if not self._loaded_models:
return
# LRU: 淘汰最早加载的
oldest = min(self._loaded_models, key=self._loaded_models.get)
del self._loaded_models[oldest]
self._unload_count += 1
logger.debug(f"卸载任务模型: {oldest}")
def unload_all(self):
"""卸载所有模型"""
self._loaded_models.clear()
for node in self._nodes.values():
node.loaded = False
# ============================================================
# 知识树操作
# ============================================================
def add_node(self, parent_id: str, node_id: str, name: str,
is_leaf: bool = False, model_id: str = "") -> bool:
"""添加节点到知识树"""
if parent_id not in self._nodes:
logger.warning(f"父节点 {parent_id} 不存在")
return False
if node_id in self._nodes:
logger.warning(f"节点 {node_id} 已存在")
return False
parent = self._nodes[parent_id]
node = TaskNode(
node_id, name,
domain=parent.domain,
parent_id=parent_id,
is_leaf=is_leaf,
model_id=model_id,
)
self._nodes[node_id] = node
parent.children.append(node_id)
return True
def remove_node(self, node_id: str) -> bool:
"""移除节点(级联删除子节点)"""
if node_id not in self._nodes or node_id == "root":
return False
node = self._nodes[node_id]
# 递归删除子节点
for child_id in list(node.children):
self.remove_node(child_id)
# 从父节点移除引用
if node.parent_id in self._nodes:
parent = self._nodes[node.parent_id]
parent.children = [c for c in parent.children if c != node_id]
del self._nodes[node_id]
return True
def get_tree(self, node_id: str = "root", depth: int = 0) -> Dict:
"""获取知识树结构(递归)"""
if node_id not in self._nodes:
return {}
node = self._nodes[node_id]
result = {
"id": node.node_id,
"name": node.name,
"domain": node.domain,
"use_count": node.use_count,
}
if node.children:
result["children"] = [
self.get_tree(c, depth + 1) for c in node.children
]
return result
# ============================================================
# 持久化
# ============================================================
def _load(self):
"""加载用户自定义的知识树"""
filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json")
if not os.path.exists(filepath):
return
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data.get("custom_nodes", []):
node = TaskNode(
node_id=item["node_id"],
name=item["name"],
domain=item.get("domain", ""),
parent_id=item.get("parent_id", "root"),
is_leaf=item.get("is_leaf", False),
model_id=item.get("model_id", ""),
)
node.use_count = item.get("use_count", 0)
self._nodes[node.node_id] = node
if node.parent_id in self._nodes:
if node.node_id not in self._nodes[node.parent_id].children:
self._nodes[node.parent_id].children.append(node.node_id)
logger.info(f"任务树加载: +{len(data.get('custom_nodes', []))}自定义节点")
except Exception as e:
logger.warning(f"任务树加载失败: {e}")
def save(self):
"""保存知识树"""
os.makedirs(self.DATA_DIR, exist_ok=True)
filepath = os.path.join(self.DATA_DIR, f"{self.user_id}.json")
custom_nodes = []
for node in self._nodes.values():
if node.node_id == "root" or "." not in node.node_id:
continue # 跳过默认节点
custom_nodes.append(node.to_dict())
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump({"custom_nodes": custom_nodes}, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.warning(f"任务树保存失败: {e}")
# ============================================================
# 状态查询
# ============================================================
def get_status(self) -> Dict:
"""获取任务树状态"""
leaf_count = sum(1 for n in self._nodes.values() if n.is_leaf)
return {
"user_id": self.user_id,
"total_nodes": len(self._nodes),
"leaf_nodes": leaf_count,
"loaded_models": len(self._loaded_models),
"max_loaded": self.MAX_LOADED,
"execute_count": self._execute_count,
"hit_count": self._hit_count,
"load_count": self._load_count,
"unload_count": self._unload_count,
}