Coding Ninja commited on
Commit
42dd095
·
1 Parent(s): 5dd60ae

feat: competitive upgrade for hackathon submission

Browse files

- inference.py: single-task mode via TASK_ID env var; clean warn/exit on invalid IDs
- models.py: add last_step_reward, done, history_entries to HelpdeskTicketState
- environment.py: state tracking, enriched history (title+predicted), ambiguity_note
in observations, extra-field penalty validation, SUPPORTS_CONCURRENT_SESSIONS=True
- reward.py: milestone shaping (+/-0.05 at score thresholds), remove overshoot penalty
- app.py: add GET /web HTML status page
- openenv.yaml: clarify entry_point vs pyproject.toml server script
- dataset.json: add 3 non-default routing tickets (TKT-NONDEFAULT-001/002/003)
- tests: test_competitive_upgrade.py (36 tests) + test_extra_fields_penalty.py (7 tests)

data/dataset.json CHANGED
@@ -538,6 +538,42 @@
538
  "resolution_action": "escalate",
539
  "ambiguity_note": null,
540
  "related_ticket_id": "ticket-030"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  }
542
  ]
543
 
 
538
  "resolution_action": "escalate",
539
  "ambiguity_note": null,
540
  "related_ticket_id": "ticket-030"
541
+ },
542
+ {
543
+ "ticket_id": "TKT-NONDEFAULT-001",
544
+ "title": "Billing question from free-tier account",
545
+ "requester": "user@freetier.io",
546
+ "description": "I have a question about my invoice but I am on the free plan and there is no charge. The billing team cannot action this; please route to service desk for general assistance.",
547
+ "issue_type": "billing_license",
548
+ "priority": "low",
549
+ "assignment_group": "service_desk",
550
+ "resolution_action": "fulfill",
551
+ "ambiguity_note": "Account tier is free; billing team cannot action, routed to service desk",
552
+ "related_ticket_id": null
553
+ },
554
+ {
555
+ "ticket_id": "TKT-NONDEFAULT-002",
556
+ "title": "App vulnerability flagged in compliance scan",
557
+ "requester": "security@clientcorp.com",
558
+ "description": "Our compliance scan flagged a product-specific vulnerability in the application layer. This is not a general security policy issue but an app bug requiring the application team to remediate.",
559
+ "issue_type": "security_compliance",
560
+ "priority": "high",
561
+ "assignment_group": "application_team",
562
+ "resolution_action": "escalate",
563
+ "ambiguity_note": "Compliance issue is product-specific (app vulnerability), routed to app team",
564
+ "related_ticket_id": null
565
+ },
566
+ {
567
+ "ticket_id": "TKT-NONDEFAULT-003",
568
+ "title": "Contractor onboarding blocked by access issue",
569
+ "requester": "pm@contractorco.com",
570
+ "description": "A new contractor cannot complete onboarding because their account access is blocked by a permissions error. The onboarding team cannot resolve access issues; routing to service desk.",
571
+ "issue_type": "onboarding",
572
+ "priority": "medium",
573
+ "assignment_group": "service_desk",
574
+ "resolution_action": "fulfill",
575
+ "ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk",
576
+ "related_ticket_id": null
577
  }
578
  ]
579
 
inference.py CHANGED
@@ -64,7 +64,7 @@ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
64
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
65
 
66
  SEED = 42
67
- TASKS = list(TASK_IDS)
68
 
69
  # ---------------------------------------------------------------------------
70
  # LLM helper
@@ -134,6 +134,20 @@ def emit_log(tag: str, **payload: Any) -> None:
134
  print(f"[{tag}] {json.dumps(payload, sort_keys=True, ensure_ascii=True)}")
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # ---------------------------------------------------------------------------
138
  # Heuristic fallback (no LLM needed)
139
  # ---------------------------------------------------------------------------
@@ -332,7 +346,10 @@ def run() -> None:
332
 
333
  all_results: dict[int, dict[str, float | int]] = {}
334
 
335
- for task_id in TASKS:
 
 
 
336
  if task_id not in available_tasks:
337
  continue
338
 
@@ -400,11 +417,12 @@ def run() -> None:
400
 
401
  overall = [
402
  float(all_results[task_id]["final_reward"])
403
- for task_id in TASKS
404
  if task_id in all_results
405
  ]
406
- overall_avg = round(sum(overall) / len(overall), 4) if overall else 0.0
407
- emit_log("END", overall_avg=overall_avg, tasks_completed=len(overall))
 
408
 
409
 
410
  if __name__ == "__main__":
 
64
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
65
 
66
  SEED = 42
67
+ TASK_ID_ENV = os.getenv("TASK_ID")
68
 
69
  # ---------------------------------------------------------------------------
70
  # LLM helper
 
134
  print(f"[{tag}] {json.dumps(payload, sort_keys=True, ensure_ascii=True)}")
135
 
136
 
137
+ def get_tasks_to_run(available_tasks: dict) -> list[int]:
138
+ if TASK_ID_ENV:
139
+ try:
140
+ task_id = int(TASK_ID_ENV)
141
+ except ValueError:
142
+ print(f"[ERROR] TASK_ID={TASK_ID_ENV!r} is not a valid integer", flush=True)
143
+ raise SystemExit(1)
144
+ if task_id not in available_tasks:
145
+ print(f"[WARN] TASK_ID={task_id} not in available tasks {list(available_tasks)}", flush=True)
146
+ return []
147
+ return [task_id]
148
+ return list(TASK_IDS) # fallback: all tasks (local dev)
149
+
150
+
151
  # ---------------------------------------------------------------------------
152
  # Heuristic fallback (no LLM needed)
153
  # ---------------------------------------------------------------------------
 
346
 
347
  all_results: dict[int, dict[str, float | int]] = {}
348
 
349
+ tasks_to_run = get_tasks_to_run(available_tasks)
350
+ single_task_mode = bool(TASK_ID_ENV)
351
+
352
+ for task_id in tasks_to_run:
353
  if task_id not in available_tasks:
354
  continue
355
 
 
417
 
