Roopalgn commited on
Commit
8ada670
·
1 Parent(s): 67ce1eb

Use evaluator API_KEY for LLM proxy and strengthen env

Browse files
Dockerfile CHANGED
@@ -1,7 +1,8 @@
1
  FROM python:3.11-slim
2
 
3
  ENV PYTHONDONTWRITEBYTECODE=1 \
4
- PYTHONUNBUFFERED=1
 
5
 
6
  WORKDIR /app
7
 
@@ -14,6 +15,14 @@ RUN python -m pip install --upgrade pip \
14
  && python -m pip install --no-cache-dir -r requirements.txt \
15
  && python -m pip install --no-cache-dir .
16
 
 
 
 
17
  EXPOSE 7860
18
 
 
 
 
 
 
19
  CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.11-slim
2
 
3
  ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1
6
 
7
  WORKDIR /app
8
 
 
15
  && python -m pip install --no-cache-dir -r requirements.txt \
16
  && python -m pip install --no-cache-dir .
17
 
18
+ RUN useradd --create-home --uid 10001 appuser \
19
+ && chown -R appuser:appuser /app
20
+
21
  EXPOSE 7860
22
 
23
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
24
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:7860/health', timeout=3)"
25
+
26
+ USER appuser
27
+
28
  CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -383,9 +383,10 @@ TASK_ID=3 python inference.py
383
 
384
  Set these environment variables first:
385
 
386
- - `API_BASE_URL`
387
- - `MODEL_NAME`
388
- - `HF_TOKEN`
 
389
 
390
  Then run:
391
 
 
383
 
384
  Set these environment variables first:
385
 
386
+ - `API_BASE_URL`
387
+ - `MODEL_NAME`
388
+ - `API_KEY`
389
+ - `HF_TOKEN`
390
 
391
  Then run:
392
 
ROADMAP.md CHANGED
@@ -130,7 +130,7 @@ These come directly from `required.md` and `KNOWLEDGE.md`:
130
  - 3 tasks exist and remain meaningfully different
131
  - grader scores stay in `[0.0, 1.0]`
132
  - `inference.py` runs reproducibly without crashing
133
- - `inference.py` uses the OpenAI client with `API_BASE_URL`, `MODEL_NAME`, and `HF_TOKEN`
134
  - structured stdout logs follow the official `[START]`, `[STEP]`, and `[END]` format
135
  - `openenv validate` passes
136
  - Docker builds and starts cleanly
 
130
  - 3 tasks exist and remain meaningfully different
131
  - grader scores stay in `[0.0, 1.0]`
132
  - `inference.py` runs reproducibly without crashing
133
+ - `inference.py` uses the OpenAI client with `API_BASE_URL`, `MODEL_NAME`, and the evaluator-injected `API_KEY` (`HF_TOKEN` remains a local fallback)
134
  - structured stdout logs follow the official `[START]`, `[STEP]`, and `[END]` format
135
  - `openenv validate` passes
136
  - Docker builds and starts cleanly
data/dataset.json CHANGED
@@ -574,6 +574,126 @@
574
  "resolution_action": "fulfill",
575
  "ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk",
576
  "related_ticket_id": null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  }
578
  ]
579
 
 
574
  "resolution_action": "fulfill",
575
  "ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk",
576
  "related_ticket_id": null
577
+ },
578
+ {
579
+ "ticket_id": "ticket-046",
580
+ "title": "Privileged admin login blocked during security review",
581
+ "requester": "security-ops@atlasbank.io",
582
+ "description": "Our privileged admin account was locked during an internal security review. We need access restored, but the security team must verify the incident trail before the account is reopened.",
583
+ "issue_type": "identity_access",
584
+ "priority": "critical",
585
+ "assignment_group": "security_team",
586
+ "resolution_action": "escalate",
587
+ "ambiguity_note": "Looks like a login problem, but security owns the privileged-access review and release decision.",
588
+ "related_ticket_id": null
589
+ },
590
+ {
591
+ "ticket_id": "ticket-047",
592
+ "title": "Temporary sandbox extension for signed pilot",
593
+ "requester": "solutions@bluequarry.io",
594
+ "description": "The commercial pilot is already approved. We only need the existing sandbox kept alive for two more weeks so the customer can finish testing.",
595
+ "issue_type": "service_request",
596
+ "priority": "medium",
597
+ "assignment_group": "service_desk",
598
+ "resolution_action": "fulfill",
599
+ "ambiguity_note": "Commercial context is present, but the actual action is an operational extension the service desk can fulfill directly.",
600
+ "related_ticket_id": null
601
+ },
602
+ {
603
+ "ticket_id": "ticket-048",
604
+ "title": "Who approves seat-transfer terms in the vendor questionnaire?",
605
+ "requester": "vendorops@aurorahealth.org",
606
+ "description": "Our procurement team is filling out your vendor questionnaire and needs clarification on who approves seat-transfer language before we continue the review.",
607
+ "issue_type": "general_inquiry",
608
+ "priority": "medium",
609
+ "assignment_group": "procurement",
610
+ "resolution_action": "assign",
611
+ "ambiguity_note": "The request is a question, but it belongs with the commercial owner rather than the generic service desk.",
612
+ "related_ticket_id": null
613
+ },
614
+ {
615
+ "ticket_id": "ticket-049",
616
+ "title": "Credential-defense rollout is causing auth API failures",
617
+ "requester": "platform@nightferry.dev",
618
+ "description": "The authentication API is returning intermittent 403 errors after a credential-stuffing defense rule was enabled. Product behavior is broken, but security needs to triage the mitigation first.",
619
+ "issue_type": "application_support",
620
+ "priority": "high",
621
+ "assignment_group": "security_team",
622
+ "resolution_action": "escalate",
623
+ "ambiguity_note": "The symptom looks like application support, but the active security control owns the first response path.",
624
+ "related_ticket_id": null
625
+ },
626
+ {
627
+ "ticket_id": "ticket-050",
628
+ "title": "Acquired-team onboarding needs cross-functional coordination",
629
+ "requester": "integration@mergerco.com",
630
+ "description": "Thirty acquired employees start next week and need onboarding, access setup, hardware coordination, and shared mailbox provisioning across multiple internal teams.",
631
+ "issue_type": "onboarding",
632
+ "priority": "high",
633
+ "assignment_group": "service_desk",
634
+ "resolution_action": "assign",
635
+ "ambiguity_note": "The workflow is onboarding, but it requires central service-desk coordination instead of a single onboarding-ops fulfillment step.",
636
+ "related_ticket_id": null
637
+ },
638
+ {
639
+ "ticket_id": "ticket-051",
640
+ "title": "Renewal credit memo requires contract amendment approval",
641
+ "requester": "procurement@crownlogistics.com",
642
+ "description": "Finance approved the renewal credit memo, but the contract amendment still needs commercial approval before the invoice can be corrected.",
643
+ "issue_type": "billing_license",
644
+ "priority": "medium",
645
+ "assignment_group": "procurement",
646
+ "resolution_action": "assign",
647
+ "ambiguity_note": "This sounds billing-related, but the remaining work is a commercial contract amendment owned by procurement.",
648
+ "related_ticket_id": null
649
+ },
650
+ {
651
+ "ticket_id": "ticket-052",
652
+ "title": "Need remediation evidence package for product vulnerability",
653
+ "requester": "assurance@clientgrid.com",
654
+ "description": "Our assurance team needs the remediation evidence package for a previously confirmed application vulnerability before we close the compliance review.",
655
+ "issue_type": "security_compliance",
656
+ "priority": "high",
657
+ "assignment_group": "application_team",
658
+ "resolution_action": "fulfill",
659
+ "ambiguity_note": "The request is compliance-driven, but the application team must provide the concrete remediation evidence.",
660
+ "related_ticket_id": null
661
+ },
662
+ {
663
+ "ticket_id": "ticket-053",
664
+ "title": "Customer requests penetration-test window and allowlist",
665
+ "requester": "engsec@vectorlabs.io",
666
+ "description": "We want to schedule a penetration test and need the approved window plus the process for allowlisting our source IPs.",
667
+ "issue_type": "service_request",
668
+ "priority": "medium",
669
+ "assignment_group": "security_team",
670
+ "resolution_action": "assign",
671
+ "ambiguity_note": "This is a request, but the security team owns approval and coordination instead of procurement.",
672
+ "related_ticket_id": null
673
+ },
674
+ {
675
+ "ticket_id": "ticket-054",
676
+ "title": "Need archived invoice copies for board audit binder",
677
+ "requester": "boardops@silverpine.com",
678
+ "description": "The board audit binder needs PDF copies of invoices from the last four quarters. No billing change is required, just document retrieval.",
679
+ "issue_type": "general_inquiry",
680
+ "priority": "low",
681
+ "assignment_group": "license_ops",
682
+ "resolution_action": "fulfill",
683
+ "ambiguity_note": "The request is informational, but license operations owns the archived invoice records and can fulfill it directly.",
684
+ "related_ticket_id": null
685
+ },
686
+ {
687
+ "ticket_id": "ticket-055",
688
+ "title": "Re: Renewal credit memo requires contract amendment approval",
689
+ "requester": "procurement@crownlogistics.com",
690
+ "description": "Following up on ticket-051. Quarter close is tomorrow and the contract amendment is still pending, so the corrected invoice cannot be issued yet.",
691
+ "issue_type": "billing_license",
692
+ "priority": "high",
693
+ "assignment_group": "procurement",
694
+ "resolution_action": "escalate",
695
+ "ambiguity_note": null,
696
+ "related_ticket_id": "ticket-051"
697
  }
698
  ]
699
 
inference.py CHANGED
@@ -16,8 +16,12 @@ MODEL_NAME
16
  Model identifier to use for LLM inference.
17
  Default: ``<your-active-model>``
18
 
 
 
 
 
19
  HF_TOKEN
20
- HuggingFace authentication token for the LLM provider.
21
  No default is set.
22
 
23
  TASK_ID
@@ -33,8 +37,9 @@ LOCAL_IMAGE_NAME
33
  Optional compatibility variable from the sample inference pattern.
34
  This script does not use ``from_docker_image()``, so the value is unused here.
35
 
36
- When both MODEL_NAME and HF_TOKEN are set explicitly, the script calls the LLM via the
37
- OpenAI-compatible API at API_BASE_URL. Otherwise it falls back to the deterministic
 
38
  heuristic baseline automatically.
39
 
40
  All stdout logs use the required structured tags: ``[START]``, ``[STEP]``, and ``[END]``.
@@ -83,6 +88,7 @@ def _get_int_env(name: str, default: int) -> int:
83
  API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
84
  MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
85
  HF_TOKEN = os.getenv("HF_TOKEN")
 
86
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
87
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
88
 
