fix: lint/type issues and ensure mypy/ruff pass
Browse files- .gitignore +1 -1
- env/environment.py +12 -6
- inference.py +1 -3
- tests/test_environment.py +0 -2
- uv.lock +0 -0
.gitignore
CHANGED
|
@@ -15,4 +15,4 @@ __pycache__/
|
|
| 15 |
.DS_Store
|
| 16 |
|
| 17 |
# Tool outputs (uv)
|
| 18 |
-
uv.lock
|
|
|
|
| 15 |
.DS_Store
|
| 16 |
|
| 17 |
# Tool outputs (uv)
|
| 18 |
+
# uv.lock should be committed for OpenEnv validation; do not ignore it
|
env/environment.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Tuple, Dict, Any, Optional
|
| 2 |
from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
|
| 3 |
from .tasks import TASKS
|
| 4 |
from .graders import grade
|
|
@@ -9,12 +9,12 @@ class SupportTicketEnv:
|
|
| 9 |
if task_id not in TASKS:
|
| 10 |
raise ValueError(f"Unknown task_id: {task_id}")
|
| 11 |
self.task_data = TASKS[task_id]
|
| 12 |
-
self.state = None
|
| 13 |
self.max_steps = 10
|
| 14 |
self.reset()
|
| 15 |
|
| 16 |
def reset(self) -> Observation:
|
| 17 |
-
ticket_data = self.task_data["ticket"]
|
| 18 |
self.state = EnvironmentState(
|
| 19 |
current_task_id=self.task_id,
|
| 20 |
step_count=0,
|
|
@@ -22,11 +22,12 @@ class SupportTicketEnv:
|
|
| 22 |
action_history=[],
|
| 23 |
is_done=False,
|
| 24 |
final_reward=0.0,
|
| 25 |
-
task_difficulty=self.task_data["difficulty"]
|
| 26 |
)
|
| 27 |
return self._get_observation("System initialized. Ticket assigned.")
|
| 28 |
|
| 29 |
def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
|
|
|
|
| 30 |
return Observation(
|
| 31 |
ticket=self.state.ticket,
|
| 32 |
available_actions=[
|
|
@@ -40,6 +41,8 @@ class SupportTicketEnv:
|
|
| 40 |
)
|
| 41 |
|
| 42 |
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
|
|
|
|
|
|
|
| 43 |
if self.state.is_done:
|
| 44 |
return self._get_observation("Episode is over."), 0.0, True, {}
|
| 45 |
|
|
@@ -53,7 +56,8 @@ class SupportTicketEnv:
|
|
| 53 |
if action.action_type == "fetch_user_data":
|
| 54 |
user_id = action.parameters.get("user_id")
|
| 55 |
if user_id == self.state.ticket.user_id:
|
| 56 |
-
|
|
|
|
| 57 |
tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}"
|
| 58 |
else:
|
| 59 |
tool_output = "Error: Invalid user_id."
|
|
@@ -61,7 +65,8 @@ class SupportTicketEnv:
|
|
| 61 |
|
| 62 |
elif action.action_type == "check_policy":
|
| 63 |
issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
|
| 64 |
-
|
|
|
|
| 65 |
tool_output = f"Policy for {issue_type}: {policy}"
|
| 66 |
|
| 67 |
elif action.action_type == "issue_refund":
|
|
@@ -107,4 +112,5 @@ class SupportTicketEnv:
|
|
| 107 |
return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
|
| 108 |
|
| 109 |
def get_state(self) -> EnvironmentState:
|
|
|
|
| 110 |
return self.state
|
|
|
|
| 1 |
+
from typing import Tuple, Dict, Any, Optional, cast
|
| 2 |
from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
|
| 3 |
from .tasks import TASKS
|
| 4 |
from .graders import grade
|
|
|
|
| 9 |
if task_id not in TASKS:
|
| 10 |
raise ValueError(f"Unknown task_id: {task_id}")
|
| 11 |
self.task_data = TASKS[task_id]
|
| 12 |
+
self.state: Optional[EnvironmentState] = None
|
| 13 |
self.max_steps = 10
|
| 14 |
self.reset()
|
| 15 |
|
| 16 |
def reset(self) -> Observation:
|
| 17 |
+
ticket_data = cast(Dict[str, Any], self.task_data["ticket"])
|
| 18 |
self.state = EnvironmentState(
|
| 19 |
current_task_id=self.task_id,
|
| 20 |
step_count=0,
|
|
|
|
| 22 |
action_history=[],
|
| 23 |
is_done=False,
|
| 24 |
final_reward=0.0,
|
| 25 |
+
task_difficulty=str(self.task_data["difficulty"])
|
| 26 |
)
|
| 27 |
return self._get_observation("System initialized. Ticket assigned.")
|
| 28 |
|
| 29 |
def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
|
| 30 |
+
assert self.state is not None
|
| 31 |
return Observation(
|
| 32 |
ticket=self.state.ticket,
|
| 33 |
available_actions=[
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
|
| 44 |
+
assert self.state is not None
|
| 45 |
+
|
| 46 |
if self.state.is_done:
|
| 47 |
return self._get_observation("Episode is over."), 0.0, True, {}
|
| 48 |
|
|
|
|
| 56 |
if action.action_type == "fetch_user_data":
|
| 57 |
user_id = action.parameters.get("user_id")
|
| 58 |
if user_id == self.state.ticket.user_id:
|
| 59 |
+
user_data = cast(Dict[str, Any], self.task_data["user_data"])
|
| 60 |
+
self.state.user_data = UserData(**user_data)
|
| 61 |
tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}"
|
| 62 |
else:
|
| 63 |
tool_output = "Error: Invalid user_id."
|
|
|
|
| 65 |
|
| 66 |
elif action.action_type == "check_policy":
|
| 67 |
issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
|
| 68 |
+
policy_map = cast(Dict[str, str], self.task_data["policy"])
|
| 69 |
+
policy = policy_map.get(issue_type, "No specific policy found.")
|
| 70 |
tool_output = f"Policy for {issue_type}: {policy}"
|
| 71 |
|
| 72 |
elif action.action_type == "issue_refund":
|
|
|
|
| 112 |
return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
|
| 113 |
|
| 114 |
def get_state(self) -> EnvironmentState:
|
| 115 |
+
assert self.state is not None
|
| 116 |
return self.state
|
inference.py
CHANGED
|
@@ -131,10 +131,9 @@ async def run_task(task_id: str, client: OpenAI) -> None:
|
|
| 131 |
try:
|
| 132 |
obs = env.reset()
|
| 133 |
last_echoed = obs.model_dump_json(indent=2)
|
| 134 |
-
last_reward = 0.0
|
| 135 |
|
| 136 |
for step in range(1, MAX_STEPS + 1):
|
| 137 |
-
if env.
|
| 138 |
break
|
| 139 |
|
| 140 |
message = get_model_message(client, step, last_echoed, history)
|
|
@@ -149,7 +148,6 @@ async def run_task(task_id: str, client: OpenAI) -> None:
|
|
| 149 |
rewards.append(actual_reward)
|
| 150 |
steps_taken = step
|
| 151 |
last_echoed = obs_json
|
| 152 |
-
last_reward = actual_reward
|
| 153 |
|
| 154 |
log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
|
| 155 |
history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
|
|
|
|
| 131 |
try:
|
| 132 |
obs = env.reset()
|
| 133 |
last_echoed = obs.model_dump_json(indent=2)
|
|
|
|
| 134 |
|
| 135 |
for step in range(1, MAX_STEPS + 1):
|
| 136 |
+
if env.get_state().is_done:
|
| 137 |
break
|
| 138 |
|
| 139 |
message = get_model_message(client, step, last_echoed, history)
|
|
|
|
| 148 |
rewards.append(actual_reward)
|
| 149 |
steps_taken = step
|
| 150 |
last_echoed = obs_json
|
|
|
|
| 151 |
|
| 152 |
log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
|
| 153 |
history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
|
tests/test_environment.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
import pytest
|
| 2 |
-
|
| 3 |
from env.environment import SupportTicketEnv
|
| 4 |
from env.models import Action
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 1 |
from env.environment import SupportTicketEnv
|
| 2 |
from env.models import Action
|
| 3 |
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|