File size: 11,309 Bytes
bb6d5ae
a1d2691
 
99f7550
 
 
 
 
 
7bff042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f7550
7bff042
 
99f7550
 
7bff042
99f7550
7bff042
 
bb6d5ae
 
a1d2691
bb6d5ae
 
a1d2691
bb6d5ae
a1d2691
99f7550
 
 
71e6f1d
 
 
 
 
 
99f7550
 
 
 
 
 
 
 
bb6d5ae
 
 
 
a1d2691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6d5ae
a1d2691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6d5ae
a1d2691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6d5ae
 
a1d2691
bb6d5ae
a1d2691
 
 
 
 
bb6d5ae
 
 
a1d2691
 
 
7bff042
 
 
 
a1d2691
 
 
 
bb6d5ae
a1d2691
 
 
 
 
 
 
 
bb6d5ae
a1d2691
 
bb6d5ae
a1d2691
 
 
 
 
bb6d5ae
a1d2691
 
 
 
 
 
 
 
 
 
bb6d5ae
 
 
a1d2691
 
bb6d5ae
 
 
 
 
a1d2691
7bff042
a1d2691
 
 
bb6d5ae
 
 
 
a1d2691
 
bb6d5ae
 
a1d2691
 
bb6d5ae
a1d2691
 
 
 
 
 
 
 
 
bb6d5ae
 
a1d2691
 
bb6d5ae
a1d2691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6d5ae
 
a1d2691
bb6d5ae
 
a1d2691
 
 
 
bb6d5ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from typing import List, Dict, Any
import re

try:
    from math_verify import parse, verify
    MATH_VERIFY_AVAILABLE = True
except ImportError:
    MATH_VERIFY_AVAILABLE = False

def _fix_sqrt(string: str) -> str:
    if "\\sqrt" not in string: return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if len(split) > 0 and split[0] != "{":
            new_string += "\\sqrt{" + split[0] + "}" + split[1:]
        else:
            new_string += "\\sqrt" + split
    return new_string

def _fix_fracs(string: str) -> str:
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        for substr in substrs[1:]:
            new_str += "\\frac"
            if len(substr) > 0 and substr[0] == "{":
                new_str += substr
            elif len(substr) >= 2:
                new_str += "{" + substr[0] + "}{" + substr[1] + "}" + substr[2:]
            else:
                new_str += substr
    return new_str

def _fix_a_slash_b(string: str) -> str:
    if "/" not in string or len(string.split("/")) != 2: return string
    a, b = string.split("/")
    try:
        a_int = int(re.sub(r'[^0-9-]', '', a))
        b_int = int(re.sub(r'[^0-9-]', '', b))
        return f"\\frac{{{a_int}}}{{{b_int}}}"
    except: return string

def _strip_string(string: str) -> str:
    string = string.replace("\n", "").replace("\\!", "").replace("\\\\", "\\")
    string = string.replace("tfrac", "frac").replace("dfrac", "frac")
    string = string.replace("\\left", "").replace("\\right", "")
    string = string.replace("^{\\circ}", "").replace("^\\circ", "")
    string = string.replace("\\$", "").replace("$", "")
    string = string.replace("\\%", "").replace("\%", "")
    if "sqrt" in string: string = _fix_sqrt(string)
    string = string.replace(" ", "")
    if "frac" in string: string = _fix_fracs(string)
    if string == "0.5": string = "\\frac{1}{2}"
    string = _fix_a_slash_b(string)
    return string

def find_math_answer(s: str) -> str:
    s = s.lower()
    if 'oxed{' in s:
        try:
            ans = re.findall(r'oxed{(.*)}', s, flags=re.S)[-1]
            if '}' in ans and ('{' not in ans or ans.find('}') < ans.find('{')):
                ans = ans.split('}')[0]
            s = ans
        except: pass
    s = s.split('=')[-1].split('\\approx')[-1]
    return _strip_string(s)

def extract_choice(text: str) -> str:
    """Extracts alphabet choice (A, B, C, D) from model response."""
    patterns = [
        r'the answer is \(([a-e])\)',
        r'the answer is ([a-e])\.',
        r'final answer: ([a-e])',
        r'^\(([a-e])\)',
        r'^([a-e])\n'
    ]
    for p in patterns:
        match = re.search(p, text.lower())
        if match: return match.group(1).upper()
    return ""

