mvm2-math-verification / consensus_fusion.py
Varshithdharmajv's picture
Upload consensus_fusion.py with huggingface_hub
71e6f1d verified
from typing import List, Dict, Any
import re
try:
from math_verify import parse, verify
MATH_VERIFY_AVAILABLE = True
except ImportError:
MATH_VERIFY_AVAILABLE = False
def _fix_sqrt(string: str) -> str:
if "\\sqrt" not in string: return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if len(split) > 0 and split[0] != "{":
new_string += "\\sqrt{" + split[0] + "}" + split[1:]
else:
new_string += "\\sqrt" + split
return new_string
def _fix_fracs(string: str) -> str:
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
for substr in substrs[1:]:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
elif len(substr) >= 2:
new_str += "{" + substr[0] + "}{" + substr[1] + "}" + substr[2:]
else:
new_str += substr
return new_str
def _fix_a_slash_b(string: str) -> str:
if "/" not in string or len(string.split("/")) != 2: return string
a, b = string.split("/")
try:
a_int = int(re.sub(r'[^0-9-]', '', a))
b_int = int(re.sub(r'[^0-9-]', '', b))
return f"\\frac{{{a_int}}}{{{b_int}}}"
except: return string
def _strip_string(string: str) -> str:
string = string.replace("\n", "").replace("\\!", "").replace("\\\\", "\\")
string = string.replace("tfrac", "frac").replace("dfrac", "frac")
string = string.replace("\\left", "").replace("\\right", "")
string = string.replace("^{\\circ}", "").replace("^\\circ", "")
string = string.replace("\\$", "").replace("$", "")
string = string.replace("\\%", "").replace("\%", "")
if "sqrt" in string: string = _fix_sqrt(string)
string = string.replace(" ", "")
if "frac" in string: string = _fix_fracs(string)
if string == "0.5": string = "\\frac{1}{2}"
string = _fix_a_slash_b(string)
return string
def find_math_answer(s: str) -> str:
s = s.lower()
if 'oxed{' in s:
try:
ans = re.findall(r'oxed{(.*)}', s, flags=re.S)[-1]
if '}' in ans and ('{' not in ans or ans.find('}') < ans.find('{')):
ans = ans.split('}')[0]
s = ans
except: pass
s = s.split('=')[-1].split('\\approx')[-1]
return _strip_string(s)
def extract_choice(text: str) -> str:
"""Extracts alphabet choice (A, B, C, D) from model response."""
patterns = [
r'the answer is \(([a-e])\)',
r'the answer is ([a-e])\.',
r'final answer: ([a-e])',
r'^\(([a-e])\)',
r'^([a-e])\n'
]
for p in patterns:
match = re.search(p, text.lower())
if match: return match.group(1).upper()
return ""
def _normalize_answer(ans: str) -> Any:
"""Uses advanced heuristics + math_verify to normalize answer."""
cleaned = find_math_answer(str(ans))
if MATH_VERIFY_AVAILABLE:
try:
return parse(cleaned)
except:
return cleaned
return cleaned
def normalize_answers(answers: List[str]) -> Dict[str, List[int]]:
"""Group answers that are numerically/symbolically equivalent."""
normalized_groups = {}
for idx, ans in enumerate(answers):
clean = _normalize_answer(ans)
matched = False
for key in list(normalized_groups.keys()):
key_clean = _normalize_answer(key)
if MATH_VERIFY_AVAILABLE:
try:
match = verify(clean, key_clean)
except (ValueError, Exception):
# Fallback when math_verify fails in threaded env (signal.alarm restriction)
match = (str(clean) == str(key_clean))
if match:
normalized_groups[key].append(idx)
matched = True
break
else:
if key_clean == clean:
normalized_groups[key].append(idx)
matched = True
break
if not matched:
normalized_groups[ans] = [idx]
return normalized_groups
def _calculate_logical_score(trace: List[str]) -> float:
"""
L_logic: measures intra-agent logical flow.
Checks for contradiction signals, empty steps, and step count.
"""
if not trace:
return 0.0
contradiction_terms = ["incorrect", "divergence", "wrong", "error", "divergent", "hallucin"]
score = 1.0
for step in trace:
if any(t in step.lower() for t in contradiction_terms):
score -= 0.3
# Longer traces with more reasoning steps are rewarded slightly
score += min(0.1 * (len(trace) - 1), 0.3)
return max(0.0, min(1.0, score))
def _calculate_classifier_score(conf_exp: str, is_divergent: bool) -> float:
"""
C_clf: maps confidence explanation to numerical probability.
"""
if is_divergent:
return 0.1
text = conf_exp.lower()
if any(w in text for w in ["high confidence", "certain", "guaranteed", "verified", "proof"]):
return 0.95
elif any(w in text for w in ["divergent", "divergence", "wrong", "hallucin", "low confidence"]):
return 0.1
elif any(w in text for w in ["likely", "confident", "probably"]):
return 0.75
elif any(w in text for w in ["unsure", "guess", "uncertain"]):
return 0.3
return 0.55 # Neutral default
def evaluate_consensus(
agent_responses: List[Dict[str, Any]],
ocr_confidence: float = 1.0
) -> Dict[str, Any]:
"""
Adaptive Multi-Signal Consensus:
Score_j = 0.40 * V_sym + 0.35 * L_logic + 0.25 * C_clf
FinalConf = Score_j * (0.9 + 0.1 * OCR_conf)
Also detects:
- Answer divergence (agents disagree → flag as uncertain)
- Individual hallucination (score < 0.65 OR marked as divergent by agent)
- High-confidence wrong answers
"""
if not agent_responses:
return {
"final_verified_answer": "No agents responded",
"winning_score": 0.0,
"detail_scores": [],
"divergence_groups": {},
"hallucination_alerts": [],
"verdict": "ERROR"
}
# Import compute symbolic score
try:
from verification_service import calculate_symbolic_score
except ImportError:
def calculate_symbolic_score(trace): return 1.0 if trace else 0.0
scores = []
hallucination_alerts = []
answers = [res["response"].get("Answer", "N/A") for res in agent_responses]
answer_groups = normalize_answers(answers)
# Determine if there is significant divergence between agents
num_unique_answers = len(answer_groups)
has_divergence = num_unique_answers > 1
for idx, agent_data in enumerate(agent_responses):
res = agent_data["response"]
trace = res.get("Reasoning Trace", [])
conf_exp = res.get("Confidence Explanation", "")
raw_ans = res.get("Answer", "N/A")
# Heuristic Bonus: Capture choices (A/B/C/D)
choice = extract_choice(str(raw_ans))
normalized_ans = choice if choice else _normalize_answer(raw_ans)
# Check if the agent itself marked this as divergent/hallucinating
is_self_flagged = any(t in conf_exp.lower() for t in ["divergent", "wrong", "hallucin", "low confidence", "divergence"])
# V_sym: SymPy symbolic reasoning verification (weight 0.40)
v_sym = calculate_symbolic_score(trace)
# L_logic: logical consistency & step quality (weight 0.35)
l_logic = _calculate_logical_score(trace)
# C_clf: confidence classifier (weight 0.25)
c_clf = _calculate_classifier_score(conf_exp, is_self_flagged)
# Core scoring formula
score_j = (0.40 * v_sym) + (0.35 * l_logic) + (0.25 * c_clf)
# OCR calibration
final_conf = score_j * (0.9 + 0.1 * ocr_confidence)
# Hallucination detection — flag if:
# 1. Score is below threshold (lowered from 0.7 to 0.65 for better sensitivity)
# 2. Agent self-flagged as divergent
# 3. High-confidence answer but symbolic score is 0 (contradiction)
is_hallucinating = False
alert_reason = None
if score_j < 0.65:
alert_reason = f"Low consensus score ({score_j:.3f} < 0.65)"
elif is_self_flagged:
alert_reason = "Agent self-reported divergent reasoning path"
elif v_sym == 0.0 and c_clf > 0.7:
alert_reason = "High-confidence answer with zero symbolic validity"
if alert_reason:
is_hallucinating = True
hallucination_alerts.append({
"agent": agent_data["agent"],
"answer": raw_ans,
"reason": alert_reason,
"score": round(score_j, 3)
})
scores.append({
"agent": agent_data["agent"],
"raw_answer": raw_ans,
"normalized_answer": str(normalized_ans),
"V_sym": round(v_sym, 3),
"L_logic": round(l_logic, 3),
"C_clf": round(c_clf, 3),
"Score_j": round(score_j, 3),
"FinalConf": round(final_conf, 3),
"is_hallucinating": is_hallucinating
})
# Aggregate: find the most supported, highest-scoring answer group
final_consensus = {}
top_score = -1.0
best_answer = "Unresolvable Divergence"
for rep_ans, indices in answer_groups.items():
# Prefer non-hallucinating agents when aggregating
valid_idx = [i for i in indices if not scores[i]["is_hallucinating"]]
base_idx = valid_idx if valid_idx else indices
group_score = sum(scores[i]["FinalConf"] for i in base_idx)
# Consistency bonus: more agents agreeing on same answer → stronger signal
consistency_multiplier = 1.0 + (0.15 * (len(base_idx) - 1))
weighted = group_score * consistency_multiplier
final_consensus[rep_ans] = {
"agents_supporting": [scores[i]["agent"] for i in indices],
"valid_agent_count": len(valid_idx),
"aggregate_score": round(weighted, 3)
}
if weighted > top_score:
top_score = weighted
best_answer = rep_ans
# Determine overall verdict with clearer thresholds
if top_score >= 1.5 and not has_divergence and not hallucination_alerts:
verdict = "✅ STRONGLY VERIFIED"
elif top_score >= 1.0 and len(hallucination_alerts) == 0:
verdict = "✅ VERIFIED"
elif has_divergence and len(hallucination_alerts) > 0:
verdict = "❌ DIVERGENCE DETECTED — LIKELY WRONG"
elif has_divergence:
verdict = "⚠️ UNCERTAIN — AGENTS DISAGREE"
elif hallucination_alerts:
verdict = "⚠️ UNCERTAIN — HALLUCINATION RISK"
else:
verdict = "⚠️ LOW CONFIDENCE"
return {
"final_verified_answer": best_answer,
"winning_score": round(top_score, 3),
"detail_scores": scores,
"divergence_groups": final_consensus,
"hallucination_alerts": hallucination_alerts,
"has_divergence": has_divergence,
"unique_answers": list(answer_groups.keys()),
"verdict": verdict
}