File size: 8,518 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
from __future__ import annotations
from typing import List, Dict, Any, Optional, Tuple
import math
from ...models.model_configs import LLMConfig
from ...agents.customize_agent import CustomizeAgent
def _tokenize(text: str) -> List[str]:
return [t for t in text.lower().split() if t.strip()]
def _tf_vector(tokens: List[str]) -> Dict[str, float]:
vec: Dict[str, float] = {}
for t in tokens:
vec[t] = vec.get(t, 0.0) + 1.0
# l2 normalize
norm = math.sqrt(sum(v * v for v in vec.values())) or 1.0
for k in list(vec.keys()):
vec[k] /= norm
return vec
def _cosine_sim(a: Dict[str, float], b: Dict[str, float]) -> float:
if len(a) < len(b):
a, b = b, a
return sum(v * b.get(k, 0.0) for k, v in a.items())
def _js_divergence(p: Dict[str, float], q: Dict[str, float]) -> float:
# simple smoothed unigram distributions over union vocab
vocab = set(p.keys()) | set(q.keys())
eps = 1e-9
def _norm(d: Dict[str, float]) -> Dict[str, float]:
s = sum(d.get(w, 0.0) for w in vocab) or 1.0
return {w: (d.get(w, 0.0) + eps) / (s + eps * len(vocab)) for w in vocab}
P = _norm(p)
Q = _norm(q)
M = {w: 0.5 * (P[w] + Q[w]) for w in vocab}
def _kl(X, Y):
return sum(X[w] * math.log((X[w] + eps) / (Y[w] + eps)) for w in vocab)
return 0.5 * _kl(P, M) + 0.5 * _kl(Q, M)
class PruningPipeline:
"""可插拔剪枝流水线:质量剪枝(QP) → 多样性剪枝(DP) → 误解反驳(MR)。
候选输入格式:List[{"agent_id": int, "text": str}]
输出保留相同结构,并在条目中填充可选指标:qp_score、dup_removed 等。
"""
def __init__(
self,
enable_qp: bool = True,
enable_dp: bool = True,
enable_mr: bool = False,
qp_threshold: float = 0.15,
qp_top_k: Optional[int] = None,
dp_similarity_threshold: float = 0.92,
dp_max_candidates: Optional[int] = None,
mr_llm_config: Optional[LLMConfig] = None,
min_keep_count: Optional[int] = None,
) -> None:
self.enable_qp = enable_qp
self.enable_dp = enable_dp
self.enable_mr = enable_mr
self.qp_threshold = qp_threshold
self.qp_top_k = qp_top_k
self.dp_similarity_threshold = dp_similarity_threshold
self.dp_max_candidates = dp_max_candidates
self.mr_llm_config = mr_llm_config
# 最少保留条数(基于参与人数由上层计算传入)。若为 None,则不强制最小保留。
self.min_keep_count = min_keep_count
# -------------------- QP: 质量剪枝 --------------------
def _qp_score(self, problem: str, text: str) -> float:
# 简易相关性评分:基于词袋余弦
qv = _tf_vector(_tokenize(problem))
tv = _tf_vector(_tokenize(text))
return _cosine_sim(qv, tv)
def _quality_prune(self, problem: str, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not self.enable_qp or len(candidates) <= 1:
return candidates
scored: List[Tuple[float, Dict[str, Any]]] = []
for c in candidates:
s = self._qp_score(problem, c.get("text", ""))
c = dict(c)
c["qp_score"] = s
scored.append((s, c))
scored.sort(key=lambda x: x[0], reverse=True)
if self.qp_top_k is not None and self.qp_top_k > 0:
scored = scored[: self.qp_top_k]
kept = [c for s, c in scored if s >= self.qp_threshold]
# 至少保留一个;若配置了最少保留,补齐到该数量
if not kept:
kept = [scored[0][1]]
if self.min_keep_count and len(kept) < self.min_keep_count:
# 依据分数从高到低补齐
existing_ids = set(id(obj) for obj in kept)
for _, c in scored:
if id(c) not in existing_ids:
kept.append(c)
if len(kept) >= self.min_keep_count:
break
return kept
# -------------------- DP: 多样性剪枝 --------------------
def _diversity_prune(self, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not self.enable_dp or len(candidates) <= 1:
return candidates
vecs = [_tf_vector(_tokenize(c.get("text", ""))) for c in candidates]
kept: List[int] = []
for i, v in enumerate(vecs):
diverse = True
for j in kept:
sim = _cosine_sim(v, vecs[j])
if sim >= self.dp_similarity_threshold:
diverse = False
break
if diverse:
kept.append(i)
if self.dp_max_candidates and len(kept) >= self.dp_max_candidates:
break
pruned = [candidates[i] for i in kept]
# 若设置了最少保留,尝试补齐(优先根据 qp_score 从高到低补)
if self.min_keep_count and len(pruned) < self.min_keep_count:
# 构造按 qp_score 的全局顺序
ranked = sorted(
range(len(candidates)),
key=lambda idx: float(candidates[idx].get("qp_score") or 0.0),
reverse=True,
)
chosen = set(kept)
for idx in ranked:
if idx in chosen:
continue
pruned.append(candidates[idx])
chosen.add(idx)
if len(pruned) >= self.min_keep_count:
break
return pruned
# -------------------- MR: 误解反驳 --------------------
def _build_critic(self) -> Optional[CustomizeAgent]:
if not self.mr_llm_config:
return None
prompt = (
"""
You are a critical reviewer. Given a problem and a set of condensed candidate answers, identify common misunderstandings or mistakes, and propose a corrected consolidated answer.
Problem:
{problem}
Candidates:
{candidates_text}
Return XML:
<response>
<issues>Common mistakes found</issues>
<rebuttal>How to fix them</rebuttal>
<corrected>Single corrected final answer</corrected>
</response>
"""
).strip()
inputs = [
{"name": "problem", "type": "str", "description": "Problem statement"},
{"name": "candidates_text", "type": "str", "description": "Concatenated candidates"},
]
outputs = [
{"name": "issues", "type": "str", "description": "Common mistakes", "required": True},
{"name": "rebuttal", "type": "str", "description": "Corrections", "required": True},
{"name": "corrected", "type": "str", "description": "Corrected final answer", "required": True},
]
return CustomizeAgent(
name="CriticAgent",
description="Detects misunderstandings and proposes corrected answer",
prompt=prompt,
llm_config=self.mr_llm_config,
inputs=inputs,
outputs=outputs,
parse_mode="xml",
)
def _misunderstanding_rebuttal(self, problem: str, candidates: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, str]]]:
if not self.enable_mr:
return candidates, None
critic = self._build_critic()
if critic is None:
return candidates, None
concat = "\n\n".join(f"#{c.get('agent_id')}: {c.get('text','').strip()}" for c in candidates)
msg = critic(inputs={"problem": problem, "candidates_text": concat})
st = msg.content.get_structured_data()
# 将批判结果附加到候选(不改变原文),并返回一个建议的最终答案
for c in candidates:
c["mr_issues"] = st.get("issues", "")
c["mr_rebuttal"] = st.get("rebuttal", "")
suggested = {
"issues": st.get("issues", ""),
"rebuttal": st.get("rebuttal", ""),
"corrected": st.get("corrected", ""),
}
return candidates, suggested
# -------------------- Pipeline entry --------------------
def apply(self, problem: str, candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
"""返回 {"candidates": pruned, "mr_suggested": optional}。"""
step1 = self._quality_prune(problem, candidates)
step2 = self._diversity_prune(step1)
step3, suggested = self._misunderstanding_rebuttal(problem, step2)
return {"candidates": step3, "mr_suggested": suggested}
|