File size: 11,309 Bytes
bb6d5ae a1d2691 99f7550 7bff042 99f7550 7bff042 99f7550 7bff042 99f7550 7bff042 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 99f7550 71e6f1d 99f7550 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 7bff042 a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 7bff042 a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae a1d2691 bb6d5ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 | from typing import List, Dict, Any
import re
try:
from math_verify import parse, verify
MATH_VERIFY_AVAILABLE = True
except ImportError:
MATH_VERIFY_AVAILABLE = False
def _fix_sqrt(string: str) -> str:
if "\\sqrt" not in string: return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if len(split) > 0 and split[0] != "{":
new_string += "\\sqrt{" + split[0] + "}" + split[1:]
else:
new_string += "\\sqrt" + split
return new_string
def _fix_fracs(string: str) -> str:
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
for substr in substrs[1:]:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
elif len(substr) >= 2:
new_str += "{" + substr[0] + "}{" + substr[1] + "}" + substr[2:]
else:
new_str += substr
return new_str
def _fix_a_slash_b(string: str) -> str:
if "/" not in string or len(string.split("/")) != 2: return string
a, b = string.split("/")
try:
a_int = int(re.sub(r'[^0-9-]', '', a))
b_int = int(re.sub(r'[^0-9-]', '', b))
return f"\\frac{{{a_int}}}{{{b_int}}}"
except: return string
def _strip_string(string: str) -> str:
string = string.replace("\n", "").replace("\\!", "").replace("\\\\", "\\")
string = string.replace("tfrac", "frac").replace("dfrac", "frac")
string = string.replace("\\left", "").replace("\\right", "")
string = string.replace("^{\\circ}", "").replace("^\\circ", "")
string = string.replace("\\$", "").replace("$", "")
string = string.replace("\\%", "").replace("\%", "")
if "sqrt" in string: string = _fix_sqrt(string)
string = string.replace(" ", "")
if "frac" in string: string = _fix_fracs(string)
if string == "0.5": string = "\\frac{1}{2}"
string = _fix_a_slash_b(string)
return string
def find_math_answer(s: str) -> str:
s = s.lower()
if 'oxed{' in s:
try:
ans = re.findall(r'oxed{(.*)}', s, flags=re.S)[-1]
if '}' in ans and ('{' not in ans or ans.find('}') < ans.find('{')):
ans = ans.split('}')[0]
s = ans
except: pass
s = s.split('=')[-1].split('\\approx')[-1]
return _strip_string(s)
def extract_choice(text: str) -> str:
"""Extracts alphabet choice (A, B, C, D) from model response."""
patterns = [
r'the answer is \(([a-e])\)',
r'the answer is ([a-e])\.',
r'final answer: ([a-e])',
r'^\(([a-e])\)',
r'^([a-e])\n'
]
for p in patterns:
match = re.search(p, text.lower())
if match: return match.group(1).upper()
return ""
def _normalize_answer(ans: str) -> Any:
"""Uses advanced heuristics + math_verify to normalize answer."""
cleaned = find_math_answer(str(ans))
if MATH_VERIFY_AVAILABLE:
try:
return parse(cleaned)
except:
return cleaned
return cleaned
def normalize_answers(answers: List[str]) -> Dict[str, List[int]]:
"""Group answers that are numerically/symbolically equivalent."""
normalized_groups = {}
for idx, ans in enumerate(answers):
clean = _normalize_answer(ans)
matched = False
for key in list(normalized_groups.keys()):
key_clean = _normalize_answer(key)
if MATH_VERIFY_AVAILABLE:
try:
match = verify(clean, key_clean)
except (ValueError, Exception):
# Fallback when math_verify fails in threaded env (signal.alarm restriction)
match = (str(clean) == str(key_clean))
if match:
normalized_groups[key].append(idx)
matched = True
break
else:
if key_clean == clean:
normalized_groups[key].append(idx)
matched = True
break
if not matched:
normalized_groups[ans] = [idx]
return normalized_groups
def _calculate_logical_score(trace: List[str]) -> float:
"""
L_logic: measures intra-agent logical flow.
Checks for contradiction signals, empty steps, and step count.
"""
if not trace:
return 0.0
contradiction_terms = ["incorrect", "divergence", "wrong", "error", "divergent", "hallucin"]
score = 1.0
for step in trace:
if any(t in step.lower() for t in contradiction_terms):
score -= 0.3
# Longer traces with more reasoning steps are rewarded slightly
score += min(0.1 * (len(trace) - 1), 0.3)
return max(0.0, min(1.0, score))
def _calculate_classifier_score(conf_exp: str, is_divergent: bool) -> float:
"""
C_clf: maps confidence explanation to numerical probability.
"""
if is_divergent:
return 0.1
text = conf_exp.lower()
if any(w in text for w in ["high confidence", "certain", "guaranteed", "verified", "proof"]):
return 0.95
elif any(w in text for w in ["divergent", "divergence", "wrong", "hallucin", "low confidence"]):
return 0.1
elif any(w in text for w in ["likely", "confident", "probably"]):
return 0.75
elif any(w in text for w in ["unsure", "guess", "uncertain"]):
return 0.3
return 0.55 # Neutral default
def evaluate_consensus(
agent_responses: List[Dict[str, Any]],
ocr_confidence: float = 1.0
) -> Dict[str, Any]:
"""
Adaptive Multi-Signal Consensus:
Score_j = 0.40 * V_sym + 0.35 * L_logic + 0.25 * C_clf
FinalConf = Score_j * (0.9 + 0.1 * OCR_conf)
Also detects:
- Answer divergence (agents disagree → flag as uncertain)
- Individual hallucination (score < 0.65 OR marked as divergent by agent)
- High-confidence wrong answers
"""
if not agent_responses:
return {
"final_verified_answer": "No agents responded",
"winning_score": 0.0,
"detail_scores": [],
"divergence_groups": {},
"hallucination_alerts": [],
"verdict": "ERROR"
}
# Import compute symbolic score
try:
from verification_service import calculate_symbolic_score
except ImportError:
def calculate_symbolic_score(trace): return 1.0 if trace else 0.0
scores = []
hallucination_alerts = []
answers = [res["response"].get("Answer", "N/A") for res in agent_responses]
answer_groups = normalize_answers(answers)
# Determine if there is significant divergence between agents
num_unique_answers = len(answer_groups)
has_divergence = num_unique_answers > 1
for idx, agent_data in enumerate(agent_responses):
res = agent_data["response"]
trace = res.get("Reasoning Trace", [])
conf_exp = res.get("Confidence Explanation", "")
raw_ans = res.get("Answer", "N/A")
# Heuristic Bonus: Capture choices (A/B/C/D)
choice = extract_choice(str(raw_ans))
normalized_ans = choice if choice else _normalize_answer(raw_ans)
# Check if the agent itself marked this as divergent/hallucinating
is_self_flagged = any(t in conf_exp.lower() for t in ["divergent", "wrong", "hallucin", "low confidence", "divergence"])
# V_sym: SymPy symbolic reasoning verification (weight 0.40)
v_sym = calculate_symbolic_score(trace)
# L_logic: logical consistency & step quality (weight 0.35)
l_logic = _calculate_logical_score(trace)
# C_clf: confidence classifier (weight 0.25)
c_clf = _calculate_classifier_score(conf_exp, is_self_flagged)
# Core scoring formula
score_j = (0.40 * v_sym) + (0.35 * l_logic) + (0.25 * c_clf)
# OCR calibration
final_conf = score_j * (0.9 + 0.1 * ocr_confidence)
# Hallucination detection — flag if:
# 1. Score is below threshold (lowered from 0.7 to 0.65 for better sensitivity)
# 2. Agent self-flagged as divergent
# 3. High-confidence answer but symbolic score is 0 (contradiction)
is_hallucinating = False
alert_reason = None
if score_j < 0.65:
alert_reason = f"Low consensus score ({score_j:.3f} < 0.65)"
elif is_self_flagged:
alert_reason = "Agent self-reported divergent reasoning path"
elif v_sym == 0.0 and c_clf > 0.7:
alert_reason = "High-confidence answer with zero symbolic validity"
if alert_reason:
is_hallucinating = True
hallucination_alerts.append({
"agent": agent_data["agent"],
"answer": raw_ans,
"reason": alert_reason,
"score": round(score_j, 3)
})
scores.append({
"agent": agent_data["agent"],
"raw_answer": raw_ans,
"normalized_answer": str(normalized_ans),
"V_sym": round(v_sym, 3),
"L_logic": round(l_logic, 3),
"C_clf": round(c_clf, 3),
"Score_j": round(score_j, 3),
"FinalConf": round(final_conf, 3),
"is_hallucinating": is_hallucinating
})
# Aggregate: find the most supported, highest-scoring answer group
final_consensus = {}
top_score = -1.0
best_answer = "Unresolvable Divergence"
for rep_ans, indices in answer_groups.items():
# Prefer non-hallucinating agents when aggregating
valid_idx = [i for i in indices if not scores[i]["is_hallucinating"]]
base_idx = valid_idx if valid_idx else indices
group_score = sum(scores[i]["FinalConf"] for i in base_idx)
# Consistency bonus: more agents agreeing on same answer → stronger signal
consistency_multiplier = 1.0 + (0.15 * (len(base_idx) - 1))
weighted = group_score * consistency_multiplier
final_consensus[rep_ans] = {
"agents_supporting": [scores[i]["agent"] for i in indices],
"valid_agent_count": len(valid_idx),
"aggregate_score": round(weighted, 3)
}
if weighted > top_score:
top_score = weighted
best_answer = rep_ans
# Determine overall verdict with clearer thresholds
if top_score >= 1.5 and not has_divergence and not hallucination_alerts:
verdict = "✅ STRONGLY VERIFIED"
elif top_score >= 1.0 and len(hallucination_alerts) == 0:
verdict = "✅ VERIFIED"
elif has_divergence and len(hallucination_alerts) > 0:
verdict = "❌ DIVERGENCE DETECTED — LIKELY WRONG"
elif has_divergence:
verdict = "⚠️ UNCERTAIN — AGENTS DISAGREE"
elif hallucination_alerts:
verdict = "⚠️ UNCERTAIN — HALLUCINATION RISK"
else:
verdict = "⚠️ LOW CONFIDENCE"
return {
"final_verified_answer": best_answer,
"winning_score": round(top_score, 3),
"detail_scores": scores,
"divergence_groups": final_consensus,
"hallucination_alerts": hallucination_alerts,
"has_divergence": has_divergence,
"unique_answers": list(answer_groups.keys()),
"verdict": verdict
}
|