def _normalize_answer(ans: str) -> Any:
    """Uses advanced heuristics + math_verify to normalize answer."""
    cleaned = find_math_answer(str(ans))
    if MATH_VERIFY_AVAILABLE:
        try:
            return parse(cleaned)
        except:
            return cleaned
    return cleaned

def normalize_answers(answers: List[str]) -> Dict[str, List[int]]:
    """Group answers that are numerically/symbolically equivalent."""
    normalized_groups = {}
    for idx, ans in enumerate(answers):
        clean = _normalize_answer(ans)
        matched = False
        for key in list(normalized_groups.keys()):
            key_clean = _normalize_answer(key)
            
            if MATH_VERIFY_AVAILABLE:
                try:
                    match = verify(clean, key_clean)
                except (ValueError, Exception):
                    # Fallback when math_verify fails in threaded env (signal.alarm restriction)
                    match = (str(clean) == str(key_clean))
                if match:
                    normalized_groups[key].append(idx)
                    matched = True
                    break
            else:
                if key_clean == clean:
                    normalized_groups[key].append(idx)
                    matched = True
                    break
        if not matched:
            normalized_groups[ans] = [idx]
    return normalized_groups

def _calculate_logical_score(trace: List[str]) -> float:
    """
    L_logic: measures intra-agent logical flow.
    Checks for contradiction signals, empty steps, and step count.
    """
    if not trace:
        return 0.0
    contradiction_terms = ["incorrect", "divergence", "wrong", "error", "divergent", "hallucin"]
    score = 1.0
    for step in trace:
        if any(t in step.lower() for t in contradiction_terms):
            score -= 0.3
    # Longer traces with more reasoning steps are rewarded slightly
    score += min(0.1 * (len(trace) - 1), 0.3)
    return max(0.0, min(1.0, score))

def _calculate_classifier_score(conf_exp: str, is_divergent: bool) -> float:
    """
    C_clf: maps confidence explanation to numerical probability.
    """
    if is_divergent:
        return 0.1
    text = conf_exp.lower()
    if any(w in text for w in ["high confidence", "certain", "guaranteed", "verified", "proof"]):
        return 0.95
    elif any(w in text for w in ["divergent", "divergence", "wrong", "hallucin", "low confidence"]):
        return 0.1
    elif any(w in text for w in ["likely", "confident", "probably"]):
        return 0.75
    elif any(w in text for w in ["unsure", "guess", "uncertain"]):
        return 0.3
    return 0.55  # Neutral default

