meta-learning-push / tests /test_engine.py
Vansh Jagetia
clean deploy for hf
a1933cb
from __future__ import annotations
from brain.decision_maker import choose_mail_action
from core_engine.evaluator import TriageEvaluator
from core_engine.schemas import AgentDecision, SyntheticMail
from core_engine.simulator import SimulationEngine
def test_reset_generates_emails():
engine = SimulationEngine(batch_size=3, random_seed=11)
state = engine.reset()
assert len(state["emails"]) == 3
assert state["progress"] == {"processed": 0, "remaining": 3, "total": 3}
assert state["current_email"]["id"] == state["emails"][0]["id"]
def test_step_updates_state_and_score():
engine = SimulationEngine(batch_size=2, random_seed=13)
state = engine.reset()
decision = choose_mail_action(state["current_email"])
result = engine.step(decision)
assert result["done"] is False
assert result["state"]["progress"]["processed"] == 1
assert 0 < result["reward"] < 1
assert 0 < result["score"]["classification_accuracy"] < 1
def test_evaluator_returns_accuracy_scores_between_zero_and_one():
evaluator = TriageEvaluator()
message = SyntheticMail(
mail_id="mail_test",
sender="ops@example.com",
subject="Urgent production review needed",
body="Please review asap.",
truth_category="urgent",
)
decision = AgentDecision(
mail_id="mail_test",
predicted_category="urgent",
priority_level="high",
)
record = evaluator.evaluate(message, decision)
summary = evaluator.summarize([record], total_count=1)
assert 0 < summary.classification_accuracy < 1
assert 0 < summary.priority_correctness < 1
assert 0 < summary.weighted_score < 1
assert summary.numeric_score == 0.999999
assert summary.confusion_matrix["urgent"]["urgent"] == 1
def test_urgent_misclassification_gets_penalty():
evaluator = TriageEvaluator()
message = SyntheticMail(
mail_id="mail_urgent",
sender="ops@example.com",
subject="Urgent production review needed",
body="Please review asap.",
truth_category="urgent",
)
decision = AgentDecision(
mail_id="mail_urgent",
predicted_category="general",
priority_level="low",
)
record = evaluator.evaluate(message, decision)
summary = evaluator.summarize([record], total_count=1)
assert record.urgent_penalty_applied is True
assert summary.urgent_penalty_count == 1
assert summary.weighted_score == 0.000001
def test_all_wrong_score_is_strictly_above_zero():
evaluator = TriageEvaluator()
message = SyntheticMail(
mail_id="mail_wrong",
sender="ops@example.com",
subject="Urgent production review needed",
body="Please review asap.",
truth_category="urgent",
)
decision = AgentDecision(
mail_id="mail_wrong",
predicted_category="general",
priority_level="low",
)
summary = evaluator.summarize([evaluator.evaluate(message, decision)], total_count=1)
assert 0 < summary.numeric_score < 1
def test_all_correct_score_is_strictly_below_one():
evaluator = TriageEvaluator()
message = SyntheticMail(
mail_id="mail_correct",
sender="shop@example.com",
subject="Weekend sale",
body="Discount available today.",
truth_category="promotion",
)
decision = AgentDecision(
mail_id="mail_correct",
predicted_category="promotion",
priority_level="medium",
)
summary = evaluator.summarize([evaluator.evaluate(message, decision)], total_count=1)
assert 0 < summary.numeric_score < 1
def test_normal_non_boundary_score_is_unchanged():
evaluator = TriageEvaluator()
message = SyntheticMail(
mail_id="mail_partial",
sender="shop@example.com",
subject="Weekend sale",
body="Discount available today.",
truth_category="promotion",
)
decision = AgentDecision(
mail_id="mail_partial",
predicted_category="general",
priority_level="medium",
)
summary = evaluator.summarize([evaluator.evaluate(message, decision)], total_count=1)
assert summary.numeric_score == 0.25