File size: 7,300 Bytes
b25b8f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, Any, List, Optional
import uvicorn
import httpx
import os
app = FastAPI(title="Classifier Service", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DOWNSTREAM_REPORTING_URL = "http://reporting-service:8006/report"
# We define 10 rigorous error classification types:
ERROR_TYPES = [
"Arithmetic Error", # Caught by SymPy checks evaluating basic ops
"Sign Error", # Subset of Arithmetic/Algebraic where only the sign is flipped
"Copying / OCR Error", # High OCR uncertainty leading to garbage states
"Logical Jump", # High divergence across agents, missing intermediate steps
"Syntax Error", # SymPy failed to parse entirely
"Formula Error", # Applied wrong formula (e.g. Area = 2*pi*r)
"Substitution Error", # Plugged in wrong values into a correct formula
"Unsimplified Form", # Correct algebraically but not final (e.g. 4/2 instead of 2)
"Out of Scope", # Non-math query
"Final Answer Mismatch", # Agents diverged at the very last step
"Unknown / Unscorable" # Blanket fallback
]
class ClassificationRequest(BaseModel):
out_of_scope: bool
sympy_valid: bool
sympy_errors: List[Dict[str, Any]]
llm_details: List[Dict[str, Any]]
divergence_matrix: Dict[str, Any]
metadata: Optional[Dict[str, Any]] = {}
def compute_symbolic_score(agent_result: Dict) -> float:
valid_steps = agent_result.get("valid", False)
return 1.0 if valid_steps else 0.0
def compute_logical_score(agent_result: Dict) -> float:
score = 1.0
if not agent_result.get("final_answer"): score -= 0.5
reasoning = agent_result.get("reasoning", "").lower()
if any(k in reasoning for k in ["unknown", "error", "cannot solve", "invalid"]): score -= 0.3
if not agent_result.get("steps") and agent_result.get("final_answer"): score -= 0.2
return max(0.0, score)
def determine_error_category(agent: Dict, request: ClassificationRequest, avg_consensus: float) -> str:
"""Explicit multi-class error routing based on heuristics and SymPy matrices."""
if request.out_of_scope:
return "Out of Scope"
ocr_conf = request.metadata.get("ocr_confidence", 1.0)
if ocr_conf < 0.6:
return "Copying / OCR Error"
# Check SymPy formal verification errors
if request.sympy_errors:
err_msg = str(request.sympy_errors).lower()
if "syntax" in err_msg or "parse" in err_msg:
return "Syntax Error"
if "sign" in err_msg or "-" in err_msg:
return "Sign Error"
return "Arithmetic Error"
# High divergence (hallucination) across agents -> Logical Jump
if avg_consensus < 0.4:
return "Logical Jump"
reasoning = agent.get("reasoning", "").lower()
# NLP keyword heuristics based on Critic Agent feedbacks
if "formula" in reasoning or "theorem" in reasoning:
return "Formula Error"
if "substituted" in reasoning or "plugged" in reasoning:
return "Substitution Error"
if "simplify" in reasoning or "reduce" in reasoning:
return "Unsimplified Form"
if not agent.get("final_answer"):
return "Final Answer Mismatch"
return "Unknown / Unscorable"
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "classifier"}
@app.post("/classify")
async def classify_endpoint(request: ClassificationRequest):
"""
Combines SymPy scores and LLM reasoning matrices into a final verdict,
and formally categorizes the errors into 10+ logical types.
"""
if request.out_of_scope:
payload = {
"final_verdict": "OUT_OF_SCOPE",
"confidence_score": 0.0,
"error_category": "Out of Scope",
"best_agent": "None",
"metadata": request.metadata
}
else:
ocr_conf = request.metadata.get("ocr_confidence", 1.0)
scored_agents = []
for agent_res in request.llm_details:
name = agent_res.get("agent_name", "Unknown")
sym = compute_symbolic_score(agent_res)
logic = compute_logical_score(agent_res)
# Consensus from matrix
divergences = request.divergence_matrix.get(name, {}).values()
consensuses = [1.0 - d for d in divergences]
avg_cons = sum(consensuses) / len(consensuses) if consensuses else 0.0
# Weighted Scoring Engine
raw_score = (0.4 * sym) + (0.35 * logic) + (0.25 * avg_cons)
final_conf = raw_score * (0.9 + 0.1 * ocr_conf)
error_category = "None"
if final_conf <= 0.6:
error_category = determine_error_category(agent_res, request, avg_cons)
scored_agents.append({
"agent": name,
"final_conf": final_conf,
"error_category": error_category,
"components": {"sym": sym, "logic": logic, "consensus": avg_cons},
"data": agent_res
})
if not scored_agents:
payload = {"final_verdict": "ERROR", "confidence_score": 0.0, "error_category": "Unknown / Unscorable", "metadata": request.metadata}
else:
best_agent = max(scored_agents, key=lambda x: x["final_conf"])
is_valid = best_agent["final_conf"] > 0.6
payload = {
"final_verdict": "VALID" if is_valid else "ERROR",
"confidence_score": round(best_agent["final_conf"], 3),
"error_category": best_agent["error_category"],
"best_agent": best_agent["agent"],
"final_answer": best_agent.get("data", {}).get("final_answer", ""),
"all_scores": [{"name": a["agent"], "score": round(a["final_conf"], 3), "breakdown": a["components"], "error": a["error_category"]} for a in scored_agents],
"winning_reasoning": best_agent["data"].get("reasoning", ""),
"divergence_matrix": request.divergence_matrix,
"metadata": request.metadata
}
# Forward to Reporting Service
try:
async with httpx.AsyncClient() as client:
response = await client.post(DOWNSTREAM_REPORTING_URL, json=payload, timeout=60.0)
response.raise_for_status()
return response.json()
except httpx.RequestError as exc:
raise HTTPException(status_code=503, detail=f"Downstream Reporting service unavailable: {exc}")
except httpx.HTTPStatusError as exc:
raise HTTPException(status_code=exc.response.status_code, detail="Downstream Reporting service error")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8005)
|