chore: prepare Round-1 OpenEnv submission (validator, evaluator, tests, CI, docs)
Browse files- .github/workflows/openenv-validation.yml +40 -10
- .gitignore +1 -1
- CONTRIBUTING.md +20 -0
- LICENSE +18 -0
- MANIFEST.in +8 -0
- README.md +29 -0
- env/__init__.py +0 -0
- env/environment.py +110 -0
- env/graders.py +69 -0
- env/models.py +37 -0
- env/tasks.py +66 -0
- evaluate.py +51 -0
- inference.py +72 -24
- server/app.py +26 -16
- tests/conftest.py +7 -0
- tests/test_environment.py +81 -0
.github/workflows/openenv-validation.yml
CHANGED
|
@@ -18,20 +18,50 @@ jobs:
|
|
| 18 |
uses: actions/setup-python@v4
|
| 19 |
with:
|
| 20 |
python-version: '3.11'
|
| 21 |
-
|
| 22 |
-
- name: Install dependencies
|
| 23 |
run: |
|
| 24 |
python -m pip install --upgrade pip
|
| 25 |
-
pip install
|
| 26 |
-
pip install
|
| 27 |
|
| 28 |
-
- name:
|
| 29 |
-
run:
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
- name: Run OpenEnv
|
| 32 |
run: |
|
| 33 |
openenv validate .
|
| 34 |
-
|
| 35 |
-
- name:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
run: |
|
| 37 |
-
docker build -t test-openenv .
|
|
|
|
| 18 |
uses: actions/setup-python@v4
|
| 19 |
with:
|
| 20 |
python-version: '3.11'
|
| 21 |
+
|
| 22 |
+
- name: Install dependencies
|
| 23 |
run: |
|
| 24 |
python -m pip install --upgrade pip
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
pip install -e .
|
| 27 |
|
| 28 |
+
- name: Install OpenEnv validator
|
| 29 |
+
run: |
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
pip install openenv-core
|
| 32 |
|
| 33 |
+
- name: Run OpenEnv validator
|
| 34 |
run: |
|
| 35 |
openenv validate .
|
| 36 |
+
|
| 37 |
+
- name: Run tests
|
| 38 |
+
run: |
|
| 39 |
+
python -m pip install pytest
|
| 40 |
+
pytest -q
|
| 41 |
+
|
| 42 |
+
lint:
|
| 43 |
+
runs-on: ubuntu-latest
|
| 44 |
+
needs: validate
|
| 45 |
+
steps:
|
| 46 |
+
- name: Checkout Repository
|
| 47 |
+
uses: actions/checkout@v4
|
| 48 |
+
|
| 49 |
+
- name: Set up Python
|
| 50 |
+
uses: actions/setup-python@v4
|
| 51 |
+
with:
|
| 52 |
+
python-version: '3.11'
|
| 53 |
+
|
| 54 |
+
- name: Install lint tools
|
| 55 |
+
run: |
|
| 56 |
+
python -m pip install --upgrade pip
|
| 57 |
+
pip install ruff mypy
|
| 58 |
+
|
| 59 |
+
- name: Run ruff
|
| 60 |
+
run: ruff check .
|
| 61 |
+
|
| 62 |
+
- name: Run mypy
|
| 63 |
+
run: mypy --ignore-missing-imports . || echo "mypy found issues"
|
| 64 |
+
|
| 65 |
+
- name: Verify Docker Builds (optional)
|
| 66 |
run: |
|
| 67 |
+
docker build -t test-openenv . || echo "Docker build failed or not available on runner"
|
.gitignore
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# Virtual Environments
|
| 2 |
.venv/
|
| 3 |
venv/
|
| 4 |
-
env/
|
| 5 |
|
| 6 |
# Python caching
|
| 7 |
__pycache__/
|
|
|
|
| 1 |
# Virtual Environments
|
| 2 |
.venv/
|
| 3 |
venv/
|
| 4 |
+
# Note: `env/` is the package source directory for this project and must NOT be ignored
|
| 5 |
|
| 6 |
# Python caching
|
| 7 |
__pycache__/
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Contributing
|
| 2 |
+
|
| 3 |
+
Run tests:
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
python -m venv .venv
|
| 7 |
+
source .venv/bin/activate
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
pip install -e .
|
| 10 |
+
pip install pytest
|
| 11 |
+
pytest -q
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
To run the API locally:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
Please open PRs against `main`. Add tests for new behavior and keep changes small and focused.
|
LICENSE
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 13 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 14 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 15 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 16 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 17 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 18 |
+
SOFTWARE.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include README.md
|
| 2 |
+
include PRD.md
|
| 3 |
+
include openenv.yaml
|
| 4 |
+
include LICENSE
|
| 5 |
+
include CONTRIBUTING.md
|
| 6 |
+
recursive-include env *.py
|
| 7 |
+
recursive-include server *.py
|
| 8 |
+
recursive-include tests *.py
|
README.md
CHANGED
|
@@ -39,3 +39,32 @@ export OPENAI_API_KEY="your-key"
|
|
| 39 |
export MODEL_NAME="gpt-4o"
|
| 40 |
python inference.py
|
| 41 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
export MODEL_NAME="gpt-4o"
|
| 40 |
python inference.py
|
| 41 |
```
|
| 42 |
+
|
| 43 |
+
Evaluation harness
|
| 44 |
+
------------------
|
| 45 |
+
To reproduce grader outputs for Round 1, run the lightweight evaluator which executes the canonical correct action sequences:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
source .venv/bin/activate
|
| 49 |
+
pip install -r requirements.txt
|
| 50 |
+
pip install -e .
|
| 51 |
+
python evaluate.py
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Packaging notes
|
| 55 |
+
---------------
|
| 56 |
+
This project includes `env/` as the package containing the OpenEnv environment. We include `openenv.yaml` and `PRD.md` in the source distribution to ensure validator and reviewers can find metadata.
|
| 57 |
+
|
| 58 |
+
Developer setup (recommended)
|
| 59 |
+
-----------------------------
|
| 60 |
+
For reviewers or contributors, it's helpful to install the package in editable mode so imports resolve and tests run without extra environment variables:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
python -m venv .venv
|
| 64 |
+
source .venv/bin/activate
|
| 65 |
+
pip install -r requirements.txt
|
| 66 |
+
pip install -e .
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
This ensures `pytest` and local imports work out-of-the-box.
|
| 70 |
+
|
env/__init__.py
ADDED
|
File without changes
|
env/environment.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Dict, Any, Optional
|
| 2 |
+
from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
|
| 3 |
+
from .tasks import TASKS
|
| 4 |
+
from .graders import grade
|
| 5 |
+
|
| 6 |
+
class SupportTicketEnv:
|
| 7 |
+
def __init__(self, task_id: str = "task_easy_1"):
|
| 8 |
+
self.task_id = task_id
|
| 9 |
+
if task_id not in TASKS:
|
| 10 |
+
raise ValueError(f"Unknown task_id: {task_id}")
|
| 11 |
+
self.task_data = TASKS[task_id]
|
| 12 |
+
self.state = None
|
| 13 |
+
self.max_steps = 10
|
| 14 |
+
self.reset()
|
| 15 |
+
|
| 16 |
+
def reset(self) -> Observation:
|
| 17 |
+
ticket_data = self.task_data["ticket"]
|
| 18 |
+
self.state = EnvironmentState(
|
| 19 |
+
current_task_id=self.task_id,
|
| 20 |
+
step_count=0,
|
| 21 |
+
ticket=TicketInfo(**ticket_data),
|
| 22 |
+
action_history=[],
|
| 23 |
+
is_done=False,
|
| 24 |
+
final_reward=0.0,
|
| 25 |
+
task_difficulty=self.task_data["difficulty"]
|
| 26 |
+
)
|
| 27 |
+
return self._get_observation("System initialized. Ticket assigned.")
|
| 28 |
+
|
| 29 |
+
def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
|
| 30 |
+
return Observation(
|
| 31 |
+
ticket=self.state.ticket,
|
| 32 |
+
available_actions=[
|
| 33 |
+
"fetch_user_data", "check_policy", "issue_refund",
|
| 34 |
+
"reply_to_customer", "escalate", "close_ticket"
|
| 35 |
+
],
|
| 36 |
+
system_message=system_message,
|
| 37 |
+
history=[f"{a.action_type}({a.parameters})" for a in self.state.action_history],
|
| 38 |
+
tool_output=tool_output,
|
| 39 |
+
step_count=self.state.step_count
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
|
| 43 |
+
if self.state.is_done:
|
| 44 |
+
return self._get_observation("Episode is over."), 0.0, True, {}
|
| 45 |
+
|
| 46 |
+
self.state.step_count += 1
|
| 47 |
+
self.state.action_history.append(action)
|
| 48 |
+
|
| 49 |
+
tool_output = None
|
| 50 |
+
system_message = f"Action {action.action_type} executed."
|
| 51 |
+
|
| 52 |
+
# Execute action logic
|
| 53 |
+
if action.action_type == "fetch_user_data":
|
| 54 |
+
user_id = action.parameters.get("user_id")
|
| 55 |
+
if user_id == self.state.ticket.user_id:
|
| 56 |
+
self.state.user_data = UserData(**self.task_data["user_data"])
|
| 57 |
+
tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}"
|
| 58 |
+
else:
|
| 59 |
+
tool_output = "Error: Invalid user_id."
|
| 60 |
+
system_message = "Failed to fetch user data."
|
| 61 |
+
|
| 62 |
+
elif action.action_type == "check_policy":
|
| 63 |
+
issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
|
| 64 |
+
policy = self.task_data["policy"].get(issue_type, "No specific policy found.")
|
| 65 |
+
tool_output = f"Policy for {issue_type}: {policy}"
|
| 66 |
+
|
| 67 |
+
elif action.action_type == "issue_refund":
|
| 68 |
+
amount = action.parameters.get("amount", "fully")
|
| 69 |
+
tool_output = f"Refund issued for {amount}."
|
| 70 |
+
|
| 71 |
+
elif action.action_type == "reply_to_customer":
|
| 72 |
+
msg = action.parameters.get("message", "")
|
| 73 |
+
tool_output = f"Replied: '{msg}'"
|
| 74 |
+
|
| 75 |
+
elif action.action_type == "escalate":
|
| 76 |
+
reason = action.parameters.get("reason", "support_tier2")
|
| 77 |
+
tool_output = f"Escalated to {reason}."
|
| 78 |
+
self.state.ticket.status = "escalated"
|
| 79 |
+
self.state.is_done = True
|
| 80 |
+
|
| 81 |
+
elif action.action_type == "close_ticket":
|
| 82 |
+
res = action.parameters.get("resolution", "")
|
| 83 |
+
tool_output = f"Ticket closed. Resolution: {res}"
|
| 84 |
+
self.state.ticket.status = "closed"
|
| 85 |
+
self.state.is_done = True
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
tool_output = "Invalid action."
|
| 89 |
+
system_message = "Action unrecognized."
|
| 90 |
+
|
| 91 |
+
# Check termination
|
| 92 |
+
if self.state.step_count >= self.max_steps:
|
| 93 |
+
self.state.is_done = True
|
| 94 |
+
system_message = "Max steps reached."
|
| 95 |
+
|
| 96 |
+
# Calculate intermediate/final reward
|
| 97 |
+
reward = 0.0
|
| 98 |
+
if self.state.is_done:
|
| 99 |
+
reward = grade(self.state)
|
| 100 |
+
self.state.final_reward = reward
|
| 101 |
+
|
| 102 |
+
info = {
|
| 103 |
+
"current_reward": reward,
|
| 104 |
+
"step_count": self.state.step_count
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
|
| 108 |
+
|
| 109 |
+
def get_state(self) -> EnvironmentState:
|
| 110 |
+
return self.state
|
env/graders.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import EnvironmentState
|
| 2 |
+
|
| 3 |
+
def grade_easy(state: EnvironmentState) -> float:
|
| 4 |
+
# Requires: check_policy, issue_refund, close_ticket
|
| 5 |
+
reward = 0.0
|
| 6 |
+
actions = [a.action_type for a in state.action_history]
|
| 7 |
+
if "check_policy" in actions:
|
| 8 |
+
reward += 0.2
|
| 9 |
+
if "issue_refund" in actions:
|
| 10 |
+
reward += 0.5
|
| 11 |
+
if "close_ticket" in actions:
|
| 12 |
+
reward += 0.3
|
| 13 |
+
|
| 14 |
+
if "escalate" in actions:
|
| 15 |
+
reward -= 0.5 # penalty for unnecessary escalation
|
| 16 |
+
return max(0.0, min(1.0, reward))
|
| 17 |
+
|
| 18 |
+
def grade_medium(state: EnvironmentState) -> float:
|
| 19 |
+
# Requires: check_policy, reply_to_customer (explaining policy), close_ticket
|
| 20 |
+
# NO refund should be issued.
|
| 21 |
+
reward = 0.0
|
| 22 |
+
actions = [a.action_type for a in state.action_history]
|
| 23 |
+
|
| 24 |
+
if "check_policy" in actions:
|
| 25 |
+
reward += 0.3
|
| 26 |
+
if "reply_to_customer" in actions:
|
| 27 |
+
reward += 0.4
|
| 28 |
+
if "close_ticket" in actions:
|
| 29 |
+
reward += 0.3
|
| 30 |
+
|
| 31 |
+
if "issue_refund" in actions: # fatal mistake
|
| 32 |
+
return 0.0
|
| 33 |
+
|
| 34 |
+
return max(0.0, min(1.0, reward))
|
| 35 |
+
|
| 36 |
+
def grade_hard(state: EnvironmentState) -> float:
|
| 37 |
+
# Requires: fetch_user_data, escalate to "billing_tier2", reply_to_customer
|
| 38 |
+
reward = 0.0
|
| 39 |
+
actions = [a.action_type for a in state.action_history]
|
| 40 |
+
|
| 41 |
+
if "fetch_user_data" in actions:
|
| 42 |
+
reward += 0.2
|
| 43 |
+
|
| 44 |
+
escalated = False
|
| 45 |
+
for a in state.action_history:
|
| 46 |
+
if a.action_type == "escalate" and a.parameters.get("reason") == "billing_tier2":
|
| 47 |
+
escalated = True
|
| 48 |
+
|
| 49 |
+
if escalated:
|
| 50 |
+
reward += 0.5
|
| 51 |
+
|
| 52 |
+
if "reply_to_customer" in actions:
|
| 53 |
+
reward += 0.3
|
| 54 |
+
|
| 55 |
+
if "issue_refund" in actions:
|
| 56 |
+
reward -= 0.5 # can't refund enterprise double charges directly
|
| 57 |
+
if "close_ticket" in actions:
|
| 58 |
+
reward -= 0.3 # can't close without resolving escalate
|
| 59 |
+
|
| 60 |
+
return max(0.0, min(1.0, reward))
|
| 61 |
+
|
| 62 |
+
def grade(state: EnvironmentState) -> float:
|
| 63 |
+
if state.task_difficulty == "easy":
|
| 64 |
+
return grade_easy(state)
|
| 65 |
+
elif state.task_difficulty == "medium":
|
| 66 |
+
return grade_medium(state)
|
| 67 |
+
elif state.task_difficulty == "hard":
|
| 68 |
+
return grade_hard(state)
|
| 69 |
+
return 0.0
|
env/models.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional, Literal, Dict, Any
|
| 3 |
+
|
| 4 |
+
class TicketInfo(BaseModel):
|
| 5 |
+
ticket_id: str
|
| 6 |
+
user_id: str
|
| 7 |
+
issue_type: str
|
| 8 |
+
subject: str
|
| 9 |
+
body: str
|
| 10 |
+
status: str
|
| 11 |
+
|
| 12 |
+
class UserData(BaseModel):
|
| 13 |
+
user_id: str
|
| 14 |
+
account_tier: str
|
| 15 |
+
join_date: str
|
| 16 |
+
|
| 17 |
+
class Action(BaseModel):
|
| 18 |
+
action_type: Literal["fetch_user_data", "check_policy", "issue_refund", "reply_to_customer", "escalate", "close_ticket"]
|
| 19 |
+
parameters: Dict[str, Any] = Field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
class Observation(BaseModel):
|
| 22 |
+
ticket: TicketInfo
|
| 23 |
+
available_actions: List[str]
|
| 24 |
+
system_message: str
|
| 25 |
+
history: List[str]
|
| 26 |
+
tool_output: Optional[str] = None
|
| 27 |
+
step_count: int
|
| 28 |
+
|
| 29 |
+
class EnvironmentState(BaseModel):
|
| 30 |
+
current_task_id: str
|
| 31 |
+
step_count: int
|
| 32 |
+
ticket: TicketInfo
|
| 33 |
+
user_data: Optional[UserData] = None
|
| 34 |
+
action_history: List[Action]
|
| 35 |
+
is_done: bool
|
| 36 |
+
final_reward: float
|
| 37 |
+
task_difficulty: str
|
env/tasks.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
class Difficulty(Enum):
|
| 4 |
+
EASY = "easy"
|
| 5 |
+
MEDIUM = "medium"
|
| 6 |
+
HARD = "hard"
|
| 7 |
+
|
| 8 |
+
TASKS = {
|
| 9 |
+
"task_easy_1": {
|
| 10 |
+
"difficulty": Difficulty.EASY.value,
|
| 11 |
+
"ticket": {
|
| 12 |
+
"ticket_id": "TKT-1001",
|
| 13 |
+
"user_id": "USR-A1",
|
| 14 |
+
"issue_type": "refund_request",
|
| 15 |
+
"subject": "Accidental purchase",
|
| 16 |
+
"body": "I clicked buy by mistake on the Premium plan today. Can I get a refund?",
|
| 17 |
+
"status": "open"
|
| 18 |
+
},
|
| 19 |
+
"user_data": {
|
| 20 |
+
"user_id": "USR-A1",
|
| 21 |
+
"account_tier": "premium",
|
| 22 |
+
"join_date": "2023-01-15"
|
| 23 |
+
},
|
| 24 |
+
"policy": {
|
| 25 |
+
"refund_request": "If requested within 7 days of accidental purchase, issue full refund."
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
"task_medium_1": {
|
| 29 |
+
"difficulty": Difficulty.MEDIUM.value,
|
| 30 |
+
"ticket": {
|
| 31 |
+
"ticket_id": "TKT-2002",
|
| 32 |
+
"user_id": "USR-B2",
|
| 33 |
+
"issue_type": "refund_request",
|
| 34 |
+
"subject": "Refund for last year",
|
| 35 |
+
"body": "I didn't use my account much last year, please refund the annual fee.",
|
| 36 |
+
"status": "open"
|
| 37 |
+
},
|
| 38 |
+
"user_data": {
|
| 39 |
+
"user_id": "USR-B2",
|
| 40 |
+
"account_tier": "standard",
|
| 41 |
+
"join_date": "2021-05-20"
|
| 42 |
+
},
|
| 43 |
+
"policy": {
|
| 44 |
+
"refund_request": "Strictly no refunds for unused time from previous billing cycles. Explain policy and close ticket."
|
| 45 |
+
}
|
| 46 |
+
},
|
| 47 |
+
"task_hard_1": {
|
| 48 |
+
"difficulty": Difficulty.HARD.value,
|
| 49 |
+
"ticket": {
|
| 50 |
+
"ticket_id": "TKT-3003",
|
| 51 |
+
"user_id": "USR-C3",
|
| 52 |
+
"issue_type": "billing_discrepancy",
|
| 53 |
+
"subject": "Double charged again!",
|
| 54 |
+
"body": "This is the third month in a row I've been charged twice! Fix this or I'm leaving.",
|
| 55 |
+
"status": "open"
|
| 56 |
+
},
|
| 57 |
+
"user_data": {
|
| 58 |
+
"user_id": "USR-C3",
|
| 59 |
+
"account_tier": "enterprise",
|
| 60 |
+
"join_date": "2019-11-01"
|
| 61 |
+
},
|
| 62 |
+
"policy": {
|
| 63 |
+
"billing_discrepancy": "For enterprise clients with recurring double charges, fetch user data, escalate immediately to billing_tier2, and reply to customer apologizing for the delay."
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
}
|
evaluate.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Small evaluation harness that executes the expected action sequence for each task
|
| 2 |
+
and prints a JSON summary of grader scores. Use this to reproduce Round-1 evaluation outputs.
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
from env.environment import SupportTicketEnv
|
| 6 |
+
from env.models import Action
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
EXPECTED_ACTIONS = {
|
| 10 |
+
"task_easy_1": [
|
| 11 |
+
Action(action_type="check_policy", parameters={}),
|
| 12 |
+
Action(action_type="issue_refund", parameters={"amount": "full"}),
|
| 13 |
+
Action(action_type="close_ticket", parameters={"resolution": "refunded"}),
|
| 14 |
+
],
|
| 15 |
+
"task_medium_1": [
|
| 16 |
+
Action(action_type="check_policy", parameters={}),
|
| 17 |
+
Action(action_type="reply_to_customer", parameters={"message": "Policy explained - no refund"}),
|
| 18 |
+
Action(action_type="close_ticket", parameters={"resolution": "policy_explained"}),
|
| 19 |
+
],
|
| 20 |
+
"task_hard_1": [
|
| 21 |
+
Action(action_type="fetch_user_data", parameters={"user_id": "USR-C3"}),
|
| 22 |
+
Action(action_type="escalate", parameters={"reason": "billing_tier2"}),
|
| 23 |
+
Action(action_type="reply_to_customer", parameters={"message": "We're escalating this to billing tier 2 and will follow up."}),
|
| 24 |
+
],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def run_sequence(task_id: str, actions):
|
| 29 |
+
env = SupportTicketEnv(task_id=task_id)
|
| 30 |
+
env.reset()
|
| 31 |
+
final_reward = 0.0
|
| 32 |
+
done = False
|
| 33 |
+
for a in actions:
|
| 34 |
+
obs, reward, done, info = env.step(a)
|
| 35 |
+
final_reward = info.get("current_reward", final_reward)
|
| 36 |
+
if done:
|
| 37 |
+
break
|
| 38 |
+
return final_reward
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
results = {}
|
| 43 |
+
for task_id, actions in EXPECTED_ACTIONS.items():
|
| 44 |
+
score = run_sequence(task_id, actions)
|
| 45 |
+
results[task_id] = {"score": score}
|
| 46 |
+
|
| 47 |
+
print(json.dumps({"results": results}, indent=2))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
inference.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
|
|
|
| 3 |
import asyncio
|
| 4 |
from typing import List, Optional
|
| 5 |
from openai import OpenAI
|
| 6 |
from env.environment import SupportTicketEnv
|
| 7 |
from env.models import Action
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 10 |
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
@@ -26,19 +30,31 @@ def log_end(success: bool, steps: int, score: float, rewards: list):
|
|
| 26 |
print(f"[END] success={success} steps={steps} score={score} rewards={rewards}", flush=True)
|
| 27 |
|
| 28 |
def parse_action(text: str) -> Action:
|
|
|
|
| 29 |
try:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def get_model_message(client, step: int, env_state: str, history: List[str]) -> str:
|
| 44 |
system_prompt = (
|
|
@@ -56,20 +72,52 @@ def get_model_message(client, step: int, env_state: str, history: List[str]) ->
|
|
| 56 |
history_str = "\n".join(history)
|
| 57 |
user_prompt = f"History:\n{history_str}\n\nCurrent Observation:\n{env_state}\n\nWhat is your next action JSON?"
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
try:
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
except Exception as exc:
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
async def run_task(task_id: str, client: OpenAI) -> None:
|
| 75 |
env = SupportTicketEnv(task_id=task_id)
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
import logging
|
| 4 |
import asyncio
|
| 5 |
from typing import List, Optional
|
| 6 |
from openai import OpenAI
|
| 7 |
from env.environment import SupportTicketEnv
|
| 8 |
from env.models import Action
|
| 9 |
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
|
| 13 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 14 |
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 15 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
| 30 |
print(f"[END] success={success} steps={steps} score={score} rewards={rewards}", flush=True)
|
| 31 |
|
| 32 |
def parse_action(text: str) -> Action:
|
| 33 |
+
# Robustly extract the first JSON object from text and validate with Pydantic
|
| 34 |
try:
|
| 35 |
+
decoder = json.JSONDecoder()
|
| 36 |
+
idx = 0
|
| 37 |
+
while True:
|
| 38 |
+
idx = text.find('{', idx)
|
| 39 |
+
if idx == -1:
|
| 40 |
+
break
|
| 41 |
+
try:
|
| 42 |
+
obj, end = decoder.raw_decode(text, idx)
|
| 43 |
+
if isinstance(obj, dict):
|
| 44 |
+
try:
|
| 45 |
+
return Action.model_validate(obj)
|
| 46 |
+
except Exception as val_err:
|
| 47 |
+
logger.warning("Action validation failed: %s", val_err)
|
| 48 |
+
# fallback to manual construction
|
| 49 |
+
return Action(action_type=obj.get("action_type", "close_ticket"), parameters=obj.get("parameters", {}))
|
| 50 |
+
except json.JSONDecodeError:
|
| 51 |
+
idx += 1
|
| 52 |
+
continue
|
| 53 |
+
except Exception as exc:
|
| 54 |
+
logger.exception("Unexpected error while parsing action: %s", exc)
|
| 55 |
+
|
| 56 |
+
# Safe default when parsing/validation fails
|
| 57 |
+
return Action(action_type="close_ticket", parameters={"resolution": "invalid_parse"})
|
| 58 |
|
| 59 |
def get_model_message(client, step: int, env_state: str, history: List[str]) -> str:
|
| 60 |
system_prompt = (
|
|
|
|
| 72 |
history_str = "\n".join(history)
|
| 73 |
user_prompt = f"History:\n{history_str}\n\nCurrent Observation:\n{env_state}\n\nWhat is your next action JSON?"
|
| 74 |
|
| 75 |
+
import time
|
| 76 |
+
# retry/backoff parameters
|
| 77 |
+
max_retries = 3
|
| 78 |
+
backoff_base = 0.5
|
| 79 |
+
|
| 80 |
try:
|
| 81 |
+
# Support a few possible client interfaces (chat.completions or responses)
|
| 82 |
+
for attempt in range(1, max_retries + 1):
|
| 83 |
+
try:
|
| 84 |
+
if hasattr(client, "chat") and hasattr(client.chat, "completions"):
|
| 85 |
+
completion = client.chat.completions.create(
|
| 86 |
+
model=MODEL_NAME,
|
| 87 |
+
messages=[
|
| 88 |
+
{"role": "system", "content": system_prompt},
|
| 89 |
+
{"role": "user", "content": user_prompt}
|
| 90 |
+
],
|
| 91 |
+
temperature=0.1
|
| 92 |
+
)
|
| 93 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 94 |
+
return text if text else "{}"
|
| 95 |
+
|
| 96 |
+
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
| 97 |
+
completion = client.responses.create(model=MODEL_NAME, input=user_prompt, temperature=0.1)
|
| 98 |
+
text = getattr(completion, "output_text", None)
|
| 99 |
+
if text:
|
| 100 |
+
return text.strip()
|
| 101 |
+
|
| 102 |
+
out = []
|
| 103 |
+
for item in getattr(completion, "output", []) or []:
|
| 104 |
+
for c in item.get("content", []):
|
| 105 |
+
if c.get("type") == "output_text":
|
| 106 |
+
out.append(c.get("text", ""))
|
| 107 |
+
if out:
|
| 108 |
+
return "".join(out).strip()
|
| 109 |
+
|
| 110 |
+
raise RuntimeError("No supported model client method available")
|
| 111 |
+
except Exception as exc:
|
| 112 |
+
logger.warning("Model request attempt %d failed: %s", attempt, exc)
|
| 113 |
+
if attempt == max_retries:
|
| 114 |
+
break
|
| 115 |
+
sleep_time = backoff_base * (2 ** (attempt - 1))
|
| 116 |
+
time.sleep(sleep_time)
|
| 117 |
except Exception as exc:
|
| 118 |
+
logger.exception("Unexpected error in get_model_message: %s", exc)
|
| 119 |
+
|
| 120 |
+
return "{}"
|
| 121 |
|
| 122 |
async def run_task(task_id: str, client: OpenAI) -> None:
|
| 123 |
env = SupportTicketEnv(task_id=task_id)
|
server/app.py
CHANGED
|
@@ -2,35 +2,45 @@ from fastapi import FastAPI, HTTPException
|
|
| 2 |
from pydantic import BaseModel
|
| 3 |
from env.environment import SupportTicketEnv
|
| 4 |
from env.models import Action
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app = FastAPI(title="OpenEnv Support Ticket API")
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class InitRequest(BaseModel):
|
| 11 |
task_id: str = "task_easy_1"
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@app.get("/")
|
| 14 |
def read_root():
|
| 15 |
return {"status": "ok", "message": "Support Ticket OpenEnv is live."}
|
| 16 |
|
| 17 |
@app.post("/reset")
|
| 18 |
def reset_env(req: InitRequest):
|
| 19 |
-
global CURRENT_ENV_SESSION
|
| 20 |
try:
|
| 21 |
-
|
| 22 |
-
obs =
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
except ValueError as e:
|
| 25 |
raise HTTPException(status_code=400, detail=str(e))
|
| 26 |
|
| 27 |
@app.post("/step")
|
| 28 |
-
def step_env(
|
| 29 |
-
|
| 30 |
-
if not
|
| 31 |
-
raise HTTPException(status_code=400, detail="
|
| 32 |
-
|
| 33 |
-
obs, reward, done, info =
|
| 34 |
return {
|
| 35 |
"observation": obs.model_dump(),
|
| 36 |
"reward": reward,
|
|
@@ -39,11 +49,11 @@ def step_env(action: Action):
|
|
| 39 |
}
|
| 40 |
|
| 41 |
@app.get("/state")
|
| 42 |
-
def state_env():
|
| 43 |
-
|
| 44 |
-
if not
|
| 45 |
-
raise HTTPException(status_code=400, detail="
|
| 46 |
-
return
|
| 47 |
|
| 48 |
def main():
|
| 49 |
import uvicorn
|
|
|
|
| 2 |
from pydantic import BaseModel
|
| 3 |
from env.environment import SupportTicketEnv
|
| 4 |
from env.models import Action
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from uuid import uuid4
|
| 7 |
|
| 8 |
app = FastAPI(title="OpenEnv Support Ticket API")
|
| 9 |
|
| 10 |
+
# Store sessions keyed by UUID to allow concurrent sessions
|
| 11 |
+
SESSIONS: Dict[str, SupportTicketEnv] = {}
|
| 12 |
+
|
| 13 |
|
| 14 |
class InitRequest(BaseModel):
|
| 15 |
task_id: str = "task_easy_1"
|
| 16 |
|
| 17 |
+
|
| 18 |
+
class StepRequest(BaseModel):
|
| 19 |
+
session_id: str
|
| 20 |
+
action: Action
|
| 21 |
+
|
| 22 |
@app.get("/")
|
| 23 |
def read_root():
|
| 24 |
return {"status": "ok", "message": "Support Ticket OpenEnv is live."}
|
| 25 |
|
| 26 |
@app.post("/reset")
|
| 27 |
def reset_env(req: InitRequest):
|
|
|
|
| 28 |
try:
|
| 29 |
+
env = SupportTicketEnv(task_id=req.task_id)
|
| 30 |
+
obs = env.reset()
|
| 31 |
+
session_id = str(uuid4())
|
| 32 |
+
SESSIONS[session_id] = env
|
| 33 |
+
return {"session_id": session_id, "observation": obs.model_dump()}
|
| 34 |
except ValueError as e:
|
| 35 |
raise HTTPException(status_code=400, detail=str(e))
|
| 36 |
|
| 37 |
@app.post("/step")
|
| 38 |
+
def step_env(req: StepRequest):
|
| 39 |
+
env = SESSIONS.get(req.session_id)
|
| 40 |
+
if not env:
|
| 41 |
+
raise HTTPException(status_code=400, detail="Invalid or expired session_id. Call /reset to create a session.")
|
| 42 |
+
|
| 43 |
+
obs, reward, done, info = env.step(req.action)
|
| 44 |
return {
|
| 45 |
"observation": obs.model_dump(),
|
| 46 |
"reward": reward,
|
|
|
|
| 49 |
}
|
| 50 |
|
| 51 |
@app.get("/state")
|
| 52 |
+
def state_env(session_id: str):
|
| 53 |
+
env = SESSIONS.get(session_id)
|
| 54 |
+
if not env:
|
| 55 |
+
raise HTTPException(status_code=400, detail="Invalid or expired session_id. Call /reset to create a session.")
|
| 56 |
+
return env.get_state().model_dump()
|
| 57 |
|
| 58 |
def main():
|
| 59 |
import uvicorn
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Ensure project root is on sys.path so tests can import the `env` package
|
| 5 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 6 |
+
if PROJECT_ROOT not in sys.path:
|
| 7 |
+
sys.path.insert(0, PROJECT_ROOT)
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from env.environment import SupportTicketEnv
|
| 4 |
+
from env.models import Action
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_reset_and_initial_observation():
|
| 8 |
+
env = SupportTicketEnv(task_id="task_easy_1")
|
| 9 |
+
obs = env.reset()
|
| 10 |
+
assert obs.ticket.ticket_id == "TKT-1001"
|
| 11 |
+
assert obs.step_count == 0
|
| 12 |
+
assert "fetch_user_data" in obs.available_actions
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_fetch_user_data_success_and_failure():
|
| 16 |
+
env = SupportTicketEnv(task_id="task_easy_1")
|
| 17 |
+
env.reset()
|
| 18 |
+
|
| 19 |
+
# correct user_id
|
| 20 |
+
action = Action(action_type="fetch_user_data", parameters={"user_id": "USR-A1"})
|
| 21 |
+
obs, reward, done, info = env.step(action)
|
| 22 |
+
assert not done
|
| 23 |
+
assert "User Data" in (obs.tool_output or "")
|
| 24 |
+
|
| 25 |
+
# incorrect user_id
|
| 26 |
+
action_bad = Action(action_type="fetch_user_data", parameters={"user_id": "WRONG"})
|
| 27 |
+
obs2, reward2, done2, info2 = env.step(action_bad)
|
| 28 |
+
assert "Invalid user_id" in (obs2.tool_output or "") or "Failed to fetch" in obs2.system_message
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_easy_flow_grader_rewards():
|
| 32 |
+
env = SupportTicketEnv(task_id="task_easy_1")
|
| 33 |
+
env.reset()
|
| 34 |
+
|
| 35 |
+
# follow expected sequence for easy task
|
| 36 |
+
a1 = Action(action_type="check_policy", parameters={})
|
| 37 |
+
obs, r, done, info = env.step(a1)
|
| 38 |
+
|
| 39 |
+
a2 = Action(action_type="issue_refund", parameters={"amount": "full"})
|
| 40 |
+
obs, r, done, info = env.step(a2)
|
| 41 |
+
|
| 42 |
+
a3 = Action(action_type="close_ticket", parameters={"resolution": "refunded"})
|
| 43 |
+
obs, r, done, info = env.step(a3)
|
| 44 |
+
|
| 45 |
+
# reward should be > 0 and final
|
| 46 |
+
assert done is True
|
| 47 |
+
assert info.get("current_reward", 0.0) > 0.0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_medium_flow_no_refund_penalty():
|
| 51 |
+
env = SupportTicketEnv(task_id="task_medium_1")
|
| 52 |
+
env.reset()
|
| 53 |
+
|
| 54 |
+
a1 = Action(action_type="check_policy", parameters={})
|
| 55 |
+
obs, r, done, info = env.step(a1)
|
| 56 |
+
|
| 57 |
+
a2 = Action(action_type="reply_to_customer", parameters={"message": "Sorry, no refunds for prior billing."})
|
| 58 |
+
obs, r, done, info = env.step(a2)
|
| 59 |
+
|
| 60 |
+
a3 = Action(action_type="close_ticket", parameters={"resolution": "policy_explained"})
|
| 61 |
+
obs, r, done, info = env.step(a3)
|
| 62 |
+
|
| 63 |
+
assert done is True
|
| 64 |
+
assert info.get("current_reward", 0.0) > 0.0
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_hard_flow_requirements():
|
| 68 |
+
env = SupportTicketEnv(task_id="task_hard_1")
|
| 69 |
+
env.reset()
|
| 70 |
+
|
| 71 |
+
# fetch user data
|
| 72 |
+
a1 = Action(action_type="fetch_user_data", parameters={"user_id": "USR-C3"})
|
| 73 |
+
obs, r, done, info = env.step(a1)
|
| 74 |
+
|
| 75 |
+
# escalate with correct reason
|
| 76 |
+
a2 = Action(action_type="escalate", parameters={"reason": "billing_tier2"})
|
| 77 |
+
obs, r, done, info = env.step(a2)
|
| 78 |
+
|
| 79 |
+
# reply should be present in history or tool_output
|
| 80 |
+
assert done is True
|
| 81 |
+
assert info.get("current_reward", 0.0) >= 0.0
|