| from fastapi import FastAPI, HTTPException
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from pydantic import BaseModel
|
| from typing import Dict, Any, List, Optional
|
| import uvicorn
|
| import httpx
|
| import os
|
|
|
| app = FastAPI(title="Classifier Service", version="1.0.0")
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
| DOWNSTREAM_REPORTING_URL = "http://reporting-service:8006/report"
|
|
|
|
|
| ERROR_TYPES = [
|
| "Arithmetic Error",
|
| "Sign Error",
|
| "Copying / OCR Error",
|
| "Logical Jump",
|
| "Syntax Error",
|
| "Formula Error",
|
| "Substitution Error",
|
| "Unsimplified Form",
|
| "Out of Scope",
|
| "Final Answer Mismatch",
|
| "Unknown / Unscorable"
|
| ]
|
|
|
| class ClassificationRequest(BaseModel):
|
| out_of_scope: bool
|
| sympy_valid: bool
|
| sympy_errors: List[Dict[str, Any]]
|
| llm_details: List[Dict[str, Any]]
|
| divergence_matrix: Dict[str, Any]
|
| metadata: Optional[Dict[str, Any]] = {}
|
|
|
| def compute_symbolic_score(agent_result: Dict) -> float:
|
| valid_steps = agent_result.get("valid", False)
|
| return 1.0 if valid_steps else 0.0
|
|
|
| def compute_logical_score(agent_result: Dict) -> float:
|
| score = 1.0
|
| if not agent_result.get("final_answer"): score -= 0.5
|
| reasoning = agent_result.get("reasoning", "").lower()
|
| if any(k in reasoning for k in ["unknown", "error", "cannot solve", "invalid"]): score -= 0.3
|
| if not agent_result.get("steps") and agent_result.get("final_answer"): score -= 0.2
|
| return max(0.0, score)
|
|
|
| def determine_error_category(agent: Dict, request: ClassificationRequest, avg_consensus: float) -> str:
|
| """Explicit multi-class error routing based on heuristics and SymPy matrices."""
|
|
|
| if request.out_of_scope:
|
| return "Out of Scope"
|
|
|
| ocr_conf = request.metadata.get("ocr_confidence", 1.0)
|
| if ocr_conf < 0.6:
|
| return "Copying / OCR Error"
|
|
|
|
|
| if request.sympy_errors:
|
| err_msg = str(request.sympy_errors).lower()
|
| if "syntax" in err_msg or "parse" in err_msg:
|
| return "Syntax Error"
|
| if "sign" in err_msg or "-" in err_msg:
|
| return "Sign Error"
|
| return "Arithmetic Error"
|
|
|
|
|
| if avg_consensus < 0.4:
|
| return "Logical Jump"
|
|
|
| reasoning = agent.get("reasoning", "").lower()
|
|
|
|
|
| if "formula" in reasoning or "theorem" in reasoning:
|
| return "Formula Error"
|
| if "substituted" in reasoning or "plugged" in reasoning:
|
| return "Substitution Error"
|
| if "simplify" in reasoning or "reduce" in reasoning:
|
| return "Unsimplified Form"
|
|
|
| if not agent.get("final_answer"):
|
| return "Final Answer Mismatch"
|
|
|
| return "Unknown / Unscorable"
|
|
|
| @app.get("/health")
|
| async def health_check():
|
| return {"status": "healthy", "service": "classifier"}
|
|
|
| @app.post("/classify")
|
| async def classify_endpoint(request: ClassificationRequest):
|
| """
|
| Combines SymPy scores and LLM reasoning matrices into a final verdict,
|
| and formally categorizes the errors into 10+ logical types.
|
| """
|
| if request.out_of_scope:
|
| payload = {
|
| "final_verdict": "OUT_OF_SCOPE",
|
| "confidence_score": 0.0,
|
| "error_category": "Out of Scope",
|
| "best_agent": "None",
|
| "metadata": request.metadata
|
| }
|
| else:
|
| ocr_conf = request.metadata.get("ocr_confidence", 1.0)
|
| scored_agents = []
|
|
|
| for agent_res in request.llm_details:
|
| name = agent_res.get("agent_name", "Unknown")
|
|
|
| sym = compute_symbolic_score(agent_res)
|
| logic = compute_logical_score(agent_res)
|
|
|
|
|
| divergences = request.divergence_matrix.get(name, {}).values()
|
| consensuses = [1.0 - d for d in divergences]
|
| avg_cons = sum(consensuses) / len(consensuses) if consensuses else 0.0
|
|
|
|
|
| raw_score = (0.4 * sym) + (0.35 * logic) + (0.25 * avg_cons)
|
| final_conf = raw_score * (0.9 + 0.1 * ocr_conf)
|
|
|
| error_category = "None"
|
| if final_conf <= 0.6:
|
| error_category = determine_error_category(agent_res, request, avg_cons)
|
|
|
| scored_agents.append({
|
| "agent": name,
|
| "final_conf": final_conf,
|
| "error_category": error_category,
|
| "components": {"sym": sym, "logic": logic, "consensus": avg_cons},
|
| "data": agent_res
|
| })
|
|
|
| if not scored_agents:
|
| payload = {"final_verdict": "ERROR", "confidence_score": 0.0, "error_category": "Unknown / Unscorable", "metadata": request.metadata}
|
| else:
|
| best_agent = max(scored_agents, key=lambda x: x["final_conf"])
|
| is_valid = best_agent["final_conf"] > 0.6
|
|
|
| payload = {
|
| "final_verdict": "VALID" if is_valid else "ERROR",
|
| "confidence_score": round(best_agent["final_conf"], 3),
|
| "error_category": best_agent["error_category"],
|
| "best_agent": best_agent["agent"],
|
| "final_answer": best_agent.get("data", {}).get("final_answer", ""),
|
| "all_scores": [{"name": a["agent"], "score": round(a["final_conf"], 3), "breakdown": a["components"], "error": a["error_category"]} for a in scored_agents],
|
| "winning_reasoning": best_agent["data"].get("reasoning", ""),
|
| "divergence_matrix": request.divergence_matrix,
|
| "metadata": request.metadata
|
| }
|
|
|
|
|
| try:
|
| async with httpx.AsyncClient() as client:
|
| response = await client.post(DOWNSTREAM_REPORTING_URL, json=payload, timeout=60.0)
|
| response.raise_for_status()
|
| return response.json()
|
|
|
| except httpx.RequestError as exc:
|
| raise HTTPException(status_code=503, detail=f"Downstream Reporting service unavailable: {exc}")
|
| except httpx.HTTPStatusError as exc:
|
| raise HTTPException(status_code=exc.response.status_code, detail="Downstream Reporting service error")
|
|
|
| if __name__ == "__main__":
|
| uvicorn.run(app, host="0.0.0.0", port=8005)
|
|
|