"""Evaluation utilities for agent triage decisions.""" from __future__ import annotations from collections.abc import Sequence from core_engine.schemas import ( AgentDecision, CategoryName, EvaluationRecord, PriorityName, ScoreSummary, SyntheticMail, ) from core_engine.score_bounds import enforce_strict_score class TriageEvaluator: """Score category and priority predictions.""" _CATEGORY_WEIGHT = 0.75 _PRIORITY_WEIGHT = 0.25 _URGENT_PENALTY = 0.35 _EXPECTED_PRIORITY: dict[CategoryName, PriorityName] = { "urgent": "high", "promotion": "medium", "spam": "low", "general": "low", } def evaluate(self, message: SyntheticMail, decision: AgentDecision) -> EvaluationRecord: """Compare one decision against the email ground truth.""" expected_priority = self.expected_priority(message.truth_category) category_correct = decision.predicted_category == message.truth_category priority_correct = decision.priority_level == expected_priority urgent_penalty_applied = self._has_urgent_miss( message.truth_category, decision.predicted_category ) step_score = self._score_pair( category_correct, priority_correct, urgent_penalty_applied ) return EvaluationRecord( mail_id=message.mail_id, predicted_category=decision.predicted_category, expected_category=message.truth_category, predicted_priority=decision.priority_level, expected_priority=expected_priority, confidence=decision.confidence, urgent_penalty_applied=urgent_penalty_applied, category_correct=category_correct, priority_correct=priority_correct, step_score=step_score, ) def summarize( self, records: Sequence[EvaluationRecord], total_count: int ) -> ScoreSummary: """Compute aggregate accuracy and weighted numeric score.""" processed_count = len(records) if processed_count == 0: return ScoreSummary( processed_count=0, total_count=total_count, classification_accuracy=enforce_strict_score(0.0), priority_correctness=enforce_strict_score(0.0), numeric_score=enforce_strict_score(0.0), accuracy=enforce_strict_score(0.0), weighted_score=enforce_strict_score(0.0), confusion_matrix=self._empty_confusion_matrix(), urgent_penalty_count=0, ) category_hits = sum(record.category_correct for record in records) priority_hits = sum(record.priority_correct for record in records) classification_accuracy = category_hits / processed_count priority_correctness = priority_hits / processed_count weighted_score = sum(record.step_score for record in records) / processed_count / 100 numeric_score = weighted_score return ScoreSummary( processed_count=processed_count, total_count=total_count, # Boundary fix applied only at the final score output stage. classification_accuracy=enforce_strict_score(round(classification_accuracy, 4)), priority_correctness=enforce_strict_score(round(priority_correctness, 4)), numeric_score=enforce_strict_score(round(numeric_score, 4)), accuracy=enforce_strict_score(round(classification_accuracy, 4)), weighted_score=enforce_strict_score(round(weighted_score, 4)), confusion_matrix=self._confusion_matrix(records), urgent_penalty_count=sum(record.urgent_penalty_applied for record in records), ) def expected_priority(self, category: CategoryName) -> PriorityName: """Return the ideal priority for a hidden category.""" return self._EXPECTED_PRIORITY[category] def _score_pair( self, category_correct: bool, priority_correct: bool, urgent_penalty_applied: bool, ) -> float: category_points = self._CATEGORY_WEIGHT if category_correct else 0.0 priority_points = self._PRIORITY_WEIGHT if priority_correct else 0.0 penalty = self._URGENT_PENALTY if urgent_penalty_applied else 0.0 return round(max(category_points + priority_points - penalty, 0.0) * 100, 2) def _has_urgent_miss( self, expected_category: CategoryName, predicted_category: CategoryName ) -> bool: return (expected_category == "urgent") != (predicted_category == "urgent") def _empty_confusion_matrix(self) -> dict[str, dict[str, int]]: categories = ("general", "promotion", "spam", "urgent") return {actual: {predicted: 0 for predicted in categories} for actual in categories} def _confusion_matrix( self, records: Sequence[EvaluationRecord] ) -> dict[str, dict[str, int]]: matrix = self._empty_confusion_matrix() for record in records: matrix[record.expected_category][record.predicted_category] += 1 return matrix