File size: 7,300 Bytes
b25b8f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, Any, List, Optional
import uvicorn
import httpx
import os

app = FastAPI(title="Classifier Service", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

DOWNSTREAM_REPORTING_URL = "http://reporting-service:8006/report"

# We define 10 rigorous error classification types:
ERROR_TYPES = [
    "Arithmetic Error",      # Caught by SymPy checks evaluating basic ops
    "Sign Error",            # Subset of Arithmetic/Algebraic where only the sign is flipped
    "Copying / OCR Error",   # High OCR uncertainty leading to garbage states
    "Logical Jump",          # High divergence across agents, missing intermediate steps
    "Syntax Error",          # SymPy failed to parse entirely
    "Formula Error",         # Applied wrong formula (e.g. Area = 2*pi*r)
    "Substitution Error",    # Plugged in wrong values into a correct formula
    "Unsimplified Form",     # Correct algebraically but not final (e.g. 4/2 instead of 2)
    "Out of Scope",          # Non-math query
    "Final Answer Mismatch", # Agents diverged at the very last step
    "Unknown / Unscorable"   # Blanket fallback
]

class ClassificationRequest(BaseModel):
    out_of_scope: bool
    sympy_valid: bool
    sympy_errors: List[Dict[str, Any]]
    llm_details: List[Dict[str, Any]]
    divergence_matrix: Dict[str, Any]
    metadata: Optional[Dict[str, Any]] = {}

def compute_symbolic_score(agent_result: Dict) -> float:
    valid_steps = agent_result.get("valid", False)
    return 1.0 if valid_steps else 0.0

def compute_logical_score(agent_result: Dict) -> float:
    score = 1.0
    if not agent_result.get("final_answer"): score -= 0.5
    reasoning = agent_result.get("reasoning", "").lower()
    if any(k in reasoning for k in ["unknown", "error", "cannot solve", "invalid"]): score -= 0.3
    if not agent_result.get("steps") and agent_result.get("final_answer"): score -= 0.2
    return max(0.0, score)

def determine_error_category(agent: Dict, request: ClassificationRequest, avg_consensus: float) -> str:
    """Explicit multi-class error routing based on heuristics and SymPy matrices."""
    
    if request.out_of_scope:
        return "Out of Scope"
        
    ocr_conf = request.metadata.get("ocr_confidence", 1.0)
    if ocr_conf < 0.6:
        return "Copying / OCR Error"
        
    # Check SymPy formal verification errors
    if request.sympy_errors:
        err_msg = str(request.sympy_errors).lower()
        if "syntax" in err_msg or "parse" in err_msg:
            return "Syntax Error"
        if "sign" in err_msg or "-" in err_msg:
            return "Sign Error"
        return "Arithmetic Error"
        
    # High divergence (hallucination) across agents -> Logical Jump
    if avg_consensus < 0.4:
        return "Logical Jump"
        
    reasoning = agent.get("reasoning", "").lower()
    
    # NLP keyword heuristics based on Critic Agent feedbacks
    if "formula" in reasoning or "theorem" in reasoning:
        return "Formula Error"
    if "substituted" in reasoning or "plugged" in reasoning:
        return "Substitution Error"
    if "simplify" in reasoning or "reduce" in reasoning:
        return "Unsimplified Form"
        
    if not agent.get("final_answer"):
        return "Final Answer Mismatch"
        
    return "Unknown / Unscorable"

@app.get("/health")
async def health_check():
    return {"status": "healthy", "service": "classifier"}

@app.post("/classify")
async def classify_endpoint(request: ClassificationRequest):
    """

    Combines SymPy scores and LLM reasoning matrices into a final verdict, 

    and formally categorizes the errors into 10+ logical types.

    """
    if request.out_of_scope:
        payload = {
            "final_verdict": "OUT_OF_SCOPE",
            "confidence_score": 0.0,
            "error_category": "Out of Scope",
            "best_agent": "None",
            "metadata": request.metadata
        }
    else:
        ocr_conf = request.metadata.get("ocr_confidence", 1.0)
        scored_agents = []
        
        for agent_res in request.llm_details:
            name = agent_res.get("agent_name", "Unknown")
            
            sym = compute_symbolic_score(agent_res)
            logic = compute_logical_score(agent_res)
            
            # Consensus from matrix
            divergences = request.divergence_matrix.get(name, {}).values()
            consensuses = [1.0 - d for d in divergences]
            avg_cons = sum(consensuses) / len(consensuses) if consensuses else 0.0
            
            # Weighted Scoring Engine
            raw_score = (0.4 * sym) + (0.35 * logic) + (0.25 * avg_cons)
            final_conf = raw_score * (0.9 + 0.1 * ocr_conf)
            
            error_category = "None"
            if final_conf <= 0.6:
                error_category = determine_error_category(agent_res, request, avg_cons)
            
            scored_agents.append({
                "agent": name,
                "final_conf": final_conf,
                "error_category": error_category,
                "components": {"sym": sym, "logic": logic, "consensus": avg_cons},
                "data": agent_res
            })
            
        if not scored_agents:
            payload = {"final_verdict": "ERROR", "confidence_score": 0.0, "error_category": "Unknown / Unscorable", "metadata": request.metadata}
        else:
            best_agent = max(scored_agents, key=lambda x: x["final_conf"])
            is_valid = best_agent["final_conf"] > 0.6
            
            payload = {
                "final_verdict": "VALID" if is_valid else "ERROR",
                "confidence_score": round(best_agent["final_conf"], 3),
                "error_category": best_agent["error_category"],
                "best_agent": best_agent["agent"],
                "final_answer": best_agent.get("data", {}).get("final_answer", ""),
                "all_scores": [{"name": a["agent"], "score": round(a["final_conf"], 3), "breakdown": a["components"], "error": a["error_category"]} for a in scored_agents],
                "winning_reasoning": best_agent["data"].get("reasoning", ""),
                "divergence_matrix": request.divergence_matrix,
                "metadata": request.metadata
            }

    # Forward to Reporting Service
    try:
        async with httpx.AsyncClient() as client:
            response = await client.post(DOWNSTREAM_REPORTING_URL, json=payload, timeout=60.0)
            response.raise_for_status()
            return response.json()
            
    except httpx.RequestError as exc:
        raise HTTPException(status_code=503, detail=f"Downstream Reporting service unavailable: {exc}")
    except httpx.HTTPStatusError as exc:
        raise HTTPException(status_code=exc.response.status_code, detail="Downstream Reporting service error")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8005)