Spaces:
Running
Running
Add policy learning loop and strengthen RL-style environment
Browse files- .gitignore +4 -0
- README.md +55 -0
- inference.py +363 -30
- models.py +10 -0
- policy_learning.py +723 -0
- pyproject.toml +2 -1
- server/environment.py +401 -38
- server/tasks.py +3 -1
- tests/test_api_integration.py +71 -0
- tests/test_competitive_upgrade.py +118 -6
- tests/test_inference_unit.py +315 -0
- tests/test_policy_learning.py +193 -0
.gitignore
CHANGED
|
@@ -6,3 +6,7 @@ __pycache__/
|
|
| 6 |
.mypy_cache/
|
| 7 |
.ruff_cache/
|
| 8 |
build/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
.mypy_cache/
|
| 7 |
.ruff_cache/
|
| 8 |
build/
|
| 9 |
+
analysis/policy_learning_runs/
|
| 10 |
+
analysis/policy_learning_test/
|
| 11 |
+
analysis/policy_learning_compare_test/
|
| 12 |
+
analysis/policy_learning_runs_smoke/
|
README.md
CHANGED
|
@@ -38,6 +38,8 @@ The environment models a realistic helpdesk workflow:
|
|
| 38 |
4. the grader assigns deterministic credit
|
| 39 |
5. the environment advances to the next ticket until the queue is complete
|
| 40 |
|
|
|
|
|
|
|
| 41 |
This domain is useful for OpenEnv because it is operationally realistic, easy to evaluate with typed outputs, and naturally supports a clean easy-to-hard task ladder.
|
| 42 |
|
| 43 |
## Why This Is A Good Hackathon Domain
|
|
@@ -59,6 +61,37 @@ The project uses a queue-based episode model.
|
|
| 59 |
|
| 60 |
The environment classes and vocabulary are intentionally frozen to keep collaboration and judging simple.
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
## Task Ladder
|
| 63 |
|
| 64 |
| ID | Name | Difficulty | Required Fields | What The Agent Must Do |
|
|
@@ -125,6 +158,7 @@ Each observation also includes:
|
|
| 125 |
- `task_name`
|
| 126 |
- `instructions`
|
| 127 |
- `allowed_fields`
|
|
|
|
| 128 |
- `available_tools`
|
| 129 |
- `investigation_budget_remaining`
|
| 130 |
- `last_tool_result`
|
|
@@ -133,7 +167,12 @@ Each observation also includes:
|
|
| 133 |
- `tickets_after_current`
|
| 134 |
- `tickets_processed`
|
| 135 |
- `queue_position`
|
|
|
|
|
|
|
| 136 |
- `history`
|
|
|
|
|
|
|
|
|
|
| 137 |
- standard OpenEnv fields such as `done` and `reward`
|
| 138 |
|
| 139 |
The internal `HelpdeskTicketState` tracks:
|
|
@@ -162,6 +201,15 @@ Available tools:
|
|
| 162 |
|
| 163 |
- `lookup_related_ticket`
|
| 164 |
- `lookup_requester_history`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
Per-field behavior:
|
| 167 |
|
|
@@ -190,6 +238,12 @@ Step reward is lightly milestone-shaped: high per-ticket scores get a small bonu
|
|
| 190 |
|
| 191 |
Final reward also includes a tiny queue-economics penalty only when the agent exceeds the free investigation budget. One investigation per queued ticket is free; extra investigation steps reduce the final reward slightly.
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
## Grounded Scoring
|
| 194 |
|
| 195 |
The grader is intentionally not fuzzy by default.
|
|
@@ -343,6 +397,7 @@ Optional target:
|
|
| 343 |
|
| 344 |
- `ENV_URL`
|
| 345 |
- default value: `http://localhost:7860`
|
|
|
|
| 346 |
- `TASK_ID`
|
| 347 |
- `RUN_ALL_TASKS`
|
| 348 |
|
|
|
|
| 38 |
4. the grader assigns deterministic credit
|
| 39 |
5. the environment advances to the next ticket until the queue is complete
|
| 40 |
|
| 41 |
+
For hard-task tickets, the environment can now withhold decisive routing context until the agent uses the right investigation tool. That keeps the task from collapsing into one-shot classification and makes tool choice part of the policy.
|
| 42 |
+
|
| 43 |
This domain is useful for OpenEnv because it is operationally realistic, easy to evaluate with typed outputs, and naturally supports a clean easy-to-hard task ladder.
|
| 44 |
|
| 45 |
## Why This Is A Good Hackathon Domain
|
|
|
|
| 61 |
|
| 62 |
The environment classes and vocabulary are intentionally frozen to keep collaboration and judging simple.
|
| 63 |
|
| 64 |
+
## Lightweight Policy Improvement Loop
|
| 65 |
+
|
| 66 |
+
The repo now includes a small local learning runner in `policy_learning.py`. It does not update model weights, but it does run repeated rollouts over many seeds, log full trajectories, and select the best policy configuration from a discrete candidate set using observed reward.
|
| 67 |
+
|
| 68 |
+
That gives the project a real improvement loop for judge demos:
|
| 69 |
+
|
| 70 |
+
- compare `no_investigation` against `investigate_when_context_hidden`
|
| 71 |
+
- log per-step rewards, feedback summaries, and reward components to JSONL
|
| 72 |
+
- search over small policy variants such as `legacy_single_probe`, `context_chain`, and `hybrid_context`
|
| 73 |
+
- select the best policy on train seeds, then re-evaluate it on holdout seeds
|
| 74 |
+
|
| 75 |
+
Example commands:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
python policy_learning.py compare --seeds 42-51 --task-ids 1,2,3
|
| 79 |
+
python policy_learning.py search --train-seeds 40-49 --eval-seeds 50-59 --task-ids 1,2,3
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Artifacts are written to `analysis/policy_learning_runs/` by default:
|
| 83 |
+
|
| 84 |
+
- `compare_summary.json`
|
| 85 |
+
- `compare_episodes.jsonl`
|
| 86 |
+
- `compare_trajectories.jsonl`
|
| 87 |
+
- `search_summary.json`
|
| 88 |
+
- `search_train_episodes.jsonl`
|
| 89 |
+
- `search_train_trajectories.jsonl`
|
| 90 |
+
- `search_eval_episodes.jsonl`
|
| 91 |
+
- `search_eval_trajectories.jsonl`
|
| 92 |
+
|
| 93 |
+
The default submit policy inside this runner stays deterministic and local. It reuses the repo's heuristic routing logic, so the discrete policy search focuses on investigation behavior and reward-driven policy selection rather than on external LLM latency or API cost.
|
| 94 |
+
|
| 95 |
## Task Ladder
|
| 96 |
|
| 97 |
| ID | Name | Difficulty | Required Fields | What The Agent Must Do |
|
|
|
|
| 158 |
- `task_name`
|
| 159 |
- `instructions`
|
| 160 |
- `allowed_fields`
|
| 161 |
+
- `available_action_types`
|
| 162 |
- `available_tools`
|
| 163 |
- `investigation_budget_remaining`
|
| 164 |
- `last_tool_result`
|
|
|
|
| 167 |
- `tickets_after_current`
|
| 168 |
- `tickets_processed`
|
| 169 |
- `queue_position`
|
| 170 |
+
- `average_score_so_far`
|
| 171 |
+
- `progress_fraction`
|
| 172 |
- `history`
|
| 173 |
+
- `last_reward_components`
|
| 174 |
+
- `rubric_reward` on terminal observations
|
| 175 |
+
- `metadata.last_feedback_summary` for compact reward / penalty feedback
|
| 176 |
- standard OpenEnv fields such as `done` and `reward`
|
| 177 |
|
| 178 |
The internal `HelpdeskTicketState` tracks:
|
|
|
|
| 201 |
|
| 202 |
- `lookup_related_ticket`
|
| 203 |
- `lookup_requester_history`
|
| 204 |
+
- `lookup_internal_routing_note`
|
| 205 |
+
|
| 206 |
+
Hard-task investigation behavior:
|
| 207 |
+
|
| 208 |
+
- some ambiguous and non-default-routing tickets start with redacted descriptions
|
| 209 |
+
- linked-ticket previews and internal routing notes stay hidden until the matching tool is used
|
| 210 |
+
- useful investigation steps return a small positive shaping reward
|
| 211 |
+
- premature hard-task submission can incur a shaping penalty even when the visible text looks plausible
|
| 212 |
+
- terminal `rubric_reward` remains the objective evaluation signal, while per-step `reward` is the denser training signal
|
| 213 |
|
| 214 |
Per-field behavior:
|
| 215 |
|
|
|
|
| 238 |
|
| 239 |
Final reward also includes a tiny queue-economics penalty only when the agent exceeds the free investigation budget. One investigation per queued ticket is free; extra investigation steps reduce the final reward slightly.
|
| 240 |
|
| 241 |
+
To make the environment more RL-friendly, each observation now also surfaces structured reward telemetry:
|
| 242 |
+
|
| 243 |
+
- `last_reward_components` exposes ticket score, shaped step reward, milestone adjustment, trajectory reward when applicable, and any investigation penalty applied
|
| 244 |
+
- `average_score_so_far` and `progress_fraction` expose trajectory progress without leaking future labels
|
| 245 |
+
- `history` retains the same reward components plus a compact `feedback_summary` string for downstream agents
|
| 246 |
+
|
| 247 |
## Grounded Scoring
|
| 248 |
|
| 249 |
The grader is intentionally not fuzzy by default.
|
|
|
|
| 397 |
|
| 398 |
- `ENV_URL`
|
| 399 |
- default value: `http://localhost:7860`
|
| 400 |
+
- `SEED`
|
| 401 |
- `TASK_ID`
|
| 402 |
- `RUN_ALL_TASKS`
|
| 403 |
|
inference.py
CHANGED
|
@@ -66,13 +66,27 @@ from vocabulary import (
|
|
| 66 |
DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
|
| 67 |
DEFAULT_MODEL_NAME = "<your-active-model>"
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
|
| 70 |
MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
|
| 71 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 72 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 73 |
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 74 |
|
| 75 |
-
SEED = 42
|
| 76 |
TASK_ID_ENV = os.getenv("TASK_ID")
|
| 77 |
RUN_ALL_TASKS_ENV = os.getenv("RUN_ALL_TASKS", "").strip().lower() in {
|
| 78 |
"1",
|
|
@@ -94,6 +108,14 @@ if llm_mode_enabled():
|
|
| 94 |
llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
SYSTEM_PROMPT = """\
|
| 98 |
You are an expert IT helpdesk ticket routing agent. Given a helpdesk ticket, you must produce a JSON object with the requested fields.
|
| 99 |
|
|
@@ -103,19 +125,79 @@ Valid values:
|
|
| 103 |
- assignment_group: {assignment_groups}
|
| 104 |
- resolution_action: {resolution_actions}
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
Return ONLY valid JSON with the requested fields. No markdown, no explanation.""".format(
|
| 107 |
issue_types=", ".join(ISSUE_TYPES),
|
| 108 |
priorities=", ".join(PRIORITIES),
|
| 109 |
assignment_groups=", ".join(ASSIGNMENT_GROUPS),
|
| 110 |
resolution_actions=", ".join(RESOLUTION_ACTIONS),
|
|
|
|
| 111 |
)
|
| 112 |
|
| 113 |
|
| 114 |
-
def
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
ambiguity_note = ticket.get("ambiguity_note")
|
| 117 |
related_preview = ticket.get("related_ticket_preview") or {}
|
| 118 |
last_tool_result = ticket.get("last_tool_result")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
extra_context_lines: list[str] = []
|
| 120 |
if ambiguity_note:
|
| 121 |
extra_context_lines.append(f"Ambiguity note: {ambiguity_note}")
|
|
@@ -132,20 +214,53 @@ def call_llm(ticket: dict, allowed_fields: list[str], instructions: str) -> dict
|
|
| 132 |
extra_context_lines.append(
|
| 133 |
"Investigation result: " + json.dumps(last_tool_result, sort_keys=True)
|
| 134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
extra_context_block = ""
|
| 136 |
if extra_context_lines:
|
| 137 |
extra_context_block = "\n" + "\n".join(extra_context_lines)
|
| 138 |
|
| 139 |
-
|
| 140 |
f"Instructions: {instructions}\n\n"
|
| 141 |
f"Allowed fields: {', '.join(allowed_fields)}\n\n"
|
| 142 |
-
f"Title: {ticket
|
| 143 |
-
f"Requester: {ticket
|
| 144 |
-
f"Description: {ticket
|
| 145 |
f"{extra_context_block}\n\n"
|
| 146 |
f"Respond with JSON containing ONLY these fields: {', '.join(allowed_fields)}"
|
| 147 |
)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
response = llm_client.chat.completions.create(
|
| 150 |
model=MODEL_NAME,
|
| 151 |
messages=[
|
|
@@ -298,6 +413,95 @@ FULFILL_KEYWORDS = (
|
|
| 298 |
"mfa enabled",
|
| 299 |
)
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
def heuristic_priority(text: str) -> str:
|
| 303 |
if any(word in text for word in CRITICAL_PRIORITY_KEYWORDS):
|
|
@@ -323,26 +527,32 @@ def heuristic_resolution_action(text: str, issue_type: str) -> str:
|
|
| 323 |
return ISSUE_TYPE_TO_RESOLUTION_ACTION.get(issue_type, "acknowledge")
|
| 324 |
|
| 325 |
|
| 326 |
-
def
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
related_preview.get("description", ""),
|
| 336 |
-
json.dumps(last_tool_result, sort_keys=True),
|
| 337 |
-
]
|
| 338 |
-
).lower()
|
| 339 |
|
|
|
|
|
|
|
| 340 |
issue_type = "general_inquiry"
|
| 341 |
for kw, mapped_issue_type in KEYWORD_ISSUE_TYPES.items():
|
| 342 |
if kw in text:
|
| 343 |
issue_type = mapped_issue_type
|
| 344 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
|
|
|
| 346 |
priority = heuristic_priority(text)
|
| 347 |
resolution_action = heuristic_resolution_action(text, issue_type)
|
| 348 |
|
|
@@ -352,14 +562,75 @@ def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
|
|
| 352 |
if "priority" in allowed_fields:
|
| 353 |
result["priority"] = priority
|
| 354 |
if "assignment_group" in allowed_fields:
|
| 355 |
-
result["assignment_group"] =
|
| 356 |
-
issue_type, "service_desk"
|
| 357 |
-
)
|
| 358 |
if "resolution_action" in allowed_fields:
|
| 359 |
result["resolution_action"] = resolution_action
|
| 360 |
return result
|
| 361 |
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
def build_action(
|
| 364 |
ticket: dict, allowed_fields: list[str], instructions: str
|
| 365 |
) -> tuple[HelpdeskTicketAction, str, str | None]:
|
|
@@ -370,13 +641,50 @@ def build_action(
|
|
| 370 |
|
| 371 |
try:
|
| 372 |
llm_dict = call_llm(ticket, allowed_fields, instructions)
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
raise ValueError("LLM returned no allowed fields")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
return HelpdeskTicketAction(**candidate), "llm", None
|
| 381 |
except Exception as exc:
|
| 382 |
return (
|
|
@@ -389,6 +697,10 @@ def build_action(
|
|
| 389 |
def should_investigate(ticket: dict, history: list[dict[str, Any]]) -> tuple[bool, str | None]:
|
| 390 |
if not ticket:
|
| 391 |
return False, None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
current_ticket_id = ticket.get("ticket_id")
|
| 393 |
already_investigated = any(
|
| 394 |
entry.get("ticket_id") == current_ticket_id
|
|
@@ -408,6 +720,22 @@ def merge_ticket_context(ticket: dict, observation: Any) -> dict:
|
|
| 408 |
merged_ticket = dict(ticket)
|
| 409 |
if getattr(observation, "last_tool_result", None) is not None:
|
| 410 |
merged_ticket["last_tool_result"] = observation.last_tool_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
return merged_ticket
|
| 412 |
|
| 413 |
|
|
@@ -518,7 +846,12 @@ def run() -> None:
|
|
| 518 |
ticket_id=ticket["ticket_id"],
|
| 519 |
)
|
| 520 |
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
all_results[task_id] = {
|
| 523 |
"final_reward": final_reward,
|
| 524 |
"step_count": step_num,
|
|
|
|
| 66 |
DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
|
| 67 |
DEFAULT_MODEL_NAME = "<your-active-model>"
|
| 68 |
|
| 69 |
+
|
| 70 |
+
def _get_int_env(name: str, default: int) -> int:
|
| 71 |
+
raw_value = os.getenv(name)
|
| 72 |
+
if raw_value is None or raw_value.strip() == "":
|
| 73 |
+
return default
|
| 74 |
+
try:
|
| 75 |
+
return int(raw_value)
|
| 76 |
+
except ValueError:
|
| 77 |
+
print(
|
| 78 |
+
f"[WARN] {name}={raw_value!r} is not a valid integer; using {default}.",
|
| 79 |
+
flush=True,
|
| 80 |
+
)
|
| 81 |
+
return default
|
| 82 |
+
|
| 83 |
API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
|
| 84 |
MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
|
| 85 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 86 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 87 |
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
|
| 88 |
|
| 89 |
+
SEED = _get_int_env("SEED", 42)
|
| 90 |
TASK_ID_ENV = os.getenv("TASK_ID")
|
| 91 |
RUN_ALL_TASKS_ENV = os.getenv("RUN_ALL_TASKS", "").strip().lower() in {
|
| 92 |
"1",
|
|
|
|
| 108 |
llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 109 |
|
| 110 |
|
| 111 |
+
RECENT_HISTORY_LIMIT = 2
|
| 112 |
+
ROUTING_PRIORS = "\n".join(
|
| 113 |
+
f"- {issue_type}: assignment_group={ISSUE_TYPE_TO_ASSIGNMENT_GROUP[issue_type]}, "
|
| 114 |
+
f"resolution_action={ISSUE_TYPE_TO_RESOLUTION_ACTION[issue_type]}"
|
| 115 |
+
for issue_type in ISSUE_TYPES
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
SYSTEM_PROMPT = """\
|
| 120 |
You are an expert IT helpdesk ticket routing agent. Given a helpdesk ticket, you must produce a JSON object with the requested fields.
|
| 121 |
|
|
|
|
| 125 |
- assignment_group: {assignment_groups}
|
| 126 |
- resolution_action: {resolution_actions}
|
| 127 |
|
| 128 |
+
Decision rules:
|
| 129 |
+
- Follow this environment's label ontology exactly; do not invent categories.
|
| 130 |
+
- Prefer the primary operational workflow label over a secondary technical symptom.
|
| 131 |
+
- Keep assignment_group and resolution_action consistent with the chosen issue_type unless the ticket explicitly justifies a different choice.
|
| 132 |
+
- Use investigation results and recent evaluation feedback when provided.
|
| 133 |
+
|
| 134 |
+
Domain conventions:
|
| 135 |
+
- Enterprise pricing, quotes, plan comparisons, and commercial procurement requests map to service_request, usually with medium priority.
|
| 136 |
+
- Onboarding work that is blocked by an access problem still maps to onboarding when the primary workflow is onboarding; the assignment_group may still be service_desk if the ticket says onboarding cannot resolve the access issue.
|
| 137 |
+
- Single-user sign-in, login, MFA, or 2FA lockouts map to identity_access and are usually high priority, not critical.
|
| 138 |
+
- Reserve critical priority for outages, widespread business blockers, or explicit urgent critical incidents.
|
| 139 |
+
|
| 140 |
+
Routing priors:
|
| 141 |
+
{routing_priors}
|
| 142 |
+
|
| 143 |
Return ONLY valid JSON with the requested fields. No markdown, no explanation.""".format(
|
| 144 |
issue_types=", ".join(ISSUE_TYPES),
|
| 145 |
priorities=", ".join(PRIORITIES),
|
| 146 |
assignment_groups=", ".join(ASSIGNMENT_GROUPS),
|
| 147 |
resolution_actions=", ".join(RESOLUTION_ACTIONS),
|
| 148 |
+
routing_priors=ROUTING_PRIORS,
|
| 149 |
)
|
| 150 |
|
| 151 |
|
| 152 |
+
def format_recent_history_entries(
|
| 153 |
+
history: list[dict[str, Any]], limit: int = RECENT_HISTORY_LIMIT
|
| 154 |
+
) -> str:
|
| 155 |
+
if not history:
|
| 156 |
+
return ""
|
| 157 |
+
|
| 158 |
+
lines = ["Recent evaluation feedback (latest last):"]
|
| 159 |
+
for entry in history[-limit:]:
|
| 160 |
+
predicted = json.dumps(entry.get("predicted", {}), sort_keys=True)
|
| 161 |
+
line = (
|
| 162 |
+
f"- Ticket {entry.get('ticket_id', '?')}: predicted={predicted}, "
|
| 163 |
+
f"score={entry.get('score', 0.0)}"
|
| 164 |
+
)
|
| 165 |
+
feedback_summary = entry.get("feedback_summary")
|
| 166 |
+
if feedback_summary:
|
| 167 |
+
line += f", feedback={feedback_summary}"
|
| 168 |
+
reward = entry.get("reward")
|
| 169 |
+
if reward is not None:
|
| 170 |
+
line += f", reward={reward}"
|
| 171 |
+
rubric_reward = entry.get("rubric_reward")
|
| 172 |
+
if rubric_reward is not None:
|
| 173 |
+
line += f", rubric_reward={rubric_reward}"
|
| 174 |
+
breakdown = entry.get("breakdown") or {}
|
| 175 |
+
if breakdown:
|
| 176 |
+
line += f", breakdown={json.dumps(breakdown, sort_keys=True)}"
|
| 177 |
+
penalty_reason = entry.get("penalty_reason")
|
| 178 |
+
if penalty_reason:
|
| 179 |
+
line += f", penalty_reason={penalty_reason}"
|
| 180 |
+
tool_result = entry.get("tool_result")
|
| 181 |
+
if tool_result is not None:
|
| 182 |
+
line += f", tool_result={json.dumps(tool_result, sort_keys=True)}"
|
| 183 |
+
reward_components = entry.get("reward_components")
|
| 184 |
+
if reward_components:
|
| 185 |
+
line += f", reward_components={json.dumps(reward_components, sort_keys=True)}"
|
| 186 |
+
lines.append(line)
|
| 187 |
+
return "\n".join(lines)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def build_llm_user_message(ticket: dict, allowed_fields: list[str], instructions: str) -> str:
|
| 191 |
ambiguity_note = ticket.get("ambiguity_note")
|
| 192 |
related_preview = ticket.get("related_ticket_preview") or {}
|
| 193 |
last_tool_result = ticket.get("last_tool_result")
|
| 194 |
+
context_status = ticket.get("context_status") or {}
|
| 195 |
+
recent_history = ticket.get("recent_history") or []
|
| 196 |
+
feedback_summary = ticket.get("feedback_summary")
|
| 197 |
+
last_reward_components = ticket.get("last_reward_components") or {}
|
| 198 |
+
investigation_budget_remaining = ticket.get("investigation_budget_remaining")
|
| 199 |
+
average_score_so_far = ticket.get("average_score_so_far")
|
| 200 |
+
progress_fraction = ticket.get("progress_fraction")
|
| 201 |
extra_context_lines: list[str] = []
|
| 202 |
if ambiguity_note:
|
| 203 |
extra_context_lines.append(f"Ambiguity note: {ambiguity_note}")
|
|
|
|
| 214 |
extra_context_lines.append(
|
| 215 |
"Investigation result: " + json.dumps(last_tool_result, sort_keys=True)
|
| 216 |
)
|
| 217 |
+
if context_status:
|
| 218 |
+
extra_context_lines.append(
|
| 219 |
+
"Context status: " + json.dumps(context_status, sort_keys=True)
|
| 220 |
+
)
|
| 221 |
+
if feedback_summary:
|
| 222 |
+
extra_context_lines.append(f"Latest environment feedback: {feedback_summary}")
|
| 223 |
+
if last_reward_components:
|
| 224 |
+
extra_context_lines.append(
|
| 225 |
+
"Latest reward components: "
|
| 226 |
+
+ json.dumps(last_reward_components, sort_keys=True)
|
| 227 |
+
)
|
| 228 |
+
recent_history_block = format_recent_history_entries(recent_history)
|
| 229 |
+
if recent_history_block:
|
| 230 |
+
extra_context_lines.append(recent_history_block)
|
| 231 |
+
queue_position = ticket.get("queue_position")
|
| 232 |
+
tickets_remaining = ticket.get("tickets_remaining")
|
| 233 |
+
if queue_position is not None and tickets_remaining is not None:
|
| 234 |
+
extra_context_lines.append(
|
| 235 |
+
f"Queue context: queue_position={queue_position}, tickets_remaining={tickets_remaining}"
|
| 236 |
+
)
|
| 237 |
+
if average_score_so_far is not None:
|
| 238 |
+
extra_context_lines.append(f"Average score so far: {average_score_so_far}")
|
| 239 |
+
if progress_fraction is not None:
|
| 240 |
+
extra_context_lines.append(f"Episode progress: {progress_fraction}")
|
| 241 |
+
if investigation_budget_remaining is not None:
|
| 242 |
+
extra_context_lines.append(
|
| 243 |
+
f"Investigation budget remaining: {investigation_budget_remaining}"
|
| 244 |
+
)
|
| 245 |
extra_context_block = ""
|
| 246 |
if extra_context_lines:
|
| 247 |
extra_context_block = "\n" + "\n".join(extra_context_lines)
|
| 248 |
|
| 249 |
+
return (
|
| 250 |
f"Instructions: {instructions}\n\n"
|
| 251 |
f"Allowed fields: {', '.join(allowed_fields)}\n\n"
|
| 252 |
+
f"Title: {ticket.get('title', '')}\n"
|
| 253 |
+
f"Requester: {ticket.get('requester', '')}\n"
|
| 254 |
+
f"Description: {ticket.get('description', '')}"
|
| 255 |
f"{extra_context_block}\n\n"
|
| 256 |
f"Respond with JSON containing ONLY these fields: {', '.join(allowed_fields)}"
|
| 257 |
)
|
| 258 |
|
| 259 |
+
|
| 260 |
+
def call_llm(ticket: dict, allowed_fields: list[str], instructions: str) -> dict:
|
| 261 |
+
assert llm_client is not None, "LLM client not configured"
|
| 262 |
+
user_msg = build_llm_user_message(ticket, allowed_fields, instructions)
|
| 263 |
+
|
| 264 |
response = llm_client.chat.completions.create(
|
| 265 |
model=MODEL_NAME,
|
| 266 |
messages=[
|
|
|
|
| 413 |
"mfa enabled",
|
| 414 |
)
|
| 415 |
|
| 416 |
+
PRICING_REQUEST_KEYWORDS = (
|
| 417 |
+
"pricing breakdown",
|
| 418 |
+
"enterprise tier pricing",
|
| 419 |
+
"enterprise plan",
|
| 420 |
+
"compare your enterprise plan",
|
| 421 |
+
"comparing your enterprise plan",
|
| 422 |
+
"quote",
|
| 423 |
+
"pricing quote",
|
| 424 |
+
"commercial proposal",
|
| 425 |
+
"vendor comparison",
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
ONBOARDING_WORKFLOW_KEYWORDS = (
|
| 429 |
+
"onboarding",
|
| 430 |
+
"new hire",
|
| 431 |
+
"contractor",
|
| 432 |
+
"provisioned",
|
| 433 |
+
"kickoff onboarding",
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
ACCESS_BLOCKER_KEYWORDS = (
|
| 437 |
+
"access issue",
|
| 438 |
+
"permissions error",
|
| 439 |
+
"permission error",
|
| 440 |
+
"account access is blocked",
|
| 441 |
+
"cannot sign in",
|
| 442 |
+
"can't sign in",
|
| 443 |
+
"locked",
|
| 444 |
+
"2fa",
|
| 445 |
+
"mfa",
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
SERVICE_DESK_ONBOARDING_ESCALATION_KEYWORDS = (
|
| 449 |
+
"onboarding team cannot resolve access issues",
|
| 450 |
+
"routing to service desk",
|
| 451 |
+
"route to service desk",
|
| 452 |
+
"service desk",
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
CRITICAL_INCIDENT_KEYWORDS = (
|
| 456 |
+
"outage",
|
| 457 |
+
"company-wide",
|
| 458 |
+
"all users",
|
| 459 |
+
"widespread",
|
| 460 |
+
"production down",
|
| 461 |
+
"critical incident",
|
| 462 |
+
"sev1",
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
HIGH_PRIORITY_SIGNAL_KEYWORDS = (
|
| 466 |
+
"locked",
|
| 467 |
+
"blocked",
|
| 468 |
+
"cannot sign in",
|
| 469 |
+
"can't sign in",
|
| 470 |
+
"2fa",
|
| 471 |
+
"mfa",
|
| 472 |
+
"expedite",
|
| 473 |
+
"start monday",
|
| 474 |
+
"asap",
|
| 475 |
+
"today",
|
| 476 |
+
"eod",
|
| 477 |
+
"urgent",
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
TIME_SENSITIVE_PRIORITY_KEYWORDS = (
|
| 481 |
+
"expedite",
|
| 482 |
+
"start monday",
|
| 483 |
+
"today",
|
| 484 |
+
"asap",
|
| 485 |
+
"eod",
|
| 486 |
+
"urgent",
|
| 487 |
+
"immediately",
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def build_routing_text(ticket: dict) -> str:
|
| 492 |
+
related_preview = ticket.get("related_ticket_preview") or {}
|
| 493 |
+
last_tool_result = ticket.get("last_tool_result") or {}
|
| 494 |
+
return " ".join(
|
| 495 |
+
[
|
| 496 |
+
ticket.get("title", ""),
|
| 497 |
+
ticket.get("description", ""),
|
| 498 |
+
ticket.get("ambiguity_note", ""),
|
| 499 |
+
related_preview.get("title", ""),
|
| 500 |
+
related_preview.get("description", ""),
|
| 501 |
+
json.dumps(last_tool_result, sort_keys=True),
|
| 502 |
+
]
|
| 503 |
+
).lower()
|
| 504 |
+
|
| 505 |
|
| 506 |
def heuristic_priority(text: str) -> str:
|
| 507 |
if any(word in text for word in CRITICAL_PRIORITY_KEYWORDS):
|
|
|
|
| 527 |
return ISSUE_TYPE_TO_RESOLUTION_ACTION.get(issue_type, "acknowledge")
|
| 528 |
|
| 529 |
|
| 530 |
+
def heuristic_assignment_group(text: str, issue_type: str) -> str:
|
| 531 |
+
if issue_type == "onboarding":
|
| 532 |
+
if any(keyword in text for keyword in SERVICE_DESK_ONBOARDING_ESCALATION_KEYWORDS):
|
| 533 |
+
return "service_desk"
|
| 534 |
+
if any(keyword in text for keyword in ACCESS_BLOCKER_KEYWORDS) and any(
|
| 535 |
+
keyword in text for keyword in ONBOARDING_WORKFLOW_KEYWORDS
|
| 536 |
+
):
|
| 537 |
+
return "service_desk"
|
| 538 |
+
return ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(issue_type, "service_desk")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
|
| 540 |
+
|
| 541 |
+
def infer_issue_type(text: str) -> str:
|
| 542 |
issue_type = "general_inquiry"
|
| 543 |
for kw, mapped_issue_type in KEYWORD_ISSUE_TYPES.items():
|
| 544 |
if kw in text:
|
| 545 |
issue_type = mapped_issue_type
|
| 546 |
break
|
| 547 |
+
return issue_type
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def heuristic_action(
|
| 551 |
+
ticket: dict, allowed_fields: list[str], issue_type_override: str | None = None
|
| 552 |
+
) -> dict:
|
| 553 |
+
text = build_routing_text(ticket)
|
| 554 |
|
| 555 |
+
issue_type = issue_type_override or infer_issue_type(text)
|
| 556 |
priority = heuristic_priority(text)
|
| 557 |
resolution_action = heuristic_resolution_action(text, issue_type)
|
| 558 |
|
|
|
|
| 562 |
if "priority" in allowed_fields:
|
| 563 |
result["priority"] = priority
|
| 564 |
if "assignment_group" in allowed_fields:
|
| 565 |
+
result["assignment_group"] = heuristic_assignment_group(text, issue_type)
|
|
|
|
|
|
|
| 566 |
if "resolution_action" in allowed_fields:
|
| 567 |
result["resolution_action"] = resolution_action
|
| 568 |
return result
|
| 569 |
|
| 570 |
|
| 571 |
+
def apply_domain_overrides(
|
| 572 |
+
ticket: dict, candidate: dict[str, Any], allowed_fields: list[str]
|
| 573 |
+
) -> tuple[dict[str, Any], list[str]]:
|
| 574 |
+
updated = dict(candidate)
|
| 575 |
+
reasons: list[str] = []
|
| 576 |
+
text = build_routing_text(ticket)
|
| 577 |
+
|
| 578 |
+
issue_type = updated.get("issue_type")
|
| 579 |
+
if "issue_type" in allowed_fields and issue_type is not None:
|
| 580 |
+
if (
|
| 581 |
+
issue_type in {"billing_license", "general_inquiry"}
|
| 582 |
+
and any(keyword in text for keyword in PRICING_REQUEST_KEYWORDS)
|
| 583 |
+
):
|
| 584 |
+
updated["issue_type"] = "service_request"
|
| 585 |
+
issue_type = "service_request"
|
| 586 |
+
reasons.append("override_issue_type=service_request(pricing_request)")
|
| 587 |
+
elif (
|
| 588 |
+
issue_type == "identity_access"
|
| 589 |
+
and any(keyword in text for keyword in ONBOARDING_WORKFLOW_KEYWORDS)
|
| 590 |
+
and any(keyword in text for keyword in ACCESS_BLOCKER_KEYWORDS)
|
| 591 |
+
):
|
| 592 |
+
updated["issue_type"] = "onboarding"
|
| 593 |
+
issue_type = "onboarding"
|
| 594 |
+
reasons.append("override_issue_type=onboarding(onboarding_access_blocker)")
|
| 595 |
+
|
| 596 |
+
if issue_type is not None:
|
| 597 |
+
if "assignment_group" in allowed_fields:
|
| 598 |
+
desired_group = heuristic_assignment_group(text, issue_type)
|
| 599 |
+
if updated.get("assignment_group") != desired_group:
|
| 600 |
+
updated["assignment_group"] = desired_group
|
| 601 |
+
reasons.append(f"override_assignment_group={desired_group}")
|
| 602 |
+
if "resolution_action" in allowed_fields:
|
| 603 |
+
desired_resolution = heuristic_resolution_action(text, issue_type)
|
| 604 |
+
if updated.get("resolution_action") != desired_resolution:
|
| 605 |
+
updated["resolution_action"] = desired_resolution
|
| 606 |
+
reasons.append(f"override_resolution_action={desired_resolution}")
|
| 607 |
+
|
| 608 |
+
if "priority" in allowed_fields and updated.get("priority") is not None:
|
| 609 |
+
priority = updated["priority"]
|
| 610 |
+
has_critical_signal = any(keyword in text for keyword in CRITICAL_INCIDENT_KEYWORDS)
|
| 611 |
+
has_high_signal = any(keyword in text for keyword in HIGH_PRIORITY_SIGNAL_KEYWORDS)
|
| 612 |
+
if priority == "critical" and not has_critical_signal:
|
| 613 |
+
updated["priority"] = "high" if has_high_signal else "medium"
|
| 614 |
+
reasons.append(f"override_priority={updated['priority']}(deescalated_from_critical)")
|
| 615 |
+
elif (
|
| 616 |
+
priority == "high"
|
| 617 |
+
and issue_type in {"service_request", "onboarding"}
|
| 618 |
+
and not any(keyword in text for keyword in TIME_SENSITIVE_PRIORITY_KEYWORDS)
|
| 619 |
+
):
|
| 620 |
+
updated["priority"] = "medium"
|
| 621 |
+
reasons.append("override_priority=medium(nonurgent_workflow_request)")
|
| 622 |
+
elif (
|
| 623 |
+
priority == "medium"
|
| 624 |
+
and issue_type == "identity_access"
|
| 625 |
+
and any(keyword in text for keyword in ("cannot sign in", "can't sign in", "2fa", "mfa", "locked"))
|
| 626 |
+
and not has_critical_signal
|
| 627 |
+
):
|
| 628 |
+
updated["priority"] = "high"
|
| 629 |
+
reasons.append("override_priority=high(identity_lockout)")
|
| 630 |
+
|
| 631 |
+
return updated, reasons
|
| 632 |
+
|
| 633 |
+
|
| 634 |
def build_action(
|
| 635 |
ticket: dict, allowed_fields: list[str], instructions: str
|
| 636 |
) -> tuple[HelpdeskTicketAction, str, str | None]:
|
|
|
|
| 641 |
|
| 642 |
try:
|
| 643 |
llm_dict = call_llm(ticket, allowed_fields, instructions)
|
| 644 |
+
validated_llm_fields: dict[str, Any] = {}
|
| 645 |
+
rejected_fields: list[str] = []
|
| 646 |
+
for field in allowed_fields:
|
| 647 |
+
value = llm_dict.get(field)
|
| 648 |
+
if value is None:
|
| 649 |
+
continue
|
| 650 |
+
try:
|
| 651 |
+
HelpdeskTicketAction(**{field: value})
|
| 652 |
+
except Exception:
|
| 653 |
+
rejected_fields.append(field)
|
| 654 |
+
continue
|
| 655 |
+
validated_llm_fields[field] = value
|
| 656 |
+
|
| 657 |
+
if not validated_llm_fields:
|
| 658 |
raise ValueError("LLM returned no allowed fields")
|
| 659 |
+
|
| 660 |
+
candidate = heuristic_action(
|
| 661 |
+
ticket,
|
| 662 |
+
allowed_fields,
|
| 663 |
+
issue_type_override=validated_llm_fields.get("issue_type"),
|
| 664 |
+
)
|
| 665 |
+
candidate.update(validated_llm_fields)
|
| 666 |
+
accepted_fields = list(validated_llm_fields)
|
| 667 |
+
candidate, override_reasons = apply_domain_overrides(
|
| 668 |
+
ticket,
|
| 669 |
+
candidate,
|
| 670 |
+
allowed_fields,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
backfilled_fields = [field for field in allowed_fields if field not in accepted_fields]
|
| 674 |
+
if backfilled_fields or rejected_fields or override_reasons:
|
| 675 |
+
reason_parts = []
|
| 676 |
+
if backfilled_fields:
|
| 677 |
+
reason_parts.append(f"heuristic_backfill={backfilled_fields}")
|
| 678 |
+
if rejected_fields:
|
| 679 |
+
reason_parts.append(f"invalid_llm_fields={rejected_fields}")
|
| 680 |
+
if override_reasons:
|
| 681 |
+
reason_parts.append(f"domain_overrides={override_reasons}")
|
| 682 |
+
return (
|
| 683 |
+
HelpdeskTicketAction(**candidate),
|
| 684 |
+
"llm_backfilled",
|
| 685 |
+
"; ".join(reason_parts),
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
return HelpdeskTicketAction(**candidate), "llm", None
|
| 689 |
except Exception as exc:
|
| 690 |
return (
|
|
|
|
| 697 |
def should_investigate(ticket: dict, history: list[dict[str, Any]]) -> tuple[bool, str | None]:
|
| 698 |
if not ticket:
|
| 699 |
return False, None
|
| 700 |
+
context_status = ticket.get("context_status") or {}
|
| 701 |
+
remaining_tools = context_status.get("remaining_tools") or []
|
| 702 |
+
if remaining_tools:
|
| 703 |
+
return True, str(remaining_tools[0])
|
| 704 |
current_ticket_id = ticket.get("ticket_id")
|
| 705 |
already_investigated = any(
|
| 706 |
entry.get("ticket_id") == current_ticket_id
|
|
|
|
| 720 |
merged_ticket = dict(ticket)
|
| 721 |
if getattr(observation, "last_tool_result", None) is not None:
|
| 722 |
merged_ticket["last_tool_result"] = observation.last_tool_result
|
| 723 |
+
merged_ticket["recent_history"] = list(getattr(observation, "history", []))
|
| 724 |
+
merged_ticket["queue_position"] = getattr(observation, "queue_position", None)
|
| 725 |
+
merged_ticket["tickets_remaining"] = getattr(observation, "tickets_remaining", None)
|
| 726 |
+
merged_ticket["investigation_budget_remaining"] = getattr(
|
| 727 |
+
observation,
|
| 728 |
+
"investigation_budget_remaining",
|
| 729 |
+
None,
|
| 730 |
+
)
|
| 731 |
+
merged_ticket["average_score_so_far"] = getattr(observation, "average_score_so_far", None)
|
| 732 |
+
merged_ticket["progress_fraction"] = getattr(observation, "progress_fraction", None)
|
| 733 |
+
merged_ticket["last_reward_components"] = dict(
|
| 734 |
+
getattr(observation, "last_reward_components", {}) or {}
|
| 735 |
+
)
|
| 736 |
+
observation_metadata = getattr(observation, "metadata", {}) or {}
|
| 737 |
+
if observation_metadata.get("last_feedback_summary"):
|
| 738 |
+
merged_ticket["feedback_summary"] = observation_metadata["last_feedback_summary"]
|
| 739 |
return merged_ticket
|
| 740 |
|
| 741 |
|
|
|
|
| 846 |
ticket_id=ticket["ticket_id"],
|
| 847 |
)
|
| 848 |
|
| 849 |
+
final_rubric_reward = getattr(obs, "rubric_reward", None)
|
| 850 |
+
final_reward = (
|
| 851 |
+
float(final_rubric_reward)
|
| 852 |
+
if final_rubric_reward is not None
|
| 853 |
+
else (task_step_rewards[-1] if task_step_rewards else 0.0)
|
| 854 |
+
)
|
| 855 |
all_results[task_id] = {
|
| 856 |
"final_reward": final_reward,
|
| 857 |
"step_count": step_num,
|
models.py
CHANGED
|
@@ -18,6 +18,7 @@ ASSIGNMENT_GROUP_SET = set(ASSIGNMENT_GROUPS)
|
|
| 18 |
RESOLUTION_ACTION_SET = set(RESOLUTION_ACTIONS)
|
| 19 |
ACTION_TYPE_SET = {"submit", "investigate"}
|
| 20 |
TOOL_NAME_SET = {"lookup_related_ticket", "lookup_requester_history"}
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def _validate_choice(value: str, allowed: set[str], field_name: str) -> str:
|
|
@@ -113,6 +114,7 @@ class HelpdeskTicketObservation(Observation):
|
|
| 113 |
task_name: str = ""
|
| 114 |
instructions: str = ""
|
| 115 |
allowed_fields: list[str] = Field(default_factory=list)
|
|
|
|
| 116 |
available_tools: list[str] = Field(default_factory=list)
|
| 117 |
investigation_budget_remaining: int = 0
|
| 118 |
last_tool_result: Optional[dict[str, Any]] = None
|
|
@@ -122,7 +124,11 @@ class HelpdeskTicketObservation(Observation):
|
|
| 122 |
tickets_after_current: int = 0
|
| 123 |
tickets_processed: int = 0
|
| 124 |
queue_position: int = 0
|
|
|
|
|
|
|
| 125 |
history: list[dict[str, Any]] = Field(default_factory=list)
|
|
|
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
class HelpdeskTicketState(State):
|
|
@@ -136,7 +142,11 @@ class HelpdeskTicketState(State):
|
|
| 136 |
# `reward` is the field the evaluator checks on GET /state (mentor spec)
|
| 137 |
reward: Optional[float] = None
|
| 138 |
done: bool = False
|
|
|
|
| 139 |
investigation_steps: int = 0
|
| 140 |
investigation_budget_remaining: int = 0
|
|
|
|
| 141 |
last_tool_result: Optional[dict[str, Any]] = None
|
|
|
|
|
|
|
| 142 |
history_entries: list[dict] = Field(default_factory=list)
|
|
|
|
| 18 |
RESOLUTION_ACTION_SET = set(RESOLUTION_ACTIONS)
|
| 19 |
ACTION_TYPE_SET = {"submit", "investigate"}
|
| 20 |
TOOL_NAME_SET = {"lookup_related_ticket", "lookup_requester_history"}
|
| 21 |
+
TOOL_NAME_SET.add("lookup_internal_routing_note")
|
| 22 |
|
| 23 |
|
| 24 |
def _validate_choice(value: str, allowed: set[str], field_name: str) -> str:
|
|
|
|
| 114 |
task_name: str = ""
|
| 115 |
instructions: str = ""
|
| 116 |
allowed_fields: list[str] = Field(default_factory=list)
|
| 117 |
+
available_action_types: list[str] = Field(default_factory=list)
|
| 118 |
available_tools: list[str] = Field(default_factory=list)
|
| 119 |
investigation_budget_remaining: int = 0
|
| 120 |
last_tool_result: Optional[dict[str, Any]] = None
|
|
|
|
| 124 |
tickets_after_current: int = 0
|
| 125 |
tickets_processed: int = 0
|
| 126 |
queue_position: int = 0
|
| 127 |
+
average_score_so_far: float = 0.0
|
| 128 |
+
progress_fraction: float = 0.0
|
| 129 |
history: list[dict[str, Any]] = Field(default_factory=list)
|
| 130 |
+
last_reward_components: dict[str, Any] = Field(default_factory=dict)
|
| 131 |
+
rubric_reward: Optional[float] = None
|
| 132 |
|
| 133 |
|
| 134 |
class HelpdeskTicketState(State):
|
|
|
|
| 142 |
# `reward` is the field the evaluator checks on GET /state (mentor spec)
|
| 143 |
reward: Optional[float] = None
|
| 144 |
done: bool = False
|
| 145 |
+
average_score_so_far: float = 0.0
|
| 146 |
investigation_steps: int = 0
|
| 147 |
investigation_budget_remaining: int = 0
|
| 148 |
+
investigation_penalty_applied: float = 0.0
|
| 149 |
last_tool_result: Optional[dict[str, Any]] = None
|
| 150 |
+
last_reward_components: dict[str, Any] = Field(default_factory=dict)
|
| 151 |
+
ticket_tool_usage: dict[str, list[str]] = Field(default_factory=dict)
|
| 152 |
history_entries: list[dict] = Field(default_factory=list)
|
policy_learning.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import importlib
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import asdict, dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from statistics import mean
|
| 10 |
+
from typing import Any, Callable, Iterable
|
| 11 |
+
|
| 12 |
+
from models import HelpdeskTicketAction, HelpdeskTicketObservation
|
| 13 |
+
from server.environment import HelpdeskTicketRoutingEnvironment
|
| 14 |
+
from server.tasks import get_task_definition
|
| 15 |
+
from vocabulary import TASK_IDS
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DEFAULT_COMPARE_POLICIES = (
|
| 19 |
+
"no_investigation",
|
| 20 |
+
"investigate_when_context_hidden",
|
| 21 |
+
)
|
| 22 |
+
DEFAULT_SEARCH_POLICIES = (
|
| 23 |
+
"no_investigation",
|
| 24 |
+
"legacy_single_probe",
|
| 25 |
+
"investigate_when_context_hidden",
|
| 26 |
+
"context_chain",
|
| 27 |
+
"hybrid_context",
|
| 28 |
+
)
|
| 29 |
+
DEFAULT_OUTPUT_DIR = "analysis/policy_learning_runs"
|
| 30 |
+
|
| 31 |
+
SubmitBuilder = Callable[[dict[str, Any], list[str]], HelpdeskTicketAction]
|
| 32 |
+
EnvFactory = Callable[[], HelpdeskTicketRoutingEnvironment]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class PolicyConfig:
|
| 37 |
+
name: str
|
| 38 |
+
investigate_hidden_context: bool
|
| 39 |
+
investigate_related_ticket_hint: bool
|
| 40 |
+
investigate_ambiguity_history: bool
|
| 41 |
+
max_investigations_per_ticket: int
|
| 42 |
+
description: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
POLICY_LIBRARY: dict[str, PolicyConfig] = {
|
| 46 |
+
"no_investigation": PolicyConfig(
|
| 47 |
+
name="no_investigation",
|
| 48 |
+
investigate_hidden_context=False,
|
| 49 |
+
investigate_related_ticket_hint=False,
|
| 50 |
+
investigate_ambiguity_history=False,
|
| 51 |
+
max_investigations_per_ticket=0,
|
| 52 |
+
description="Always submit immediately and never investigate.",
|
| 53 |
+
),
|
| 54 |
+
"legacy_single_probe": PolicyConfig(
|
| 55 |
+
name="legacy_single_probe",
|
| 56 |
+
investigate_hidden_context=False,
|
| 57 |
+
investigate_related_ticket_hint=True,
|
| 58 |
+
investigate_ambiguity_history=True,
|
| 59 |
+
max_investigations_per_ticket=1,
|
| 60 |
+
description="Mimics the earlier single-tool hint policy.",
|
| 61 |
+
),
|
| 62 |
+
"investigate_when_context_hidden": PolicyConfig(
|
| 63 |
+
name="investigate_when_context_hidden",
|
| 64 |
+
investigate_hidden_context=True,
|
| 65 |
+
investigate_related_ticket_hint=False,
|
| 66 |
+
investigate_ambiguity_history=False,
|
| 67 |
+
max_investigations_per_ticket=1,
|
| 68 |
+
description="Investigate once when the environment says context is hidden.",
|
| 69 |
+
),
|
| 70 |
+
"context_chain": PolicyConfig(
|
| 71 |
+
name="context_chain",
|
| 72 |
+
investigate_hidden_context=True,
|
| 73 |
+
investigate_related_ticket_hint=False,
|
| 74 |
+
investigate_ambiguity_history=False,
|
| 75 |
+
max_investigations_per_ticket=3,
|
| 76 |
+
description="Follow the environment's required-tool chain until context is revealed.",
|
| 77 |
+
),
|
| 78 |
+
"hybrid_context": PolicyConfig(
|
| 79 |
+
name="hybrid_context",
|
| 80 |
+
investigate_hidden_context=True,
|
| 81 |
+
investigate_related_ticket_hint=True,
|
| 82 |
+
investigate_ambiguity_history=True,
|
| 83 |
+
max_investigations_per_ticket=3,
|
| 84 |
+
description="Use hidden-context signals first, then legacy ambiguity hints.",
|
| 85 |
+
),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _dedupe_preserving_order(values: Iterable[int]) -> list[int]:
|
| 90 |
+
seen: set[int] = set()
|
| 91 |
+
ordered: list[int] = []
|
| 92 |
+
for value in values:
|
| 93 |
+
if value in seen:
|
| 94 |
+
continue
|
| 95 |
+
seen.add(value)
|
| 96 |
+
ordered.append(value)
|
| 97 |
+
return ordered
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def parse_int_spec(spec: str, *, field_name: str) -> list[int]:
|
| 101 |
+
values: list[int] = []
|
| 102 |
+
for chunk in spec.split(","):
|
| 103 |
+
part = chunk.strip()
|
| 104 |
+
if not part:
|
| 105 |
+
continue
|
| 106 |
+
if "-" in part:
|
| 107 |
+
start_raw, end_raw = part.split("-", 1)
|
| 108 |
+
try:
|
| 109 |
+
start = int(start_raw)
|
| 110 |
+
end = int(end_raw)
|
| 111 |
+
except ValueError as exc:
|
| 112 |
+
raise ValueError(f"{field_name} contains an invalid range: {part!r}") from exc
|
| 113 |
+
if end < start:
|
| 114 |
+
raise ValueError(f"{field_name} range must be ascending: {part!r}")
|
| 115 |
+
values.extend(range(start, end + 1))
|
| 116 |
+
continue
|
| 117 |
+
try:
|
| 118 |
+
values.append(int(part))
|
| 119 |
+
except ValueError as exc:
|
| 120 |
+
raise ValueError(f"{field_name} contains an invalid integer: {part!r}") from exc
|
| 121 |
+
if not values:
|
| 122 |
+
raise ValueError(f"{field_name} must not be empty")
|
| 123 |
+
return _dedupe_preserving_order(values)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def parse_task_ids(spec: str) -> list[int]:
|
| 127 |
+
task_ids = parse_int_spec(spec, field_name="task_ids")
|
| 128 |
+
unsupported = [task_id for task_id in task_ids if task_id not in TASK_IDS]
|
| 129 |
+
if unsupported:
|
| 130 |
+
raise ValueError(f"Unsupported task_ids: {unsupported}")
|
| 131 |
+
return task_ids
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def resolve_policies(spec: str) -> list[PolicyConfig]:
|
| 135 |
+
names = [name.strip() for name in spec.split(",") if name.strip()]
|
| 136 |
+
if not names:
|
| 137 |
+
raise ValueError("At least one policy must be specified")
|
| 138 |
+
policies: list[PolicyConfig] = []
|
| 139 |
+
for name in names:
|
| 140 |
+
if name not in POLICY_LIBRARY:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"Unknown policy {name!r}. Available policies: {sorted(POLICY_LIBRARY)}"
|
| 143 |
+
)
|
| 144 |
+
policies.append(POLICY_LIBRARY[name])
|
| 145 |
+
return policies
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def default_submit_builder(
|
| 149 |
+
ticket: dict[str, Any], allowed_fields: list[str]
|
| 150 |
+
) -> HelpdeskTicketAction:
|
| 151 |
+
inference = importlib.import_module("inference")
|
| 152 |
+
candidate = inference.heuristic_action(ticket, allowed_fields)
|
| 153 |
+
candidate, _ = inference.apply_domain_overrides(ticket, candidate, allowed_fields)
|
| 154 |
+
return HelpdeskTicketAction(**candidate)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def choose_policy_action(
|
| 158 |
+
policy: PolicyConfig,
|
| 159 |
+
observation: HelpdeskTicketObservation,
|
| 160 |
+
investigations_by_ticket: dict[str, int],
|
| 161 |
+
submit_builder: SubmitBuilder,
|
| 162 |
+
) -> tuple[HelpdeskTicketAction, str]:
|
| 163 |
+
ticket = observation.current_ticket or {}
|
| 164 |
+
ticket_id = str(ticket.get("ticket_id", ""))
|
| 165 |
+
ticket_investigations = investigations_by_ticket.get(ticket_id, 0)
|
| 166 |
+
revealed_tools = set(((ticket.get("context_status") or {}).get("revealed_tools") or []))
|
| 167 |
+
remaining_tools = list(((ticket.get("context_status") or {}).get("remaining_tools") or []))
|
| 168 |
+
|
| 169 |
+
if ticket_investigations < policy.max_investigations_per_ticket:
|
| 170 |
+
if policy.investigate_hidden_context and remaining_tools:
|
| 171 |
+
tool_name = str(remaining_tools[0])
|
| 172 |
+
return (
|
| 173 |
+
HelpdeskTicketAction(action_type="investigate", tool_name=tool_name),
|
| 174 |
+
"investigate_hidden_context",
|
| 175 |
+
)
|
| 176 |
+
if (
|
| 177 |
+
policy.investigate_related_ticket_hint
|
| 178 |
+
and ticket.get("related_ticket_id")
|
| 179 |
+
and "lookup_related_ticket" not in revealed_tools
|
| 180 |
+
):
|
| 181 |
+
return (
|
| 182 |
+
HelpdeskTicketAction(
|
| 183 |
+
action_type="investigate",
|
| 184 |
+
tool_name="lookup_related_ticket",
|
| 185 |
+
),
|
| 186 |
+
"investigate_related_ticket_hint",
|
| 187 |
+
)
|
| 188 |
+
if (
|
| 189 |
+
policy.investigate_ambiguity_history
|
| 190 |
+
and ticket.get("ambiguity_note")
|
| 191 |
+
and "lookup_requester_history" not in revealed_tools
|
| 192 |
+
):
|
| 193 |
+
return (
|
| 194 |
+
HelpdeskTicketAction(
|
| 195 |
+
action_type="investigate",
|
| 196 |
+
tool_name="lookup_requester_history",
|
| 197 |
+
),
|
| 198 |
+
"investigate_ambiguity_history",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return submit_builder(ticket, list(observation.allowed_fields)), "submit"
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def rollout_episode(
|
| 205 |
+
*,
|
| 206 |
+
env: HelpdeskTicketRoutingEnvironment,
|
| 207 |
+
policy: PolicyConfig,
|
| 208 |
+
seed: int,
|
| 209 |
+
task_id: int,
|
| 210 |
+
submit_builder: SubmitBuilder,
|
| 211 |
+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
| 212 |
+
task = get_task_definition(task_id)
|
| 213 |
+
observation = env.reset(seed=seed, task_id=task_id)
|
| 214 |
+
investigations_by_ticket: dict[str, int] = {}
|
| 215 |
+
episode_return = 0.0
|
| 216 |
+
trajectories: list[dict[str, Any]] = []
|
| 217 |
+
|
| 218 |
+
while not observation.done:
|
| 219 |
+
ticket = observation.current_ticket or {}
|
| 220 |
+
ticket_id = str(ticket.get("ticket_id", ""))
|
| 221 |
+
action, action_source = choose_policy_action(
|
| 222 |
+
policy,
|
| 223 |
+
observation,
|
| 224 |
+
investigations_by_ticket,
|
| 225 |
+
submit_builder,
|
| 226 |
+
)
|
| 227 |
+
next_observation = env.step(action)
|
| 228 |
+
reward_value = float(next_observation.reward or 0.0)
|
| 229 |
+
episode_return += reward_value
|
| 230 |
+
if action.action_type == "investigate" and ticket_id:
|
| 231 |
+
investigations_by_ticket[ticket_id] = investigations_by_ticket.get(ticket_id, 0) + 1
|
| 232 |
+
|
| 233 |
+
history_entry = env.state.history_entries[-1] if env.state.history_entries else {}
|
| 234 |
+
trajectories.append(
|
| 235 |
+
{
|
| 236 |
+
"policy": policy.name,
|
| 237 |
+
"seed": seed,
|
| 238 |
+
"task_id": task_id,
|
| 239 |
+
"task_name": task["name"],
|
| 240 |
+
"episode_id": env.state.episode_id,
|
| 241 |
+
"step_index": len(trajectories) + 1,
|
| 242 |
+
"ticket_id": history_entry.get("ticket_id", ticket_id),
|
| 243 |
+
"action_source": action_source,
|
| 244 |
+
"action": action.model_dump(exclude_none=True),
|
| 245 |
+
"step_reward": reward_value,
|
| 246 |
+
"rubric_reward": next_observation.rubric_reward,
|
| 247 |
+
"done": next_observation.done,
|
| 248 |
+
"feedback_summary": history_entry.get("feedback_summary"),
|
| 249 |
+
"reward_kind": history_entry.get("reward_kind"),
|
| 250 |
+
"score": history_entry.get("score"),
|
| 251 |
+
"breakdown": history_entry.get("breakdown", {}),
|
| 252 |
+
"reward_components": history_entry.get("reward_components", {}),
|
| 253 |
+
"context_status_before_action": ticket.get("context_status"),
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
observation = next_observation
|
| 257 |
+
|
| 258 |
+
queue_size = max(1, len(env.state.queue_ticket_ids))
|
| 259 |
+
terminal_reward = float(observation.reward or 0.0)
|
| 260 |
+
terminal_rubric_reward = (
|
| 261 |
+
float(observation.rubric_reward)
|
| 262 |
+
if observation.rubric_reward is not None
|
| 263 |
+
else terminal_reward
|
| 264 |
+
)
|
| 265 |
+
summary = {
|
| 266 |
+
"policy": policy.name,
|
| 267 |
+
"policy_config": asdict(policy),
|
| 268 |
+
"seed": seed,
|
| 269 |
+
"task_id": task_id,
|
| 270 |
+
"task_name": task["name"],
|
| 271 |
+
"episode_id": env.state.episode_id,
|
| 272 |
+
"queue_size": queue_size,
|
| 273 |
+
"step_count": env.state.step_count,
|
| 274 |
+
"tickets_processed": len(env.state.per_ticket_scores),
|
| 275 |
+
"investigation_steps": env.state.investigation_steps,
|
| 276 |
+
"episode_return": episode_return,
|
| 277 |
+
"normalized_return": episode_return / queue_size,
|
| 278 |
+
"terminal_reward": terminal_reward,
|
| 279 |
+
"terminal_rubric_reward": terminal_rubric_reward,
|
| 280 |
+
"average_ticket_score": env.state.average_score_so_far,
|
| 281 |
+
"per_ticket_scores": list(env.state.per_ticket_scores),
|
| 282 |
+
}
|
| 283 |
+
return summary, trajectories
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _safe_mean(values: list[float]) -> float:
|
| 287 |
+
if not values:
|
| 288 |
+
return 0.0
|
| 289 |
+
return round(mean(values), 6)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def summarize_policy_episodes(
|
| 293 |
+
policy: PolicyConfig,
|
| 294 |
+
episode_summaries: list[dict[str, Any]],
|
| 295 |
+
) -> dict[str, Any]:
|
| 296 |
+
per_task: dict[str, Any] = {}
|
| 297 |
+
for task_id in TASK_IDS:
|
| 298 |
+
task_episodes = [
|
| 299 |
+
episode for episode in episode_summaries if episode["task_id"] == task_id
|
| 300 |
+
]
|
| 301 |
+
if not task_episodes:
|
| 302 |
+
continue
|
| 303 |
+
per_task[str(task_id)] = {
|
| 304 |
+
"episodes": len(task_episodes),
|
| 305 |
+
"avg_episode_return": _safe_mean(
|
| 306 |
+
[float(episode["episode_return"]) for episode in task_episodes]
|
| 307 |
+
),
|
| 308 |
+
"avg_normalized_return": _safe_mean(
|
| 309 |
+
[float(episode["normalized_return"]) for episode in task_episodes]
|
| 310 |
+
),
|
| 311 |
+
"avg_terminal_reward": _safe_mean(
|
| 312 |
+
[float(episode["terminal_reward"]) for episode in task_episodes]
|
| 313 |
+
),
|
| 314 |
+
"avg_terminal_rubric_reward": _safe_mean(
|
| 315 |
+
[float(episode["terminal_rubric_reward"]) for episode in task_episodes]
|
| 316 |
+
),
|
| 317 |
+
"avg_investigation_steps": _safe_mean(
|
| 318 |
+
[float(episode["investigation_steps"]) for episode in task_episodes]
|
| 319 |
+
),
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
return {
|
| 323 |
+
"policy": policy.name,
|
| 324 |
+
"config": asdict(policy),
|
| 325 |
+
"episodes": len(episode_summaries),
|
| 326 |
+
"avg_episode_return": _safe_mean(
|
| 327 |
+
[float(episode["episode_return"]) for episode in episode_summaries]
|
| 328 |
+
),
|
| 329 |
+
"avg_normalized_return": _safe_mean(
|
| 330 |
+
[float(episode["normalized_return"]) for episode in episode_summaries]
|
| 331 |
+
),
|
| 332 |
+
"avg_terminal_reward": _safe_mean(
|
| 333 |
+
[float(episode["terminal_reward"]) for episode in episode_summaries]
|
| 334 |
+
),
|
| 335 |
+
"avg_terminal_rubric_reward": _safe_mean(
|
| 336 |
+
[float(episode["terminal_rubric_reward"]) for episode in episode_summaries]
|
| 337 |
+
),
|
| 338 |
+
"avg_investigation_steps": _safe_mean(
|
| 339 |
+
[float(episode["investigation_steps"]) for episode in episode_summaries]
|
| 340 |
+
),
|
| 341 |
+
"avg_ticket_score": _safe_mean(
|
| 342 |
+
[float(episode["average_ticket_score"]) for episode in episode_summaries]
|
| 343 |
+
),
|
| 344 |
+
"per_task": per_task,
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def evaluate_policy(
|
| 349 |
+
policy: PolicyConfig,
|
| 350 |
+
seeds: Iterable[int],
|
| 351 |
+
task_ids: Iterable[int],
|
| 352 |
+
*,
|
| 353 |
+
env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
|
| 354 |
+
submit_builder: SubmitBuilder = default_submit_builder,
|
| 355 |
+
) -> dict[str, Any]:
|
| 356 |
+
episode_summaries: list[dict[str, Any]] = []
|
| 357 |
+
trajectories: list[dict[str, Any]] = []
|
| 358 |
+
|
| 359 |
+
for seed in seeds:
|
| 360 |
+
for task_id in task_ids:
|
| 361 |
+
env = env_factory()
|
| 362 |
+
summary, episode_trajectories = rollout_episode(
|
| 363 |
+
env=env,
|
| 364 |
+
policy=policy,
|
| 365 |
+
seed=seed,
|
| 366 |
+
task_id=task_id,
|
| 367 |
+
submit_builder=submit_builder,
|
| 368 |
+
)
|
| 369 |
+
episode_summaries.append(summary)
|
| 370 |
+
trajectories.extend(episode_trajectories)
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
"policy": policy.name,
|
| 374 |
+
"summary": summarize_policy_episodes(policy, episode_summaries),
|
| 375 |
+
"episodes": episode_summaries,
|
| 376 |
+
"trajectories": trajectories,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _selection_tuple(summary: dict[str, Any]) -> tuple[float, float, float, float]:
|
| 381 |
+
return (
|
| 382 |
+
float(summary["avg_normalized_return"]),
|
| 383 |
+
float(summary["avg_terminal_reward"]),
|
| 384 |
+
float(summary["avg_terminal_rubric_reward"]),
|
| 385 |
+
-float(summary["avg_investigation_steps"]),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def select_best_policy(policy_runs: list[dict[str, Any]]) -> dict[str, Any]:
|
| 390 |
+
return max(policy_runs, key=lambda run: _selection_tuple(run["summary"]))
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _delta(best: dict[str, Any], baseline: dict[str, Any], key: str) -> float:
|
| 394 |
+
return round(float(best[key]) - float(baseline[key]), 6)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _write_json(path: Path, payload: dict[str, Any]) -> None:
|
| 398 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 399 |
+
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _write_jsonl(path: Path, records: Iterable[dict[str, Any]]) -> None:
|
| 403 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 404 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 405 |
+
for record in records:
|
| 406 |
+
handle.write(json.dumps(record, sort_keys=True) + "\n")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def compare_policies(
|
| 410 |
+
policies: list[PolicyConfig],
|
| 411 |
+
seeds: list[int],
|
| 412 |
+
task_ids: list[int],
|
| 413 |
+
*,
|
| 414 |
+
output_dir: Path,
|
| 415 |
+
env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
|
| 416 |
+
submit_builder: SubmitBuilder = default_submit_builder,
|
| 417 |
+
) -> dict[str, Any]:
|
| 418 |
+
output_dir = Path(output_dir)
|
| 419 |
+
policy_runs = [
|
| 420 |
+
evaluate_policy(
|
| 421 |
+
policy,
|
| 422 |
+
seeds,
|
| 423 |
+
task_ids,
|
| 424 |
+
env_factory=env_factory,
|
| 425 |
+
submit_builder=submit_builder,
|
| 426 |
+
)
|
| 427 |
+
for policy in policies
|
| 428 |
+
]
|
| 429 |
+
best_run = select_best_policy(policy_runs)
|
| 430 |
+
baseline_run = policy_runs[0]
|
| 431 |
+
|
| 432 |
+
report = {
|
| 433 |
+
"mode": "compare",
|
| 434 |
+
"task_ids": task_ids,
|
| 435 |
+
"seeds": seeds,
|
| 436 |
+
"selection_metric": "avg_normalized_return",
|
| 437 |
+
"baseline_policy": baseline_run["policy"],
|
| 438 |
+
"best_policy": best_run["policy"],
|
| 439 |
+
"improvement_vs_baseline": {
|
| 440 |
+
"avg_episode_return": _delta(
|
| 441 |
+
best_run["summary"], baseline_run["summary"], "avg_episode_return"
|
| 442 |
+
),
|
| 443 |
+
"avg_normalized_return": _delta(
|
| 444 |
+
best_run["summary"], baseline_run["summary"], "avg_normalized_return"
|
| 445 |
+
),
|
| 446 |
+
"avg_terminal_reward": _delta(
|
| 447 |
+
best_run["summary"], baseline_run["summary"], "avg_terminal_reward"
|
| 448 |
+
),
|
| 449 |
+
"avg_terminal_rubric_reward": _delta(
|
| 450 |
+
best_run["summary"],
|
| 451 |
+
baseline_run["summary"],
|
| 452 |
+
"avg_terminal_rubric_reward",
|
| 453 |
+
),
|
| 454 |
+
},
|
| 455 |
+
"policy_summaries": [run["summary"] for run in policy_runs],
|
| 456 |
+
"ranking": [
|
| 457 |
+
run["policy"]
|
| 458 |
+
for run in sorted(
|
| 459 |
+
policy_runs,
|
| 460 |
+
key=lambda run: _selection_tuple(run["summary"]),
|
| 461 |
+
reverse=True,
|
| 462 |
+
)
|
| 463 |
+
],
|
| 464 |
+
"artifacts": {
|
| 465 |
+
"summary": str(output_dir / "compare_summary.json"),
|
| 466 |
+
"episodes": str(output_dir / "compare_episodes.jsonl"),
|
| 467 |
+
"trajectories": str(output_dir / "compare_trajectories.jsonl"),
|
| 468 |
+
},
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
_write_json(output_dir / "compare_summary.json", report)
|
| 472 |
+
_write_jsonl(
|
| 473 |
+
output_dir / "compare_episodes.jsonl",
|
| 474 |
+
(
|
| 475 |
+
{"policy": run["policy"], **episode}
|
| 476 |
+
for run in policy_runs
|
| 477 |
+
for episode in run["episodes"]
|
| 478 |
+
),
|
| 479 |
+
)
|
| 480 |
+
_write_jsonl(
|
| 481 |
+
output_dir / "compare_trajectories.jsonl",
|
| 482 |
+
(trajectory for run in policy_runs for trajectory in run["trajectories"]),
|
| 483 |
+
)
|
| 484 |
+
return report
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def search_policies(
|
| 488 |
+
candidate_policies: list[PolicyConfig],
|
| 489 |
+
train_seeds: list[int],
|
| 490 |
+
eval_seeds: list[int],
|
| 491 |
+
task_ids: list[int],
|
| 492 |
+
*,
|
| 493 |
+
output_dir: Path,
|
| 494 |
+
env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
|
| 495 |
+
submit_builder: SubmitBuilder = default_submit_builder,
|
| 496 |
+
baseline_policy_name: str = "no_investigation",
|
| 497 |
+
) -> dict[str, Any]:
|
| 498 |
+
output_dir = Path(output_dir)
|
| 499 |
+
train_runs = [
|
| 500 |
+
evaluate_policy(
|
| 501 |
+
policy,
|
| 502 |
+
train_seeds,
|
| 503 |
+
task_ids,
|
| 504 |
+
env_factory=env_factory,
|
| 505 |
+
submit_builder=submit_builder,
|
| 506 |
+
)
|
| 507 |
+
for policy in candidate_policies
|
| 508 |
+
]
|
| 509 |
+
selected_run = select_best_policy(train_runs)
|
| 510 |
+
selected_policy = POLICY_LIBRARY[selected_run["policy"]]
|
| 511 |
+
eval_selected = evaluate_policy(
|
| 512 |
+
selected_policy,
|
| 513 |
+
eval_seeds,
|
| 514 |
+
task_ids,
|
| 515 |
+
env_factory=env_factory,
|
| 516 |
+
submit_builder=submit_builder,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
baseline_policy = POLICY_LIBRARY.get(baseline_policy_name, candidate_policies[0])
|
| 520 |
+
eval_baseline = evaluate_policy(
|
| 521 |
+
baseline_policy,
|
| 522 |
+
eval_seeds,
|
| 523 |
+
task_ids,
|
| 524 |
+
env_factory=env_factory,
|
| 525 |
+
submit_builder=submit_builder,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
report = {
|
| 529 |
+
"mode": "search",
|
| 530 |
+
"task_ids": task_ids,
|
| 531 |
+
"train_seeds": train_seeds,
|
| 532 |
+
"eval_seeds": eval_seeds,
|
| 533 |
+
"selection_metric": "avg_normalized_return",
|
| 534 |
+
"candidate_policies": [policy.name for policy in candidate_policies],
|
| 535 |
+
"selected_policy": selected_policy.name,
|
| 536 |
+
"baseline_policy": baseline_policy.name,
|
| 537 |
+
"train_policy_summaries": [run["summary"] for run in train_runs],
|
| 538 |
+
"eval_selected_summary": eval_selected["summary"],
|
| 539 |
+
"eval_baseline_summary": eval_baseline["summary"],
|
| 540 |
+
"eval_improvement_vs_baseline": {
|
| 541 |
+
"avg_episode_return": _delta(
|
| 542 |
+
eval_selected["summary"],
|
| 543 |
+
eval_baseline["summary"],
|
| 544 |
+
"avg_episode_return",
|
| 545 |
+
),
|
| 546 |
+
"avg_normalized_return": _delta(
|
| 547 |
+
eval_selected["summary"],
|
| 548 |
+
eval_baseline["summary"],
|
| 549 |
+
"avg_normalized_return",
|
| 550 |
+
),
|
| 551 |
+
"avg_terminal_reward": _delta(
|
| 552 |
+
eval_selected["summary"],
|
| 553 |
+
eval_baseline["summary"],
|
| 554 |
+
"avg_terminal_reward",
|
| 555 |
+
),
|
| 556 |
+
"avg_terminal_rubric_reward": _delta(
|
| 557 |
+
eval_selected["summary"],
|
| 558 |
+
eval_baseline["summary"],
|
| 559 |
+
"avg_terminal_rubric_reward",
|
| 560 |
+
),
|
| 561 |
+
},
|
| 562 |
+
"artifacts": {
|
| 563 |
+
"summary": str(output_dir / "search_summary.json"),
|
| 564 |
+
"train_episodes": str(output_dir / "search_train_episodes.jsonl"),
|
| 565 |
+
"train_trajectories": str(output_dir / "search_train_trajectories.jsonl"),
|
| 566 |
+
"eval_episodes": str(output_dir / "search_eval_episodes.jsonl"),
|
| 567 |
+
"eval_trajectories": str(output_dir / "search_eval_trajectories.jsonl"),
|
| 568 |
+
},
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
_write_json(output_dir / "search_summary.json", report)
|
| 572 |
+
_write_jsonl(
|
| 573 |
+
output_dir / "search_train_episodes.jsonl",
|
| 574 |
+
(
|
| 575 |
+
{"policy": run["policy"], **episode}
|
| 576 |
+
for run in train_runs
|
| 577 |
+
for episode in run["episodes"]
|
| 578 |
+
),
|
| 579 |
+
)
|
| 580 |
+
_write_jsonl(
|
| 581 |
+
output_dir / "search_train_trajectories.jsonl",
|
| 582 |
+
(trajectory for run in train_runs for trajectory in run["trajectories"]),
|
| 583 |
+
)
|
| 584 |
+
_write_jsonl(
|
| 585 |
+
output_dir / "search_eval_episodes.jsonl",
|
| 586 |
+
(
|
| 587 |
+
{"policy": eval_selected["policy"], **episode}
|
| 588 |
+
for episode in eval_selected["episodes"]
|
| 589 |
+
),
|
| 590 |
+
)
|
| 591 |
+
_write_jsonl(
|
| 592 |
+
output_dir / "search_eval_trajectories.jsonl",
|
| 593 |
+
(trajectory for trajectory in eval_selected["trajectories"]),
|
| 594 |
+
)
|
| 595 |
+
return report
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 599 |
+
parser = argparse.ArgumentParser(
|
| 600 |
+
description=(
|
| 601 |
+
"Run seeded local rollouts and a small policy-improvement loop for the "
|
| 602 |
+
"IT helpdesk OpenEnv environment."
|
| 603 |
+
)
|
| 604 |
+
)
|
| 605 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 606 |
+
|
| 607 |
+
compare_parser = subparsers.add_parser(
|
| 608 |
+
"compare",
|
| 609 |
+
help="Compare fixed policy choices across repeated seeded rollouts.",
|
| 610 |
+
)
|
| 611 |
+
compare_parser.add_argument(
|
| 612 |
+
"--policies",
|
| 613 |
+
default=",".join(DEFAULT_COMPARE_POLICIES),
|
| 614 |
+
help=f"Comma-separated policy names. Available: {', '.join(POLICY_LIBRARY)}",
|
| 615 |
+
)
|
| 616 |
+
compare_parser.add_argument(
|
| 617 |
+
"--seeds",
|
| 618 |
+
default="42-51",
|
| 619 |
+
help="Comma-separated seeds or ranges, for example 42-51 or 42,50,60.",
|
| 620 |
+
)
|
| 621 |
+
compare_parser.add_argument(
|
| 622 |
+
"--task-ids",
|
| 623 |
+
default="1,2,3",
|
| 624 |
+
help="Comma-separated task IDs or ranges, for example 1,2,3 or 1-3.",
|
| 625 |
+
)
|
| 626 |
+
compare_parser.add_argument(
|
| 627 |
+
"--output-dir",
|
| 628 |
+
default=DEFAULT_OUTPUT_DIR,
|
| 629 |
+
help="Directory for JSON and JSONL artifacts.",
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
search_parser = subparsers.add_parser(
|
| 633 |
+
"search",
|
| 634 |
+
help="Select the best policy on train seeds, then re-evaluate on holdout seeds.",
|
| 635 |
+
)
|
| 636 |
+
search_parser.add_argument(
|
| 637 |
+
"--candidate-policies",
|
| 638 |
+
default=",".join(DEFAULT_SEARCH_POLICIES),
|
| 639 |
+
help=f"Comma-separated candidate policy names. Available: {', '.join(POLICY_LIBRARY)}",
|
| 640 |
+
)
|
| 641 |
+
search_parser.add_argument(
|
| 642 |
+
"--train-seeds",
|
| 643 |
+
default="40-49",
|
| 644 |
+
help="Train seeds used for reward-based policy selection.",
|
| 645 |
+
)
|
| 646 |
+
search_parser.add_argument(
|
| 647 |
+
"--eval-seeds",
|
| 648 |
+
default="50-59",
|
| 649 |
+
help="Holdout seeds used for the selected policy evaluation.",
|
| 650 |
+
)
|
| 651 |
+
search_parser.add_argument(
|
| 652 |
+
"--task-ids",
|
| 653 |
+
default="1,2,3",
|
| 654 |
+
help="Comma-separated task IDs or ranges, for example 1,2,3 or 1-3.",
|
| 655 |
+
)
|
| 656 |
+
search_parser.add_argument(
|
| 657 |
+
"--baseline-policy",
|
| 658 |
+
default="no_investigation",
|
| 659 |
+
help="Baseline policy used for the final improvement delta.",
|
| 660 |
+
)
|
| 661 |
+
search_parser.add_argument(
|
| 662 |
+
"--output-dir",
|
| 663 |
+
default=DEFAULT_OUTPUT_DIR,
|
| 664 |
+
help="Directory for JSON and JSONL artifacts.",
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
return parser
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def _print_summary(label: str, summary: dict[str, Any]) -> None:
|
| 671 |
+
print(
|
| 672 |
+
json.dumps(
|
| 673 |
+
{
|
| 674 |
+
label: {
|
| 675 |
+
"policy": summary["policy"],
|
| 676 |
+
"avg_episode_return": summary["avg_episode_return"],
|
| 677 |
+
"avg_normalized_return": summary["avg_normalized_return"],
|
| 678 |
+
"avg_terminal_reward": summary["avg_terminal_reward"],
|
| 679 |
+
"avg_terminal_rubric_reward": summary["avg_terminal_rubric_reward"],
|
| 680 |
+
"avg_investigation_steps": summary["avg_investigation_steps"],
|
| 681 |
+
}
|
| 682 |
+
},
|
| 683 |
+
sort_keys=True,
|
| 684 |
+
)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def main() -> None:
|
| 689 |
+
parser = build_parser()
|
| 690 |
+
args = parser.parse_args()
|
| 691 |
+
|
| 692 |
+
output_dir = Path(args.output_dir)
|
| 693 |
+
|
| 694 |
+
if args.command == "compare":
|
| 695 |
+
policies = resolve_policies(args.policies)
|
| 696 |
+
seeds = parse_int_spec(args.seeds, field_name="seeds")
|
| 697 |
+
task_ids = parse_task_ids(args.task_ids)
|
| 698 |
+
report = compare_policies(
|
| 699 |
+
policies,
|
| 700 |
+
seeds,
|
| 701 |
+
task_ids,
|
| 702 |
+
output_dir=output_dir,
|
| 703 |
+
)
|
| 704 |
+
print(json.dumps(report, indent=2, sort_keys=True))
|
| 705 |
+
return
|
| 706 |
+
|
| 707 |
+
candidate_policies = resolve_policies(args.candidate_policies)
|
| 708 |
+
train_seeds = parse_int_spec(args.train_seeds, field_name="train_seeds")
|
| 709 |
+
eval_seeds = parse_int_spec(args.eval_seeds, field_name="eval_seeds")
|
| 710 |
+
task_ids = parse_task_ids(args.task_ids)
|
| 711 |
+
report = search_policies(
|
| 712 |
+
candidate_policies,
|
| 713 |
+
train_seeds,
|
| 714 |
+
eval_seeds,
|
| 715 |
+
task_ids,
|
| 716 |
+
output_dir=output_dir,
|
| 717 |
+
baseline_policy_name=args.baseline_policy,
|
| 718 |
+
)
|
| 719 |
+
print(json.dumps(report, indent=2, sort_keys=True))
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
if __name__ == "__main__":
|
| 723 |
+
main()
|
pyproject.toml
CHANGED
|
@@ -24,12 +24,13 @@ dependencies = [
|
|
| 24 |
|
| 25 |
[project.scripts]
|
| 26 |
server = "server.app:main"
|
|
|
|
| 27 |
|
| 28 |
[project.optional-dependencies]
|
| 29 |
dev = ["pytest", "httpx"]
|
| 30 |
|
| 31 |
[tool.setuptools]
|
| 32 |
-
py-modules = ["models", "client", "vocabulary"]
|
| 33 |
|
| 34 |
[tool.setuptools.packages.find]
|
| 35 |
include = ["server*"]
|
|
|
|
| 24 |
|
| 25 |
[project.scripts]
|
| 26 |
server = "server.app:main"
|
| 27 |
+
policy-learn = "policy_learning:main"
|
| 28 |
|
| 29 |
[project.optional-dependencies]
|
| 30 |
dev = ["pytest", "httpx"]
|
| 31 |
|
| 32 |
[tool.setuptools]
|
| 33 |
+
py-modules = ["models", "client", "policy_learning", "vocabulary"]
|
| 34 |
|
| 35 |
[tool.setuptools.packages.find]
|
| 36 |
include = ["server*"]
|
server/environment.py
CHANGED
|
@@ -18,10 +18,68 @@ from server.tasks import get_task_definition, load_dataset
|
|
| 18 |
|
| 19 |
|
| 20 |
QUEUE_SIZE_RANGE = (3, 5)
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
FREE_INVESTIGATIONS_PER_TICKET = 1
|
| 23 |
EXTRA_INVESTIGATION_COST = 0.02
|
| 24 |
MAX_EXTRA_INVESTIGATION_PENALTY = 0.15
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
|
|
@@ -86,7 +144,11 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 86 |
current_ticket_index=0,
|
| 87 |
per_ticket_scores=[],
|
| 88 |
total_reward=0.0,
|
|
|
|
| 89 |
investigation_budget_remaining=queue_size * FREE_INVESTIGATIONS_PER_TICKET,
|
|
|
|
|
|
|
|
|
|
| 90 |
)
|
| 91 |
|
| 92 |
return self._build_observation(task)
|
|
@@ -122,54 +184,104 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 122 |
if extra_fields:
|
| 123 |
# Penalty: record score 0.0, advance index, return penalty observation
|
| 124 |
self._state.per_ticket_scores.append(0.0)
|
| 125 |
-
self._state.
|
| 126 |
-
self._build_history_entry(
|
| 127 |
-
current_ticket,
|
| 128 |
-
predicted=action.model_dump(exclude_none=True),
|
| 129 |
-
score=0.0,
|
| 130 |
-
breakdown={},
|
| 131 |
-
queue_position=idx + 1,
|
| 132 |
-
penalty_reason=f"extra_fields: {sorted(extra_fields)}",
|
| 133 |
-
)
|
| 134 |
-
)
|
| 135 |
self._state.step_count += 1
|
| 136 |
self._state.current_ticket_index += 1
|
| 137 |
is_done = self._state.current_ticket_index >= len(self._queue)
|
| 138 |
self._state.done = is_done
|
|
|
|
|
|
|
| 139 |
if is_done:
|
| 140 |
-
|
| 141 |
self._state.per_ticket_scores, len(self._queue), self._state.step_count
|
| 142 |
)
|
| 143 |
-
final_reward = self._apply_episode_economics(
|
| 144 |
self._state.total_reward = final_reward
|
| 145 |
else:
|
| 146 |
final_reward = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
self._state.last_step_reward = final_reward
|
| 148 |
self._state.reward = final_reward
|
|
|
|
| 149 |
self._state.last_tool_result = None
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
score, breakdown = grade_action(action, current_ticket, task_id)
|
| 153 |
step_reward = compute_step_reward(score)
|
|
|
|
|
|
|
| 154 |
|
| 155 |
is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
if is_done:
|
| 158 |
self._state.per_ticket_scores.append(score)
|
|
|
|
| 159 |
self._state.step_count += 1
|
| 160 |
self._state.current_ticket_index += 1
|
| 161 |
-
|
| 162 |
self._state.per_ticket_scores,
|
| 163 |
len(self._queue),
|
| 164 |
self._state.step_count,
|
| 165 |
)
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
else:
|
| 169 |
self._state.per_ticket_scores.append(score)
|
|
|
|
| 170 |
self._state.step_count += 1
|
| 171 |
self._state.current_ticket_index += 1
|
| 172 |
-
final_reward = step_reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
history_entry = self._build_history_entry(
|
| 175 |
current_ticket,
|
|
@@ -177,15 +289,26 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 177 |
score=score,
|
| 178 |
breakdown=breakdown,
|
| 179 |
queue_position=idx + 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
)
|
| 181 |
self._state.history_entries.append(history_entry)
|
| 182 |
|
| 183 |
self._state.last_step_reward = final_reward
|
| 184 |
self._state.reward = final_reward
|
| 185 |
self._state.done = is_done
|
|
|
|
| 186 |
self._state.last_tool_result = None
|
|
|
|
| 187 |
|
| 188 |
-
return self._build_observation(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
@property
|
| 191 |
def state(self) -> HelpdeskTicketState:
|
|
@@ -195,15 +318,112 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 195 |
# Helpers
|
| 196 |
# ------------------------------------------------------------------
|
| 197 |
|
| 198 |
-
def
|
| 199 |
free_investigations = len(self._queue) * FREE_INVESTIGATIONS_PER_TICKET
|
| 200 |
extra_investigations = max(0, self._state.investigation_steps - free_investigations)
|
| 201 |
-
|
| 202 |
MAX_EXTRA_INVESTIGATION_PENALTY,
|
| 203 |
extra_investigations * EXTRA_INVESTIGATION_COST,
|
| 204 |
)
|
|
|
|
|
|
|
|
|
|
| 205 |
return max(0.0, min(1.0, base_reward - penalty))
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
def _lookup_related_ticket(
|
| 208 |
self,
|
| 209 |
current_ticket: HelpdeskTicketRecord,
|
|
@@ -259,6 +479,15 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 259 |
"matches": matches,
|
| 260 |
}
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
def _run_investigation_tool(
|
| 263 |
self,
|
| 264 |
current_ticket: HelpdeskTicketRecord,
|
|
@@ -269,6 +498,8 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 269 |
return self._lookup_related_ticket(current_ticket, target_ticket_id)
|
| 270 |
if tool_name == "lookup_requester_history":
|
| 271 |
return self._lookup_requester_history(current_ticket)
|
|
|
|
|
|
|
| 272 |
raise ValueError(f"Unsupported tool_name: {tool_name}")
|
| 273 |
|
| 274 |
def _handle_investigation_action(
|
|
@@ -296,6 +527,14 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 296 |
action.tool_name,
|
| 297 |
action.tool_target_ticket_id,
|
| 298 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
self._state.step_count += 1
|
| 300 |
self._state.investigation_steps += 1
|
| 301 |
self._state.investigation_budget_remaining = max(
|
|
@@ -303,9 +542,25 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 303 |
self._state.investigation_budget_remaining - 1,
|
| 304 |
)
|
| 305 |
self._state.last_tool_result = tool_result
|
| 306 |
-
|
| 307 |
-
self._state.
|
|
|
|
| 308 |
self._state.done = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
self._state.history_entries.append(
|
| 310 |
self._build_history_entry(
|
| 311 |
current_ticket,
|
|
@@ -313,21 +568,35 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 313 |
score=0.0,
|
| 314 |
breakdown={},
|
| 315 |
queue_position=idx + 1,
|
|
|
|
|
|
|
| 316 |
tool_result=tool_result,
|
|
|
|
| 317 |
)
|
| 318 |
)
|
| 319 |
-
|
|
|
|
| 320 |
|
| 321 |
def _build_ticket_view(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
| 322 |
ticket_view: dict[str, Any] = {
|
| 323 |
"ticket_id": ticket.ticket_id,
|
| 324 |
"title": ticket.title,
|
| 325 |
"requester": ticket.requester,
|
| 326 |
-
"description":
|
| 327 |
}
|
| 328 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
ticket_view["ambiguity_note"] = ticket.ambiguity_note
|
| 330 |
-
if ticket.related_ticket_id is not None:
|
| 331 |
ticket_view["related_ticket_id"] = ticket.related_ticket_id
|
| 332 |
related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
|
| 333 |
if related_ticket is not None:
|
|
@@ -339,6 +608,50 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 339 |
}
|
| 340 |
return ticket_view
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
def _build_history_entry(
|
| 343 |
self,
|
| 344 |
ticket: HelpdeskTicketRecord,
|
|
@@ -347,9 +660,15 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 347 |
score: float,
|
| 348 |
breakdown: dict[str, float],
|
| 349 |
queue_position: int,
|
|
|
|
|
|
|
|
|
|
| 350 |
penalty_reason: str | None = None,
|
| 351 |
tool_result: dict[str, Any] | None = None,
|
|
|
|
| 352 |
) -> dict[str, Any]:
|
|
|
|
|
|
|
| 353 |
history_entry: dict[str, Any] = {
|
| 354 |
"ticket_id": ticket.ticket_id,
|
| 355 |
"title": ticket.title,
|
|
@@ -359,9 +678,15 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 359 |
"breakdown": breakdown,
|
| 360 |
"queue_position": queue_position,
|
| 361 |
}
|
| 362 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
history_entry["ambiguity_note"] = ticket.ambiguity_note
|
| 364 |
-
if ticket.related_ticket_id is not None:
|
| 365 |
history_entry["related_ticket_id"] = ticket.related_ticket_id
|
| 366 |
related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
|
| 367 |
if related_ticket is not None:
|
|
@@ -375,6 +700,21 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 375 |
history_entry["penalty_reason"] = penalty_reason
|
| 376 |
if tool_result is not None:
|
| 377 |
history_entry["tool_result"] = tool_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
return history_entry
|
| 379 |
|
| 380 |
def _build_observation(
|
|
@@ -382,6 +722,7 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 382 |
task: dict,
|
| 383 |
done: bool = False,
|
| 384 |
reward: float | None = None,
|
|
|
|
| 385 |
) -> HelpdeskTicketObservation:
|
| 386 |
idx = self._state.current_ticket_index
|
| 387 |
queue_size = len(self._queue)
|
|
@@ -395,28 +736,47 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 395 |
queue_position = 0
|
| 396 |
|
| 397 |
history = list(self._state.history_entries)
|
|
|
|
| 398 |
tickets_remaining = max(0, queue_size - idx)
|
| 399 |
tickets_after_current = max(
|
| 400 |
0,
|
| 401 |
tickets_remaining - (1 if ticket_view is not None else 0),
|
| 402 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
return HelpdeskTicketObservation(
|
| 405 |
done=done,
|
| 406 |
reward=reward,
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
"tickets_remaining_includes_current": ticket_view is not None,
|
| 410 |
-
"has_ambiguity_note": bool(ticket_view and ticket_view.get("ambiguity_note")),
|
| 411 |
-
"has_related_ticket_context": bool(
|
| 412 |
-
ticket_view and ticket_view.get("related_ticket_preview")
|
| 413 |
-
),
|
| 414 |
-
"action_mode": "investigate_or_submit",
|
| 415 |
-
},
|
| 416 |
task_id=task["id"],
|
| 417 |
task_name=task["name"],
|
| 418 |
instructions=task["instructions"],
|
| 419 |
allowed_fields=list(task["allowed_fields"]),
|
|
|
|
| 420 |
available_tools=list(AVAILABLE_TOOLS),
|
| 421 |
investigation_budget_remaining=self._state.investigation_budget_remaining,
|
| 422 |
last_tool_result=self._state.last_tool_result,
|
|
@@ -426,5 +786,8 @@ class HelpdeskTicketRoutingEnvironment(
|
|
| 426 |
tickets_after_current=tickets_after_current,
|
| 427 |
tickets_processed=idx,
|
| 428 |
queue_position=queue_position,
|
|
|
|
|
|
|
| 429 |
history=history,
|
|
|
|
| 430 |
)
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
QUEUE_SIZE_RANGE = (3, 5)
|
| 21 |
+
AVAILABLE_ACTION_TYPES = ("submit", "investigate")
|
| 22 |
+
AVAILABLE_TOOLS = (
|
| 23 |
+
"lookup_related_ticket",
|
| 24 |
+
"lookup_requester_history",
|
| 25 |
+
"lookup_internal_routing_note",
|
| 26 |
+
)
|
| 27 |
FREE_INVESTIGATIONS_PER_TICKET = 1
|
| 28 |
EXTRA_INVESTIGATION_COST = 0.02
|
| 29 |
MAX_EXTRA_INVESTIGATION_PENALTY = 0.15
|
| 30 |
+
USEFUL_INVESTIGATION_REWARD = 0.08
|
| 31 |
+
PREMATURE_SUBMIT_PENALTY = 0.10
|
| 32 |
+
|
| 33 |
+
TASK3_INVESTIGATION_TOOL_PLAN: dict[str, tuple[str, ...]] = {
|
| 34 |
+
"ticket-021": ("lookup_related_ticket", "lookup_requester_history"),
|
| 35 |
+
"ticket-022": ("lookup_internal_routing_note",),
|
| 36 |
+
"ticket-027": ("lookup_internal_routing_note",),
|
| 37 |
+
"ticket-029": ("lookup_internal_routing_note",),
|
| 38 |
+
"ticket-038": ("lookup_related_ticket", "lookup_requester_history"),
|
| 39 |
+
"ticket-045": ("lookup_related_ticket", "lookup_requester_history"),
|
| 40 |
+
"TKT-NONDEFAULT-001": ("lookup_internal_routing_note",),
|
| 41 |
+
"TKT-NONDEFAULT-002": ("lookup_internal_routing_note",),
|
| 42 |
+
"TKT-NONDEFAULT-003": ("lookup_internal_routing_note",),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
HARD_TASK_DESCRIPTION_REDACTIONS: dict[str, str] = {
|
| 46 |
+
"ticket-021": (
|
| 47 |
+
"Production checkout is still unstable after a recent fix. "
|
| 48 |
+
"Additional routing context is available via investigation."
|
| 49 |
+
),
|
| 50 |
+
"ticket-022": (
|
| 51 |
+
"Usage charges increased while the integration was failing. "
|
| 52 |
+
"Additional routing context is available via investigation."
|
| 53 |
+
),
|
| 54 |
+
"ticket-027": (
|
| 55 |
+
"A vendor offer arrived with a near-term deadline. "
|
| 56 |
+
"Additional routing context is available via investigation."
|
| 57 |
+
),
|
| 58 |
+
"ticket-029": (
|
| 59 |
+
"A team needs a large seat expansion right away. "
|
| 60 |
+
"Additional routing context is available via investigation."
|
| 61 |
+
),
|
| 62 |
+
"ticket-038": (
|
| 63 |
+
"A prior invoice discrepancy is still unresolved and now time-sensitive. "
|
| 64 |
+
"Additional routing context is available via investigation."
|
| 65 |
+
),
|
| 66 |
+
"ticket-045": (
|
| 67 |
+
"A company-wide suspension remains unresolved after repeated follow-ups. "
|
| 68 |
+
"Additional routing context is available via investigation."
|
| 69 |
+
),
|
| 70 |
+
"TKT-NONDEFAULT-001": (
|
| 71 |
+
"A user needs help with a billing-style question. "
|
| 72 |
+
"Additional routing context is available via investigation."
|
| 73 |
+
),
|
| 74 |
+
"TKT-NONDEFAULT-002": (
|
| 75 |
+
"A client compliance scan surfaced a product-specific issue. "
|
| 76 |
+
"Additional routing context is available via investigation."
|
| 77 |
+
),
|
| 78 |
+
"TKT-NONDEFAULT-003": (
|
| 79 |
+
"A contractor onboarding workflow is blocked by an account problem. "
|
| 80 |
+
"Additional routing context is available via investigation."
|
| 81 |
+
),
|
| 82 |
+
}
|
| 83 |
|
| 84 |
|
| 85 |
def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
|
|
|
|
| 144 |
current_ticket_index=0,
|
| 145 |
per_ticket_scores=[],
|
| 146 |
total_reward=0.0,
|
| 147 |
+
average_score_so_far=0.0,
|
| 148 |
investigation_budget_remaining=queue_size * FREE_INVESTIGATIONS_PER_TICKET,
|
| 149 |
+
investigation_penalty_applied=0.0,
|
| 150 |
+
last_reward_components={},
|
| 151 |
+
ticket_tool_usage={},
|
| 152 |
)
|
| 153 |
|
| 154 |
return self._build_observation(task)
|
|
|
|
| 184 |
if extra_fields:
|
| 185 |
# Penalty: record score 0.0, advance index, return penalty observation
|
| 186 |
self._state.per_ticket_scores.append(0.0)
|
| 187 |
+
self._state.average_score_so_far = self._current_average_score()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
self._state.step_count += 1
|
| 189 |
self._state.current_ticket_index += 1
|
| 190 |
is_done = self._state.current_ticket_index >= len(self._queue)
|
| 191 |
self._state.done = is_done
|
| 192 |
+
trajectory_reward = None
|
| 193 |
+
investigation_penalty = self._compute_episode_penalty() if is_done else 0.0
|
| 194 |
if is_done:
|
| 195 |
+
trajectory_reward = compute_trajectory_reward(
|
| 196 |
self._state.per_ticket_scores, len(self._queue), self._state.step_count
|
| 197 |
)
|
| 198 |
+
final_reward = self._apply_episode_economics(trajectory_reward)
|
| 199 |
self._state.total_reward = final_reward
|
| 200 |
else:
|
| 201 |
final_reward = 0.0
|
| 202 |
+
reward_components = self._build_reward_components(
|
| 203 |
+
ticket_score=0.0,
|
| 204 |
+
field_breakdown={},
|
| 205 |
+
shaped_step_reward=0.0,
|
| 206 |
+
reward_kind="trajectory" if is_done else "step_penalty",
|
| 207 |
+
final_reward=final_reward,
|
| 208 |
+
trajectory_reward=trajectory_reward,
|
| 209 |
+
investigation_penalty=investigation_penalty,
|
| 210 |
+
penalty_reason=f"extra_fields: {sorted(extra_fields)}",
|
| 211 |
+
)
|
| 212 |
+
self._state.history_entries.append(
|
| 213 |
+
self._build_history_entry(
|
| 214 |
+
current_ticket,
|
| 215 |
+
predicted=action.model_dump(exclude_none=True),
|
| 216 |
+
score=0.0,
|
| 217 |
+
breakdown={},
|
| 218 |
+
queue_position=idx + 1,
|
| 219 |
+
reward=final_reward,
|
| 220 |
+
rubric_reward=final_reward if is_done else None,
|
| 221 |
+
reward_kind="trajectory" if is_done else "step_penalty",
|
| 222 |
+
penalty_reason=f"extra_fields: {sorted(extra_fields)}",
|
| 223 |
+
reward_components=reward_components,
|
| 224 |
+
)
|
| 225 |
+
)
|
| 226 |
self._state.last_step_reward = final_reward
|
| 227 |
self._state.reward = final_reward
|
| 228 |
+
self._state.investigation_penalty_applied = self._compute_episode_penalty()
|
| 229 |
self._state.last_tool_result = None
|
| 230 |
+
self._state.last_reward_components = reward_components
|
| 231 |
+
return self._build_observation(
|
| 232 |
+
task,
|
| 233 |
+
done=is_done,
|
| 234 |
+
reward=final_reward,
|
| 235 |
+
rubric_reward=final_reward if is_done else None,
|
| 236 |
+
)
|
| 237 |
|
| 238 |
score, breakdown = grade_action(action, current_ticket, task_id)
|
| 239 |
step_reward = compute_step_reward(score)
|
| 240 |
+
context_penalty, missing_required_tools = self._submit_context_penalty(current_ticket)
|
| 241 |
+
milestone_adjustment = step_reward - score
|
| 242 |
|
| 243 |
is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
|
| 244 |
+
trajectory_reward = None
|
| 245 |
+
investigation_penalty = 0.0
|
| 246 |
+
rubric_reward = None
|
| 247 |
|
| 248 |
if is_done:
|
| 249 |
self._state.per_ticket_scores.append(score)
|
| 250 |
+
self._state.average_score_so_far = self._current_average_score()
|
| 251 |
self._state.step_count += 1
|
| 252 |
self._state.current_ticket_index += 1
|
| 253 |
+
trajectory_reward = compute_trajectory_reward(
|
| 254 |
self._state.per_ticket_scores,
|
| 255 |
len(self._queue),
|
| 256 |
self._state.step_count,
|
| 257 |
)
|
| 258 |
+
rubric_reward = self._apply_episode_economics(trajectory_reward)
|
| 259 |
+
final_reward = max(0.0, min(1.0, rubric_reward - context_penalty))
|
| 260 |
+
self._state.total_reward = rubric_reward
|
| 261 |
+
investigation_penalty = self._compute_episode_penalty()
|
| 262 |
else:
|
| 263 |
self._state.per_ticket_scores.append(score)
|
| 264 |
+
self._state.average_score_so_far = self._current_average_score()
|
| 265 |
self._state.step_count += 1
|
| 266 |
self._state.current_ticket_index += 1
|
| 267 |
+
final_reward = max(0.0, min(1.0, step_reward - context_penalty))
|
| 268 |
+
|
| 269 |
+
reward_components = self._build_reward_components(
|
| 270 |
+
ticket_score=score,
|
| 271 |
+
field_breakdown=breakdown,
|
| 272 |
+
shaped_step_reward=step_reward,
|
| 273 |
+
reward_kind="trajectory" if is_done else "step",
|
| 274 |
+
final_reward=final_reward,
|
| 275 |
+
milestone_adjustment=milestone_adjustment,
|
| 276 |
+
trajectory_reward=trajectory_reward,
|
| 277 |
+
investigation_penalty=investigation_penalty,
|
| 278 |
+
extra_details={
|
| 279 |
+
"context_gap_penalty": context_penalty,
|
| 280 |
+
"required_tools": self._required_tools_for_ticket(current_ticket),
|
| 281 |
+
"remaining_required_tools": missing_required_tools,
|
| 282 |
+
"rubric_reward": rubric_reward,
|
| 283 |
+
},
|
| 284 |
+
)
|
| 285 |
|
| 286 |
history_entry = self._build_history_entry(
|
| 287 |
current_ticket,
|
|
|
|
| 289 |
score=score,
|
| 290 |
breakdown=breakdown,
|
| 291 |
queue_position=idx + 1,
|
| 292 |
+
reward=final_reward,
|
| 293 |
+
rubric_reward=rubric_reward if is_done else None,
|
| 294 |
+
reward_kind="trajectory" if is_done else "step",
|
| 295 |
+
reward_components=reward_components,
|
| 296 |
)
|
| 297 |
self._state.history_entries.append(history_entry)
|
| 298 |
|
| 299 |
self._state.last_step_reward = final_reward
|
| 300 |
self._state.reward = final_reward
|
| 301 |
self._state.done = is_done
|
| 302 |
+
self._state.investigation_penalty_applied = self._compute_episode_penalty()
|
| 303 |
self._state.last_tool_result = None
|
| 304 |
+
self._state.last_reward_components = reward_components
|
| 305 |
|
| 306 |
+
return self._build_observation(
|
| 307 |
+
task,
|
| 308 |
+
done=is_done,
|
| 309 |
+
reward=final_reward,
|
| 310 |
+
rubric_reward=rubric_reward if is_done else None,
|
| 311 |
+
)
|
| 312 |
|
| 313 |
@property
|
| 314 |
def state(self) -> HelpdeskTicketState:
|
|
|
|
| 318 |
# Helpers
|
| 319 |
# ------------------------------------------------------------------
|
| 320 |
|
| 321 |
+
def _compute_episode_penalty(self) -> float:
|
| 322 |
free_investigations = len(self._queue) * FREE_INVESTIGATIONS_PER_TICKET
|
| 323 |
extra_investigations = max(0, self._state.investigation_steps - free_investigations)
|
| 324 |
+
return min(
|
| 325 |
MAX_EXTRA_INVESTIGATION_PENALTY,
|
| 326 |
extra_investigations * EXTRA_INVESTIGATION_COST,
|
| 327 |
)
|
| 328 |
+
|
| 329 |
+
def _apply_episode_economics(self, base_reward: float) -> float:
|
| 330 |
+
penalty = self._compute_episode_penalty()
|
| 331 |
return max(0.0, min(1.0, base_reward - penalty))
|
| 332 |
|
| 333 |
+
def _current_average_score(self) -> float:
|
| 334 |
+
if not self._state.per_ticket_scores:
|
| 335 |
+
return 0.0
|
| 336 |
+
return sum(self._state.per_ticket_scores) / len(self._state.per_ticket_scores)
|
| 337 |
+
|
| 338 |
+
def _required_tools_for_ticket(
|
| 339 |
+
self,
|
| 340 |
+
ticket: HelpdeskTicketRecord,
|
| 341 |
+
task_id: int | None = None,
|
| 342 |
+
) -> list[str]:
|
| 343 |
+
resolved_task_id = self._state.current_task_id if task_id is None else task_id
|
| 344 |
+
if resolved_task_id != 3:
|
| 345 |
+
return []
|
| 346 |
+
return list(TASK3_INVESTIGATION_TOOL_PLAN.get(ticket.ticket_id, ()))
|
| 347 |
+
|
| 348 |
+
def _used_tools_for_ticket(self, ticket_id: str) -> list[str]:
|
| 349 |
+
return list(self._state.ticket_tool_usage.get(ticket_id, []))
|
| 350 |
+
|
| 351 |
+
def _remaining_tools_for_ticket(
|
| 352 |
+
self,
|
| 353 |
+
ticket: HelpdeskTicketRecord,
|
| 354 |
+
task_id: int | None = None,
|
| 355 |
+
) -> list[str]:
|
| 356 |
+
required_tools = self._required_tools_for_ticket(ticket, task_id)
|
| 357 |
+
used_tools = set(self._used_tools_for_ticket(ticket.ticket_id))
|
| 358 |
+
return [tool for tool in required_tools if tool not in used_tools]
|
| 359 |
+
|
| 360 |
+
def _record_tool_usage(self, ticket_id: str, tool_name: str) -> None:
|
| 361 |
+
used = self._state.ticket_tool_usage.setdefault(ticket_id, [])
|
| 362 |
+
if tool_name not in used:
|
| 363 |
+
used.append(tool_name)
|
| 364 |
+
|
| 365 |
+
def _investigation_hints_for_ticket(self, ticket: HelpdeskTicketRecord) -> list[str]:
|
| 366 |
+
hints: list[str] = []
|
| 367 |
+
remaining_tools = self._remaining_tools_for_ticket(ticket)
|
| 368 |
+
if "lookup_internal_routing_note" in remaining_tools:
|
| 369 |
+
hints.append("An internal routing note may disambiguate the correct workflow.")
|
| 370 |
+
if "lookup_related_ticket" in remaining_tools:
|
| 371 |
+
hints.append("A linked prior ticket can reveal important follow-up context.")
|
| 372 |
+
if "lookup_requester_history" in remaining_tools:
|
| 373 |
+
hints.append("Requester history may clarify severity or routing intent.")
|
| 374 |
+
return hints
|
| 375 |
+
|
| 376 |
+
def _visible_description(self, ticket: HelpdeskTicketRecord) -> str:
|
| 377 |
+
if (
|
| 378 |
+
self._state.current_task_id == 3
|
| 379 |
+
and self._remaining_tools_for_ticket(ticket)
|
| 380 |
+
and ticket.ticket_id in HARD_TASK_DESCRIPTION_REDACTIONS
|
| 381 |
+
):
|
| 382 |
+
return HARD_TASK_DESCRIPTION_REDACTIONS[ticket.ticket_id]
|
| 383 |
+
return ticket.description
|
| 384 |
+
|
| 385 |
+
def _submit_context_penalty(self, ticket: HelpdeskTicketRecord) -> tuple[float, list[str]]:
|
| 386 |
+
required_tools = self._required_tools_for_ticket(ticket)
|
| 387 |
+
if not required_tools:
|
| 388 |
+
return 0.0, []
|
| 389 |
+
remaining_tools = self._remaining_tools_for_ticket(ticket)
|
| 390 |
+
if not remaining_tools:
|
| 391 |
+
return 0.0, []
|
| 392 |
+
penalty = PREMATURE_SUBMIT_PENALTY * (len(remaining_tools) / len(required_tools))
|
| 393 |
+
return penalty, remaining_tools
|
| 394 |
+
|
| 395 |
+
def _build_reward_components(
|
| 396 |
+
self,
|
| 397 |
+
*,
|
| 398 |
+
ticket_score: float,
|
| 399 |
+
field_breakdown: dict[str, float],
|
| 400 |
+
shaped_step_reward: float,
|
| 401 |
+
reward_kind: str,
|
| 402 |
+
final_reward: float,
|
| 403 |
+
milestone_adjustment: float = 0.0,
|
| 404 |
+
trajectory_reward: float | None = None,
|
| 405 |
+
investigation_penalty: float = 0.0,
|
| 406 |
+
penalty_reason: str | None = None,
|
| 407 |
+
extra_details: dict[str, Any] | None = None,
|
| 408 |
+
) -> dict[str, Any]:
|
| 409 |
+
components: dict[str, Any] = {
|
| 410 |
+
"reward_kind": reward_kind,
|
| 411 |
+
"ticket_score": ticket_score,
|
| 412 |
+
"field_breakdown": field_breakdown,
|
| 413 |
+
"shaped_step_reward": shaped_step_reward,
|
| 414 |
+
"milestone_adjustment": milestone_adjustment,
|
| 415 |
+
"final_reward": final_reward,
|
| 416 |
+
"average_score_so_far": self._current_average_score(),
|
| 417 |
+
"investigation_penalty_applied": investigation_penalty,
|
| 418 |
+
}
|
| 419 |
+
if trajectory_reward is not None:
|
| 420 |
+
components["trajectory_reward"] = trajectory_reward
|
| 421 |
+
if penalty_reason is not None:
|
| 422 |
+
components["penalty_reason"] = penalty_reason
|
| 423 |
+
if extra_details:
|
| 424 |
+
components.update(extra_details)
|
| 425 |
+
return components
|
| 426 |
+
|
| 427 |
def _lookup_related_ticket(
|
| 428 |
self,
|
| 429 |
current_ticket: HelpdeskTicketRecord,
|
|
|
|
| 479 |
"matches": matches,
|
| 480 |
}
|
| 481 |
|
| 482 |
+
def _lookup_internal_routing_note(self, current_ticket: HelpdeskTicketRecord) -> dict[str, Any]:
|
| 483 |
+
found = current_ticket.ambiguity_note is not None
|
| 484 |
+
return {
|
| 485 |
+
"tool_name": "lookup_internal_routing_note",
|
| 486 |
+
"found": found,
|
| 487 |
+
"ticket_id": current_ticket.ticket_id,
|
| 488 |
+
"routing_note": current_ticket.ambiguity_note if found else "",
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
def _run_investigation_tool(
|
| 492 |
self,
|
| 493 |
current_ticket: HelpdeskTicketRecord,
|
|
|
|
| 498 |
return self._lookup_related_ticket(current_ticket, target_ticket_id)
|
| 499 |
if tool_name == "lookup_requester_history":
|
| 500 |
return self._lookup_requester_history(current_ticket)
|
| 501 |
+
if tool_name == "lookup_internal_routing_note":
|
| 502 |
+
return self._lookup_internal_routing_note(current_ticket)
|
| 503 |
raise ValueError(f"Unsupported tool_name: {tool_name}")
|
| 504 |
|
| 505 |
def _handle_investigation_action(
|
|
|
|
| 527 |
action.tool_name,
|
| 528 |
action.tool_target_ticket_id,
|
| 529 |
)
|
| 530 |
+
required_tools = self._required_tools_for_ticket(current_ticket)
|
| 531 |
+
already_used = action.tool_name in self._used_tools_for_ticket(current_ticket.ticket_id)
|
| 532 |
+
useful_investigation = (
|
| 533 |
+
action.tool_name in required_tools
|
| 534 |
+
and not already_used
|
| 535 |
+
and bool(tool_result.get("found", True))
|
| 536 |
+
)
|
| 537 |
+
self._record_tool_usage(current_ticket.ticket_id, action.tool_name)
|
| 538 |
self._state.step_count += 1
|
| 539 |
self._state.investigation_steps += 1
|
| 540 |
self._state.investigation_budget_remaining = max(
|
|
|
|
| 542 |
self._state.investigation_budget_remaining - 1,
|
| 543 |
)
|
| 544 |
self._state.last_tool_result = tool_result
|
| 545 |
+
investigation_reward = USEFUL_INVESTIGATION_REWARD if useful_investigation else 0.0
|
| 546 |
+
self._state.last_step_reward = investigation_reward
|
| 547 |
+
self._state.reward = investigation_reward
|
| 548 |
self._state.done = False
|
| 549 |
+
self._state.investigation_penalty_applied = self._compute_episode_penalty()
|
| 550 |
+
reward_components = self._build_reward_components(
|
| 551 |
+
ticket_score=0.0,
|
| 552 |
+
field_breakdown={},
|
| 553 |
+
shaped_step_reward=investigation_reward,
|
| 554 |
+
reward_kind="investigation",
|
| 555 |
+
final_reward=investigation_reward,
|
| 556 |
+
investigation_penalty=self._state.investigation_penalty_applied,
|
| 557 |
+
extra_details={
|
| 558 |
+
"new_context_revealed": useful_investigation,
|
| 559 |
+
"required_tools": required_tools,
|
| 560 |
+
"remaining_required_tools": self._remaining_tools_for_ticket(current_ticket),
|
| 561 |
+
"tool_name": action.tool_name,
|
| 562 |
+
},
|
| 563 |
+
)
|
| 564 |
self._state.history_entries.append(
|
| 565 |
self._build_history_entry(
|
| 566 |
current_ticket,
|
|
|
|
| 568 |
score=0.0,
|
| 569 |
breakdown={},
|
| 570 |
queue_position=idx + 1,
|
| 571 |
+
reward=investigation_reward,
|
| 572 |
+
reward_kind="investigation",
|
| 573 |
tool_result=tool_result,
|
| 574 |
+
reward_components=reward_components,
|
| 575 |
)
|
| 576 |
)
|
| 577 |
+
self._state.last_reward_components = reward_components
|
| 578 |
+
return self._build_observation(task, done=False, reward=investigation_reward)
|
| 579 |
|
| 580 |
def _build_ticket_view(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
|
| 581 |
+
required_tools = self._required_tools_for_ticket(ticket)
|
| 582 |
+
revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
|
| 583 |
+
remaining_tools = self._remaining_tools_for_ticket(ticket)
|
| 584 |
ticket_view: dict[str, Any] = {
|
| 585 |
"ticket_id": ticket.ticket_id,
|
| 586 |
"title": ticket.title,
|
| 587 |
"requester": ticket.requester,
|
| 588 |
+
"description": self._visible_description(ticket),
|
| 589 |
}
|
| 590 |
+
if required_tools:
|
| 591 |
+
ticket_view["context_status"] = {
|
| 592 |
+
"investigation_required": True,
|
| 593 |
+
"revealed_tools": revealed_tools,
|
| 594 |
+
"remaining_tools": remaining_tools,
|
| 595 |
+
"hints": self._investigation_hints_for_ticket(ticket),
|
| 596 |
+
}
|
| 597 |
+
if ticket.ambiguity_note is not None and "lookup_internal_routing_note" not in remaining_tools:
|
| 598 |
ticket_view["ambiguity_note"] = ticket.ambiguity_note
|
| 599 |
+
if ticket.related_ticket_id is not None and "lookup_related_ticket" not in remaining_tools:
|
| 600 |
ticket_view["related_ticket_id"] = ticket.related_ticket_id
|
| 601 |
related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
|
| 602 |
if related_ticket is not None:
|
|
|
|
| 608 |
}
|
| 609 |
return ticket_view
|
| 610 |
|
| 611 |
+
def _build_feedback_summary(
|
| 612 |
+
self,
|
| 613 |
+
*,
|
| 614 |
+
predicted: dict[str, Any],
|
| 615 |
+
score: float,
|
| 616 |
+
breakdown: dict[str, float],
|
| 617 |
+
reward: float | None = None,
|
| 618 |
+
rubric_reward: float | None = None,
|
| 619 |
+
reward_kind: str | None = None,
|
| 620 |
+
penalty_reason: str | None = None,
|
| 621 |
+
tool_result: dict[str, Any] | None = None,
|
| 622 |
+
reward_components: dict[str, Any] | None = None,
|
| 623 |
+
) -> str:
|
| 624 |
+
parts: list[str] = []
|
| 625 |
+
|
| 626 |
+
if reward_kind == "investigation":
|
| 627 |
+
tool_name = predicted.get("tool_name") or (tool_result or {}).get("tool_name")
|
| 628 |
+
parts.append(f"Investigation step used {tool_name or 'a tool'}")
|
| 629 |
+
if reward_components and reward_components.get("new_context_revealed"):
|
| 630 |
+
parts.append("new context was revealed")
|
| 631 |
+
elif penalty_reason is not None:
|
| 632 |
+
parts.append(f"Penalty applied: {penalty_reason}")
|
| 633 |
+
else:
|
| 634 |
+
parts.append(f"Ticket score={score:.2f}")
|
| 635 |
+
|
| 636 |
+
if breakdown:
|
| 637 |
+
field_scores = ", ".join(
|
| 638 |
+
f"{field}={value:.2f}" for field, value in sorted(breakdown.items())
|
| 639 |
+
)
|
| 640 |
+
parts.append(f"field_scores[{field_scores}]")
|
| 641 |
+
if reward is not None:
|
| 642 |
+
parts.append(f"reward={reward:.2f}")
|
| 643 |
+
if rubric_reward is not None:
|
| 644 |
+
parts.append(f"rubric_reward={rubric_reward:.2f}")
|
| 645 |
+
if reward_components:
|
| 646 |
+
context_gap_penalty = reward_components.get("context_gap_penalty")
|
| 647 |
+
if context_gap_penalty:
|
| 648 |
+
parts.append(f"context_gap_penalty={context_gap_penalty:.2f}")
|
| 649 |
+
remaining_required_tools = reward_components.get("remaining_required_tools") or []
|
| 650 |
+
if remaining_required_tools:
|
| 651 |
+
parts.append(f"missing_context={remaining_required_tools}")
|
| 652 |
+
|
| 653 |
+
return "; ".join(parts)
|
| 654 |
+
|
| 655 |
def _build_history_entry(
|
| 656 |
self,
|
| 657 |
ticket: HelpdeskTicketRecord,
|
|
|
|
| 660 |
score: float,
|
| 661 |
breakdown: dict[str, float],
|
| 662 |
queue_position: int,
|
| 663 |
+
reward: float | None = None,
|
| 664 |
+
rubric_reward: float | None = None,
|
| 665 |
+
reward_kind: str | None = None,
|
| 666 |
penalty_reason: str | None = None,
|
| 667 |
tool_result: dict[str, Any] | None = None,
|
| 668 |
+
reward_components: dict[str, Any] | None = None,
|
| 669 |
) -> dict[str, Any]:
|
| 670 |
+
remaining_tools = self._remaining_tools_for_ticket(ticket)
|
| 671 |
+
revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
|
| 672 |
history_entry: dict[str, Any] = {
|
| 673 |
"ticket_id": ticket.ticket_id,
|
| 674 |
"title": ticket.title,
|
|
|
|
| 678 |
"breakdown": breakdown,
|
| 679 |
"queue_position": queue_position,
|
| 680 |
}
|
| 681 |
+
if reward is not None:
|
| 682 |
+
history_entry["reward"] = reward
|
| 683 |
+
if rubric_reward is not None:
|
| 684 |
+
history_entry["rubric_reward"] = rubric_reward
|
| 685 |
+
if reward_kind is not None:
|
| 686 |
+
history_entry["reward_kind"] = reward_kind
|
| 687 |
+
if ticket.ambiguity_note is not None and "lookup_internal_routing_note" not in remaining_tools:
|
| 688 |
history_entry["ambiguity_note"] = ticket.ambiguity_note
|
| 689 |
+
if ticket.related_ticket_id is not None and "lookup_related_ticket" not in remaining_tools:
|
| 690 |
history_entry["related_ticket_id"] = ticket.related_ticket_id
|
| 691 |
related_ticket = self._tickets_by_id.get(ticket.related_ticket_id)
|
| 692 |
if related_ticket is not None:
|
|
|
|
| 700 |
history_entry["penalty_reason"] = penalty_reason
|
| 701 |
if tool_result is not None:
|
| 702 |
history_entry["tool_result"] = tool_result
|
| 703 |
+
if reward_components is not None:
|
| 704 |
+
history_entry["reward_components"] = reward_components
|
| 705 |
+
if revealed_tools:
|
| 706 |
+
history_entry["revealed_tools"] = revealed_tools
|
| 707 |
+
history_entry["feedback_summary"] = self._build_feedback_summary(
|
| 708 |
+
predicted=predicted,
|
| 709 |
+
score=score,
|
| 710 |
+
breakdown=breakdown,
|
| 711 |
+
reward=reward,
|
| 712 |
+
rubric_reward=rubric_reward,
|
| 713 |
+
reward_kind=reward_kind,
|
| 714 |
+
penalty_reason=penalty_reason,
|
| 715 |
+
tool_result=tool_result,
|
| 716 |
+
reward_components=reward_components,
|
| 717 |
+
)
|
| 718 |
return history_entry
|
| 719 |
|
| 720 |
def _build_observation(
|
|
|
|
| 722 |
task: dict,
|
| 723 |
done: bool = False,
|
| 724 |
reward: float | None = None,
|
| 725 |
+
rubric_reward: float | None = None,
|
| 726 |
) -> HelpdeskTicketObservation:
|
| 727 |
idx = self._state.current_ticket_index
|
| 728 |
queue_size = len(self._queue)
|
|
|
|
| 736 |
queue_position = 0
|
| 737 |
|
| 738 |
history = list(self._state.history_entries)
|
| 739 |
+
last_history_entry = history[-1] if history else None
|
| 740 |
tickets_remaining = max(0, queue_size - idx)
|
| 741 |
tickets_after_current = max(
|
| 742 |
0,
|
| 743 |
tickets_remaining - (1 if ticket_view is not None else 0),
|
| 744 |
)
|
| 745 |
+
progress_fraction = (idx / queue_size) if queue_size else 0.0
|
| 746 |
+
|
| 747 |
+
metadata = {
|
| 748 |
+
"queue_position": queue_position,
|
| 749 |
+
"tickets_remaining_includes_current": ticket_view is not None,
|
| 750 |
+
"has_ambiguity_note": bool(ticket_view and ticket_view.get("ambiguity_note")),
|
| 751 |
+
"has_related_ticket_context": bool(
|
| 752 |
+
ticket_view and ticket_view.get("related_ticket_preview")
|
| 753 |
+
),
|
| 754 |
+
"action_mode": "investigate_or_submit",
|
| 755 |
+
"available_action_types": list(AVAILABLE_ACTION_TYPES),
|
| 756 |
+
"average_score_so_far": self._state.average_score_so_far,
|
| 757 |
+
"progress_fraction": progress_fraction,
|
| 758 |
+
"investigation_penalty_applied": self._state.investigation_penalty_applied,
|
| 759 |
+
}
|
| 760 |
+
if last_history_entry is not None:
|
| 761 |
+
metadata["last_score"] = last_history_entry.get("score")
|
| 762 |
+
metadata["last_reward"] = last_history_entry.get("reward")
|
| 763 |
+
metadata["last_reward_kind"] = last_history_entry.get("reward_kind")
|
| 764 |
+
metadata["last_breakdown"] = last_history_entry.get("breakdown")
|
| 765 |
+
metadata["last_feedback_summary"] = last_history_entry.get("feedback_summary")
|
| 766 |
+
metadata["last_reward_components"] = last_history_entry.get("reward_components", {})
|
| 767 |
+
if "penalty_reason" in last_history_entry:
|
| 768 |
+
metadata["last_penalty_reason"] = last_history_entry["penalty_reason"]
|
| 769 |
|
| 770 |
return HelpdeskTicketObservation(
|
| 771 |
done=done,
|
| 772 |
reward=reward,
|
| 773 |
+
rubric_reward=rubric_reward,
|
| 774 |
+
metadata=metadata,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
task_id=task["id"],
|
| 776 |
task_name=task["name"],
|
| 777 |
instructions=task["instructions"],
|
| 778 |
allowed_fields=list(task["allowed_fields"]),
|
| 779 |
+
available_action_types=list(AVAILABLE_ACTION_TYPES),
|
| 780 |
available_tools=list(AVAILABLE_TOOLS),
|
| 781 |
investigation_budget_remaining=self._state.investigation_budget_remaining,
|
| 782 |
last_tool_result=self._state.last_tool_result,
|
|
|
|
| 786 |
tickets_after_current=tickets_after_current,
|
| 787 |
tickets_processed=idx,
|
| 788 |
queue_position=queue_position,
|
| 789 |
+
average_score_so_far=self._state.average_score_so_far,
|
| 790 |
+
progress_fraction=progress_fraction,
|
| 791 |
history=history,
|
| 792 |
+
last_reward_components=dict(self._state.last_reward_components),
|
| 793 |
)
|
server/tasks.py
CHANGED
|
@@ -37,7 +37,9 @@ TASKS = {
|
|
| 37 |
"Perform full helpdesk routing by selecting the best issue type, "
|
| 38 |
"priority, assignment group, and resolution action for the ticket. "
|
| 39 |
"Use any ambiguity notes or related-ticket previews when present. "
|
| 40 |
-
"
|
|
|
|
|
|
|
| 41 |
),
|
| 42 |
"allowed_fields": [
|
| 43 |
"issue_type",
|
|
|
|
| 37 |
"Perform full helpdesk routing by selecting the best issue type, "
|
| 38 |
"priority, assignment group, and resolution action for the ticket. "
|
| 39 |
"Use any ambiguity notes or related-ticket previews when present. "
|
| 40 |
+
"Some hard tickets intentionally hide decisive routing context until "
|
| 41 |
+
"you investigate with the available tools, so premature submission can "
|
| 42 |
+
"underperform even when the visible text looks plausible."
|
| 43 |
),
|
| 44 |
"allowed_fields": [
|
| 45 |
"issue_type",
|
tests/test_api_integration.py
CHANGED
|
@@ -167,6 +167,9 @@ class TestResetEndpoint(unittest.TestCase):
|
|
| 167 |
def test_reset_reward_is_null(self):
|
| 168 |
self.assertIsNone(self.data["reward"])
|
| 169 |
|
|
|
|
|
|
|
|
|
|
| 170 |
def test_reset_task_id_is_1(self):
|
| 171 |
self.assertEqual(self.data["task_id"], 1)
|
| 172 |
|
|
@@ -177,6 +180,13 @@ class TestResetEndpoint(unittest.TestCase):
|
|
| 177 |
self.assertIsInstance(self.data["allowed_fields"], list)
|
| 178 |
self.assertGreater(len(self.data["allowed_fields"]), 0)
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
class TestStepEndpoint(unittest.TestCase):
|
| 182 |
"""2.1.4 — POST /step returns observation JSON with reward in [0.0, 1.0]."""
|
|
@@ -200,6 +210,35 @@ class TestStepEndpoint(unittest.TestCase):
|
|
| 200 |
def test_step_tickets_processed_is_1(self):
|
| 201 |
self.assertEqual(self.data["tickets_processed"], 1)
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
class TestStateEndpoint(unittest.TestCase):
|
| 205 |
"""2.1.5 — GET /state returns current episode state JSON after a reset."""
|
|
@@ -278,6 +317,38 @@ class TestFullSeededEpisode(unittest.TestCase):
|
|
| 278 |
self.assertGreaterEqual(final_reward, 0.0)
|
| 279 |
self.assertLessEqual(final_reward, 1.0)
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
def test_full_episode_all_tasks_complete(self):
|
| 282 |
"""4.1.1 — Full seeded episode completes for each task ID (1, 2, 3)."""
|
| 283 |
for task_id in (1, 2, 3):
|
|
|
|
| 167 |
def test_reset_reward_is_null(self):
|
| 168 |
self.assertIsNone(self.data["reward"])
|
| 169 |
|
| 170 |
+
def test_reset_rubric_reward_is_null(self):
|
| 171 |
+
self.assertIsNone(self.data["rubric_reward"])
|
| 172 |
+
|
| 173 |
def test_reset_task_id_is_1(self):
|
| 174 |
self.assertEqual(self.data["task_id"], 1)
|
| 175 |
|
|
|
|
| 180 |
self.assertIsInstance(self.data["allowed_fields"], list)
|
| 181 |
self.assertGreater(len(self.data["allowed_fields"]), 0)
|
| 182 |
|
| 183 |
+
def test_reset_available_action_types_exposed(self):
|
| 184 |
+
self.assertEqual(self.data["available_action_types"], ["submit", "investigate"])
|
| 185 |
+
|
| 186 |
+
def test_reset_progress_metrics_start_at_zero(self):
|
| 187 |
+
self.assertEqual(self.data["average_score_so_far"], 0.0)
|
| 188 |
+
self.assertEqual(self.data["progress_fraction"], 0.0)
|
| 189 |
+
|
| 190 |
|
| 191 |
class TestStepEndpoint(unittest.TestCase):
|
| 192 |
"""2.1.4 — POST /step returns observation JSON with reward in [0.0, 1.0]."""
|
|
|
|
| 210 |
def test_step_tickets_processed_is_1(self):
|
| 211 |
self.assertEqual(self.data["tickets_processed"], 1)
|
| 212 |
|
| 213 |
+
def test_step_metadata_exposes_last_feedback_summary(self):
|
| 214 |
+
metadata = self.data.get("metadata", {})
|
| 215 |
+
self.assertIn("last_feedback_summary", metadata)
|
| 216 |
+
self.assertIsInstance(metadata["last_feedback_summary"], str)
|
| 217 |
+
self.assertTrue(metadata["last_feedback_summary"])
|
| 218 |
+
|
| 219 |
+
def test_step_history_entry_includes_feedback_summary(self):
|
| 220 |
+
history = self.data.get("history", [])
|
| 221 |
+
self.assertGreater(len(history), 0)
|
| 222 |
+
self.assertIn("feedback_summary", history[-1])
|
| 223 |
+
self.assertIsInstance(history[-1]["feedback_summary"], str)
|
| 224 |
+
self.assertTrue(history[-1]["feedback_summary"])
|
| 225 |
+
|
| 226 |
+
def test_step_exposes_structured_reward_components(self):
|
| 227 |
+
self.assertIn("last_reward_components", self.data)
|
| 228 |
+
self.assertIsInstance(self.data["last_reward_components"], dict)
|
| 229 |
+
self.assertIn("ticket_score", self.data["last_reward_components"])
|
| 230 |
+
self.assertIn("final_reward", self.data["last_reward_components"])
|
| 231 |
+
self.assertEqual(
|
| 232 |
+
self.data["metadata"].get("last_reward_components"),
|
| 233 |
+
self.data["last_reward_components"],
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def test_step_progress_metrics_are_exposed(self):
|
| 237 |
+
self.assertIn("average_score_so_far", self.data)
|
| 238 |
+
self.assertIn("progress_fraction", self.data)
|
| 239 |
+
self.assertGreaterEqual(self.data["progress_fraction"], 0.0)
|
| 240 |
+
self.assertLessEqual(self.data["progress_fraction"], 1.0)
|
| 241 |
+
|
| 242 |
|
| 243 |
class TestStateEndpoint(unittest.TestCase):
|
| 244 |
"""2.1.5 — GET /state returns current episode state JSON after a reset."""
|
|
|
|
| 317 |
self.assertGreaterEqual(final_reward, 0.0)
|
| 318 |
self.assertLessEqual(final_reward, 1.0)
|
| 319 |
|
| 320 |
+
def test_full_episode_terminal_rubric_reward_in_unit_interval(self):
|
| 321 |
+
reset_resp = _reset(task_id=1, seed=42)
|
| 322 |
+
self.assertEqual(reset_resp.status_code, 200)
|
| 323 |
+
obs = reset_resp.json()
|
| 324 |
+
|
| 325 |
+
allowed_fields = obs["allowed_fields"]
|
| 326 |
+
final_rubric_reward = None
|
| 327 |
+
for _ in range(20):
|
| 328 |
+
action_payload: dict = {}
|
| 329 |
+
if "issue_type" in allowed_fields:
|
| 330 |
+
action_payload["issue_type"] = "general_inquiry"
|
| 331 |
+
if "priority" in allowed_fields:
|
| 332 |
+
action_payload["priority"] = "medium"
|
| 333 |
+
if "assignment_group" in allowed_fields:
|
| 334 |
+
action_payload["assignment_group"] = "service_desk"
|
| 335 |
+
if "resolution_action" in allowed_fields:
|
| 336 |
+
action_payload["resolution_action"] = "acknowledge"
|
| 337 |
+
|
| 338 |
+
step_resp = client.post("/step", json=action_payload)
|
| 339 |
+
self.assertEqual(step_resp.status_code, 200)
|
| 340 |
+
obs = step_resp.json()
|
| 341 |
+
|
| 342 |
+
if obs["done"]:
|
| 343 |
+
final_rubric_reward = obs.get("rubric_reward")
|
| 344 |
+
break
|
| 345 |
+
|
| 346 |
+
self.assertIsNotNone(
|
| 347 |
+
final_rubric_reward, "Terminal observation did not include rubric_reward"
|
| 348 |
+
)
|
| 349 |
+
self.assertGreaterEqual(final_rubric_reward, 0.0)
|
| 350 |
+
self.assertLessEqual(final_rubric_reward, 1.0)
|
| 351 |
+
|
| 352 |
def test_full_episode_all_tasks_complete(self):
|
| 353 |
"""4.1.1 — Full seeded episode completes for each task ID (1, 2, 3)."""
|
| 354 |
for task_id in (1, 2, 3):
|
tests/test_competitive_upgrade.py
CHANGED
|
@@ -182,6 +182,16 @@ class TestStateHasRewardAndDone(unittest.TestCase):
|
|
| 182 |
obs = env.step(_heuristic_action(obs))
|
| 183 |
self.assertFalse(env.state.done)
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
# ---------------------------------------------------------------------------
|
| 187 |
# 9.3 — History entry contains title and predicted
|
|
@@ -318,7 +328,7 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
|
|
| 318 |
return seed
|
| 319 |
return None
|
| 320 |
|
| 321 |
-
def
|
| 322 |
"""Force a ticket with ambiguity_note by patching the dataset."""
|
| 323 |
from unittest.mock import patch
|
| 324 |
from server.tasks import load_dataset
|
|
@@ -336,8 +346,22 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
|
|
| 336 |
obs = env.reset(seed=0, task_id=3)
|
| 337 |
|
| 338 |
self.assertIsNotNone(obs.current_ticket)
|
| 339 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
self.assertEqual(obs.current_ticket["ambiguity_note"], target.ambiguity_note)
|
|
|
|
| 341 |
|
| 342 |
def test_ambiguity_note_absent_when_ticket_has_none(self) -> None:
|
| 343 |
"""Tickets without ambiguity_note should not expose the key."""
|
|
@@ -370,6 +394,13 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
|
|
| 370 |
with patch.object(env, "_dataset", [ticket]):
|
| 371 |
obs = env.reset(seed=0, task_id=3)
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
self.assertIn("ambiguity_note", obs.current_ticket)
|
| 374 |
|
| 375 |
|
|
@@ -397,12 +428,27 @@ class TestRelatedTicketPreviewInObservation(unittest.TestCase):
|
|
| 397 |
):
|
| 398 |
obs = env.reset(seed=0, task_id=3, queue_size=1)
|
| 399 |
|
| 400 |
-
return env, obs, related
|
| 401 |
|
| 402 |
def test_related_ticket_preview_present_when_ticket_has_link(self) -> None:
|
| 403 |
-
env, obs, related = self._reset_linked_ticket_env()
|
| 404 |
|
| 405 |
self.assertIsNotNone(obs.current_ticket)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
self.assertIn("related_ticket_preview", obs.current_ticket)
|
| 407 |
self.assertEqual(
|
| 408 |
obs.current_ticket["related_ticket_preview"]["ticket_id"],
|
|
@@ -414,8 +460,22 @@ class TestRelatedTicketPreviewInObservation(unittest.TestCase):
|
|
| 414 |
)
|
| 415 |
|
| 416 |
def test_history_keeps_related_ticket_preview_after_step(self) -> None:
|
| 417 |
-
env, obs, related = self._reset_linked_ticket_env()
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
self.assertGreaterEqual(len(next_obs.history), 1)
|
| 421 |
self.assertIn("related_ticket_preview", next_obs.history[0])
|
|
@@ -563,6 +623,58 @@ class TestInvestigationActions(unittest.TestCase):
|
|
| 563 |
self.assertTrue(obs2.last_tool_result["found"])
|
| 564 |
self.assertGreaterEqual(len(obs2.last_tool_result["matches"]), 1)
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
class TestQueueEconomics(unittest.TestCase):
|
| 568 |
"""Free investigations are allowed, but excessive investigation gets a queue-level penalty."""
|
|
|
|
| 182 |
obs = env.step(_heuristic_action(obs))
|
| 183 |
self.assertFalse(env.state.done)
|
| 184 |
|
| 185 |
+
def test_state_tracks_average_score_and_reward_components(self) -> None:
|
| 186 |
+
env = _make_env()
|
| 187 |
+
obs = env.reset(seed=42, task_id=1)
|
| 188 |
+
env.step(_heuristic_action(obs))
|
| 189 |
+
state = env.state
|
| 190 |
+
self.assertGreaterEqual(state.average_score_so_far, 0.0)
|
| 191 |
+
self.assertLessEqual(state.average_score_so_far, 1.0)
|
| 192 |
+
self.assertIsInstance(state.last_reward_components, dict)
|
| 193 |
+
self.assertIn("final_reward", state.last_reward_components)
|
| 194 |
+
|
| 195 |
|
| 196 |
# ---------------------------------------------------------------------------
|
| 197 |
# 9.3 — History entry contains title and predicted
|
|
|
|
| 328 |
return seed
|
| 329 |
return None
|
| 330 |
|
| 331 |
+
def test_ambiguity_note_hidden_until_internal_note_lookup(self) -> None:
|
| 332 |
"""Force a ticket with ambiguity_note by patching the dataset."""
|
| 333 |
from unittest.mock import patch
|
| 334 |
from server.tasks import load_dataset
|
|
|
|
| 346 |
obs = env.reset(seed=0, task_id=3)
|
| 347 |
|
| 348 |
self.assertIsNotNone(obs.current_ticket)
|
| 349 |
+
self.assertNotIn("ambiguity_note", obs.current_ticket)
|
| 350 |
+
self.assertIn("context_status", obs.current_ticket)
|
| 351 |
+
self.assertIn(
|
| 352 |
+
"lookup_internal_routing_note",
|
| 353 |
+
obs.current_ticket["context_status"]["remaining_tools"],
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
obs = env.step(
|
| 357 |
+
HelpdeskTicketAction(
|
| 358 |
+
action_type="investigate",
|
| 359 |
+
tool_name="lookup_internal_routing_note",
|
| 360 |
+
)
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
self.assertEqual(obs.current_ticket["ambiguity_note"], target.ambiguity_note)
|
| 364 |
+
self.assertGreater(obs.reward or 0.0, 0.0)
|
| 365 |
|
| 366 |
def test_ambiguity_note_absent_when_ticket_has_none(self) -> None:
|
| 367 |
"""Tickets without ambiguity_note should not expose the key."""
|
|
|
|
| 394 |
with patch.object(env, "_dataset", [ticket]):
|
| 395 |
obs = env.reset(seed=0, task_id=3)
|
| 396 |
|
| 397 |
+
self.assertNotIn("ambiguity_note", obs.current_ticket)
|
| 398 |
+
obs = env.step(
|
| 399 |
+
HelpdeskTicketAction(
|
| 400 |
+
action_type="investigate",
|
| 401 |
+
tool_name="lookup_internal_routing_note",
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
self.assertIn("ambiguity_note", obs.current_ticket)
|
| 405 |
|
| 406 |
|
|
|
|
| 428 |
):
|
| 429 |
obs = env.reset(seed=0, task_id=3, queue_size=1)
|
| 430 |
|
| 431 |
+
return env, obs, ticket, related
|
| 432 |
|
| 433 |
def test_related_ticket_preview_present_when_ticket_has_link(self) -> None:
|
| 434 |
+
env, obs, ticket, related = self._reset_linked_ticket_env()
|
| 435 |
|
| 436 |
self.assertIsNotNone(obs.current_ticket)
|
| 437 |
+
self.assertNotIn("related_ticket_preview", obs.current_ticket)
|
| 438 |
+
self.assertIn("context_status", obs.current_ticket)
|
| 439 |
+
self.assertIn(
|
| 440 |
+
"lookup_related_ticket",
|
| 441 |
+
obs.current_ticket["context_status"]["remaining_tools"],
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
obs = env.step(
|
| 445 |
+
HelpdeskTicketAction(
|
| 446 |
+
action_type="investigate",
|
| 447 |
+
tool_name="lookup_related_ticket",
|
| 448 |
+
tool_target_ticket_id=ticket.related_ticket_id,
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
self.assertIn("related_ticket_preview", obs.current_ticket)
|
| 453 |
self.assertEqual(
|
| 454 |
obs.current_ticket["related_ticket_preview"]["ticket_id"],
|
|
|
|
| 460 |
)
|
| 461 |
|
| 462 |
def test_history_keeps_related_ticket_preview_after_step(self) -> None:
|
| 463 |
+
env, obs, ticket, related = self._reset_linked_ticket_env()
|
| 464 |
+
env.step(
|
| 465 |
+
HelpdeskTicketAction(
|
| 466 |
+
action_type="investigate",
|
| 467 |
+
tool_name="lookup_related_ticket",
|
| 468 |
+
tool_target_ticket_id=ticket.related_ticket_id,
|
| 469 |
+
)
|
| 470 |
+
)
|
| 471 |
+
next_obs = env.step(
|
| 472 |
+
HelpdeskTicketAction(
|
| 473 |
+
issue_type=ticket.issue_type,
|
| 474 |
+
priority=ticket.priority,
|
| 475 |
+
assignment_group=ticket.assignment_group,
|
| 476 |
+
resolution_action=ticket.resolution_action,
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
|
| 480 |
self.assertGreaterEqual(len(next_obs.history), 1)
|
| 481 |
self.assertIn("related_ticket_preview", next_obs.history[0])
|
|
|
|
| 623 |
self.assertTrue(obs2.last_tool_result["found"])
|
| 624 |
self.assertGreaterEqual(len(obs2.last_tool_result["matches"]), 1)
|
| 625 |
|
| 626 |
+
def test_internal_note_tool_reveals_hidden_hard_task_context(self) -> None:
|
| 627 |
+
from unittest.mock import patch
|
| 628 |
+
|
| 629 |
+
dataset = load_dataset()
|
| 630 |
+
ticket = next((t for t in dataset if t.ticket_id == "TKT-NONDEFAULT-003"), None)
|
| 631 |
+
self.assertIsNotNone(ticket)
|
| 632 |
+
|
| 633 |
+
env = _make_env()
|
| 634 |
+
with patch.object(env, "_dataset", [ticket]):
|
| 635 |
+
with patch.object(env, "_tickets_by_id", {ticket.ticket_id: ticket}):
|
| 636 |
+
obs = env.reset(seed=0, task_id=3, queue_size=1)
|
| 637 |
+
|
| 638 |
+
self.assertNotIn("ambiguity_note", obs.current_ticket)
|
| 639 |
+
obs = env.step(
|
| 640 |
+
HelpdeskTicketAction(
|
| 641 |
+
action_type="investigate",
|
| 642 |
+
tool_name="lookup_internal_routing_note",
|
| 643 |
+
)
|
| 644 |
+
)
|
| 645 |
+
self.assertEqual(obs.last_tool_result["routing_note"], ticket.ambiguity_note)
|
| 646 |
+
self.assertEqual(obs.current_ticket["ambiguity_note"], ticket.ambiguity_note)
|
| 647 |
+
self.assertGreater(obs.reward or 0.0, 0.0)
|
| 648 |
+
|
| 649 |
+
def test_submit_without_required_investigation_gets_shaping_penalty(self) -> None:
|
| 650 |
+
from unittest.mock import patch
|
| 651 |
+
|
| 652 |
+
dataset = load_dataset()
|
| 653 |
+
ticket = next((t for t in dataset if t.ticket_id == "TKT-NONDEFAULT-003"), None)
|
| 654 |
+
self.assertIsNotNone(ticket)
|
| 655 |
+
|
| 656 |
+
env = _make_env()
|
| 657 |
+
with patch.object(env, "_dataset", [ticket]):
|
| 658 |
+
with patch.object(env, "_tickets_by_id", {ticket.ticket_id: ticket}):
|
| 659 |
+
obs = env.reset(seed=0, task_id=3, queue_size=1)
|
| 660 |
+
|
| 661 |
+
final_obs = env.step(
|
| 662 |
+
HelpdeskTicketAction(
|
| 663 |
+
issue_type=ticket.issue_type,
|
| 664 |
+
priority=ticket.priority,
|
| 665 |
+
assignment_group=ticket.assignment_group,
|
| 666 |
+
resolution_action=ticket.resolution_action,
|
| 667 |
+
)
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
self.assertTrue(final_obs.done)
|
| 671 |
+
self.assertIsNotNone(final_obs.rubric_reward)
|
| 672 |
+
self.assertLess(final_obs.reward, final_obs.rubric_reward)
|
| 673 |
+
self.assertGreater(
|
| 674 |
+
final_obs.last_reward_components.get("context_gap_penalty", 0.0),
|
| 675 |
+
0.0,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
|
| 679 |
class TestQueueEconomics(unittest.TestCase):
|
| 680 |
"""Free investigations are allowed, but excessive investigation gets a queue-level penalty."""
|
tests/test_inference_unit.py
CHANGED
|
@@ -140,6 +140,16 @@ class InferenceUnitTests(unittest.TestCase):
|
|
| 140 |
self.assertIsNone(inference.HF_TOKEN)
|
| 141 |
self.assertFalse(inference.llm_mode_enabled())
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def test_run_uses_only_structured_start_step_end_logs(self) -> None:
|
| 144 |
inference = _load_inference_module()
|
| 145 |
|
|
@@ -179,6 +189,311 @@ class InferenceUnitTests(unittest.TestCase):
|
|
| 179 |
[1, 2, 3],
|
| 180 |
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
if __name__ == "__main__":
|
| 184 |
unittest.main()
|
|
|
|
| 140 |
self.assertIsNone(inference.HF_TOKEN)
|
| 141 |
self.assertFalse(inference.llm_mode_enabled())
|
| 142 |
|
| 143 |
+
def test_seed_env_override_is_respected(self) -> None:
|
| 144 |
+
inference = _load_inference_module({"SEED": "7"})
|
| 145 |
+
|
| 146 |
+
self.assertEqual(inference.SEED, 7)
|
| 147 |
+
|
| 148 |
+
def test_invalid_seed_env_falls_back_to_default(self) -> None:
|
| 149 |
+
inference = _load_inference_module({"SEED": "not-an-int"})
|
| 150 |
+
|
| 151 |
+
self.assertEqual(inference.SEED, 42)
|
| 152 |
+
|
| 153 |
def test_run_uses_only_structured_start_step_end_logs(self) -> None:
|
| 154 |
inference = _load_inference_module()
|
| 155 |
|
|
|
|
| 189 |
[1, 2, 3],
|
| 190 |
)
|
| 191 |
|
| 192 |
+
def test_build_llm_user_message_includes_recent_history_feedback(self) -> None:
|
| 193 |
+
inference = _load_inference_module()
|
| 194 |
+
|
| 195 |
+
ticket = {
|
| 196 |
+
"ticket_id": "ticket-xyz",
|
| 197 |
+
"title": "Contractor onboarding blocked by access issue",
|
| 198 |
+
"requester": "pm@contractorco.com",
|
| 199 |
+
"description": "Access permissions are blocking contractor setup.",
|
| 200 |
+
"context_status": {
|
| 201 |
+
"investigation_required": True,
|
| 202 |
+
"revealed_tools": [],
|
| 203 |
+
"remaining_tools": ["lookup_internal_routing_note"],
|
| 204 |
+
"hints": ["An internal routing note may disambiguate the correct workflow."],
|
| 205 |
+
},
|
| 206 |
+
"last_tool_result": {"tool_name": "lookup_requester_history", "found": False},
|
| 207 |
+
"feedback_summary": "Ticket score=0.40; field_scores[issue_type=0.40]; reward=0.40",
|
| 208 |
+
"last_reward_components": {"ticket_score": 0.4, "final_reward": 0.4},
|
| 209 |
+
"investigation_budget_remaining": 2,
|
| 210 |
+
"average_score_so_far": 0.7,
|
| 211 |
+
"progress_fraction": 0.5,
|
| 212 |
+
"recent_history": [
|
| 213 |
+
{
|
| 214 |
+
"ticket_id": "ticket-prev",
|
| 215 |
+
"predicted": {"issue_type": "identity_access"},
|
| 216 |
+
"score": 0.4,
|
| 217 |
+
"breakdown": {"issue_type": 0.4},
|
| 218 |
+
"penalty_reason": "extra_fields: ['assignment_group']",
|
| 219 |
+
"feedback_summary": "Penalty applied: extra_fields: ['assignment_group']; reward=0.00",
|
| 220 |
+
"reward_components": {"reward_kind": "step_penalty", "final_reward": 0.0},
|
| 221 |
+
}
|
| 222 |
+
],
|
| 223 |
+
"queue_position": 2,
|
| 224 |
+
"tickets_remaining": 4,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
message = inference.build_llm_user_message(
|
| 228 |
+
ticket,
|
| 229 |
+
["issue_type"],
|
| 230 |
+
"Read the ticket and select the single best IT issue type.",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.assertIn("Recent evaluation feedback", message)
|
| 234 |
+
self.assertIn("score=0.4", message)
|
| 235 |
+
self.assertIn("penalty_reason=extra_fields", message)
|
| 236 |
+
self.assertIn("Latest environment feedback", message)
|
| 237 |
+
self.assertIn("Context status", message)
|
| 238 |
+
self.assertIn("Latest reward components", message)
|
| 239 |
+
self.assertIn("Average score so far: 0.7", message)
|
| 240 |
+
self.assertIn("Episode progress: 0.5", message)
|
| 241 |
+
self.assertIn("Investigation budget remaining: 2", message)
|
| 242 |
+
self.assertIn("Investigation result", message)
|
| 243 |
+
self.assertIn("queue_position=2", message)
|
| 244 |
+
|
| 245 |
+
def test_build_action_backfills_missing_fields_from_heuristic(self) -> None:
|
| 246 |
+
inference = _load_inference_module()
|
| 247 |
+
inference.llm_client = object()
|
| 248 |
+
|
| 249 |
+
ticket = {
|
| 250 |
+
"ticket_id": "ticket-018",
|
| 251 |
+
"title": "Question about enterprise tier pricing",
|
| 252 |
+
"requester": "finance@urbanstack.io",
|
| 253 |
+
"description": (
|
| 254 |
+
"We're comparing your enterprise plan against two competitors. "
|
| 255 |
+
"Can you send over a detailed pricing breakdown?"
|
| 256 |
+
),
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
with mock.patch.object(
|
| 260 |
+
inference,
|
| 261 |
+
"call_llm",
|
| 262 |
+
return_value={"issue_type": "service_request"},
|
| 263 |
+
):
|
| 264 |
+
action, action_source, fallback_reason = inference.build_action(
|
| 265 |
+
ticket,
|
| 266 |
+
["issue_type", "priority", "assignment_group", "resolution_action"],
|
| 267 |
+
"Perform full helpdesk routing.",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.assertEqual(action.issue_type, "service_request")
|
| 271 |
+
self.assertEqual(action.priority, "medium")
|
| 272 |
+
self.assertEqual(action.assignment_group, "procurement")
|
| 273 |
+
self.assertEqual(action.resolution_action, "assign")
|
| 274 |
+
self.assertEqual(action_source, "llm_backfilled")
|
| 275 |
+
self.assertIn("heuristic_backfill", fallback_reason or "")
|
| 276 |
+
|
| 277 |
+
def test_build_action_ignores_invalid_llm_fields_and_keeps_valid_ones(self) -> None:
|
| 278 |
+
inference = _load_inference_module()
|
| 279 |
+
inference.llm_client = object()
|
| 280 |
+
|
| 281 |
+
ticket = {
|
| 282 |
+
"ticket_id": "ticket-018",
|
| 283 |
+
"title": "Question about enterprise tier pricing",
|
| 284 |
+
"requester": "finance@urbanstack.io",
|
| 285 |
+
"description": (
|
| 286 |
+
"We're comparing your enterprise plan against two competitors. "
|
| 287 |
+
"Can you send over a detailed pricing breakdown?"
|
| 288 |
+
),
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
with mock.patch.object(
|
| 292 |
+
inference,
|
| 293 |
+
"call_llm",
|
| 294 |
+
return_value={
|
| 295 |
+
"issue_type": "service_request",
|
| 296 |
+
"priority": "urgent",
|
| 297 |
+
},
|
| 298 |
+
):
|
| 299 |
+
action, action_source, fallback_reason = inference.build_action(
|
| 300 |
+
ticket,
|
| 301 |
+
["issue_type", "priority"],
|
| 302 |
+
"Read the ticket, select the best IT issue type, and estimate the priority.",
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.assertEqual(action.issue_type, "service_request")
|
| 306 |
+
self.assertEqual(action.priority, "medium")
|
| 307 |
+
self.assertEqual(action_source, "llm_backfilled")
|
| 308 |
+
self.assertIn("invalid_llm_fields=['priority']", fallback_reason or "")
|
| 309 |
+
|
| 310 |
+
def test_build_action_backfills_dependent_fields_from_llm_issue_type(self) -> None:
|
| 311 |
+
inference = _load_inference_module()
|
| 312 |
+
inference.llm_client = object()
|
| 313 |
+
|
| 314 |
+
ticket = {
|
| 315 |
+
"ticket_id": "ticket-002",
|
| 316 |
+
"title": "Can not sign in after 2FA reset",
|
| 317 |
+
"requester": "ops@laneeight.io",
|
| 318 |
+
"description": (
|
| 319 |
+
"I was forced to reset 2FA and now the account stays locked even "
|
| 320 |
+
"with the backup code."
|
| 321 |
+
),
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
with mock.patch.object(
|
| 325 |
+
inference,
|
| 326 |
+
"call_llm",
|
| 327 |
+
return_value={"issue_type": "identity_access"},
|
| 328 |
+
):
|
| 329 |
+
action, action_source, fallback_reason = inference.build_action(
|
| 330 |
+
ticket,
|
| 331 |
+
["issue_type", "assignment_group", "resolution_action"],
|
| 332 |
+
"Perform full helpdesk routing.",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
self.assertEqual(action.issue_type, "identity_access")
|
| 336 |
+
self.assertEqual(action.assignment_group, "service_desk")
|
| 337 |
+
self.assertEqual(action.resolution_action, "fulfill")
|
| 338 |
+
self.assertEqual(action_source, "llm_backfilled")
|
| 339 |
+
self.assertIn("heuristic_backfill", fallback_reason or "")
|
| 340 |
+
|
| 341 |
+
def test_build_action_normalizes_pricing_request_issue_type(self) -> None:
|
| 342 |
+
inference = _load_inference_module()
|
| 343 |
+
inference.llm_client = object()
|
| 344 |
+
|
| 345 |
+
ticket = {
|
| 346 |
+
"ticket_id": "ticket-018",
|
| 347 |
+
"title": "Question about enterprise tier pricing",
|
| 348 |
+
"requester": "finance@urbanstack.io",
|
| 349 |
+
"description": (
|
| 350 |
+
"We're comparing your enterprise plan against two competitors. "
|
| 351 |
+
"Can you send over a detailed pricing breakdown?"
|
| 352 |
+
),
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
with mock.patch.object(
|
| 356 |
+
inference,
|
| 357 |
+
"call_llm",
|
| 358 |
+
return_value={
|
| 359 |
+
"issue_type": "billing_license",
|
| 360 |
+
"priority": "medium",
|
| 361 |
+
},
|
| 362 |
+
):
|
| 363 |
+
action, action_source, fallback_reason = inference.build_action(
|
| 364 |
+
ticket,
|
| 365 |
+
["issue_type", "priority", "assignment_group", "resolution_action"],
|
| 366 |
+
"Perform full helpdesk routing.",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
self.assertEqual(action.issue_type, "service_request")
|
| 370 |
+
self.assertEqual(action.assignment_group, "procurement")
|
| 371 |
+
self.assertEqual(action.resolution_action, "assign")
|
| 372 |
+
self.assertEqual(action.priority, "medium")
|
| 373 |
+
self.assertEqual(action_source, "llm_backfilled")
|
| 374 |
+
self.assertIn("domain_overrides", fallback_reason or "")
|
| 375 |
+
|
| 376 |
+
def test_build_action_normalizes_onboarding_access_blocker(self) -> None:
|
| 377 |
+
inference = _load_inference_module()
|
| 378 |
+
inference.llm_client = object()
|
| 379 |
+
|
| 380 |
+
ticket = {
|
| 381 |
+
"ticket_id": "TKT-NONDEFAULT-003",
|
| 382 |
+
"title": "Contractor onboarding blocked by access issue",
|
| 383 |
+
"requester": "pm@contractorco.com",
|
| 384 |
+
"description": (
|
| 385 |
+
"A new contractor cannot complete onboarding because their account "
|
| 386 |
+
"access is blocked by a permissions error. The onboarding team "
|
| 387 |
+
"cannot resolve access issues; routing to service desk."
|
| 388 |
+
),
|
| 389 |
+
"ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk",
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
with mock.patch.object(
|
| 393 |
+
inference,
|
| 394 |
+
"call_llm",
|
| 395 |
+
return_value={
|
| 396 |
+
"issue_type": "identity_access",
|
| 397 |
+
"priority": "high",
|
| 398 |
+
},
|
| 399 |
+
):
|
| 400 |
+
action, action_source, fallback_reason = inference.build_action(
|
| 401 |
+
ticket,
|
| 402 |
+
["issue_type", "priority", "assignment_group", "resolution_action"],
|
| 403 |
+
"Perform full helpdesk routing.",
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.assertEqual(action.issue_type, "onboarding")
|
| 407 |
+
self.assertEqual(action.priority, "medium")
|
| 408 |
+
self.assertEqual(action.assignment_group, "service_desk")
|
| 409 |
+
self.assertEqual(action.resolution_action, "fulfill")
|
| 410 |
+
self.assertEqual(action_source, "llm_backfilled")
|
| 411 |
+
self.assertIn("domain_overrides", fallback_reason or "")
|
| 412 |
+
|
| 413 |
+
def test_build_action_deescalates_nonurgent_onboarding_priority(self) -> None:
|
| 414 |
+
inference = _load_inference_module()
|
| 415 |
+
inference.llm_client = object()
|
| 416 |
+
|
| 417 |
+
ticket = {
|
| 418 |
+
"ticket_id": "ticket-008",
|
| 419 |
+
"title": "Kickoff onboarding session for newly activated account",
|
| 420 |
+
"requester": "admin@brightpath.io",
|
| 421 |
+
"description": (
|
| 422 |
+
"We activated our account this week and need an onboarding call plus "
|
| 423 |
+
"admin setup guidance for six internal users."
|
| 424 |
+
),
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
with mock.patch.object(
|
| 428 |
+
inference,
|
| 429 |
+
"call_llm",
|
| 430 |
+
return_value={
|
| 431 |
+
"issue_type": "onboarding",
|
| 432 |
+
"priority": "high",
|
| 433 |
+
},
|
| 434 |
+
):
|
| 435 |
+
action, action_source, fallback_reason = inference.build_action(
|
| 436 |
+
ticket,
|
| 437 |
+
["issue_type", "priority"],
|
| 438 |
+
"Read the ticket, select the best IT issue type, and estimate the priority.",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
self.assertEqual(action.issue_type, "onboarding")
|
| 442 |
+
self.assertEqual(action.priority, "medium")
|
| 443 |
+
self.assertEqual(action_source, "llm_backfilled")
|
| 444 |
+
self.assertIn("domain_overrides", fallback_reason or "")
|
| 445 |
+
|
| 446 |
+
def test_merge_ticket_context_carries_feedback_summary_from_observation(self) -> None:
|
| 447 |
+
inference = _load_inference_module()
|
| 448 |
+
|
| 449 |
+
observation = SimpleNamespace(
|
| 450 |
+
last_tool_result={"tool_name": "lookup_requester_history", "found": True},
|
| 451 |
+
history=[{"ticket_id": "ticket-prev", "score": 0.4}],
|
| 452 |
+
queue_position=2,
|
| 453 |
+
tickets_remaining=4,
|
| 454 |
+
investigation_budget_remaining=1,
|
| 455 |
+
average_score_so_far=0.55,
|
| 456 |
+
progress_fraction=0.4,
|
| 457 |
+
last_reward_components={"ticket_score": 0.4, "final_reward": 0.4},
|
| 458 |
+
metadata={"last_feedback_summary": "Ticket score=0.40; reward=0.40"},
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
merged = inference.merge_ticket_context(
|
| 462 |
+
{
|
| 463 |
+
"ticket_id": "ticket-xyz",
|
| 464 |
+
"title": "Contractor onboarding blocked by access issue",
|
| 465 |
+
},
|
| 466 |
+
observation,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
self.assertEqual(merged["feedback_summary"], "Ticket score=0.40; reward=0.40")
|
| 470 |
+
self.assertEqual(merged["investigation_budget_remaining"], 1)
|
| 471 |
+
self.assertEqual(merged["average_score_so_far"], 0.55)
|
| 472 |
+
self.assertEqual(merged["progress_fraction"], 0.4)
|
| 473 |
+
self.assertEqual(merged["last_reward_components"]["final_reward"], 0.4)
|
| 474 |
+
self.assertEqual(merged["queue_position"], 2)
|
| 475 |
+
self.assertEqual(merged["tickets_remaining"], 4)
|
| 476 |
+
self.assertEqual(merged["last_tool_result"]["tool_name"], "lookup_requester_history")
|
| 477 |
+
|
| 478 |
+
def test_should_investigate_uses_remaining_tools_from_context_status(self) -> None:
|
| 479 |
+
inference = _load_inference_module()
|
| 480 |
+
|
| 481 |
+
investigate, tool_name = inference.should_investigate(
|
| 482 |
+
{
|
| 483 |
+
"ticket_id": "ticket-021",
|
| 484 |
+
"context_status": {
|
| 485 |
+
"remaining_tools": [
|
| 486 |
+
"lookup_related_ticket",
|
| 487 |
+
"lookup_requester_history",
|
| 488 |
+
]
|
| 489 |
+
},
|
| 490 |
+
},
|
| 491 |
+
[],
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
self.assertTrue(investigate)
|
| 495 |
+
self.assertEqual(tool_name, "lookup_related_ticket")
|
| 496 |
+
|
| 497 |
|
| 498 |
if __name__ == "__main__":
|
| 499 |
unittest.main()
|
tests/test_policy_learning.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import types as _types
|
| 6 |
+
import unittest
|
| 7 |
+
|
| 8 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 9 |
+
|
| 10 |
+
import openenv_test_stubs # noqa: F401
|
| 11 |
+
|
| 12 |
+
if "openenv.core.env_server.interfaces" not in sys.modules:
|
| 13 |
+
_interfaces_mod = _types.ModuleType("openenv.core.env_server.interfaces")
|
| 14 |
+
|
| 15 |
+
class _Environment:
|
| 16 |
+
def __init__(self) -> None:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def __init_subclass__(cls, **kwargs: object) -> None:
|
| 20 |
+
super().__init_subclass__(**kwargs)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def __class_getitem__(cls, item: object) -> type:
|
| 24 |
+
return cls
|
| 25 |
+
|
| 26 |
+
_interfaces_mod.Environment = _Environment # type: ignore[attr-defined]
|
| 27 |
+
sys.modules["openenv.core.env_server.interfaces"] = _interfaces_mod
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
from models import HelpdeskTicketAction, HelpdeskTicketObservation
|
| 31 |
+
from policy_learning import (
|
| 32 |
+
POLICY_LIBRARY,
|
| 33 |
+
choose_policy_action,
|
| 34 |
+
compare_policies,
|
| 35 |
+
parse_int_spec,
|
| 36 |
+
rollout_episode,
|
| 37 |
+
search_policies,
|
| 38 |
+
)
|
| 39 |
+
from server.environment import HelpdeskTicketRoutingEnvironment
|
| 40 |
+
from server.tasks import get_task_definition
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SingleTicketEnvironment(HelpdeskTicketRoutingEnvironment):
|
| 44 |
+
def __init__(self, ticket_id: str) -> None:
|
| 45 |
+
super().__init__()
|
| 46 |
+
self._forced_ticket_id = ticket_id
|
| 47 |
+
|
| 48 |
+
def reset(self, seed=None, episode_id=None, **kwargs):
|
| 49 |
+
observation = super().reset(seed=seed, episode_id=episode_id, **kwargs)
|
| 50 |
+
ticket = self._tickets_by_id[self._forced_ticket_id]
|
| 51 |
+
self._queue = [ticket]
|
| 52 |
+
self._state.current_task_id = int(kwargs.get("task_id", 3))
|
| 53 |
+
self._state.queue_ticket_ids = [ticket.ticket_id]
|
| 54 |
+
self._state.current_ticket_index = 0
|
| 55 |
+
self._state.per_ticket_scores = []
|
| 56 |
+
self._state.total_reward = 0.0
|
| 57 |
+
self._state.last_step_reward = None
|
| 58 |
+
self._state.reward = None
|
| 59 |
+
self._state.done = False
|
| 60 |
+
self._state.average_score_so_far = 0.0
|
| 61 |
+
self._state.investigation_steps = 0
|
| 62 |
+
self._state.investigation_budget_remaining = len(self._queue)
|
| 63 |
+
self._state.investigation_penalty_applied = 0.0
|
| 64 |
+
self._state.last_tool_result = None
|
| 65 |
+
self._state.last_reward_components = {}
|
| 66 |
+
self._state.ticket_tool_usage = {}
|
| 67 |
+
self._state.history_entries = []
|
| 68 |
+
return self._build_observation(get_task_definition(self._state.current_task_id))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _context_sensitive_submit_builder(
|
| 72 |
+
ticket: dict[str, object], allowed_fields: list[str]
|
| 73 |
+
) -> HelpdeskTicketAction:
|
| 74 |
+
if ticket.get("ambiguity_note"):
|
| 75 |
+
values = {
|
| 76 |
+
"issue_type": "onboarding",
|
| 77 |
+
"priority": "medium",
|
| 78 |
+
"assignment_group": "service_desk",
|
| 79 |
+
"resolution_action": "fulfill",
|
| 80 |
+
}
|
| 81 |
+
else:
|
| 82 |
+
values = {
|
| 83 |
+
"issue_type": "identity_access",
|
| 84 |
+
"priority": "high",
|
| 85 |
+
"assignment_group": "service_desk",
|
| 86 |
+
"resolution_action": "fulfill",
|
| 87 |
+
}
|
| 88 |
+
return HelpdeskTicketAction(
|
| 89 |
+
**{field: value for field, value in values.items() if field in allowed_fields}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class PolicyLearningTests(unittest.TestCase):
|
| 94 |
+
def test_parse_int_spec_expands_ranges(self) -> None:
|
| 95 |
+
self.assertEqual(parse_int_spec("42-44,44,46", field_name="seeds"), [42, 43, 44, 46])
|
| 96 |
+
|
| 97 |
+
def test_choose_policy_action_prefers_hidden_context_tools(self) -> None:
|
| 98 |
+
policy = POLICY_LIBRARY["investigate_when_context_hidden"]
|
| 99 |
+
observation = HelpdeskTicketObservation(
|
| 100 |
+
current_ticket={
|
| 101 |
+
"ticket_id": "ticket-021",
|
| 102 |
+
"context_status": {
|
| 103 |
+
"remaining_tools": ["lookup_related_ticket", "lookup_requester_history"],
|
| 104 |
+
"revealed_tools": [],
|
| 105 |
+
}
|
| 106 |
+
},
|
| 107 |
+
allowed_fields=["issue_type"],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
action, source = choose_policy_action(policy, observation, {}, _context_sensitive_submit_builder)
|
| 111 |
+
|
| 112 |
+
self.assertEqual(action.action_type, "investigate")
|
| 113 |
+
self.assertEqual(action.tool_name, "lookup_related_ticket")
|
| 114 |
+
self.assertEqual(source, "investigate_hidden_context")
|
| 115 |
+
|
| 116 |
+
def test_choose_policy_action_submits_when_investigation_disabled(self) -> None:
|
| 117 |
+
policy = POLICY_LIBRARY["no_investigation"]
|
| 118 |
+
observation = HelpdeskTicketObservation(
|
| 119 |
+
current_ticket={
|
| 120 |
+
"ticket_id": "ticket-021",
|
| 121 |
+
"context_status": {"remaining_tools": ["lookup_related_ticket"]},
|
| 122 |
+
},
|
| 123 |
+
allowed_fields=["issue_type", "priority"],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
action, source = choose_policy_action(policy, observation, {}, _context_sensitive_submit_builder)
|
| 127 |
+
|
| 128 |
+
self.assertEqual(action.action_type, "submit")
|
| 129 |
+
self.assertEqual(action.issue_type, "identity_access")
|
| 130 |
+
self.assertEqual(source, "submit")
|
| 131 |
+
|
| 132 |
+
def test_rollout_episode_rewards_context_aware_policy(self) -> None:
|
| 133 |
+
no_investigation = POLICY_LIBRARY["no_investigation"]
|
| 134 |
+
context_aware = POLICY_LIBRARY["investigate_when_context_hidden"]
|
| 135 |
+
|
| 136 |
+
no_summary, _ = rollout_episode(
|
| 137 |
+
env=SingleTicketEnvironment("TKT-NONDEFAULT-003"),
|
| 138 |
+
policy=no_investigation,
|
| 139 |
+
seed=42,
|
| 140 |
+
task_id=3,
|
| 141 |
+
submit_builder=_context_sensitive_submit_builder,
|
| 142 |
+
)
|
| 143 |
+
context_summary, _ = rollout_episode(
|
| 144 |
+
env=SingleTicketEnvironment("TKT-NONDEFAULT-003"),
|
| 145 |
+
policy=context_aware,
|
| 146 |
+
seed=42,
|
| 147 |
+
task_id=3,
|
| 148 |
+
submit_builder=_context_sensitive_submit_builder,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self.assertLess(no_summary["terminal_reward"], context_summary["terminal_reward"])
|
| 152 |
+
self.assertLess(no_summary["normalized_return"], context_summary["normalized_return"])
|
| 153 |
+
self.assertEqual(context_summary["investigation_steps"], 1)
|
| 154 |
+
|
| 155 |
+
def test_search_policies_selects_better_policy(self) -> None:
|
| 156 |
+
report = search_policies(
|
| 157 |
+
[
|
| 158 |
+
POLICY_LIBRARY["no_investigation"],
|
| 159 |
+
POLICY_LIBRARY["investigate_when_context_hidden"],
|
| 160 |
+
],
|
| 161 |
+
train_seeds=[41, 42],
|
| 162 |
+
eval_seeds=[43],
|
| 163 |
+
task_ids=[3],
|
| 164 |
+
output_dir=os.path.join(os.getcwd(), "analysis", "policy_learning_test"),
|
| 165 |
+
env_factory=lambda: SingleTicketEnvironment("TKT-NONDEFAULT-003"),
|
| 166 |
+
submit_builder=_context_sensitive_submit_builder,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.assertEqual(report["selected_policy"], "investigate_when_context_hidden")
|
| 170 |
+
self.assertGreater(
|
| 171 |
+
report["eval_improvement_vs_baseline"]["avg_normalized_return"],
|
| 172 |
+
0.0,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def test_compare_policies_reports_improvement(self) -> None:
|
| 176 |
+
report = compare_policies(
|
| 177 |
+
[
|
| 178 |
+
POLICY_LIBRARY["no_investigation"],
|
| 179 |
+
POLICY_LIBRARY["investigate_when_context_hidden"],
|
| 180 |
+
],
|
| 181 |
+
seeds=[42],
|
| 182 |
+
task_ids=[3],
|
| 183 |
+
output_dir=os.path.join(os.getcwd(), "analysis", "policy_learning_compare_test"),
|
| 184 |
+
env_factory=lambda: SingleTicketEnvironment("TKT-NONDEFAULT-003"),
|
| 185 |
+
submit_builder=_context_sensitive_submit_builder,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.assertEqual(report["best_policy"], "investigate_when_context_hidden")
|
| 189 |
+
self.assertGreater(report["improvement_vs_baseline"]["avg_terminal_reward"], 0.0)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
unittest.main()
|