Spaces:
Sleeping
Sleeping
| """Evaluation utilities for agent triage decisions.""" | |
| from __future__ import annotations | |
| from collections.abc import Sequence | |
| from core_engine.schemas import ( | |
| AgentDecision, | |
| CategoryName, | |
| EvaluationRecord, | |
| PriorityName, | |
| ScoreSummary, | |
| SyntheticMail, | |
| ) | |
| from core_engine.score_bounds import enforce_strict_score | |
| class TriageEvaluator: | |
| """Score category and priority predictions.""" | |
| _CATEGORY_WEIGHT = 0.75 | |
| _PRIORITY_WEIGHT = 0.25 | |
| _URGENT_PENALTY = 0.35 | |
| _EXPECTED_PRIORITY: dict[CategoryName, PriorityName] = { | |
| "urgent": "high", | |
| "promotion": "medium", | |
| "spam": "low", | |
| "general": "low", | |
| } | |
| def evaluate(self, message: SyntheticMail, decision: AgentDecision) -> EvaluationRecord: | |
| """Compare one decision against the email ground truth.""" | |
| expected_priority = self.expected_priority(message.truth_category) | |
| category_correct = decision.predicted_category == message.truth_category | |
| priority_correct = decision.priority_level == expected_priority | |
| urgent_penalty_applied = self._has_urgent_miss( | |
| message.truth_category, decision.predicted_category | |
| ) | |
| step_score = self._score_pair( | |
| category_correct, priority_correct, urgent_penalty_applied | |
| ) | |
| return EvaluationRecord( | |
| mail_id=message.mail_id, | |
| predicted_category=decision.predicted_category, | |
| expected_category=message.truth_category, | |
| predicted_priority=decision.priority_level, | |
| expected_priority=expected_priority, | |
| confidence=decision.confidence, | |
| urgent_penalty_applied=urgent_penalty_applied, | |
| category_correct=category_correct, | |
| priority_correct=priority_correct, | |
| step_score=step_score, | |
| ) | |
| def summarize( | |
| self, records: Sequence[EvaluationRecord], total_count: int | |
| ) -> ScoreSummary: | |
| """Compute aggregate accuracy and weighted numeric score.""" | |
| processed_count = len(records) | |
| if processed_count == 0: | |
| return ScoreSummary( | |
| processed_count=0, | |
| total_count=total_count, | |
| classification_accuracy=enforce_strict_score(0.0), | |
| priority_correctness=enforce_strict_score(0.0), | |
| numeric_score=enforce_strict_score(0.0), | |
| accuracy=enforce_strict_score(0.0), | |
| weighted_score=enforce_strict_score(0.0), | |
| confusion_matrix=self._empty_confusion_matrix(), | |
| urgent_penalty_count=0, | |
| ) | |
| category_hits = sum(record.category_correct for record in records) | |
| priority_hits = sum(record.priority_correct for record in records) | |
| classification_accuracy = category_hits / processed_count | |
| priority_correctness = priority_hits / processed_count | |
| weighted_score = sum(record.step_score for record in records) / processed_count / 100 | |
| numeric_score = weighted_score | |
| return ScoreSummary( | |
| processed_count=processed_count, | |
| total_count=total_count, | |
| # Boundary fix applied only at the final score output stage. | |
| classification_accuracy=enforce_strict_score(round(classification_accuracy, 4)), | |
| priority_correctness=enforce_strict_score(round(priority_correctness, 4)), | |
| numeric_score=enforce_strict_score(round(numeric_score, 4)), | |
| accuracy=enforce_strict_score(round(classification_accuracy, 4)), | |
| weighted_score=enforce_strict_score(round(weighted_score, 4)), | |
| confusion_matrix=self._confusion_matrix(records), | |
| urgent_penalty_count=sum(record.urgent_penalty_applied for record in records), | |
| ) | |
| def expected_priority(self, category: CategoryName) -> PriorityName: | |
| """Return the ideal priority for a hidden category.""" | |
| return self._EXPECTED_PRIORITY[category] | |
| def _score_pair( | |
| self, | |
| category_correct: bool, | |
| priority_correct: bool, | |
| urgent_penalty_applied: bool, | |
| ) -> float: | |
| category_points = self._CATEGORY_WEIGHT if category_correct else 0.0 | |
| priority_points = self._PRIORITY_WEIGHT if priority_correct else 0.0 | |
| penalty = self._URGENT_PENALTY if urgent_penalty_applied else 0.0 | |
| return round(max(category_points + priority_points - penalty, 0.0) * 100, 2) | |
| def _has_urgent_miss( | |
| self, expected_category: CategoryName, predicted_category: CategoryName | |
| ) -> bool: | |
| return (expected_category == "urgent") != (predicted_category == "urgent") | |
| def _empty_confusion_matrix(self) -> dict[str, dict[str, int]]: | |
| categories = ("general", "promotion", "spam", "urgent") | |
| return {actual: {predicted: 0 for predicted in categories} for actual in categories} | |
| def _confusion_matrix( | |
| self, records: Sequence[EvaluationRecord] | |
| ) -> dict[str, dict[str, int]]: | |
| matrix = self._empty_confusion_matrix() | |
| for record in records: | |
| matrix[record.expected_category][record.predicted_category] += 1 | |
| return matrix | |