Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 虫群智能体系统 — 聚合引擎 | |
| 将多个模型的结果智能合并为最终响应 | |
| """ | |
| import logging | |
| from typing import Dict, List | |
| from datetime import datetime | |
| from core.types import ( | |
| AggregationMethod, AggregationResult, ModelResult, PerformanceStats | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class AggregationEngine: | |
| """聚合引擎 — 合并多模型结果""" | |
| def __init__(self): | |
| self.method = AggregationMethod.CONFIDENCE | |
| def aggregate(self, query: str, results: List[ModelResult], | |
| method: AggregationMethod = None) -> AggregationResult: | |
| """ | |
| 聚合多个模型的结果 | |
| Args: | |
| query: 原始查询 | |
| results: 各模型的处理结果列表 | |
| method: 聚合策略(默认置信度优先) | |
| Returns: | |
| AggregationResult | |
| """ | |
| method = method or self.method | |
| start = datetime.now() | |
| # 过滤失败结果 | |
| successful = [r for r in results if r.success and r.response] | |
| if not successful: | |
| # 全部失败,收集错误并给出友好提示 | |
| errors = [r.error for r in results if r.error] | |
| err_msg = "; ".join(errors[:2]) if errors else "所有模型均未返回有效结果" | |
| return AggregationResult( | |
| query=query, | |
| final_response=f"抱歉,暂时无法处理您的请求。原因: {err_msg}", | |
| method=method, | |
| model_results=results, | |
| ) | |
| # 按策略聚合 | |
| if method == AggregationMethod.CONFIDENCE: | |
| resp, primary, primary_conf = self._by_confidence(successful) | |
| elif method == AggregationMethod.WEIGHTED: | |
| resp, primary, primary_conf = self._by_weight(successful) | |
| elif method == AggregationMethod.VOTE: | |
| resp, primary, primary_conf = self._by_vote(successful) | |
| else: # SEQUENTIAL | |
| resp, primary, primary_conf = self._by_sequential(successful) | |
| total_ms = sum(r.latency_ms for r in successful) | |
| return AggregationResult( | |
| query=query, | |
| final_response=resp, | |
| method=method, | |
| primary_model=primary, | |
| primary_confidence=primary_conf, | |
| model_results=results, | |
| total_latency_ms=total_ms, | |
| ) | |
| # ============================================================ | |
| # 聚合策略 | |
| # ============================================================ | |
| def _by_confidence(self, results: List[ModelResult]) -> tuple: | |
| """置信度优先 — 选择置信度最高的,次高作补充""" | |
| sorted_r = sorted(results, key=lambda r: r.confidence, reverse=True) | |
| best = sorted_r[0] | |
| primary = best.model_id | |
| primary_conf = best.confidence | |
| # 如果有次高结果且差异不大,补充信息 | |
| if len(sorted_r) > 1: | |
| second = sorted_r[1] | |
| if best.confidence - second.confidence < 0.3: | |
| supp = second.response[:200] | |
| if supp and supp[:50] not in best.response: | |
| return ( | |
| f"{best.response}\n\n📌 补充参考: {supp}", | |
| primary, primary_conf | |
| ) | |
| return best.response, primary, primary_conf | |
| def _by_weight(self, results: List[ModelResult]) -> tuple: | |
| """加权平均 — 按置信度加权拼接""" | |
| total_conf = sum(r.confidence for r in results) | |
| if total_conf == 0: | |
| total_conf = 1.0 | |
| parts = [] | |
| best_id = "" | |
| best_conf = 0.0 | |
| for r in results: | |
| weight = r.confidence / total_conf | |
| if weight > 0.1: # 只展示权重>10%的 | |
| parts.append(f"[{weight*100:.0f}%] {r.response[:300]}") | |
| if r.confidence > best_conf: | |
| best_conf = r.confidence | |
| best_id = r.model_id | |
| return "\n---\n".join(parts), best_id, best_conf | |
| def _by_vote(self, results: List[ModelResult]) -> tuple: | |
| """投票聚合 — 选取关键词频率最高的结果""" | |
| if len(results) == 1: | |
| r = results[0] | |
| return r.response, r.model_id, r.confidence | |
| # 简单投票:选最长的响应(信息量最大) | |
| best = max(results, key=lambda r: len(r.response)) | |
| return best.response, best.model_id, best.confidence | |
| def _by_sequential(self, results: List[ModelResult]) -> tuple: | |
| """顺序聚合 — 选最快的""" | |
| fastest = min(results, key=lambda r: r.latency_ms) | |
| return fastest.response, fastest.model_id, fastest.confidence | |