Roopalgn commited on
Commit
67ce1eb
·
1 Parent(s): 8ccf96d

Add policy learning loop and strengthen RL-style environment

Browse files
.gitignore CHANGED
@@ -6,3 +6,7 @@ __pycache__/
6
  .mypy_cache/
7
  .ruff_cache/
8
  build/
 
 
 
 
 
6
  .mypy_cache/
7
  .ruff_cache/
8
  build/
9
+ analysis/policy_learning_runs/
10
+ analysis/policy_learning_test/
11
+ analysis/policy_learning_compare_test/
12
+ analysis/policy_learning_runs_smoke/
README.md CHANGED
@@ -38,6 +38,8 @@ The environment models a realistic helpdesk workflow:
38
  4. the grader assigns deterministic credit
39
  5. the environment advances to the next ticket until the queue is complete
40
 
 
 
41
  This domain is useful for OpenEnv because it is operationally realistic, easy to evaluate with typed outputs, and naturally supports a clean easy-to-hard task ladder.
42
 
43
  ## Why This Is A Good Hackathon Domain
@@ -59,6 +61,37 @@ The project uses a queue-based episode model.
59
 
60
  The environment classes and vocabulary are intentionally frozen to keep collaboration and judging simple.
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ## Task Ladder
63
 
64
  | ID | Name | Difficulty | Required Fields | What The Agent Must Do |
@@ -125,6 +158,7 @@ Each observation also includes:
125
  - `task_name`
126
  - `instructions`
127
  - `allowed_fields`
 
128
  - `available_tools`
129
  - `investigation_budget_remaining`
130
  - `last_tool_result`
@@ -133,7 +167,12 @@ Each observation also includes:
133
  - `tickets_after_current`
134
  - `tickets_processed`
135
  - `queue_position`
 
 
136
  - `history`
 
 
 
137
  - standard OpenEnv fields such as `done` and `reward`
138
 
139
  The internal `HelpdeskTicketState` tracks:
@@ -162,6 +201,15 @@ Available tools:
162
 
163
  - `lookup_related_ticket`
164
  - `lookup_requester_history`
 
 
 
 
 
 
 
 
 
165
 
166
  Per-field behavior:
167
 
@@ -190,6 +238,12 @@ Step reward is lightly milestone-shaped: high per-ticket scores get a small bonu
190
 
191
  Final reward also includes a tiny queue-economics penalty only when the agent exceeds the free investigation budget. One investigation per queued ticket is free; extra investigation steps reduce the final reward slightly.
192
 
 
 
 
 
 
 
193
  ## Grounded Scoring
194
 
195
  The grader is intentionally not fuzzy by default.
@@ -343,6 +397,7 @@ Optional target:
343
 
344
  - `ENV_URL`
345
  - default value: `http://localhost:7860`
 
346
  - `TASK_ID`
347
  - `RUN_ALL_TASKS`
348
 
 
38
  4. the grader assigns deterministic credit
39
  5. the environment advances to the next ticket until the queue is complete
40
 
41
+ For hard-task tickets, the environment can now withhold decisive routing context until the agent uses the right investigation tool. That keeps the task from collapsing into one-shot classification and makes tool choice part of the policy.
42
+
43
  This domain is useful for OpenEnv because it is operationally realistic, easy to evaluate with typed outputs, and naturally supports a clean easy-to-hard task ladder.
44
 
45
  ## Why This Is A Good Hackathon Domain
 
61
 
62
  The environment classes and vocabulary are intentionally frozen to keep collaboration and judging simple.
63
 
64
+ ## Lightweight Policy Improvement Loop
65
+
66
+ The repo now includes a small local learning runner in `policy_learning.py`. It does not update model weights, but it does run repeated rollouts over many seeds, log full trajectories, and select the best policy configuration from a discrete candidate set using observed reward.
67
+
68
+ That gives the project a real improvement loop for judge demos:
69
+
70
+ - compare `no_investigation` against `investigate_when_context_hidden`
71
+ - log per-step rewards, feedback summaries, and reward components to JSONL
72
+ - search over small policy variants such as `legacy_single_probe`, `context_chain`, and `hybrid_context`
73
+ - select the best policy on train seeds, then re-evaluate it on holdout seeds
74
+
75
+ Example commands:
76
+
77
+ ```bash
78
+ python policy_learning.py compare --seeds 42-51 --task-ids 1,2,3
79
+ python policy_learning.py search --train-seeds 40-49 --eval-seeds 50-59 --task-ids 1,2,3
80
+ ```
81
+
82
+ Artifacts are written to `analysis/policy_learning_runs/` by default:
83
+
84
+ - `compare_summary.json`
85
+ - `compare_episodes.jsonl`
86
+ - `compare_trajectories.jsonl`
87
+ - `search_summary.json`
88
+ - `search_train_episodes.jsonl`
89
+ - `search_train_trajectories.jsonl`
90
+ - `search_eval_episodes.jsonl`
91
+ - `search_eval_trajectories.jsonl`
92
+
93
+ The default submit policy inside this runner stays deterministic and local. It reuses the repo's heuristic routing logic, so the discrete policy search focuses on investigation behavior and reward-driven policy selection rather than on external LLM latency or API cost.
94
+
95
  ## Task Ladder
96
 
97
  | ID | Name | Difficulty | Required Fields | What The Agent Must Do |
 
158
  - `task_name`
159
  - `instructions`
160
  - `allowed_fields`
161
+ - `available_action_types`
162
  - `available_tools`
163
  - `investigation_budget_remaining`
164
  - `last_tool_result`
 
167
  - `tickets_after_current`
168
  - `tickets_processed`
169
  - `queue_position`
170
+ - `average_score_so_far`
171
+ - `progress_fraction`
172
  - `history`
173
+ - `last_reward_components`
174
+ - `rubric_reward` on terminal observations
175
+ - `metadata.last_feedback_summary` for compact reward / penalty feedback
176
  - standard OpenEnv fields such as `done` and `reward`
177
 
178
  The internal `HelpdeskTicketState` tracks:
 
201
 
202
  - `lookup_related_ticket`
203
  - `lookup_requester_history`
204
+ - `lookup_internal_routing_note`
205
+
206
+ Hard-task investigation behavior:
207
+
208
+ - some ambiguous and non-default-routing tickets start with redacted descriptions
209
+ - linked-ticket previews and internal routing notes stay hidden until the matching tool is used
210
+ - useful investigation steps return a small positive shaping reward
211
+ - premature hard-task submission can incur a shaping penalty even when the visible text looks plausible
212
+ - terminal `rubric_reward` remains the objective evaluation signal, while per-step `reward` is the denser training signal
213
 
214
  Per-field behavior:
215
 
 
238
 
239
  Final reward also includes a tiny queue-economics penalty only when the agent exceeds the free investigation budget. One investigation per queued ticket is free; extra investigation steps reduce the final reward slightly.
240
 
241
+ To make the environment more RL-friendly, each observation now also surfaces structured reward telemetry:
242
+
243
+ - `last_reward_components` exposes ticket score, shaped step reward, milestone adjustment, trajectory reward when applicable, and any investigation penalty applied
244
+ - `average_score_so_far` and `progress_fraction` expose trajectory progress without leaking future labels
245
+ - `history` retains the same reward components plus a compact `feedback_summary` string for downstream agents
246
+
247
  ## Grounded Scoring
248
 
249
  The grader is intentionally not fuzzy by default.
 
397
 
398
  - `ENV_URL`
399
  - default value: `http://localhost:7860`
400
+ - `SEED`
401
  - `TASK_ID`
402
  - `RUN_ALL_TASKS`
403
 
