Vansh Jagetia
clean deploy for hf
a1933cb
"""Evaluation utilities for agent triage decisions."""
from __future__ import annotations
from collections.abc import Sequence
from core_engine.schemas import (
AgentDecision,
CategoryName,
EvaluationRecord,
PriorityName,
ScoreSummary,
SyntheticMail,
)
from core_engine.score_bounds import enforce_strict_score
class TriageEvaluator:
"""Score category and priority predictions."""
_CATEGORY_WEIGHT = 0.75
_PRIORITY_WEIGHT = 0.25
_URGENT_PENALTY = 0.35
_EXPECTED_PRIORITY: dict[CategoryName, PriorityName] = {
"urgent": "high",
"promotion": "medium",
"spam": "low",
"general": "low",
}
def evaluate(self, message: SyntheticMail, decision: AgentDecision) -> EvaluationRecord:
"""Compare one decision against the email ground truth."""
expected_priority = self.expected_priority(message.truth_category)
category_correct = decision.predicted_category == message.truth_category
priority_correct = decision.priority_level == expected_priority
urgent_penalty_applied = self._has_urgent_miss(
message.truth_category, decision.predicted_category
)
step_score = self._score_pair(
category_correct, priority_correct, urgent_penalty_applied
)
return EvaluationRecord(
mail_id=message.mail_id,
predicted_category=decision.predicted_category,
expected_category=message.truth_category,
predicted_priority=decision.priority_level,
expected_priority=expected_priority,
confidence=decision.confidence,
urgent_penalty_applied=urgent_penalty_applied,
category_correct=category_correct,
priority_correct=priority_correct,
step_score=step_score,
)
def summarize(
self, records: Sequence[EvaluationRecord], total_count: int
) -> ScoreSummary:
"""Compute aggregate accuracy and weighted numeric score."""
processed_count = len(records)
if processed_count == 0:
return ScoreSummary(
processed_count=0,
total_count=total_count,
classification_accuracy=enforce_strict_score(0.0),
priority_correctness=enforce_strict_score(0.0),
numeric_score=enforce_strict_score(0.0),
accuracy=enforce_strict_score(0.0),
weighted_score=enforce_strict_score(0.0),
confusion_matrix=self._empty_confusion_matrix(),
urgent_penalty_count=0,
)
category_hits = sum(record.category_correct for record in records)
priority_hits = sum(record.priority_correct for record in records)
classification_accuracy = category_hits / processed_count
priority_correctness = priority_hits / processed_count
weighted_score = sum(record.step_score for record in records) / processed_count / 100
numeric_score = weighted_score
return ScoreSummary(
processed_count=processed_count,
total_count=total_count,
# Boundary fix applied only at the final score output stage.
classification_accuracy=enforce_strict_score(round(classification_accuracy, 4)),
priority_correctness=enforce_strict_score(round(priority_correctness, 4)),
numeric_score=enforce_strict_score(round(numeric_score, 4)),
accuracy=enforce_strict_score(round(classification_accuracy, 4)),
weighted_score=enforce_strict_score(round(weighted_score, 4)),
confusion_matrix=self._confusion_matrix(records),
urgent_penalty_count=sum(record.urgent_penalty_applied for record in records),
)
def expected_priority(self, category: CategoryName) -> PriorityName:
"""Return the ideal priority for a hidden category."""
return self._EXPECTED_PRIORITY[category]
def _score_pair(
self,
category_correct: bool,
priority_correct: bool,
urgent_penalty_applied: bool,
) -> float:
category_points = self._CATEGORY_WEIGHT if category_correct else 0.0
priority_points = self._PRIORITY_WEIGHT if priority_correct else 0.0
penalty = self._URGENT_PENALTY if urgent_penalty_applied else 0.0
return round(max(category_points + priority_points - penalty, 0.0) * 100, 2)
def _has_urgent_miss(
self, expected_category: CategoryName, predicted_category: CategoryName
) -> bool:
return (expected_category == "urgent") != (predicted_category == "urgent")
def _empty_confusion_matrix(self) -> dict[str, dict[str, int]]:
categories = ("general", "promotion", "spam", "urgent")
return {actual: {predicted: 0 for predicted in categories} for actual in categories}
def _confusion_matrix(
self, records: Sequence[EvaluationRecord]
) -> dict[str, dict[str, int]]:
matrix = self._empty_confusion_matrix()
for record in records:
matrix[record.expected_category][record.predicted_category] += 1
return matrix