Sid8421 commited on
Commit
af3f703
·
1 Parent(s): ce6200f

fix: lint/type issues and ensure mypy/ruff pass

Browse files
Files changed (5) hide show
  1. .gitignore +1 -1
  2. env/environment.py +12 -6
  3. inference.py +1 -3
  4. tests/test_environment.py +0 -2
  5. uv.lock +0 -0
.gitignore CHANGED
@@ -15,4 +15,4 @@ __pycache__/
15
  .DS_Store
16
 
17
  # Tool outputs (uv)
18
- uv.lock
 
15
  .DS_Store
16
 
17
  # Tool outputs (uv)
18
+ # uv.lock should be committed for OpenEnv validation; do not ignore it
env/environment.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Tuple, Dict, Any, Optional
2
  from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
3
  from .tasks import TASKS
4
  from .graders import grade
@@ -9,12 +9,12 @@ class SupportTicketEnv:
9
  if task_id not in TASKS:
10
  raise ValueError(f"Unknown task_id: {task_id}")
11
  self.task_data = TASKS[task_id]
12
- self.state = None
13
  self.max_steps = 10
14
  self.reset()
15
 
16
  def reset(self) -> Observation:
17
- ticket_data = self.task_data["ticket"]
18
  self.state = EnvironmentState(
19
  current_task_id=self.task_id,
20
  step_count=0,
@@ -22,11 +22,12 @@ class SupportTicketEnv:
22
  action_history=[],
23
  is_done=False,
24
  final_reward=0.0,
25
- task_difficulty=self.task_data["difficulty"]
26
  )
27
  return self._get_observation("System initialized. Ticket assigned.")
28
 
29
  def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
 
30
  return Observation(
31
  ticket=self.state.ticket,
32
  available_actions=[
@@ -40,6 +41,8 @@ class SupportTicketEnv:
40
  )
41
 
42
  def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
 
 
43
  if self.state.is_done:
44
  return self._get_observation("Episode is over."), 0.0, True, {}
45
 
@@ -53,7 +56,8 @@ class SupportTicketEnv:
53
  if action.action_type == "fetch_user_data":
54
  user_id = action.parameters.get("user_id")
55
  if user_id == self.state.ticket.user_id:
56
- self.state.user_data = UserData(**self.task_data["user_data"])
 
57
  tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}"
58
  else:
59
  tool_output = "Error: Invalid user_id."
@@ -61,7 +65,8 @@ class SupportTicketEnv:
61
 
62
  elif action.action_type == "check_policy":
63
  issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
64
- policy = self.task_data["policy"].get(issue_type, "No specific policy found.")
 
65
  tool_output = f"Policy for {issue_type}: {policy}"
66
 
67
  elif action.action_type == "issue_refund":
@@ -107,4 +112,5 @@ class SupportTicketEnv:
107
  return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
108
 
109
  def get_state(self) -> EnvironmentState:
 
110
  return self.state
 
1
+ from typing import Tuple, Dict, Any, Optional, cast
2
  from .models import Action, Observation, EnvironmentState, TicketInfo, UserData
3
  from .tasks import TASKS
4
  from .graders import grade
 
9
  if task_id not in TASKS:
10
  raise ValueError(f"Unknown task_id: {task_id}")
11
  self.task_data = TASKS[task_id]
12
+ self.state: Optional[EnvironmentState] = None
13
  self.max_steps = 10
14
  self.reset()
15
 
16
  def reset(self) -> Observation:
17
+ ticket_data = cast(Dict[str, Any], self.task_data["ticket"])
18
  self.state = EnvironmentState(
19
  current_task_id=self.task_id,
20
  step_count=0,
 
22
  action_history=[],
23
  is_done=False,
24
  final_reward=0.0,
25
+ task_difficulty=str(self.task_data["difficulty"])
26
  )
27
  return self._get_observation("System initialized. Ticket assigned.")
28
 
29
  def _get_observation(self, system_message: str, tool_output: Optional[str] = None) -> Observation:
30
+ assert self.state is not None
31
  return Observation(
32
  ticket=self.state.ticket,
33
  available_actions=[
 
41
  )
42
 
43
  def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
44
+ assert self.state is not None
45
+
46
  if self.state.is_done:
47
  return self._get_observation("Episode is over."), 0.0, True, {}
48
 
 
56
  if action.action_type == "fetch_user_data":
57
  user_id = action.parameters.get("user_id")
58
  if user_id == self.state.ticket.user_id:
59
+ user_data = cast(Dict[str, Any], self.task_data["user_data"])
60
+ self.state.user_data = UserData(**user_data)
61
  tool_output = f"User Data: Tier = {self.state.user_data.account_tier}, Joined = {self.state.user_data.join_date}"
62
  else:
63
  tool_output = "Error: Invalid user_id."
 
65
 
66
  elif action.action_type == "check_policy":
67
  issue_type = action.parameters.get("issue_type", self.state.ticket.issue_type)
68
+ policy_map = cast(Dict[str, str], self.task_data["policy"])
69
+ policy = policy_map.get(issue_type, "No specific policy found.")
70
  tool_output = f"Policy for {issue_type}: {policy}"
71
 
72
  elif action.action_type == "issue_refund":
 
112
  return self._get_observation(system_message, tool_output), reward, self.state.is_done, info
113
 
114
  def get_state(self) -> EnvironmentState:
115
+ assert self.state is not None
116
  return self.state
inference.py CHANGED
@@ -131,10 +131,9 @@ async def run_task(task_id: str, client: OpenAI) -> None:
131
  try:
132
  obs = env.reset()
133
  last_echoed = obs.model_dump_json(indent=2)
134
- last_reward = 0.0
135
 
136
  for step in range(1, MAX_STEPS + 1):
137
- if env.state.is_done:
138
  break
139
 
140
  message = get_model_message(client, step, last_echoed, history)
@@ -149,7 +148,6 @@ async def run_task(task_id: str, client: OpenAI) -> None:
149
  rewards.append(actual_reward)
150
  steps_taken = step
151
  last_echoed = obs_json
152
- last_reward = actual_reward
153
 
154
  log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
155
  history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
 
131
  try:
132
  obs = env.reset()
133
  last_echoed = obs.model_dump_json(indent=2)
 
134
 
135
  for step in range(1, MAX_STEPS + 1):
136
+ if env.get_state().is_done:
137
  break
138
 
139
  message = get_model_message(client, step, last_echoed, history)
 
148
  rewards.append(actual_reward)
149
  steps_taken = step
150
  last_echoed = obs_json
 
151
 
152
  log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
153
  history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
tests/test_environment.py CHANGED
@@ -1,5 +1,3 @@
1
- import pytest
2
-
3
  from env.environment import SupportTicketEnv
4
  from env.models import Action
5
 
 
 
 
1
  from env.environment import SupportTicketEnv
2
  from env.models import Action
3
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff