openenv-email-triage / tests /test_environment.py
Aneesha Das
Updated
61e4f68
"""
pytest suite for Email Triage OpenEnv v2.
Covers: model validation, grader, environment lifecycle,
sequential mechanics (SLA/budget/queue/cascade), and full episodes.
Run: pytest tests/ -v
"""
import pytest
from environment import EmailTriageEnv
from grader import score_action, grade_episode
from models import (
Action, Priority, Category, RouteTo,
Observation, Reward, EnvironmentState,
TEAM_CAPACITY, SLA_STEPS, TASK_ESCALATION_BUDGET,
)
from dataset import EASY_EMAILS, MEDIUM_EMAILS, HARD_EMAILS
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture
def easy_env():
e = EmailTriageEnv(task_id="easy"); e.reset(); return e
@pytest.fixture
def hard_env():
e = EmailTriageEnv(task_id="hard"); e.reset(); return e
GT = {
d["email"]["header"]["email_id"]: d["ground_truth"]
for dataset in [EASY_EMAILS, MEDIUM_EMAILS, HARD_EMAILS]
for d in dataset
}
def perfect(email_id: str) -> Action:
gt = GT[email_id]
return Action(
email_id = email_id,
priority = Priority(gt["priority"]),
category = Category(gt["category"]),
route_to = RouteTo(gt["route_to"]),
summary = f"Triaged {email_id}: {gt['category']} β†’ {gt['route_to']}",
flag_review = gt.get("requires_escalation", False),
)
# ── 1. Model layer ────────────────────────────────────────────────────────────
def test_action_valid():
a = Action(
email_id="e001", priority=Priority.URGENT,
category=Category.TECHNICAL_SUPPORT, route_to=RouteTo.SUPPORT_TIER2,
summary="Account compromised β€” urgent lockdown needed.", flag_review=True,
)
assert a.priority == Priority.URGENT
assert a.flag_review is True
def test_action_summary_max_length():
with pytest.raises(Exception):
Action(
email_id="e001", priority=Priority.LOW,
category=Category.GENERAL_INQUIRY, route_to=RouteTo.SUPPORT_TIER1,
summary="x" * 281, flag_review=False,
)
def test_observation_has_sequential_fields(easy_env):
obs = easy_env._make_observation()
assert hasattr(obs, "escalation_budget_remaining")
assert hasattr(obs, "team_queue_remaining")
assert hasattr(obs, "active_sla_warnings")
assert hasattr(obs, "sla_breaches_so_far")
assert hasattr(obs, "cascade_active")
def test_reward_has_sequential_breakdown_fields(easy_env):
obs = easy_env._make_observation()
eid = obs.current_email.header.email_id
_, r, _, _ = easy_env.step(perfect(eid))
assert hasattr(r.breakdown, "sla_penalty")
assert hasattr(r.breakdown, "queue_penalty")
assert hasattr(r.breakdown, "budget_penalty")
assert hasattr(r.breakdown, "cascade_penalty")
# ── 2. Grader ─────────────────────────────────────────────────────────────────
def test_perfect_action_high_score():
reward, _ = score_action(perfect("e001"))
assert reward.total >= 0.85
def test_wrong_priority_lowers_score():
correct, _ = score_action(perfect("e001"))
wrong = Action(
email_id="e001", priority=Priority.LOW,
category=Category.TECHNICAL_SUPPORT, route_to=RouteTo.SUPPORT_TIER2,
summary="Test.", flag_review=True,
)
wrong_r, _ = score_action(wrong)
assert wrong_r.total < correct.total
def test_spam_correctly_identified():
r, _ = score_action(perfect("e002"))
assert r.total >= 0.85
def test_spam_misrouted_penalised():
a = Action(
email_id="e002", priority=Priority.SPAM,
category=Category.SPAM_PHISHING, route_to=RouteTo.SALES,
summary="Spam.", flag_review=False,
)
_, detail = score_action(a)
assert detail["scores"]["base_penalty"] < 0
def test_urgent_as_spam_heavy_penalty():
a = Action(
email_id="e001", priority=Priority.SPAM,
category=Category.SPAM_PHISHING, route_to=RouteTo.TRASH,
summary="Spam.", flag_review=False,
)
_, detail = score_action(a)
assert detail["scores"]["base_penalty"] <= -0.3
def test_adjacent_priority_partial_credit():
a = Action(
email_id="e001", priority=Priority.HIGH, # adjacent to urgent
category=Category.TECHNICAL_SUPPORT, route_to=RouteTo.SUPPORT_TIER2,
summary="Security issue.", flag_review=True,
)
_, detail = score_action(a)
assert 0 < detail["scores"]["priority"] < 1
def test_missed_escalation_zero_score():
a = Action(
email_id="e001", priority=Priority.URGENT,
category=Category.TECHNICAL_SUPPORT, route_to=RouteTo.SUPPORT_TIER2,
summary="Security.", flag_review=False, # should be True
)
_, detail = score_action(a)
assert detail["scores"]["escalation"] == 0.0
def test_over_escalation_half_score():
a = Action(
email_id="e003", priority=Priority.LOW,
category=Category.INTERNAL_HR, route_to=RouteTo.HR,
summary="Team lunch reminder.", flag_review=True, # not needed
)
_, detail = score_action(a)
assert detail["scores"]["escalation"] == 0.5
def test_grade_episode_aggregation():
actions = [
{"email_id":"e001","priority":"urgent","category":"technical_support",
"route_to":"support_tier2","summary":"Security incident.",
"flag_review":True,"reasoning":""},
{"email_id":"e002","priority":"spam","category":"spam_phishing",
"route_to":"trash","summary":"Spam.","flag_review":False,"reasoning":""},
]
result = grade_episode(actions)
assert 0.0 <= result["label_score"] <= 1.0
assert result["num_emails"] == 2
# ── 3. Environment lifecycle ──────────────────────────────────────────────────
def test_reset_returns_observation(easy_env):
obs = easy_env.reset()
assert isinstance(obs, Observation)
assert obs.total_emails == 5
assert obs.remaining == 5
assert obs.current_email is not None
def test_reset_initialises_sequential_state(easy_env):
obs = easy_env.reset()
assert obs.escalation_budget_remaining == TASK_ESCALATION_BUDGET["easy"]
assert obs.sla_breaches_so_far == 0
assert not obs.cascade_active
for k, cap in TEAM_CAPACITY.items():
assert obs.team_queue_remaining[k] == cap
def test_step_types(easy_env):
eid = easy_env._make_observation().current_email.header.email_id
obs, reward, done, info = easy_env.step(perfect(eid))
assert isinstance(obs, Observation)
assert isinstance(reward, Reward)
assert isinstance(done, bool)
assert isinstance(info, dict)
def test_step_removes_email_from_inbox(easy_env):
obs = easy_env._make_observation()
eid = obs.current_email.header.email_id
obs2, _, _, _ = easy_env.step(perfect(eid))
assert obs2.remaining == obs.remaining - 1
assert eid not in [e.header.email_id for e in obs2.inbox]
def test_step_decrements_queue(easy_env):
obs = easy_env._make_observation()
eid = obs.current_email.header.email_id # e001 β†’ support_tier2
obs2, _, _, _ = easy_env.step(perfect(eid))
assert obs2.team_queue_remaining["support_tier2"] == TEAM_CAPACITY["support_tier2"] - 1
def test_step_decrements_escalation_budget(easy_env):
# e001 requires escalation
obs = easy_env._make_observation()
eid = obs.current_email.header.email_id
budget_before = obs.escalation_budget_remaining
obs2, _, _, _ = easy_env.step(perfect(eid))
# perfect(e001) has flag_review=True
assert obs2.escalation_budget_remaining == budget_before - 1
def test_invalid_email_id_penalised(easy_env):
a = Action(
email_id="INVALID", priority=Priority.LOW,
category=Category.GENERAL_INQUIRY, route_to=RouteTo.SUPPORT_TIER1,
summary="Test.", flag_review=False,
)
_, reward, _, _ = easy_env.step(a)
assert reward.total < 0
def test_state_returns_correct_type(easy_env):
st = easy_env.state()
assert isinstance(st, EnvironmentState)
assert st.task_id == "easy"
assert not st.done
assert "escalation_budget" in st.constraints
# ── 4. Sequential mechanics ───────────────────────────────────────────────────
@pytest.mark.parametrize("task_id,n", [("easy",5),("medium",8),("hard",10)])
def test_task_email_count(task_id, n):
env = EmailTriageEnv(task_id=task_id)
obs = env.reset()
assert obs.total_emails == n
@pytest.mark.parametrize("task_id", ["easy","medium","hard"])
def test_task_escalation_budget(task_id):
env = EmailTriageEnv(task_id=task_id)
obs = env.reset()
assert obs.escalation_budget_remaining == TASK_ESCALATION_BUDGET[task_id]
def test_sla_breach_fires_for_delayed_urgent():
"""Leaving an urgent email untouched for 3 steps triggers SLA breach."""
env = EmailTriageEnv("easy"); env.reset()
# e001 (urgent) deadline = arrived_at_step(0) + SLA_STEPS["urgent"](2) = 2
# Breach fires when self._step >= 2, which happens at the START of step index 2 (3rd action)
non_urgent = [
d for d in EASY_EMAILS if d["ground_truth"]["priority"] != "urgent"
]
last_reward = None
for d in non_urgent[:3]: # 3 actions β†’ tick at step=2 β†’ breach
eid = d["email"]["header"]["email_id"]
_, last_reward, _, _ = env.step(perfect(eid))
assert env._constraints.sla_breaches >= 1
assert last_reward.breakdown.sla_penalty < 0
def test_sla_penalty_magnitude():
env = EmailTriageEnv("easy"); env.reset()
non_urgent = [d for d in EASY_EMAILS if d["ground_truth"]["priority"] != "urgent"]
for d in non_urgent[:3]:
eid = d["email"]["header"]["email_id"]
_, r, _, _ = env.step(perfect(eid))
assert r.breakdown.sla_penalty == pytest.approx(-0.15, abs=0.001)
def test_budget_exhaustion_penalty():
"""Escalating beyond the budget incurs -0.20 penalty."""
env = EmailTriageEnv("easy"); env.reset() # budget=3
budget_penalties = []
for d in EASY_EMAILS:
eid = d["email"]["header"]["email_id"]
gt = d["ground_truth"]
# Force flag_review=True for every email regardless of need
a = Action(
email_id=eid, priority=Priority(gt["priority"]),
category=Category(gt["category"]), route_to=RouteTo(gt["route_to"]),
summary=f"Forced escalation of {eid}.", flag_review=True,
)
_, r, _, info = env.step(a)
budget_penalties.append(r.breakdown.budget_penalty)
# Budget=3, 5 emails all escalated β†’ emails 4 and 5 overflow
assert sum(1 for p in budget_penalties if p == -0.20) >= 2
def test_queue_saturation_penalty(hard_env):
"""Routing 3+ emails to legal (capacity=2) triggers overflow penalty."""
legal_emails = [
d for d in HARD_EMAILS if d["ground_truth"]["route_to"] == "legal"
]
assert len(legal_emails) >= 3, "Need β‰₯3 legal-routed emails"
queue_penalties = []
for d in legal_emails:
eid = d["email"]["header"]["email_id"]
gt = d["ground_truth"]
# Force routing to legal regardless of capacity
a = Action(
email_id=eid, priority=Priority(gt["priority"]),
category=Category(gt["category"]), route_to=RouteTo.LEGAL,
summary=f"Routing {eid} to legal.", flag_review=False,
)
_, r, _, _ = hard_env.step(a)
queue_penalties.append(r.breakdown.queue_penalty)
assert any(p == -0.10 for p in queue_penalties)
assert hard_env._constraints.queue_overflows >= 1
def test_cascade_triggers_after_two_urgent_sla_breaches():
"""Processing urgents in wrong order causes cascade."""
env = EmailTriageEnv("hard"); env.reset()
# h001 deadline=2, h002 deadline=3
non_urgent = [d for d in HARD_EMAILS if d["ground_truth"]["priority"] != "urgent"]
urgents = [d for d in HARD_EMAILS if d["ground_truth"]["priority"] == "urgent"]
for d in non_urgent: # steps 0, 1
env.step(perfect(d["email"]["header"]["email_id"]))
env.step(perfect(urgents[0]["email"]["header"]["email_id"])) # step 2: h001 SLA fires
_, r, _, _ = env.step(perfect(urgents[1]["email"]["header"]["email_id"])) # step 3: h002 SLA fires
assert env._constraints.sla_breaches >= 2
assert env._constraints.cascade_triggered
assert r.breakdown.cascade_penalty == pytest.approx(-0.25, abs=0.001)
def test_cascade_fires_only_once():
"""Cascade penalty is one-time even if more urgents breach."""
env = EmailTriageEnv("hard"); env.reset()
non_urgent = [d for d in HARD_EMAILS if d["ground_truth"]["priority"] != "urgent"]
urgents = [d for d in HARD_EMAILS if d["ground_truth"]["priority"] == "urgent"]
for d in non_urgent:
env.step(perfect(d["email"]["header"]["email_id"]))
cascade_penalties = []
for d in urgents[:4]:
_, r, _, _ = env.step(perfect(d["email"]["header"]["email_id"]))
cascade_penalties.append(r.breakdown.cascade_penalty)
assert sum(1 for p in cascade_penalties if p == -0.25) == 1 # fires exactly once
# ── 5. Full episodes ──────────────────────────────────────────────────────────
@pytest.mark.parametrize("task_id", ["easy","medium","hard"])
def test_oracle_episode_completes(task_id):
"""Perfect oracle agent completes episode with all rewards in [βˆ’1, 1]."""
datasets = {"easy": EASY_EMAILS, "medium": MEDIUM_EMAILS, "hard": HARD_EMAILS}
env = EmailTriageEnv(task_id=task_id); env.reset()
for d in datasets[task_id]:
eid = d["email"]["header"]["email_id"]
_, r, _, _ = env.step(perfect(eid))
assert -1.0 <= r.total <= 1.0
assert env.is_done
@pytest.mark.parametrize("task_id", ["easy","medium","hard"])
def test_oracle_label_scores_above_floor(task_id):
"""Oracle label scores should be high (before sequential penalties)."""
from grader import grade_episode
datasets = {"easy": EASY_EMAILS, "medium": MEDIUM_EMAILS, "hard": HARD_EMAILS}
env = EmailTriageEnv(task_id=task_id); env.reset()
actions = []
for d in datasets[task_id]:
eid = d["email"]["header"]["email_id"]
a = perfect(eid)
env.step(a)
actions.append(a.model_dump())
result = grade_episode(actions)
assert result["label_score"] >= 0.75, (
f"Oracle label score too low on {task_id}: {result['label_score']}"
)
def test_done_raises_on_further_step():
"""Calling step() after episode is done raises RuntimeError."""
env = EmailTriageEnv("easy"); env.reset()
for d in EASY_EMAILS:
env.step(perfect(d["email"]["header"]["email_id"]))
assert env.is_done
with pytest.raises(RuntimeError):
env.step(perfect("e001"))
def test_invalid_task_raises():
with pytest.raises(ValueError):
EmailTriageEnv(task_id="impossible")
def test_episode_info_contains_summary_on_done():
env = EmailTriageEnv("easy"); env.reset()
last_info = None
for d in EASY_EMAILS:
_, _, done, info = env.step(perfect(d["email"]["header"]["email_id"]))
last_info = info
assert "episode_summary" in last_info
assert "label_score" in last_info["episode_summary"]
def test_state_constraints_exposed():
env = EmailTriageEnv("hard"); env.reset()
st = env.state()
c = st.constraints
assert "escalation_budget" in c
assert "escalations_used" in c
assert "sla_breaches" in c
assert "queue_overflows" in c
assert "team_queues" in c