@@ -100,12 +106,12 @@ RUN_ALL_TASKS_ENV = os.getenv("RUN_ALL_TASKS", "").strip().lower() in {
100
 
101
 
102
  def llm_mode_enabled() -> bool:
103
- return bool(HF_TOKEN) and MODEL_NAME != DEFAULT_MODEL_NAME
104
 
105
 
106
  llm_client: OpenAI | None = None
107
  if llm_mode_enabled():
108
- llm_client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
109
 
110
 
111
  RECENT_HISTORY_LIMIT = 2
@@ -698,21 +704,108 @@ def should_investigate(ticket: dict, history: list[dict[str, Any]]) -> tuple[boo
698
  if not ticket:
699
  return False, None
700
  context_status = ticket.get("context_status") or {}
701
- remaining_tools = context_status.get("remaining_tools") or []
702
- if remaining_tools:
703
- return True, str(remaining_tools[0])
704
  current_ticket_id = ticket.get("ticket_id")
 
 
 
 
 
705
  already_investigated = any(
706
  entry.get("ticket_id") == current_ticket_id
707
  and entry.get("predicted", {}).get("action_type") == "investigate"
708
  for entry in history
709
  )
710
- if already_investigated:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  return False, None
712
- if ticket.get("related_ticket_id"):
 
 
713
  return True, "lookup_related_ticket"
714
- if ticket.get("ambiguity_note"):
715
- return True, "lookup_requester_history"
716
  return False, None
717
 
718
 
 
16
  Model identifier to use for LLM inference.
17
  Default: ``<your-active-model>``
18
 
19
+ API_KEY
20
+ Proxy/API authentication token injected by the evaluator.
21
+ No default is set.
22
+
23
  HF_TOKEN
24
+ Backward-compatible local fallback alias for API_KEY.
25
  No default is set.
26
 
27
  TASK_ID
 
37
  Optional compatibility variable from the sample inference pattern.
38
  This script does not use ``from_docker_image()``, so the value is unused here.
39
 
40
+ When MODEL_NAME and API_KEY are set explicitly, the script calls the LLM via the
41
+ OpenAI-compatible API at API_BASE_URL. For local compatibility, HF_TOKEN is accepted
42
+ as a fallback alias for API_KEY. Otherwise it falls back to the deterministic
43
  heuristic baseline automatically.
44
 
45
  All stdout logs use the required structured tags: ``[START]``, ``[STEP]``, and ``[END]``.
 
88
  API_BASE_URL = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
89
  MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME)
90
  HF_TOKEN = os.getenv("HF_TOKEN")
91
+ API_KEY = os.getenv("API_KEY") or HF_TOKEN
92
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
93
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
94
 
 
106
 
107
 
108
  def llm_mode_enabled() -> bool:
109
+ return bool(API_KEY) and MODEL_NAME != DEFAULT_MODEL_NAME
110
 
111
 
112
  llm_client: OpenAI | None = None
113
  if llm_mode_enabled():
114
+ llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
115
 
116
 
117
  RECENT_HISTORY_LIMIT = 2
 
704
  if not ticket:
705
  return False, None
706
  context_status = ticket.get("context_status") or {}
 
 
 
707
  current_ticket_id = ticket.get("ticket_id")
708
+ prior_ticket_history = [
709
+ entry
710
+ for entry in history
711
+ if entry.get("ticket_id") == current_ticket_id
712
+ ]
713
  already_investigated = any(
714
  entry.get("ticket_id") == current_ticket_id
715
  and entry.get("predicted", {}).get("action_type") == "investigate"
716
  for entry in history
717
  )
718
+ investigations_used = sum(
719
+ 1
720
+ for entry in prior_ticket_history
721
+ if entry.get("predicted", {}).get("action_type") == "investigate"
722
+ )
723
+ hidden_context_remaining = bool(context_status.get("hidden_context_remaining"))
724
+ if investigations_used >= 3:
725
+ return False, None
726
+
727
+ used_tools = {
728
+ entry.get("predicted", {}).get("tool_name")
729
+ for entry in prior_ticket_history
730
+ if entry.get("predicted", {}).get("action_type") == "investigate"
731
+ }
732
+ routing_text = build_routing_text(ticket)
733
+ last_tool_result = ticket.get("last_tool_result") or {}
734
+ last_tool_name = str(last_tool_result.get("tool_name", "") or "")
735
+
736
+ follow_up_signal = any(
737
+ phrase in routing_text
738
+ for phrase in (
739
+ "re:",
740
+ "follow-up",
741
+ "following up",
742
+ "regression",
743
+ "reference ticket",
744
+ "third update",
745
+ "still",
746
+ "unresolved",
747
+ )
748
+ )
749
+ routing_ambiguity_signal = any(
750
+ phrase in routing_text
751
+ for phrase in (
752
+ "billing-style",
753
+ "prorating",
754
+ "seat expansion",
755
+ "vendor offer",
756
+ "pricing",
757
+ "compliance scan",
758
+ "vulnerability",
759
+ "onboarding workflow",
760
+ "blocked by an account problem",
761
+ "permissions error",
762
+ "mixed workflow",
763
+ )
764
+ )
765
+ requester_history_signal = any(
766
+ phrase in routing_text
767
+ for phrase in (
768
+ "still haven't",
769
+ "third update",
770
+ "again",
771
+ "follow-up",
772
+ "priority",
773
+ "legal",
774
+ "overdue",
775
+ "escalating",
776
+ )
777
+ )
778
+
779
+ preferred_tools: list[str] = []
780
+ if last_tool_name == "lookup_related_ticket":
781
+ preferred_tools.append("lookup_requester_history")
782
+ if last_tool_name == "lookup_requester_history":
783
+ preferred_tools.append("lookup_internal_routing_note")
784
+ if follow_up_signal or ticket.get("related_ticket_id"):
785
+ preferred_tools.append("lookup_related_ticket")
786
+ if routing_ambiguity_signal or hidden_context_remaining:
787
+ preferred_tools.append("lookup_internal_routing_note")
788
+ if requester_history_signal:
789
+ preferred_tools.append("lookup_requester_history")
790
+ if hidden_context_remaining:
791
+ preferred_tools.extend(
792
+ [
793
+ "lookup_related_ticket",
794
+ "lookup_internal_routing_note",
795
+ "lookup_requester_history",
796
+ ]
797
+ )
798
+
799
+ for tool_name in preferred_tools:
800
+ if tool_name not in used_tools:
801
+ return True, tool_name
802
+
803
+ if already_investigated and not hidden_context_remaining:
804
  return False, None
805
+ if ticket.get("ambiguity_note") and "lookup_internal_routing_note" not in used_tools:
806
+ return True, "lookup_internal_routing_note"
807
+ if ticket.get("related_ticket_id") and "lookup_related_ticket" not in used_tools:
808
  return True, "lookup_related_ticket"
 
 
809
  return False, None
810
 
811
 
openenv.yaml CHANGED
@@ -32,7 +32,11 @@ api:
32
  - /reset
33
  - /step
34
  - /state
 
35
  - /tasks
 
 
 
36
  - /docs
37
 
38
  evaluation:
@@ -51,9 +55,13 @@ inference:
51
  env_vars:
52
  - API_BASE_URL
53
  - MODEL_NAME
 
54
  - HF_TOKEN
55
  - ENV_URL
56
  - TASK_ID
 
 
 
57
 
58
  requirements:
59
  python: ">=3.11"
 
32
  - /reset
33
  - /step
34
  - /state
35
+ - /ws
36
  - /tasks
37
+ - /web
38
+ - /baseline
39
+ - /grader
40
  - /docs
41
 
42
  evaluation:
 
55
  env_vars:
56
  - API_BASE_URL
57
  - MODEL_NAME
58
+ - API_KEY
59
  - HF_TOKEN
60
  - ENV_URL
61
  - TASK_ID
62
+ - SEED
63
+ - RUN_ALL_TASKS
64
+ - LOCAL_IMAGE_NAME
65
 
66
  requirements:
67
  python: ">=3.11"
policy_learning.py CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
4
  import argparse
5
  import importlib
6
  import json
7
- from dataclasses import asdict, dataclass
8
  from pathlib import Path
9
  from statistics import mean
10
  from typing import Any, Callable, Iterable
@@ -18,13 +18,13 @@ from vocabulary import TASK_IDS
18
  DEFAULT_COMPARE_POLICIES = (
19
  "no_investigation",
20
  "investigate_when_context_hidden",
 
21
  )
22
  DEFAULT_SEARCH_POLICIES = (
23
  "no_investigation",
24
  "legacy_single_probe",
25
  "investigate_when_context_hidden",
26
- "context_chain",
27
- "hybrid_context",
28
  )
29
  DEFAULT_OUTPUT_DIR = "analysis/policy_learning_runs"
30
 
@@ -40,11 +40,13 @@ class PolicyConfig:
40
  investigate_ambiguity_history: bool
41
  max_investigations_per_ticket: int
42
  description: str
 
43
 
44
 
45
  POLICY_LIBRARY: dict[str, PolicyConfig] = {
46
  "no_investigation": PolicyConfig(
47
  name="no_investigation",
 
48
  investigate_hidden_context=False,
49
  investigate_related_ticket_hint=False,
50
  investigate_ambiguity_history=False,
@@ -53,6 +55,7 @@ POLICY_LIBRARY: dict[str, PolicyConfig] = {
53
  ),
54
  "legacy_single_probe": PolicyConfig(
55
  name="legacy_single_probe",
 
56
  investigate_hidden_context=False,
57
  investigate_related_ticket_hint=True,
58
  investigate_ambiguity_history=True,
@@ -61,30 +64,105 @@ POLICY_LIBRARY: dict[str, PolicyConfig] = {
61
  ),
62
  "investigate_when_context_hidden": PolicyConfig(
63
  name="investigate_when_context_hidden",
 
64
  investigate_hidden_context=True,
65
  investigate_related_ticket_hint=False,
66
  investigate_ambiguity_history=False,
67
  max_investigations_per_ticket=1,
68
- description="Investigate once when the environment says context is hidden.",
69
  ),
70
- "context_chain": PolicyConfig(
71
- name="context_chain",
72
- investigate_hidden_context=True,
73
- investigate_related_ticket_hint=False,
74
- investigate_ambiguity_history=False,
75
- max_investigations_per_ticket=3,
76
- description="Follow the environment's required-tool chain until context is revealed.",
77
- ),
78
- "hybrid_context": PolicyConfig(
79
- name="hybrid_context",
80
  investigate_hidden_context=True,
81
  investigate_related_ticket_hint=True,
82
  investigate_ambiguity_history=True,
83
  max_investigations_per_ticket=3,
84
- description="Use hidden-context signals first, then legacy ambiguity hints.",
 
 
85
  ),
86
  }
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def _dedupe_preserving_order(values: Iterable[int]) -> list[int]:
90
  seen: set[int] = set()
@@ -154,29 +232,199 @@ def default_submit_builder(
154
  return HelpdeskTicketAction(**candidate)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def choose_policy_action(
158
  policy: PolicyConfig,
159
  observation: HelpdeskTicketObservation,
160
  investigations_by_ticket: dict[str, int],
161
  submit_builder: SubmitBuilder,
162
- ) -> tuple[HelpdeskTicketAction, str]:
 
 
 
163
  ticket = observation.current_ticket or {}
164
  ticket_id = str(ticket.get("ticket_id", ""))
165
  ticket_investigations = investigations_by_ticket.get(ticket_id, 0)
166
- revealed_tools = set(((ticket.get("context_status") or {}).get("revealed_tools") or []))
167
- remaining_tools = list(((ticket.get("context_status") or {}).get("remaining_tools") or []))
 
 
 
168
 
169
  if ticket_investigations < policy.max_investigations_per_ticket:
170
- if policy.investigate_hidden_context and remaining_tools:
171
- tool_name = str(remaining_tools[0])
172
- return (
173
- HelpdeskTicketAction(action_type="investigate", tool_name=tool_name),
174
- "investigate_hidden_context",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
 
 
 
 
 
 
176
  if (
177
  policy.investigate_related_ticket_hint
178
  and ticket.get("related_ticket_id")
179
- and "lookup_related_ticket" not in revealed_tools
180
  ):
181
  return (
182
  HelpdeskTicketAction(
@@ -184,11 +432,16 @@ def choose_policy_action(
184
  tool_name="lookup_related_ticket",
185
  ),
186
  "investigate_related_ticket_hint",
 
187
  )
188
  if (
189
  policy.investigate_ambiguity_history
190
- and ticket.get("ambiguity_note")
191
- and "lookup_requester_history" not in revealed_tools
 
 
 
 
192
  ):
193
  return (
194
  HelpdeskTicketAction(
@@ -196,9 +449,10 @@ def choose_policy_action(
196
  tool_name="lookup_requester_history",
197
  ),
198
  "investigate_ambiguity_history",
 
199
  )
200
 
201
- return submit_builder(ticket, list(observation.allowed_fields)), "submit"
202
 
203
 
204
  def rollout_episode(
@@ -208,27 +462,39 @@ def rollout_episode(
208
  seed: int,
209
  task_id: int,
210
  submit_builder: SubmitBuilder,
 
 
211
  ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
212
  task = get_task_definition(task_id)
213
  observation = env.reset(seed=seed, task_id=task_id)
214
  investigations_by_ticket: dict[str, int] = {}
 
215
  episode_return = 0.0
216
  trajectories: list[dict[str, Any]] = []
217
 
218
  while not observation.done:
219
  ticket = observation.current_ticket or {}
220
  ticket_id = str(ticket.get("ticket_id", ""))
221
- action, action_source = choose_policy_action(
222
  policy,
223
  observation,
224
  investigations_by_ticket,
225
  submit_builder,
 
 
226
  )
227
  next_observation = env.step(action)
228
  reward_value = float(next_observation.reward or 0.0)
229
  episode_return += reward_value
230
  if action.action_type == "investigate" and ticket_id:
231
  investigations_by_ticket[ticket_id] = investigations_by_ticket.get(ticket_id, 0) + 1
 
 
 
 
 
 
 
232
 
233
  history_entry = env.state.history_entries[-1] if env.state.history_entries else {}
234
  trajectories.append(
@@ -241,6 +507,7 @@ def rollout_episode(
241
  "step_index": len(trajectories) + 1,
242
  "ticket_id": history_entry.get("ticket_id", ticket_id),
243
  "action_source": action_source,
 
244
  "action": action.model_dump(exclude_none=True),
245
  "step_reward": reward_value,
246
  "rubric_reward": next_observation.rubric_reward,
@@ -280,6 +547,8 @@ def rollout_episode(
280
  "average_ticket_score": env.state.average_score_so_far,
281
  "per_ticket_scores": list(env.state.per_ticket_scores),
282
  }
 
 
283
  return summary, trajectories
284
 
285
 
@@ -352,6 +621,8 @@ def evaluate_policy(
352
  *,
353
  env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
354
  submit_builder: SubmitBuilder = default_submit_builder,
 
 
355
  ) -> dict[str, Any]:
356
  episode_summaries: list[dict[str, Any]] = []
357
  trajectories: list[dict[str, Any]] = []
@@ -365,16 +636,21 @@ def evaluate_policy(
365
  seed=seed,
366
  task_id=task_id,
367
  submit_builder=submit_builder,
 
 
368
  )
369
  episode_summaries.append(summary)
370
  trajectories.extend(episode_trajectories)
371
 
372
- return {
373
  "policy": policy.name,
374
  "summary": summarize_policy_episodes(policy, episode_summaries),
375
  "episodes": episode_summaries,
376
  "trajectories": trajectories,
377
  }
 
 
 
378
 
379
 
380
  def _selection_tuple(summary: dict[str, Any]) -> tuple[float, float, float, float]:
@@ -416,16 +692,20 @@ def compare_policies(
416
  submit_builder: SubmitBuilder = default_submit_builder,
417
  ) -> dict[str, Any]:
418
  output_dir = Path(output_dir)
419
- policy_runs = [
420
- evaluate_policy(
421
- policy,
422
- seeds,
423
- task_ids,
424
- env_factory=env_factory,
425
- submit_builder=submit_builder,
 
 
 
 
 
 
426
  )
427
- for policy in policies
428
- ]
429
  best_run = select_best_policy(policy_runs)
430
  baseline_run = policy_runs[0]
431
 
@@ -461,6 +741,11 @@ def compare_policies(
461
  reverse=True,
462
  )
463
  ],
 
 
 
 
 
464
  "artifacts": {
465
  "summary": str(output_dir / "compare_summary.json"),
466
  "episodes": str(output_dir / "compare_episodes.jsonl"),
@@ -496,16 +781,22 @@ def search_policies(
496
  baseline_policy_name: str = "no_investigation",
497
  ) -> dict[str, Any]:
498
  output_dir = Path(output_dir)
499
- train_runs = [
500
- evaluate_policy(
 
 
 
501
  policy,
502
  train_seeds,
503
  task_ids,
504
  env_factory=env_factory,
505
  submit_builder=submit_builder,
 
 
506
  )
507
- for policy in candidate_policies
508
- ]
 
509
  selected_run = select_best_policy(train_runs)
510
  selected_policy = POLICY_LIBRARY[selected_run["policy"]]
511
  eval_selected = evaluate_policy(
@@ -514,6 +805,8 @@ def search_policies(
514
  task_ids,
515
  env_factory=env_factory,
516
  submit_builder=submit_builder,
 
 
517
  )
518
 
519
  baseline_policy = POLICY_LIBRARY.get(baseline_policy_name, candidate_policies[0])
@@ -523,6 +816,8 @@ def search_policies(
523
  task_ids,
524
  env_factory=env_factory,
525
  submit_builder=submit_builder,
 
 
526
  )
527
 
528
  report = {
@@ -535,6 +830,9 @@ def search_policies(
535
  "selected_policy": selected_policy.name,
536
  "baseline_policy": baseline_policy.name,
537
  "train_policy_summaries": [run["summary"] for run in train_runs],
 
 
 
538
  "eval_selected_summary": eval_selected["summary"],
539
  "eval_baseline_summary": eval_baseline["summary"],
540
  "eval_improvement_vs_baseline": {
 
4
  import argparse
5
  import importlib
6
  import json
7
+ from dataclasses import asdict, dataclass, field
8
  from pathlib import Path
9
  from statistics import mean
10
  from typing import Any, Callable, Iterable
 
18
  DEFAULT_COMPARE_POLICIES = (
19
  "no_investigation",
20
  "investigate_when_context_hidden",
21
+ "adaptive_cue_bandit",
22
  )
23
  DEFAULT_SEARCH_POLICIES = (
24
  "no_investigation",
25
  "legacy_single_probe",
26
  "investigate_when_context_hidden",
27
+ "adaptive_cue_bandit",
 
28
  )
29
  DEFAULT_OUTPUT_DIR = "analysis/policy_learning_runs"
30
 
 
40
  investigate_ambiguity_history: bool
41
  max_investigations_per_ticket: int
42
  description: str
43
+ strategy: str = "static"
44
 
45
 
46
  POLICY_LIBRARY: dict[str, PolicyConfig] = {
47
  "no_investigation": PolicyConfig(
48
  name="no_investigation",
49
+ strategy="static",
50
  investigate_hidden_context=False,
51
  investigate_related_ticket_hint=False,
52
  investigate_ambiguity_history=False,
 
55
  ),
56
  "legacy_single_probe": PolicyConfig(
57
  name="legacy_single_probe",
58
+ strategy="static",
59
  investigate_hidden_context=False,
60
  investigate_related_ticket_hint=True,
61
  investigate_ambiguity_history=True,
 
64
  ),
65
  "investigate_when_context_hidden": PolicyConfig(
66
  name="investigate_when_context_hidden",
67
+ strategy="static",
68
  investigate_hidden_context=True,
69
  investigate_related_ticket_hint=False,
70
  investigate_ambiguity_history=False,
71
  max_investigations_per_ticket=1,
72
+ description="Investigate once when the environment shows hidden-context pressure.",
73
  ),
74
+ "adaptive_cue_bandit": PolicyConfig(
75
+ name="adaptive_cue_bandit",
76
+ strategy="adaptive",
 
 
 
 
 
 
 
77
  investigate_hidden_context=True,
78
  investigate_related_ticket_hint=True,
79
  investigate_ambiguity_history=True,
80
  max_investigations_per_ticket=3,
81
+ description=(
82
+ "Learn cue-conditioned tool preferences from investigation rewards on train seeds."
83
+ ),
84
  ),
85
  }
86
 
87
+ AVAILABLE_TOOLS = (
88
+ "lookup_related_ticket",
89
+ "lookup_requester_history",
90
+ "lookup_internal_routing_note",
91
+ )
92
+
93
+
94
+ @dataclass
95
+ class AdaptiveToolBandit:
96
+ exploration_rounds: int = 1
97
+ cue_tool_totals: dict[str, dict[str, float]] = field(default_factory=dict)
98
+ cue_tool_counts: dict[str, dict[str, int]] = field(default_factory=dict)
99
+ global_tool_totals: dict[str, float] = field(default_factory=dict)
100
+ global_tool_counts: dict[str, int] = field(default_factory=dict)
101
+
102
+ def choose_tool(self, cue: str, candidate_tools: list[str]) -> str:
103
+ for tool_name in candidate_tools:
104
+ if self.cue_tool_counts.get(cue, {}).get(tool_name, 0) < self.exploration_rounds:
105
+ return tool_name
106
+ return max(
107
+ candidate_tools,
108
+ key=lambda tool_name: (
109
+ self._cue_average(cue, tool_name),
110
+ self._global_average(tool_name),
111
+ -candidate_tools.index(tool_name),
112
+ ),
113
+ )
114
+
115
+ def record_reward(self, cue: str, tool_name: str, reward: float) -> None:
116
+ cue_totals = self.cue_tool_totals.setdefault(cue, {})
117
+ cue_counts = self.cue_tool_counts.setdefault(cue, {})
118
+ cue_totals[tool_name] = cue_totals.get(tool_name, 0.0) + reward
119
+ cue_counts[tool_name] = cue_counts.get(tool_name, 0) + 1
120
+ self.global_tool_totals[tool_name] = self.global_tool_totals.get(tool_name, 0.0) + reward
121
+ self.global_tool_counts[tool_name] = self.global_tool_counts.get(tool_name, 0) + 1
122
+
123
+ def export(self) -> dict[str, Any]:
124
+ return {
125
+ "exploration_rounds": self.exploration_rounds,
126
+ "cue_tool_averages": {
127
+ cue: {
128
+ tool_name: round(self._cue_average(cue, tool_name), 6)
129
+ for tool_name in sorted(tool_totals)
130
+ }
131
+ for cue, tool_totals in sorted(self.cue_tool_totals.items())
132
+ },
133
+ "global_tool_averages": {
134
+ tool_name: round(self._global_average(tool_name), 6)
135
+ for tool_name in sorted(self.global_tool_totals)
136
+ },
137
+ }
138
+
139
+ def frozen_copy(self) -> "AdaptiveToolBandit":
140
+ return AdaptiveToolBandit(
141
+ exploration_rounds=self.exploration_rounds,
142
+ cue_tool_totals={
143
+ cue: dict(tool_totals) for cue, tool_totals in self.cue_tool_totals.items()
144
+ },
145
+ cue_tool_counts={
146
+ cue: dict(tool_counts) for cue, tool_counts in self.cue_tool_counts.items()
147
+ },
148
+ global_tool_totals=dict(self.global_tool_totals),
149
+ global_tool_counts=dict(self.global_tool_counts),
150
+ )
151
+
152
+ def _cue_average(self, cue: str, tool_name: str) -> float:
153
+ total = self.cue_tool_totals.get(cue, {}).get(tool_name, 0.0)
154
+ count = self.cue_tool_counts.get(cue, {}).get(tool_name, 0)
155
+ if count == 0:
156
+ return self._global_average(tool_name)
157
+ return total / count
158
+
159
+ def _global_average(self, tool_name: str) -> float:
160
+ total = self.global_tool_totals.get(tool_name, 0.0)
161
+ count = self.global_tool_counts.get(tool_name, 0)
162
+ if count == 0:
163
+ return 0.0
164
+ return total / count
165
+
166
 
167
  def _dedupe_preserving_order(values: Iterable[int]) -> list[int]:
168
  seen: set[int] = set()
 
232
  return HelpdeskTicketAction(**candidate)
233
 
234
 
235
+ def _routing_text(ticket: dict[str, Any]) -> str:
236
+ parts = [
237
+ str(ticket.get("title", "")),
238
+ str(ticket.get("description", "")),
239
+ str(ticket.get("ambiguity_note", "")),
240
+ json.dumps(ticket.get("last_tool_result") or {}, sort_keys=True),
241
+ ]
242
+ related_preview = ticket.get("related_ticket_preview") or {}
243
+ parts.extend(
244
+ [
245
+ str(related_preview.get("title", "")),
246
+ str(related_preview.get("description", "")),
247
+ ]
248
+ )
249
+ return " ".join(parts).lower()
250
+
251
+
252
+ def infer_ticket_cue(ticket: dict[str, Any]) -> str:
253
+ text = _routing_text(ticket)
254
+ if any(
255
+ phrase in text
256
+ for phrase in ("re:", "follow-up", "following up", "regression", "reference ticket", "third update")
257
+ ):
258
+ return "follow_up"
259
+ if any(
260
+ phrase in text
261
+ for phrase in (
262
+ "pricing",
263
+ "quote",
264
+ "vendor offer",
265
+ "prorating",
266
+ "seat expansion",
267
+ "commercial",
268
+ )
269
+ ):
270
+ return "commercial_ambiguity"
271
+ if any(
272
+ phrase in text
273
+ for phrase in (
274
+ "onboarding",
275
+ "contractor",
276
+ "permissions error",
277
+ "blocked by an account problem",
278
+ )
279
+ ):
280
+ return "workflow_blocker"
281
+ if any(
282
+ phrase in text
283
+ for phrase in ("compliance scan", "vulnerability", "policy issue", "routing note")
284
+ ):
285
+ return "routing_note"
286
+ if any(
287
+ phrase in text
288
+ for phrase in ("still", "again", "overdue", "legal", "priority")
289
+ ):
290
+ return "history_pressure"
291
+ return "generic_hidden_context"
292
+
293
+
294
+ def preferred_tool_order(
295
+ ticket: dict[str, Any],
296
+ *,
297
+ hidden_context_remaining: bool,
298
+ ) -> list[str]:
299
+ text = _routing_text(ticket)
300
+ last_tool_result = ticket.get("last_tool_result") or {}
301
+ last_tool_name = str(last_tool_result.get("tool_name", "") or "")
302
+
303
+ preferred_tools: list[str] = []
304
+ if last_tool_name == "lookup_related_ticket":
305
+ preferred_tools.append("lookup_requester_history")
306
+ if last_tool_name == "lookup_requester_history":
307
+ preferred_tools.append("lookup_internal_routing_note")
308
+
309
+ if any(
310
+ phrase in text
311
+ for phrase in ("re:", "follow-up", "following up", "regression", "reference ticket")
312
+ ) or ticket.get("related_ticket_id"):
313
+ preferred_tools.append("lookup_related_ticket")
314
+
315
+ if any(
316
+ phrase in text
317
+ for phrase in (
318
+ "pricing",
319
+ "quote",
320
+ "vendor offer",
321
+ "prorating",
322
+ "seat expansion",
323
+ "billing-style",
324
+ "compliance scan",
325
+ "vulnerability",
326
+ "onboarding workflow",
327
+ "permissions error",
328
+ "blocked by an account problem",
329
+ )
330
+ ):
331
+ preferred_tools.append("lookup_internal_routing_note")
332
+
333
+ if any(
334
+ phrase in text
335
+ for phrase in ("still", "again", "overdue", "legal", "third update", "priority")
336
+ ):
337
+ preferred_tools.append("lookup_requester_history")
338
+
339
+ if hidden_context_remaining:
340
+ preferred_tools.extend(
341
+ [
342
+ "lookup_internal_routing_note",
343
+ "lookup_related_ticket",
344
+ "lookup_requester_history",
345
+ ]
346
+ )
347
+
348
+ deduped_tools: list[str] = []
349
+ for tool_name in preferred_tools:
350
+ if tool_name not in deduped_tools:
351
+ deduped_tools.append(tool_name)
352
+ return deduped_tools
353
+
354
+
355
+ def select_cue_based_tool(
356
+ ticket: dict[str, Any],
357
+ *,
358
+ hidden_context_remaining: bool,
359
+ used_tools: set[str],
360
+ ) -> str | None:
361
+ preferred_tools = preferred_tool_order(
362
+ ticket,
363
+ hidden_context_remaining=hidden_context_remaining,
364
+ )
365
+ for tool_name in preferred_tools:
366
+ if tool_name not in used_tools:
367
+ return tool_name
368
+ return None
369
+
370
+
371
  def choose_policy_action(
372
  policy: PolicyConfig,
373
  observation: HelpdeskTicketObservation,
374
  investigations_by_ticket: dict[str, int],
375
  submit_builder: SubmitBuilder,
376
+ *,
377
+ used_tools_by_ticket: dict[str, set[str]] | None = None,
378
+ adaptive_bandit: AdaptiveToolBandit | None = None,
379
+ ) -> tuple[HelpdeskTicketAction, str, str | None]:
380
  ticket = observation.current_ticket or {}
381
  ticket_id = str(ticket.get("ticket_id", ""))
382
  ticket_investigations = investigations_by_ticket.get(ticket_id, 0)
383
+ used_tools = set()
384
+ if used_tools_by_ticket is not None:
385
+ used_tools = set(used_tools_by_ticket.get(ticket_id, set()))
386
+ context_status = ticket.get("context_status") or {}
387
+ hidden_context_remaining = bool(context_status.get("hidden_context_remaining"))
388
 
389
  if ticket_investigations < policy.max_investigations_per_ticket:
390
+ if policy.strategy == "adaptive" and adaptive_bandit is not None and hidden_context_remaining:
391
+ candidate_tools = [
392
+ tool_name
393
+ for tool_name in preferred_tool_order(
394
+ ticket,
395
+ hidden_context_remaining=hidden_context_remaining,
396
+ )
397
+ if tool_name not in used_tools
398
+ ]
399
+ if not candidate_tools:
400
+ candidate_tools = [
401
+ tool_name for tool_name in AVAILABLE_TOOLS if tool_name not in used_tools
402
+ ]
403
+ if candidate_tools:
404
+ cue = infer_ticket_cue(ticket)
405
+ tool_name = adaptive_bandit.choose_tool(cue, candidate_tools)
406
+ return (
407
+ HelpdeskTicketAction(action_type="investigate", tool_name=tool_name),
408
+ "adaptive_bandit_investigate",
409
+ cue,
410
+ )
411
+
412
+ if policy.investigate_hidden_context and hidden_context_remaining:
413
+ tool_name = select_cue_based_tool(
414
+ ticket,
415
+ hidden_context_remaining=hidden_context_remaining,
416
+ used_tools=used_tools,
417
  )
418
+ if tool_name is not None:
419
+ return (
420
+ HelpdeskTicketAction(action_type="investigate", tool_name=tool_name),
421
+ "investigate_hidden_context",
422
+ infer_ticket_cue(ticket),
423
+ )
424
  if (
425
  policy.investigate_related_ticket_hint
426
  and ticket.get("related_ticket_id")
427
+ and "lookup_related_ticket" not in used_tools
428
  ):
429
  return (
430
  HelpdeskTicketAction(
 
432
  tool_name="lookup_related_ticket",
433
  ),
434
  "investigate_related_ticket_hint",
435
+ infer_ticket_cue(ticket),
436
  )
437
  if (
438
  policy.investigate_ambiguity_history
439
+ and (
440
+ ticket.get("ambiguity_note")
441
+ or ticket.get("feedback_summary")
442
+ or hidden_context_remaining
443
+ )
444
+ and "lookup_requester_history" not in used_tools
445
  ):
446
  return (
447
  HelpdeskTicketAction(
 
449
  tool_name="lookup_requester_history",
450
  ),
451
  "investigate_ambiguity_history",
452
+ infer_ticket_cue(ticket),
453
  )
454
 
455
+ return submit_builder(ticket, list(observation.allowed_fields)), "submit", None
456
 
457
 
458
  def rollout_episode(
 
462
  seed: int,
463
  task_id: int,
464
  submit_builder: SubmitBuilder,
465
+ adaptive_bandit: AdaptiveToolBandit | None = None,
466
+ update_adaptive: bool = False,
467
  ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
468
  task = get_task_definition(task_id)
469
  observation = env.reset(seed=seed, task_id=task_id)
470
  investigations_by_ticket: dict[str, int] = {}
471
+ used_tools_by_ticket: dict[str, set[str]] = {}
472
  episode_return = 0.0
473
  trajectories: list[dict[str, Any]] = []
474
 
475
  while not observation.done:
476
  ticket = observation.current_ticket or {}
477
  ticket_id = str(ticket.get("ticket_id", ""))
478
+ action, action_source, action_cue = choose_policy_action(
479
  policy,
480
  observation,
481
  investigations_by_ticket,
482
  submit_builder,
483
+ used_tools_by_ticket=used_tools_by_ticket,
484
+ adaptive_bandit=adaptive_bandit,
485
  )
486
  next_observation = env.step(action)
487
  reward_value = float(next_observation.reward or 0.0)
488
  episode_return += reward_value
489
  if action.action_type == "investigate" and ticket_id:
490
  investigations_by_ticket[ticket_id] = investigations_by_ticket.get(ticket_id, 0) + 1
491
+ used_tools_by_ticket.setdefault(ticket_id, set()).add(str(action.tool_name))
492
+ if policy.strategy == "adaptive" and adaptive_bandit is not None and update_adaptive:
493
+ adaptive_bandit.record_reward(
494
+ action_cue or infer_ticket_cue(ticket),
495
+ str(action.tool_name),
496
+ reward_value,
497
+ )
498
 
499
  history_entry = env.state.history_entries[-1] if env.state.history_entries else {}
500
  trajectories.append(
 
507
  "step_index": len(trajectories) + 1,
508
  "ticket_id": history_entry.get("ticket_id", ticket_id),
509
  "action_source": action_source,
510
+ "action_cue": action_cue,
511
  "action": action.model_dump(exclude_none=True),
512
  "step_reward": reward_value,
513
  "rubric_reward": next_observation.rubric_reward,
 
547
  "average_ticket_score": env.state.average_score_so_far,
548
  "per_ticket_scores": list(env.state.per_ticket_scores),
549
  }
550
+ if adaptive_bandit is not None and policy.strategy == "adaptive":
551
+ summary["learned_tool_values"] = adaptive_bandit.export()
552
  return summary, trajectories
553
 
554
 
 
621
  *,
622
  env_factory: EnvFactory = HelpdeskTicketRoutingEnvironment,
623
  submit_builder: SubmitBuilder = default_submit_builder,
624
+ adaptive_bandit: AdaptiveToolBandit | None = None,
625
+ update_adaptive: bool = False,
626
  ) -> dict[str, Any]:
627
  episode_summaries: list[dict[str, Any]] = []
628
  trajectories: list[dict[str, Any]] = []
 
636
  seed=seed,
637
  task_id=task_id,
638
  submit_builder=submit_builder,
639
+ adaptive_bandit=adaptive_bandit,
640
+ update_adaptive=update_adaptive,
641
  )
642
  episode_summaries.append(summary)
643
  trajectories.extend(episode_trajectories)
644
 
645
+ result = {
646
  "policy": policy.name,
647
  "summary": summarize_policy_episodes(policy, episode_summaries),
648
  "episodes": episode_summaries,
649
  "trajectories": trajectories,
650
  }
651
+ if adaptive_bandit is not None and policy.strategy == "adaptive":
652
+ result["adaptive_bandit"] = adaptive_bandit.export()
653
+ return result
654
 
655
 
656
  def _selection_tuple(summary: dict[str, Any]) -> tuple[float, float, float, float]:
 
692
  submit_builder: SubmitBuilder = default_submit_builder,
693
  ) -> dict[str, Any]:
694
  output_dir = Path(output_dir)
695
+ policy_runs = []
696
+ for policy in policies:
697
+ adaptive_bandit = AdaptiveToolBandit() if policy.strategy == "adaptive" else None
698
+ policy_runs.append(
699
+ evaluate_policy(
700
+ policy,
701
+ seeds,
702
+ task_ids,
703
+ env_factory=env_factory,
704
+ submit_builder=submit_builder,
705
+ adaptive_bandit=adaptive_bandit,
706
+ update_adaptive=policy.strategy == "adaptive",
707
+ )
708
  )
 
 
709
  best_run = select_best_policy(policy_runs)
710
  baseline_run = policy_runs[0]
711
 
 
741
  reverse=True,
742
  )
743
  ],
744
+ "adaptive_bandits": {
745
+ run["policy"]: run["adaptive_bandit"]
746
+ for run in policy_runs
747
+ if "adaptive_bandit" in run
748
+ },
749
  "artifacts": {
750
  "summary": str(output_dir / "compare_summary.json"),
751
  "episodes": str(output_dir / "compare_episodes.jsonl"),
 
781
  baseline_policy_name: str = "no_investigation",
782
  ) -> dict[str, Any]:
783
  output_dir = Path(output_dir)
784
+ train_runs = []
785
+ trained_bandits: dict[str, AdaptiveToolBandit] = {}
786
+ for policy in candidate_policies:
787
+ adaptive_bandit = AdaptiveToolBandit() if policy.strategy == "adaptive" else None
788
+ train_run = evaluate_policy(
789
  policy,
790
  train_seeds,
791
  task_ids,
792
  env_factory=env_factory,
793
  submit_builder=submit_builder,
794
+ adaptive_bandit=adaptive_bandit,
795
+ update_adaptive=policy.strategy == "adaptive",
796
  )
797
+ train_runs.append(train_run)
798
+ if adaptive_bandit is not None:
799
+ trained_bandits[policy.name] = adaptive_bandit.frozen_copy()
800
  selected_run = select_best_policy(train_runs)
801
  selected_policy = POLICY_LIBRARY[selected_run["policy"]]
802
  eval_selected = evaluate_policy(
 
805
  task_ids,
806
  env_factory=env_factory,
807
  submit_builder=submit_builder,
808
+ adaptive_bandit=trained_bandits.get(selected_policy.name),
809
+ update_adaptive=False,
810
  )
811
 
812
  baseline_policy = POLICY_LIBRARY.get(baseline_policy_name, candidate_policies[0])
 
816
  task_ids,
817
  env_factory=env_factory,
818
  submit_builder=submit_builder,
819
+ adaptive_bandit=trained_bandits.get(baseline_policy.name),
820
+ update_adaptive=False,
821
  )
822
 
823
  report = {
 
830
  "selected_policy": selected_policy.name,
831
  "baseline_policy": baseline_policy.name,
832
  "train_policy_summaries": [run["summary"] for run in train_runs],
833
+ "trained_adaptive_bandits": {
834
+ name: bandit.export() for name, bandit in trained_bandits.items()
835
+ },
836
  "eval_selected_summary": eval_selected["summary"],
837
  "eval_baseline_summary": eval_baseline["summary"],
838
  "eval_improvement_vs_baseline": {
required.md CHANGED
@@ -154,11 +154,12 @@ All of these must pass:
154
 
155
  ### Required inference environment variables
156
 
157
- - `API_BASE_URL`
158
- - `MODEL_NAME`
159
- - `HF_TOKEN`
 
160
 
161
- The official text also mentions `OPENAI_API_KEY` in one place, but the more specific submission instructions above consistently emphasize `API_BASE_URL`, `MODEL_NAME`, and `HF_TOKEN`. We should follow the later, more specific instruction while continuing to use the OpenAI client.
162
 
163
  ### Inference script constraints
164
 
@@ -302,7 +303,7 @@ The project keeps three tasks:
302
  ### Inference
303
 
304
  - heuristic mode works without model credentials
305
- - LLM mode reads `API_BASE_URL`, `MODEL_NAME`, and `HF_TOKEN`
306
  - uses the OpenAI client
307
  - stdout follows `[START]`, `[STEP]`, and `[END]`
308
  - output is reproducible when the seed is fixed
 
154
 
155
  ### Required inference environment variables
156
 
157
+ - `API_BASE_URL`
158
+ - `MODEL_NAME`
159
+ - `API_KEY`
160
+ - `HF_TOKEN`
161
 
162
+ Use `API_KEY` as the primary evaluator-injected credential for the OpenAI client. `HF_TOKEN` can remain as a backward-compatible local fallback, but submission-time LLM traffic should flow through the injected proxy key.
163
 
164
  ### Inference script constraints
165
 
 
303
  ### Inference
304
 
305
  - heuristic mode works without model credentials
306
+ - LLM mode reads `API_BASE_URL`, `MODEL_NAME`, and `API_KEY` (`HF_TOKEN` remains a local fallback)
307
  - uses the OpenAI client
308
  - stdout follows `[START]`, `[STEP]`, and `[END]`
309
  - output is reproducible when the seed is fixed
server/Dockerfile CHANGED
@@ -1,7 +1,8 @@
1
  FROM python:3.11-slim
2
 
3
  ENV PYTHONDONTWRITEBYTECODE=1 \
4
- PYTHONUNBUFFERED=1
 
5
 
6
  WORKDIR /app
7
 
@@ -14,6 +15,14 @@ RUN python -m pip install --upgrade pip \
14
  && python -m pip install --no-cache-dir -r requirements.txt \
15
  && python -m pip install --no-cache-dir .
16
 
 
 
 
17
  EXPOSE 7860
18
 
 
 
 
 
 
19
  CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.11-slim
2
 
3
  ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1
6
 
7
  WORKDIR /app
8
 
 
15
  && python -m pip install --no-cache-dir -r requirements.txt \
16
  && python -m pip install --no-cache-dir .
17
 
18
+ RUN useradd --create-home --uid 10001 appuser \
19
+ && chown -R appuser:appuser /app
20
+
21
  EXPOSE 7860
22
 
23
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
24
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://127.0.0.1:7860/health', timeout=3)"
25
+
26
+ USER appuser
27
+
28
  CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
server/app.py CHANGED
@@ -1,17 +1,21 @@
1
  import sys
2
  from pathlib import Path
 
3
 
4
  # Ensure repo root is on sys.path so `models` and `server` are importable
5
  _repo_root = str(Path(__file__).resolve().parent.parent)
6
  if _repo_root not in sys.path:
7
  sys.path.insert(0, _repo_root)
8
 
9
- from fastapi.responses import HTMLResponse
 
 
10
  from openenv.core.env_server import create_app
11
 
12
  from models import HelpdeskTicketAction, HelpdeskTicketObservation
13
  from server.environment import HelpdeskTicketRoutingEnvironment
14
- from server.tasks import TASKS
 
15
  from vocabulary import APP_ENV_NAME
16
 
17
  app = create_app(
@@ -22,6 +26,17 @@ app = create_app(
22
  )
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  @app.get("/tasks")
26
  def list_tasks():
27
  return {
@@ -57,6 +72,109 @@ def web_ui():
57
  return HTMLResponse(content=html)
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def main() -> None:
61
  import uvicorn
62
 
 
1
  import sys
2
  from pathlib import Path
3
+ from typing import Any
4
 
5
  # Ensure repo root is on sys.path so `models` and `server` are importable
6
  _repo_root = str(Path(__file__).resolve().parent.parent)
7
  if _repo_root not in sys.path:
8
  sys.path.insert(0, _repo_root)
9
 
10
+ from fastapi import HTTPException
11
+ from pydantic import BaseModel
12
+ from fastapi.responses import HTMLResponse, RedirectResponse
13
  from openenv.core.env_server import create_app
14
 
15
  from models import HelpdeskTicketAction, HelpdeskTicketObservation
16
  from server.environment import HelpdeskTicketRoutingEnvironment
17
+ from server.grader import grade_action
18
+ from server.tasks import TASKS, load_dataset
19
  from vocabulary import APP_ENV_NAME
20
 
21
  app = create_app(
 
26
  )
27
 
28
 
29
+ class GraderRequest(BaseModel):
30
+ task_id: int
31
+ ticket_id: str
32
+ action: dict[str, Any]
33
+
34
+
35
+ @app.get("/", include_in_schema=False)
36
+ def root_redirect():
37
+ return RedirectResponse(url="/web", status_code=307)
38
+
39
+
40
  @app.get("/tasks")
41
  def list_tasks():
42
  return {
 
72
  return HTMLResponse(content=html)
73
 
74
 
75
+ def _build_baseline_submit_action(
76
+ ticket: dict[str, Any], allowed_fields: list[str]
77
+ ) -> HelpdeskTicketAction:
78
+ import inference
79
+
80
+ candidate = inference.heuristic_action(ticket, allowed_fields)
81
+ candidate, _ = inference.apply_domain_overrides(ticket, candidate, allowed_fields)
82
+ return HelpdeskTicketAction(**candidate)
83
+
84
+
85
+ @app.get("/baseline")
86
+ def baseline_rollout(task_id: int = 1, seed: int = 42):
87
+ import inference
88
+
89
+ env = HelpdeskTicketRoutingEnvironment()
90
+ observation = env.reset(seed=seed, task_id=task_id)
91
+ steps: list[dict[str, Any]] = []
92
+
93
+ while not observation.done:
94
+ ticket = observation.current_ticket
95
+ if ticket is None:
96
+ break
97
+
98
+ investigate, tool_name = inference.should_investigate(ticket, observation.history)
99
+ if (
100
+ investigate
101
+ and tool_name is not None
102
+ and observation.investigation_budget_remaining > 0
103
+ ):
104
+ investigate_action = HelpdeskTicketAction(
105
+ action_type="investigate",
106
+ tool_name=tool_name,
107
+ tool_target_ticket_id=ticket.get("related_ticket_id"),
108
+ )
109
+ observation = env.step(investigate_action)
110
+ steps.append(
111
+ {
112
+ "action": investigate_action.model_dump(exclude_none=True),
113
+ "reward": observation.reward,
114
+ "done": observation.done,
115
+ "action_source": "baseline_investigate",
116
+ }
117
+ )
118
+ if observation.done:
119
+ break
120
+ ticket = observation.current_ticket
121
+ if ticket is None:
122
+ break
123
+
124
+ action = _build_baseline_submit_action(
125
+ inference.merge_ticket_context(ticket, observation),
126
+ list(observation.allowed_fields),
127
+ )
128
+ observation = env.step(action)
129
+ steps.append(
130
+ {
131
+ "action": action.model_dump(exclude_none=True),
132
+ "reward": observation.reward,
133
+ "done": observation.done,
134
+ "action_source": "baseline_submit",
135
+ }
136
+ )
137
+
138
+ return {
139
+ "task_id": task_id,
140
+ "seed": seed,
141
+ "step_count": len(steps),
142
+ "final_reward": observation.reward,
143
+ "rubric_reward": observation.rubric_reward,
144
+ "steps": steps,
145
+ }
146
+
147
+
148
+ @app.post("/grader")
149
+ def grader_preview(request: GraderRequest):
150
+ ticket = next(
151
+ (record for record in load_dataset() if record.ticket_id == request.ticket_id),
152
+ None,
153
+ )
154
+ if ticket is None:
155
+ raise HTTPException(status_code=404, detail=f"Unknown ticket_id: {request.ticket_id}")
156
+
157
+ try:
158
+ action = HelpdeskTicketAction.model_validate(request.action)
159
+ except Exception as exc:
160
+ raise HTTPException(status_code=422, detail=str(exc)) from exc
161
+
162
+ score, breakdown = grade_action(action, ticket, request.task_id)
163
+ return {
164
+ "task_id": request.task_id,
165
+ "ticket_id": request.ticket_id,
166
+ "score": score,
167
+ "breakdown": breakdown,
168
+ "expected": {
169
+ "issue_type": ticket.issue_type,
170
+ "priority": ticket.priority,
171
+ "assignment_group": ticket.assignment_group,
172
+ "resolution_action": ticket.resolution_action,
173
+ },
174
+ "submitted": action.model_dump(exclude_none=True),
175
+ }
176
+
177
+
178
  def main() -> None:
179
  import uvicorn
180
 
server/environment.py CHANGED
@@ -13,8 +13,15 @@ from models import (
13
  HelpdeskTicketState,
14
  )
15
  from server.grader import grade_action
16
- from server.reward import compute_step_reward, compute_trajectory_reward
 
 
 
17
  from server.tasks import get_task_definition, load_dataset
 
 
 
 
18
 
19
 
20
  QUEUE_SIZE_RANGE = (3, 5)
@@ -29,6 +36,12 @@ EXTRA_INVESTIGATION_COST = 0.02
29
  MAX_EXTRA_INVESTIGATION_PENALTY = 0.15
30
  USEFUL_INVESTIGATION_REWARD = 0.08
31
  PREMATURE_SUBMIT_PENALTY = 0.10
 
 
 
 
 
 
32
 
33
  TASK3_INVESTIGATION_TOOL_PLAN: dict[str, tuple[str, ...]] = {
34
  "ticket-021": ("lookup_related_ticket", "lookup_requester_history"),
@@ -190,11 +203,16 @@ class HelpdeskTicketRoutingEnvironment(
190
  is_done = self._state.current_ticket_index >= len(self._queue)
191
  self._state.done = is_done
192
  trajectory_reward = None
 
193
  investigation_penalty = self._compute_episode_penalty() if is_done else 0.0
194
  if is_done:
195
- trajectory_reward = compute_trajectory_reward(
196
- self._state.per_ticket_scores, len(self._queue), self._state.step_count
 
 
 
197
  )
 
198
  final_reward = self._apply_episode_economics(trajectory_reward)
199
  self._state.total_reward = final_reward
200
  else:
@@ -208,6 +226,23 @@ class HelpdeskTicketRoutingEnvironment(
208
  trajectory_reward=trajectory_reward,
209
  investigation_penalty=investigation_penalty,
210
  penalty_reason=f"extra_fields: {sorted(extra_fields)}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
  self._state.history_entries.append(
213
  self._build_history_entry(
@@ -235,13 +270,30 @@ class HelpdeskTicketRoutingEnvironment(
235
  rubric_reward=final_reward if is_done else None,
236
  )
237
 
 
238
  score, breakdown = grade_action(action, current_ticket, task_id)
239
- step_reward = compute_step_reward(score)
240
- context_penalty, missing_required_tools = self._submit_context_penalty(current_ticket)
241
- milestone_adjustment = step_reward - score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
244
  trajectory_reward = None
 
245
  investigation_penalty = 0.0
246
  rubric_reward = None
247
 
@@ -250,11 +302,13 @@ class HelpdeskTicketRoutingEnvironment(
250
  self._state.average_score_so_far = self._current_average_score()
251
  self._state.step_count += 1
252
  self._state.current_ticket_index += 1
253
- trajectory_reward = compute_trajectory_reward(
254
  self._state.per_ticket_scores,
255
  len(self._queue),
256
  self._state.step_count,
 
257
  )
 
258
  rubric_reward = self._apply_episode_economics(trajectory_reward)
259
  final_reward = max(0.0, min(1.0, rubric_reward - context_penalty))
260
  self._state.total_reward = rubric_reward
@@ -272,14 +326,35 @@ class HelpdeskTicketRoutingEnvironment(
272
  shaped_step_reward=step_reward,
273
  reward_kind="trajectory" if is_done else "step",
274
  final_reward=final_reward,
275
- milestone_adjustment=milestone_adjustment,
276
  trajectory_reward=trajectory_reward,
277
  investigation_penalty=investigation_penalty,
278
  extra_details={
279
  "context_gap_penalty": context_penalty,
280
- "required_tools": self._required_tools_for_ticket(current_ticket),
281
- "remaining_required_tools": missing_required_tools,
 
 
 
 
 
 
282
  "rubric_reward": rubric_reward,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  },
284
  )
285
 
@@ -335,6 +410,35 @@ class HelpdeskTicketRoutingEnvironment(
335
  return 0.0
336
  return sum(self._state.per_ticket_scores) / len(self._state.per_ticket_scores)
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def _required_tools_for_ticket(
339
  self,
340
  ticket: HelpdeskTicketRecord,
@@ -343,7 +447,25 @@ class HelpdeskTicketRoutingEnvironment(
343
  resolved_task_id = self._state.current_task_id if task_id is None else task_id
344
  if resolved_task_id != 3:
345
  return []
346
- return list(TASK3_INVESTIGATION_TOOL_PLAN.get(ticket.ticket_id, ()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  def _used_tools_for_ticket(self, ticket_id: str) -> list[str]:
349
  return list(self._state.ticket_tool_usage.get(ticket_id, []))
@@ -362,35 +484,122 @@ class HelpdeskTicketRoutingEnvironment(
362
  if tool_name not in used:
363
  used.append(tool_name)
364
 
365
- def _investigation_hints_for_ticket(self, ticket: HelpdeskTicketRecord) -> list[str]:
366
- hints: list[str] = []
 
367
  remaining_tools = self._remaining_tools_for_ticket(ticket)
368
- if "lookup_internal_routing_note" in remaining_tools:
369
- hints.append("An internal routing note may disambiguate the correct workflow.")
370
- if "lookup_related_ticket" in remaining_tools:
371
- hints.append("A linked prior ticket can reveal important follow-up context.")
372
- if "lookup_requester_history" in remaining_tools:
373
- hints.append("Requester history may clarify severity or routing intent.")
374
- return hints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  def _visible_description(self, ticket: HelpdeskTicketRecord) -> str:
377
- if (
378
- self._state.current_task_id == 3
379
- and self._remaining_tools_for_ticket(ticket)
380
- and ticket.ticket_id in HARD_TASK_DESCRIPTION_REDACTIONS
381
- ):
382
- return HARD_TASK_DESCRIPTION_REDACTIONS[ticket.ticket_id]
383
  return ticket.description
384
 
385
- def _submit_context_penalty(self, ticket: HelpdeskTicketRecord) -> tuple[float, list[str]]:
386
- required_tools = self._required_tools_for_ticket(ticket)
387
- if not required_tools:
388
- return 0.0, []
389
- remaining_tools = self._remaining_tools_for_ticket(ticket)
390
- if not remaining_tools:
391
- return 0.0, []
392
- penalty = PREMATURE_SUBMIT_PENALTY * (len(remaining_tools) / len(required_tools))
393
- return penalty, remaining_tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  def _build_reward_components(
396
  self,
@@ -547,6 +756,7 @@ class HelpdeskTicketRoutingEnvironment(
547
  self._state.reward = investigation_reward
548
  self._state.done = False
549
  self._state.investigation_penalty_applied = self._compute_episode_penalty()
 
550
  reward_components = self._build_reward_components(
551
  ticket_score=0.0,
552
  field_breakdown={},
@@ -556,8 +766,10 @@ class HelpdeskTicketRoutingEnvironment(
556
  investigation_penalty=self._state.investigation_penalty_applied,
557
  extra_details={
558
  "new_context_revealed": useful_investigation,
559
- "required_tools": required_tools,
560
- "remaining_required_tools": self._remaining_tools_for_ticket(current_ticket),
 
 
561
  "tool_name": action.tool_name,
562
  },
563
  )
@@ -578,21 +790,22 @@ class HelpdeskTicketRoutingEnvironment(
578
  return self._build_observation(task, done=False, reward=investigation_reward)
579
 
580
  def _build_ticket_view(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
581
- required_tools = self._required_tools_for_ticket(ticket)
582
- revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
583
- remaining_tools = self._remaining_tools_for_ticket(ticket)
584
  ticket_view: dict[str, Any] = {
585
  "ticket_id": ticket.ticket_id,
586
  "title": ticket.title,
587
  "requester": ticket.requester,
588
  "description": self._visible_description(ticket),
589
  }
590
- if required_tools:
591
  ticket_view["context_status"] = {
592
  "investigation_required": True,
593
- "revealed_tools": revealed_tools,
594
- "remaining_tools": remaining_tools,
595
- "hints": self._investigation_hints_for_ticket(ticket),
 
 
596
  }
597
  if ticket.ambiguity_note is not None and "lookup_internal_routing_note" not in remaining_tools:
598
  ticket_view["ambiguity_note"] = ticket.ambiguity_note
@@ -646,9 +859,19 @@ class HelpdeskTicketRoutingEnvironment(
646
  context_gap_penalty = reward_components.get("context_gap_penalty")
647
  if context_gap_penalty:
648
  parts.append(f"context_gap_penalty={context_gap_penalty:.2f}")
649
- remaining_required_tools = reward_components.get("remaining_required_tools") or []
650
- if remaining_required_tools:
651
- parts.append(f"missing_context={remaining_required_tools}")
 
 
 
 
 
 
 
 
 
 
652
 
653
  return "; ".join(parts)
654
 
@@ -667,8 +890,8 @@ class HelpdeskTicketRoutingEnvironment(
667
  tool_result: dict[str, Any] | None = None,
668
  reward_components: dict[str, Any] | None = None,
669
  ) -> dict[str, Any]:
670
- remaining_tools = self._remaining_tools_for_ticket(ticket)
671
- revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
672
  history_entry: dict[str, Any] = {
673
  "ticket_id": ticket.ticket_id,
674
  "title": ticket.title,
@@ -702,8 +925,13 @@ class HelpdeskTicketRoutingEnvironment(
702
  history_entry["tool_result"] = tool_result
703
  if reward_components is not None:
704
  history_entry["reward_components"] = reward_components
705
- if revealed_tools:
706
- history_entry["revealed_tools"] = revealed_tools
 
 
 
 
 
707
  history_entry["feedback_summary"] = self._build_feedback_summary(
708
  predicted=predicted,
709
  score=score,
@@ -751,6 +979,10 @@ class HelpdeskTicketRoutingEnvironment(
751
  "has_related_ticket_context": bool(
752
  ticket_view and ticket_view.get("related_ticket_preview")
753
  ),
 
 
 
 
754
  "action_mode": "investigate_or_submit",
755
  "available_action_types": list(AVAILABLE_ACTION_TYPES),
756
  "average_score_so_far": self._state.average_score_so_far,
 
13
  HelpdeskTicketState,
14
  )
15
  from server.grader import grade_action
16
+ from server.reward import (
17
+ compute_step_adjustments,
18
+ compute_trajectory_adjustments,
19
+ )
20
  from server.tasks import get_task_definition, load_dataset
21
+ from vocabulary import (
22
+ ISSUE_TYPE_TO_ASSIGNMENT_GROUP,
23
+ ISSUE_TYPE_TO_RESOLUTION_ACTION,
24
+ )
25
 
26
 
27
  QUEUE_SIZE_RANGE = (3, 5)
 
36
  MAX_EXTRA_INVESTIGATION_PENALTY = 0.15
37
  USEFUL_INVESTIGATION_REWARD = 0.08
38
  PREMATURE_SUBMIT_PENALTY = 0.10
39
+ CONTEXT_COMPLETION_BONUS = 0.04
40
+ TRAJECTORY_CONTEXT_COMPLETION_BONUS = 0.03
41
+ PRIORITY_UNDERSHOOT_PENALTY = 0.03
42
+ SEVERE_PRIORITY_UNDERSHOOT_PENALTY = 0.07
43
+ DANGEROUS_RESOLUTION_PENALTY = 0.05
44
+ NONDEFAULT_ROUTING_FOLLOWTHROUGH_BONUS = 0.02
45
 
46
  TASK3_INVESTIGATION_TOOL_PLAN: dict[str, tuple[str, ...]] = {
47
  "ticket-021": ("lookup_related_ticket", "lookup_requester_history"),
 
203
  is_done = self._state.current_ticket_index >= len(self._queue)
204
  self._state.done = is_done
205
  trajectory_reward = None
206
+ trajectory_components = None
207
  investigation_penalty = self._compute_episode_penalty() if is_done else 0.0
208
  if is_done:
209
+ trajectory_components = compute_trajectory_adjustments(
210
+ self._state.per_ticket_scores,
211
+ len(self._queue),
212
+ self._state.step_count,
213
+ completion_bonus=self._trajectory_consistency_bonus(),
214
  )
215
+ trajectory_reward = trajectory_components["final_reward"]
216
  final_reward = self._apply_episode_economics(trajectory_reward)
217
  self._state.total_reward = final_reward
218
  else:
 
226
  trajectory_reward=trajectory_reward,
227
  investigation_penalty=investigation_penalty,
228
  penalty_reason=f"extra_fields: {sorted(extra_fields)}",
229
+ extra_details={
230
+ "trajectory_average_reward": (
231
+ trajectory_components["average_reward"]
232
+ if trajectory_components is not None
233
+ else None
234
+ ),
235
+ "trajectory_completion_bonus": (
236
+ trajectory_components["completion_bonus"]
237
+ if trajectory_components is not None
238
+ else None
239
+ ),
240
+ "trajectory_consistency_bonus": (
241
+ trajectory_components["consistency_bonus"]
242
+ if trajectory_components is not None
243
+ else None
244
+ ),
245
+ },
246
  )
247
  self._state.history_entries.append(
248
  self._build_history_entry(
 
270
  rubric_reward=final_reward if is_done else None,
271
  )
272
 
273
+ previous_average = self._current_average_score()
274
  score, breakdown = grade_action(action, current_ticket, task_id)
275
+ context_penalty, missing_required_count = self._submit_context_penalty(current_ticket)
276
+ process_bonus = self._context_completion_bonus(
277
+ current_ticket,
278
+ missing_required_count=missing_required_count,
279
+ score=score,
280
+ )
281
+ risk_penalty = self._operational_risk_penalty(
282
+ current_ticket,
283
+ action,
284
+ task_id=task_id,
285
+ )
286
+ step_adjustments = compute_step_adjustments(
287
+ score,
288
+ previous_average=previous_average,
289
+ process_bonus=process_bonus,
290
+ risk_penalty=risk_penalty,
291
+ )
292
+ step_reward = step_adjustments["final_reward"]
293
 
294
  is_done = (self._state.current_ticket_index + 1) >= len(self._queue)
295
  trajectory_reward = None
296
+ trajectory_components = None
297
  investigation_penalty = 0.0
298
  rubric_reward = None
299
 
 
302
  self._state.average_score_so_far = self._current_average_score()
303
  self._state.step_count += 1
304
  self._state.current_ticket_index += 1
305
+ trajectory_components = compute_trajectory_adjustments(
306
  self._state.per_ticket_scores,
307
  len(self._queue),
308
  self._state.step_count,
309
+ completion_bonus=self._trajectory_consistency_bonus(),
310
  )
311
+ trajectory_reward = trajectory_components["final_reward"]
312
  rubric_reward = self._apply_episode_economics(trajectory_reward)
313
  final_reward = max(0.0, min(1.0, rubric_reward - context_penalty))
314
  self._state.total_reward = rubric_reward
 
326
  shaped_step_reward=step_reward,
327
  reward_kind="trajectory" if is_done else "step",
328
  final_reward=final_reward,
329
+ milestone_adjustment=step_adjustments["milestone_adjustment"],
330
  trajectory_reward=trajectory_reward,
331
  investigation_penalty=investigation_penalty,
332
  extra_details={
333
  "context_gap_penalty": context_penalty,
334
+ "context_completion_bonus": process_bonus,
335
+ "risk_penalty": risk_penalty,
336
+ "delta_adjustment": step_adjustments["delta_adjustment"],
337
+ "required_investigation_count": len(self._required_tools_for_ticket(current_ticket)),
338
+ "hidden_context_remaining_count": missing_required_count,
339
+ "hidden_context_revealed_count": len(
340
+ self._used_tools_for_ticket(current_ticket.ticket_id)
341
+ ),
342
  "rubric_reward": rubric_reward,
343
+ "trajectory_average_reward": (
344
+ trajectory_components["average_reward"]
345
+ if trajectory_components is not None
346
+ else None
347
+ ),
348
+ "trajectory_completion_bonus": (
349
+ trajectory_components["completion_bonus"]
350
+ if trajectory_components is not None
351
+ else None
352
+ ),
353
+ "trajectory_consistency_bonus": (
354
+ trajectory_components["consistency_bonus"]
355
+ if trajectory_components is not None
356
+ else None
357
+ ),
358
  },
359
  )
360
 
 
410
  return 0.0
411
  return sum(self._state.per_ticket_scores) / len(self._state.per_ticket_scores)
412
 
413
+ def _ticket_has_nondefault_routing(self, ticket: HelpdeskTicketRecord) -> bool:
414
+ return (
415
+ ticket.assignment_group
416
+ != ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(ticket.issue_type, ticket.assignment_group)
417
+ or ticket.resolution_action
418
+ != ISSUE_TYPE_TO_RESOLUTION_ACTION.get(
419
+ ticket.issue_type, ticket.resolution_action
420
+ )
421
+ )
422
+
423
+ def _ticket_mentions_follow_up(self, ticket: HelpdeskTicketRecord) -> bool:
424
+ text = f"{ticket.title} {ticket.description}".lower()
425
+ return any(
426
+ phrase in text
427
+ for phrase in (
428
+ "re:",
429
+ "follow-up",
430
+ "following up",
431
+ "still",
432
+ "third update",
433
+ "reference ticket",
434
+ "regression",
435
+ "unresolved",
436
+ )
437
+ )
438
+
439
+ def _ticket_repeated_requester_count(self, ticket: HelpdeskTicketRecord) -> int:
440
+ return sum(1 for candidate in self._dataset if candidate.requester == ticket.requester)
441
+
442
  def _required_tools_for_ticket(
443
  self,
444
  ticket: HelpdeskTicketRecord,
 
447
  resolved_task_id = self._state.current_task_id if task_id is None else task_id
448
  if resolved_task_id != 3:
449
  return []
450
+ required_tools: list[str] = list(TASK3_INVESTIGATION_TOOL_PLAN.get(ticket.ticket_id, ()))
451
+ if ticket.related_ticket_id is not None and "lookup_related_ticket" not in required_tools:
452
+ required_tools.append("lookup_related_ticket")
453
+ if (
454
+ ticket.ambiguity_note is not None or self._ticket_has_nondefault_routing(ticket)
455
+ ) and "lookup_internal_routing_note" not in required_tools:
456
+ required_tools.append("lookup_internal_routing_note")
457
+ if (
458
+ self._ticket_repeated_requester_count(ticket) >= 2
459
+ and (
460
+ ticket.related_ticket_id is not None
461
+ or self._ticket_mentions_follow_up(ticket)
462
+ or self._ticket_has_nondefault_routing(ticket)
463
+ or ticket.priority in {"high", "critical"}
464
+ )
465
+ and "lookup_requester_history" not in required_tools
466
+ ):
467
+ required_tools.append("lookup_requester_history")
468
+ return required_tools
469
 
470
  def _used_tools_for_ticket(self, ticket_id: str) -> list[str]:
471
  return list(self._state.ticket_tool_usage.get(ticket_id, []))
 
484
  if tool_name not in used:
485
  used.append(tool_name)
486
 
487
+ def _tool_progress_for_ticket(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
488
+ required_tools = self._required_tools_for_ticket(ticket)
489
+ revealed_tools = self._used_tools_for_ticket(ticket.ticket_id)
490
  remaining_tools = self._remaining_tools_for_ticket(ticket)
491
+ total_required = max(1, len(required_tools))
492
+ return {
493
+ "required_tools": required_tools,
494
+ "revealed_tools": revealed_tools,
495
+ "remaining_tools": remaining_tools,
496
+ "revealed_count": len(revealed_tools),
497
+ "remaining_count": len(remaining_tools),
498
+ "completeness": round(len(revealed_tools) / total_required, 2),
499
+ }
500
+
501
+ def _default_redacted_description(self, ticket: HelpdeskTicketRecord) -> str:
502
+ if ticket.related_ticket_id is not None:
503
+ return (
504
+ "This is a follow-up operational issue that references prior work. "
505
+ "Additional routing context is available via investigation."
506
+ )
507
+ if ticket.ambiguity_note is not None:
508
+ return (
509
+ "This ticket mixes multiple plausible workflows. "
510
+ "Additional routing context is available via investigation."
511
+ )
512
+ if self._ticket_has_nondefault_routing(ticket):
513
+ return (
514
+ "The visible request looks straightforward, but the decisive routing "
515
+ "detail is hidden until investigation."
516
+ )
517
+ return (
518
+ "Additional routing context is available via investigation before final submission."
519
+ )
520
 
521
  def _visible_description(self, ticket: HelpdeskTicketRecord) -> str:
522
+ if self._state.current_task_id == 3 and self._remaining_tools_for_ticket(ticket):
523
+ return HARD_TASK_DESCRIPTION_REDACTIONS.get(
524
+ ticket.ticket_id,
525
+ self._default_redacted_description(ticket),
526
+ )
 
527
  return ticket.description
528
 
529
+ def _submit_context_penalty(self, ticket: HelpdeskTicketRecord) -> tuple[float, int]:
530
+ progress = self._tool_progress_for_ticket(ticket)
531
+ required_tools = progress["required_tools"]
532
+ remaining_tools = progress["remaining_tools"]
533
+ if not required_tools or not remaining_tools:
534
+ return 0.0, 0
535
+ penalty = PREMATURE_SUBMIT_PENALTY * (
536
+ len(remaining_tools) / max(1, len(required_tools))
537
+ )
538
+ return penalty, len(remaining_tools)
539
+
540
+ def _context_completion_bonus(
541
+ self,
542
+ ticket: HelpdeskTicketRecord,
543
+ *,
544
+ missing_required_count: int,
545
+ score: float,
546
+ ) -> float:
547
+ if not self._required_tools_for_ticket(ticket):
548
+ return 0.0
549
+ if missing_required_count != 0 or score < 0.75:
550
+ return 0.0
551
+ bonus = CONTEXT_COMPLETION_BONUS
552
+ if self._ticket_has_nondefault_routing(ticket):
553
+ bonus += NONDEFAULT_ROUTING_FOLLOWTHROUGH_BONUS
554
+ return bonus
555
+
556
+ def _trajectory_consistency_bonus(self) -> float:
557
+ if not self._queue:
558
+ return 0.0
559
+ hidden_context_tickets = [
560
+ ticket for ticket in self._queue if self._required_tools_for_ticket(ticket)
561
+ ]
562
+ if not hidden_context_tickets:
563
+ return 0.0
564
+ resolved = sum(
565
+ 1 for ticket in hidden_context_tickets if not self._remaining_tools_for_ticket(ticket)
566
+ )
567
+ resolution_rate = resolved / len(hidden_context_tickets)
568
+ return round(TRAJECTORY_CONTEXT_COMPLETION_BONUS * resolution_rate, 4)
569
+
570
+ def _operational_risk_penalty(
571
+ self,
572
+ ticket: HelpdeskTicketRecord,
573
+ action: HelpdeskTicketAction,
574
+ *,
575
+ task_id: int,
576
+ ) -> float:
577
+ if task_id < 2 or action.priority is None:
578
+ priority_penalty = 0.0
579
+ else:
580
+ priority_rank = {"critical": 3, "high": 2, "medium": 1, "low": 0}
581
+ expected_rank = priority_rank.get(ticket.priority, 0)
582
+ predicted_rank = priority_rank.get(action.priority, 0)
583
+ gap = expected_rank - predicted_rank
584
+ if gap >= 2:
585
+ priority_penalty = SEVERE_PRIORITY_UNDERSHOOT_PENALTY
586
+ elif gap == 1 and ticket.priority in {"high", "critical"}:
587
+ priority_penalty = PRIORITY_UNDERSHOOT_PENALTY
588
+ else:
589
+ priority_penalty = 0.0
590
+
591
+ resolution_penalty = 0.0
592
+ if task_id == 3 and action.resolution_action is not None:
593
+ if (
594
+ ticket.issue_type in {"identity_access", "application_support", "security_compliance"}
595
+ and ticket.priority in {"high", "critical"}
596
+ and action.resolution_action == "acknowledge"
597
+ ):
598
+ resolution_penalty += DANGEROUS_RESOLUTION_PENALTY
599
+ if ticket.issue_type == "spam_phishing" and action.resolution_action == "fulfill":
600
+ resolution_penalty += PRIORITY_UNDERSHOOT_PENALTY
601
+
602
+ return round(priority_penalty + resolution_penalty, 4)
603
 
604
  def _build_reward_components(
605
  self,
 
756
  self._state.reward = investigation_reward
757
  self._state.done = False
758
  self._state.investigation_penalty_applied = self._compute_episode_penalty()
759
+ progress = self._tool_progress_for_ticket(current_ticket)
760
  reward_components = self._build_reward_components(
761
  ticket_score=0.0,
762
  field_breakdown={},
 
766
  investigation_penalty=self._state.investigation_penalty_applied,
767
  extra_details={
768
  "new_context_revealed": useful_investigation,
769
+ "required_investigation_count": len(required_tools),
770
+ "hidden_context_remaining_count": progress["remaining_count"],
771
+ "hidden_context_revealed_count": progress["revealed_count"],
772
+ "context_completeness": progress["completeness"],
773
  "tool_name": action.tool_name,
774
  },
775
  )
 
790
  return self._build_observation(task, done=False, reward=investigation_reward)
791
 
792
  def _build_ticket_view(self, ticket: HelpdeskTicketRecord) -> dict[str, Any]:
793
+ progress = self._tool_progress_for_ticket(ticket)
794
+ remaining_tools = progress["remaining_tools"]
 
795
  ticket_view: dict[str, Any] = {
796
  "ticket_id": ticket.ticket_id,
797
  "title": ticket.title,
798
  "requester": ticket.requester,
799
  "description": self._visible_description(ticket),
800
  }
801
+ if progress["required_tools"]:
802
  ticket_view["context_status"] = {
803
  "investigation_required": True,
804
+ "hidden_context_remaining": bool(progress["remaining_count"]),
805
+ "context_gap_count": progress["remaining_count"],
806
+ "revealed_context_count": progress["revealed_count"],
807
+ "context_completeness": progress["completeness"],
808
+ "investigations_used_for_ticket": progress["revealed_count"],
809
  }
810
  if ticket.ambiguity_note is not None and "lookup_internal_routing_note" not in remaining_tools:
811
  ticket_view["ambiguity_note"] = ticket.ambiguity_note
 
859
  context_gap_penalty = reward_components.get("context_gap_penalty")
860
  if context_gap_penalty:
861
  parts.append(f"context_gap_penalty={context_gap_penalty:.2f}")
862
+ hidden_context_remaining_count = reward_components.get(
863
+ "hidden_context_remaining_count"
864
+ )
865
+ if hidden_context_remaining_count:
866
+ parts.append(
867
+ f"hidden_context_remaining={hidden_context_remaining_count}"
868
+ )
869
+ context_completion_bonus = reward_components.get("context_completion_bonus")
870
+ if context_completion_bonus:
871
+ parts.append(f"context_bonus={context_completion_bonus:.2f}")
872
+ risk_penalty = reward_components.get("risk_penalty")
873
+ if risk_penalty:
874
+ parts.append(f"risk_penalty={risk_penalty:.2f}")
875
 
876
  return "; ".join(parts)
877
 
 
890
  tool_result: dict[str, Any] | None = None,
891
  reward_components: dict[str, Any] | None = None,
892
  ) -> dict[str, Any]:
893
+ progress = self._tool_progress_for_ticket(ticket)
894
+ remaining_tools = progress["remaining_tools"]
895
  history_entry: dict[str, Any] = {
896
  "ticket_id": ticket.ticket_id,
897
  "title": ticket.title,
 
925
  history_entry["tool_result"] = tool_result
926
  if reward_components is not None:
927
  history_entry["reward_components"] = reward_components
928
+ if progress["required_tools"]:
929
+ history_entry["context_progress"] = {
930
+ "hidden_context_remaining": bool(progress["remaining_count"]),
931
+ "context_gap_count": progress["remaining_count"],
932
+ "revealed_context_count": progress["revealed_count"],
933
+ "context_completeness": progress["completeness"],
934
+ }
935
  history_entry["feedback_summary"] = self._build_feedback_summary(
936
  predicted=predicted,
937
  score=score,
 
979
  "has_related_ticket_context": bool(
980
  ticket_view and ticket_view.get("related_ticket_preview")
981
  ),
982
+ "has_hidden_context": bool(
983
+ ticket_view
984
+ and (ticket_view.get("context_status") or {}).get("hidden_context_remaining")
985
+ ),
986
  "action_mode": "investigate_or_submit",
987
  "available_action_types": list(AVAILABLE_ACTION_TYPES),
988
  "average_score_so_far": self._state.average_score_so_far,
server/reward.py CHANGED
@@ -4,21 +4,113 @@ MILESTONE_HIGH_THRESHOLD = 0.8
4
  MILESTONE_LOW_THRESHOLD = 0.2
5
  MILESTONE_BONUS = 0.05
6
  MILESTONE_PENALTY = 0.05
 
 
 
 
7
 
8
 
9
- def compute_step_reward(score: float) -> float:
10
- base = max(0.0, min(1.0, score))
 
 
 
 
 
 
 
 
 
 
 
11
  if score >= MILESTONE_HIGH_THRESHOLD:
12
- return min(1.0, base + MILESTONE_BONUS)
13
- if score < MILESTONE_LOW_THRESHOLD:
14
- return max(0.0, base - MILESTONE_PENALTY)
15
- return base
 
16
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def compute_trajectory_reward(
19
- per_ticket_scores: list[float], queue_size: int, steps_taken: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if not per_ticket_scores:
22
- return 0.0
 
 
 
 
 
23
  avg = sum(per_ticket_scores) / len(per_ticket_scores)
24
- return max(0.0, min(1.0, avg))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  MILESTONE_LOW_THRESHOLD = 0.2
5
  MILESTONE_BONUS = 0.05
6
  MILESTONE_PENALTY = 0.05
7
+ DELTA_REWARD_WEIGHT = 0.08
8
+ DELTA_REWARD_CAP = 0.04
9
+ PROCESS_BONUS_CAP = 0.08
10
+ RISK_PENALTY_CAP = 0.12
11
 
12
 
13
+ def _clamp_unit_interval(value: float) -> float:
14
+ return max(0.0, min(1.0, value))
15
+
16
+
17
+ def compute_step_adjustments(
18
+ score: float,
19
+ *,
20
+ previous_average: float = 0.0,
21
+ process_bonus: float = 0.0,
22
+ risk_penalty: float = 0.0,
23
+ ) -> dict[str, float]:
24
+ base = _clamp_unit_interval(score)
25
+
26
  if score >= MILESTONE_HIGH_THRESHOLD:
27
+ milestone_adjustment = MILESTONE_BONUS
28
+ elif score < MILESTONE_LOW_THRESHOLD:
29
+ milestone_adjustment = -MILESTONE_PENALTY
30
+ else:
31
+ milestone_adjustment = 0.0
32
 
33
+ delta_adjustment = _clamp_delta((base - previous_average) * DELTA_REWARD_WEIGHT)
34
+ bounded_process_bonus = max(0.0, min(PROCESS_BONUS_CAP, process_bonus))
35
+ bounded_risk_penalty = max(0.0, min(RISK_PENALTY_CAP, risk_penalty))
36
+ final_reward = _clamp_unit_interval(
37
+ base
38
+ + milestone_adjustment
39
+ + delta_adjustment
40
+ + bounded_process_bonus
41
+ - bounded_risk_penalty
42
+ )
43
 
44
+ return {
45
+ "base_reward": base,
46
+ "milestone_adjustment": milestone_adjustment,
47
+ "delta_adjustment": delta_adjustment,
48
+ "process_bonus": bounded_process_bonus,
49
+ "risk_penalty": bounded_risk_penalty,
50
+ "final_reward": final_reward,
51
+ }
52
+
53
+
54
+ def _clamp_delta(value: float) -> float:
55
+ return max(-DELTA_REWARD_CAP, min(DELTA_REWARD_CAP, value))
56
+
57
+
58
+ def compute_step_reward(
59
+ score: float,
60
+ *,
61
+ previous_average: float = 0.0,
62
+ process_bonus: float = 0.0,
63
+ risk_penalty: float = 0.0,
64
  ) -> float:
65
+ return compute_step_adjustments(
66
+ score,
67
+ previous_average=previous_average,
68
+ process_bonus=process_bonus,
69
+ risk_penalty=risk_penalty,
70
+ )["final_reward"]
71
+
72
+
73
+ def compute_trajectory_adjustments(
74
+ per_ticket_scores: list[float],
75
+ queue_size: int,
76
+ steps_taken: int,
77
+ *,
78
+ completion_bonus: float = 0.0,
79
+ consistency_bonus: float = 0.0,
80
+ ) -> dict[str, float]:
81
  if not per_ticket_scores:
82
+ return {
83
+ "average_reward": 0.0,
84
+ "completion_bonus": 0.0,
85
+ "consistency_bonus": 0.0,
86
+ "final_reward": 0.0,
87
+ }
88
  avg = sum(per_ticket_scores) / len(per_ticket_scores)
89
+ bounded_completion_bonus = max(0.0, min(0.08, completion_bonus))
90
+ bounded_consistency_bonus = max(0.0, min(0.05, consistency_bonus))
91
+ final_reward = _clamp_unit_interval(
92
+ avg + bounded_completion_bonus + bounded_consistency_bonus
93
+ )
94
+ return {
95
+ "average_reward": avg,
96
+ "completion_bonus": bounded_completion_bonus,
97
+ "consistency_bonus": bounded_consistency_bonus,
98
+ "final_reward": final_reward,
99
+ }
100
+
101
+
102
+ def compute_trajectory_reward(
103
+ per_ticket_scores: list[float],
104
+ queue_size: int,
105
+ steps_taken: int,
106
+ *,
107
+ completion_bonus: float = 0.0,
108
+ consistency_bonus: float = 0.0,
109
+ ) -> float:
110
+ return compute_trajectory_adjustments(
111
+ per_ticket_scores,
112
+ queue_size,
113
+ steps_taken,
114
+ completion_bonus=completion_bonus,
115
+ consistency_bonus=consistency_bonus,
116
+ )["final_reward"]
tests/test_competitive_upgrade.py CHANGED
@@ -245,27 +245,27 @@ class TestMilestoneRewardShaping(unittest.TestCase):
245
 
246
  def test_high_score_gets_bonus(self) -> None:
247
  # score=0.9 >= 0.8 threshold → base=0.9, bonus=0.05 → 0.95
248
- result = compute_step_reward(0.9)
249
  self.assertAlmostEqual(result, 0.95, places=9)
250
 
251
  def test_low_score_gets_penalty(self) -> None:
252
  # score=0.1 < 0.2 threshold → base=0.1, penalty=0.05 → 0.05
253
- result = compute_step_reward(0.1)
254
  self.assertAlmostEqual(result, 0.05, places=9)
255
 
256
  def test_mid_score_is_neutral(self) -> None:
257
  # score=0.5 is in [0.2, 0.8) → no shaping → 0.5
258
- result = compute_step_reward(0.5)
259
  self.assertAlmostEqual(result, 0.5, places=9)
260
 
261
  def test_boundary_high_threshold_gets_bonus(self) -> None:
262
  # score=0.8 exactly → bonus applies → 0.85
263
- result = compute_step_reward(0.8)
264
  self.assertAlmostEqual(result, 0.85, places=9)
265
 
266
  def test_boundary_low_threshold_is_neutral(self) -> None:
267
  # score=0.2 exactly → not < 0.2, so neutral → 0.2
268
- result = compute_step_reward(0.2)
269
  self.assertAlmostEqual(result, 0.2, places=9)
270
 
271
  def test_reward_clamped_to_unit_interval(self) -> None:
@@ -274,6 +274,11 @@ class TestMilestoneRewardShaping(unittest.TestCase):
274
  self.assertLessEqual(result, 1.0)
275
  self.assertGreaterEqual(result, 0.0)
276
 
 
 
 
 
 
277
  def test_zero_score_clamped_to_zero(self) -> None:
278
  # score=0.0 < 0.2 → base=0.0, penalty → max(0.0, -0.05) = 0.0
279
  result = compute_step_reward(0.0)
@@ -348,10 +353,8 @@ class TestAmbiguityNoteInObservation(unittest.TestCase):
348
  self.assertIsNotNone(obs.current_ticket)
349
  self.assertNotIn("ambiguity_note", obs.current_ticket)
350
  self.assertIn("context_status", obs.current_ticket)
351
- self.assertIn(
352
- "lookup_internal_routing_note",
353
- obs.current_ticket["context_status"]["remaining_tools"],
354
- )
355
 
356
  obs = env.step(
357
  HelpdeskTicketAction(
@@ -436,10 +439,8 @@ class TestRelatedTicketPreviewInObservation(unittest.TestCase):
436
  self.assertIsNotNone(obs.current_ticket)
437
  self.assertNotIn("related_ticket_preview", obs.current_ticket)
438
  self.assertIn("context_status", obs.current_ticket)
439
- self.assertIn(
440
- "lookup_related_ticket",
441
- obs.current_ticket["context_status"]["remaining_tools"],
442
- )
443
 
444
  obs = env.step(
445
  HelpdeskTicketAction(
@@ -766,8 +767,8 @@ class TestDatasetNonDefaultRouting(unittest.TestCase):
766
  if t.assignment_group != ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(t.issue_type)
767
  ]
768
  self.assertGreaterEqual(
769
- len(non_default), 3,
770
- f"Expected >= 3 non-default routing tickets, found {len(non_default)}: "
771
  + str([(t.ticket_id, t.issue_type, t.assignment_group) for t in non_default])
772
  )
773
 
 
245
 
246
  def test_high_score_gets_bonus(self) -> None:
247
  # score=0.9 >= 0.8 threshold → base=0.9, bonus=0.05 → 0.95
248
+ result = compute_step_reward(0.9, previous_average=0.9)
249
  self.assertAlmostEqual(result, 0.95, places=9)
250
 
251
  def test_low_score_gets_penalty(self) -> None:
252
  # score=0.1 < 0.2 threshold → base=0.1, penalty=0.05 → 0.05
253
+ result = compute_step_reward(0.1, previous_average=0.1)
254
  self.assertAlmostEqual(result, 0.05, places=9)
255
 
256
  def test_mid_score_is_neutral(self) -> None:
257
  # score=0.5 is in [0.2, 0.8) → no shaping → 0.5
258
+ result = compute_step_reward(0.5, previous_average=0.5)
259
  self.assertAlmostEqual(result, 0.5, places=9)
260
 
261
  def test_boundary_high_threshold_gets_bonus(self) -> None:
262
  # score=0.8 exactly → bonus applies → 0.85
263
+ result = compute_step_reward(0.8, previous_average=0.8)
264
  self.assertAlmostEqual(result, 0.85, places=9)
265
 
266
  def test_boundary_low_threshold_is_neutral(self) -> None:
267
  # score=0.2 exactly → not < 0.2, so neutral → 0.2
268
+ result = compute_step_reward(0.2, previous_average=0.2)
269
  self.assertAlmostEqual(result, 0.2, places=9)
270
 
271
  def test_reward_clamped_to_unit_interval(self) -> None:
 
274
  self.assertLessEqual(result, 1.0)
275
  self.assertGreaterEqual(result, 0.0)
276
 
277
+ def test_improvement_delta_adds_small_bonus(self) -> None:
278
+ improved = compute_step_reward(0.7, previous_average=0.2)
279
+ flat = compute_step_reward(0.7, previous_average=0.7)
280
+ self.assertGreater(improved, flat)
281
+
282
  def test_zero_score_clamped_to_zero(self) -> None:
283
  # score=0.0 < 0.2 → base=0.0, penalty → max(0.0, -0.05) = 0.0
284
  result = compute_step_reward(0.0)
 
353
  self.assertIsNotNone(obs.current_ticket)
354
  self.assertNotIn("ambiguity_note", obs.current_ticket)
355
  self.assertIn("context_status", obs.current_ticket)
356
+ self.assertTrue(obs.current_ticket["context_status"]["hidden_context_remaining"])
357
+ self.assertGreater(obs.current_ticket["context_status"]["context_gap_count"], 0)
 
 
358
 
359
  obs = env.step(
360
  HelpdeskTicketAction(
 
439
  self.assertIsNotNone(obs.current_ticket)
440
  self.assertNotIn("related_ticket_preview", obs.current_ticket)
441
  self.assertIn("context_status", obs.current_ticket)
442
+ self.assertTrue(obs.current_ticket["context_status"]["hidden_context_remaining"])
443
+ self.assertGreater(obs.current_ticket["context_status"]["context_gap_count"], 0)
 
 
444
 
445
  obs = env.step(
446
  HelpdeskTicketAction(
 
767
  if t.assignment_group != ISSUE_TYPE_TO_ASSIGNMENT_GROUP.get(t.issue_type)
768
  ]
769
  self.assertGreaterEqual(
770
+ len(non_default), 10,
771
+ f"Expected >= 10 non-default routing tickets, found {len(non_default)}: "
772
  + str([(t.ticket_id, t.issue_type, t.assignment_group) for t in non_default])
773
  )
774
 
tests/test_inference_unit.py CHANGED
@@ -129,7 +129,7 @@ class FakeEnvClient:
129
 
130
 
131
  class InferenceUnitTests(unittest.TestCase):
132
- def test_hf_token_has_no_default_and_model_name_keeps_allowed_default(self) -> None:
133
  inference = _load_inference_module()
134
 
135
  self.assertEqual(
@@ -137,9 +137,22 @@ class InferenceUnitTests(unittest.TestCase):
137
  "https://router.huggingface.co/v1",
138
  )
139
  self.assertEqual(inference.MODEL_NAME, "<your-active-model>")
 
140
  self.assertIsNone(inference.HF_TOKEN)
141
  self.assertFalse(inference.llm_mode_enabled())
142
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def test_seed_env_override_is_respected(self) -> None:
144
  inference = _load_inference_module({"SEED": "7"})
145
 
@@ -199,9 +212,11 @@ class InferenceUnitTests(unittest.TestCase):
199
  "description": "Access permissions are blocking contractor setup.",
200
  "context_status": {
201
  "investigation_required": True,
202
- "revealed_tools": [],
203
- "remaining_tools": ["lookup_internal_routing_note"],
204
- "hints": ["An internal routing note may disambiguate the correct workflow."],
 
 
205
  },
206
  "last_tool_result": {"tool_name": "lookup_requester_history", "found": False},
207
  "feedback_summary": "Ticket score=0.40; field_scores[issue_type=0.40]; reward=0.40",
@@ -475,24 +490,24 @@ class InferenceUnitTests(unittest.TestCase):
475
  self.assertEqual(merged["tickets_remaining"], 4)
476
  self.assertEqual(merged["last_tool_result"]["tool_name"], "lookup_requester_history")
477
 
478
- def test_should_investigate_uses_remaining_tools_from_context_status(self) -> None:
479
  inference = _load_inference_module()
480
 
481
  investigate, tool_name = inference.should_investigate(
482
  {
483
- "ticket_id": "ticket-021",
 
 
484
  "context_status": {
485
- "remaining_tools": [
486
- "lookup_related_ticket",
487
- "lookup_requester_history",
488
- ]
489
- },
490
  },
491
  [],
492
  )
493
 
494
  self.assertTrue(investigate)
495
- self.assertEqual(tool_name, "lookup_related_ticket")
496
 
497
 
498
  if __name__ == "__main__":
 
129
 
130
 
131
  class InferenceUnitTests(unittest.TestCase):
132
+ def test_api_credentials_have_no_defaults_and_model_name_keeps_allowed_default(self) -> None:
133
  inference = _load_inference_module()
134
 
135
  self.assertEqual(
 
137
  "https://router.huggingface.co/v1",
138
  )
139
  self.assertEqual(inference.MODEL_NAME, "<your-active-model>")
140
+ self.assertIsNone(inference.API_KEY)
141
  self.assertIsNone(inference.HF_TOKEN)
142
  self.assertFalse(inference.llm_mode_enabled())
143
 
144
+ def test_api_key_enables_llm_mode_without_hf_token(self) -> None:
145
+ inference = _load_inference_module(
146
+ {
147
+ "API_KEY": "validator-proxy-key",
148
+ "MODEL_NAME": "meta/test-model",
149
+ }
150
+ )
151
+
152
+ self.assertEqual(inference.API_KEY, "validator-proxy-key")
153
+ self.assertIsNone(inference.HF_TOKEN)
154
+ self.assertTrue(inference.llm_mode_enabled())
155
+
156
  def test_seed_env_override_is_respected(self) -> None:
157
  inference = _load_inference_module({"SEED": "7"})
158
 
 
212
  "description": "Access permissions are blocking contractor setup.",
213
  "context_status": {
214
  "investigation_required": True,
215
+ "hidden_context_remaining": True,
216
+ "context_gap_count": 1,
217
+ "revealed_context_count": 0,
218
+ "context_completeness": 0.0,
219
+ "investigations_used_for_ticket": 0,
220
  },
221
  "last_tool_result": {"tool_name": "lookup_requester_history", "found": False},
222
  "feedback_summary": "Ticket score=0.40; field_scores[issue_type=0.40]; reward=0.40",
 
490
  self.assertEqual(merged["tickets_remaining"], 4)
491
  self.assertEqual(merged["last_tool_result"]["tool_name"], "lookup_requester_history")
492
 
493
+ def test_should_investigate_uses_hidden_context_and_ticket_cues(self) -> None:
494
  inference = _load_inference_module()
495
 
496
  investigate, tool_name = inference.should_investigate(
497
  {
498
+ "ticket_id": "TKT-NONDEFAULT-003",
499
+ "title": "Contractor onboarding blocked by access issue",
500
+ "description": "Additional routing context is available via investigation.",
501
  "context_status": {
502
+ "hidden_context_remaining": True,
503
+ "context_gap_count": 1,
504
+ }
 
 
505
  },
506
  [],
507
  )
508
 
509
  self.assertTrue(investigate)
510
+ self.assertEqual(tool_name, "lookup_internal_routing_note")
511
 
512
 
513
  if __name__ == "__main__":
tests/test_policy_learning.py CHANGED
@@ -32,6 +32,7 @@ from policy_learning import (
32
  POLICY_LIBRARY,
33
  choose_policy_action,
34
  compare_policies,
 
35
  parse_int_spec,
36
  rollout_episode,
37
  search_policies,
@@ -99,35 +100,55 @@ class PolicyLearningTests(unittest.TestCase):
99
  observation = HelpdeskTicketObservation(
100
  current_ticket={
101
  "ticket_id": "ticket-021",
 
 
102
  "context_status": {
103
- "remaining_tools": ["lookup_related_ticket", "lookup_requester_history"],
104
- "revealed_tools": [],
 
 
105
  }
106
  },
107
  allowed_fields=["issue_type"],
108
  )
109
 
110
- action, source = choose_policy_action(policy, observation, {}, _context_sensitive_submit_builder)
 
 
 
 
 
 
111
 
112
  self.assertEqual(action.action_type, "investigate")
113
  self.assertEqual(action.tool_name, "lookup_related_ticket")
114
  self.assertEqual(source, "investigate_hidden_context")
 
115
 
116
  def test_choose_policy_action_submits_when_investigation_disabled(self) -> None:
117
  policy = POLICY_LIBRARY["no_investigation"]
118
  observation = HelpdeskTicketObservation(
119
  current_ticket={
120
  "ticket_id": "ticket-021",
121
- "context_status": {"remaining_tools": ["lookup_related_ticket"]},
 
 
122
  },
123
  allowed_fields=["issue_type", "priority"],
124
  )
125
 
126
- action, source = choose_policy_action(policy, observation, {}, _context_sensitive_submit_builder)
 
 
 
 
 
 
127
 
128
  self.assertEqual(action.action_type, "submit")
129
  self.assertEqual(action.issue_type, "identity_access")
130
  self.assertEqual(source, "submit")
 
131
 
132
  def test_rollout_episode_rewards_context_aware_policy(self) -> None:
133
  no_investigation = POLICY_LIBRARY["no_investigation"]
@@ -152,11 +173,11 @@ class PolicyLearningTests(unittest.TestCase):
152
  self.assertLess(no_summary["normalized_return"], context_summary["normalized_return"])
153
  self.assertEqual(context_summary["investigation_steps"], 1)
154
 
155
- def test_search_policies_selects_better_policy(self) -> None:
156
  report = search_policies(
157
  [
158
  POLICY_LIBRARY["no_investigation"],
159
- POLICY_LIBRARY["investigate_when_context_hidden"],
160
  ],
161
  train_seeds=[41, 42],
162
  eval_seeds=[43],
@@ -166,17 +187,18 @@ class PolicyLearningTests(unittest.TestCase):
166
  submit_builder=_context_sensitive_submit_builder,
167
  )
168
 
169
- self.assertEqual(report["selected_policy"], "investigate_when_context_hidden")
170
  self.assertGreater(
171
  report["eval_improvement_vs_baseline"]["avg_normalized_return"],
172
  0.0,
173
  )
 
174
 
175
  def test_compare_policies_reports_improvement(self) -> None:
176
  report = compare_policies(
177
  [
178
  POLICY_LIBRARY["no_investigation"],
179
- POLICY_LIBRARY["investigate_when_context_hidden"],
180
  ],
181
  seeds=[42],
182
  task_ids=[3],
@@ -185,9 +207,18 @@ class PolicyLearningTests(unittest.TestCase):
185
  submit_builder=_context_sensitive_submit_builder,
186
  )
187
 
188
- self.assertEqual(report["best_policy"], "investigate_when_context_hidden")
189
  self.assertGreater(report["improvement_vs_baseline"]["avg_terminal_reward"], 0.0)
190
 
 
 
 
 
 
 
 
 
 
191
 
192
  if __name__ == "__main__":
193
  unittest.main()
 
32
  POLICY_LIBRARY,
33
  choose_policy_action,
34
  compare_policies,
35
+ infer_ticket_cue,
36
  parse_int_spec,
37
  rollout_episode,
38
  search_policies,
 
100
  observation = HelpdeskTicketObservation(
101
  current_ticket={
102
  "ticket_id": "ticket-021",
103
+ "title": "Re: Production checkout throwing null reference exception",
104
+ "description": "Additional routing context is available via investigation.",
105
  "context_status": {
106
+ "hidden_context_remaining": True,
107
+ "context_gap_count": 2,
108
+ "revealed_context_count": 0,
109
+ "context_completeness": 0.0,
110
  }
111
  },
112
  allowed_fields=["issue_type"],
113
  )
114
 
115
+ action, source, cue = choose_policy_action(
116
+ policy,
117
+ observation,
118
+ {},
119
+ _context_sensitive_submit_builder,
120
+ used_tools_by_ticket={},
121
+ )
122
 
123
  self.assertEqual(action.action_type, "investigate")
124
  self.assertEqual(action.tool_name, "lookup_related_ticket")
125
  self.assertEqual(source, "investigate_hidden_context")
126
+ self.assertEqual(cue, "follow_up")
127
 
128
  def test_choose_policy_action_submits_when_investigation_disabled(self) -> None:
129
  policy = POLICY_LIBRARY["no_investigation"]
130
  observation = HelpdeskTicketObservation(
131
  current_ticket={
132
  "ticket_id": "ticket-021",
133
+ "title": "Re: Production checkout throwing null reference exception",
134
+ "description": "Additional routing context is available via investigation.",
135
+ "context_status": {"hidden_context_remaining": True, "context_gap_count": 1},
136
  },
137
  allowed_fields=["issue_type", "priority"],
138
  )
139
 
140
+ action, source, cue = choose_policy_action(
141
+ policy,
142
+ observation,
143
+ {},
144
+ _context_sensitive_submit_builder,
145
+ used_tools_by_ticket={},
146
+ )
147
 
148
  self.assertEqual(action.action_type, "submit")
149
  self.assertEqual(action.issue_type, "identity_access")
150
  self.assertEqual(source, "submit")
151
+ self.assertIsNone(cue)
152
 
153
  def test_rollout_episode_rewards_context_aware_policy(self) -> None:
154
  no_investigation = POLICY_LIBRARY["no_investigation"]
 
173
  self.assertLess(no_summary["normalized_return"], context_summary["normalized_return"])
174
  self.assertEqual(context_summary["investigation_steps"], 1)
175
 
176
+ def test_search_policies_selects_adaptive_policy(self) -> None:
177
  report = search_policies(
178
  [
179
  POLICY_LIBRARY["no_investigation"],
180
+ POLICY_LIBRARY["adaptive_cue_bandit"],
181
  ],
182
  train_seeds=[41, 42],
183
  eval_seeds=[43],
 
187
  submit_builder=_context_sensitive_submit_builder,
188
  )
189
 
190
+ self.assertEqual(report["selected_policy"], "adaptive_cue_bandit")
191
  self.assertGreater(
192
  report["eval_improvement_vs_baseline"]["avg_normalized_return"],
193
  0.0,
194
  )
195
+ self.assertIn("adaptive_cue_bandit", report["trained_adaptive_bandits"])
196
 
197
  def test_compare_policies_reports_improvement(self) -> None:
198
  report = compare_policies(
199
  [
200
  POLICY_LIBRARY["no_investigation"],
201
+ POLICY_LIBRARY["adaptive_cue_bandit"],
202
  ],
203
  seeds=[42],
204
  task_ids=[3],
 
207
  submit_builder=_context_sensitive_submit_builder,
208
  )
209
 
210
+ self.assertEqual(report["best_policy"], "adaptive_cue_bandit")
211
  self.assertGreater(report["improvement_vs_baseline"]["avg_terminal_reward"], 0.0)
212
 
213
+ def test_infer_ticket_cue_distinguishes_workflow_blocker(self) -> None:
214
+ cue = infer_ticket_cue(
215
+ {
216
+ "title": "Contractor onboarding blocked by access issue",
217
+ "description": "A contractor onboarding workflow is blocked by a permissions error.",
218
+ }
219
+ )
220
+ self.assertEqual(cue, "workflow_blocker")
221
+
222
 
223
  if __name__ == "__main__":
224
  unittest.main()
tests/test_real_openenv_integration.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import unittest
6
+
7
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
8
+
9
+ REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
10
+ SITE_PACKAGES = os.path.join(REPO_ROOT, ".venv", "Lib", "site-packages")
11
+ if SITE_PACKAGES not in sys.path:
12
+ sys.path.insert(0, SITE_PACKAGES)
13
+
14
+ for module_name in list(sys.modules):
15
+ if module_name == "openenv" or module_name.startswith("openenv."):
16
+ del sys.modules[module_name]
17
+ for module_name in list(sys.modules):
18
+ if module_name in {"models", "server.app", "server.environment", "client"}:
19
+ del sys.modules[module_name]
20
+
21
+ try:
22
+ from starlette.testclient import TestClient
23
+ from server.app import app
24
+
25
+ REAL_OPENENV_AVAILABLE = True
26
+ IMPORT_ERROR: Exception | None = None
27
+ except Exception as exc: # pragma: no cover - only used for skip messaging
28
+ REAL_OPENENV_AVAILABLE = False
29
+ IMPORT_ERROR = exc
30
+
31
+
32
+ @unittest.skipUnless(
33
+ REAL_OPENENV_AVAILABLE,
34
+ f"real OpenEnv stack unavailable: {IMPORT_ERROR}",
35
+ )
36
+ class RealOpenEnvIntegrationTests(unittest.TestCase):
37
+ @classmethod
38
+ def setUpClass(cls) -> None:
39
+ cls.client = TestClient(app)
40
+
41
+ def test_root_redirects_to_web(self) -> None:
42
+ response = self.client.get("/", follow_redirects=False)
43
+ self.assertEqual(response.status_code, 307)
44
+ self.assertEqual(response.headers["location"], "/web")
45
+
46
+ def test_grader_endpoint_scores_known_action(self) -> None:
47
+ response = self.client.post(
48
+ "/grader",
49
+ json={
50
+ "task_id": 3,
51
+ "ticket_id": "ticket-002",
52
+ "action": {
53
+ "issue_type": "identity_access",
54
+ "priority": "high",
55
+ "assignment_group": "service_desk",
56
+ "resolution_action": "fulfill",
57
+ },
58
+ },
59
+ )
60
+ self.assertEqual(response.status_code, 200)
61
+ payload = response.json()
62
+ self.assertEqual(payload["score"], 1.0)
63
+ self.assertEqual(payload["breakdown"]["issue_type"], 1.0)
64
+
65
+ def test_baseline_endpoint_runs_episode(self) -> None:
66
+ response = self.client.get("/baseline", params={"task_id": 3, "seed": 42})
67
+ self.assertEqual(response.status_code, 200)
68
+ payload = response.json()
69
+ self.assertEqual(payload["task_id"], 3)
70
+ self.assertGreater(payload["step_count"], 0)
71
+ self.assertIn("steps", payload)
72
+ self.assertIsInstance(payload["steps"], list)
73
+
74
+ def test_websocket_round_trip_reset_state_step(self) -> None:
75
+ with self.client.websocket_connect("/ws") as websocket:
76
+ websocket.send_json({"type": "reset", "data": {"task_id": 1, "seed": 42}})
77
+ reset_message = websocket.receive_json()
78
+ self.assertEqual(reset_message["type"], "observation")
79
+ reset_payload = reset_message["data"]
80
+ reset_obs = reset_payload.get("observation", reset_payload)
81
+ self.assertEqual(reset_obs["task_id"], 1)
82
+ self.assertFalse(reset_payload.get("done", reset_obs.get("done", False)))
83
+
84
+ websocket.send_json({"type": "state"})
85
+ state_message = websocket.receive_json()
86
+ self.assertEqual(state_message["type"], "state")
87
+ self.assertEqual(state_message["data"]["current_task_id"], 1)
88
+
89
+ websocket.send_json(
90
+ {
91
+ "type": "step",
92
+ "data": {
93
+ "issue_type": "billing_license",
94
+ },
95
+ }
96
+ )
97
+ step_message = websocket.receive_json()
98
+ self.assertEqual(step_message["type"], "observation")
99
+ step_payload = step_message["data"]
100
+ step_obs = step_payload.get("observation", step_payload)
101
+ reward = step_payload.get("reward", step_obs.get("reward"))
102
+ self.assertGreaterEqual(reward, 0.0)
103
+ self.assertLessEqual(reward, 1.0)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ unittest.main()