Spaces:
Sleeping
Sleeping
File size: 7,268 Bytes
96a5caf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """
tests.py — Unit tests for the Email Triage environment.
Run with: python tests.py
"""
import sys
from environment import (
EmailTriageEnv,
Action,
grade_task1,
grade_task2,
InboxState,
Email,
TASK1_GROUND_TRUTH,
TASK1_EMAILS
)
def run_test(name: str, fn):
try:
fn()
print(f" ✅ {name}")
return True
except AssertionError as e:
print(f" ❌ {name}: {e}")
return False
except Exception as e:
print(f" 💥 {name}: {type(e).__name__}: {e}")
return False
# ---------------------------------------------------------------------------
# Task 1 tests
# ---------------------------------------------------------------------------
def test_task1_reset():
env = EmailTriageEnv(task=1)
obs = env.reset()
assert obs.status == "ok"
assert obs.data["inbox_size"] == 5
def test_task1_list():
env = EmailTriageEnv(task=1)
env.reset()
result = env.step(Action(action="list_inbox"))
assert result.observation.status == "ok"
assert len(result.observation.data["emails"]) == 5
def test_task1_read():
env = EmailTriageEnv(task=1)
env.reset()
result = env.step(Action(action="read", email_id="t1_001"))
assert result.observation.status == "ok"
assert len(result.observation.data["subject"]) > 0
def test_task1_label_correct():
env = EmailTriageEnv(task=1)
env.reset()
gt = TASK1_GROUND_TRUTH["t1_001"]
result = env.step(Action(action="label", email_id="t1_001", priority=gt))
assert result.reward == 0.2, f"Expected 0.2, got {result.reward}"
def test_task1_label_wrong():
env = EmailTriageEnv(task=1)
env.reset()
gt = TASK1_GROUND_TRUTH["t1_001"]
wrong = "low" if gt in ("urgent", "normal") else "urgent"
result = env.step(Action(action="label", email_id="t1_001", priority=wrong))
assert result.reward == 0.0
def test_task1_full_score():
env = EmailTriageEnv(task=1)
env.reset()
for eid, priority in TASK1_GROUND_TRUTH.items():
env.step(Action(action="label", email_id=eid, priority=priority))
assert env.score() == 1.0, f"Expected 1.0, got {env.score()}"
def test_task1_partial_score():
env = EmailTriageEnv(task=1)
env.reset()
eids = list(TASK1_GROUND_TRUTH.keys())
env.step(Action(action="label", email_id=eids[0], priority=TASK1_GROUND_TRUTH[eids[0]]))
env.step(Action(action="label", email_id=eids[1], priority=TASK1_GROUND_TRUTH[eids[1]]))
score = env.score()
assert score == 0.4, f"Expected 0.4, got {score}"
# ---------------------------------------------------------------------------
# Task 2 tests
# ---------------------------------------------------------------------------
def test_task2_reset():
env = EmailTriageEnv(task=2)
obs = env.reset()
assert obs.data["inbox_size"] == 1
def test_task2_no_reply_zero():
env = EmailTriageEnv(task=2)
env.reset()
assert env.score() == 0.0
def test_task2_good_reply():
env = EmailTriageEnv(task=2)
env.reset()
env.step(Action(
action="draft_reply",
email_id="t2_001",
body=(
"Dear Jamie,\n\nThank you for reaching out. We sincerely apologize for the "
"experience you have had with order #48291. We understand how frustrating "
"this must be.\n\nWe are urgently investigating the status of your delivery "
"and will provide an update within 2 hours. If we cannot confirm delivery "
"within 48 hours we will process a full refund immediately. We will also "
"review the service failures you experienced and follow up regarding "
"compensation.\n\nWe truly value your business and are committed to "
"making this right.\n\nSincerely,\nCustomer Support Team"
),
))
score = env.score()
assert score > 0.5, f"Expected score > 0.5, got {score}"
def test_task2_short_reply_penalised():
env = EmailTriageEnv(task=2)
env.reset()
result = env.step(Action(action="draft_reply", email_id="t2_001", body="ok"))
assert result.observation.status == "error"
# ---------------------------------------------------------------------------
# Task 3 tests
# ---------------------------------------------------------------------------
def test_task3_reset():
env = EmailTriageEnv(task=3)
obs = env.reset()
assert obs.data["inbox_size"] == 10
def test_task3_archive_spam_no_penalty():
env = EmailTriageEnv(task=3)
env.reset()
# Label spam as low first (so archiving doesn't trigger urgent penalty)
env.step(Action(action="label", email_id="t3_002", priority="low"))
result = env.step(Action(action="archive", email_id="t3_002"))
assert result.observation.status == "ok"
def test_task3_archive_urgent_penalty():
env = EmailTriageEnv(task=3)
env.reset()
env.step(Action(action="label", email_id="t3_001", priority="urgent"))
result = env.step(Action(action="archive", email_id="t3_001"))
assert result.reward == -0.1
assert result.observation.status == "warning"
def test_task3_flag():
env = EmailTriageEnv(task=3)
env.reset()
result = env.step(Action(action="flag", email_id="t3_009", reason="Missing context — need sender identity"))
assert result.observation.status == "ok"
def test_task3_loop_detection():
env = EmailTriageEnv(task=3)
env.reset()
for _ in range(3):
env.step(Action(action="label", email_id="t3_006", priority="normal"))
assert env._penalties["loop_actions"] >= 1
def test_task3_not_found():
env = EmailTriageEnv(task=3)
env.reset()
result = env.step(Action(action="read", email_id="nonexistent"))
assert result.observation.status == "error"
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
if __name__ == "__main__":
tests = [
# Task 1
("Task1 reset", test_task1_reset),
("Task1 list inbox", test_task1_list),
("Task1 read email", test_task1_read),
("Task1 correct label reward", test_task1_label_correct),
("Task1 wrong label no reward", test_task1_label_wrong),
("Task1 full score 1.0", test_task1_full_score),
("Task1 partial score 0.4", test_task1_partial_score),
# Task 2
("Task2 reset", test_task2_reset),
("Task2 no reply = 0.0", test_task2_no_reply_zero),
("Task2 good reply > 0.5", test_task2_good_reply),
("Task2 short reply error", test_task2_short_reply_penalised),
# Task 3
("Task3 reset", test_task3_reset),
("Task3 archive spam no penalty", test_task3_archive_spam_no_penalty),
("Task3 archive urgent = penalty", test_task3_archive_urgent_penalty),
("Task3 flag ambiguous", test_task3_flag),
("Task3 loop detection", test_task3_loop_detection),
("Task3 not found error", test_task3_not_found),
]
print("\nRunning Email Triage Environment Tests")
print("=" * 45)
passed = sum(run_test(name, fn) for name, fn in tests)
total = len(tests)
print(f"\n{passed}/{total} tests passed")
sys.exit(0 if passed == total else 1)
|