Spaces:
Sleeping
Sleeping
File size: 9,548 Bytes
a3d65ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | #!/usr/bin/env python3
"""
run_tests.py β Self-contained test runner for support_ticket_env.
Runs all test cases using only the Python standard library.
Usage:
python run_tests.py
"""
import sys
import os
import traceback
from typing import Callable, List, Tuple
# βββ path setup ββββββββββββββββββββββββββββββββββββββββββββββββ
ROOT = os.path.dirname(os.path.abspath(__file__))
STUB = os.path.join(ROOT, "openenv_stub")
sys.path.insert(0, STUB)
sys.path.insert(0, ROOT)
# βββ minimal test framework ββββββββββββββββββββββββββββββββββββ
_tests: List[Tuple[str, Callable]] = []
_passed = 0
_failed = 0
_errors = 0
def test(fn: Callable) -> Callable:
_tests.append((fn.__qualname__, fn))
return fn
def assert_eq(a, b, msg=""):
if a != b:
raise AssertionError(f"{msg} | expected {b!r}, got {a!r}")
def assert_true(val, msg=""):
if not val:
raise AssertionError(msg or f"Expected truthy, got {val!r}")
def assert_in_range(val, lo, hi, msg=""):
if not (lo <= val <= hi):
raise AssertionError(msg or f"Expected {val!r} in [{lo}, {hi}]")
# βββββββββββββββββββββββββββββββ imports βββββββββββββββββββββββ
from support_ticket_env.graders import (
grade_task1, grade_task2, grade_task3, loop_penalty,
)
from support_ticket_env.server.support_environment import SupportTicketEnvironment
from support_ticket_env.models import SupportAction
def make_env():
return SupportTicketEnvironment()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# GRADER TESTS
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@test
def test_grade1_correct():
assert_eq(grade_task1("billing", "billing"), 1.0)
@test
def test_grade1_wrong():
assert_eq(grade_task1("technical", "billing"), 0.0)
@test
def test_grade1_all_categories():
for cat in ["billing", "technical", "account", "general", "refund"]:
assert_eq(grade_task1(cat, cat), 1.0, f"cat={cat}")
@test
def test_grade1_empty():
assert_eq(grade_task1("", "billing"), 0.0)
@test
def test_grade2_exact_reply():
assert_eq(grade_task2("reply", "reply"), 1.0)
@test
def test_grade2_exact_escalate():
assert_eq(grade_task2("escalate", "escalate"), 1.0)
@test
def test_grade2_exact_close():
assert_eq(grade_task2("close", "close"), 1.0)
@test
def test_grade2_partial_reply_escalate():
assert_eq(grade_task2("reply", "escalate"), 0.5)
assert_eq(grade_task2("escalate", "reply"), 0.5)
@test
def test_grade2_close_wrong():
assert_eq(grade_task2("close", "reply"), 0.0)
@test
def test_grade3_perfect():
score = grade_task3(True, True, False,
"we will process your refund billing payment",
"billing", True, 1, 5)
assert_true(score >= 0.9, f"Expected >=0.9, got {score}")
@test
def test_grade3_capped_at_one():
score = grade_task3(True, True, False,
"refund billing payment account cancel subscription",
"billing", True, 1, 5)
assert_true(score <= 1.0, f"Score exceeds 1.0: {score}")
@test
def test_grade3_partial_action_less_than_full():
s_partial = grade_task3(True, False, True, None, "technical", True, 2)
s_full = grade_task3(True, True, False, None, "technical", True, 2)
assert_true(s_partial < s_full, f"partial={s_partial} should < full={s_full}")
@test
def test_loop_penalty_none_within_limit():
assert_eq(loop_penalty(5), 0.0)
assert_eq(loop_penalty(10), 0.0)
@test
def test_loop_penalty_grows():
assert_true(loop_penalty(12) < loop_penalty(11))
assert_true(loop_penalty(11) < 0)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ENVIRONMENT TESTS
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@test
def test_env_reset_task1():
env = make_env()
obs = env.reset(task_id=1, seed=42)
assert_true(obs.ticket_text != "", "ticket_text should not be empty")
assert_eq(obs.task_id, 1)
assert_eq(obs.done, False)
@test
def test_env_task1_correct_classification():
env = make_env()
env.reset(task_id=1, seed=42)
state = env.state
obs = env.step(SupportAction(action_type="classify", category=state.correct_category))
assert_eq(obs.reward, 1.0)
assert_eq(obs.done, True)
@test
def test_env_task1_wrong_classification():
env = make_env()
env.reset(task_id=1, seed=42)
state = env.state
wrong = next(c for c in ["billing","technical","account","general","refund"]
if c != state.correct_category)
obs = env.step(SupportAction(action_type="classify", category=wrong))
assert_eq(obs.reward, 0.0)
assert_eq(obs.done, True)
@test
def test_env_task2_must_classify_first():
env = make_env()
env.reset(task_id=2, seed=42)
obs = env.step(SupportAction(action_type="escalate"))
assert_eq(obs.done, False)
assert_true("classify" in obs.feedback.lower())
@test
def test_env_task2_full_correct_episode():
env = make_env()
env.reset(task_id=2, seed=42)
state = env.state
env.step(SupportAction(action_type="classify", category=state.correct_category))
obs = env.step(SupportAction(action_type=state.correct_action))
assert_eq(obs.done, True)
assert_true(obs.reward >= 0.5, f"reward={obs.reward}")
@test
def test_env_task3_three_tickets():
env = make_env()
env.reset(task_id=3, seed=42)
assert_eq(env.state.tickets_total, 3)
@test
def test_env_task3_resolves_all():
env = make_env()
env.reset(task_id=3, seed=42)
done = False
steps = 0
while not done and steps < 30:
state = env.state
if not state.classified:
action = SupportAction(action_type="classify", category=state.correct_category)
else:
ca = state.correct_action
action = (SupportAction(action_type="reply",
reply_text=f"Handling your {state.correct_category} issue.")
if ca == "reply" else SupportAction(action_type=ca))
obs = env.step(action)
done = obs.done
steps += 1
assert_true(done, "Episode did not finish")
assert_eq(env.state.tickets_resolved, 3)
@test
def test_env_state_step_count():
env = make_env()
env.reset(task_id=1, seed=0)
assert_eq(env.state.step_count, 0)
state = env.state
env.step(SupportAction(action_type="classify", category=state.correct_category))
assert_eq(env.state.step_count, 1)
@test
def test_env_reward_always_in_range():
for seed in [0, 1, 2, 42, 99]:
for task_id in [1, 2, 3]:
env = make_env()
env.reset(task_id=task_id, seed=seed)
state = env.state
obs = env.step(SupportAction(action_type="classify", category=state.correct_category))
r = obs.reward or 0.0
assert_in_range(r, -1.0, 1.0, f"task={task_id} seed={seed} reward={r}")
@test
def test_env_task3_total_reward_positive():
env = make_env()
env.reset(task_id=3, seed=7)
total = 0.0
done = False
steps = 0
while not done and steps < 20:
state = env.state
action = (SupportAction(action_type="classify", category=state.correct_category)
if not state.classified
else SupportAction(action_type=state.correct_action))
obs = env.step(action)
total += obs.reward or 0.0
done = obs.done
steps += 1
assert_true(total > 0.0, f"total_reward={total}")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Runner
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_all():
global _passed, _failed, _errors
width = max(len(name) for name, _ in _tests) + 2
print(f"\n{'='*60}")
print(f" Running {len(_tests)} tests")
print(f"{'='*60}")
for name, fn in _tests:
try:
fn()
print(f" β
{name}")
_passed += 1
except AssertionError as e:
print(f" β {name}")
print(f" {e}")
_failed += 1
except Exception:
print(f" π₯ {name}")
traceback.print_exc(limit=3)
_errors += 1
total = _passed + _failed + _errors
print(f"\n{'='*60}")
print(f" Results: {_passed}/{total} passed | {_failed} failed | {_errors} errors")
print(f"{'='*60}\n")
return _failed + _errors == 0
if __name__ == "__main__":
ok = run_all()
sys.exit(0 if ok else 1)
|