418
  overall = [
419
  float(all_results[task_id]["final_reward"])
420
+ for task_id in tasks_to_run
421
  if task_id in all_results
422
  ]
423
+ if not single_task_mode:
424
+ overall_avg = round(sum(overall) / len(overall), 4) if overall else 0.0
425
+ emit_log("END", overall_avg=overall_avg, tasks_completed=len(overall))
426
 
427
 
428
  if __name__ == "__main__":
models.py CHANGED
@@ -112,3 +112,6 @@ class HelpdeskTicketState(State):
112
  current_ticket_index: int = 0
113
  per_ticket_scores: list[float] = Field(default_factory=list)
114
  total_reward: float = 0.0
 
 
 
 
112
  current_ticket_index: int = 0
113
  per_ticket_scores: list[float] = Field(default_factory=list)
114
  total_reward: float = 0.0
115
+ last_step_reward: Optional[float] = None
116
+ done: bool = False
117
+ history_entries: list[dict] = Field(default_factory=list)
openenv.yaml CHANGED
@@ -7,6 +7,9 @@ author: Hackstreet Boys - Roopal Guha Neogi, Suyash Kumar
7
 
8
  environment:
9
  type: openenv
 
 
 
10
  entry_point: server.environment:HelpdeskTicketRoutingEnvironment
11
  action_model: models:HelpdeskTicketAction
12
  observation_model: models:HelpdeskTicketObservation
 
7
 
8
  environment:
9
  type: openenv
10
+ # entry_point identifies the Environment class for the OpenEnv validator.
11
+ # The HTTP server entrypoint for deployment is defined separately in
12
+ # pyproject.toml under [project.scripts] as: server = "server.app:main"
13
  entry_point: server.environment:HelpdeskTicketRoutingEnvironment
14
  action_model: models:HelpdeskTicketAction
15
  observation_model: models:HelpdeskTicketObservation
server/app.py CHANGED
@@ -6,6 +6,7 @@ _repo_root = str(Path(__file__).resolve().parent.parent)
6
  if _repo_root not in sys.path:
7
  sys.path.insert(0, _repo_root)
8
 
 
9
  from openenv.core.env_server import create_app
10
 
11
  from models import HelpdeskTicketAction, HelpdeskTicketObservation
@@ -37,6 +38,25 @@ def list_tasks():
37
  }
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def main() -> None:
41
  import uvicorn
42
 
 
6
  if _repo_root not in sys.path:
7
  sys.path.insert(0, _repo_root)
8
 
9
+ from fastapi.responses import HTMLResponse
10
  from openenv.core.env_server import create_app
11
 
12
  from models import HelpdeskTicketAction, HelpdeskTicketObservation
 
38
  }
39
 
40
 
41
+ @app.get("/web", response_class=HTMLResponse)
42
+ def web_ui():
43
+ task_rows = "".join(
44
+ f"<tr><td>{t['id']}</td><td>{t['name']}</td><td>{t['difficulty']}</td></tr>"
45
+ for t in TASKS.values()
46
+ )
47
+ html = f"""<!DOCTYPE html>
48
+ <html><head><title>{APP_ENV_NAME}</title></head>
49
+ <body>
50
+ <h1>{APP_ENV_NAME}</h1>
51
+ <p>Version: 0.1.0 | <a href="/health">Health</a> | <a href="/docs">API Docs</a></p>
52
+ <h2>Tasks</h2>
53
+ <table border="1"><tr><th>ID</th><th>Name</th><th>Difficulty</th></tr>
54
+ {task_rows}
55
+ </table>
56
+ </body></html>"""
57
+ return HTMLResponse(content=html)
58
+
59
+
60
  def main() -> None:
61
  import uvicorn
62
 
server/environment.py CHANGED
@@ -36,6 +36,8 @@ def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
36
  class HelpdeskTicketRoutingEnvironment(
37
  Environment[HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState]
38
  ):
 
 
39
  def __init__(self) -> None:
40
  super().__init__()
41
  self._dataset = load_dataset()