inference.py CHANGED
@@ -66,13 +66,27 @@ from vocabulary import (
66
  DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
67
  DEFAULT_MODEL_NAME = "<your-active-model>"
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
70
  MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
71
  HF_TOKEN = os.getenv("HF_TOKEN")
72
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
73
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
74
 
75
- SEED = 42
76
  TASK_ID_ENV = os.getenv("TASK_ID")
77
  RUN_ALL_TASKS_ENV = os.getenv("RUN_ALL_TASKS", "").strip().lower() in {
78
  "1",
@@ -94,6 +108,14 @@ if llm_mode_enabled():
94
  llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
95
 
96
 
 
 
 
 
 
 
 
 
97
  SYSTEM_PROMPT = """\
98
  You are an expert IT helpdesk ticket routing agent. Given a helpdesk ticket, you must produce a JSON object with the requested fields.
99
 
@@ -103,19 +125,79 @@ Valid values:
103
  - assignment_group: {assignment_groups}
104
  - resolution_action: {resolution_actions}
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  Return ONLY valid JSON with the requested fields. No markdown, no explanation.""".format(
107
  issue_types=", ".join(ISSUE_TYPES),
108
  priorities=", ".join(PRIORITIES),
109
  assignment_groups=", ".join(ASSIGNMENT_GROUPS),
110
  resolution_actions=", ".join(RESOLUTION_ACTIONS),
 
111
  )
112
 
113
 
114
- def call_llm(ticket: dict, allowed_fields: list[str], instructions: str) -> dict:
115
- assert llm_client is not None, "LLM client not configured"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ambiguity_note = ticket.get("ambiguity_note")
117
  related_preview = ticket.get("related_ticket_preview") or {}
118
  last_tool_result = ticket.get("last_tool_result")
 
 
 
 
 
 
 
119
  extra_context_lines: list[str] = []
120
  if ambiguity_note:
121
  extra_context_lines.append(f"Ambiguity note: {ambiguity_note}")
@@ -132,20 +214,53 @@ def call_llm(ticket: dict, allowed_fields: list[str], instructions: str) -> dict
132
  extra_context_lines.append(
133
  "Investigation result: " + json.dumps(last_tool_result, sort_keys=True)
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  extra_context_block = ""
136
  if extra_context_lines:
137
  extra_context_block = "\n" + "\n".join(extra_context_lines)
138
 
139
- user_msg = (
140
  f"Instructions: {instructions}\n\n"
141
  f"Allowed fields: {', '.join(allowed_fields)}\n\n"
142
- f"Title: {ticket['title']}\n"
143
- f"Requester: {ticket['requester']}\n"
144
- f"Description: {ticket['description']}"
145
  f"{extra_context_block}\n\n"
146
  f"Respond with JSON containing ONLY these fields: {', '.join(allowed_fields)}"
147
  )
148
 
 
 
 
 
 
149
  response = llm_client.chat.completions.create(
150
  model=MODEL_NAME,
151
  messages=[
@@ -298,6 +413,95 @@ FULFILL_KEYWORDS = (
298
  "mfa enabled",
299
  )
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  def heuristic_priority(text: str) -> str:
303
  if any(word in text for word in CRITICAL_PRIORITY_KEYWORDS):
@@ -323,26 +527,32 @@ def heuristic_resolution_action(text: str, issue_type: str) -> str:
323
  return ISSUE_TYPE_TO_RESOLUTION_ACTION.get(issue_type, "acknowledge")
324
 
325
 
326
- def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
327
- related_preview = ticket.get("related_ticket_preview") or {}
328
- last_tool_result = ticket.get("last_tool_result") or {}
329
- text = " ".join(
330
- [
331
- ticket.get("title", ""),
332
- ticket.get("description", ""),
333
- ticket.get("ambiguity_note", ""),
334
- related_preview.get("title", ""),
335
- related_preview.get("description", ""),
336
- json.dumps(last_tool_result, sort_keys=True),
337
- ]
338
- ).lower()
339
 
 
 
340
  issue_type = "general_inquiry"
341
  for kw, mapped_issue_type in KEYWORD_ISSUE_TYPES.items():
342
  if kw in text:
343
  issue_type = mapped_issue_type
344
  break
 
 
 
 
 
 
 
345
 
 
346
  priority = heuristic_priority(text)
347
  resolution_action = heuristic_resolution_action(text, issue_type)
348
 
@@ -352,14 +562,75 @@ def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
352
  if "priority" in allowed_fields:
353
  result["priority"] = priority
354
  if "assignment_group" in allowed_fields:
355
- result["assignment_group"] = ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(
356
- issue_type, "service_desk"
357
- )
358
  if "resolution_action" in allowed_fields:
359
  result["resolution_action"] = resolution_action
360
  return result
361
 
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def build_action(
364
  ticket: dict, allowed_fields: list[str], instructions: str
365
  ) -> tuple[HelpdeskTicketAction, str, str | None]:
@@ -370,13 +641,50 @@ def build_action(
370
 
371
  try:
372
  llm_dict = call_llm(ticket, allowed_fields, instructions)
373
- candidate = {
374
- field: llm_dict[field]
375
- for field in allowed_fields
376
- if llm_dict.get(field) is not None
377
- }
378
- if not candidate:
 
 
 
 
 
 
 
 
379
  raise ValueError("LLM returned no allowed fields")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  return HelpdeskTicketAction(**candidate), "llm", None
381
  except Exception as exc:
382
  return (
@@ -389,6 +697,10 @@ def build_action(
389
  def should_investigate(ticket: dict, history: list[dict[str, Any]]) -> tuple[bool, str | None]:
390
  if not ticket:
391
  return False, None
 
 
 
 
392
  current_ticket_id = ticket.get("ticket_id")
393
  already_investigated = any(
394
  entry.get("ticket_id") == current_ticket_id
@@ -408,6 +720,22 @@ def merge_ticket_context(ticket: dict, observation: Any) -> dict:
408
  merged_ticket = dict(ticket)
409
  if getattr(observation, "last_tool_result", None) is not None:
410
  merged_ticket["last_tool_result"] = observation.last_tool_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  return merged_ticket
412
 
413
 
@@ -518,7 +846,12 @@ def run() -> None:
518
  ticket_id=ticket["ticket_id"],
519
  )
520
 
521
- final_reward = task_step_rewards[-1] if task_step_rewards else 0.0
 
 
 
 
 
522
  all_results[task_id] = {
523
  "final_reward": final_reward,
524
  "step_count": step_num,
 
66
  DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
67
  DEFAULT_MODEL_NAME = "<your-active-model>"
68
 
69
+
70
+ def _get_int_env(name: str, default: int) -> int:
71
+ raw_value = os.getenv(name)
72
+ if raw_value is None or raw_value.strip() == "":
73
+ return default
74
+ try:
75
+ return int(raw_value)
76
+ except ValueError:
77
+ print(
78
+ f"[WARN] {name}={raw_value!r} is not a valid integer; using {default}.",
79
+ flush=True,
80
+ )
81
+ return default
82
+
83
  API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
84
  MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
85
  HF_TOKEN = os.getenv("HF_TOKEN")
86
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
87
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
88
 
89
+ SEED = _get_int_env("SEED", 42)
90
  TASK_ID_ENV = os.getenv("TASK_ID")
91
  RUN_ALL_TASKS_ENV = os.getenv("RUN_ALL_TASKS", "").strip().lower() in {
92
  "1",
 
108
  llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
109
 
110
 
111
+ RECENT_HISTORY_LIMIT = 2
112
+ ROUTING_PRIORS = "\n".join(
113
+ f"- {issue_type}: assignment_group={ISSUE_TYPE_TO_ASSIGNMENT_GROUP[issue_type]}, "
114
+ f"resolution_action={ISSUE_TYPE_TO_RESOLUTION_ACTION[issue_type]}"
115
+ for issue_type in ISSUE_TYPES
116
+ )
117
+
118
+
119
  SYSTEM_PROMPT = """\
120
  You are an expert IT helpdesk ticket routing agent. Given a helpdesk ticket, you must produce a JSON object with the requested fields.
121
 
 
125
  - assignment_group: {assignment_groups}
126
  - resolution_action: {resolution_actions}
127
 
128
+ Decision rules:
129
+ - Follow this environment's label ontology exactly; do not invent categories.
130
+ - Prefer the primary operational workflow label over a secondary technical symptom.
131
+ - Keep assignment_group and resolution_action consistent with the chosen issue_type unless the ticket explicitly justifies a different choice.
132
+ - Use investigation results and recent evaluation feedback when provided.
133
+
134
+ Domain conventions:
135
+ - Enterprise pricing, quotes, plan comparisons, and commercial procurement requests map to service_request, usually with medium priority.
136
+ - Onboarding work that is blocked by an access problem still maps to onboarding when the primary workflow is onboarding; the assignment_group may still be service_desk if the ticket says onboarding cannot resolve the access issue.
137
+ - Single-user sign-in, login, MFA, or 2FA lockouts map to identity_access and are usually high priority, not critical.
138
+ - Reserve critical priority for outages, widespread business blockers, or explicit urgent critical incidents.
139
+
140
+ Routing priors:
141
+ {routing_priors}
142
+
143
  Return ONLY valid JSON with the requested fields. No markdown, no explanation.""".format(
144
  issue_types=", ".join(ISSUE_TYPES),
145
  priorities=", ".join(PRIORITIES),
146
  assignment_groups=", ".join(ASSIGNMENT_GROUPS),
147
  resolution_actions=", ".join(RESOLUTION_ACTIONS),
148
+ routing_priors=ROUTING_PRIORS,
149
  )
150
 
151
 
152
+ def format_recent_history_entries(
153
+ history: list[dict[str, Any]], limit: int = RECENT_HISTORY_LIMIT
154
+ ) -> str:
155
+ if not history:
156
+ return ""
157
+
158
+ lines = ["Recent evaluation feedback (latest last):"]
159
+ for entry in history[-limit:]:
160
+ predicted = json.dumps(entry.get("predicted", {}), sort_keys=True)
161
+ line = (
162
+ f"- Ticket {entry.get('ticket_id', '?')}: predicted={predicted}, "
163
+ f"score={entry.get('score', 0.0)}"
164
+ )
165
+ feedback_summary = entry.get("feedback_summary")
166
+ if feedback_summary:
167
+ line += f", feedback={feedback_summary}"
168
+ reward = entry.get("reward")
169
+ if reward is not None:
170
+ line += f", reward={reward}"
171
+ rubric_reward = entry.get("rubric_reward")
172
+ if rubric_reward is not None:
173
+ line += f", rubric_reward={rubric_reward}"
174
+ breakdown = entry.get("breakdown") or {}
175
+ if breakdown:
176
+ line += f", breakdown={json.dumps(breakdown, sort_keys=True)}"
177
+ penalty_reason = entry.get("penalty_reason")
178
+ if penalty_reason:
179
+ line += f", penalty_reason={penalty_reason}"
180
+ tool_result = entry.get("tool_result")
181
+ if tool_result is not None:
182
+ line += f", tool_result={json.dumps(tool_result, sort_keys=True)}"
183
+ reward_components = entry.get("reward_components")
184
+ if reward_components:
185
+ line += f", reward_components={json.dumps(reward_components, sort_keys=True)}"
186
+ lines.append(line)
187
+ return "\n".join(lines)
188
+
189
+
190
+ def build_llm_user_message(ticket: dict, allowed_fields: list[str], instructions: str) -> str:
191
  ambiguity_note = ticket.get("ambiguity_note")
192
  related_preview = ticket.get("related_ticket_preview") or {}
193
  last_tool_result = ticket.get("last_tool_result")
194
+ context_status = ticket.get("context_status") or {}
195
+ recent_history = ticket.get("recent_history") or []
196
+ feedback_summary = ticket.get("feedback_summary")
197
+ last_reward_components = ticket.get("last_reward_components") or {}
198
+ investigation_budget_remaining = ticket.get("investigation_budget_remaining")
199
+ average_score_so_far = ticket.get("average_score_so_far")
200
+ progress_fraction = ticket.get("progress_fraction")
201
  extra_context_lines: list[str] = []
202
  if ambiguity_note:
203
  extra_context_lines.append(f"Ambiguity note: {ambiguity_note}")
 
214
  extra_context_lines.append(
215
  "Investigation result: " + json.dumps(last_tool_result, sort_keys=True)
216
  )
217
+ if context_status:
218
+ extra_context_lines.append(
219
+ "Context status: " + json.dumps(context_status, sort_keys=True)
220
+ )
221
+ if feedback_summary:
222
+ extra_context_lines.append(f"Latest environment feedback: {feedback_summary}")
223
+ if last_reward_components:
224
+ extra_context_lines.append(
225
+ "Latest reward components: "
226
+ + json.dumps(last_reward_components, sort_keys=True)
227
+ )
228
+ recent_history_block = format_recent_history_entries(recent_history)
229
+ if recent_history_block:
230
+ extra_context_lines.append(recent_history_block)
231
+ queue_position = ticket.get("queue_position")
232
+ tickets_remaining = ticket.get("tickets_remaining")
233
+ if queue_position is not None and tickets_remaining is not None:
234
+ extra_context_lines.append(
235
+ f"Queue context: queue_position={queue_position}, tickets_remaining={tickets_remaining}"
236
+ )
237
+ if average_score_so_far is not None:
238
+ extra_context_lines.append(f"Average score so far: {average_score_so_far}")
239
+ if progress_fraction is not None:
240
+ extra_context_lines.append(f"Episode progress: {progress_fraction}")
241
+ if investigation_budget_remaining is not None:
242
+ extra_context_lines.append(
243
+ f"Investigation budget remaining: {investigation_budget_remaining}"
244
+ )
245
  extra_context_block = ""
246
  if extra_context_lines:
247
  extra_context_block = "\n" + "\n".join(extra_context_lines)
248
 
249
+ return (
250
  f"Instructions: {instructions}\n\n"
251
  f"Allowed fields: {', '.join(allowed_fields)}\n\n"
252
+ f"Title: {ticket.get('title', '')}\n"
253
+ f"Requester: {ticket.get('requester', '')}\n"
254
+ f"Description: {ticket.get('description', '')}"
255
  f"{extra_context_block}\n\n"
256
  f"Respond with JSON containing ONLY these fields: {', '.join(allowed_fields)}"
257
  )
258
 
259
+
260
+ def call_llm(ticket: dict, allowed_fields: list[str], instructions: str) -> dict:
261
+ assert llm_client is not None, "LLM client not configured"
262
+ user_msg = build_llm_user_message(ticket, allowed_fields, instructions)
263
+
264
  response = llm_client.chat.completions.create(
265
  model=MODEL_NAME,
266
  messages=[
 
413
  "mfa enabled",
414
  )
415
 
416
+ PRICING_REQUEST_KEYWORDS = (
417
+ "pricing breakdown",
418
+ "enterprise tier pricing",
419
+ "enterprise plan",
420
+ "compare your enterprise plan",
421
+ "comparing your enterprise plan",
422
+ "quote",
423
+ "pricing quote",
424
+ "commercial proposal",
425
+ "vendor comparison",
426
+ )
427
+
428
+ ONBOARDING_WORKFLOW_KEYWORDS = (
429
+ "onboarding",
430
+ "new hire",
431
+ "contractor",
432
+ "provisioned",
433
+ "kickoff onboarding",
434
+ )
435
+
436
+ ACCESS_BLOCKER_KEYWORDS = (
437
+ "access issue",
438
+ "permissions error",
439
+ "permission error",
440
+ "account access is blocked",
441
+ "cannot sign in",
442
+ "can't sign in",
443
+ "locked",
444
+ "2fa",
445
+ "mfa",
446
+ )
447
+
448
+ SERVICE_DESK_ONBOARDING_ESCALATION_KEYWORDS = (
449
+ "onboarding team cannot resolve access issues",
450
+ "routing to service desk",
451
+ "route to service desk",
452
+ "service desk",
453
+ )
454
+
455
+ CRITICAL_INCIDENT_KEYWORDS = (
456
+ "outage",
457
+ "company-wide",
458
+ "all users",
459
+ "widespread",
460
+ "production down",
461
+ "critical incident",
462
+ "sev1",
463
+ )
464
+
465
+ HIGH_PRIORITY_SIGNAL_KEYWORDS = (
466
+ "locked",
467
+ "blocked",
468
+ "cannot sign in",
469
+ "can't sign in",
470
+ "2fa",
471
+ "mfa",
472
+ "expedite",
473
+ "start monday",
474
+ "asap",
475
+ "today",
476
+ "eod",
477
+ "urgent",
478
+ )
479
+
480
+ TIME_SENSITIVE_PRIORITY_KEYWORDS = (
481
+ "expedite",
482
+ "start monday",
483
+ "today",
484
+ "asap",
485
+ "eod",
486
+ "urgent",
487
+ "immediately",
488
+ )
489
+
490
+
491
+ def build_routing_text(ticket: dict) -> str:
492
+ related_preview = ticket.get("related_ticket_preview") or {}
493
+ last_tool_result = ticket.get("last_tool_result") or {}
494
+ return " ".join(
495
+ [
496
+ ticket.get("title", ""),
497
+ ticket.get("description", ""),
498
+ ticket.get("ambiguity_note", ""),
499
+ related_preview.get("title", ""),
500
+ related_preview.get("description", ""),
501
+ json.dumps(last_tool_result, sort_keys=True),
502
+ ]
503
+ ).lower()
504
+
505
 
506
  def heuristic_priority(text: str) -> str:
507
  if any(word in text for word in CRITICAL_PRIORITY_KEYWORDS):
 
527
  return ISSUE_TYPE_TO_RESOLUTION_ACTION.get(issue_type, "acknowledge")
528
 
529
 
530
+ def heuristic_assignment_group(text: str, issue_type: str) -> str:
531
+ if issue_type == "onboarding":
532
+ if any(keyword in text for keyword in SERVICE_DESK_ONBOARDING_ESCALATION_KEYWORDS):
533
+ return "service_desk"
534
+ if any(keyword in text for keyword in ACCESS_BLOCKER_KEYWORDS) and any(
535
+ keyword in text for keyword in ONBOARDING_WORKFLOW_KEYWORDS
536
+ ):
537
+ return "service_desk"
538
+ return ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(issue_type, "service_desk")
 
 
 
 
539
 
540
+
541
+ def infer_issue_type(text: str) -> str:
542
  issue_type = "general_inquiry"
543
  for kw, mapped_issue_type in KEYWORD_ISSUE_TYPES.items():
544
  if kw in text:
545
  issue_type = mapped_issue_type
546
  break
547
+ return issue_type
548
+
549
+
550
+ def heuristic_action(
551
+ ticket: dict, allowed_fields: list[str], issue_type_override: str | None = None
552
+ ) -> dict:
553
+ text = build_routing_text(ticket)
554
 
555
+ issue_type = issue_type_override or infer_issue_type(text)
556
  priority = heuristic_priority(text)
557
  resolution_action = heuristic_resolution_action(text, issue_type)
558
 
 
562
  if "priority" in allowed_fields:
563
  result["priority"] = priority
564
  if "assignment_group" in allowed_fields:
565
+ result["assignment_group"] = heuristic_assignment_group(text, issue_type)
 
 
566
  if "resolution_action" in allowed_fields:
567
  result["resolution_action"] = resolution_action
568
  return result
569
 
570
 
571
+ def apply_domain_overrides(
572
+ ticket: dict, candidate: dict[str, Any], allowed_fields: list[str]
573
+ ) -> tuple[dict[str, Any], list[str]]:
574
+ updated = dict(candidate)
575
+ reasons: list[str] = []
576
+ text = build_routing_text(ticket)
577
+
578
+ issue_type = updated.get("issue_type")
579
+ if "issue_type" in allowed_fields and issue_type is not None:
580
+ if (
581
+ issue_type in {"billing_license", "general_inquiry"}
582
+ and any(keyword in text for keyword in PRICING_REQUEST_KEYWORDS)
583
+ ):
584
+ updated["issue_type"] = "service_request"
585
+ issue_type = "service_request"
586
+ reasons.append("override_issue_type=service_request(pricing_request)")
587
+ elif (
588
+ issue_type == "identity_access"
589
+ and any(keyword in text for keyword in ONBOARDING_WORKFLOW_KEYWORDS)
590
+ and any(keyword in text for keyword in ACCESS_BLOCKER_KEYWORDS)
591
+ ):
592
+ updated["issue_type"] = "onboarding"
593
+ issue_type = "onboarding"
594
+ reasons.append("override_issue_type=onboarding(onboarding_access_blocker)")
595
+
596
+ if issue_type is not None:
597
+ if "assignment_group" in allowed_fields:
598
+ desired_group = heuristic_assignment_group(text, issue_type)
599
+ if updated.get("assignment_group") != desired_group:
600
+ updated["assignment_group"] = desired_group
601
+ reasons.append(f"override_assignment_group={desired_group}")
602
+ if "resolution_action" in allowed_fields:
603
+ desired_resolution = heuristic_resolution_action(text, issue_type)
604
+ if updated.get("resolution_action") != desired_resolution:
605
+ updated["resolution_action"] = desired_resolution
606
+ reasons.append(f"override_resolution_action={desired_resolution}")
607
+
608
+ if "priority" in allowed_fields and updated.get("priority") is not None:
609
+ priority = updated["priority"]
610
+ has_critical_signal = any(keyword in text for keyword in CRITICAL_INCIDENT_KEYWORDS)
611
+ has_high_signal = any(keyword in text for keyword in HIGH_PRIORITY_SIGNAL_KEYWORDS)
612
+ if priority == "critical" and not has_critical_signal:
613
+ updated["priority"] = "high" if has_high_signal else "medium"
614
+ reasons.append(f"override_priority={updated['priority']}(deescalated_from_critical)")
615
+ elif (
616
+ priority == "high"
617
+ and issue_type in {"service_request", "onboarding"}
618
+ and not any(keyword in text for keyword in TIME_SENSITIVE_PRIORITY_KEYWORDS)
619
+ ):
620
+ updated["priority"] = "medium"
621
+ reasons.append("override_priority=medium(nonurgent_workflow_request)")
622
+ elif (
623
+ priority == "medium"
624
+ and issue_type == "identity_access"
625
+ and any(keyword in text for keyword in ("cannot sign in", "can't sign in", "2fa", "mfa", "locked"))
626
+ and not has_critical_signal
627
+ ):
628
+ updated["priority"] = "high"
629
+ reasons.append("override_priority=high(identity_lockout)")
630
+
631
+ return updated, reasons
632
+
633
+
634
  def build_action(
635
  ticket: dict, allowed_fields: list[str], instructions: str
636
  ) -> tuple[HelpdeskTicketAction, str, str | None]:
 
641
 
642
  try:
643
  llm_dict = call_llm(ticket, allowed_fields, instructions)
644
+ validated_llm_fields: dict[str, Any] = {}
645
+ rejected_fields: list[str] = []
646
+ for field in allowed_fields:
647
+ value = llm_dict.get(field)
648
+ if value is None:
649
+ continue
650
+ try:
651
+ HelpdeskTicketAction(**{field: value})
652
+ except Exception:
653
+ rejected_fields.append(field)
654
+ continue
655
+ validated_llm_fields[field] = value
656
+
657
+ if not validated_llm_fields:
658
  raise ValueError("LLM returned no allowed fields")
659
+
660
+ candidate = heuristic_action(
661
+ ticket,
662
+ allowed_fields,
663
+ issue_type_override=validated_llm_fields.get("issue_type"),
664
+ )
665
+ candidate.update(validated_llm_fields)
666
+ accepted_fields = list(validated_llm_fields)
667
+ candidate, override_reasons = apply_domain_overrides(
668
+ ticket,
669
+ candidate,
670
+ allowed_fields,
671
+ )
672
+
673
+ backfilled_fields = [field for field in allowed_fields if field not in accepted_fields]
674
+ if backfilled_fields or rejected_fields or override_reasons:
675
+ reason_parts = []
676
+ if backfilled_fields:
677
+ reason_parts.append(f"heuristic_backfill={backfilled_fields}")
678
+ if rejected_fields:
679
+ reason_parts.append(f"invalid_llm_fields={rejected_fields}")
680
+ if override_reasons:
681
+ reason_parts.append(f"domain_overrides={override_reasons}")
682
+ return (
683
+ HelpdeskTicketAction(**candidate),
684
+ "llm_backfilled",
685
+ "; ".join(reason_parts),
686
+ )
687
+
688
  return HelpdeskTicketAction(**candidate), "llm", None
689
  except Exception as exc:
690
  return (
 
697
  def should_investigate(ticket: dict, history: list[dict[str, Any]]) -> tuple[bool, str | None]:
698
  if not ticket:
699
  return False, None
700
+ context_status = ticket.get("context_status") or {}
701
+ remaining_tools = context_status.get("remaining_tools") or []
702
+ if remaining_tools:
703
+ return True, str(remaining_tools[0])
704
  current_ticket_id = ticket.get("ticket_id")
705
  already_investigated = any(
706
  entry.get("ticket_id") == current_ticket_id
 
720
  merged_ticket = dict(ticket)
721
  if getattr(observation, "last_tool_result", None) is not None:
722
  merged_ticket["last_tool_result"] = observation.last_tool_result
723
+ merged_ticket["recent_history"] = list(getattr(observation, "history", []))
724
+ merged_ticket["queue_position"] = getattr(observation, "queue_position", None)
725
+ merged_ticket["tickets_remaining"] = getattr(observation, "tickets_remaining", None)
726
+ merged_ticket["investigation_budget_remaining"] = getattr(
727
+ observation,
728
+ "investigation_budget_remaining",
729
+ None,
730
+ )
731
+ merged_ticket["average_score_so_far"] = getattr(observation, "average_score_so_far", None)
732
+ merged_ticket["progress_fraction"] = getattr(observation, "progress_fraction", None)
733
+ merged_ticket["last_reward_components"] = dict(
734
+ getattr(observation, "last_reward_components", {}) or {}
735
+ )
736
+ observation_metadata = getattr(observation, "metadata", {}) or {}
737
+ if observation_metadata.get("last_feedback_summary"):
738
+ merged_ticket["feedback_summary"] = observation_metadata["last_feedback_summary"]
739
  return merged_ticket
740
 
741
 
 
846
  ticket_id=ticket["ticket_id"],
847
  )
848
 
849
+ final_rubric_reward = getattr(obs, "rubric_reward", None)
850
+ final_reward = (
851
+ float(final_rubric_reward)
852
+ if final_rubric_reward is not None
853
+ else (task_step_rewards[-1] if task_step_rewards else 0.0)
854
+ )
855
  all_results[task_id] = {
856
  "final_reward": final_reward,
857
  "step_count": step_num,
models.py CHANGED
@@ -18,6 +18,7 @@ ASSIGNMENT_GROUP_SET = set(ASSIGNMENT_GROUPS)
18
  RESOLUTION_ACTION_SET = set(RESOLUTION_ACTIONS)
19
  ACTION_TYPE_SET = {"submit", "investigate"}
20
  TOOL_NAME_SET = {"lookup_related_ticket", "lookup_requester_history"}
 
21
 
22
 
23
  def _validate_choice(value: str, allowed: set[str], field_name: str) -> str:
@@ -113,6 +114,7 @@ class HelpdeskTicketObservation(Observation):
113
  task_name: str = ""
114
  instructions: str = ""
115
  allowed_fields: list[str] = Field(default_factory=list)
 
116
  available_tools: list[str] = Field(default_factory=list)
117
  investigation_budget_remaining: int = 0
118
  last_tool_result: Optional[dict[str, Any]] = None
@@ -122,7 +124,11 @@ class HelpdeskTicketObservation(Observation):
122
  tickets_after_current: int = 0
123
  tickets_processed: int = 0
124
  queue_position: int = 0
 
 
125
  history: list[dict[str, Any]] = Field(default_factory=list)
 
 
126
 
127
 
128
  class HelpdeskTicketState(State):
@@ -136,7 +142,11 @@ class HelpdeskTicketState(State):
136
  # `reward` is the field the evaluator checks on GET /state (mentor spec)
137
  reward: Optional[float] = None
138
  done: bool = False
 
139
  investigation_steps: int = 0
140
  investigation_budget_remaining: int = 0
 
141
  last_tool_result: Optional[dict[str, Any]] = None
 
 
142
  history_entries: list[dict] = Field(default_factory=list)
 
18
  RESOLUTION_ACTION_SET = set(RESOLUTION_ACTIONS)
19
  ACTION_TYPE_SET = {"submit", "investigate"}
20
  TOOL_NAME_SET = {"lookup_related_ticket", "lookup_requester_history"}
21
+ TOOL_NAME_SET.add("lookup_internal_routing_note")
22
 
23
 
24
  def _validate_choice(value: str, allowed: set[str], field_name: str) -> str:
 
114
  task_name: str = ""
115
  instructions: str = ""
116
  allowed_fields: list[str] = Field(default_factory=list)
117
+ available_action_types: list[str] = Field(default_factory=list)
118
  available_tools: list[str] = Field(default_factory=list)
119
  investigation_budget_remaining: int = 0
120
  last_tool_result: Optional[dict[str, Any]] = None
 
124
  tickets_after_current: int = 0
125
  tickets_processed: int = 0
126
  queue_position: int = 0
127
+ average_score_so_far: float = 0.0
128
+ progress_fraction: float = 0.0
129
  history: list[dict[str, Any]] = Field(default_factory=list)
130
+ last_reward_components: dict[str, Any] = Field(default_factory=dict)
131
+ rubric_reward: Optional[float] = None
132
 
133
 
134
  class HelpdeskTicketState(State):
 
142
  # `reward` is the field the evaluator checks on GET /state (mentor spec)
143
  reward: Optional[float] = None
144
  done: bool = False
145
+ average_score_so_far: float = 0.0
146
  investigation_steps: int = 0
147
  investigation_budget_remaining: int = 0
148
+ investigation_penalty_applied: float = 0.0
149
  last_tool_result: Optional[dict[str, Any]] = None
150
+ last_reward_components: dict[str, Any] = Field(default_factory=dict)
151
+ ticket_tool_usage: dict[str, list[str]] = Field(default_factory=dict)
152
  history_entries: list[dict] = Field(default_factory=list)
policy_learning.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import importlib
6
+ import json
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from statistics import mean
10
+ from typing import Any, Callable, Iterable
11
+
12
+ from models import HelpdeskTicketAction, HelpdeskTicketObservation
13
+ from server.environment import HelpdeskTicketRoutingEnvironment
14
+ from server.tasks import get_task_definition
15
+ from vocabulary import TASK_IDS
16
+
17
+
18
+ DEFAULT_COMPARE_POLICIES = (
19
+ "no_investigation",
20
+ "investigate_when_context_hidden",
21
+ )
22
+ DEFAULT_SEARCH_POLICIES = (
23
+ "no_investigation",
24
+ "legacy_single_probe",
25
+ "investigate_when_context_hidden",
26
+ "context_chain",
27
+ "hybrid_context",
28
+ )
29
+ DEFAULT_OUTPUT_DIR = "analysis/policy_learning_runs"
30
+
31
+ SubmitBuilder = Callable[[dict[str, Any], list[str]], HelpdeskTicketAction]
32
+ EnvFactory = Callable[[], HelpdeskTicketRoutingEnvironment]
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class PolicyConfig:
37
+ name: str
38
+ investigate_hidden_context: bool
39
+ investigate_related_ticket_hint: bool
40
+ investigate_ambiguity_history: bool
41
+ max_investigations_per_ticket: int
42
+ description: str
43
+
44
+
45
+ POLICY_LIBRARY: dict[str, PolicyConfig] = {
46
+ "no_investigation": PolicyConfig(
47
+ name="no_investigation",
48
+ investigate_hidden_context=False,
49
+ investigate_related_ticket_hint=False,
50
+ investigate_ambiguity_history=False,
51
+ max_investigations_per_ticket=0,
52
+ description="Always submit immediately and never investigate.",
53
+ ),
54
+ "legacy_single_probe": PolicyConfig(
55
+ name="legacy_single_probe",
56
+ investigate_hidden_context=False,
57
+ investigate_related_ticket_hint=True,
58
+ investigate_ambiguity_history=True,
59
+ max_investigations_per_ticket=1,
60
+ description="Mimics the earlier single-tool hint policy.",
61
+ ),
62
+ "investigate_when_context_hidden": PolicyConfig(
63
+ name="investigate_when_context_hidden",
64
+ investigate_hidden_context=True,
65
+ investigate_related_ticket_hint=False,
66
+ investigate_ambiguity_history=False,
67
+ max_investigations_per_ticket=1,
68
+ description="Investigate once when the environment says context is hidden.",
69
+ ),
70
+ "context_chain": PolicyConfig(
71
+ name="context_chain",
72
+ investigate_hidden_context=True,
73
+ investigate_related_ticket_hint=False,
74
+ investigate_ambiguity_history=False,
75
+ max_investigations_per_ticket=3,
76
+ description="Follow the environment's required-tool chain until context is revealed.",
77
+ ),
78
+ "hybrid_context": PolicyConfig(
79
+ name="hybrid_context",
80
+ investigate_hidden_context=True,
81
+ investigate_related_ticket_hint=True,
82
+ investigate_ambiguity_history=True,
83
+ max_investigations_per_ticket=3,
84
+ description="Use hidden-context signals first, then legacy ambiguity hints.",
85
+ ),
86
+ }
87
+
88
+
89
+ def _dedupe_preserving_order(values: Iterable[int]) -> list[int]:
90
+ seen: set[int] = set()
91
+ ordered: list[int] = []
92
+ for value in values:
93
+ if value in seen:
94
+ continue
95
+ seen.add(value)
96
+ ordered.append(value)
97
+ return ordered
98
+
99
+
100
+ def parse_int_spec(spec: str, *, field_name: str) -> list[int]:
101
+ values: list[int] = []
102
+ for chunk in spec.split(","):
103
+ part = chunk.strip()
104
+ if not part:
105
+ continue
106
+ if "-" in part:
107
+ start_raw, end_raw = part.split("-", 1)
108
+ try:
109
+ start = int(start_raw)
110
+ end = int(end_raw)
111
+ except ValueError as exc:
112
+ raise ValueError(f"{field_name} contains an invalid range: {part!r}") from exc
113
+ if end < start:
114
+ raise ValueError(f"{field_name} range must be ascending: {part!r}")
115
+ values.extend(range(start, end + 1))
116
+ continue
117
+ try:
118
+ values.append(int(part))
119
+ except ValueError as exc:
120
+ raise ValueError(f"{field_name} contains an invalid integer: {part!r}") from exc
121
+ if not values:
122
+ raise ValueError(f"{field_name} must not be empty")
123
+ return _dedupe_preserving_order(values)
124
+
125
+
126
+ def parse_task_ids(spec: str) -> list[int]:
127
+ task_ids = parse_int_spec(spec, field_name="task_ids")
128
+ unsupported = [task_id for task_id in task_ids if task_id not in TASK_IDS]
129
+ if unsupported:
130
+ raise ValueError(f"Unsupported task_ids: {unsupported}")
131
+ return task_ids
132
+
133
+
134
+ def resolve_policies(spec: str) -> list[PolicyConfig]:
135
+ names = [name.strip() for name in spec.split(",") if name.strip()]
136
+ if not names:
137
+ raise ValueError("At least one policy must be specified")
138
+ policies: list[PolicyConfig] = []
139
+ for name in names:
140
+ if name not in POLICY_LIBRARY:
141
+ raise ValueError(
142
+ f"Unknown policy {name!r}. Available policies: {sorted(POLICY_LIBRARY)}"
143
+ )
144
+ policies.append(POLICY_LIBRARY[name])
145
+ return policies
146
+
147
+
148
+ def default_submit_builder(
149
+ ticket: dict[str, Any], allowed_fields: list[str]
150
+ ) -> HelpdeskTicketAction:
151
+ inference = importlib.import_module("inference")
152
+ candidate = inference.heuristic_action(ticket, allowed_fields)
153
+ candidate, _ = inference.apply_domain_overrides(ticket, candidate, allowed_fields)
154
+ return HelpdeskTicketAction(**candidate)
155
+
156
+
157
+ def choose_policy_action(
158
+ policy: PolicyConfig,
159
+ observation: HelpdeskTicketObservation,
160
+ investigations_by_ticket: dict[str, int],
161
+ submit_builder: SubmitBuilder,
162
+ ) -> tuple[HelpdeskTicketAction, str]:
163
+ ticket = observation.current_ticket or {}
164
+ ticket_id = str(ticket.get("ticket_id", ""))
165
+ ticket_investigations = investigations_by_ticket.get(ticket_id, 0)
166
+ revealed_tools = set(((ticket.get("context_status") or {}).get("revealed_tools") or []))
167
+ remaining_tools = list(((ticket.get("context_status") or {}).get("remaining_tools") or []))
168
+
169
+ if ticket_investigations < policy.max_investigations_per_ticket:
170
+ if policy.investigate_hidden_context and remaining_tools:
171
+ tool_name = str(remaining_tools[0])
172
+ return (
173
+ HelpdeskTicketAction(action_type="investigate", tool_name=tool_name),
174
+ "investigate_hidden_context",
175
+ )
176
+ if (
177
+ policy.investigate_related_ticket_hint
178
+ and ticket.get("related_ticket_id")
179
+ and "lookup_related_ticket" not in revealed_tools
180
+ ):
181
+ return (
182
+ HelpdeskTicketAction(
183
+ action_type="investigate",
184
+ tool_name="lookup_related_ticket",
185
+ ),
186
+ "investigate_related_ticket_hint",
187
+ )
188
+ if (
189
+ policy.investigate_ambiguity_history
190
+ and ticket.get("ambiguity_note")
191
+ and "lookup_requester_history" not in revealed_tools
192
+ ):
193
+ return (
194
+ HelpdeskTicketAction(
195
+ action_type="investigate",
196
+ tool_name="lookup_requester_history",
197
+ ),
198
+ "investigate_ambiguity_history",
199
+ )
200
+
201
+ return submit_builder(ticket, list(observation.allowed_fields)), "submit"
202
+
203
+
204
+ def rollout_episode(
205
+ *,
206
+ env: HelpdeskTicketRoutingEnvironment,
207
+ policy: PolicyConfig,
208
+ seed: int,
209
+ task_id: int,
210
+ submit_builder: SubmitBuilder,
211
+ ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
212
+ task = get_task_definition(task_id)
213
+ observation = env.reset(seed=seed, task_id=task_id)
214
+ investigations_by_ticket: dict[str, int] = {}
215
+ episode_return = 0.0
216
+ trajectories: list[dict[str, Any]] = []
217
+
218
+ while not observation.done:
219
+ ticket = observation.current_ticket or {}
220
+ ticket_id = str(ticket.get("ticket_id", ""))
221
+ action, action_source = choose_policy_action(
222
+ policy,
223
+ observation,
224
+ investigations_by_ticket,
225
+ submit_builder,
226
+ )
227
+ next_observation = env.step(action)
228
+ reward_value = float(next_observation.reward or 0.0)
229
+ episode_return += reward_value
230
+ if action.action_type == "investigate" and ticket_id:
231
+ investigations_by_ticket[ticket_id] = investigations_by_ticket.get(ticket_id, 0) + 1
232
+
233
+ history_entry = env.state.history_entries[-1] if env.state.history_entries else {}
234
+ trajectories.append(
235
+ {
236
+ "policy": policy.name,
237
+ "seed": seed,
238
+ "task_id": task_id,
239
+ "task_name": task["name"],
240
+ "episode_id": env.state.episode_id,
241
+ "step_index": len(trajectories) + 1,
242
+ "ticket_id": history_entry.get("ticket_id", ticket_id),
243
+ "action_source": action_source,
244
+ "action": action.model_dump(exclude_none=True),
245
+ "step_reward": reward_value,
246
+ "rubric_reward": next_observation.rubric_reward,
247
+ "done": next_observation.done,
248
+ "feedback_summary": history_entry.get("feedback_summary"),
249
+ "reward_kind": history_entry.get("reward_kind"),
250
+ "score": history_entry.get("score"),
251
+ "breakdown": history_entry.get("breakdown", {}),
252
+ "reward_components": history_entry.get("reward_components", {}),
253
+ "context_status_before_action": ticket.get("context_status"),
254
+ }
255
+ )
256
+ observation = next_observation
257
+
258
+ queue_size = max(1, len(env.state.queue_ticket_ids))
259
+ terminal_reward = float(observation.reward or 0.0)
260
+ terminal_rubric_reward = (
261
+ float(observation.rubric_reward)
262
+ if observation.rubric_reward is not None
263
+ else terminal_reward
264
+ )
265
+ summary = {
266
+ "policy": policy.name,
267
+ "policy_config": asdict(policy),
268
+ "seed": seed,
269
+ "task_id": task_id,
270
+ "task_name": task["name"],
271
+ "episode_id": env.state.episode_id,
272
+ "queue_size": queue_size,
273
+ "step_count": env.state.step_count,
274
+ "tickets_processed": len(env.state.per_ticket_scores),
275
+ "investigation_steps": env.state.investigation_steps,
276
+ "episode_return": episode_return,
277
+ "normalized_return": episode_return / queue_size,
278
+ "terminal_reward": terminal_reward,
279
+ "terminal_rubric_reward": terminal_rubric_reward,
280
+ "average_ticket_score": env.state.average_score_so_far,
281
+ "per_ticket_scores": list(env.state.per_ticket_scores),
282
+ }
283
+ return summary, trajectories
284
+
285
+
286
+ def _safe_mean(values: list[float]) -> float:
287
+ if not values:
288
+ return 0.0
289
+ return round(mean(values), 6)
290
+
291
+
292
+ def summarize_policy_episodes(
293
+ policy: PolicyConfig,
294
+ episode_summaries: list[dict[str, Any]],
295
+ ) -> dict[str, Any]:
296
+ per_task: dict[str, Any] = {}
297
+ for task_id in TASK_IDS:
298
+ task_episodes = [
299
+ episode for episode in episode_summaries if episode["task_id"] == task_id
300
+ ]
301
+ if not task_episodes:
302
+ continue
303
+ per_task[str(task_id)] = {
304
+ "episodes": len(task_episodes),
305
+ "avg_episode_return": _safe_mean(
306
+ [float(episode["episode_return"]) for episode in task_episodes]
307
+ ),
308
+ "avg_normalized_return": _safe_mean(
309
+ [float(episode["normalized_return"]) for episode in task_episodes]
310
+ ),
311
+ "avg_terminal_reward": _safe_mean(
312
+ [float(episode["terminal_reward"]) for episode in task_episodes]
313
+ ),
314
+ "avg_terminal_rubric_reward": _safe_mean(
315
+ [float(episode["terminal_rubric_reward"]) for episode in task_episodes]
316
+ ),
317
+ "avg_investigation_steps": _safe_mean(
318
+ [float(episode["investigation_steps"]) for episode in task_episodes]
319
+ ),
320
+ }
321
+
322
+ return {
323
+ "policy": policy.name,
324
+ "config": asdict(policy),
325
+ "episodes": len(episode_summaries),
326
+ "avg_episode_return": _safe_mean(
327
+ [float(episode["episode_return"]) for episode in episode_summaries]
328
+ ),
329
+ "avg_normalized_return": _safe_mean(
330
+ [float(episode["normalized_return"]) for episode in episode_summaries]
331
+ ),
332
+ "avg_terminal_reward": _safe_mean(
333
+ [float(episode["terminal_reward"]) for episode in episode_summaries]
334
+ ),
335
+ "avg_terminal_rubric_reward": _safe_mean(
336
+ [float(episode["terminal_rubric_reward"]) for episode in episode_summaries]
337
+ ),
338
+ "avg_investigation_steps": _safe_mean(
339
+ [float(episode["investigation_steps"]) for episode in episode_summaries]
340
+ ),
341
+ "avg_ticket_score": _safe_mean(
342
+ [float(episode["average_ticket_score"]) for episode in episode_summaries]
343
+ ),
344
+ "per_task": per_task,
345
+ }
346
+
347
+
348
+ def evaluate_policy(
349
+ policy: PolicyConfig,
350
+ seeds: Iterable[int],
351
+ task_ids: Iterable[int],
352
+ *,
353
+ env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
354
+ submit_builder: SubmitBuilder = default_submit_builder,
355
+ ) -> dict[str, Any]:
356
+ episode_summaries: list[dict[str, Any]] = []
357
+ trajectories: list[dict[str, Any]] = []
358
+
359
+ for seed in seeds:
360
+ for task_id in task_ids:
361
+ env = env_factory()
362
+ summary, episode_trajectories = rollout_episode(
363
+ env=env,
364
+ policy=policy,
365
+ seed=seed,
366
+ task_id=task_id,
367
+ submit_builder=submit_builder,
368
+ )
369
+ episode_summaries.append(summary)
370
+ trajectories.extend(episode_trajectories)
371
+
372
+ return {
373
+ "policy": policy.name,
374
+ "summary": summarize_policy_episodes(policy, episode_summaries),
375
+ "episodes": episode_summaries,
376
+ "trajectories": trajectories,
377
+ }
378
+
379
+
380
+ def _selection_tuple(summary: dict[str, Any]) -> tuple[float, float, float, float]:
381
+ return (
382
+ float(summary["avg_normalized_return"]),
383
+ float(summary["avg_terminal_reward"]),
384
+ float(summary["avg_terminal_rubric_reward"]),
385
+ -float(summary["avg_investigation_steps"]),
386
+ )
387
+
388
+
389
+ def select_best_policy(policy_runs: list[dict[str, Any]]) -> dict[str, Any]:
390
+ return max(policy_runs, key=lambda run: _selection_tuple(run["summary"]))
391
+
392
+
393
+ def _delta(best: dict[str, Any], baseline: dict[str, Any], key: str) -> float:
394
+ return round(float(best[key]) - float(baseline[key]), 6)
395
+
396
+
397
+ def _write_json(path: Path, payload: dict[str, Any]) -> None:
398
+ path.parent.mkdir(parents=True, exist_ok=True)
399
+ path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
400
+
401
+
402
+ def _write_jsonl(path: Path, records: Iterable[dict[str, Any]]) -> None:
403
+ path.parent.mkdir(parents=True, exist_ok=True)
404
+ with path.open("w", encoding="utf-8") as handle:
405
+ for record in records:
406
+ handle.write(json.dumps(record, sort_keys=True) + "\n")
407
+
408
+
409
+ def compare_policies(
410
+ policies: list[PolicyConfig],
411
+ seeds: list[int],
412
+ task_ids: list[int],
413
+ *,
414
+ output_dir: Path,
415
+ env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
416
+ submit_builder: SubmitBuilder = default_submit_builder,
417
+ ) -> dict[str, Any]:
418
+ output_dir = Path(output_dir)
419
+ policy_runs = [
420
+ evaluate_policy(
421
+ policy,
422
+ seeds,
423
+ task_ids,
424
+ env_factory=env_factory,
425
+ submit_builder=submit_builder,
426
+ )
427
+ for policy in policies
428
+ ]
429
+ best_run = select_best_policy(policy_runs)
430
+ baseline_run = policy_runs[0]
431
+
432
+ report = {
433
+ "mode": "compare",
434
+ "task_ids": task_ids,
435
+ "seeds": seeds,
436
+ "selection_metric": "avg_normalized_return",
437
+ "baseline_policy": baseline_run["policy"],
438
+ "best_policy": best_run["policy"],
439
+ "improvement_vs_baseline": {
440
+ "avg_episode_return": _delta(
441
+ best_run["summary"], baseline_run["summary"], "avg_episode_return"
442
+ ),
443
+ "avg_normalized_return": _delta(
444
+ best_run["summary"], baseline_run["summary"], "avg_normalized_return"
445
+ ),
446
+ "avg_terminal_reward": _delta(
447
+ best_run["summary"], baseline_run["summary"], "avg_terminal_reward"
448
+ ),
449
+ "avg_terminal_rubric_reward": _delta(
450
+ best_run["summary"],
451
+ baseline_run["summary"],
452
+ "avg_terminal_rubric_reward",
453
+ ),
454
+ },
455
+ "policy_summaries": [run["summary"] for run in policy_runs],
456
+ "ranking": [
457
+ run["policy"]
458
+ for run in sorted(
459
+ policy_runs,
460
+ key=lambda run: _selection_tuple(run["summary"]),
461
+ reverse=True,
462
+ )
463
+ ],
464
+ "artifacts": {
465
+ "summary": str(output_dir / "compare_summary.json"),
466
+ "episodes": str(output_dir / "compare_episodes.jsonl"),
467
+ "trajectories": str(output_dir / "compare_trajectories.jsonl"),
468
+ },
469
+ }
470
+
471
+ _write_json(output_dir / "compare_summary.json", report)
472
+ _write_jsonl(
473
+ output_dir / "compare_episodes.jsonl",
474
+ (
475
+ {"policy": run["policy"], **episode}
476
+ for run in policy_runs
477
+ for episode in run["episodes"]
478
+ ),
479
+ )
480
+ _write_jsonl(
481
+ output_dir / "compare_trajectories.jsonl",
482
+ (trajectory for run in policy_runs for trajectory in run["trajectories"]),
483
+ )
484
+ return report
485
+
486
+
487
+ def search_policies(
488
+ candidate_policies: list[PolicyConfig],
489
+ train_seeds: list[int],
490
+ eval_seeds: list[int],
491
+ task_ids: list[int],
492
+ *,
493
+ output_dir: Path,
494
+ env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
495
+ submit_builder: SubmitBuilder = default_submit_builder,
496
+ baseline_policy_name: str = "no_investigation",
497
+ ) -> dict[str, Any]:
498
+ output_dir = Path(output_dir)
499
+ train_runs = [
500
+ evaluate_policy(
501
+ policy,
502
+ train_seeds,
503
+ task_ids,
504
+ env_factory=env_factory,
505
+ submit_builder=submit_builder,
506
+ )
507
+ for policy in candidate_policies
508
+ ]
509
+ selected_run = select_best_policy(train_runs)
510
+ selected_policy = POLICY_LIBRARY[selected_run["policy"]]
511
+ eval_selected = evaluate_policy(
512
+ selected_policy,
513
+ eval_seeds,
514
+ task_ids,
515
+ env_factory=env_factory,
516
+ submit_builder=submit_builder,
517
+ )
518
+
519
+ baseline_policy = POLICY_LIBRARY.get(baseline_policy_name, candidate_policies[0])
520
+ eval_baseline = evaluate_policy(
521
+ baseline_policy,
522
+ eval_seeds,
523
+ task_ids,
524
+ env_factory=env_factory,
525
+ submit_builder=submit_builder,
526
+ )
527
+
528
+ report = {
529
+ "mode": "search",
530
+ "task_ids": task_ids,
531
+ "train_seeds": train_seeds,
532
+ "eval_seeds": eval_seeds,
533
+ "selection_metric": "avg_normalized_return",
534
+ "candidate_policies": [policy.name for policy in candidate_policies],
535
+ "selected_policy": selected_policy.name,
536
+ "baseline_policy": baseline_policy.name,
537
+ "train_policy_summaries": [run["summary"] for run in train_runs],
538
+ "eval_selected_summary": eval_selected["summary"],
539
+ "eval_baseline_summary": eval_baseline["summary"],
540
+ "eval_improvement_vs_baseline": {
541
+ "avg_episode_return": _delta(
542
+ eval_selected["summary"],
543
+ eval_baseline["summary"],
544
+ "avg_episode_return",
545
+ ),
546
+ "avg_normalized_return": _delta(
547
+ eval_selected["summary"],
548
+ eval_baseline["summary"],
549
+ "avg_normalized_return",
550
+ ),
551
+ "avg_terminal_reward": _delta(
552
+ eval_selected["summary"],
553
+ eval_baseline["summary"],
554
+ "avg_terminal_reward",
555
+ ),
556
+ "avg_terminal_rubric_reward": _delta(
557
+ eval_selected["summary"],
558
+ eval_baseline["summary"],
559
+ "avg_terminal_rubric_reward",
560
+ ),
561
+ },
562
+ "artifacts": {
563
+ "summary": str(output_dir / "search_summary.json"),
564
+ "train_episodes": str(output_dir / "search_train_episodes.jsonl"),
565
+ "train_trajectories": str(output_dir / "search_train_trajectories.jsonl"),
566
+ "eval_episodes": str(output_dir / "search_eval_episodes.jsonl"),
567
+ "eval_trajectories": str(output_dir / "search_eval_trajectories.jsonl"),
568
+ },
569
+ }
570
+
571
+ _write_json(output_dir / "search_summary.json", report)
572
+ _write_jsonl(
573
+ output_dir / "search_train_episodes.jsonl",
574
+ (
575
+ {"policy": run["policy"], **episode}
576
+ for run in train_runs
577
+ for episode in run["episodes"]
578
+ ),
579
+ )
580
+ _write_jsonl(
581
+ output_dir / "search_train_trajectories.jsonl",
582
+ (trajectory for run in train_runs for trajectory in run["trajectories"]),
583
+ )
584
+ _write_jsonl(
585
+ output_dir / "search_eval_episodes.jsonl",
586
+ (
587
+ {"policy": eval_selected["policy"], **episode}
588
+ for episode in eval_selected["episodes"]
589
+ ),
590
+ )
591
+ _write_jsonl(
592
+ output_dir / "search_eval_trajectories.jsonl",
593
+ (trajectory for trajectory in eval_selected["trajectories"]),
594
+ )
595
+ return report
596
+
597
+
598
+ def build_parser() -> argparse.ArgumentParser:
599
+ parser = argparse.ArgumentParser(
600
+ description=(
601
+ "Run seeded local rollouts and a small policy-improvement loop for the "
602
+ "IT helpdesk OpenEnv environment."
603
+ )
604
+ )
605
+ subparsers = parser.add_subparsers(dest="command", required=True)
606
+
607
+ compare_parser = subparsers.add_parser(
608
+ "compare",
609
+ help="Compare fixed policy choices across repeated seeded rollouts.",
610
+ )
611
+ compare_parser.add_argument(
612
+ "--policies",
613
+ default=",".join(DEFAULT_COMPARE_POLICIES),
614
+ help=f"Comma-separated policy names. Available: {', '.join(POLICY_LIBRARY)}",
615
+ )
616
+ compare_parser.add_argument(
617
+ "--seeds",
618
+ default="42-51",
619
+ help="Comma-separated seeds or ranges, for example 42-51 or 42,50,60.",
620
+ )
621
+ compare_parser.add_argument(
622
+ "--task-ids",
623
+ default="1,2,3",
624
+ help="Comma-separated task IDs or ranges, for example 1,2,3 or 1-3.",
625
+ )
626
+ compare_parser.add_argument(
627
+ "--output-dir",
628
+ default=DEFAULT_OUTPUT_DIR,
629
+ help="Directory for JSON and JSONL artifacts.",
630
+ )
631
+
632
+ search_parser = subparsers.add_parser(
633
+ "search",
634
+ help="Select the best policy on train seeds, then re-evaluate on holdout seeds.",
635
+ )
636
+ search_parser.add_argument(
637
+ "--candidate-policies",
638
+ default=",".join(DEFAULT_SEARCH_POLICIES),
639
+ help=f"Comma-separated candidate policy names. Available: {', '.join(POLICY_LIBRARY)}",
640
+ )
641
+ search_parser.add_argument(
642
+ "--train-seeds",
643
+ default="40-49",
644
+ help="Train seeds used for reward-based policy selection.",
645
+ )
646
+ search_parser.add_argument(
647
+ "--eval-seeds",
648
+ default="50-59",
649
+ help="Holdout seeds used for the selected policy evaluation.",
650
+ )
651
+ search_parser.add_argument(
652
+ "--task-ids",
653
+ default="1,2,3",
654
+ help="Comma-separated task IDs or ranges, for example 1,2,3 or 1-3.",
655
+ )
656
+ search_parser.add_argument(
657
+ "--baseline-policy",
658
+ default="no_investigation",
659
+ help="Baseline policy used for the final improvement delta.",
660
+ )
661
+ search_parser.add_argument(
662
+ "--output-dir",
663
+ default=DEFAULT_OUTPUT_DIR,
664
+ help="Directory for JSON and JSONL artifacts.",
665
+ )
666
+
667
+ return parser
668
+
669
+
670
+ def _print_summary(label: str, summary: dict[str, Any]) -> None:
671
+ print(
672
+ json.dumps(
673
+ {
674
+ label: {
675
+ "policy": summary["policy"],
676
+ "avg_episode_return": summary["avg_episode_return"],
677
+ "avg_normalized_return": summary["avg_normalized_return"],
678
+ "avg_terminal_reward": summary["avg_terminal_reward"],
679
+ "avg_terminal_rubric_reward": summary["avg_terminal_rubric_reward"],
680
+ "avg_investigation_steps": summary["avg_investigation_steps"],
681
+ }
682
+ },
683
+ sort_keys=True,
684
+ )
685
+ )
686
+
687
+
688
+ def main() -> None:
689
+ parser = build_parser()
690
+ args = parser.parse_args()
691
+
692
+ output_dir = Path(args.output_dir)
693
+
694
+ if args.command == "compare":
695
+ policies = resolve_policies(args.policies)
696
+ seeds = parse_int_spec(args.seeds, field_name="seeds")
697
+ task_ids = parse_task_ids(args.task_ids)
698
+ report = compare_policies(
699
+ policies,
700
+ seeds,
701
+ task_ids,
702
+ output_dir=output_dir,
703
+ )
704
+ print(json.dumps(report, indent=2, sort_keys=True))
705
+ return
706
+
707
+ candidate_policies = resolve_policies(args.candidate_policies)
708
+ train_seeds = parse_int_spec(args.train_seeds, field_name="train_seeds")
709
+ eval_seeds = parse_int_spec(args.eval_seeds, field_name="eval_seeds")
710
+ task_ids = parse_task_ids(args.task_ids)
711
+ report = search_policies(
712
+ candidate_policies,
713
+ train_seeds,
714
+ eval_seeds,
715
+ task_ids,
716
+ output_dir=output_dir,
717
+ baseline_policy_name=args.baseline_policy,
718
+ )
719
+ print(json.dumps(report, indent=2, sort_keys=True))
720
+
721
+
722
+ if __name__ == "__main__":
723
+ main()
pyproject.toml CHANGED
@@ -24,12 +24,13 @@ dependencies = [
24
 
25
  [project.scripts]
26
  server = "server.app:main"
 
27
 
28
  [project.optional-dependencies]
29
  dev = ["pytest", "httpx"]
30
 
31
  [tool.setuptools]
32
- py-modules = ["models", "client", "vocabulary"]
33
 
34
  [tool.setuptools.packages.find]
35
  include = ["server*"]
 
24
 
25
  [project.scripts]
26
  server = "server.app:main"
27
+ policy-learn = "policy_learning:main"
28
 
29
  [project.optional-dependencies]
30
  dev = ["pytest", "httpx"]
31
 
32
  [tool.setuptools]
33
+ py-modules = ["models", "client", "policy_learning", "vocabulary"]
34
 
35
  [tool.setuptools.packages.find]
36
  include = ["server*"]
server/environment.py CHANGED
@@ -18,10 +18,68 @@ from server.tasks import get_task_definition, load_dataset
18
 
19
 
20
  QUEUE_SIZE_RANGE = (3, 5)
21
- AVAILABLE_TOOLS = ("lookup_related_ticket", "lookup_requester_history")
 
 
 
 
 
22
  FREE_INVESTIGATIONS_PER_TICKET = 1
23
  EXTRA_INVESTIGATION_COST = 0.02
24
  MAX_EXTRA_INVESTIGATION_PENALTY = 0.15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
@@ -86,7 +144,11 @@ class HelpdeskTicketRoutingEnvironment(
86
  current_ticket_index=0,
87
  per_ticket_scores=[],
88
  total_reward=0.0,
 
89
  investigation_budget_remaining=queue_size * FREE_INVESTIGATIONS_PER_TICKET,
 
 
 
90
  )
91
 
92
  return self._build_observation(task)
@@ -122,54 +184,104 @@ class HelpdeskTicketRoutingEnvironment(
122
  if extra_fields:
123
  # Penalty: record score 0.0, advance index, return penalty observation
124
  self._state.per_ticket_scores.append(0.0)
125
- self._state.history_entries.append(
126
- self._build_history_entry(
127
- current_ticket,
128
- predicted=action.model_dump(exclude_none=True),
129
- score=0.0,
130
- breakdown={},
131
- queue_position=idx + 1,
132
- penalty_reason=f"extra_fields: {sorted(extra_fields)}",
133
- )
134
- )
135
  self._state.step_count += 1
136
  self._state.current_ticket_index += 1
137
  is_done = self._state.current_ticket_index >= len(self._queue)
138
  self._state.done = is_done
 
 
139
  if is_done:
140
- traj_reward = compute_trajectory_reward(
141
  self._state.per_ticket_scores, len(self._queue), self._state.step_count
142
  )
143
- final_reward = self._apply_episode_economics(traj_reward)
144
  self._state.total_reward = final_reward
145
  else:
146
  final_reward = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  self._state.last_step_reward = final_reward
148
  self._state.reward = final_reward
 
149
  self._state.last_tool_result = None
150
- return self._build_observation(task, done=is_done, reward=final_reward)
 
 
 
 
 
 
151
 
152
  score, breakdown = grade_action(action, current_ticket, task_id)
153
  step_reward = compute_step_reward(score)
 
 
154
 
155
  is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
 
 
 
156
 
157
  if is_done:
158
  self._state.per_ticket_scores.append(score)
 
159
  self._state.step_count += 1
160
  self._state.current_ticket_index += 1
161
- traj_reward = compute_trajectory_reward(
162
  self._state.per_ticket_scores,
163
  len(self._queue),
164
  self._state.step_count,
165
  )
166
- final_reward = self._apply_episode_economics(traj_reward)
167
- self._state.total_reward = final_reward
 
 
168
  else:
169
  self._state.per_ticket_scores.append(score)
 
170
  self._state.step_count += 1
171
  self._state.current_ticket_index += 1
172
- final_reward = step_reward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  history_entry = self._build_history_entry(
175
  current_ticket,
@@ -177,15 +289,26 @@ class HelpdeskTicketRoutingEnvironment(
177
  score=score,
178
  breakdown=breakdown,
179
  queue_position=idx + 1,
 
 
 
 
180
  )
181
  self._state.history_entries.append(history_entry)
182
 
183
  self._state.last_step_reward = final_reward
184
  self._state.reward = final_reward
185
  self._state.done = is_done
 
186
  self._state.last_tool_result = None
 
187
 
188
- return self._build_observation(task, done=is_done, reward=final_reward)
 
 
 
 
 
189
 
190
  @property
191
  def state(self) -> HelpdeskTicketState:
@@ -195,15 +318,112 @@ class HelpdeskTicketRoutingEnvironment(
195
  # Helpers
196
  # ------------------------------------------------------------------
197
 
198
- def _apply_episode_economics(self, base_reward: float) -> float:
199
  free_investigations = len(self._queue) * FREE_INVESTIGATIONS_PER_TICKET
200
  extra_investigations = max(0, self._state.investigation_steps - free_investigations)
201
- penalty = min(
202
  MAX_EXTRA_INVESTIGATION_PENALTY,
203
  extra_investigations * EXTRA_INVESTIGATION_COST,
204
  )
 
 
 
205
  return max(0.0, min(1.0, base_reward - penalty))
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def _lookup_related_ticket(
208
  self,
209
  current_ticket: HelpdeskTicketRecord,
@@ -259,6 +479,15 @@ class HelpdeskTicketRoutingEnvironment(
259
  "matches": matches,
260
  }
261
 
 
 
 
 
 
 
 
 
 
262
  def _run_investigation_tool(
263
  self,
264
  current_ticket: HelpdeskTicketRecord,
@@ -269,6 +498,8 @@ class HelpdeskTicketRoutingEnvironment(
269
  return self._lookup_related_ticket(current_ticket, target_ticket_id)
270
  if tool_name == "lookup_requester_history":
271
  return self._lookup_requester_history(current_ticket)
 
 
272
  raise ValueError(f"Unsupported tool_name: {tool_name}")
273
 
274
  def _handle_investigation_action(
@@ -296,6 +527,14 @@ class HelpdeskTicketRoutingEnvironment(
296
  action.tool_name,
297
  action.tool_target_ticket_id,
298
  )
 
 
 
 
 
 
 
 
299
  self._state.step_count += 1
300
  self._state.investigation_steps += 1
301
  self._state.investigation_budget_remaining = max(
@@ -303,9 +542,25 @@ class HelpdeskTicketRoutingEnvironment(
303
  self._state.investigation_budget_remaining - 1,
304
  )
305
  self._state.last_tool_result = tool_result
306
- self._state.last_step_reward = 0.0
307
- self._state.reward = 0.0
 
308
  self._state.done = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  self._state.history_entries.append(
310
  self._build_history_entry(
311
  current_ticket,
@@ -313,21 +568,35 @@ class HelpdeskTicketRoutingEnvironment(
313
  score=0.0,
314
  breakdown={},
315
  queue_position=idx + 1,
 
 
316
  tool_result=tool_result,
 
317
  )
318
  )
319
- return self._build_observation(task, done=False, reward=0.0)
 
320
 
321
  def _build_ticket_view(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
 
 
 
322
  ticket_view: dict[str, Any] = {
323
  "ticket_id": ticket.ticket_id,
324
  "title": ticket.title,
325
  "requester": ticket.requester,
326
- "description": ticket.description,
327
  }
328
- if ticket.ambiguity_note is not None:
 
 
 
 
 
 
 
329
  ticket_view["ambiguity_note"] = ticket.ambiguity_note
330
- if ticket.related_ticket_id is not None:
331
  ticket_view["related_ticket_id"] = ticket.related_ticket_id
332
  related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
333
  if related_ticket is not None:
@@ -339,6 +608,50 @@ class HelpdeskTicketRoutingEnvironment(
339
  }
340
  return ticket_view
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  def _build_history_entry(
343
  self,
344
  ticket: HelpdeskTicketRecord,
@@ -347,9 +660,15 @@ class HelpdeskTicketRoutingEnvironment(
347
  score: float,
348
  breakdown: dict[str, float],
349
  queue_position: int,
 
 
 
350
  penalty_reason: str | None = None,
351
  tool_result: dict[str, Any] | None = None,
 
352
  ) -> dict[str, Any]:
 
 
353
  history_entry: dict[str, Any] = {
354
  "ticket_id": ticket.ticket_id,
355
  "title": ticket.title,
@@ -359,9 +678,15 @@ class HelpdeskTicketRoutingEnvironment(
359
  "breakdown": breakdown,
360
  "queue_position": queue_position,
361
  }
362
- if ticket.ambiguity_note is not None:
 
 
 
 
 
 
363
  history_entry["ambiguity_note"] = ticket.ambiguity_note
364
- if ticket.related_ticket_id is not None:
365
  history_entry["related_ticket_id"] = ticket.related_ticket_id
366
  related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
367
  if related_ticket is not None:
@@ -375,6 +700,21 @@ class HelpdeskTicketRoutingEnvironment(
375
  history_entry["penalty_reason"] = penalty_reason
376
  if tool_result is not None:
377
  history_entry["tool_result"] = tool_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  return history_entry
379
 
380
  def _build_observation(
@@ -382,6 +722,7 @@ class HelpdeskTicketRoutingEnvironment(
382
  task: dict,
383
  done: bool = False,
384
  reward: float | None = None,
 
385
  ) -> HelpdeskTicketObservation:
386
  idx = self._state.current_ticket_index
387
  queue_size = len(self._queue)
@@ -395,28 +736,47 @@ class HelpdeskTicketRoutingEnvironment(
395
  queue_position = 0
396
 
397
  history = list(self._state.history_entries)
 
398
  tickets_remaining = max(0, queue_size - idx)
399
  tickets_after_current = max(
400
  0,
401
  tickets_remaining - (1 if ticket_view is not None else 0),
402
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
  return HelpdeskTicketObservation(
405
  done=done,
406
  reward=reward,
407
- metadata={
408
- "queue_position": queue_position,
409
- "tickets_remaining_includes_current": ticket_view is not None,
410
- "has_ambiguity_note": bool(ticket_view and ticket_view.get("ambiguity_note")),
411
- "has_related_ticket_context": bool(
412
- ticket_view and ticket_view.get("related_ticket_preview")
413
- ),
414
- "action_mode": "investigate_or_submit",
415
- },
416
  task_id=task["id"],
417
  task_name=task["name"],
418
  instructions=task["instructions"],
419
  allowed_fields=list(task["allowed_fields"]),
 
420
  available_tools=list(AVAILABLE_TOOLS),
421
  investigation_budget_remaining=self._state.investigation_budget_remaining,
422
  last_tool_result=self._state.last_tool_result,
@@ -426,5 +786,8 @@ class HelpdeskTicketRoutingEnvironment(
426
  tickets_after_current=tickets_after_current,
427
  tickets_processed=idx,
428
  queue_position=queue_position,
 
 
429
  history=history,
 
430
  )
 
18
 
19
 
20
  QUEUE_SIZE_RANGE = (3, 5)
21
+ AVAILABLE_ACTION_TYPES = ("submit", "investigate")
22
+ AVAILABLE_TOOLS = (
23
+ "lookup_related_ticket",
24
+ "lookup_requester_history",
25
+ "lookup_internal_routing_note",
26
+ )
27
  FREE_INVESTIGATIONS_PER_TICKET = 1
28
  EXTRA_INVESTIGATION_COST = 0.02
29
  MAX_EXTRA_INVESTIGATION_PENALTY = 0.15
30
+ USEFUL_INVESTIGATION_REWARD = 0.08
31
+ PREMATURE_SUBMIT_PENALTY = 0.10
32
+
33
+ TASK3_INVESTIGATION_TOOL_PLAN: dict[str, tuple[str, ...]] = {
34
+ "ticket-021": ("lookup_related_ticket", "lookup_requester_history"),
35
+ "ticket-022": ("lookup_internal_routing_note",),
36
+ "ticket-027": ("lookup_internal_routing_note",),
37
+ "ticket-029": ("lookup_internal_routing_note",),
38
+ "ticket-038": ("lookup_related_ticket", "lookup_requester_history"),
39
+ "ticket-045": ("lookup_related_ticket", "lookup_requester_history"),
40
+ "TKT-NONDEFAULT-001": ("lookup_internal_routing_note",),
41
+ "TKT-NONDEFAULT-002": ("lookup_internal_routing_note",),
42
+ "TKT-NONDEFAULT-003": ("lookup_internal_routing_note",),
43
+ }
44
+
45
+ HARD_TASK_DESCRIPTION_REDACTIONS: dict[str, str] = {
46
+ "ticket-021": (
47
+ "Production checkout is still unstable after a recent fix. "
48
+ "Additional routing context is available via investigation."
49
+ ),
50
+ "ticket-022": (
51
+ "Usage charges increased while the integration was failing. "
52
+ "Additional routing context is available via investigation."
53
+ ),
54
+ "ticket-027": (
55
+ "A vendor offer arrived with a near-term deadline. "
56
+ "Additional routing context is available via investigation."
57
+ ),
58
+ "ticket-029": (
59
+ "A team needs a large seat expansion right away. "
60
+ "Additional routing context is available via investigation."
61
+ ),
62
+ "ticket-038": (
63
+ "A prior invoice discrepancy is still unresolved and now time-sensitive. "
64
+ "Additional routing context is available via investigation."
65
+ ),
66
+ "ticket-045": (
67
+ "A company-wide suspension remains unresolved after repeated follow-ups. "
68
+ "Additional routing context is available via investigation."
69
+ ),
70
+ "TKT-NONDEFAULT-001": (
71
+ "A user needs help with a billing-style question. "
72
+ "Additional routing context is available via investigation."
73
+ ),
74
+ "TKT-NONDEFAULT-002": (
75
+ "A client compliance scan surfaced a product-specific issue. "
76
+ "Additional routing context is available via investigation."
77
+ ),
78
+ "TKT-NONDEFAULT-003": (
79
+ "A contractor onboarding workflow is blocked by an account problem. "
80
+ "Additional routing context is available via investigation."
81
+ ),
82
+ }
83
 
84
 
85
  def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
 
144
  current_ticket_index=0,
145
  per_ticket_scores=[],
146
  total_reward=0.0,
147
+ average_score_so_far=0.0,
148
  investigation_budget_remaining=queue_size * FREE_INVESTIGATIONS_PER_TICKET,
149
+ investigation_penalty_applied=0.0,
150
+ last_reward_components={},
151
+ ticket_tool_usage={},
152
  )
153
 
154
  return self._build_observation(task)
 
184
  if extra_fields:
185
  # Penalty: record score 0.0, advance index, return penalty observation
186
  self._state.per_ticket_scores.append(0.0)
187
+ self._state.average_score_so_far = self._current_average_score()
 
 
 
 
 
 
 
 
 
188
  self._state.step_count += 1
189
  self._state.current_ticket_index += 1
190
  is_done = self._state.current_ticket_index >= len(self._queue)
191
  self._state.done = is_done
192
+ trajectory_reward = None
193
+ investigation_penalty = self._compute_episode_penalty() if is_done else 0.0
194
  if is_done:
195
+ trajectory_reward = compute_trajectory_reward(
196
  self._state.per_ticket_scores, len(self._queue), self._state.step_count
197
  )
198
+ final_reward = self._apply_episode_economics(trajectory_reward)
199
  self._state.total_reward = final_reward
200
  else:
201
  final_reward = 0.0
202
+ reward_components = self._build_reward_components(
203
+ ticket_score=0.0,
204
+ field_breakdown={},
205
+ shaped_step_reward=0.0,
206
+ reward_kind="trajectory" if is_done else "step_penalty",
207
+ final_reward=final_reward,
208
+ trajectory_reward=trajectory_reward,
209
+ investigation_penalty=investigation_penalty,
210
+ penalty_reason=f"extra_fields: {sorted(extra_fields)}",
211
+ )
212
+ self._state.history_entries.append(
213
+ self._build_history_entry(
214
+ current_ticket,
215
+ predicted=action.model_dump(exclude_none=True),
216
+ score=0.0,
217
+ breakdown={},
218
+ queue_position=idx + 1,
219
+ reward=final_reward,
220
+ rubric_reward=final_reward if is_done else None,
221
+ reward_kind="trajectory" if is_done else "step_penalty",
222
+ penalty_reason=f"extra_fields: {sorted(extra_fields)}",
223
+ reward_components=reward_components,
224
+ )
225
+ )
226
  self._state.last_step_reward = final_reward
227
  self._state.reward = final_reward
228
+ self._state.investigation_penalty_applied = self._compute_episode_penalty()
229
  self._state.last_tool_result = None
230
+ self._state.last_reward_components = reward_components
231
+ return self._build_observation(
232
+ task,
233
+ done=is_done,
234
+ reward=final_reward,
235
+ rubric_reward=final_reward if is_done else None,
236
+ )
237
 
238
  score, breakdown = grade_action(action, current_ticket, task_id)
239
  step_reward = compute_step_reward(score)
240
+ context_penalty, missing_required_tools = self._submit_context_penalty(current_ticket)
241
+ milestone_adjustment = step_reward - score
242
 
243
  is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
244
+ trajectory_reward = None
245
+ investigation_penalty = 0.0
246
+ rubric_reward = None
247
 
248
  if is_done:
249
  self._state.per_ticket_scores.append(score)
250
+ self._state.average_score_so_far = self._current_average_score()
251
  self._state.step_count += 1
252
  self._state.current_ticket_index += 1
253
+ trajectory_reward = compute_trajectory_reward(
254
  self._state.per_ticket_scores,
255
  len(self._queue),
256
  self._state.step_count,
257
  )
258
+ rubric_reward = self._apply_episode_economics(trajectory_reward)
259
+ final_reward = max(0.0, min(1.0, rubric_reward - context_penalty))
260
+ self._state.total_reward = rubric_reward
261
+ investigation_penalty = self._compute_episode_penalty()
262
  else:
263
  self._state.per_ticket_scores.append(score)
264
+ self._state.average_score_so_far = self._current_average_score()
265
  self._state.step_count += 1
266
  self._state.current_ticket_index += 1
267
+ final_reward = max(0.0, min(1.0, step_reward - context_penalty))
268
+
269
+ reward_components = self._build_reward_components(
270
+ ticket_score=score,
271
+ field_breakdown=breakdown,
272
+ shaped_step_reward=step_reward,
273
+ reward_kind="trajectory" if is_done else "step",
274
+ final_reward=final_reward,
275
+ milestone_adjustment=milestone_adjustment,
276
+ trajectory_reward=trajectory_reward,
277
+ investigation_penalty=investigation_penalty,
278
+ extra_details={
279
+ "context_gap_penalty": context_penalty,
280
+ "required_tools": self._required_tools_for_ticket(current_ticket),
281
+ "remaining_required_tools": missing_required_tools,
282
+ "rubric_reward": rubric_reward,
283
+ },
284
+ )
285
 
286
  history_entry = self._build_history_entry(
287
  current_ticket,
 
289
  score=score,
290
  breakdown=breakdown,
291
  queue_position=idx + 1,
292
+ reward=final_reward,
293
+ rubric_reward=rubric_reward if is_done else None,
294
+ reward_kind="trajectory" if is_done else "step",
295
+ reward_components=reward_components,
296
  )
297
  self._state.history_entries.append(history_entry)
298
 
299
  self._state.last_step_reward = final_reward
300
  self._state.reward = final_reward
301
  self._state.done = is_done
302
+ self._state.investigation_penalty_applied = self._compute_episode_penalty()
303
  self._state.last_tool_result = None
304
+ self._state.last_reward_components = reward_components
305
 
306
+ return self._build_observation(
307
+ task,
308
+ done=is_done,
309
+ reward=final_reward,
310
+ rubric_reward=rubric_reward if is_done else None,
311
+ )
312
 
313
  @property
314
  def state(self) -> HelpdeskTicketState:
 
318
  # Helpers
319
  # ------------------------------------------------------------------
320
 
321
+ def _compute_episode_penalty(self) -> float:
322
  free_investigations = len(self._queue) * FREE_INVESTIGATIONS_PER_TICKET
323
  extra_investigations = max(0, self._state.investigation_steps - free_investigations)
324
+ return min(
325
  MAX_EXTRA_INVESTIGATION_PENALTY,
326
  extra_investigations * EXTRA_INVESTIGATION_COST,
327
  )
328
+
329
+ def _apply_episode_economics(self, base_reward: float) -> float:
330
+ penalty = self._compute_episode_penalty()
331
  return max(0.0, min(1.0, base_reward - penalty))
332
 
333
+ def _current_average_score(self) -> float:
334
+ if not self._state.per_ticket_scores:
335
+ return 0.0
336
+ return sum(self._state.per_ticket_scores) / len(self._state.per_ticket_scores)
337
+
338
+ def _required_tools_for_ticket(
339
+ self,
340
+ ticket: HelpdeskTicketRecord,
341
+ task_id: int | None = None,
342
+ ) -> list[str]:
343
+ resolved_task_id = self._state.current_task_id if task_id is None else task_id
344
+ if resolved_task_id != 3:
345
+ return []
346
+ return list(TASK3_INVESTIGATION_TOOL_PLAN.get(ticket.ticket_id, ()))
347
+
348
+ def _used_tools_for_ticket(self, ticket_id: str) -> list[str]:
349
+ return list(self._state.ticket_tool_usage.get(ticket_id, []))
350
+
351
+ def _remaining_tools_for_ticket(
352
+ self,
353
+ ticket: HelpdeskTicketRecord,
354
+ task_id: int | None = None,
355
+ ) -> list[str]:
356
+ required_tools = self._required_tools_for_ticket(ticket, task_id)
357
+ used_tools = set(self._used_tools_for_ticket(ticket.ticket_id))
358
+ return [tool for tool in required_tools if tool not in used_tools]
359
+
360
+ def _record_tool_usage(self, ticket_id: str, tool_name: str) -> None:
361
+ used = self._state.ticket_tool_usage.setdefault(ticket_id, [])
362
+ if tool_name not in used:
363
+ used.append(tool_name)
364
+
365
+ def _investigation_hints_for_ticket(self, ticket: HelpdeskTicketRecord) -> list[str]:
366
+ hints: list[str] = []
367
+ remaining_tools = self._remaining_tools_for_ticket(ticket)
368
+ if "lookup_internal_routing_note" in remaining_tools:
369
+ hints.append("An internal routing note may disambiguate the correct workflow.")
370
+ if "lookup_related_ticket" in remaining_tools:
371
+ hints.append("A linked prior ticket can reveal important follow-up context.")
372
+ if "lookup_requester_history" in remaining_tools:
373
+ hints.append("Requester history may clarify severity or routing intent.")
374
+ return hints
375
+
376
+ def _visible_description(self, ticket: HelpdeskTicketRecord) -> str:
377
+ if (
378
+ self._state.current_task_id == 3
379
+ and self._remaining_tools_for_ticket(ticket)
380
+ and ticket.ticket_id in HARD_TASK_DESCRIPTION_REDACTIONS
381
+ ):
382
+ return HARD_TASK_DESCRIPTION_REDACTIONS[ticket.ticket_id]
383
+ return ticket.description
384
+
385
+ def _submit_context_penalty(self, ticket: HelpdeskTicketRecord) -> tuple[float, list[str]]:
386
+ required_tools = self._required_tools_for_ticket(ticket)
387
+ if not required_tools:
388
+ return 0.0, []
389
+ remaining_tools = self._remaining_tools_for_ticket(ticket)
390
+ if not remaining_tools:
391
+ return 0.0, []
392
+ penalty = PREMATURE_SUBMIT_PENALTY * (len(remaining_tools) / len(required_tools))
393
+ return penalty, remaining_tools
394
+
395
+ def _build_reward_components(
396
+ self,
397
+ *,
398
+ ticket_score: float,
399
+ field_breakdown: dict[str, float],
400
+ shaped_step_reward: float,
401
+ reward_kind: str,
402
+ final_reward: float,
403
+ milestone_adjustment: float = 0.0,
404
+ trajectory_reward: float | None = None,
405
+ investigation_penalty: float = 0.0,
406
+ penalty_reason: str | None = None,
407
+ extra_details: dict[str, Any] | None = None,
408
+ ) -> dict[str, Any]:
409
+ components: dict[str, Any] = {
410
+ "reward_kind": reward_kind,
411
+ "ticket_score": ticket_score,
412
+ "field_breakdown": field_breakdown,
413
+ "shaped_step_reward": shaped_step_reward,
414
+ "milestone_adjustment": milestone_adjustment,
415
+ "final_reward": final_reward,
416
+ "average_score_so_far": self._current_average_score(),
417
+ "investigation_penalty_applied": investigation_penalty,
418
+ }
419
+ if trajectory_reward is not None:
420
+ components["trajectory_reward"] = trajectory_reward
421
+ if penalty_reason is not None:
422
+ components["penalty_reason"] = penalty_reason
423
+ if extra_details:
424
+ components.update(extra_details)
425
+ return components
426
+
427
  def _lookup_related_ticket(
428
  self,
429
  current_ticket: HelpdeskTicketRecord,
 
479
  "matches": matches,
480
  }
481
 
482
+ def _lookup_internal_routing_note(self, current_ticket: HelpdeskTicketRecord) -> dict[str, Any]:
483
+ found = current_ticket.ambiguity_note is not None
484
+ return {
485
+ "tool_name": "lookup_internal_routing_note",
486
+ "found": found,
487
+ "ticket_id": current_ticket.ticket_id,
488
+ "routing_note": current_ticket.ambiguity_note if found else "",
489
+ }
490
+
491
  def _run_investigation_tool(
492
  self,
493
  current_ticket: HelpdeskTicketRecord,
 
498
  return self._lookup_related_ticket(current_ticket, target_ticket_id)
499
  if tool_name == "lookup_requester_history":
500
  return self._lookup_requester_history(current_ticket)
501
+ if tool_name == "lookup_internal_routing_note":
502
+ return self._lookup_internal_routing_note(current_ticket)
503
  raise ValueError(f"Unsupported tool_name: {tool_name}")
504
 
505
  def _handle_investigation_action(
 
527
  action.tool_name,
528
  action.tool_target_ticket_id,
529
  )
530
+ required_tools = self._required_tools_for_ticket(current_ticket)
531
+ already_used = action.tool_name in self._used_tools_for_ticket(current_ticket.ticket_id)
532
+ useful_investigation = (
533
+ action.tool_name in required_tools
534
+ and not already_used
535
+ and bool(tool_result.get("found", True))
536
+ )
537
+ self._record_tool_usage(current_ticket.ticket_id, action.tool_name)
538
  self._state.step_count += 1
539
  self._state.investigation_steps += 1
540
  self._state.investigation_budget_remaining = max(
 
542
  self._state.investigation_budget_remaining - 1,
543
  )
544
  self._state.last_tool_result = tool_result
545
+ investigation_reward = USEFUL_INVESTIGATION_REWARD if useful_investigation else 0.0
546
+ self._state.last_step_reward = investigation_reward
547
+ self._state.reward = investigation_reward
548
  self._state.done = False
549
+ self._state.investigation_penalty_applied = self._compute_episode_penalty()
550
+ reward_components = self._build_reward_components(
551
+ ticket_score=0.0,
552
+ field_breakdown={},
553
+ shaped_step_reward=investigation_reward,
554
+ reward_kind="investigation",
555
+ final_reward=investigation_reward,
556
+ investigation_penalty=self._state.investigation_penalty_applied,
557
+ extra_details={
558
+ "new_context_revealed": useful_investigation,
559
+ "required_tools": required_tools,
560
+ "remaining_required_tools": self._remaining_tools_for_ticket(current_ticket),
561
+ "tool_name": action.tool_name,
562
+ },
563
+ )
564
  self._state.history_entries.append(
565
  self._build_history_entry(
566
  current_ticket,
 
568
  score=0.0,
569
  breakdown={},
570
  queue_position=idx + 1,
571
+ reward=investigation_reward,
572
+ reward_kind="investigation",
573
  tool_result=tool_result,
574
+ reward_components=reward_components,
575
  )
576
  )
577
+ self._state.last_reward_components = reward_components
578
+ return self._build_observation(task, done=False, reward=investigation_reward)
579
 
580
  def _build_ticket_view(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
581
+ required_tools = self._required_tools_for_ticket(ticket)
582
+ revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
583
+ remaining_tools = self._remaining_tools_for_ticket(ticket)
584
  ticket_view: dict[str, Any] = {
585
  "ticket_id": ticket.ticket_id,
586
  "title": ticket.title,
587
  "requester": ticket.requester,
588
+ "description": self._visible_description(ticket),
589
  }
590
+ if required_tools:
591
+ ticket_view["context_status"] = {
592
+ "investigation_required": True,
593
+ "revealed_tools": revealed_tools,
594
+ "remaining_tools": remaining_tools,
595
+ "hints": self._investigation_hints_for_ticket(ticket),
596
+ }
597
+ if ticket.ambiguity_note is not None and "lookup_internal_routing_note" not in remaining_tools:
598
  ticket_view["ambiguity_note"] = ticket.ambiguity_note
599
+ if ticket.related_ticket_id is not None and "lookup_related_ticket" not in remaining_tools:
600
  ticket_view["related_ticket_id"] = ticket.related_ticket_id
601
  related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
602
  if related_ticket is not None:
 
608
  }
609
  return ticket_view
610
 
611
+ def _build_feedback_summary(
612
+ self,
613
+ *,
614
+ predicted: dict[str, Any],
615
+ score: float,
616
+ breakdown: dict[str, float],
617
+ reward: float | None = None,
618
+ rubric_reward: float | None = None,
619
+ reward_kind: str | None = None,
620
+ penalty_reason: str | None = None,
621
+ tool_result: dict[str, Any] | None = None,
622
+ reward_components: dict[str, Any] | None = None,
623
+ ) -> str:
624
+ parts: list[str] = []
625
+
626
+ if reward_kind == "investigation":
627
+ tool_name = predicted.get("tool_name") or (tool_result or {}).get("tool_name")
628
+ parts.append(f"Investigation step used {tool_name or 'a tool'}")
629
+ if reward_components and reward_components.get("new_context_revealed"):
630
+ parts.append("new context was revealed")
631
+ elif penalty_reason is not None:
632
+ parts.append(f"Penalty applied: {penalty_reason}")
633
+ else:
634
+ parts.append(f"Ticket score={score:.2f}")
635
+
636
+ if breakdown:
637
+ field_scores = ", ".join(
638
+ f"{field}={value:.2f}" for field, value in sorted(breakdown.items())
639
+ )
640
+ parts.append(f"field_scores[{field_scores}]")
641
+ if reward is not None:
642
+ parts.append(f"reward={reward:.2f}")
643
+ if rubric_reward is not None:
644
+ parts.append(f"rubric_reward={rubric_reward:.2f}")
645
+ if reward_components:
646
+ context_gap_penalty = reward_components.get("context_gap_penalty")
647
+ if context_gap_penalty:
648
+ parts.append(f"context_gap_penalty={context_gap_penalty:.2f}")
649
+ remaining_required_tools = reward_components.get("remaining_required_tools") or []
650
+ if remaining_required_tools:
651
+ parts.append(f"missing_context={remaining_required_tools}")
652
+
653
+ return "; ".join(parts)
654
+
655
  def _build_history_entry(
656
  self,
657
  ticket: HelpdeskTicketRecord,
 
660
  score: float,
661
  breakdown: dict[str, float],
662
  queue_position: int,
663
+ reward: float | None = None,
664
+ rubric_reward: float | None = None,
665
+ reward_kind: str | None = None,
666
  penalty_reason: str | None = None,
667
  tool_result: dict[str, Any] | None = None,
668
+ reward_components: dict[str, Any] | None = None,
669
  ) -> dict[str, Any]:
670
+ remaining_tools = self._remaining_tools_for_ticket(ticket)
671
+ revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
672
  history_entry: dict[str, Any] = {
673
  "ticket_id": ticket.ticket_id,
674
  "title": ticket.title,
 
678
  "breakdown": breakdown,
679
  "queue_position": queue_position,
680
  }
681
+ if reward is not None:
682
+ history_entry["reward"] = reward
683
+ if rubric_reward is not None:
684
+ history_entry["rubric_reward"] = rubric_reward
685
+ if reward_kind is not None:
686
+ history_entry["reward_kind"] = reward_kind
687
+ if ticket.ambiguity_note is not None and "lookup_internal_routing_note" not in remaining_tools:
688
  history_entry["ambiguity_note"] = ticket.ambiguity_note
689
+ if ticket.related_ticket_id is not None and "lookup_related_ticket" not in remaining_tools:
690
  history_entry["related_ticket_id"] = ticket.related_ticket_id
691
  related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
692
  if related_ticket is not None:
 
700
  history_entry["penalty_reason"] = penalty_reason
701
  if tool_result is not None:
702
  history_entry["tool_result"] = tool_result
703
+ if reward_components is not None:
704
+ history_entry["reward_components"] = reward_components
705
+ if revealed_tools:
706
+ history_entry["revealed_tools"] = revealed_tools
707
+ history_entry["feedback_summary"] = self._build_feedback_summary(
708
+ predicted=predicted,
709
+ score=score,
710
+ breakdown=breakdown,
711
+ reward=reward,
712
+ rubric_reward=rubric_reward,
713
+ reward_kind=reward_kind,
714
+ penalty_reason=penalty_reason,
715
+ tool_result=tool_result,
716
+ reward_components=reward_components,
717
+ )
718
  return history_entry
719
 
720
  def _build_observation(
 
722
  task: dict,
723
  done: bool = False,
724
  reward: float | None = None,
725
+ rubric_reward: float | None = None,
726
  ) -> HelpdeskTicketObservation:
727
  idx = self._state.current_ticket_index
728
  queue_size = len(self._queue)
 
736
  queue_position = 0
737
 
738
  history = list(self._state.history_entries)
739
+ last_history_entry = history[-1] if history else None
740
  tickets_remaining = max(0, queue_size - idx)
741
  tickets_after_current = max(
742
  0,
743
  tickets_remaining - (1 if ticket_view is not None else 0),
744
  )
745
+ progress_fraction = (idx / queue_size) if queue_size else 0.0
746
+
747
+ metadata = {
748
+ "queue_position": queue_position,
749
+ "tickets_remaining_includes_current": ticket_view is not None,
750
+ "has_ambiguity_note": bool(ticket_view and ticket_view.get("ambiguity_note")),
751
+ "has_related_ticket_context": bool(
752
+ ticket_view and ticket_view.get("related_ticket_preview")
753
+ ),
754
+ "action_mode": "investigate_or_submit",
755
+ "available_action_types": list(AVAILABLE_ACTION_TYPES),
756
+ "average_score_so_far": self._state.average_score_so_far,
757
+ "progress_fraction": progress_fraction,
758
+ "investigation_penalty_applied": self._state.investigation_penalty_applied,
759
+ }
760
+ if last_history_entry is not None:
761
+ metadata["last_score"] = last_history_entry.get("score")
762
+ metadata["last_reward"] = last_history_entry.get("reward")
763
+ metadata["last_reward_kind"] = last_history_entry.get("reward_kind")
764
+ metadata["last_breakdown"] = last_history_entry.get("breakdown")
765
+ metadata["last_feedback_summary"] = last_history_entry.get("feedback_summary")
766
+ metadata["last_reward_components"] = last_history_entry.get("reward_components", {})
767
+ if "penalty_reason" in last_history_entry:
768
+ metadata["last_penalty_reason"] = last_history_entry["penalty_reason"]
769
 
770
  return HelpdeskTicketObservation(
771
  done=done,
772
  reward=reward,
773
+ rubric_reward=rubric_reward,
774
+ metadata=metadata,
 
 
 
 
 
 
 
775
  task_id=task["id"],
776
  task_name=task["name"],
777
  instructions=task["instructions"],
778
  allowed_fields=list(task["allowed_fields"]),
779
+ available_action_types=list(AVAILABLE_ACTION_TYPES),
780
  available_tools=list(AVAILABLE_TOOLS),
781
  investigation_budget_remaining=self._state.investigation_budget_remaining,
782
  last_tool_result=self._state.last_tool_result,
 
786
  tickets_after_current=tickets_after_current,
787
  tickets_processed=idx,
788
  queue_position=queue_position,
789
+ average_score_so_far=self._state.average_score_so_far,
790
+ progress_fraction=progress_fraction,
791
  history=history,
792
+ last_reward_components=dict(self._state.last_reward_components),
793
  )
server/tasks.py CHANGED
@@ -37,7 +37,9 @@ TASKS = {
37
  "Perform full helpdesk routing by selecting the best issue type, "
38
  "priority, assignment group, and resolution action for the ticket. "
39
  "Use any ambiguity notes or related-ticket previews when present. "
40
- "You may investigate with tools before you submit the final action."
 
 
41
  ),
42
  "allowed_fields": [
43
  "issue_type",
 
37
  "Perform full helpdesk routing by selecting the best issue type, "
38
  "priority, assignment group, and resolution action for the ticket. "
39
  "Use any ambiguity notes or related-ticket previews when present. "
40
+ "Some hard tickets intentionally hide decisive routing context until "
41
+ "you investigate with the available tools, so premature submission can "
42
+ "underperform even when the visible text looks plausible."
43
  ),
44
  "allowed_fields": [
45
  "issue_type",
tests/test_api_integration.py CHANGED
@@ -167,6 +167,9 @@ class TestResetEndpoint(unittest.TestCase):
167
  def test_reset_reward_is_null(self):
168
  self.assertIsNone(self.data["reward"])
169
 
 
 
 
170
  def test_reset_task_id_is_1(self):
171
  self.assertEqual(self.data["task_id"], 1)
172
 
@@ -177,6 +180,13 @@ class TestResetEndpoint(unittest.TestCase):
177
  self.assertIsInstance(self.data["allowed_fields"], list)
178
  self.assertGreater(len(self.data["allowed_fields"]), 0)
179
 
 
 
 
 
 
 
 
180
 
181
  class TestStepEndpoint(unittest.TestCase):
182
  """2.1.4 — POST /step returns observation JSON with reward in [0.0, 1.0]."""
@@ -200,6 +210,35 @@ class TestStepEndpoint(unittest.TestCase):
200
  def test_step_tickets_processed_is_1(self):
201
  self.assertEqual(self.data["tickets_processed"], 1)
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  class TestStateEndpoint(unittest.TestCase):
205
  """2.1.5 — GET /state returns current episode state JSON after a reset."""
@@ -278,6 +317,38 @@ class TestFullSeededEpisode(unittest.TestCase):
278
  self.assertGreaterEqual(final_reward, 0.0)
279
  self.assertLessEqual(final_reward, 1.0)
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def test_full_episode_all_tasks_complete(self):
282
  """4.1.1 — Full seeded episode completes for each task ID (1, 2, 3)."""
283
  for task_id in (1, 2, 3):
 
167
  def test_reset_reward_is_null(self):
168
  self.assertIsNone(self.data["reward"])
169
 
170
+ def test_reset_rubric_reward_is_null(self):
171
+ self.assertIsNone(self.data["rubric_reward"])
172
+
173
  def test_reset_task_id_is_1(self):
174
  self.assertEqual(self.data["task_id"], 1)
175
 
 
180
  self.assertIsInstance(self.data["allowed_fields"], list)
181
  self.assertGreater(len(self.data["allowed_fields"]), 0)
182
 
183
+ def test_reset_available_action_types_exposed(self):
184
+ self.assertEqual(self.data["available_action_types"], ["submit", "investigate"])
185
+
186
+ def test_reset_progress_metrics_start_at_zero(self):
187
+ self.assertEqual(self.data["average_score_so_far"], 0.0)
188
+ self.assertEqual(self.data["progress_fraction"], 0.0)
189
+
190
 
191
  class TestStepEndpoint(unittest.TestCase):
192
  """2.1.4 — POST /step returns observation JSON with reward in [0.0, 1.0]."""
 
210
  def test_step_tickets_processed_is_1(self):
211
  self.assertEqual(self.data["tickets_processed"], 1)
212
 
213
+ def test_step_metadata_exposes_last_feedback_summary(self):
214
+ metadata = self.data.get("metadata", {})
215
+ self.assertIn("last_feedback_summary", metadata)
216
+ self.assertIsInstance(metadata["last_feedback_summary"], str)
217
+ self.assertTrue(metadata["last_feedback_summary"])
218
+
219
+ def test_step_history_entry_includes_feedback_summary(self):
220
+ history = self.data.get("history", [])
221
+ self.assertGreater(len(history), 0)
222
+ self.assertIn("feedback_summary", history[-1])
223
+ self.assertIsInstance(history[-1]["feedback_summary"], str)
224
+ self.assertTrue(history[-1]["feedback_summary"])
225
+
226
+ def test_step_exposes_structured_reward_components(self):
227
+ self.assertIn("last_reward_components", self.data)
228
+ self.assertIsInstance(self.data["last_reward_components"], dict)
229
+ self.assertIn("ticket_score", self.data["last_reward_components"])
230
+ self.assertIn("final_reward", self.data["last_reward_components"])
231
+ self.assertEqual(
232
+ self.data["metadata"].get("last_reward_components"),
233
+ self.data["last_reward_components"],
234
+ )
235
+
236
+ def test_step_progress_metrics_are_exposed(self):
237
+ self.assertIn("average_score_so_far", self.data)
238
+ self.assertIn("progress_fraction", self.data)
239
+ self.assertGreaterEqual(self.data["progress_fraction"], 0.0)
240
+ self.assertLessEqual(self.data["progress_fraction"], 1.0)
241
+
242
 
243
  class TestStateEndpoint(unittest.TestCase):
244
  """2.1.5 — GET /state returns current episode state JSON after a reset."""
 
317
  self.assertGreaterEqual(final_reward, 0.0)
318
  self.assertLessEqual(final_reward, 1.0)
319
 
320
+ def test_full_episode_terminal_rubric_reward_in_unit_interval(self):
321
+ reset_resp = _reset(task_id=1, seed=42)
322
+ self.assertEqual(reset_resp.status_code, 200)
323
+ obs = reset_resp.json()
324
+
325
+ allowed_fields = obs["allowed_fields"]
326
+ final_rubric_reward = None
327
+ for _ in range(20):
328
+ action_payload: dict = {}
329
+ if "issue_type" in allowed_fields:
330
+ action_payload["issue_type"] = "general_inquiry"
331
+ if "priority" in allowed_fields:
332
+ action_payload["priority"] = "medium"
333
+ if "assignment_group" in allowed_fields:
334
+ action_payload["assignment_group"] = "service_desk"
335
+ if "resolution_action" in allowed_fields:
336
+ action_payload["resolution_action"] = "acknowledge"
337
+
338
+ step_resp = client.post("/step", json=action_payload)
339
+ self.assertEqual(step_resp.status_code, 200)
340
+ obs = step_resp.json()
341
+
342
+ if obs["done"]:
343
+ final_rubric_reward = obs.get("rubric_reward")
344
+ break
345
+
346
+ self.assertIsNotNone(
347
+ final_rubric_reward, "Terminal observation did not include rubric_reward"
348
+ )
349
+ self.assertGreaterEqual(final_rubric_reward, 0.0)
350
+ self.assertLessEqual(final_rubric_reward, 1.0)
351
+
352
  def test_full_episode_all_tasks_complete(self):
353
  """4.1.1 — Full seeded episode completes for each task ID (1, 2, 3)."""
354
  for task_id in (1, 2, 3):
tests/test_competitive_upgrade.py CHANGED
@@ -182,6 +182,16 @@ class TestStateHasRewardAndDone(unittest.TestCase):
182
  obs = env.step(_heuristic_action(obs))
183
  self.assertFalse(env.state.done)
184
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # ---------------------------------------------------------------------------
187
  # 9.3 — History entry contains title and predicted
@@ -318,7 +328,7 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
318
  return seed
319
  return None
320
 
321
- def test_ambiguity_note_present_when_ticket_has_one(self) -> None:
322
  """Force a ticket with ambiguity_note by patching the dataset."""
323
  from unittest.mock import patch
324
  from server.tasks import load_dataset
@@ -336,8 +346,22 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
336
  obs = env.reset(seed=0, task_id=3)
337
 
338
  self.assertIsNotNone(obs.current_ticket)
339
- self.assertIn("ambiguity_note", obs.current_ticket)
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  self.assertEqual(obs.current_ticket["ambiguity_note"], target.ambiguity_note)
 
341
 
342
  def test_ambiguity_note_absent_when_ticket_has_none(self) -> None:
343
  """Tickets without ambiguity_note should not expose the key."""
@@ -370,6 +394,13 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
370
  with patch.object(env, "_dataset", [ticket]):
371
  obs = env.reset(seed=0, task_id=3)
372
 
 
 
 
 
 
 
 
373
  self.assertIn("ambiguity_note", obs.current_ticket)
374
 
375
 
@@ -397,12 +428,27 @@ class TestRelatedTicketPreviewInObservation(unittest.TestCase):
397
  ):
398
  obs = env.reset(seed=0, task_id=3, queue_size=1)
399
 
400
- return env, obs, related
401
 
402
  def test_related_ticket_preview_present_when_ticket_has_link(self) -> None:
403
- env, obs, related = self._reset_linked_ticket_env()
404
 
405
  self.assertIsNotNone(obs.current_ticket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  self.assertIn("related_ticket_preview", obs.current_ticket)
407
  self.assertEqual(
408
  obs.current_ticket["related_ticket_preview"]["ticket_id"],
@@ -414,8 +460,22 @@ class TestRelatedTicketPreviewInObservation(unittest.TestCase):
414
  )
415
 
416
  def test_history_keeps_related_ticket_preview_after_step(self) -> None:
417
- env, obs, related = self._reset_linked_ticket_env()
418
- next_obs = env.step(_heuristic_action(obs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  self.assertGreaterEqual(len(next_obs.history), 1)
421
  self.assertIn("related_ticket_preview", next_obs.history[0])
@@ -563,6 +623,58 @@ class TestInvestigationActions(unittest.TestCase):
563
  self.assertTrue(obs2.last_tool_result["found"])
564
  self.assertGreaterEqual(len(obs2.last_tool_result["matches"]), 1)
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
  class TestQueueEconomics(unittest.TestCase):
568
  """Free investigations are allowed, but excessive investigation gets a queue-level penalty."""
 
182
  obs = env.step(_heuristic_action(obs))
183
  self.assertFalse(env.state.done)
184
 
185
+ def test_state_tracks_average_score_and_reward_components(self) -> None:
186
+ env = _make_env()
187
+ obs = env.reset(seed=42, task_id=1)
188
+ env.step(_heuristic_action(obs))
189
+ state = env.state
190
+ self.assertGreaterEqual(state.average_score_so_far, 0.0)
191
+ self.assertLessEqual(state.average_score_so_far, 1.0)
192
+ self.assertIsInstance(state.last_reward_components, dict)
193
+ self.assertIn("final_reward", state.last_reward_components)
194
+
195
 
196
  # ---------------------------------------------------------------------------
197
  # 9.3 — History entry contains title and predicted
 
328
  return seed
329
  return None
330
 
331
+ def test_ambiguity_note_hidden_until_internal_note_lookup(self) -> None:
332
  """Force a ticket with ambiguity_note by patching the dataset."""
333
  from unittest.mock import patch
334
  from server.tasks import load_dataset
 
346
  obs = env.reset(seed=0, task_id=3)
347
 
348
  self.assertIsNotNone(obs.current_ticket)
349
+ self.assertNotIn("ambiguity_note", obs.current_ticket)
350
+ self.assertIn("context_status", obs.current_ticket)
351
+ self.assertIn(
352
+ "lookup_internal_routing_note",
353
+ obs.current_ticket["context_status"]["remaining_tools"],
354
+ )
355
+
356
+ obs = env.step(
357
+ HelpdeskTicketAction(
358
+ action_type="investigate",
359
+ tool_name="lookup_internal_routing_note",
360
+ )
361
+ )
362
+
363
  self.assertEqual(obs.current_ticket["ambiguity_note"], target.ambiguity_note)
364
+ self.assertGreater(obs.reward or 0.0, 0.0)
365
 
366
  def test_ambiguity_note_absent_when_ticket_has_none(self) -> None:
367
  """Tickets without ambiguity_note should not expose the key."""
 
394
  with patch.object(env, "_dataset", [ticket]):
395
  obs = env.reset(seed=0, task_id=3)
396
 
397
+ self.assertNotIn("ambiguity_note", obs.current_ticket)
398
+ obs = env.step(
399
+ HelpdeskTicketAction(
400
+ action_type="investigate",
401
+ tool_name="lookup_internal_routing_note",
402
+ )
403
+ )
404
  self.assertIn("ambiguity_note", obs.current_ticket)
405
 
406
 
 
428
  ):
429
  obs = env.reset(seed=0, task_id=3, queue_size=1)
430
 
431
+ return env, obs, ticket, related
432
 
433
  def test_related_ticket_preview_present_when_ticket_has_link(self) -> None:
434
+ env, obs, ticket, related = self._reset_linked_ticket_env()
435
 
436
  self.assertIsNotNone(obs.current_ticket)
437
+ self.assertNotIn("related_ticket_preview", obs.current_ticket)
438
+ self.assertIn("context_status", obs.current_ticket)
439
+ self.assertIn(
440
+ "lookup_related_ticket",
441
+ obs.current_ticket["context_status"]["remaining_tools"],
442
+ )
443
+
444
+ obs = env.step(
445
+ HelpdeskTicketAction(
446
+ action_type="investigate",
447
+ tool_name="lookup_related_ticket",
448
+ tool_target_ticket_id=ticket.related_ticket_id,
449
+ )
450
+ )
451
+
452
  self.assertIn("related_ticket_preview", obs.current_ticket)
453
  self.assertEqual(
454
  obs.current_ticket["related_ticket_preview"]["ticket_id"],
 
460
  )
461
 
462
  def test_history_keeps_related_ticket_preview_after_step(self) -> None:
463
+ env, obs, ticket, related = self._reset_linked_ticket_env()
464
+ env.step(
465
+ HelpdeskTicketAction(
466
+ action_type="investigate",
467
+ tool_name="lookup_related_ticket",
468
+ tool_target_ticket_id=ticket.related_ticket_id,
469
+ )
470
+ )
471
+ next_obs = env.step(
472
+ HelpdeskTicketAction(
473
+ issue_type=ticket.issue_type,
474
+ priority=ticket.priority,
475
+ assignment_group=ticket.assignment_group,
476
+ resolution_action=ticket.resolution_action,
477
+ )
478
+ )
479
 
480
  self.assertGreaterEqual(len(next_obs.history), 1)
481
  self.assertIn("related_ticket_preview", next_obs.history[0])
 
623
  self.assertTrue(obs2.last_tool_result["found"])
624
  self.assertGreaterEqual(len(obs2.last_tool_result["matches"]), 1)
625
 
626
+ def test_internal_note_tool_reveals_hidden_hard_task_context(self) -> None:
627
+ from unittest.mock import patch
628
+
629
+ dataset = load_dataset()
630
+ ticket = next((t for t in dataset if t.ticket_id == "TKT-NONDEFAULT-003"), None)
631
+ self.assertIsNotNone(ticket)
632
+
633
+ env = _make_env()
634
+ with patch.object(env, "_dataset", [ticket]):
635
+ with patch.object(env, "_tickets_by_id", {ticket.ticket_id: ticket}):
636
+ obs = env.reset(seed=0, task_id=3, queue_size=1)
637
+
638
+ self.assertNotIn("ambiguity_note", obs.current_ticket)
639
+ obs = env.step(
640
+ HelpdeskTicketAction(
641
+ action_type="investigate",
642
+ tool_name="lookup_internal_routing_note",
643
+ )
644
+ )
645
+ self.assertEqual(obs.last_tool_result["routing_note"], ticket.ambiguity_note)
646
+ self.assertEqual(obs.current_ticket["ambiguity_note"], ticket.ambiguity_note)
647
+ self.assertGreater(obs.reward or 0.0, 0.0)
648
+
649
+ def test_submit_without_required_investigation_gets_shaping_penalty(self) -> None:
650
+ from unittest.mock import patch
651
+
652
+ dataset = load_dataset()
653
+ ticket = next((t for t in dataset if t.ticket_id == "TKT-NONDEFAULT-003"), None)
654
+ self.assertIsNotNone(ticket)
655
+
656
+ env = _make_env()
657
+ with patch.object(env, "_dataset", [ticket]):
658
+ with patch.object(env, "_tickets_by_id", {ticket.ticket_id: ticket}):
659
+ obs = env.reset(seed=0, task_id=3, queue_size=1)
660
+
661
+ final_obs = env.step(
662
+ HelpdeskTicketAction(
663
+ issue_type=ticket.issue_type,
664
+ priority=ticket.priority,
665
+ assignment_group=ticket.assignment_group,
666
+ resolution_action=ticket.resolution_action,
667
+ )
668
+ )
669
+
670
+ self.assertTrue(final_obs.done)
671
+ self.assertIsNotNone(final_obs.rubric_reward)
672
+ self.assertLess(final_obs.reward, final_obs.rubric_reward)
673
+ self.assertGreater(
674
+ final_obs.last_reward_components.get("context_gap_penalty", 0.0),
675
+ 0.0,
676
+ )
677
+
678
 
679
  class TestQueueEconomics(unittest.TestCase):
680
  """Free investigations are allowed, but excessive investigation gets a queue-level penalty."""
tests/test_inference_unit.py CHANGED
@@ -140,6 +140,16 @@ class InferenceUnitTests(unittest.TestCase):
140
  self.assertIsNone(inference.HF_TOKEN)
141
  self.assertFalse(inference.llm_mode_enabled())
142
 
 
 
 
 
 
 
 
 
 
 
143
  def test_run_uses_only_structured_start_step_end_logs(self) -> None:
144
  inference = _load_inference_module()
145
 
@@ -179,6 +189,311 @@ class InferenceUnitTests(unittest.TestCase):
179
  [1, 2, 3],
180
  )
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
  unittest.main()
 
140
  self.assertIsNone(inference.HF_TOKEN)
141
  self.assertFalse(inference.llm_mode_enabled())
142
 
143
+ def test_seed_env_override_is_respected(self) -> None:
144
+ inference = _load_inference_module({"SEED": "7"})
145
+
146
+ self.assertEqual(inference.SEED, 7)
147
+
148
+ def test_invalid_seed_env_falls_back_to_default(self) -> None:
149
+ inference = _load_inference_module({"SEED": "not-an-int"})
150
+
151
+ self.assertEqual(inference.SEED, 42)
152
+
153
  def test_run_uses_only_structured_start_step_end_logs(self) -> None:
154
  inference = _load_inference_module()
155
 
 
189
  [1, 2, 3],
190
  )
191
 
192
+ def test_build_llm_user_message_includes_recent_history_feedback(self) -> None:
193
+ inference = _load_inference_module()
194
+
195
+ ticket = {
196
+ "ticket_id": "ticket-xyz",
197
+ "title": "Contractor onboarding blocked by access issue",
198
+ "requester": "pm@contractorco.com",
199
+ "description": "Access permissions are blocking contractor setup.",
200
+ "context_status": {
201
+ "investigation_required": True,
202
+ "revealed_tools": [],
203
+ "remaining_tools": ["lookup_internal_routing_note"],
204
+ "hints": ["An internal routing note may disambiguate the correct workflow."],
205
+ },
206
+ "last_tool_result": {"tool_name": "lookup_requester_history", "found": False},
207
+ "feedback_summary": "Ticket score=0.40; field_scores[issue_type=0.40]; reward=0.40",
208
+ "last_reward_components": {"ticket_score": 0.4, "final_reward": 0.4},
209
+ "investigation_budget_remaining": 2,
210
+ "average_score_so_far": 0.7,
211
+ "progress_fraction": 0.5,
212
+ "recent_history": [
213
+ {
214
+ "ticket_id": "ticket-prev",
215
+ "predicted": {"issue_type": "identity_access"},
216
+ "score": 0.4,
217
+ "breakdown": {"issue_type": 0.4},
218
+ "penalty_reason": "extra_fields: ['assignment_group']",
219
+ "feedback_summary": "Penalty applied: extra_fields: ['assignment_group']; reward=0.00",
220
+ "reward_components": {"reward_kind": "step_penalty", "final_reward": 0.0},
221
+ }
222
+ ],
223
+ "queue_position": 2,
224
+ "tickets_remaining": 4,
225
+ }
226
+
227
+ message = inference.build_llm_user_message(
228
+ ticket,
229
+ ["issue_type"],
230
+ "Read the ticket and select the single best IT issue type.",
231
+ )
232
+
233
+ self.assertIn("Recent evaluation feedback", message)
234
+ self.assertIn("score=0.4", message)
235
+ self.assertIn("penalty_reason=extra_fields", message)
236
+ self.assertIn("Latest environment feedback", message)
237
+ self.assertIn("Context status", message)
238
+ self.assertIn("Latest reward components", message)
239
+ self.assertIn("Average score so far: 0.7", message)
240
+ self.assertIn("Episode progress: 0.5", message)
241
+ self.assertIn("Investigation budget remaining: 2", message)
242
+ self.assertIn("Investigation result", message)
243
+ self.assertIn("queue_position=2", message)
244
+
245
+ def test_build_action_backfills_missing_fields_from_heuristic(self) -> None:
246
+ inference = _load_inference_module()
247
+ inference.llm_client = object()
248
+
249
+ ticket = {
250
+ "ticket_id": "ticket-018",
251
+ "title": "Question about enterprise tier pricing",
252
+ "requester": "finance@urbanstack.io",
253
+ "description": (
254
+ "We're comparing your enterprise plan against two competitors. "
255
+ "Can you send over a detailed pricing breakdown?"
256
+ ),
257
+ }
258
+
259
+ with mock.patch.object(
260
+ inference,
261
+ "call_llm",
262
+ return_value={"issue_type": "service_request"},
263
+ ):
264
+ action, action_source, fallback_reason = inference.build_action(
265
+ ticket,
266
+ ["issue_type", "priority", "assignment_group", "resolution_action"],
267
+ "Perform full helpdesk routing.",
268
+ )
269
+
270
+ self.assertEqual(action.issue_type, "service_request")
271
+ self.assertEqual(action.priority, "medium")
272
+ self.assertEqual(action.assignment_group, "procurement")
273
+ self.assertEqual(action.resolution_action, "assign")
274
+ self.assertEqual(action_source, "llm_backfilled")
275
+ self.assertIn("heuristic_backfill", fallback_reason or "")
276
+
277
+ def test_build_action_ignores_invalid_llm_fields_and_keeps_valid_ones(self) -> None:
278
+ inference = _load_inference_module()
279
+ inference.llm_client = object()
280
+
281
+ ticket = {
282
+ "ticket_id": "ticket-018",
283
+ "title": "Question about enterprise tier pricing",
284
+ "requester": "finance@urbanstack.io",
285
+ "description": (
286
+ "We're comparing your enterprise plan against two competitors. "
287
+ "Can you send over a detailed pricing breakdown?"
288
+ ),
289
+ }
290
+
291
+ with mock.patch.object(
292
+ inference,
293
+ "call_llm",
294
+ return_value={
295
+ "issue_type": "service_request",
296
+ "priority": "urgent",
297
+ },
298
+ ):
299
+ action, action_source, fallback_reason = inference.build_action(
300
+ ticket,
301
+ ["issue_type", "priority"],
302
+ "Read the ticket, select the best IT issue type, and estimate the priority.",
303
+ )
304
+
305
+ self.assertEqual(action.issue_type, "service_request")
306
+ self.assertEqual(action.priority, "medium")
307
+ self.assertEqual(action_source, "llm_backfilled")
308
+ self.assertIn("invalid_llm_fields=['priority']", fallback_reason or "")
309
+
310
+ def test_build_action_backfills_dependent_fields_from_llm_issue_type(self) -> None:
311
+ inference = _load_inference_module()
312
+ inference.llm_client = object()
313
+
314
+ ticket = {
315
+ "ticket_id": "ticket-002",
316
+ "title": "Can not sign in after 2FA reset",
317
+ "requester": "ops@laneeight.io",
318
+ "description": (
319
+ "I was forced to reset 2FA and now the account stays locked even "
320
+ "with the backup code."
321
+ ),
322
+ }
323
+
324
+ with mock.patch.object(
325
+ inference,
326
+ "call_llm",
327
+ return_value={"issue_type": "identity_access"},
328
+ ):
329
+ action, action_source, fallback_reason = inference.build_action(
330
+ ticket,
331
+ ["issue_type", "assignment_group", "resolution_action"],
332
+ "Perform full helpdesk routing.",
333
+ )
334
+
335
+ self.assertEqual(action.issue_type, "identity_access")
336
+ self.assertEqual(action.assignment_group, "service_desk")
337
+ self.assertEqual(action.resolution_action, "fulfill")
338
+ self.assertEqual(action_source, "llm_backfilled")
339
+ self.assertIn("heuristic_backfill", fallback_reason or "")
340
+
341
+ def test_build_action_normalizes_pricing_request_issue_type(self) -> None:
342
+ inference = _load_inference_module()
343
+ inference.llm_client = object()
344
+
345
+ ticket = {
346
+ "ticket_id": "ticket-018",
347
+ "title": "Question about enterprise tier pricing",
348
+ "requester": "finance@urbanstack.io",
349
+ "description": (
350
+ "We're comparing your enterprise plan against two competitors. "
351
+ "Can you send over a detailed pricing breakdown?"
352
+ ),
353
+ }
354
+
355
+ with mock.patch.object(
356
+ inference,
357
+ "call_llm",
358
+ return_value={
359
+ "issue_type": "billing_license",
360
+ "priority": "medium",
361
+ },
362
+ ):
363
+ action, action_source, fallback_reason = inference.build_action(
364
+ ticket,
365
+ ["issue_type", "priority", "assignment_group", "resolution_action"],
366
+ "Perform full helpdesk routing.",
367
+ )
368
+
369
+ self.assertEqual(action.issue_type, "service_request")
370
+ self.assertEqual(action.assignment_group, "procurement")
371
+ self.assertEqual(action.resolution_action, "assign")
372
+ self.assertEqual(action.priority, "medium")
373
+ self.assertEqual(action_source, "llm_backfilled")
374
+ self.assertIn("domain_overrides", fallback_reason or "")
375
+
376
+ def test_build_action_normalizes_onboarding_access_blocker(self) -> None:
377
+ inference = _load_inference_module()
378
+ inference.llm_client = object()
379
+
380
+ ticket = {
381
+ "ticket_id": "TKT-NONDEFAULT-003",
382
+ "title": "Contractor onboarding blocked by access issue",
383
+ "requester": "pm@contractorco.com",
384
+ "description": (
385
+ "A new contractor cannot complete onboarding because their account "
386
+ "access is blocked by a permissions error. The onboarding team "
387
+ "cannot resolve access issues; routing to service desk."
388
+ ),
389
+ "ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk",
390
+ }
391
+
392
+ with mock.patch.object(
393
+ inference,
394
+ "call_llm",
395
+ return_value={
396
+ "issue_type": "identity_access",
397
+ "priority": "high",
398
+ },
399
+ ):
400
+ action, action_source, fallback_reason = inference.build_action(
401
+ ticket,
402
+ ["issue_type", "priority", "assignment_group", "resolution_action"],
403
+ "Perform full helpdesk routing.",
404
+ )
405
+
406
+ self.assertEqual(action.issue_type, "onboarding")
407
+ self.assertEqual(action.priority, "medium")
408
+ self.assertEqual(action.assignment_group, "service_desk")
409
+ self.assertEqual(action.resolution_action, "fulfill")
410
+ self.assertEqual(action_source, "llm_backfilled")
411
+ self.assertIn("domain_overrides", fallback_reason or "")
412
+
413
+ def test_build_action_deescalates_nonurgent_onboarding_priority(self) -> None:
414
+ inference = _load_inference_module()
415
+ inference.llm_client = object()
416
+
417
+ ticket = {
418
+ "ticket_id": "ticket-008",
419
+ "title": "Kickoff onboarding session for newly activated account",
420
+ "requester": "admin@brightpath.io",
421
+ "description": (
422
+ "We activated our account this week and need an onboarding call plus "
423
+ "admin setup guidance for six internal users."
424
+ ),
425
+ }
426
+
427
+ with mock.patch.object(
428
+ inference,
429
+ "call_llm",
430
+ return_value={
431
+ "issue_type": "onboarding",
432
+ "priority": "high",
433
+ },
434
+ ):
435
+ action, action_source, fallback_reason = inference.build_action(
436
+ ticket,
437
+ ["issue_type", "priority"],
438
+ "Read the ticket, select the best IT issue type, and estimate the priority.",
439
+ )
440
+
441
+ self.assertEqual(action.issue_type, "onboarding")
442
+ self.assertEqual(action.priority, "medium")
443
+ self.assertEqual(action_source, "llm_backfilled")
444
+ self.assertIn("domain_overrides", fallback_reason or "")
445
+
446
+ def test_merge_ticket_context_carries_feedback_summary_from_observation(self) -> None:
447
+ inference = _load_inference_module()
448
+
449
+ observation = SimpleNamespace(
450
+ last_tool_result={"tool_name": "lookup_requester_history", "found": True},
451
+ history=[{"ticket_id": "ticket-prev", "score": 0.4}],
452
+ queue_position=2,
453
+ tickets_remaining=4,
454
+ investigation_budget_remaining=1,
455
+ average_score_so_far=0.55,
456
+ progress_fraction=0.4,
457
+ last_reward_components={"ticket_score": 0.4, "final_reward": 0.4},
458
+ metadata={"last_feedback_summary": "Ticket score=0.40; reward=0.40"},
459
+ )
460
+
461
+ merged = inference.merge_ticket_context(
462
+ {
463
+ "ticket_id": "ticket-xyz",
464
+ "title": "Contractor onboarding blocked by access issue",
465
+ },
466
+ observation,
467
+ )
468
+
469
+ self.assertEqual(merged["feedback_summary"], "Ticket score=0.40; reward=0.40")
470
+ self.assertEqual(merged["investigation_budget_remaining"], 1)
471
+ self.assertEqual(merged["average_score_so_far"], 0.55)
472
+ self.assertEqual(merged["progress_fraction"], 0.4)
473
+ self.assertEqual(merged["last_reward_components"]["final_reward"], 0.4)
474
+ self.assertEqual(merged["queue_position"], 2)
475
+ self.assertEqual(merged["tickets_remaining"], 4)
476
+ self.assertEqual(merged["last_tool_result"]["tool_name"], "lookup_requester_history")
477
+
478
+ def test_should_investigate_uses_remaining_tools_from_context_status(self) -> None:
479
+ inference = _load_inference_module()
480
+
481
+ investigate, tool_name = inference.should_investigate(
482
+ {
483
+ "ticket_id": "ticket-021",
484
+ "context_status": {
485
+ "remaining_tools": [
486
+ "lookup_related_ticket",
487
+ "lookup_requester_history",
488
+ ]
489
+ },
490
+ },
491
+ [],
492
+ )
493
+
494
+ self.assertTrue(investigate)
495
+ self.assertEqual(tool_name, "lookup_related_ticket")
496
+
497
 
498
  if __name__ == "__main__":
499
  unittest.main()
tests/test_policy_learning.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import types as _types
6
+ import unittest
7
+
8
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
9
+
10
+ import openenv_test_stubs # noqa: F401
11
+
12
+ if "openenv.core.env_server.interfaces" not in sys.modules:
13
+ _interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
14
+
15
+ class _Environment:
16
+ def __init__(self) -> None:
17
+ pass
18
+
19
+ def __init_subclass__(cls, **kwargs: object) -> None:
20
+ super().__init_subclass__(**kwargs)
21
+
22
+ @classmethod
23
+ def __class_getitem__(cls, item: object) -> type:
24
+ return cls
25
+
26
+ _interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
27
+ sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
28
+
29
+
30
+ from models import HelpdeskTicketAction, HelpdeskTicketObservation
31
+ from policy_learning import (
32
+ POLICY_LIBRARY,
33
+ choose_policy_action,
34
+ compare_policies,
35
+ parse_int_spec,
36
+ rollout_episode,
37
+ search_policies,
38
+ )
39
+ from server.environment import HelpdeskTicketRoutingEnvironment
40
+ from server.tasks import get_task_definition
41
+
42
+
43
+ class SingleTicketEnvironment(HelpdeskTicketRoutingEnvironment):
44
+ def __init__(self, ticket_id: str) -> None:
45
+ super().__init__()
46
+ self._forced_ticket_id = ticket_id
47
+
48
+ def reset(self, seed=None, episode_id=None, **kwargs):
49
+ observation = super().reset(seed=seed, episode_id=episode_id, **kwargs)
50
+ ticket = self._tickets_by_id[self._forced_ticket_id]
51
+ self._queue = [ticket]
52
+ self._state.current_task_id = int(kwargs.get("task_id", 3))
53
+ self._state.queue_ticket_ids = [ticket.ticket_id]
54
+ self._state.current_ticket_index = 0
55
+ self._state.per_ticket_scores = []
56
+ self._state.total_reward = 0.0
57
+ self._state.last_step_reward = None
58
+ self._state.reward = None
59
+ self._state.done = False
60
+ self._state.average_score_so_far = 0.0
61
+ self._state.investigation_steps = 0
62
+ self._state.investigation_budget_remaining = len(self._queue)
63
+ self._state.investigation_penalty_applied = 0.0
64
+ self._state.last_tool_result = None
65
+ self._state.last_reward_components = {}
66
+ self._state.ticket_tool_usage = {}
67
+ self._state.history_entries = []
68
+ return self._build_observation(get_task_definition(self._state.current_task_id))
69
+
70
+
71
+ def _context_sensitive_submit_builder(
72
+ ticket: dict[str, object], allowed_fields: list[str]
73
+ ) -> HelpdeskTicketAction:
74
+ if ticket.get("ambiguity_note"):
75
+ values = {
76
+ "issue_type": "onboarding",
77
+ "priority": "medium",
78
+ "assignment_group": "service_desk",
79
+ "resolution_action": "fulfill",
80
+ }
81
+ else:
82
+ values = {
83
+ "issue_type": "identity_access",
84
+ "priority": "high",
85
+ "assignment_group": "service_desk",
86
+ "resolution_action": "fulfill",
87
+ }
88
+ return HelpdeskTicketAction(
89
+ **{field: value for field, value in values.items() if field in allowed_fields}
90
+ )
91
+
92
+
93
+ class PolicyLearningTests(unittest.TestCase):
94
+ def test_parse_int_spec_expands_ranges(self) -> None:
95
+ self.assertEqual(parse_int_spec("42-44,44,46", field_name="seeds"), [42, 43, 44, 46])
96
+
97
+ def test_choose_policy_action_prefers_hidden_context_tools(self) -> None:
98
+ policy = POLICY_LIBRARY["investigate_when_context_hidden"]
99
+ observation = HelpdeskTicketObservation(
100
+ current_ticket={
101
+ "ticket_id": "ticket-021",
102
+ "context_status": {
103
+ "remaining_tools": ["lookup_related_ticket", "lookup_requester_history"],
104
+ "revealed_tools": [],
105
+ }
106
+ },
107
+ allowed_fields=["issue_type"],
108
+ )
109
+
110
+ action, source = choose_policy_action(policy, observation, {}, _context_sensitive_submit_builder)
111
+
112
+ self.assertEqual(action.action_type, "investigate")
113
+ self.assertEqual(action.tool_name, "lookup_related_ticket")
114
+ self.assertEqual(source, "investigate_hidden_context")
115
+
116
+ def test_choose_policy_action_submits_when_investigation_disabled(self) -> None:
117
+ policy = POLICY_LIBRARY["no_investigation"]
118
+ observation = HelpdeskTicketObservation(
119
+ current_ticket={
120
+ "ticket_id": "ticket-021",
121
+ "context_status": {"remaining_tools": ["lookup_related_ticket"]},
122
+ },
123
+ allowed_fields=["issue_type", "priority"],
124
+ )
125
+
126
+ action, source = choose_policy_action(policy, observation, {}, _context_sensitive_submit_builder)
127
+
128
+ self.assertEqual(action.action_type, "submit")
129
+ self.assertEqual(action.issue_type, "identity_access")
130
+ self.assertEqual(source, "submit")
131
+
132
+ def test_rollout_episode_rewards_context_aware_policy(self) -> None:
133
+ no_investigation = POLICY_LIBRARY["no_investigation"]
134
+ context_aware = POLICY_LIBRARY["investigate_when_context_hidden"]
135
+
136
+ no_summary, _ = rollout_episode(
137
+ env=SingleTicketEnvironment("TKT-NONDEFAULT-003"),
138
+ policy=no_investigation,
139
+ seed=42,
140
+ task_id=3,
141
+ submit_builder=_context_sensitive_submit_builder,
142
+ )
143
+ context_summary, _ = rollout_episode(
144
+ env=SingleTicketEnvironment("TKT-NONDEFAULT-003"),
145
+ policy=context_aware,
146
+ seed=42,
147
+ task_id=3,
148
+ submit_builder=_context_sensitive_submit_builder,
149
+ )
150
+
151
+ self.assertLess(no_summary["terminal_reward"], context_summary["terminal_reward"])
152
+ self.assertLess(no_summary["normalized_return"], context_summary["normalized_return"])
153
+ self.assertEqual(context_summary["investigation_steps"], 1)
154
+
155
+ def test_search_policies_selects_better_policy(self) -> None:
156
+ report = search_policies(
157
+ [
158
+ POLICY_LIBRARY["no_investigation"],
159
+ POLICY_LIBRARY["investigate_when_context_hidden"],
160
+ ],
161
+ train_seeds=[41, 42],
162
+ eval_seeds=[43],
163
+ task_ids=[3],
164
+ output_dir=os.path.join(os.getcwd(), "analysis", "policy_learning_test"),
165
+ env_factory=lambda: SingleTicketEnvironment("TKT-NONDEFAULT-003"),
166
+ submit_builder=_context_sensitive_submit_builder,
167
+ )
168
+
169
+ self.assertEqual(report["selected_policy"], "investigate_when_context_hidden")
170
+ self.assertGreater(
171
+ report["eval_improvement_vs_baseline"]["avg_normalized_return"],
172
+ 0.0,
173
+ )
174
+
175
+ def test_compare_policies_reports_improvement(self) -> None:
176
+ report = compare_policies(
177
+ [
178
+ POLICY_LIBRARY["no_investigation"],
179
+ POLICY_LIBRARY["investigate_when_context_hidden"],
180
+ ],
181
+ seeds=[42],
182
+ task_ids=[3],
183
+ output_dir=os.path.join(os.getcwd(), "analysis", "policy_learning_compare_test"),
184
+ env_factory=lambda: SingleTicketEnvironment("TKT-NONDEFAULT-003"),
185
+ submit_builder=_context_sensitive_submit_builder,
186
+ )
187
+
188
+ self.assertEqual(report["best_policy"], "investigate_when_context_hidden")
189
+ self.assertGreater(report["improvement_vs_baseline"]["avg_terminal_reward"], 0.0)
190
+
191
+
192
+ if __name__ == "__main__":
193
+ unittest.main()