Spaces:
Running
Running
Coding Ninja commited on
Commit ·
42dd095
1
Parent(s): 5dd60ae
feat: competitive upgrade for hackathon submission
Browse files- inference.py: single-task mode via TASK_ID env var; clean warn/exit on invalid IDs
- models.py: add last_step_reward, done, history_entries to HelpdeskTicketState
- environment.py: state tracking, enriched history (title+predicted), ambiguity_note
in observations, extra-field penalty validation, SUPPORTS_CONCURRENT_SESSIONS=True
- reward.py: milestone shaping (+/-0.05 at score thresholds), remove overshoot penalty
- app.py: add GET /web HTML status page
- openenv.yaml: clarify entry_point vs pyproject.toml server script
- dataset.json: add 3 non-default routing tickets (TKT-NONDEFAULT-001/002/003)
- tests: test_competitive_upgrade.py (36 tests) + test_extra_fields_penalty.py (7 tests)
- data/dataset.json +36 -0
- inference.py +23 -5
- models.py +3 -0
- openenv.yaml +3 -0
- server/app.py +20 -0
- server/environment.py +51 -18
- server/reward.py +12 -4
- tests/test_competitive_upgrade.py +468 -0
- tests/test_extra_fields_penalty.py +183 -0
- uv.lock +0 -0
data/dataset.json
CHANGED
|
@@ -538,6 +538,42 @@
|
|
| 538 |
"resolution_action": "escalate",
|
| 539 |
"ambiguity_note": null,
|
| 540 |
"related_ticket_id": "ticket-030"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
}
|
| 542 |
]
|
| 543 |
|
|
|
|
| 538 |
"resolution_action": "escalate",
|
| 539 |
"ambiguity_note": null,
|
| 540 |
"related_ticket_id": "ticket-030"
|
| 541 |
+
},
|
| 542 |
+
{
|
| 543 |
+
"ticket_id": "TKT-NONDEFAULT-001",
|
| 544 |
+
"title": "Billing question from free-tier account",
|
| 545 |
+
"requester": "user@freetier.io",
|
| 546 |
+
"description": "I have a question about my invoice but I am on the free plan and there is no charge. The billing team cannot action this; please route to service desk for general assistance.",
|
| 547 |
+
"issue_type": "billing_license",
|
| 548 |
+
"priority": "low",
|
| 549 |
+
"assignment_group": "service_desk",
|
| 550 |
+
"resolution_action": "fulfill",
|
| 551 |
+
"ambiguity_note": "Account tier is free; billing team cannot action, routed to service desk",
|
| 552 |
+
"related_ticket_id": null
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"ticket_id": "TKT-NONDEFAULT-002",
|
| 556 |
+
"title": "App vulnerability flagged in compliance scan",
|
| 557 |
+
"requester": "security@clientcorp.com",
|
| 558 |
+
"description": "Our compliance scan flagged a product-specific vulnerability in the application layer. This is not a general security policy issue but an app bug requiring the application team to remediate.",
|
| 559 |
+
"issue_type": "security_compliance",
|
| 560 |
+
"priority": "high",
|
| 561 |
+
"assignment_group": "application_team",
|
| 562 |
+
"resolution_action": "escalate",
|
| 563 |
+
"ambiguity_note": "Compliance issue is product-specific (app vulnerability), routed to app team",
|
| 564 |
+
"related_ticket_id": null
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
"ticket_id": "TKT-NONDEFAULT-003",
|
| 568 |
+
"title": "Contractor onboarding blocked by access issue",
|
| 569 |
+
"requester": "pm@contractorco.com",
|
| 570 |
+
"description": "A new contractor cannot complete onboarding because their account access is blocked by a permissions error. The onboarding team cannot resolve access issues; routing to service desk.",
|
| 571 |
+
"issue_type": "onboarding",
|
| 572 |
+
"priority": "medium",
|
| 573 |
+
"assignment_group": "service_desk",
|
| 574 |
+
"resolution_action": "fulfill",
|
| 575 |
+
"ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk",
|
| 576 |
+
"related_ticket_id": null
|
| 577 |
}
|
| 578 |
]
|
| 579 |
|
inference.py
CHANGED
|
@@ -64,7 +64,7 @@ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
|
| 64 |
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 65 |
|
| 66 |
SEED = 42
|
| 67 |
-
|
| 68 |
|
| 69 |
# ---------------------------------------------------------------------------
|
| 70 |
# LLM helper
|
|
@@ -134,6 +134,20 @@ def emit_log(tag: str, **payload: Any) -> None:
|
|
| 134 |
print(f"[{tag}] {json.dumps(payload, sort_keys=True, ensure_ascii=True)}")
|
| 135 |
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# ---------------------------------------------------------------------------
|
| 138 |
# Heuristic fallback (no LLM needed)
|
| 139 |
# ---------------------------------------------------------------------------
|
|
@@ -332,7 +346,10 @@ def run() -> None:
|
|
| 332 |
|
| 333 |
all_results: dict[int, dict[str, float | int]] = {}
|
| 334 |
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
| 336 |
if task_id not in available_tasks:
|
| 337 |
continue
|
| 338 |
|
|
@@ -400,11 +417,12 @@ def run() -> None:
|
|
| 400 |
|
| 401 |
overall = [
|
| 402 |
float(all_results[task_id]["final_reward"])
|
| 403 |
-
for task_id in
|
| 404 |
if task_id in all_results
|
| 405 |
]
|
| 406 |
-
|
| 407 |
-
|
|
|
|
| 408 |
|
| 409 |
|
| 410 |
if __name__ == "__main__":
|
|
|
|
| 64 |
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 65 |
|
| 66 |
SEED = 42
|
| 67 |
+
TASK_ID_ENV = os.getenv("TASK_ID")
|
| 68 |
|
| 69 |
# ---------------------------------------------------------------------------
|
| 70 |
# LLM helper
|
|
|
|
| 134 |
print(f"[{tag}] {json.dumps(payload, sort_keys=True, ensure_ascii=True)}")
|
| 135 |
|
| 136 |
|
| 137 |
+
def get_tasks_to_run(available_tasks: dict) -> list[int]:
|
| 138 |
+
if TASK_ID_ENV:
|
| 139 |
+
try:
|
| 140 |
+
task_id = int(TASK_ID_ENV)
|
| 141 |
+
except ValueError:
|
| 142 |
+
print(f"[ERROR] TASK_ID={TASK_ID_ENV!r} is not a valid integer", flush=True)
|
| 143 |
+
raise SystemExit(1)
|
| 144 |
+
if task_id not in available_tasks:
|
| 145 |
+
print(f"[WARN] TASK_ID={task_id} not in available tasks {list(available_tasks)}", flush=True)
|
| 146 |
+
return []
|
| 147 |
+
return [task_id]
|
| 148 |
+
return list(TASK_IDS) # fallback: all tasks (local dev)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
# Heuristic fallback (no LLM needed)
|
| 153 |
# ---------------------------------------------------------------------------
|
|
|
|
| 346 |
|
| 347 |
all_results: dict[int, dict[str, float | int]] = {}
|
| 348 |
|
| 349 |
+
tasks_to_run = get_tasks_to_run(available_tasks)
|
| 350 |
+
single_task_mode = bool(TASK_ID_ENV)
|
| 351 |
+
|
| 352 |
+
for task_id in tasks_to_run:
|
| 353 |
if task_id not in available_tasks:
|
| 354 |
continue
|
| 355 |
|
|
|
|
| 417 |
|
| 418 |
overall = [
|
| 419 |
float(all_results[task_id]["final_reward"])
|
| 420 |
+
for task_id in tasks_to_run
|
| 421 |
if task_id in all_results
|
| 422 |
]
|
| 423 |
+
if not single_task_mode:
|
| 424 |
+
overall_avg = round(sum(overall) / len(overall), 4) if overall else 0.0
|
| 425 |
+
emit_log("END", overall_avg=overall_avg, tasks_completed=len(overall))
|
| 426 |
|
| 427 |
|
| 428 |
if __name__ == "__main__":
|
models.py
CHANGED
|
@@ -112,3 +112,6 @@ class HelpdeskTicketState(State):
|
|
| 112 |
current_ticket_index: int = 0
|
| 113 |
per_ticket_scores: list[float] = Field(default_factory=list)
|
| 114 |
total_reward: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
current_ticket_index: int = 0
|
| 113 |
per_ticket_scores: list[float] = Field(default_factory=list)
|
| 114 |
total_reward: float = 0.0
|
| 115 |
+
last_step_reward: Optional[float] = None
|
| 116 |
+
done: bool = False
|
| 117 |
+
history_entries: list[dict] = Field(default_factory=list)
|
openenv.yaml
CHANGED
|
@@ -7,6 +7,9 @@ author: Hackstreet Boys - Roopal Guha Neogi, Suyash Kumar
|
|
| 7 |
|
| 8 |
environment:
|
| 9 |
type: openenv
|
|
|
|
|
|
|
|
|
|
| 10 |
entry_point: server.environment:HelpdeskTicketRoutingEnvironment
|
| 11 |
action_model: models:HelpdeskTicketAction
|
| 12 |
observation_model: models:HelpdeskTicketObservation
|
|
|
|
| 7 |
|
| 8 |
environment:
|
| 9 |
type: openenv
|
| 10 |
+
# entry_point identifies the Environment class for the OpenEnv validator.
|
| 11 |
+
# The HTTP server entrypoint for deployment is defined separately in
|
| 12 |
+
# pyproject.toml under [project.scripts] as: server = "server.app:main"
|
| 13 |
entry_point: server.environment:HelpdeskTicketRoutingEnvironment
|
| 14 |
action_model: models:HelpdeskTicketAction
|
| 15 |
observation_model: models:HelpdeskTicketObservation
|
server/app.py
CHANGED
|
@@ -6,6 +6,7 @@ _repo_root = str(Path(__file__).resolve().parent.parent)
|
|
| 6 |
if _repo_root not in sys.path:
|
| 7 |
sys.path.insert(0, _repo_root)
|
| 8 |
|
|
|
|
| 9 |
from openenv.core.env_server import create_app
|
| 10 |
|
| 11 |
from models import HelpdeskTicketAction, HelpdeskTicketObservation
|
|
@@ -37,6 +38,25 @@ def list_tasks():
|
|
| 37 |
}
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def main() -> None:
|
| 41 |
import uvicorn
|
| 42 |
|
|
|
|
| 6 |
if _repo_root not in sys.path:
|
| 7 |
sys.path.insert(0, _repo_root)
|
| 8 |
|
| 9 |
+
from fastapi.responses import HTMLResponse
|
| 10 |
from openenv.core.env_server import create_app
|
| 11 |
|
| 12 |
from models import HelpdeskTicketAction, HelpdeskTicketObservation
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
|
| 41 |
+
@app.get("/web", response_class=HTMLResponse)
|
| 42 |
+
def web_ui():
|
| 43 |
+
task_rows = "".join(
|
| 44 |
+
f"<tr><td>{t['id']}</td><td>{t['name']}</td><td>{t['difficulty']}</td></tr>"
|
| 45 |
+
for t in TASKS.values()
|
| 46 |
+
)
|
| 47 |
+
html = f"""<!DOCTYPE html>
|
| 48 |
+
<html><head><title>{APP_ENV_NAME}</title></head>
|
| 49 |
+
<body>
|
| 50 |
+
<h1>{APP_ENV_NAME}</h1>
|
| 51 |
+
<p>Version: 0.1.0 | <a href="/health">Health</a> | <a href="/docs">API Docs</a></p>
|
| 52 |
+
<h2>Tasks</h2>
|
| 53 |
+
<table border="1"><tr><th>ID</th><th>Name</th><th>Difficulty</th></tr>
|
| 54 |
+
{task_rows}
|
| 55 |
+
</table>
|
| 56 |
+
</body></html>"""
|
| 57 |
+
return HTMLResponse(content=html)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
def main() -> None:
|
| 61 |
import uvicorn
|
| 62 |
|
server/environment.py
CHANGED
|
@@ -36,6 +36,8 @@ def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
|
|
| 36 |
class HelpdeskTicketRoutingEnvironment(
|
| 37 |
Environment[HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState]
|
| 38 |
):
|
|
|
|
|
|
|
| 39 |
def __init__(self) -> None:
|
| 40 |
super().__init__()
|
| 41 |
self._dataset = load_dataset()
|
|
@@ -94,16 +96,43 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 94 |
task_id = self._state.current_task_id
|
| 95 |
task = get_task_definition(task_id)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
score, breakdown = grade_action(action, current_ticket, task_id)
|
| 98 |
step_reward = compute_step_reward(score)
|
| 99 |
|
| 100 |
-
self._state.
|
| 101 |
-
self._state.step_count += 1
|
| 102 |
-
self._state.current_ticket_index += 1
|
| 103 |
-
|
| 104 |
-
is_done = self._state.current_ticket_index >= len(self._queue)
|
| 105 |
|
| 106 |
if is_done:
|
|
|
|
|
|
|
|
|
|
| 107 |
traj_reward = compute_trajectory_reward(
|
| 108 |
self._state.per_ticket_scores,
|
| 109 |
len(self._queue),
|
|
@@ -112,20 +141,24 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 112 |
self._state.total_reward = traj_reward
|
| 113 |
final_reward = traj_reward
|
| 114 |
else:
|
|
|
|
|
|
|
|
|
|
| 115 |
final_reward = step_reward
|
| 116 |
|
| 117 |
history_entry = {
|
| 118 |
"ticket_id": current_ticket.ticket_id,
|
|
|
|
|
|
|
| 119 |
"score": score,
|
| 120 |
"breakdown": breakdown,
|
| 121 |
}
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
extra_history=history_entry,
|
| 128 |
-
)
|
| 129 |
|
| 130 |
@property
|
| 131 |
def state(self) -> HelpdeskTicketState:
|
|
@@ -140,27 +173,26 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 140 |
task: dict,
|
| 141 |
done: bool = False,
|
| 142 |
reward: float | None = None,
|
| 143 |
-
extra_history: dict | None = None,
|
| 144 |
) -> HelpdeskTicketObservation:
|
| 145 |
idx = self._state.current_ticket_index
|
| 146 |
queue_size = len(self._queue)
|
| 147 |
|
| 148 |
if idx < queue_size:
|
| 149 |
ticket = self._queue[idx]
|
| 150 |
-
ticket_view = {
|
| 151 |
"ticket_id": ticket.ticket_id,
|
| 152 |
"title": ticket.title,
|
| 153 |
"requester": ticket.requester,
|
| 154 |
"description": ticket.description,
|
| 155 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
else:
|
| 157 |
ticket_view = None
|
| 158 |
|
| 159 |
-
history
|
| 160 |
-
for i, s in enumerate(self._state.per_ticket_scores):
|
| 161 |
-
history.append({"step": i + 1, "score": s})
|
| 162 |
-
if extra_history and history:
|
| 163 |
-
history[-1] = {"step": len(history), **extra_history}
|
| 164 |
|
| 165 |
return HelpdeskTicketObservation(
|
| 166 |
done=done,
|
|
@@ -172,6 +204,7 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 172 |
allowed_fields=list(task["allowed_fields"]),
|
| 173 |
current_ticket=ticket_view,
|
| 174 |
queue_size=queue_size,
|
|
|
|
| 175 |
tickets_remaining=max(0, queue_size - idx),
|
| 176 |
tickets_processed=idx,
|
| 177 |
history=history,
|
|
|
|
| 36 |
class HelpdeskTicketRoutingEnvironment(
|
| 37 |
Environment[HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState]
|
| 38 |
):
|
| 39 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 40 |
+
|
| 41 |
def __init__(self) -> None:
|
| 42 |
super().__init__()
|
| 43 |
self._dataset = load_dataset()
|
|
|
|
| 96 |
task_id = self._state.current_task_id
|
| 97 |
task = get_task_definition(task_id)
|
| 98 |
|
| 99 |
+
submitted_fields = {
|
| 100 |
+
f for f, v in action.model_dump(exclude_none=True).items() if v is not None
|
| 101 |
+
}
|
| 102 |
+
allowed = set(task["allowed_fields"])
|
| 103 |
+
extra_fields = submitted_fields - allowed
|
| 104 |
+
if extra_fields:
|
| 105 |
+
# Penalty: record score 0.0, advance index, return penalty observation
|
| 106 |
+
self._state.per_ticket_scores.append(0.0)
|
| 107 |
+
self._state.history_entries.append({
|
| 108 |
+
"ticket_id": current_ticket.ticket_id,
|
| 109 |
+
"title": current_ticket.title,
|
| 110 |
+
"predicted": action.model_dump(exclude_none=True),
|
| 111 |
+
"score": 0.0,
|
| 112 |
+
"breakdown": {},
|
| 113 |
+
"penalty_reason": f"extra_fields: {sorted(extra_fields)}",
|
| 114 |
+
})
|
| 115 |
+
self._state.step_count += 1
|
| 116 |
+
self._state.current_ticket_index += 1
|
| 117 |
+
is_done = self._state.current_ticket_index >= len(self._queue)
|
| 118 |
+
self._state.last_step_reward = 0.0
|
| 119 |
+
self._state.done = is_done
|
| 120 |
+
if is_done:
|
| 121 |
+
traj_reward = compute_trajectory_reward(
|
| 122 |
+
self._state.per_ticket_scores, len(self._queue), self._state.step_count
|
| 123 |
+
)
|
| 124 |
+
self._state.total_reward = traj_reward
|
| 125 |
+
return self._build_observation(task, done=is_done, reward=0.0)
|
| 126 |
+
|
| 127 |
score, breakdown = grade_action(action, current_ticket, task_id)
|
| 128 |
step_reward = compute_step_reward(score)
|
| 129 |
|
| 130 |
+
is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
if is_done:
|
| 133 |
+
self._state.per_ticket_scores.append(score)
|
| 134 |
+
self._state.step_count += 1
|
| 135 |
+
self._state.current_ticket_index += 1
|
| 136 |
traj_reward = compute_trajectory_reward(
|
| 137 |
self._state.per_ticket_scores,
|
| 138 |
len(self._queue),
|
|
|
|
| 141 |
self._state.total_reward = traj_reward
|
| 142 |
final_reward = traj_reward
|
| 143 |
else:
|
| 144 |
+
self._state.per_ticket_scores.append(score)
|
| 145 |
+
self._state.step_count += 1
|
| 146 |
+
self._state.current_ticket_index += 1
|
| 147 |
final_reward = step_reward
|
| 148 |
|
| 149 |
history_entry = {
|
| 150 |
"ticket_id": current_ticket.ticket_id,
|
| 151 |
+
"title": current_ticket.title,
|
| 152 |
+
"predicted": action.model_dump(exclude_none=True),
|
| 153 |
"score": score,
|
| 154 |
"breakdown": breakdown,
|
| 155 |
}
|
| 156 |
+
self._state.history_entries.append(history_entry)
|
| 157 |
|
| 158 |
+
self._state.last_step_reward = final_reward
|
| 159 |
+
self._state.done = is_done
|
| 160 |
+
|
| 161 |
+
return self._build_observation(task, done=is_done, reward=final_reward)
|
|
|
|
|
|
|
| 162 |
|
| 163 |
@property
|
| 164 |
def state(self) -> HelpdeskTicketState:
|
|
|
|
| 173 |
task: dict,
|
| 174 |
done: bool = False,
|
| 175 |
reward: float | None = None,
|
|
|
|
| 176 |
) -> HelpdeskTicketObservation:
|
| 177 |
idx = self._state.current_ticket_index
|
| 178 |
queue_size = len(self._queue)
|
| 179 |
|
| 180 |
if idx < queue_size:
|
| 181 |
ticket = self._queue[idx]
|
| 182 |
+
ticket_view: dict[str, Any] = {
|
| 183 |
"ticket_id": ticket.ticket_id,
|
| 184 |
"title": ticket.title,
|
| 185 |
"requester": ticket.requester,
|
| 186 |
"description": ticket.description,
|
| 187 |
}
|
| 188 |
+
if ticket.ambiguity_note is not None:
|
| 189 |
+
ticket_view["ambiguity_note"] = ticket.ambiguity_note
|
| 190 |
+
if ticket.related_ticket_id is not None:
|
| 191 |
+
ticket_view["related_ticket_id"] = ticket.related_ticket_id
|
| 192 |
else:
|
| 193 |
ticket_view = None
|
| 194 |
|
| 195 |
+
history = list(self._state.history_entries)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
return HelpdeskTicketObservation(
|
| 198 |
done=done,
|
|
|
|
| 204 |
allowed_fields=list(task["allowed_fields"]),
|
| 205 |
current_ticket=ticket_view,
|
| 206 |
queue_size=queue_size,
|
| 207 |
+
# tickets_remaining: count of tickets not yet processed after this step
|
| 208 |
tickets_remaining=max(0, queue_size - idx),
|
| 209 |
tickets_processed=idx,
|
| 210 |
history=history,
|
server/reward.py
CHANGED
|
@@ -1,8 +1,18 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
def compute_step_reward(score: float) -> float:
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def compute_trajectory_reward(
|
|
@@ -11,6 +21,4 @@ def compute_trajectory_reward(
|
|
| 11 |
if not per_ticket_scores:
|
| 12 |
return 0.0
|
| 13 |
avg = sum(per_ticket_scores) / len(per_ticket_scores)
|
| 14 |
-
|
| 15 |
-
penalty = overshoot * 0.03
|
| 16 |
-
return max(0.0, min(1.0, avg - penalty))
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
MILESTONE_HIGH_THRESHOLD = 0.8
|
| 4 |
+
MILESTONE_LOW_THRESHOLD = 0.2
|
| 5 |
+
MILESTONE_BONUS = 0.05
|
| 6 |
+
MILESTONE_PENALTY = 0.05
|
| 7 |
+
|
| 8 |
|
| 9 |
def compute_step_reward(score: float) -> float:
|
| 10 |
+
base = max(0.0, min(1.0, score))
|
| 11 |
+
if score >= MILESTONE_HIGH_THRESHOLD:
|
| 12 |
+
return min(1.0, base + MILESTONE_BONUS)
|
| 13 |
+
if score < MILESTONE_LOW_THRESHOLD:
|
| 14 |
+
return max(0.0, base - MILESTONE_PENALTY)
|
| 15 |
+
return base
|
| 16 |
|
| 17 |
|
| 18 |
def compute_trajectory_reward(
|
|
|
|
| 21 |
if not per_ticket_scores:
|
| 22 |
return 0.0
|
| 23 |
avg = sum(per_ticket_scores) / len(per_ticket_scores)
|
| 24 |
+
return max(0.0, min(1.0, avg))
|
|
|
|
|
|
tests/test_competitive_upgrade.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the helpdesk-competitive-upgrade spec (Task 9).
|
| 3 |
+
|
| 4 |
+
Covers:
|
| 5 |
+
9.1 test_inference_single_task_mode
|
| 6 |
+
9.2 test_state_has_reward_and_done
|
| 7 |
+
9.3 test_history_has_title_and_predicted
|
| 8 |
+
9.4 test_milestone_reward_shaping
|
| 9 |
+
9.5 test_trajectory_reward_no_overshoot
|
| 10 |
+
9.6 test_ambiguity_note_in_observation
|
| 11 |
+
9.7 test_dataset_nondefault_routing
|
| 12 |
+
9.9 test_concurrent_sessions_flag
|
| 13 |
+
9.10 test_web_ui_endpoint
|
| 14 |
+
|
| 15 |
+
Run with:
|
| 16 |
+
pytest tests/test_competitive_upgrade.py
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
import types as _types
|
| 23 |
+
import unittest
|
| 24 |
+
|
| 25 |
+
# Ensure repo root is on sys.path
|
| 26 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 27 |
+
|
| 28 |
+
import openenv_test_stubs # noqa: F401 — must come before any openenv imports
|
| 29 |
+
|
| 30 |
+
# Patch in the interfaces module so environment.py can import Environment.
|
| 31 |
+
if "openenv.core.env_server.interfaces" not in sys.modules:
|
| 32 |
+
_interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
|
| 33 |
+
|
| 34 |
+
class _Environment:
|
| 35 |
+
"""Minimal stub matching the openenv-core Environment base class."""
|
| 36 |
+
|
| 37 |
+
def __init__(self) -> None:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
def __init_subclass__(cls, **kwargs: object) -> None:
|
| 41 |
+
super().__init_subclass__(**kwargs)
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def __class_getitem__(cls, item: object) -> type:
|
| 45 |
+
return cls
|
| 46 |
+
|
| 47 |
+
_interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
|
| 48 |
+
sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
from models import HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState
|
| 52 |
+
from server.environment import HelpdeskTicketRoutingEnvironment
|
| 53 |
+
from server.reward import compute_step_reward, compute_trajectory_reward
|
| 54 |
+
from server.tasks import load_dataset
|
| 55 |
+
from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS, TASK_IDS
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Helpers
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def _make_env() -> HelpdeskTicketRoutingEnvironment:
|
| 63 |
+
return HelpdeskTicketRoutingEnvironment()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _heuristic_action(obs: HelpdeskTicketObservation) -> HelpdeskTicketAction:
|
| 67 |
+
allowed = obs.allowed_fields
|
| 68 |
+
kwargs: dict = {}
|
| 69 |
+
if "issue_type" in allowed:
|
| 70 |
+
kwargs["issue_type"] = ISSUE_TYPES[0]
|
| 71 |
+
if "priority" in allowed:
|
| 72 |
+
kwargs["priority"] = PRIORITIES[0]
|
| 73 |
+
if "assignment_group" in allowed:
|
| 74 |
+
kwargs["assignment_group"] = ASSIGNMENT_GROUPS[0]
|
| 75 |
+
if "resolution_action" in allowed:
|
| 76 |
+
kwargs["resolution_action"] = RESOLUTION_ACTIONS[0]
|
| 77 |
+
return HelpdeskTicketAction(**kwargs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
# 9.1 — Inference single-task mode
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def _get_tasks_to_run_impl(task_id_env: str | None, available_tasks: dict) -> list[int]:
|
| 85 |
+
"""
|
| 86 |
+
Standalone re-implementation of inference.get_tasks_to_run() logic for testing.
|
| 87 |
+
|
| 88 |
+
This mirrors the logic in inference.py without importing the full module
|
| 89 |
+
(which has heavy dependencies like openai, httpx, and client.py).
|
| 90 |
+
"""
|
| 91 |
+
if task_id_env:
|
| 92 |
+
try:
|
| 93 |
+
task_id = int(task_id_env)
|
| 94 |
+
except ValueError:
|
| 95 |
+
raise SystemExit(1)
|
| 96 |
+
if task_id not in available_tasks:
|
| 97 |
+
return []
|
| 98 |
+
return [task_id]
|
| 99 |
+
return list(TASK_IDS)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class TestInferenceSingleTaskMode(unittest.TestCase):
|
| 103 |
+
"""9.1 — get_tasks_to_run() respects TASK_ID env var."""
|
| 104 |
+
|
| 105 |
+
def test_task_id_set_to_valid_id_returns_single_element_list(self) -> None:
|
| 106 |
+
available = {1: {}, 2: {}, 3: {}}
|
| 107 |
+
result = _get_tasks_to_run_impl("1", available)
|
| 108 |
+
self.assertEqual(result, [1])
|
| 109 |
+
|
| 110 |
+
def test_task_id_set_to_unavailable_id_returns_empty_list(self) -> None:
|
| 111 |
+
available = {1: {}, 2: {}, 3: {}}
|
| 112 |
+
result = _get_tasks_to_run_impl("999", available)
|
| 113 |
+
self.assertEqual(result, [])
|
| 114 |
+
|
| 115 |
+
def test_task_id_unset_returns_all_task_ids(self) -> None:
|
| 116 |
+
available = {1: {}, 2: {}, 3: {}}
|
| 117 |
+
result = _get_tasks_to_run_impl(None, available)
|
| 118 |
+
self.assertEqual(sorted(result), sorted(list(TASK_IDS)))
|
| 119 |
+
|
| 120 |
+
def test_task_id_set_to_2_returns_only_task_2(self) -> None:
|
| 121 |
+
available = {1: {}, 2: {}, 3: {}}
|
| 122 |
+
result = _get_tasks_to_run_impl("2", available)
|
| 123 |
+
self.assertEqual(result, [2])
|
| 124 |
+
|
| 125 |
+
def test_task_id_set_to_3_returns_only_task_3(self) -> None:
|
| 126 |
+
available = {1: {}, 2: {}, 3: {}}
|
| 127 |
+
result = _get_tasks_to_run_impl("3", available)
|
| 128 |
+
self.assertEqual(result, [3])
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
# 9.2 — State has last_step_reward and done after step()
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
class TestStateHasRewardAndDone(unittest.TestCase):
|
| 136 |
+
"""9.2 — state.last_step_reward and state.done are set after step()."""
|
| 137 |
+
|
| 138 |
+
def test_last_step_reward_is_none_after_reset(self) -> None:
|
| 139 |
+
env = _make_env()
|
| 140 |
+
env.reset(seed=42, task_id=1)
|
| 141 |
+
self.assertIsNone(env.state.last_step_reward)
|
| 142 |
+
|
| 143 |
+
def test_done_is_false_after_reset(self) -> None:
|
| 144 |
+
env = _make_env()
|
| 145 |
+
env.reset(seed=42, task_id=1)
|
| 146 |
+
self.assertFalse(env.state.done)
|
| 147 |
+
|
| 148 |
+
def test_last_step_reward_set_after_step(self) -> None:
|
| 149 |
+
env = _make_env()
|
| 150 |
+
obs = env.reset(seed=42, task_id=1)
|
| 151 |
+
action = _heuristic_action(obs)
|
| 152 |
+
env.step(action)
|
| 153 |
+
state = env.state
|
| 154 |
+
self.assertIsNotNone(state.last_step_reward)
|
| 155 |
+
self.assertGreaterEqual(state.last_step_reward, 0.0)
|
| 156 |
+
self.assertLessEqual(state.last_step_reward, 1.0)
|
| 157 |
+
|
| 158 |
+
def test_done_is_true_after_last_ticket(self) -> None:
|
| 159 |
+
env = _make_env()
|
| 160 |
+
obs = env.reset(seed=42, task_id=1)
|
| 161 |
+
while not obs.done:
|
| 162 |
+
obs = env.step(_heuristic_action(obs))
|
| 163 |
+
self.assertTrue(env.state.done)
|
| 164 |
+
|
| 165 |
+
def test_done_is_false_before_last_ticket(self) -> None:
|
| 166 |
+
env = _make_env()
|
| 167 |
+
obs = env.reset(seed=42, task_id=1)
|
| 168 |
+
if obs.queue_size > 1:
|
| 169 |
+
obs = env.step(_heuristic_action(obs))
|
| 170 |
+
self.assertFalse(env.state.done)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
# 9.3 — History entry contains title and predicted
|
| 175 |
+
# ---------------------------------------------------------------------------
|
| 176 |
+
|
| 177 |
+
class TestHistoryHasTitleAndPredicted(unittest.TestCase):
|
| 178 |
+
"""9.3 — observation.history[0] contains 'title' and 'predicted' keys."""
|
| 179 |
+
|
| 180 |
+
def test_history_entry_has_title(self) -> None:
|
| 181 |
+
env = _make_env()
|
| 182 |
+
obs = env.reset(seed=42, task_id=1)
|
| 183 |
+
action = _heuristic_action(obs)
|
| 184 |
+
obs2 = env.step(action)
|
| 185 |
+
self.assertEqual(len(obs2.history), 1)
|
| 186 |
+
self.assertIn("title", obs2.history[0])
|
| 187 |
+
self.assertIsInstance(obs2.history[0]["title"], str)
|
| 188 |
+
self.assertTrue(obs2.history[0]["title"]) # non-empty
|
| 189 |
+
|
| 190 |
+
def test_history_entry_has_predicted(self) -> None:
|
| 191 |
+
env = _make_env()
|
| 192 |
+
obs = env.reset(seed=42, task_id=1)
|
| 193 |
+
action = _heuristic_action(obs)
|
| 194 |
+
obs2 = env.step(action)
|
| 195 |
+
self.assertIn("predicted", obs2.history[0])
|
| 196 |
+
self.assertIsInstance(obs2.history[0]["predicted"], dict)
|
| 197 |
+
|
| 198 |
+
def test_history_predicted_matches_action(self) -> None:
|
| 199 |
+
env = _make_env()
|
| 200 |
+
obs = env.reset(seed=42, task_id=1)
|
| 201 |
+
action = _heuristic_action(obs)
|
| 202 |
+
obs2 = env.step(action)
|
| 203 |
+
predicted = obs2.history[0]["predicted"]
|
| 204 |
+
action_dict = action.model_dump(exclude_none=True)
|
| 205 |
+
self.assertEqual(predicted, action_dict)
|
| 206 |
+
|
| 207 |
+
def test_history_entry_has_ticket_id_and_score(self) -> None:
|
| 208 |
+
env = _make_env()
|
| 209 |
+
obs = env.reset(seed=42, task_id=1)
|
| 210 |
+
obs2 = env.step(_heuristic_action(obs))
|
| 211 |
+
entry = obs2.history[0]
|
| 212 |
+
self.assertIn("ticket_id", entry)
|
| 213 |
+
self.assertIn("score", entry)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ---------------------------------------------------------------------------
|
| 217 |
+
# 9.4 — Milestone reward shaping
|
| 218 |
+
# ---------------------------------------------------------------------------
|
| 219 |
+
|
| 220 |
+
class TestMilestoneRewardShaping(unittest.TestCase):
|
| 221 |
+
"""9.4 — compute_step_reward applies bonus at high scores, penalty at low scores."""
|
| 222 |
+
|
| 223 |
+
def test_high_score_gets_bonus(self) -> None:
|
| 224 |
+
# score=0.9 >= 0.8 threshold → base=0.9, bonus=0.05 → 0.95
|
| 225 |
+
result = compute_step_reward(0.9)
|
| 226 |
+
self.assertAlmostEqual(result, 0.95, places=9)
|
| 227 |
+
|
| 228 |
+
def test_low_score_gets_penalty(self) -> None:
|
| 229 |
+
# score=0.1 < 0.2 threshold → base=0.1, penalty=0.05 → 0.05
|
| 230 |
+
result = compute_step_reward(0.1)
|
| 231 |
+
self.assertAlmostEqual(result, 0.05, places=9)
|
| 232 |
+
|
| 233 |
+
def test_mid_score_is_neutral(self) -> None:
|
| 234 |
+
# score=0.5 is in [0.2, 0.8) → no shaping → 0.5
|
| 235 |
+
result = compute_step_reward(0.5)
|
| 236 |
+
self.assertAlmostEqual(result, 0.5, places=9)
|
| 237 |
+
|
| 238 |
+
def test_boundary_high_threshold_gets_bonus(self) -> None:
|
| 239 |
+
# score=0.8 exactly → bonus applies → 0.85
|
| 240 |
+
result = compute_step_reward(0.8)
|
| 241 |
+
self.assertAlmostEqual(result, 0.85, places=9)
|
| 242 |
+
|
| 243 |
+
def test_boundary_low_threshold_is_neutral(self) -> None:
|
| 244 |
+
# score=0.2 exactly → not < 0.2, so neutral → 0.2
|
| 245 |
+
result = compute_step_reward(0.2)
|
| 246 |
+
self.assertAlmostEqual(result, 0.2, places=9)
|
| 247 |
+
|
| 248 |
+
def test_reward_clamped_to_unit_interval(self) -> None:
|
| 249 |
+
# score=1.0 → base=1.0, bonus would push to 1.05 → clamped to 1.0
|
| 250 |
+
result = compute_step_reward(1.0)
|
| 251 |
+
self.assertLessEqual(result, 1.0)
|
| 252 |
+
self.assertGreaterEqual(result, 0.0)
|
| 253 |
+
|
| 254 |
+
def test_zero_score_clamped_to_zero(self) -> None:
|
| 255 |
+
# score=0.0 < 0.2 → base=0.0, penalty → max(0.0, -0.05) = 0.0
|
| 256 |
+
result = compute_step_reward(0.0)
|
| 257 |
+
self.assertGreaterEqual(result, 0.0)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
# 9.5 — Trajectory reward has no overshoot penalty
|
| 262 |
+
# ---------------------------------------------------------------------------
|
| 263 |
+
|
| 264 |
+
class TestTrajectoryRewardNoOvershoot(unittest.TestCase):
|
| 265 |
+
"""9.5 — compute_trajectory_reward does not penalise when steps > queue_size."""
|
| 266 |
+
|
| 267 |
+
def test_no_penalty_when_steps_exceed_queue_size(self) -> None:
|
| 268 |
+
scores = [0.8, 0.9, 0.7]
|
| 269 |
+
queue_size = 3
|
| 270 |
+
steps_taken = 10 # more steps than queue_size
|
| 271 |
+
result = compute_trajectory_reward(scores, queue_size, steps_taken)
|
| 272 |
+
expected_avg = sum(scores) / len(scores)
|
| 273 |
+
self.assertAlmostEqual(result, expected_avg, places=9)
|
| 274 |
+
|
| 275 |
+
def test_result_equals_average_regardless_of_steps(self) -> None:
|
| 276 |
+
scores = [0.5, 0.6]
|
| 277 |
+
for steps in [1, 2, 5, 100]:
|
| 278 |
+
result = compute_trajectory_reward(scores, len(scores), steps)
|
| 279 |
+
self.assertAlmostEqual(result, 0.55, places=9,
|
| 280 |
+
msg=f"Failed for steps={steps}")
|
| 281 |
+
|
| 282 |
+
def test_empty_scores_returns_zero(self) -> None:
|
| 283 |
+
self.assertEqual(compute_trajectory_reward([], 3, 3), 0.0)
|
| 284 |
+
|
| 285 |
+
def test_result_in_unit_interval(self) -> None:
|
| 286 |
+
scores = [0.9, 1.0, 0.95]
|
| 287 |
+
result = compute_trajectory_reward(scores, 3, 3)
|
| 288 |
+
self.assertGreaterEqual(result, 0.0)
|
| 289 |
+
self.assertLessEqual(result, 1.0)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
# 9.6 — ambiguity_note appears in current_ticket observation
|
| 294 |
+
# ---------------------------------------------------------------------------
|
| 295 |
+
|
| 296 |
+
class TestAmbiguityNoteInObservation(unittest.TestCase):
|
| 297 |
+
"""9.6 — current_ticket includes ambiguity_note when the ticket has one."""
|
| 298 |
+
|
| 299 |
+
def _find_seed_with_ambiguity_note(self, task_id: int = 3) -> int | None:
|
| 300 |
+
"""Try seeds 0..999 to find one where the first ticket has ambiguity_note."""
|
| 301 |
+
env = _make_env()
|
| 302 |
+
for seed in range(1000):
|
| 303 |
+
obs = env.reset(seed=seed, task_id=task_id)
|
| 304 |
+
if obs.current_ticket and obs.current_ticket.get("ambiguity_note"):
|
| 305 |
+
return seed
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
def test_ambiguity_note_present_when_ticket_has_one(self) -> None:
|
| 309 |
+
"""Force a ticket with ambiguity_note by patching the dataset."""
|
| 310 |
+
from unittest.mock import patch
|
| 311 |
+
from server.tasks import load_dataset
|
| 312 |
+
|
| 313 |
+
dataset = load_dataset()
|
| 314 |
+
# Find a ticket with ambiguity_note
|
| 315 |
+
ambiguous_tickets = [t for t in dataset if t.ambiguity_note is not None]
|
| 316 |
+
self.assertGreater(len(ambiguous_tickets), 0, "No tickets with ambiguity_note in dataset")
|
| 317 |
+
|
| 318 |
+
target = ambiguous_tickets[0]
|
| 319 |
+
|
| 320 |
+
env = _make_env()
|
| 321 |
+
# Patch the dataset to only contain the ambiguous ticket
|
| 322 |
+
with patch.object(env, "_dataset", [target]):
|
| 323 |
+
obs = env.reset(seed=0, task_id=3)
|
| 324 |
+
|
| 325 |
+
self.assertIsNotNone(obs.current_ticket)
|
| 326 |
+
self.assertIn("ambiguity_note", obs.current_ticket)
|
| 327 |
+
self.assertEqual(obs.current_ticket["ambiguity_note"], target.ambiguity_note)
|
| 328 |
+
|
| 329 |
+
def test_ambiguity_note_absent_when_ticket_has_none(self) -> None:
|
| 330 |
+
"""Tickets without ambiguity_note should not expose the key."""
|
| 331 |
+
from unittest.mock import patch
|
| 332 |
+
from server.tasks import load_dataset
|
| 333 |
+
|
| 334 |
+
dataset = load_dataset()
|
| 335 |
+
non_ambiguous = [t for t in dataset if t.ambiguity_note is None]
|
| 336 |
+
self.assertGreater(len(non_ambiguous), 0)
|
| 337 |
+
|
| 338 |
+
target = non_ambiguous[0]
|
| 339 |
+
env = _make_env()
|
| 340 |
+
with patch.object(env, "_dataset", [target]):
|
| 341 |
+
obs = env.reset(seed=0, task_id=3)
|
| 342 |
+
|
| 343 |
+
self.assertIsNotNone(obs.current_ticket)
|
| 344 |
+
self.assertNotIn("ambiguity_note", obs.current_ticket)
|
| 345 |
+
|
| 346 |
+
def test_tkt_nondefault_001_has_ambiguity_note(self) -> None:
|
| 347 |
+
"""TKT-NONDEFAULT-001 specifically has ambiguity_note set."""
|
| 348 |
+
from unittest.mock import patch
|
| 349 |
+
from server.tasks import load_dataset
|
| 350 |
+
|
| 351 |
+
dataset = load_dataset()
|
| 352 |
+
ticket = next((t for t in dataset if t.ticket_id == "TKT-NONDEFAULT-001"), None)
|
| 353 |
+
self.assertIsNotNone(ticket, "TKT-NONDEFAULT-001 not found in dataset")
|
| 354 |
+
self.assertIsNotNone(ticket.ambiguity_note)
|
| 355 |
+
|
| 356 |
+
env = _make_env()
|
| 357 |
+
with patch.object(env, "_dataset", [ticket]):
|
| 358 |
+
obs = env.reset(seed=0, task_id=3)
|
| 359 |
+
|
| 360 |
+
self.assertIn("ambiguity_note", obs.current_ticket)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# ---------------------------------------------------------------------------
|
| 364 |
+
# 9.7 — Dataset has >= 3 non-default routing tickets
|
| 365 |
+
# ---------------------------------------------------------------------------
|
| 366 |
+
|
| 367 |
+
class TestDatasetNonDefaultRouting(unittest.TestCase):
|
| 368 |
+
"""9.7 — Dataset contains at least 3 tickets with non-default assignment_group."""
|
| 369 |
+
|
| 370 |
+
def test_at_least_three_nondefault_routing_tickets(self) -> None:
|
| 371 |
+
from vocabulary import ISSUE_TYPE_TO_ASSIGNMENT_GROUP
|
| 372 |
+
|
| 373 |
+
dataset = load_dataset()
|
| 374 |
+
non_default = [
|
| 375 |
+
t for t in dataset
|
| 376 |
+
if t.assignment_group != ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(t.issue_type)
|
| 377 |
+
]
|
| 378 |
+
self.assertGreaterEqual(
|
| 379 |
+
len(non_default), 3,
|
| 380 |
+
f"Expected >= 3 non-default routing tickets, found {len(non_default)}: "
|
| 381 |
+
+ str([(t.ticket_id, t.issue_type, t.assignment_group) for t in non_default])
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def test_tkt_nondefault_tickets_exist(self) -> None:
|
| 385 |
+
dataset = load_dataset()
|
| 386 |
+
ids = {t.ticket_id for t in dataset}
|
| 387 |
+
for expected_id in ("TKT-NONDEFAULT-001", "TKT-NONDEFAULT-002", "TKT-NONDEFAULT-003"):
|
| 388 |
+
self.assertIn(expected_id, ids, f"{expected_id} not found in dataset")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# ---------------------------------------------------------------------------
|
| 392 |
+
# 9.9 — SUPPORTS_CONCURRENT_SESSIONS is True
|
| 393 |
+
# ---------------------------------------------------------------------------
|
| 394 |
+
|
| 395 |
+
class TestConcurrentSessionsFlag(unittest.TestCase):
|
| 396 |
+
"""9.9 — HelpdeskTicketRoutingEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True."""
|
| 397 |
+
|
| 398 |
+
def test_supports_concurrent_sessions_is_true(self) -> None:
|
| 399 |
+
self.assertTrue(HelpdeskTicketRoutingEnvironment.SUPPORTS_CONCURRENT_SESSIONS)
|
| 400 |
+
|
| 401 |
+
def test_flag_is_boolean_true(self) -> None:
|
| 402 |
+
flag = HelpdeskTicketRoutingEnvironment.SUPPORTS_CONCURRENT_SESSIONS
|
| 403 |
+
self.assertIs(flag, True)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# ---------------------------------------------------------------------------
|
| 407 |
+
# 9.10 — GET /web returns 200 with HTML content
|
| 408 |
+
# ---------------------------------------------------------------------------
|
| 409 |
+
|
| 410 |
+
def _build_web_test_app():
|
| 411 |
+
"""Build a minimal FastAPI app with only the /web route for testing."""
|
| 412 |
+
from fastapi import FastAPI
|
| 413 |
+
from fastapi.responses import HTMLResponse
|
| 414 |
+
from server.tasks import TASKS
|
| 415 |
+
from vocabulary import APP_ENV_NAME
|
| 416 |
+
|
| 417 |
+
_app = FastAPI()
|
| 418 |
+
|
| 419 |
+
@_app.get("/web", response_class=HTMLResponse)
|
| 420 |
+
def web_ui():
|
| 421 |
+
task_rows = "".join(
|
| 422 |
+
f"<tr><td>{t['id']}</td><td>{t['name']}</td><td>{t['difficulty']}</td></tr>"
|
| 423 |
+
for t in TASKS.values()
|
| 424 |
+
)
|
| 425 |
+
html = f"""<!DOCTYPE html>
|
| 426 |
+
<html><head><title>{APP_ENV_NAME}</title></head>
|
| 427 |
+
<body>
|
| 428 |
+
<h1>{APP_ENV_NAME}</h1>
|
| 429 |
+
<p>Version: 0.1.0 | <a href="/health">Health</a> | <a href="/docs">API Docs</a></p>
|
| 430 |
+
<h2>Tasks</h2>
|
| 431 |
+
<table border="1"><tr><th>ID</th><th>Name</th><th>Difficulty</th></tr>
|
| 432 |
+
{task_rows}
|
| 433 |
+
</table>
|
| 434 |
+
</body></html>"""
|
| 435 |
+
return HTMLResponse(content=html)
|
| 436 |
+
|
| 437 |
+
return _app
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class TestWebUIEndpoint(unittest.TestCase):
|
| 441 |
+
"""9.10 — GET /web returns HTTP 200 with HTML content."""
|
| 442 |
+
|
| 443 |
+
@classmethod
|
| 444 |
+
def setUpClass(cls) -> None:
|
| 445 |
+
from starlette.testclient import TestClient
|
| 446 |
+
app = _build_web_test_app()
|
| 447 |
+
cls.client = TestClient(app)
|
| 448 |
+
|
| 449 |
+
def test_web_returns_200(self) -> None:
|
| 450 |
+
response = self.client.get("/web")
|
| 451 |
+
self.assertEqual(response.status_code, 200)
|
| 452 |
+
|
| 453 |
+
def test_web_returns_html_content_type(self) -> None:
|
| 454 |
+
response = self.client.get("/web")
|
| 455 |
+
self.assertIn("text/html", response.headers.get("content-type", ""))
|
| 456 |
+
|
| 457 |
+
def test_web_response_contains_html_tag(self) -> None:
|
| 458 |
+
response = self.client.get("/web")
|
| 459 |
+
self.assertIn("<!DOCTYPE html>", response.text)
|
| 460 |
+
|
| 461 |
+
def test_web_response_contains_env_name(self) -> None:
|
| 462 |
+
from vocabulary import APP_ENV_NAME
|
| 463 |
+
response = self.client.get("/web")
|
| 464 |
+
self.assertIn(APP_ENV_NAME, response.text)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
if __name__ == "__main__":
|
| 468 |
+
unittest.main()
|
tests/test_extra_fields_penalty.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for action field validation (Task 4) in HelpdeskTicketRoutingEnvironment.step().
|
| 3 |
+
|
| 4 |
+
Validates Requirement 7: Step Validates Action Fields Against Task Contract.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import unittest
|
| 11 |
+
import types as _types
|
| 12 |
+
|
| 13 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 14 |
+
|
| 15 |
+
import openenv_test_stubs # noqa: F401
|
| 16 |
+
|
| 17 |
+
if "openenv.core.env_server.interfaces" not in sys.modules:
|
| 18 |
+
_interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
|
| 19 |
+
|
| 20 |
+
class _Environment:
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def __init_subclass__(cls, **kwargs: object) -> None:
|
| 25 |
+
super().__init_subclass__(**kwargs)
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def __class_getitem__(cls, item: object) -> type:
|
| 29 |
+
return cls
|
| 30 |
+
|
| 31 |
+
_interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
|
| 32 |
+
sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
|
| 33 |
+
|
| 34 |
+
from models import HelpdeskTicketAction, HelpdeskTicketObservation
|
| 35 |
+
from server.environment import HelpdeskTicketRoutingEnvironment
|
| 36 |
+
from server.tasks import TASKS
|
| 37 |
+
from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _make_env() -> HelpdeskTicketRoutingEnvironment:
|
| 41 |
+
return HelpdeskTicketRoutingEnvironment()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestExtraFieldsPenalty(unittest.TestCase):
|
| 45 |
+
"""Requirement 7: step() rejects actions with fields outside the task's allowed_fields."""
|
| 46 |
+
|
| 47 |
+
def test_extra_fields_returns_reward_zero(self) -> None:
|
| 48 |
+
"""Task 1 only allows issue_type and priority; submitting assignment_group triggers penalty."""
|
| 49 |
+
env = _make_env()
|
| 50 |
+
obs = env.reset(seed=42, task_id=1)
|
| 51 |
+
|
| 52 |
+
# Task 1 allowed_fields should NOT include assignment_group
|
| 53 |
+
self.assertNotIn("assignment_group", obs.allowed_fields)
|
| 54 |
+
|
| 55 |
+
# Submit an action with an extra field (assignment_group) not in task 1's allowed_fields
|
| 56 |
+
action = HelpdeskTicketAction(
|
| 57 |
+
issue_type=ISSUE_TYPES[0],
|
| 58 |
+
priority=PRIORITIES[0],
|
| 59 |
+
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
|
| 60 |
+
)
|
| 61 |
+
penalty_obs = env.step(action)
|
| 62 |
+
|
| 63 |
+
self.assertIsInstance(penalty_obs, HelpdeskTicketObservation)
|
| 64 |
+
self.assertEqual(penalty_obs.reward, 0.0)
|
| 65 |
+
|
| 66 |
+
def test_extra_fields_advances_ticket_index(self) -> None:
|
| 67 |
+
"""Penalty step must advance tickets_processed by 1."""
|
| 68 |
+
env = _make_env()
|
| 69 |
+
obs = env.reset(seed=42, task_id=1)
|
| 70 |
+
self.assertEqual(obs.tickets_processed, 0)
|
| 71 |
+
|
| 72 |
+
action = HelpdeskTicketAction(
|
| 73 |
+
issue_type=ISSUE_TYPES[0],
|
| 74 |
+
assignment_group=ASSIGNMENT_GROUPS[0], # extra field for task 1
|
| 75 |
+
)
|
| 76 |
+
penalty_obs = env.step(action)
|
| 77 |
+
|
| 78 |
+
self.assertEqual(penalty_obs.tickets_processed, 1)
|
| 79 |
+
|
| 80 |
+
def test_extra_fields_records_score_zero(self) -> None:
|
| 81 |
+
"""per_ticket_scores must contain 0.0 after a penalty step."""
|
| 82 |
+
env = _make_env()
|
| 83 |
+
env.reset(seed=42, task_id=1)
|
| 84 |
+
|
| 85 |
+
action = HelpdeskTicketAction(
|
| 86 |
+
issue_type=ISSUE_TYPES[0],
|
| 87 |
+
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
|
| 88 |
+
)
|
| 89 |
+
env.step(action)
|
| 90 |
+
|
| 91 |
+
state = env.state
|
| 92 |
+
self.assertEqual(len(state.per_ticket_scores), 1)
|
| 93 |
+
self.assertEqual(state.per_ticket_scores[0], 0.0)
|
| 94 |
+
|
| 95 |
+
def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
|
| 96 |
+
"""History entry for a penalty step must include penalty_reason."""
|
| 97 |
+
env = _make_env()
|
| 98 |
+
env.reset(seed=42, task_id=1)
|
| 99 |
+
|
| 100 |
+
action = HelpdeskTicketAction(
|
| 101 |
+
issue_type=ISSUE_TYPES[0],
|
| 102 |
+
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
|
| 103 |
+
)
|
| 104 |
+
penalty_obs = env.step(action)
|
| 105 |
+
|
| 106 |
+
self.assertEqual(len(penalty_obs.history), 1)
|
| 107 |
+
entry = penalty_obs.history[0]
|
| 108 |
+
self.assertIn("penalty_reason", entry)
|
| 109 |
+
self.assertIn("assignment_group", entry["penalty_reason"])
|
| 110 |
+
self.assertEqual(entry["score"], 0.0)
|
| 111 |
+
|
| 112 |
+
def test_no_extra_fields_grades_normally(self) -> None:
|
| 113 |
+
"""When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
|
| 114 |
+
env = _make_env()
|
| 115 |
+
obs = env.reset(seed=42, task_id=1)
|
| 116 |
+
|
| 117 |
+
# Build action using only allowed fields
|
| 118 |
+
allowed = obs.allowed_fields
|
| 119 |
+
action_kwargs = {}
|
| 120 |
+
if "issue_type" in allowed:
|
| 121 |
+
action_kwargs["issue_type"] = ISSUE_TYPES[0]
|
| 122 |
+
if "priority" in allowed:
|
| 123 |
+
action_kwargs["priority"] = PRIORITIES[0]
|
| 124 |
+
|
| 125 |
+
action = HelpdeskTicketAction(**action_kwargs)
|
| 126 |
+
result_obs = env.step(action)
|
| 127 |
+
|
| 128 |
+
# Should be a valid observation; reward may be any value in [0.0, 1.0]
|
| 129 |
+
self.assertIsInstance(result_obs, HelpdeskTicketObservation)
|
| 130 |
+
self.assertIsNotNone(result_obs.reward)
|
| 131 |
+
# No penalty_reason in history
|
| 132 |
+
self.assertEqual(len(result_obs.history), 1)
|
| 133 |
+
self.assertNotIn("penalty_reason", result_obs.history[0])
|
| 134 |
+
|
| 135 |
+
def test_extra_fields_no_exception_raised(self) -> None:
|
| 136 |
+
"""Requirement 7.4: extra fields must not raise an unhandled exception."""
|
| 137 |
+
env = _make_env()
|
| 138 |
+
env.reset(seed=42, task_id=1)
|
| 139 |
+
|
| 140 |
+
action = HelpdeskTicketAction(
|
| 141 |
+
issue_type=ISSUE_TYPES[0],
|
| 142 |
+
priority=PRIORITIES[0],
|
| 143 |
+
assignment_group=ASSIGNMENT_GROUPS[0],
|
| 144 |
+
resolution_action=RESOLUTION_ACTIONS[0], # multiple extra fields
|
| 145 |
+
)
|
| 146 |
+
try:
|
| 147 |
+
obs = env.step(action)
|
| 148 |
+
except Exception as exc: # noqa: BLE001
|
| 149 |
+
self.fail(f"step() raised an unexpected exception: {exc}")
|
| 150 |
+
|
| 151 |
+
self.assertIsInstance(obs, HelpdeskTicketObservation)
|
| 152 |
+
|
| 153 |
+
def test_extra_fields_done_flag_set_correctly_on_last_ticket(self) -> None:
|
| 154 |
+
"""When the penalty step is on the last ticket, done must be True."""
|
| 155 |
+
env = _make_env()
|
| 156 |
+
# Use a queue of size 1 by controlling the seed — find a seed that gives queue_size=1
|
| 157 |
+
# Instead, exhaust all but the last ticket normally, then trigger penalty on last
|
| 158 |
+
obs = env.reset(seed=42, task_id=1)
|
| 159 |
+
queue_size = obs.queue_size
|
| 160 |
+
|
| 161 |
+
# Process all tickets except the last one normally
|
| 162 |
+
for _ in range(queue_size - 1):
|
| 163 |
+
allowed = obs.allowed_fields
|
| 164 |
+
action_kwargs = {}
|
| 165 |
+
if "issue_type" in allowed:
|
| 166 |
+
action_kwargs["issue_type"] = ISSUE_TYPES[0]
|
| 167 |
+
if "priority" in allowed:
|
| 168 |
+
action_kwargs["priority"] = PRIORITIES[0]
|
| 169 |
+
obs = env.step(HelpdeskTicketAction(**action_kwargs))
|
| 170 |
+
|
| 171 |
+
# Now trigger penalty on the last ticket
|
| 172 |
+
action = HelpdeskTicketAction(
|
| 173 |
+
issue_type=ISSUE_TYPES[0],
|
| 174 |
+
assignment_group=ASSIGNMENT_GROUPS[0], # extra field
|
| 175 |
+
)
|
| 176 |
+
final_obs = env.step(action)
|
| 177 |
+
|
| 178 |
+
self.assertTrue(final_obs.done)
|
| 179 |
+
self.assertEqual(final_obs.reward, 0.0)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
unittest.main()
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|