| from typing import List, Dict, Any |
| import re |
|
|
| try: |
| from math_verify import parse, verify |
| MATH_VERIFY_AVAILABLE = True |
| except ImportError: |
| MATH_VERIFY_AVAILABLE = False |
|
|
| def _fix_sqrt(string: str) -> str: |
| if "\\sqrt" not in string: return string |
| splits = string.split("\\sqrt") |
| new_string = splits[0] |
| for split in splits[1:]: |
| if len(split) > 0 and split[0] != "{": |
| new_string += "\\sqrt{" + split[0] + "}" + split[1:] |
| else: |
| new_string += "\\sqrt" + split |
| return new_string |
|
|
| def _fix_fracs(string: str) -> str: |
| substrs = string.split("\\frac") |
| new_str = substrs[0] |
| if len(substrs) > 1: |
| for substr in substrs[1:]: |
| new_str += "\\frac" |
| if len(substr) > 0 and substr[0] == "{": |
| new_str += substr |
| elif len(substr) >= 2: |
| new_str += "{" + substr[0] + "}{" + substr[1] + "}" + substr[2:] |
| else: |
| new_str += substr |
| return new_str |
|
|
| def _fix_a_slash_b(string: str) -> str: |
| if "/" not in string or len(string.split("/")) != 2: return string |
| a, b = string.split("/") |
| try: |
| a_int = int(re.sub(r'[^0-9-]', '', a)) |
| b_int = int(re.sub(r'[^0-9-]', '', b)) |
| return f"\\frac{{{a_int}}}{{{b_int}}}" |
| except: return string |
|
|
| def _strip_string(string: str) -> str: |
| string = string.replace("\n", "").replace("\\!", "").replace("\\\\", "\\") |
| string = string.replace("tfrac", "frac").replace("dfrac", "frac") |
| string = string.replace("\\left", "").replace("\\right", "") |
| string = string.replace("^{\\circ}", "").replace("^\\circ", "") |
| string = string.replace("\\$", "").replace("$", "") |
| string = string.replace("\\%", "").replace("\%", "") |
| if "sqrt" in string: string = _fix_sqrt(string) |
| string = string.replace(" ", "") |
| if "frac" in string: string = _fix_fracs(string) |
| if string == "0.5": string = "\\frac{1}{2}" |
| string = _fix_a_slash_b(string) |
| return string |
|
|
| def find_math_answer(s: str) -> str: |
| s = s.lower() |
| if 'oxed{' in s: |
| try: |
| ans = re.findall(r'oxed{(.*)}', s, flags=re.S)[-1] |
| if '}' in ans and ('{' not in ans or ans.find('}') < ans.find('{')): |
| ans = ans.split('}')[0] |
| s = ans |
| except: pass |
| s = s.split('=')[-1].split('\\approx')[-1] |
| return _strip_string(s) |
|
|
| def extract_choice(text: str) -> str: |
| """Extracts alphabet choice (A, B, C, D) from model response.""" |
| patterns = [ |
| r'the answer is \(([a-e])\)', |
| r'the answer is ([a-e])\.', |
| r'final answer: ([a-e])', |
| r'^\(([a-e])\)', |
| r'^([a-e])\n' |
| ] |
| for p in patterns: |
| match = re.search(p, text.lower()) |
| if match: return match.group(1).upper() |
| return "" |
|
|
| def _normalize_answer(ans: str) -> Any: |
| """Uses advanced heuristics + math_verify to normalize answer.""" |
| cleaned = find_math_answer(str(ans)) |
| if MATH_VERIFY_AVAILABLE: |
| try: |
| return parse(cleaned) |
| except: |
| return cleaned |
| return cleaned |
|
|
| def normalize_answers(answers: List[str]) -> Dict[str, List[int]]: |
| """Group answers that are numerically/symbolically equivalent.""" |
| normalized_groups = {} |
| for idx, ans in enumerate(answers): |
| clean = _normalize_answer(ans) |
| matched = False |
| for key in list(normalized_groups.keys()): |
| key_clean = _normalize_answer(key) |
| |
| if MATH_VERIFY_AVAILABLE: |
| try: |
| match = verify(clean, key_clean) |
| except (ValueError, Exception): |
| |
| match = (str(clean) == str(key_clean)) |
| if match: |
| normalized_groups[key].append(idx) |
| matched = True |
| break |
| else: |
| if key_clean == clean: |
| normalized_groups[key].append(idx) |
| matched = True |
| break |
| if not matched: |
| normalized_groups[ans] = [idx] |
| return normalized_groups |
|
|
| def _calculate_logical_score(trace: List[str]) -> float: |
| """ |
| L_logic: measures intra-agent logical flow. |
| Checks for contradiction signals, empty steps, and step count. |
| """ |
| if not trace: |
| return 0.0 |
| contradiction_terms = ["incorrect", "divergence", "wrong", "error", "divergent", "hallucin"] |
| score = 1.0 |
| for step in trace: |
| if any(t in step.lower() for t in contradiction_terms): |
| score -= 0.3 |
| |
| score += min(0.1 * (len(trace) - 1), 0.3) |
| return max(0.0, min(1.0, score)) |
|
|
| def _calculate_classifier_score(conf_exp: str, is_divergent: bool) -> float: |
| """ |
| C_clf: maps confidence explanation to numerical probability. |
| """ |
| if is_divergent: |
| return 0.1 |
| text = conf_exp.lower() |
| if any(w in text for w in ["high confidence", "certain", "guaranteed", "verified", "proof"]): |
| return 0.95 |
| elif any(w in text for w in ["divergent", "divergence", "wrong", "hallucin", "low confidence"]): |
| return 0.1 |
| elif any(w in text for w in ["likely", "confident", "probably"]): |
| return 0.75 |
| elif any(w in text for w in ["unsure", "guess", "uncertain"]): |
| return 0.3 |
| return 0.55 |
|
|
| def evaluate_consensus( |
| agent_responses: List[Dict[str, Any]], |
| ocr_confidence: float = 1.0 |
| ) -> Dict[str, Any]: |
| """ |
| Adaptive Multi-Signal Consensus: |
| Score_j = 0.40 * V_sym + 0.35 * L_logic + 0.25 * C_clf |
| FinalConf = Score_j * (0.9 + 0.1 * OCR_conf) |
| |
| Also detects: |
| - Answer divergence (agents disagree → flag as uncertain) |
| - Individual hallucination (score < 0.65 OR marked as divergent by agent) |
| - High-confidence wrong answers |
| """ |
| if not agent_responses: |
| return { |
| "final_verified_answer": "No agents responded", |
| "winning_score": 0.0, |
| "detail_scores": [], |
| "divergence_groups": {}, |
| "hallucination_alerts": [], |
| "verdict": "ERROR" |
| } |
|
|
| |
| try: |
| from verification_service import calculate_symbolic_score |
| except ImportError: |
| def calculate_symbolic_score(trace): return 1.0 if trace else 0.0 |
|
|
| scores = [] |
| hallucination_alerts = [] |
| answers = [res["response"].get("Answer", "N/A") for res in agent_responses] |
| answer_groups = normalize_answers(answers) |
|
|
| |
| num_unique_answers = len(answer_groups) |
| has_divergence = num_unique_answers > 1 |
|
|
| for idx, agent_data in enumerate(agent_responses): |
| res = agent_data["response"] |
| trace = res.get("Reasoning Trace", []) |
| conf_exp = res.get("Confidence Explanation", "") |
| raw_ans = res.get("Answer", "N/A") |
|
|
| |
| choice = extract_choice(str(raw_ans)) |
| normalized_ans = choice if choice else _normalize_answer(raw_ans) |
| |
| |
| is_self_flagged = any(t in conf_exp.lower() for t in ["divergent", "wrong", "hallucin", "low confidence", "divergence"]) |
|
|
| |
| v_sym = calculate_symbolic_score(trace) |
|
|
| |
| l_logic = _calculate_logical_score(trace) |
|
|
| |
| c_clf = _calculate_classifier_score(conf_exp, is_self_flagged) |
|
|
| |
| score_j = (0.40 * v_sym) + (0.35 * l_logic) + (0.25 * c_clf) |
|
|
| |
| final_conf = score_j * (0.9 + 0.1 * ocr_confidence) |
|
|
| |
| |
| |
| |
| is_hallucinating = False |
| alert_reason = None |
|
|
| if score_j < 0.65: |
| alert_reason = f"Low consensus score ({score_j:.3f} < 0.65)" |
| elif is_self_flagged: |
| alert_reason = "Agent self-reported divergent reasoning path" |
| elif v_sym == 0.0 and c_clf > 0.7: |
| alert_reason = "High-confidence answer with zero symbolic validity" |
|
|
| if alert_reason: |
| is_hallucinating = True |
| hallucination_alerts.append({ |
| "agent": agent_data["agent"], |
| "answer": raw_ans, |
| "reason": alert_reason, |
| "score": round(score_j, 3) |
| }) |
|
|
| scores.append({ |
| "agent": agent_data["agent"], |
| "raw_answer": raw_ans, |
| "normalized_answer": str(normalized_ans), |
| "V_sym": round(v_sym, 3), |
| "L_logic": round(l_logic, 3), |
| "C_clf": round(c_clf, 3), |
| "Score_j": round(score_j, 3), |
| "FinalConf": round(final_conf, 3), |
| "is_hallucinating": is_hallucinating |
| }) |
|
|
| |
| final_consensus = {} |
| top_score = -1.0 |
| best_answer = "Unresolvable Divergence" |
|
|
| for rep_ans, indices in answer_groups.items(): |
| |
| valid_idx = [i for i in indices if not scores[i]["is_hallucinating"]] |
| base_idx = valid_idx if valid_idx else indices |
|
|
| group_score = sum(scores[i]["FinalConf"] for i in base_idx) |
| |
| consistency_multiplier = 1.0 + (0.15 * (len(base_idx) - 1)) |
| weighted = group_score * consistency_multiplier |
|
|
| final_consensus[rep_ans] = { |
| "agents_supporting": [scores[i]["agent"] for i in indices], |
| "valid_agent_count": len(valid_idx), |
| "aggregate_score": round(weighted, 3) |
| } |
|
|
| if weighted > top_score: |
| top_score = weighted |
| best_answer = rep_ans |
|
|
| |
| if top_score >= 1.5 and not has_divergence and not hallucination_alerts: |
| verdict = "✅ STRONGLY VERIFIED" |
| elif top_score >= 1.0 and len(hallucination_alerts) == 0: |
| verdict = "✅ VERIFIED" |
| elif has_divergence and len(hallucination_alerts) > 0: |
| verdict = "❌ DIVERGENCE DETECTED — LIKELY WRONG" |
| elif has_divergence: |
| verdict = "⚠️ UNCERTAIN — AGENTS DISAGREE" |
| elif hallucination_alerts: |
| verdict = "⚠️ UNCERTAIN — HALLUCINATION RISK" |
| else: |
| verdict = "⚠️ LOW CONFIDENCE" |
|
|
| return { |
| "final_verified_answer": best_answer, |
| "winning_score": round(top_score, 3), |
| "detail_scores": scores, |
| "divergence_groups": final_consensus, |
| "hallucination_alerts": hallucination_alerts, |
| "has_divergence": has_divergence, |
| "unique_answers": list(answer_groups.keys()), |
| "verdict": verdict |
| } |
|
|