swarm-backend / core /moa_engine.py
lk080424's picture
Upload folder using huggingface_hub
17fba62 verified
#!/usr/bin/env python3
"""
虫群智能体系统 — MOA多模型聚合引擎
Mixtures of Agents: 路由 → 并行执行 → 聚合
"""
import logging
import os
import time
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional
import requests
from core.types import (
ModelResult, AggregationResult, AggregationMethod,
TaskContext, ComplexityLevel, ModelType
)
from core.model_registry import ModelRegistry, get_registry
from core.router import SwarmRouter
from core.aggregation import AggregationEngine
from core.memory_core import MemoryCore
from core.performance_monitor import get_monitor
from core.conversation import get_conversation_manager
from core.health_checker import get_health_checker
from core.smart_cache import get_cache
from core.config_center import get_config
from core.logging_config import setup_logging
# 初始化日志系统
setup_logging("swarm")
logger = logging.getLogger(__name__)
class MOAEngine:
"""MOA多模型聚合引擎 — 核心处理流水线"""
def __init__(self, registry: ModelRegistry = None):
self.registry = registry or get_registry()
self.router = SwarmRouter(self.registry)
self.aggregator = AggregationEngine()
self.memory = MemoryCore()
# 配置中心
self._config = get_config()
# 加载API密钥(优先从配置中心读取)
self._api_keys = self._load_api_keys()
# 技能管线(延迟初始化)
self._pipeline = None
# 性能监控
self._monitor = get_monitor()
# 对话上下文管理
self._conv_mgr = get_conversation_manager()
# 健康检查
self._health = get_health_checker()
# 智能缓存
self._cache = get_cache()
logger.info("MOA引擎初始化完成(含配置中心+日志系统)")
@property
def pipeline(self):
"""延迟加载技能管线"""
if self._pipeline is None:
try:
from skills import create_default_pipeline
self._pipeline = create_default_pipeline()
logger.info("技能管线初始化完成")
except Exception as e:
logger.warning(f"技能管线初始化失败: {e}")
return self._pipeline
# ============================================================
# 主入口
# ============================================================
def process(self, query: str, user_id: str = "default",
conversation_id: str = "",
method: AggregationMethod = None) -> AggregationResult:
"""
处理用户查询的完整流水线:
0. 技能管线预处理(安全检查 + 文本分析 + 任务分类)
1. 路由分析(复杂度 + 模型链推荐)
2. 并行执行模型
3. 聚合结果
4. 存储记忆
"""
start_time = time.time()
# -1. 缓存查询
cached = self._cache.get(query)
if cached:
logger.info(f"缓存命中: {query[:30]} (命中{cached.hit_count}次)")
return AggregationResult(
query=query,
final_response=cached.response,
primary_model=f"{cached.model_id}(cached)",
)
# 0. 技能管线预处理
skill_analysis = self._run_skill_pipeline(query, user_id, conversation_id)
# 安全过滤:如果不安全,直接拒绝
if skill_analysis and not skill_analysis.get("is_safe", True):
return AggregationResult(
query=query,
final_response="抱歉,您的请求包含不安全内容,无法处理。",
primary_model="safety_filter",
)
# 1. 路由
ctx = self.router.analyze(query, user_id, conversation_id)
# 注入技能分析结果到上下文
if skill_analysis and ctx.metadata is not None:
ctx.metadata["skill_analysis"] = skill_analysis
# 利用任务分类调整路由
task_cat = skill_analysis.get("task_category", "")
if task_cat in ("code", "analysis"):
# 代码和分析类任务提升复杂度
if ctx.complexity_score < 0.50:
ctx.complexity_score = min(ctx.complexity_score + 0.15, 0.70)
ctx.compute_complexity_level()
ctx.model_chain = self.router._recommend_models(ctx)
logger.info(f"技能调整: {task_cat} → 复杂度提升至 {ctx.complexity_score:.2f}")
logger.info(f"MOA路由: 复杂度={ctx.complexity_score:.2f} "
f"模型链={ctx.model_chain}")
if not ctx.model_chain:
# 无可用模型,直接返回记忆
mem_ctx = self.memory.get_relevant_context(query, user_id, top_k=3)
return AggregationResult(
query=query,
final_response=mem_ctx or "暂无可用模型处理您的请求。",
primary_model="memory_only",
)
# 2. 并行执行
results = self._execute_models(ctx)
# 3. 聚合
agg_result = self.aggregator.aggregate(
query, results, method=method or AggregationMethod.CONFIDENCE
)
# 4. 存储缓存
try:
if agg_result.final_response and agg_result.primary_model:
self._cache.put(
query, agg_result.final_response,
agg_result.primary_model, agg_result.primary_confidence
)
except Exception as e:
logger.warning(f"缓存存储失败: {e}")
# 5. 存储记忆(后台,不阻塞)
try:
self.memory.store(
user_id=user_id,
conversation_id=conversation_id,
title=query[:50],
user_message=query,
ai_response=agg_result.final_response,
)
except Exception as e:
logger.warning(f"记忆存储失败: {e}")
# 5. 记录对话历史
try:
self._conv_mgr.add_message(
"user", query, conversation_id=conversation_id)
self._conv_mgr.add_message(
"assistant", agg_result.final_response,
conversation_id=conversation_id,
model=agg_result.primary_model,
)
except Exception as e:
logger.warning(f"对话记录失败: {e}")
elapsed_ms = (time.time() - start_time) * 1000
logger.info(f"MOA处理完成: {elapsed_ms:.0f}ms 模型数={len(results)} "
f"主模型={agg_result.primary_model}")
return agg_result
# ============================================================
# 模型执行
# ============================================================
def _execute_models(self, ctx: TaskContext) -> List[ModelResult]:
"""并行执行模型链中的所有模型"""
results = []
model_ids = ctx.model_chain
with ThreadPoolExecutor(max_workers=len(model_ids)) as executor:
future_map = {}
for mid in model_ids:
model = self.registry.get(mid)
if not model:
continue
if model.is_local:
future = executor.submit(self._exec_local, model, ctx)
else:
future = executor.submit(self._exec_api, model, ctx)
future_map[future] = mid
for future in as_completed(future_map):
mid = future_map[future]
try:
result = future.result()
results.append(result)
# 记录性能到注册表
self.registry.record_performance(
mid, result.latency_ms, result.confidence, result.success
)
# 记录到性能监控器
self._monitor.record(
model_id=mid,
latency_ms=result.latency_ms,
confidence=result.confidence,
success=result.success,
query=ctx.query[:100],
error=result.error or "",
)
# 记录到健康检查器
if result.success:
self._health.record_success(mid, result.latency_ms)
else:
self._health.record_failure(mid)
except Exception as e:
logger.error(f"模型 {mid} 执行异常: {e}")
results.append(ModelResult(
model_id=mid, response="",
success=False, error=str(e)
))
return results
def _exec_local(self, model, ctx: TaskContext) -> ModelResult:
"""执行本地模型(本地推理 + 快速规则 + 记忆检索)"""
start = time.time()
try:
# 0. 尝试本地推理模型
local_result = self._exec_local_inference(model.model_id, ctx.query)
if local_result:
return local_result
# 1. 简单规则匹配(快速响应常见模式)
quick = self._quick_match(ctx.query)
if quick:
return ModelResult(
model_id=model.model_id,
response=quick,
confidence=0.6,
latency_ms=(time.time() - start) * 1000,
success=True,
)
# 2. 记忆检索
memories = self.memory.retrieve(ctx.query, user_id=ctx.user_id, top_k=5)
if memories:
response_parts = []
for m in memories[:3]:
response_parts.append(m.get("ai_response", ""))
response = "\n".join(response_parts) if response_parts else ""
confidence = min(0.3 + len(memories) * 0.1, 0.7)
else:
response = "暂无相关记忆,建议使用API模型获取更准确的回复。"
confidence = 0.1
latency = (time.time() - start) * 1000
return ModelResult(
model_id=model.model_id,
response=response,
confidence=confidence,
latency_ms=latency,
success=bool(response),
)
except Exception as e:
return ModelResult(
model_id=model.model_id, response="",
confidence=0.0, success=False, error=str(e),
latency_ms=(time.time() - start) * 1000,
)
def _exec_local_inference(self, model_id: str, query: str) -> Optional[ModelResult]:
"""调用本地推理模型(训练好的SwarmModel)"""
try:
from core.local_inference import get_local_backend
backend = get_local_backend()
# 检查该模型是否可用
available = backend.list_available()
# model_id如"swarm_tiny"映射到"tiny"
target = None
for name in available:
if name in model_id or model_id in f"swarm_{name}":
target = name
break
if not target:
return None
result = backend.infer(target, query, max_new_tokens=128)
if result["success"]:
return ModelResult(
model_id=model_id,
response=result["response"],
confidence=0.7, # 本地模型置信度
latency_ms=result["latency_ms"],
success=True,
)
except Exception as e:
logger.debug(f"本地推理不可用({model_id}): {e}")
return None
def _exec_api(self, model, ctx: TaskContext) -> ModelResult:
"""执行API模型(OpenAI兼容格式)"""
start = time.time()
# 获取密钥
api_key = self._api_keys.get(model.model_id, "")
if not api_key:
return ModelResult(
model_id=model.model_id, response="",
confidence=0.0, success=False,
error=f"API密钥未配置: {model.model_id}",
latency_ms=(time.time() - start) * 1000,
)
# 构建请求
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# 获取记忆上下文
mem_ctx = self.memory.get_relevant_context(ctx.query, ctx.user_id, top_k=3)
# 获取对话上下文(多轮对话历史)
conv_id = ctx.conversation_id or ""
conv_messages = self._conv_mgr.get_context(conv_id, max_tokens=2000)
# 构建messages:对话历史 + 当前查询
messages = []
if mem_ctx:
messages.append({
"role": "system",
"content": f"以下是相关历史上下文,供参考:\n{mem_ctx[:500]}"
})
if conv_messages:
messages.extend(conv_messages)
# 确保最后一条是当前用户查询
messages.append({"role": "user", "content": ctx.query})
payload = {
"model": self._get_api_model_name(model.model_id),
"messages": messages,
"max_tokens": model.max_tokens,
}
try:
# 重试逻辑:最多3次,指数退避
max_retries = 3
for attempt in range(max_retries):
try:
resp = requests.post(
model.endpoint,
headers=headers,
json=payload,
timeout=30,
)
latency = (time.time() - start) * 1000
if resp.status_code == 200:
data = resp.json()
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
confidence = model.default_confidence
return ModelResult(
model_id=model.model_id,
response=content,
confidence=confidence,
latency_ms=latency,
success=bool(content),
)
elif resp.status_code in (429, 502, 503, 504):
# 可重试的错误码
if attempt < max_retries - 1:
wait = 2 ** attempt # 1s, 2s, 4s
logger.warning(f"{model.model_id} 返回 {resp.status_code}{wait}秒后重试({attempt+1}/{max_retries})")
time.sleep(wait)
continue
error_msg = f"API错误 {resp.status_code} (重试{max_retries}次后仍失败)"
else:
error_msg = f"API错误 {resp.status_code}: {resp.text[:200]}"
break # 4xx错误不重试
except requests.exceptions.Timeout:
if attempt < max_retries - 1:
wait = 2 ** attempt
logger.warning(f"{model.model_id} 超时,{wait}秒后重试({attempt+1}/{max_retries})")
time.sleep(wait)
continue
error_msg = f"API超时 (重试{max_retries}次后仍超时)"
break
except requests.exceptions.ConnectionError:
if attempt < max_retries - 1:
wait = 2 ** attempt
logger.warning(f"{model.model_id} 连接失败,{wait}秒后重试({attempt+1}/{max_retries})")
time.sleep(wait)
continue
error_msg = f"API连接失败 (重试{max_retries}次后仍失败)"
break
logger.error(f"{model.model_id}: {error_msg}")
latency = (time.time() - start) * 1000
# 故障转移:尝试备选模型
fallback_id = self._health.get_fallback(model.model_id)
if fallback_id:
fallback_model = self.registry.get(fallback_id)
if fallback_model:
logger.info(f"故障转移到 {fallback_id}")
return self._exec_api(fallback_model, ctx)
return ModelResult(
model_id=model.model_id, response="",
confidence=0.0, success=False, error=error_msg,
latency_ms=latency,
)
except Exception as e:
return ModelResult(
model_id=model.model_id, response="",
confidence=0.0, success=False, error=str(e),
latency_ms=(time.time() - start) * 1000,
)
# ============================================================
# 流式输出
# ============================================================
def process_stream(self, query: str, user_id: str = "default",
conversation_id: str = ""):
"""
流式处理用户查询,逐字返回结果。
生成器模式:yield (chunk_text, metadata_dict)
"""
# 0. 技能管线预处理
skill_analysis = self._run_skill_pipeline(query, user_id, conversation_id)
if skill_analysis and not skill_analysis.get("is_safe", True):
yield ("抱歉,您的请求包含不安全内容,无法处理。",
{"model": "safety_filter", "done": True})
return
# 1. 快速规则匹配
quick = self._quick_match(query)
if quick:
yield (quick, {"model": "local_memory", "done": True})
return
# 2. 路由分析
ctx = self.router.analyze(query, user_id, conversation_id)
# 技能分析调整
if skill_analysis and ctx.metadata is not None:
ctx.metadata["skill_analysis"] = skill_analysis
task_cat = skill_analysis.get("task_category", "")
if task_cat in ("code", "analysis"):
if ctx.complexity_score < 0.50:
ctx.complexity_score = min(ctx.complexity_score + 0.15, 0.70)
ctx.compute_complexity_level()
ctx.model_chain = self.router._recommend_models(ctx)
# 3. 选择最佳API模型进行流式调用
api_model = None
for mid in ctx.model_chain:
model = self.registry.get(mid)
if model and not model.is_local:
api_model = model
break
if not api_model:
# 无API模型,走本地记忆
mem_ctx = self.memory.get_relevant_context(query, user_id, top_k=3)
yield (mem_ctx or "暂无可用模型处理您的请求。",
{"model": "memory_only", "done": True})
return
# 4. 流式API调用
full_response = []
for chunk_text, done in self._stream_api(api_model, ctx):
full_response.append(chunk_text)
yield (chunk_text, {"model": api_model.model_id, "done": done})
# 5. 存储记忆
try:
self.memory.store(
user_id=user_id,
conversation_id=conversation_id,
title=query[:50],
user_message=query,
ai_response="".join(full_response),
)
except Exception as e:
logger.warning(f"流式记忆存储失败: {e}")
def _stream_api(self, model, ctx: TaskContext):
"""
流式调用API模型,逐步yield文本片段。
yield (chunk_text, is_done)
"""
api_key = self._api_keys.get(model.model_id, "")
if not api_key:
yield ("API密钥未配置", True)
return
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# 获取记忆上下文
mem_ctx = self.memory.get_relevant_context(ctx.query, ctx.user_id, top_k=3)
# 获取对话上下文(多轮对话历史)
conv_id = ctx.conversation_id or ""
conv_messages = self._conv_mgr.get_context(conv_id, max_tokens=2000)
messages = []
if mem_ctx:
messages.append({
"role": "system",
"content": f"以下是相关历史上下文,供参考:\n{mem_ctx[:500]}"
})
if conv_messages:
messages.extend(conv_messages)
messages.append({"role": "user", "content": ctx.query})
payload = {
"model": self._get_api_model_name(model.model_id),
"messages": messages,
"max_tokens": model.max_tokens,
"stream": True, # 关键:启用流式
}
try:
resp = requests.post(
model.endpoint,
headers=headers,
json=payload,
timeout=60,
stream=True, # HTTP流式响应
)
if resp.status_code != 200:
yield (f"API错误 {resp.status_code}", True)
return
# 解析SSE流
for line in resp.iter_lines(decode_unicode=True):
if not line or not line.startswith("data:"):
continue
data_str = line[5:].strip()
if data_str == "[DONE]":
yield ("", True)
return
try:
data = json.loads(data_str)
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
yield (content, False)
except json.JSONDecodeError:
continue
yield ("", True)
except requests.exceptions.Timeout:
yield ("请求超时,请稍后重试", True)
except Exception as e:
yield (f"流式调用异常: {str(e)[:50]}", True)
# ============================================================
# 辅助
# ============================================================
def _load_api_keys(self) -> Dict[str, str]:
"""从配置中心或api.env加载API密钥"""
keys = {}
# 优先从配置中心读取
try:
config_keys = self._config.get("api_keys", {})
if config_keys:
keys.update(config_keys)
logger.info(f"从配置中心加载{len(config_keys)}个API密钥")
except Exception:
pass
# 补充从api.env读取
env_path = "/home/admin/swarm/api.env"
if os.path.exists(env_path):
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if "=" in line and not line.startswith("#"):
k, v = line.split("=", 1)
k, v = k.strip(), v.strip()
if "GLM" in k and "KEY" in k:
keys.setdefault("glm_api", v)
except Exception as e:
logger.warning(f"加载API密钥失败: {e}")
return keys
def _get_api_model_name(self, model_id: str) -> str:
"""获取API调用的模型名"""
names = {
"glm_api": "glm-4-flash",
}
return names.get(model_id, model_id)
def _quick_match(self, query: str) -> str:
"""快速规则匹配 — 常见简单模式无需API"""
q = query.strip()
# 问候
greetings = ["你好", "您好", "嗨", "hi", "hello", "早上好", "下午好", "晚上好"]
if q.lower() in greetings:
return "你好!我是虫群智能体,有什么可以帮助你的吗?"
# 身份
if any(kw in q for kw in ["你是谁", "你叫什么", "自我介绍"]):
return "我是虫群智能体,基于多模型聚合架构,可以为你提供智能问答服务。"
# 确认
if q in ["好的", "明白", "收到", "ok", "谢谢", "感谢"]:
return "不客气!如有其他问题随时问我。"
return ""
def _run_skill_pipeline(self, query: str, user_id: str = "default",
conversation_id: str = "") -> dict:
"""执行技能管线,返回分析摘要"""
if not self.pipeline:
return None
try:
results = self.pipeline.execute(
query, user_id=user_id, session_id=conversation_id
)
# 提取关键信息
analysis = {"is_safe": True, "task_category": "", "text_type": ""}
# 安全过滤结果
if "safety_filter" in results:
sf = results["safety_filter"]
if sf.success and sf.result:
analysis["is_safe"] = sf.result.get("is_safe", True)
# 文本解析结果
if "text_parser" in results:
tp = results["text_parser"]
if tp.success and tp.result:
analysis["text_type"] = tp.result.get("text_type", "")
analysis["keywords"] = tp.result.get("keywords", [])[:5]
analysis["sentiment"] = tp.result.get("sentiment", "")
# 任务分类结果
if "task_classifier" in results:
tc = results["task_classifier"]
if tc.success and tc.result:
analysis["task_category"] = tc.result.get("category", "")
analysis["task_priority"] = tc.result.get("priority", 1)
return analysis
except Exception as e:
logger.warning(f"技能管线执行失败: {e}")
return None