Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群智能体系统 — MOA多模型聚合引擎 | |
| Mixtures of Agents: 路由 → 并行执行 → 聚合 | |
| """ | |
| import logging | |
| import os | |
| import time | |
| import json | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import Dict, List, Optional | |
| import requests | |
| from core.types import ( | |
| ModelResult, AggregationResult, AggregationMethod, | |
| TaskContext, ComplexityLevel, ModelType | |
| ) | |
| from core.model_registry import ModelRegistry, get_registry | |
| from core.router import SwarmRouter | |
| from core.aggregation import AggregationEngine | |
| from core.memory_core import MemoryCore | |
| from core.performance_monitor import get_monitor | |
| from core.conversation import get_conversation_manager | |
| from core.health_checker import get_health_checker | |
| from core.smart_cache import get_cache | |
| from core.config_center import get_config | |
| from core.logging_config import setup_logging | |
| # 初始化日志系统 | |
| setup_logging("swarm") | |
| logger = logging.getLogger(__name__) | |
| class MOAEngine: | |
| """MOA多模型聚合引擎 — 核心处理流水线""" | |
| def __init__(self, registry: ModelRegistry = None): | |
| self.registry = registry or get_registry() | |
| self.router = SwarmRouter(self.registry) | |
| self.aggregator = AggregationEngine() | |
| self.memory = MemoryCore() | |
| # 配置中心 | |
| self._config = get_config() | |
| # 加载API密钥(优先从配置中心读取) | |
| self._api_keys = self._load_api_keys() | |
| # 技能管线(延迟初始化) | |
| self._pipeline = None | |
| # 性能监控 | |
| self._monitor = get_monitor() | |
| # 对话上下文管理 | |
| self._conv_mgr = get_conversation_manager() | |
| # 健康检查 | |
| self._health = get_health_checker() | |
| # 智能缓存 | |
| self._cache = get_cache() | |
| logger.info("MOA引擎初始化完成(含配置中心+日志系统)") | |
| def pipeline(self): | |
| """延迟加载技能管线""" | |
| if self._pipeline is None: | |
| try: | |
| from skills import create_default_pipeline | |
| self._pipeline = create_default_pipeline() | |
| logger.info("技能管线初始化完成") | |
| except Exception as e: | |
| logger.warning(f"技能管线初始化失败: {e}") | |
| return self._pipeline | |
| # ============================================================ | |
| # 主入口 | |
| # ============================================================ | |
| def process(self, query: str, user_id: str = "default", | |
| conversation_id: str = "", | |
| method: AggregationMethod = None) -> AggregationResult: | |
| """ | |
| 处理用户查询的完整流水线: | |
| 0. 技能管线预处理(安全检查 + 文本分析 + 任务分类) | |
| 1. 路由分析(复杂度 + 模型链推荐) | |
| 2. 并行执行模型 | |
| 3. 聚合结果 | |
| 4. 存储记忆 | |
| """ | |
| start_time = time.time() | |
| # -1. 缓存查询 | |
| cached = self._cache.get(query) | |
| if cached: | |
| logger.info(f"缓存命中: {query[:30]} (命中{cached.hit_count}次)") | |
| return AggregationResult( | |
| query=query, | |
| final_response=cached.response, | |
| primary_model=f"{cached.model_id}(cached)", | |
| ) | |
| # 0. 技能管线预处理 | |
| skill_analysis = self._run_skill_pipeline(query, user_id, conversation_id) | |
| # 安全过滤:如果不安全,直接拒绝 | |
| if skill_analysis and not skill_analysis.get("is_safe", True): | |
| return AggregationResult( | |
| query=query, | |
| final_response="抱歉,您的请求包含不安全内容,无法处理。", | |
| primary_model="safety_filter", | |
| ) | |
| # 1. 路由 | |
| ctx = self.router.analyze(query, user_id, conversation_id) | |
| # 注入技能分析结果到上下文 | |
| if skill_analysis and ctx.metadata is not None: | |
| ctx.metadata["skill_analysis"] = skill_analysis | |
| # 利用任务分类调整路由 | |
| task_cat = skill_analysis.get("task_category", "") | |
| if task_cat in ("code", "analysis"): | |
| # 代码和分析类任务提升复杂度 | |
| if ctx.complexity_score < 0.50: | |
| ctx.complexity_score = min(ctx.complexity_score + 0.15, 0.70) | |
| ctx.compute_complexity_level() | |
| ctx.model_chain = self.router._recommend_models(ctx) | |
| logger.info(f"技能调整: {task_cat} → 复杂度提升至 {ctx.complexity_score:.2f}") | |
| logger.info(f"MOA路由: 复杂度={ctx.complexity_score:.2f} " | |
| f"模型链={ctx.model_chain}") | |
| if not ctx.model_chain: | |
| # 无可用模型,直接返回记忆 | |
| mem_ctx = self.memory.get_relevant_context(query, user_id, top_k=3) | |
| return AggregationResult( | |
| query=query, | |
| final_response=mem_ctx or "暂无可用模型处理您的请求。", | |
| primary_model="memory_only", | |
| ) | |
| # 2. 并行执行 | |
| results = self._execute_models(ctx) | |
| # 3. 聚合 | |
| agg_result = self.aggregator.aggregate( | |
| query, results, method=method or AggregationMethod.CONFIDENCE | |
| ) | |
| # 4. 存储缓存 | |
| try: | |
| if agg_result.final_response and agg_result.primary_model: | |
| self._cache.put( | |
| query, agg_result.final_response, | |
| agg_result.primary_model, agg_result.primary_confidence | |
| ) | |
| except Exception as e: | |
| logger.warning(f"缓存存储失败: {e}") | |
| # 5. 存储记忆(后台,不阻塞) | |
| try: | |
| self.memory.store( | |
| user_id=user_id, | |
| conversation_id=conversation_id, | |
| title=query[:50], | |
| user_message=query, | |
| ai_response=agg_result.final_response, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"记忆存储失败: {e}") | |
| # 5. 记录对话历史 | |
| try: | |
| self._conv_mgr.add_message( | |
| "user", query, conversation_id=conversation_id) | |
| self._conv_mgr.add_message( | |
| "assistant", agg_result.final_response, | |
| conversation_id=conversation_id, | |
| model=agg_result.primary_model, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"对话记录失败: {e}") | |
| elapsed_ms = (time.time() - start_time) * 1000 | |
| logger.info(f"MOA处理完成: {elapsed_ms:.0f}ms 模型数={len(results)} " | |
| f"主模型={agg_result.primary_model}") | |
| return agg_result | |
| # ============================================================ | |
| # 模型执行 | |
| # ============================================================ | |
| def _execute_models(self, ctx: TaskContext) -> List[ModelResult]: | |
| """并行执行模型链中的所有模型""" | |
| results = [] | |
| model_ids = ctx.model_chain | |
| with ThreadPoolExecutor(max_workers=len(model_ids)) as executor: | |
| future_map = {} | |
| for mid in model_ids: | |
| model = self.registry.get(mid) | |
| if not model: | |
| continue | |
| if model.is_local: | |
| future = executor.submit(self._exec_local, model, ctx) | |
| else: | |
| future = executor.submit(self._exec_api, model, ctx) | |
| future_map[future] = mid | |
| for future in as_completed(future_map): | |
| mid = future_map[future] | |
| try: | |
| result = future.result() | |
| results.append(result) | |
| # 记录性能到注册表 | |
| self.registry.record_performance( | |
| mid, result.latency_ms, result.confidence, result.success | |
| ) | |
| # 记录到性能监控器 | |
| self._monitor.record( | |
| model_id=mid, | |
| latency_ms=result.latency_ms, | |
| confidence=result.confidence, | |
| success=result.success, | |
| query=ctx.query[:100], | |
| error=result.error or "", | |
| ) | |
| # 记录到健康检查器 | |
| if result.success: | |
| self._health.record_success(mid, result.latency_ms) | |
| else: | |
| self._health.record_failure(mid) | |
| except Exception as e: | |
| logger.error(f"模型 {mid} 执行异常: {e}") | |
| results.append(ModelResult( | |
| model_id=mid, response="", | |
| success=False, error=str(e) | |
| )) | |
| return results | |
| def _exec_local(self, model, ctx: TaskContext) -> ModelResult: | |
| """执行本地模型(本地推理 + 快速规则 + 记忆检索)""" | |
| start = time.time() | |
| try: | |
| # 0. 尝试本地推理模型 | |
| local_result = self._exec_local_inference(model.model_id, ctx.query) | |
| if local_result: | |
| return local_result | |
| # 1. 简单规则匹配(快速响应常见模式) | |
| quick = self._quick_match(ctx.query) | |
| if quick: | |
| return ModelResult( | |
| model_id=model.model_id, | |
| response=quick, | |
| confidence=0.6, | |
| latency_ms=(time.time() - start) * 1000, | |
| success=True, | |
| ) | |
| # 2. 记忆检索 | |
| memories = self.memory.retrieve(ctx.query, user_id=ctx.user_id, top_k=5) | |
| if memories: | |
| response_parts = [] | |
| for m in memories[:3]: | |
| response_parts.append(m.get("ai_response", "")) | |
| response = "\n".join(response_parts) if response_parts else "" | |
| confidence = min(0.3 + len(memories) * 0.1, 0.7) | |
| else: | |
| response = "暂无相关记忆,建议使用API模型获取更准确的回复。" | |
| confidence = 0.1 | |
| latency = (time.time() - start) * 1000 | |
| return ModelResult( | |
| model_id=model.model_id, | |
| response=response, | |
| confidence=confidence, | |
| latency_ms=latency, | |
| success=bool(response), | |
| ) | |
| except Exception as e: | |
| return ModelResult( | |
| model_id=model.model_id, response="", | |
| confidence=0.0, success=False, error=str(e), | |
| latency_ms=(time.time() - start) * 1000, | |
| ) | |
| def _exec_local_inference(self, model_id: str, query: str) -> Optional[ModelResult]: | |
| """调用本地推理模型(训练好的SwarmModel)""" | |
| try: | |
| from core.local_inference import get_local_backend | |
| backend = get_local_backend() | |
| # 检查该模型是否可用 | |
| available = backend.list_available() | |
| # model_id如"swarm_tiny"映射到"tiny" | |
| target = None | |
| for name in available: | |
| if name in model_id or model_id in f"swarm_{name}": | |
| target = name | |
| break | |
| if not target: | |
| return None | |
| result = backend.infer(target, query, max_new_tokens=128) | |
| if result["success"]: | |
| return ModelResult( | |
| model_id=model_id, | |
| response=result["response"], | |
| confidence=0.7, # 本地模型置信度 | |
| latency_ms=result["latency_ms"], | |
| success=True, | |
| ) | |
| except Exception as e: | |
| logger.debug(f"本地推理不可用({model_id}): {e}") | |
| return None | |
| def _exec_api(self, model, ctx: TaskContext) -> ModelResult: | |
| """执行API模型(OpenAI兼容格式)""" | |
| start = time.time() | |
| # 获取密钥 | |
| api_key = self._api_keys.get(model.model_id, "") | |
| if not api_key: | |
| return ModelResult( | |
| model_id=model.model_id, response="", | |
| confidence=0.0, success=False, | |
| error=f"API密钥未配置: {model.model_id}", | |
| latency_ms=(time.time() - start) * 1000, | |
| ) | |
| # 构建请求 | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| # 获取记忆上下文 | |
| mem_ctx = self.memory.get_relevant_context(ctx.query, ctx.user_id, top_k=3) | |
| # 获取对话上下文(多轮对话历史) | |
| conv_id = ctx.conversation_id or "" | |
| conv_messages = self._conv_mgr.get_context(conv_id, max_tokens=2000) | |
| # 构建messages:对话历史 + 当前查询 | |
| messages = [] | |
| if mem_ctx: | |
| messages.append({ | |
| "role": "system", | |
| "content": f"以下是相关历史上下文,供参考:\n{mem_ctx[:500]}" | |
| }) | |
| if conv_messages: | |
| messages.extend(conv_messages) | |
| # 确保最后一条是当前用户查询 | |
| messages.append({"role": "user", "content": ctx.query}) | |
| payload = { | |
| "model": self._get_api_model_name(model.model_id), | |
| "messages": messages, | |
| "max_tokens": model.max_tokens, | |
| } | |
| try: | |
| # 重试逻辑:最多3次,指数退避 | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| resp = requests.post( | |
| model.endpoint, | |
| headers=headers, | |
| json=payload, | |
| timeout=30, | |
| ) | |
| latency = (time.time() - start) * 1000 | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| content = data.get("choices", [{}])[0].get("message", {}).get("content", "") | |
| confidence = model.default_confidence | |
| return ModelResult( | |
| model_id=model.model_id, | |
| response=content, | |
| confidence=confidence, | |
| latency_ms=latency, | |
| success=bool(content), | |
| ) | |
| elif resp.status_code in (429, 502, 503, 504): | |
| # 可重试的错误码 | |
| if attempt < max_retries - 1: | |
| wait = 2 ** attempt # 1s, 2s, 4s | |
| logger.warning(f"{model.model_id} 返回 {resp.status_code},{wait}秒后重试({attempt+1}/{max_retries})") | |
| time.sleep(wait) | |
| continue | |
| error_msg = f"API错误 {resp.status_code} (重试{max_retries}次后仍失败)" | |
| else: | |
| error_msg = f"API错误 {resp.status_code}: {resp.text[:200]}" | |
| break # 4xx错误不重试 | |
| except requests.exceptions.Timeout: | |
| if attempt < max_retries - 1: | |
| wait = 2 ** attempt | |
| logger.warning(f"{model.model_id} 超时,{wait}秒后重试({attempt+1}/{max_retries})") | |
| time.sleep(wait) | |
| continue | |
| error_msg = f"API超时 (重试{max_retries}次后仍超时)" | |
| break | |
| except requests.exceptions.ConnectionError: | |
| if attempt < max_retries - 1: | |
| wait = 2 ** attempt | |
| logger.warning(f"{model.model_id} 连接失败,{wait}秒后重试({attempt+1}/{max_retries})") | |
| time.sleep(wait) | |
| continue | |
| error_msg = f"API连接失败 (重试{max_retries}次后仍失败)" | |
| break | |
| logger.error(f"{model.model_id}: {error_msg}") | |
| latency = (time.time() - start) * 1000 | |
| # 故障转移:尝试备选模型 | |
| fallback_id = self._health.get_fallback(model.model_id) | |
| if fallback_id: | |
| fallback_model = self.registry.get(fallback_id) | |
| if fallback_model: | |
| logger.info(f"故障转移到 {fallback_id}") | |
| return self._exec_api(fallback_model, ctx) | |
| return ModelResult( | |
| model_id=model.model_id, response="", | |
| confidence=0.0, success=False, error=error_msg, | |
| latency_ms=latency, | |
| ) | |
| except Exception as e: | |
| return ModelResult( | |
| model_id=model.model_id, response="", | |
| confidence=0.0, success=False, error=str(e), | |
| latency_ms=(time.time() - start) * 1000, | |
| ) | |
| # ============================================================ | |
| # 流式输出 | |
| # ============================================================ | |
| def process_stream(self, query: str, user_id: str = "default", | |
| conversation_id: str = ""): | |
| """ | |
| 流式处理用户查询,逐字返回结果。 | |
| 生成器模式:yield (chunk_text, metadata_dict) | |
| """ | |
| # 0. 技能管线预处理 | |
| skill_analysis = self._run_skill_pipeline(query, user_id, conversation_id) | |
| if skill_analysis and not skill_analysis.get("is_safe", True): | |
| yield ("抱歉,您的请求包含不安全内容,无法处理。", | |
| {"model": "safety_filter", "done": True}) | |
| return | |
| # 1. 快速规则匹配 | |
| quick = self._quick_match(query) | |
| if quick: | |
| yield (quick, {"model": "local_memory", "done": True}) | |
| return | |
| # 2. 路由分析 | |
| ctx = self.router.analyze(query, user_id, conversation_id) | |
| # 技能分析调整 | |
| if skill_analysis and ctx.metadata is not None: | |
| ctx.metadata["skill_analysis"] = skill_analysis | |
| task_cat = skill_analysis.get("task_category", "") | |
| if task_cat in ("code", "analysis"): | |
| if ctx.complexity_score < 0.50: | |
| ctx.complexity_score = min(ctx.complexity_score + 0.15, 0.70) | |
| ctx.compute_complexity_level() | |
| ctx.model_chain = self.router._recommend_models(ctx) | |
| # 3. 选择最佳API模型进行流式调用 | |
| api_model = None | |
| for mid in ctx.model_chain: | |
| model = self.registry.get(mid) | |
| if model and not model.is_local: | |
| api_model = model | |
| break | |
| if not api_model: | |
| # 无API模型,走本地记忆 | |
| mem_ctx = self.memory.get_relevant_context(query, user_id, top_k=3) | |
| yield (mem_ctx or "暂无可用模型处理您的请求。", | |
| {"model": "memory_only", "done": True}) | |
| return | |
| # 4. 流式API调用 | |
| full_response = [] | |
| for chunk_text, done in self._stream_api(api_model, ctx): | |
| full_response.append(chunk_text) | |
| yield (chunk_text, {"model": api_model.model_id, "done": done}) | |
| # 5. 存储记忆 | |
| try: | |
| self.memory.store( | |
| user_id=user_id, | |
| conversation_id=conversation_id, | |
| title=query[:50], | |
| user_message=query, | |
| ai_response="".join(full_response), | |
| ) | |
| except Exception as e: | |
| logger.warning(f"流式记忆存储失败: {e}") | |
| def _stream_api(self, model, ctx: TaskContext): | |
| """ | |
| 流式调用API模型,逐步yield文本片段。 | |
| yield (chunk_text, is_done) | |
| """ | |
| api_key = self._api_keys.get(model.model_id, "") | |
| if not api_key: | |
| yield ("API密钥未配置", True) | |
| return | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| # 获取记忆上下文 | |
| mem_ctx = self.memory.get_relevant_context(ctx.query, ctx.user_id, top_k=3) | |
| # 获取对话上下文(多轮对话历史) | |
| conv_id = ctx.conversation_id or "" | |
| conv_messages = self._conv_mgr.get_context(conv_id, max_tokens=2000) | |
| messages = [] | |
| if mem_ctx: | |
| messages.append({ | |
| "role": "system", | |
| "content": f"以下是相关历史上下文,供参考:\n{mem_ctx[:500]}" | |
| }) | |
| if conv_messages: | |
| messages.extend(conv_messages) | |
| messages.append({"role": "user", "content": ctx.query}) | |
| payload = { | |
| "model": self._get_api_model_name(model.model_id), | |
| "messages": messages, | |
| "max_tokens": model.max_tokens, | |
| "stream": True, # 关键:启用流式 | |
| } | |
| try: | |
| resp = requests.post( | |
| model.endpoint, | |
| headers=headers, | |
| json=payload, | |
| timeout=60, | |
| stream=True, # HTTP流式响应 | |
| ) | |
| if resp.status_code != 200: | |
| yield (f"API错误 {resp.status_code}", True) | |
| return | |
| # 解析SSE流 | |
| for line in resp.iter_lines(decode_unicode=True): | |
| if not line or not line.startswith("data:"): | |
| continue | |
| data_str = line[5:].strip() | |
| if data_str == "[DONE]": | |
| yield ("", True) | |
| return | |
| try: | |
| data = json.loads(data_str) | |
| delta = data.get("choices", [{}])[0].get("delta", {}) | |
| content = delta.get("content", "") | |
| if content: | |
| yield (content, False) | |
| except json.JSONDecodeError: | |
| continue | |
| yield ("", True) | |
| except requests.exceptions.Timeout: | |
| yield ("请求超时,请稍后重试", True) | |
| except Exception as e: | |
| yield (f"流式调用异常: {str(e)[:50]}", True) | |
| # ============================================================ | |
| # 辅助 | |
| # ============================================================ | |
| def _load_api_keys(self) -> Dict[str, str]: | |
| """从配置中心或api.env加载API密钥""" | |
| keys = {} | |
| # 优先从配置中心读取 | |
| try: | |
| config_keys = self._config.get("api_keys", {}) | |
| if config_keys: | |
| keys.update(config_keys) | |
| logger.info(f"从配置中心加载{len(config_keys)}个API密钥") | |
| except Exception: | |
| pass | |
| # 补充从api.env读取 | |
| env_path = "/home/admin/swarm/api.env" | |
| if os.path.exists(env_path): | |
| try: | |
| with open(env_path, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if "=" in line and not line.startswith("#"): | |
| k, v = line.split("=", 1) | |
| k, v = k.strip(), v.strip() | |
| if "GLM" in k and "KEY" in k: | |
| keys.setdefault("glm_api", v) | |
| except Exception as e: | |
| logger.warning(f"加载API密钥失败: {e}") | |
| return keys | |
| def _get_api_model_name(self, model_id: str) -> str: | |
| """获取API调用的模型名""" | |
| names = { | |
| "glm_api": "glm-4-flash", | |
| } | |
| return names.get(model_id, model_id) | |
| def _quick_match(self, query: str) -> str: | |
| """快速规则匹配 — 常见简单模式无需API""" | |
| q = query.strip() | |
| # 问候 | |
| greetings = ["你好", "您好", "嗨", "hi", "hello", "早上好", "下午好", "晚上好"] | |
| if q.lower() in greetings: | |
| return "你好!我是虫群智能体,有什么可以帮助你的吗?" | |
| # 身份 | |
| if any(kw in q for kw in ["你是谁", "你叫什么", "自我介绍"]): | |
| return "我是虫群智能体,基于多模型聚合架构,可以为你提供智能问答服务。" | |
| # 确认 | |
| if q in ["好的", "明白", "收到", "ok", "谢谢", "感谢"]: | |
| return "不客气!如有其他问题随时问我。" | |
| return "" | |
| def _run_skill_pipeline(self, query: str, user_id: str = "default", | |
| conversation_id: str = "") -> dict: | |
| """执行技能管线,返回分析摘要""" | |
| if not self.pipeline: | |
| return None | |
| try: | |
| results = self.pipeline.execute( | |
| query, user_id=user_id, session_id=conversation_id | |
| ) | |
| # 提取关键信息 | |
| analysis = {"is_safe": True, "task_category": "", "text_type": ""} | |
| # 安全过滤结果 | |
| if "safety_filter" in results: | |
| sf = results["safety_filter"] | |
| if sf.success and sf.result: | |
| analysis["is_safe"] = sf.result.get("is_safe", True) | |
| # 文本解析结果 | |
| if "text_parser" in results: | |
| tp = results["text_parser"] | |
| if tp.success and tp.result: | |
| analysis["text_type"] = tp.result.get("text_type", "") | |
| analysis["keywords"] = tp.result.get("keywords", [])[:5] | |
| analysis["sentiment"] = tp.result.get("sentiment", "") | |
| # 任务分类结果 | |
| if "task_classifier" in results: | |
| tc = results["task_classifier"] | |
| if tc.success and tc.result: | |
| analysis["task_category"] = tc.result.get("category", "") | |
| analysis["task_priority"] = tc.result.get("priority", 1) | |
| return analysis | |
| except Exception as e: | |
| logger.warning(f"技能管线执行失败: {e}") | |
| return None | |