def evaluate_consensus(
    agent_responses: List[Dict[str, Any]],
    ocr_confidence: float = 1.0
) -> Dict[str, Any]:
    """
    Adaptive Multi-Signal Consensus:
    Score_j = 0.40 * V_sym + 0.35 * L_logic + 0.25 * C_clf
    FinalConf = Score_j * (0.9 + 0.1 * OCR_conf)

    Also detects:
    - Answer divergence (agents disagree → flag as uncertain)
    - Individual hallucination (score < 0.65 OR marked as divergent by agent)
    - High-confidence wrong answers
    """
    if not agent_responses:
        return {
            "final_verified_answer": "No agents responded",
            "winning_score": 0.0,
            "detail_scores": [],
            "divergence_groups": {},
            "hallucination_alerts": [],
            "verdict": "ERROR"
        }

    # Import compute symbolic score
    try:
        from verification_service import calculate_symbolic_score
    except ImportError:
        def calculate_symbolic_score(trace): return 1.0 if trace else 0.0

    scores = []
    hallucination_alerts = []
    answers = [res["response"].get("Answer", "N/A") for res in agent_responses]
    answer_groups = normalize_answers(answers)

    # Determine if there is significant divergence between agents
    num_unique_answers = len(answer_groups)
    has_divergence = num_unique_answers > 1

    for idx, agent_data in enumerate(agent_responses):
        res = agent_data["response"]
        trace = res.get("Reasoning Trace", [])
        conf_exp = res.get("Confidence Explanation", "")
        raw_ans = res.get("Answer", "N/A")

        # Heuristic Bonus: Capture choices (A/B/C/D)
        choice = extract_choice(str(raw_ans))
        normalized_ans = choice if choice else _normalize_answer(raw_ans)
        
        # Check if the agent itself marked this as divergent/hallucinating
        is_self_flagged = any(t in conf_exp.lower() for t in ["divergent", "wrong", "hallucin", "low confidence", "divergence"])

        # V_sym: SymPy symbolic reasoning verification (weight 0.40)
        v_sym = calculate_symbolic_score(trace)

        # L_logic: logical consistency & step quality (weight 0.35)
        l_logic = _calculate_logical_score(trace)

        # C_clf: confidence classifier (weight 0.25)
        c_clf = _calculate_classifier_score(conf_exp, is_self_flagged)

        # Core scoring formula
        score_j = (0.40 * v_sym) + (0.35 * l_logic) + (0.25 * c_clf)

        # OCR calibration
        final_conf = score_j * (0.9 + 0.1 * ocr_confidence)

        # Hallucination detection — flag if:
        # 1. Score is below threshold (lowered from 0.7 to 0.65 for better sensitivity)
        # 2. Agent self-flagged as divergent
        # 3. High-confidence answer but symbolic score is 0 (contradiction)
        is_hallucinating = False
        alert_reason = None

        if score_j < 0.65:
            alert_reason = f"Low consensus score ({score_j:.3f} < 0.65)"
        elif is_self_flagged:
            alert_reason = "Agent self-reported divergent reasoning path"
        elif v_sym == 0.0 and c_clf > 0.7:
            alert_reason = "High-confidence answer with zero symbolic validity"

        if alert_reason:
            is_hallucinating = True
            hallucination_alerts.append({
                "agent": agent_data["agent"],
                "answer": raw_ans,
                "reason": alert_reason,
                "score": round(score_j, 3)
            })

        scores.append({
            "agent": agent_data["agent"],
            "raw_answer": raw_ans,
            "normalized_answer": str(normalized_ans),
            "V_sym": round(v_sym, 3),
            "L_logic": round(l_logic, 3),
            "C_clf": round(c_clf, 3),
            "Score_j": round(score_j, 3),
            "FinalConf": round(final_conf, 3),
            "is_hallucinating": is_hallucinating
        })

    # Aggregate: find the most supported, highest-scoring answer group
    final_consensus = {}
    top_score = -1.0
    best_answer = "Unresolvable Divergence"

    for rep_ans, indices in answer_groups.items():
        # Prefer non-hallucinating agents when aggregating
        valid_idx = [i for i in indices if not scores[i]["is_hallucinating"]]
        base_idx = valid_idx if valid_idx else indices

        group_score = sum(scores[i]["FinalConf"] for i in base_idx)
        # Consistency bonus: more agents agreeing on same answer → stronger signal
        consistency_multiplier = 1.0 + (0.15 * (len(base_idx) - 1))
        weighted = group_score * consistency_multiplier

        final_consensus[rep_ans] = {
            "agents_supporting": [scores[i]["agent"] for i in indices],
            "valid_agent_count": len(valid_idx),
            "aggregate_score": round(weighted, 3)
        }

        if weighted > top_score:
            top_score = weighted
            best_answer = rep_ans

    # Determine overall verdict with clearer thresholds
    if top_score >= 1.5 and not has_divergence and not hallucination_alerts:
        verdict = "✅ STRONGLY VERIFIED"
    elif top_score >= 1.0 and len(hallucination_alerts) == 0:
        verdict = "✅ VERIFIED"
    elif has_divergence and len(hallucination_alerts) > 0:
        verdict = "❌ DIVERGENCE DETECTED — LIKELY WRONG"
    elif has_divergence:
        verdict = "⚠️ UNCERTAIN — AGENTS DISAGREE"
    elif hallucination_alerts:
        verdict = "⚠️ UNCERTAIN — HALLUCINATION RISK"
    else:
        verdict = "⚠️ LOW CONFIDENCE"

    return {
        "final_verified_answer": best_answer,
        "winning_score": round(top_score, 3),
        "detail_scores": scores,
        "divergence_groups": final_consensus,
        "hallucination_alerts": hallucination_alerts,
        "has_divergence": has_divergence,
        "unique_answers": list(answer_groups.keys()),
        "verdict": verdict
    }