@@ -94,16 +96,43 @@ class HelpdeskTicketRoutingEnvironment(
94
  task_id = self._state.current_task_id
95
  task = get_task_definition(task_id)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  score, breakdown = grade_action(action, current_ticket, task_id)
98
  step_reward = compute_step_reward(score)
99
 
100
- self._state.per_ticket_scores.append(score)
101
- self._state.step_count += 1
102
- self._state.current_ticket_index += 1
103
-
104
- is_done = self._state.current_ticket_index >= len(self._queue)
105
 
106
  if is_done:
 
 
 
107
  traj_reward = compute_trajectory_reward(
108
  self._state.per_ticket_scores,
109
  len(self._queue),
@@ -112,20 +141,24 @@ class HelpdeskTicketRoutingEnvironment(
112
  self._state.total_reward = traj_reward
113
  final_reward = traj_reward
114
  else:
 
 
 
115
  final_reward = step_reward
116
 
117
  history_entry = {
118
  "ticket_id": current_ticket.ticket_id,
 
 
119
  "score": score,
120
  "breakdown": breakdown,
121
  }
 
122
 
123
- return self._build_observation(
124
- task,
125
- done=is_done,
126
- reward=final_reward,
127
- extra_history=history_entry,
128
- )
129
 
130
  @property
131
  def state(self) -> HelpdeskTicketState:
@@ -140,27 +173,26 @@ class HelpdeskTicketRoutingEnvironment(
140
  task: dict,
141
  done: bool = False,
142
  reward: float | None = None,
143
- extra_history: dict | None = None,
144
  ) -> HelpdeskTicketObservation:
145
  idx = self._state.current_ticket_index
146
  queue_size = len(self._queue)
147
 
148
  if idx < queue_size:
149
  ticket = self._queue[idx]
150
- ticket_view = {
151
  "ticket_id": ticket.ticket_id,
152
  "title": ticket.title,
153
  "requester": ticket.requester,
154
  "description": ticket.description,
155
  }
 
 
 
 
156
  else:
157
  ticket_view = None
158
 
159
- history: list[dict] = []
160
- for i, s in enumerate(self._state.per_ticket_scores):
161
- history.append({"step": i + 1, "score": s})
162
- if extra_history and history:
163
- history[-1] = {"step": len(history), **extra_history}
164
 
165
  return HelpdeskTicketObservation(
166
  done=done,
@@ -172,6 +204,7 @@ class HelpdeskTicketRoutingEnvironment(
172
  allowed_fields=list(task["allowed_fields"]),
173
  current_ticket=ticket_view,
174
  queue_size=queue_size,
 
175
  tickets_remaining=max(0, queue_size - idx),
176
  tickets_processed=idx,
177
  history=history,
 
36
  class HelpdeskTicketRoutingEnvironment(
37
  Environment[HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState]
38
  ):
39
+ SUPPORTS_CONCURRENT_SESSIONS = True
40
+
41
  def __init__(self) -> None:
42
  super().__init__()
43
  self._dataset = load_dataset()
 
96
  task_id = self._state.current_task_id
97
  task = get_task_definition(task_id)
98
 
99
+ submitted_fields = {
100
+ f for f, v in action.model_dump(exclude_none=True).items() if v is not None
101
+ }
102
+ allowed = set(task["allowed_fields"])
103
+ extra_fields = submitted_fields - allowed
104
+ if extra_fields:
105
+ # Penalty: record score 0.0, advance index, return penalty observation
106
+ self._state.per_ticket_scores.append(0.0)
107
+ self._state.history_entries.append({
108
+ "ticket_id": current_ticket.ticket_id,
109
+ "title": current_ticket.title,
110
+ "predicted": action.model_dump(exclude_none=True),
111
+ "score": 0.0,
112
+ "breakdown": {},
113
+ "penalty_reason": f"extra_fields: {sorted(extra_fields)}",
114
+ })
115
+ self._state.step_count += 1
116
+ self._state.current_ticket_index += 1
117
+ is_done = self._state.current_ticket_index >= len(self._queue)
118
+ self._state.last_step_reward = 0.0
119
+ self._state.done = is_done
120
+ if is_done:
121
+ traj_reward = compute_trajectory_reward(
122
+ self._state.per_ticket_scores, len(self._queue), self._state.step_count
123
+ )
124
+ self._state.total_reward = traj_reward
125
+ return self._build_observation(task, done=is_done, reward=0.0)
126
+
127
  score, breakdown = grade_action(action, current_ticket, task_id)
128
  step_reward = compute_step_reward(score)
129
 
130
+ is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
 
 
 
 
131
 
132
  if is_done:
133
+ self._state.per_ticket_scores.append(score)
134
+ self._state.step_count += 1
135
+ self._state.current_ticket_index += 1
136
  traj_reward = compute_trajectory_reward(
137
  self._state.per_ticket_scores,
138
  len(self._queue),
 
141
  self._state.total_reward = traj_reward
142
  final_reward = traj_reward
143
  else:
144
+ self._state.per_ticket_scores.append(score)
145
+ self._state.step_count += 1
146
+ self._state.current_ticket_index += 1
147
  final_reward = step_reward
148
 
149
  history_entry = {
150
  "ticket_id": current_ticket.ticket_id,
151
+ "title": current_ticket.title,
152
+ "predicted": action.model_dump(exclude_none=True),
153
  "score": score,
154
  "breakdown": breakdown,
155
  }
156
+ self._state.history_entries.append(history_entry)
157
 
158
+ self._state.last_step_reward = final_reward
159
+ self._state.done = is_done
160
+
161
+ return self._build_observation(task, done=is_done, reward=final_reward)
 
 
162
 
163
  @property
164
  def state(self) -> HelpdeskTicketState:
 
173
  task: dict,
174
  done: bool = False,
175
  reward: float | None = None,
 
176
  ) -> HelpdeskTicketObservation:
177
  idx = self._state.current_ticket_index
178
  queue_size = len(self._queue)
179
 
180
  if idx < queue_size:
181
  ticket = self._queue[idx]
182
+ ticket_view: dict[str, Any] = {
183
  "ticket_id": ticket.ticket_id,
184
  "title": ticket.title,
185
  "requester": ticket.requester,
186
  "description": ticket.description,
187
  }
188
+ if ticket.ambiguity_note is not None:
189
+ ticket_view["ambiguity_note"] = ticket.ambiguity_note
190
+ if ticket.related_ticket_id is not None:
191
+ ticket_view["related_ticket_id"] = ticket.related_ticket_id
192
  else:
193
  ticket_view = None
194
 
195
+ history = list(self._state.history_entries)
 
 
 
 
196
 
197
  return HelpdeskTicketObservation(
198
  done=done,
 
204
  allowed_fields=list(task["allowed_fields"]),
205
  current_ticket=ticket_view,
206
  queue_size=queue_size,
207
+ # tickets_remaining: count of tickets not yet processed after this step
208
  tickets_remaining=max(0, queue_size - idx),
209
  tickets_processed=idx,
210
  history=history,
server/reward.py CHANGED
@@ -1,8 +1,18 @@
1
  from __future__ import annotations
2
 
 
 
 
 
 
3
 
4
  def compute_step_reward(score: float) -> float:
5
- return max(0.0, min(1.0, score))
 
 
 
 
 
6
 
7
 
8
  def compute_trajectory_reward(
@@ -11,6 +21,4 @@ def compute_trajectory_reward(
11
  if not per_ticket_scores:
12
  return 0.0
13
  avg = sum(per_ticket_scores) / len(per_ticket_scores)
14
- overshoot = max(0, steps_taken - queue_size)
15
- penalty = overshoot * 0.03
16
- return max(0.0, min(1.0, avg - penalty))
 
1
  from __future__ import annotations
2
 
3
+ MILESTONE_HIGH_THRESHOLD = 0.8
4
+ MILESTONE_LOW_THRESHOLD = 0.2
5
+ MILESTONE_BONUS = 0.05
6
+ MILESTONE_PENALTY = 0.05
7
+
8
 
9
  def compute_step_reward(score: float) -> float:
10
+ base = max(0.0, min(1.0, score))
11
+ if score >= MILESTONE_HIGH_THRESHOLD:
12
+ return min(1.0, base + MILESTONE_BONUS)
13
+ if score < MILESTONE_LOW_THRESHOLD:
14
+ return max(0.0, base - MILESTONE_PENALTY)
15
+ return base
16
 
17
 
18
  def compute_trajectory_reward(
 
21
  if not per_ticket_scores:
22
  return 0.0
23
  avg = sum(per_ticket_scores) / len(per_ticket_scores)
24
+ return max(0.0, min(1.0, avg))
 
 
tests/test_competitive_upgrade.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the helpdesk-competitive-upgrade spec (Task 9).
3
+
4
+ Covers:
5
+ 9.1 test_inference_single_task_mode
6
+ 9.2 test_state_has_reward_and_done
7
+ 9.3 test_history_has_title_and_predicted
8
+ 9.4 test_milestone_reward_shaping
9
+ 9.5 test_trajectory_reward_no_overshoot
10
+ 9.6 test_ambiguity_note_in_observation
11
+ 9.7 test_dataset_nondefault_routing
12
+ 9.9 test_concurrent_sessions_flag
13
+ 9.10 test_web_ui_endpoint
14
+
15
+ Run with:
16
+ pytest tests/test_competitive_upgrade.py
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import sys
22
+ import types as _types
23
+ import unittest
24
+
25
+ # Ensure repo root is on sys.path
26
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
27
+
28
+ import openenv_test_stubs # noqa: F401 — must come before any openenv imports
29
+
30
+ # Patch in the interfaces module so environment.py can import Environment.
31
+ if "openenv.core.env_server.interfaces" not in sys.modules:
32
+ _interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
33
+
34
+ class _Environment:
35
+ """Minimal stub matching the openenv-core Environment base class."""
36
+
37
+ def __init__(self) -> None:
38
+ pass
39
+
40
+ def __init_subclass__(cls, **kwargs: object) -> None:
41
+ super().__init_subclass__(**kwargs)
42
+
43
+ @classmethod
44
+ def __class_getitem__(cls, item: object) -> type:
45
+ return cls
46
+
47
+ _interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
48
+ sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
49
+
50
+
51
+ from models import HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState
52
+ from server.environment import HelpdeskTicketRoutingEnvironment
53
+ from server.reward import compute_step_reward, compute_trajectory_reward
54
+ from server.tasks import load_dataset
55
+ from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS, TASK_IDS
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Helpers
60
+ # ---------------------------------------------------------------------------
61
+
62
+ def _make_env() -> HelpdeskTicketRoutingEnvironment:
63
+ return HelpdeskTicketRoutingEnvironment()
64
+
65
+
66
+ def _heuristic_action(obs: HelpdeskTicketObservation) -> HelpdeskTicketAction:
67
+ allowed = obs.allowed_fields
68
+ kwargs: dict = {}
69
+ if "issue_type" in allowed:
70
+ kwargs["issue_type"] = ISSUE_TYPES[0]
71
+ if "priority" in allowed:
72
+ kwargs["priority"] = PRIORITIES[0]
73
+ if "assignment_group" in allowed:
74
+ kwargs["assignment_group"] = ASSIGNMENT_GROUPS[0]
75
+ if "resolution_action" in allowed:
76
+ kwargs["resolution_action"] = RESOLUTION_ACTIONS[0]
77
+ return HelpdeskTicketAction(**kwargs)
78
+
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # 9.1 — Inference single-task mode
82
+ # ---------------------------------------------------------------------------
83
+
84
+ def _get_tasks_to_run_impl(task_id_env: str | None, available_tasks: dict) -> list[int]:
85
+ """
86
+ Standalone re-implementation of inference.get_tasks_to_run() logic for testing.
87
+
88
+ This mirrors the logic in inference.py without importing the full module
89
+ (which has heavy dependencies like openai, httpx, and client.py).
90
+ """
91
+ if task_id_env:
92
+ try:
93
+ task_id = int(task_id_env)
94
+ except ValueError:
95
+ raise SystemExit(1)
96
+ if task_id not in available_tasks:
97
+ return []
98
+ return [task_id]
99
+ return list(TASK_IDS)
100
+
101
+
102
+ class TestInferenceSingleTaskMode(unittest.TestCase):
103
+ """9.1 — get_tasks_to_run() respects TASK_ID env var."""
104
+
105
+ def test_task_id_set_to_valid_id_returns_single_element_list(self) -> None:
106
+ available = {1: {}, 2: {}, 3: {}}
107
+ result = _get_tasks_to_run_impl("1", available)
108
+ self.assertEqual(result, [1])
109
+
110
+ def test_task_id_set_to_unavailable_id_returns_empty_list(self) -> None:
111
+ available = {1: {}, 2: {}, 3: {}}
112
+ result = _get_tasks_to_run_impl("999", available)
113
+ self.assertEqual(result, [])
114
+
115
+ def test_task_id_unset_returns_all_task_ids(self) -> None:
116
+ available = {1: {}, 2: {}, 3: {}}
117
+ result = _get_tasks_to_run_impl(None, available)
118
+ self.assertEqual(sorted(result), sorted(list(TASK_IDS)))
119
+
120
+ def test_task_id_set_to_2_returns_only_task_2(self) -> None:
121
+ available = {1: {}, 2: {}, 3: {}}
122
+ result = _get_tasks_to_run_impl("2", available)
123
+ self.assertEqual(result, [2])
124
+
125
+ def test_task_id_set_to_3_returns_only_task_3(self) -> None:
126
+ available = {1: {}, 2: {}, 3: {}}
127
+ result = _get_tasks_to_run_impl("3", available)
128
+ self.assertEqual(result, [3])
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # 9.2 — State has last_step_reward and done after step()
133
+ # ---------------------------------------------------------------------------
134
+
135
+ class TestStateHasRewardAndDone(unittest.TestCase):
136
+ """9.2 — state.last_step_reward and state.done are set after step()."""
137
+
138
+ def test_last_step_reward_is_none_after_reset(self) -> None:
139
+ env = _make_env()
140
+ env.reset(seed=42, task_id=1)
141
+ self.assertIsNone(env.state.last_step_reward)
142
+
143
+ def test_done_is_false_after_reset(self) -> None:
144
+ env = _make_env()
145
+ env.reset(seed=42, task_id=1)
146
+ self.assertFalse(env.state.done)
147
+
148
+ def test_last_step_reward_set_after_step(self) -> None:
149
+ env = _make_env()
150
+ obs = env.reset(seed=42, task_id=1)
151
+ action = _heuristic_action(obs)
152
+ env.step(action)
153
+ state = env.state
154
+ self.assertIsNotNone(state.last_step_reward)
155
+ self.assertGreaterEqual(state.last_step_reward, 0.0)
156
+ self.assertLessEqual(state.last_step_reward, 1.0)
157
+
158
+ def test_done_is_true_after_last_ticket(self) -> None:
159
+ env = _make_env()
160
+ obs = env.reset(seed=42, task_id=1)
161
+ while not obs.done:
162
+ obs = env.step(_heuristic_action(obs))
163
+ self.assertTrue(env.state.done)
164
+
165
+ def test_done_is_false_before_last_ticket(self) -> None:
166
+ env = _make_env()
167
+ obs = env.reset(seed=42, task_id=1)
168
+ if obs.queue_size > 1:
169
+ obs = env.step(_heuristic_action(obs))
170
+ self.assertFalse(env.state.done)
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # 9.3 — History entry contains title and predicted
175
+ # ---------------------------------------------------------------------------
176
+
177
+ class TestHistoryHasTitleAndPredicted(unittest.TestCase):
178
+ """9.3 — observation.history[0] contains 'title' and 'predicted' keys."""
179
+
180
+ def test_history_entry_has_title(self) -> None:
181
+ env = _make_env()
182
+ obs = env.reset(seed=42, task_id=1)
183
+ action = _heuristic_action(obs)
184
+ obs2 = env.step(action)
185
+ self.assertEqual(len(obs2.history), 1)
186
+ self.assertIn("title", obs2.history[0])
187
+ self.assertIsInstance(obs2.history[0]["title"], str)
188
+ self.assertTrue(obs2.history[0]["title"]) # non-empty
189
+
190
+ def test_history_entry_has_predicted(self) -> None:
191
+ env = _make_env()
192
+ obs = env.reset(seed=42, task_id=1)
193
+ action = _heuristic_action(obs)
194
+ obs2 = env.step(action)
195
+ self.assertIn("predicted", obs2.history[0])
196
+ self.assertIsInstance(obs2.history[0]["predicted"], dict)
197
+
198
+ def test_history_predicted_matches_action(self) -> None:
199
+ env = _make_env()
200
+ obs = env.reset(seed=42, task_id=1)
201
+ action = _heuristic_action(obs)
202
+ obs2 = env.step(action)
203
+ predicted = obs2.history[0]["predicted"]
204
+ action_dict = action.model_dump(exclude_none=True)
205
+ self.assertEqual(predicted, action_dict)
206
+
207
+ def test_history_entry_has_ticket_id_and_score(self) -> None:
208
+ env = _make_env()
209
+ obs = env.reset(seed=42, task_id=1)
210
+ obs2 = env.step(_heuristic_action(obs))
211
+ entry = obs2.history[0]
212
+ self.assertIn("ticket_id", entry)
213
+ self.assertIn("score", entry)
214
+
215
+
216
+ # ---------------------------------------------------------------------------
217
+ # 9.4 — Milestone reward shaping
218
+ # ---------------------------------------------------------------------------
219
+
220
+ class TestMilestoneRewardShaping(unittest.TestCase):
221
+ """9.4 — compute_step_reward applies bonus at high scores, penalty at low scores."""
222
+
223
+ def test_high_score_gets_bonus(self) -> None:
224
+ # score=0.9 >= 0.8 threshold → base=0.9, bonus=0.05 → 0.95
225
+ result = compute_step_reward(0.9)
226
+ self.assertAlmostEqual(result, 0.95, places=9)
227
+
228
+ def test_low_score_gets_penalty(self) -> None:
229
+ # score=0.1 < 0.2 threshold → base=0.1, penalty=0.05 → 0.05
230
+ result = compute_step_reward(0.1)
231
+ self.assertAlmostEqual(result, 0.05, places=9)
232
+
233
+ def test_mid_score_is_neutral(self) -> None:
234
+ # score=0.5 is in [0.2, 0.8) → no shaping → 0.5
235
+ result = compute_step_reward(0.5)
236
+ self.assertAlmostEqual(result, 0.5, places=9)
237
+
238
+ def test_boundary_high_threshold_gets_bonus(self) -> None:
239
+ # score=0.8 exactly → bonus applies → 0.85
240
+ result = compute_step_reward(0.8)
241
+ self.assertAlmostEqual(result, 0.85, places=9)
242
+
243
+ def test_boundary_low_threshold_is_neutral(self) -> None:
244
+ # score=0.2 exactly → not < 0.2, so neutral → 0.2
245
+ result = compute_step_reward(0.2)
246
+ self.assertAlmostEqual(result, 0.2, places=9)
247
+
248
+ def test_reward_clamped_to_unit_interval(self) -> None:
249
+ # score=1.0 → base=1.0, bonus would push to 1.05 → clamped to 1.0
250
+ result = compute_step_reward(1.0)
251
+ self.assertLessEqual(result, 1.0)
252
+ self.assertGreaterEqual(result, 0.0)
253
+
254
+ def test_zero_score_clamped_to_zero(self) -> None:
255
+ # score=0.0 < 0.2 → base=0.0, penalty → max(0.0, -0.05) = 0.0
256
+ result = compute_step_reward(0.0)
257
+ self.assertGreaterEqual(result, 0.0)
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # 9.5 — Trajectory reward has no overshoot penalty
262
+ # ---------------------------------------------------------------------------
263
+
264
+ class TestTrajectoryRewardNoOvershoot(unittest.TestCase):
265
+ """9.5 — compute_trajectory_reward does not penalise when steps > queue_size."""
266
+
267
+ def test_no_penalty_when_steps_exceed_queue_size(self) -> None:
268
+ scores = [0.8, 0.9, 0.7]
269
+ queue_size = 3
270
+ steps_taken = 10 # more steps than queue_size
271
+ result = compute_trajectory_reward(scores, queue_size, steps_taken)
272
+ expected_avg = sum(scores) / len(scores)
273
+ self.assertAlmostEqual(result, expected_avg, places=9)
274
+
275
+ def test_result_equals_average_regardless_of_steps(self) -> None:
276
+ scores = [0.5, 0.6]
277
+ for steps in [1, 2, 5, 100]:
278
+ result = compute_trajectory_reward(scores, len(scores), steps)
279
+ self.assertAlmostEqual(result, 0.55, places=9,
280
+ msg=f"Failed for steps={steps}")
281
+
282
+ def test_empty_scores_returns_zero(self) -> None:
283
+ self.assertEqual(compute_trajectory_reward([], 3, 3), 0.0)
284
+
285
+ def test_result_in_unit_interval(self) -> None:
286
+ scores = [0.9, 1.0, 0.95]
287
+ result = compute_trajectory_reward(scores, 3, 3)
288
+ self.assertGreaterEqual(result, 0.0)
289
+ self.assertLessEqual(result, 1.0)
290
+
291
+
292
+ # ---------------------------------------------------------------------------
293
+ # 9.6 — ambiguity_note appears in current_ticket observation
294
+ # ---------------------------------------------------------------------------
295
+
296
+ class TestAmbiguityNoteInObservation(unittest.TestCase):
297
+ """9.6 — current_ticket includes ambiguity_note when the ticket has one."""
298
+
299
+ def _find_seed_with_ambiguity_note(self, task_id: int = 3) -> int | None:
300
+ """Try seeds 0..999 to find one where the first ticket has ambiguity_note."""
301
+ env = _make_env()
302
+ for seed in range(1000):
303
+ obs = env.reset(seed=seed, task_id=task_id)
304
+ if obs.current_ticket and obs.current_ticket.get("ambiguity_note"):
305
+ return seed
306
+ return None
307
+
308
+ def test_ambiguity_note_present_when_ticket_has_one(self) -> None:
309
+ """Force a ticket with ambiguity_note by patching the dataset."""
310
+ from unittest.mock import patch
311
+ from server.tasks import load_dataset
312
+
313
+ dataset = load_dataset()
314
+ # Find a ticket with ambiguity_note
315
+ ambiguous_tickets = [t for t in dataset if t.ambiguity_note is not None]
316
+ self.assertGreater(len(ambiguous_tickets), 0, "No tickets with ambiguity_note in dataset")
317
+
318
+ target = ambiguous_tickets[0]
319
+
320
+ env = _make_env()
321
+ # Patch the dataset to only contain the ambiguous ticket
322
+ with patch.object(env, "_dataset", [target]):
323
+ obs = env.reset(seed=0, task_id=3)
324
+
325
+ self.assertIsNotNone(obs.current_ticket)
326
+ self.assertIn("ambiguity_note", obs.current_ticket)
327
+ self.assertEqual(obs.current_ticket["ambiguity_note"], target.ambiguity_note)
328
+
329
+ def test_ambiguity_note_absent_when_ticket_has_none(self) -> None:
330
+ """Tickets without ambiguity_note should not expose the key."""
331
+ from unittest.mock import patch
332
+ from server.tasks import load_dataset
333
+
334
+ dataset = load_dataset()
335
+ non_ambiguous = [t for t in dataset if t.ambiguity_note is None]
336
+ self.assertGreater(len(non_ambiguous), 0)
337
+
338
+ target = non_ambiguous[0]
339
+ env = _make_env()
340
+ with patch.object(env, "_dataset", [target]):
341
+ obs = env.reset(seed=0, task_id=3)
342
+
343
+ self.assertIsNotNone(obs.current_ticket)
344
+ self.assertNotIn("ambiguity_note", obs.current_ticket)
345
+
346
+ def test_tkt_nondefault_001_has_ambiguity_note(self) -> None:
347
+ """TKT-NONDEFAULT-001 specifically has ambiguity_note set."""
348
+ from unittest.mock import patch
349
+ from server.tasks import load_dataset
350
+
351
+ dataset = load_dataset()
352
+ ticket = next((t for t in dataset if t.ticket_id == "TKT-NONDEFAULT-001"), None)
353
+ self.assertIsNotNone(ticket, "TKT-NONDEFAULT-001 not found in dataset")
354
+ self.assertIsNotNone(ticket.ambiguity_note)
355
+
356
+ env = _make_env()
357
+ with patch.object(env, "_dataset", [ticket]):
358
+ obs = env.reset(seed=0, task_id=3)
359
+
360
+ self.assertIn("ambiguity_note", obs.current_ticket)
361
+
362
+
363
+ # ---------------------------------------------------------------------------
364
+ # 9.7 — Dataset has >= 3 non-default routing tickets
365
+ # ---------------------------------------------------------------------------
366
+
367
+ class TestDatasetNonDefaultRouting(unittest.TestCase):
368
+ """9.7 — Dataset contains at least 3 tickets with non-default assignment_group."""
369
+
370
+ def test_at_least_three_nondefault_routing_tickets(self) -> None:
371
+ from vocabulary import ISSUE_TYPE_TO_ASSIGNMENT_GROUP
372
+
373
+ dataset = load_dataset()
374
+ non_default = [
375
+ t for t in dataset
376
+ if t.assignment_group != ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(t.issue_type)
377
+ ]
378
+ self.assertGreaterEqual(
379
+ len(non_default), 3,
380
+ f"Expected >= 3 non-default routing tickets, found {len(non_default)}: "
381
+ + str([(t.ticket_id, t.issue_type, t.assignment_group) for t in non_default])
382
+ )
383
+
384
+ def test_tkt_nondefault_tickets_exist(self) -> None:
385
+ dataset = load_dataset()
386
+ ids = {t.ticket_id for t in dataset}
387
+ for expected_id in ("TKT-NONDEFAULT-001", "TKT-NONDEFAULT-002", "TKT-NONDEFAULT-003"):
388
+ self.assertIn(expected_id, ids, f"{expected_id} not found in dataset")
389
+
390
+
391
+ # ---------------------------------------------------------------------------
392
+ # 9.9 — SUPPORTS_CONCURRENT_SESSIONS is True
393
+ # ---------------------------------------------------------------------------
394
+
395
+ class TestConcurrentSessionsFlag(unittest.TestCase):
396
+ """9.9 — HelpdeskTicketRoutingEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True."""
397
+
398
+ def test_supports_concurrent_sessions_is_true(self) -> None:
399
+ self.assertTrue(HelpdeskTicketRoutingEnvironment.SUPPORTS_CONCURRENT_SESSIONS)
400
+
401
+ def test_flag_is_boolean_true(self) -> None:
402
+ flag = HelpdeskTicketRoutingEnvironment.SUPPORTS_CONCURRENT_SESSIONS
403
+ self.assertIs(flag, True)
404
+
405
+
406
+ # ---------------------------------------------------------------------------
407
+ # 9.10 — GET /web returns 200 with HTML content
408
+ # ---------------------------------------------------------------------------
409
+
410
+ def _build_web_test_app():
411
+ """Build a minimal FastAPI app with only the /web route for testing."""
412
+ from fastapi import FastAPI
413
+ from fastapi.responses import HTMLResponse
414
+ from server.tasks import TASKS
415
+ from vocabulary import APP_ENV_NAME
416
+
417
+ _app = FastAPI()
418
+
419
+ @_app.get("/web", response_class=HTMLResponse)
420
+ def web_ui():
421
+ task_rows = "".join(
422
+ f"<tr><td>{t['id']}</td><td>{t['name']}</td><td>{t['difficulty']}</td></tr>"
423
+ for t in TASKS.values()
424
+ )
425
+ html = f"""<!DOCTYPE html>
426
+ <html><head><title>{APP_ENV_NAME}</title></head>
427
+ <body>
428
+ <h1>{APP_ENV_NAME}</h1>
429
+ <p>Version: 0.1.0 | <a href="/health">Health</a> | <a href="/docs">API Docs</a></p>
430
+ <h2>Tasks</h2>
431
+ <table border="1"><tr><th>ID</th><th>Name</th><th>Difficulty</th></tr>
432
+ {task_rows}
433
+ </table>
434
+ </body></html>"""
435
+ return HTMLResponse(content=html)
436
+
437
+ return _app
438
+
439
+
440
+ class TestWebUIEndpoint(unittest.TestCase):
441
+ """9.10 — GET /web returns HTTP 200 with HTML content."""
442
+
443
+ @classmethod
444
+ def setUpClass(cls) -> None:
445
+ from starlette.testclient import TestClient
446
+ app = _build_web_test_app()
447
+ cls.client = TestClient(app)
448
+
449
+ def test_web_returns_200(self) -> None:
450
+ response = self.client.get("/web")
451
+ self.assertEqual(response.status_code, 200)
452
+
453
+ def test_web_returns_html_content_type(self) -> None:
454
+ response = self.client.get("/web")
455
+ self.assertIn("text/html", response.headers.get("content-type", ""))
456
+
457
+ def test_web_response_contains_html_tag(self) -> None:
458
+ response = self.client.get("/web")
459
+ self.assertIn("<!DOCTYPE html>", response.text)
460
+
461
+ def test_web_response_contains_env_name(self) -> None:
462
+ from vocabulary import APP_ENV_NAME
463
+ response = self.client.get("/web")
464
+ self.assertIn(APP_ENV_NAME, response.text)
465
+
466
+
467
+ if __name__ == "__main__":
468
+ unittest.main()
tests/test_extra_fields_penalty.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for action field validation (Task 4) in HelpdeskTicketRoutingEnvironment.step().
3
+
4
+ Validates Requirement 7: Step Validates Action Fields Against Task Contract.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import sys
9
+ import os
10
+ import unittest
11
+ import types as _types
12
+
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
+
15
+ import openenv_test_stubs # noqa: F401
16
+
17
+ if "openenv.core.env_server.interfaces" not in sys.modules:
18
+ _interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
19
+
20
+ class _Environment:
21
+ def __init__(self) -> None:
22
+ pass
23
+
24
+ def __init_subclass__(cls, **kwargs: object) -> None:
25
+ super().__init_subclass__(**kwargs)
26
+
27
+ @classmethod
28
+ def __class_getitem__(cls, item: object) -> type:
29
+ return cls
30
+
31
+ _interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
32
+ sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
33
+
34
+ from models import HelpdeskTicketAction, HelpdeskTicketObservation
35
+ from server.environment import HelpdeskTicketRoutingEnvironment
36
+ from server.tasks import TASKS
37
+ from vocabulary import ISSUE_TYPES, PRIORITIES, ASSIGNMENT_GROUPS, RESOLUTION_ACTIONS
38
+
39
+
40
+ def _make_env() -> HelpdeskTicketRoutingEnvironment:
41
+ return HelpdeskTicketRoutingEnvironment()
42
+
43
+
44
+ class TestExtraFieldsPenalty(unittest.TestCase):
45
+ """Requirement 7: step() rejects actions with fields outside the task's allowed_fields."""
46
+
47
+ def test_extra_fields_returns_reward_zero(self) -> None:
48
+ """Task 1 only allows issue_type and priority; submitting assignment_group triggers penalty."""
49
+ env = _make_env()
50
+ obs = env.reset(seed=42, task_id=1)
51
+
52
+ # Task 1 allowed_fields should NOT include assignment_group
53
+ self.assertNotIn("assignment_group", obs.allowed_fields)
54
+
55
+ # Submit an action with an extra field (assignment_group) not in task 1's allowed_fields
56
+ action = HelpdeskTicketAction(
57
+ issue_type=ISSUE_TYPES[0],
58
+ priority=PRIORITIES[0],
59
+ assignment_group=ASSIGNMENT_GROUPS[0], # extra field
60
+ )
61
+ penalty_obs = env.step(action)
62
+
63
+ self.assertIsInstance(penalty_obs, HelpdeskTicketObservation)
64
+ self.assertEqual(penalty_obs.reward, 0.0)
65
+
66
+ def test_extra_fields_advances_ticket_index(self) -> None:
67
+ """Penalty step must advance tickets_processed by 1."""
68
+ env = _make_env()
69
+ obs = env.reset(seed=42, task_id=1)
70
+ self.assertEqual(obs.tickets_processed, 0)
71
+
72
+ action = HelpdeskTicketAction(
73
+ issue_type=ISSUE_TYPES[0],
74
+ assignment_group=ASSIGNMENT_GROUPS[0], # extra field for task 1
75
+ )
76
+ penalty_obs = env.step(action)
77
+
78
+ self.assertEqual(penalty_obs.tickets_processed, 1)
79
+
80
+ def test_extra_fields_records_score_zero(self) -> None:
81
+ """per_ticket_scores must contain 0.0 after a penalty step."""
82
+ env = _make_env()
83
+ env.reset(seed=42, task_id=1)
84
+
85
+ action = HelpdeskTicketAction(
86
+ issue_type=ISSUE_TYPES[0],
87
+ assignment_group=ASSIGNMENT_GROUPS[0], # extra field
88
+ )
89
+ env.step(action)
90
+
91
+ state = env.state
92
+ self.assertEqual(len(state.per_ticket_scores), 1)
93
+ self.assertEqual(state.per_ticket_scores[0], 0.0)
94
+
95
+ def test_extra_fields_history_entry_has_penalty_reason(self) -> None:
96
+ """History entry for a penalty step must include penalty_reason."""
97
+ env = _make_env()
98
+ env.reset(seed=42, task_id=1)
99
+
100
+ action = HelpdeskTicketAction(
101
+ issue_type=ISSUE_TYPES[0],
102
+ assignment_group=ASSIGNMENT_GROUPS[0], # extra field
103
+ )
104
+ penalty_obs = env.step(action)
105
+
106
+ self.assertEqual(len(penalty_obs.history), 1)
107
+ entry = penalty_obs.history[0]
108
+ self.assertIn("penalty_reason", entry)
109
+ self.assertIn("assignment_group", entry["penalty_reason"])
110
+ self.assertEqual(entry["score"], 0.0)
111
+
112
+ def test_no_extra_fields_grades_normally(self) -> None:
113
+ """When action fields are within allowed_fields, grading proceeds normally (reward != forced 0.0)."""
114
+ env = _make_env()
115
+ obs = env.reset(seed=42, task_id=1)
116
+
117
+ # Build action using only allowed fields
118
+ allowed = obs.allowed_fields
119
+ action_kwargs = {}
120
+ if "issue_type" in allowed:
121
+ action_kwargs["issue_type"] = ISSUE_TYPES[0]
122
+ if "priority" in allowed:
123
+ action_kwargs["priority"] = PRIORITIES[0]
124
+
125
+ action = HelpdeskTicketAction(**action_kwargs)
126
+ result_obs = env.step(action)
127
+
128
+ # Should be a valid observation; reward may be any value in [0.0, 1.0]
129
+ self.assertIsInstance(result_obs, HelpdeskTicketObservation)
130
+ self.assertIsNotNone(result_obs.reward)
131
+ # No penalty_reason in history
132
+ self.assertEqual(len(result_obs.history), 1)
133
+ self.assertNotIn("penalty_reason", result_obs.history[0])
134
+
135
+ def test_extra_fields_no_exception_raised(self) -> None:
136
+ """Requirement 7.4: extra fields must not raise an unhandled exception."""
137
+ env = _make_env()
138
+ env.reset(seed=42, task_id=1)
139
+
140
+ action = HelpdeskTicketAction(
141
+ issue_type=ISSUE_TYPES[0],
142
+ priority=PRIORITIES[0],
143
+ assignment_group=ASSIGNMENT_GROUPS[0],
144
+ resolution_action=RESOLUTION_ACTIONS[0], # multiple extra fields
145
+ )
146
+ try:
147
+ obs = env.step(action)
148
+ except Exception as exc: # noqa: BLE001
149
+ self.fail(f"step() raised an unexpected exception: {exc}")
150
+
151
+ self.assertIsInstance(obs, HelpdeskTicketObservation)
152
+
153
+ def test_extra_fields_done_flag_set_correctly_on_last_ticket(self) -> None:
154
+ """When the penalty step is on the last ticket, done must be True."""
155
+ env = _make_env()
156
+ # Use a queue of size 1 by controlling the seed — find a seed that gives queue_size=1
157
+ # Instead, exhaust all but the last ticket normally, then trigger penalty on last
158
+ obs = env.reset(seed=42, task_id=1)
159
+ queue_size = obs.queue_size
160
+
161
+ # Process all tickets except the last one normally
162
+ for _ in range(queue_size - 1):
163
+ allowed = obs.allowed_fields
164
+ action_kwargs = {}
165
+ if "issue_type" in allowed:
166
+ action_kwargs["issue_type"] = ISSUE_TYPES[0]
167
+ if "priority" in allowed:
168
+ action_kwargs["priority"] = PRIORITIES[0]
169
+ obs = env.step(HelpdeskTicketAction(**action_kwargs))
170
+
171
+ # Now trigger penalty on the last ticket
172
+ action = HelpdeskTicketAction(
173
+ issue_type=ISSUE_TYPES[0],
174
+ assignment_group=ASSIGNMENT_GROUPS[0], # extra field
175
+ )
176
+ final_obs = env.step(action)
177
+
178
+ self.assertTrue(final_obs.done)
179
+ self.assertEqual(final_obs.reward, 0.0)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ unittest.main()
uv.lock CHANGED
The diff for this file is too large to render. See raw diff