File size: 4,717 Bytes
17fba62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
虫群智能体系统 — 聚合引擎
将多个模型的结果智能合并为最终响应
"""

import logging
from typing import Dict, List
from datetime import datetime

from core.types import (
    AggregationMethod, AggregationResult, ModelResult, PerformanceStats
)

logger = logging.getLogger(__name__)


class AggregationEngine:
    """聚合引擎 — 合并多模型结果"""

    def __init__(self):
        self.method = AggregationMethod.CONFIDENCE

    def aggregate(self, query: str, results: List[ModelResult],
                  method: AggregationMethod = None) -> AggregationResult:
        """
        聚合多个模型的结果
        Args:
            query: 原始查询
            results: 各模型的处理结果列表
            method: 聚合策略(默认置信度优先)
        Returns:
            AggregationResult
        """
        method = method or self.method
        start = datetime.now()

        # 过滤失败结果
        successful = [r for r in results if r.success and r.response]
        if not successful:
            # 全部失败,收集错误并给出友好提示
            errors = [r.error for r in results if r.error]
            err_msg = "; ".join(errors[:2]) if errors else "所有模型均未返回有效结果"
            return AggregationResult(
                query=query,
                final_response=f"抱歉,暂时无法处理您的请求。原因: {err_msg}",
                method=method,
                model_results=results,
            )

        # 按策略聚合
        if method == AggregationMethod.CONFIDENCE:
            resp, primary, primary_conf = self._by_confidence(successful)
        elif method == AggregationMethod.WEIGHTED:
            resp, primary, primary_conf = self._by_weight(successful)
        elif method == AggregationMethod.VOTE:
            resp, primary, primary_conf = self._by_vote(successful)
        else:  # SEQUENTIAL
            resp, primary, primary_conf = self._by_sequential(successful)

        total_ms = sum(r.latency_ms for r in successful)

        return AggregationResult(
            query=query,
            final_response=resp,
            method=method,
            primary_model=primary,
            primary_confidence=primary_conf,
            model_results=results,
            total_latency_ms=total_ms,
        )

    # ============================================================
    # 聚合策略
    # ============================================================

    def _by_confidence(self, results: List[ModelResult]) -> tuple:
        """置信度优先 — 选择置信度最高的,次高作补充"""
        sorted_r = sorted(results, key=lambda r: r.confidence, reverse=True)
        best = sorted_r[0]
        primary = best.model_id
        primary_conf = best.confidence

        # 如果有次高结果且差异不大,补充信息
        if len(sorted_r) > 1:
            second = sorted_r[1]
            if best.confidence - second.confidence < 0.3:
                supp = second.response[:200]
                if supp and supp[:50] not in best.response:
                    return (
                        f"{best.response}\n\n📌 补充参考: {supp}",
                        primary, primary_conf
                    )
        return best.response, primary, primary_conf

    def _by_weight(self, results: List[ModelResult]) -> tuple:
        """加权平均 — 按置信度加权拼接"""
        total_conf = sum(r.confidence for r in results)
        if total_conf == 0:
            total_conf = 1.0

        parts = []
        best_id = ""
        best_conf = 0.0
        for r in results:
            weight = r.confidence / total_conf
            if weight > 0.1:  # 只展示权重>10%的
                parts.append(f"[{weight*100:.0f}%] {r.response[:300]}")
            if r.confidence > best_conf:
                best_conf = r.confidence
                best_id = r.model_id

        return "\n---\n".join(parts), best_id, best_conf

    def _by_vote(self, results: List[ModelResult]) -> tuple:
        """投票聚合 — 选取关键词频率最高的结果"""
        if len(results) == 1:
            r = results[0]
            return r.response, r.model_id, r.confidence

        # 简单投票:选最长的响应(信息量最大)
        best = max(results, key=lambda r: len(r.response))
        return best.response, best.model_id, best.confidence

    def _by_sequential(self, results: List[ModelResult]) -> tuple:
        """顺序聚合 — 选最快的"""
        fastest = min(results, key=lambda r: r.latency_ms)
        return fastest.response, fastest.model_id, fastest.confidence