Siteshcodes commited on
Commit
ca4e18e
Β·
1 Parent(s): 442df7c

upgrade environment.py: done guard, fix tasks_completed, sample_bug

Browse files
Files changed (1) hide show
  1. server/environment.py +79 -20
server/environment.py CHANGED
@@ -8,23 +8,41 @@ import random
8
  from typing import Optional
9
  from openenv.core.env_server.interfaces import Environment
10
  from model import TriageAction, TriageObservation, TriageState, BugReport
11
- from task import TASKS, grade_action
 
 
 
 
 
12
  class BugTriageEnvironment(Environment):
13
  """
14
  Bug Report Triage RL environment.
15
- Agent reads GitHub-style bug reports and must triage them
16
- by assigning priority, labels, team, and milestone.
 
 
 
 
 
 
17
  """
18
 
19
  def __init__(self):
20
  self._state = TriageState(episode_id=str(uuid.uuid4()))
21
  self._current_bug: Optional[BugReport] = None
22
  self._current_task_key: str = "easy"
 
 
 
 
 
23
 
24
  def reset(self) -> TriageObservation:
 
25
  self._state = TriageState(episode_id=str(uuid.uuid4()))
26
  self._current_task_key = "easy"
27
- self._current_bug = random.choice(TASKS["easy"]["bugs"])
 
28
 
29
  return TriageObservation(
30
  bug_report=self._current_bug,
@@ -35,36 +53,77 @@ class BugTriageEnvironment(Environment):
35
  reward=0.0,
36
  )
37
 
 
 
 
 
38
  def step(self, action: TriageAction) -> TriageObservation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  self._state.step_count += 1
40
  task_key = self._current_task_key
41
 
42
- # Grade the action
43
- assert self._current_bug is not None, "step() called before reset()"
44
  score, feedback = grade_action(task_key, self._current_bug, action)
45
  self._state.total_score += score
 
46
 
47
- # Determine if episode is done
48
- # Each task = 1 step; easy β†’ medium β†’ hard β†’ done
49
- progression = ["easy", "medium", "hard"]
50
- current_idx = progression.index(task_key)
51
- done = current_idx == 2 # done after hard task
52
 
53
- if not done:
54
- next_key = progression[current_idx + 1]
55
- self._current_task_key = next_key
56
- self._current_bug = random.choice(TASKS[next_key]["bugs"])
57
- self._state.tasks_completed.append(task_key)
 
 
 
 
 
 
 
 
58
 
59
  return TriageObservation(
60
- bug_report=self._current_bug,
61
- task_id=self._current_task_key,
62
- score=self._state.total_score / max(self._state.step_count, 1),
63
  feedback=feedback,
64
  done=done,
65
- reward=score,
66
  )
67
 
 
 
 
 
68
  @property
69
  def state(self) -> TriageState:
 
 
 
 
 
70
  return self._state
 
8
  from typing import Optional
9
  from openenv.core.env_server.interfaces import Environment
10
  from model import TriageAction, TriageObservation, TriageState, BugReport
11
+ from task import TASKS, grade_action, sample_bug
12
+
13
+
14
+ TASK_PROGRESSION = ["easy", "medium", "hard"]
15
+
16
+
17
  class BugTriageEnvironment(Environment):
18
  """
19
  Bug Report Triage RL environment.
20
+
21
+ Episode structure:
22
+ Step 1 β†’ easy task (priority only)
23
+ Step 2 β†’ medium task (priority + labels + team)
24
+ Step 3 β†’ hard task (priority + labels + team + milestone)
25
+
26
+ Each reset() picks a fresh random bug from each task pool,
27
+ so the agent never sees the same sequence twice.
28
  """
29
 
30
  def __init__(self):
31
  self._state = TriageState(episode_id=str(uuid.uuid4()))
32
  self._current_bug: Optional[BugReport] = None
33
  self._current_task_key: str = "easy"
34
+ self._episode_done: bool = False
35
+
36
+ # ─────────────────────────────────────────
37
+ # reset()
38
+ # ─────────────────────────────────────────
39
 
40
  def reset(self) -> TriageObservation:
41
+ """Start a fresh episode. Picks a random bug from the easy pool."""
42
  self._state = TriageState(episode_id=str(uuid.uuid4()))
43
  self._current_task_key = "easy"
44
+ self._episode_done = False
45
+ self._current_bug = sample_bug("easy")
46
 
47
  return TriageObservation(
48
  bug_report=self._current_bug,
 
53
  reward=0.0,
54
  )
55
 
56
+ # ─────────────────────────────────────────
57
+ # step()
58
+ # ─────────────────────────────────────────
59
+
60
  def step(self, action: TriageAction) -> TriageObservation:
61
+ """
62
+ Process the agent's triage action and return the next observation.
63
+
64
+ - Grades the current task
65
+ - Advances to next task (easy β†’ medium β†’ hard)
66
+ - Returns done=True after the hard task is graded
67
+ """
68
+ # Guard: prevent stepping after episode is over
69
+ if self._episode_done:
70
+ assert self._current_bug is not None
71
+ return TriageObservation(
72
+ bug_report=self._current_bug,
73
+ task_id=self._current_task_key,
74
+ score=self._state.total_score / max(self._state.step_count, 1),
75
+ feedback="Episode already complete. Call reset() to start a new episode.",
76
+ done=True,
77
+ reward=0.0,
78
+ )
79
+
80
+ # Guard: step() must be called after reset()
81
+ assert self._current_bug is not None, "step() called before reset()"
82
+
83
  self._state.step_count += 1
84
  task_key = self._current_task_key
85
 
86
+ # Grade the action for this task
 
87
  score, feedback = grade_action(task_key, self._current_bug, action)
88
  self._state.total_score += score
89
+ self._state.tasks_completed.append(task_key)
90
 
91
+ # Determine progression
92
+ current_idx = TASK_PROGRESSION.index(task_key)
93
+ done = current_idx == len(TASK_PROGRESSION) - 1 # True after hard task
 
 
94
 
95
+ if done:
96
+ # Episode complete β€” keep current bug/task for final observation
97
+ self._episode_done = True
98
+ next_bug = self._current_bug
99
+ next_task = self._current_task_key
100
+ else:
101
+ # Advance to next task with a fresh random bug
102
+ next_task = TASK_PROGRESSION[current_idx + 1]
103
+ next_bug = sample_bug(next_task)
104
+ self._current_task_key = next_task
105
+ self._current_bug = next_bug
106
+
107
+ avg_score = self._state.total_score / self._state.step_count
108
 
109
  return TriageObservation(
110
+ bug_report=next_bug,
111
+ task_id=next_task,
112
+ score=round(avg_score, 3),
113
  feedback=feedback,
114
  done=done,
115
+ reward=round(score, 3),
116
  )
117
 
118
+ # ─────────────────────────────────────────
119
+ # state() β€” both property and method forms
120
+ # ─────────────────────────────────────────
121
+
122
  @property
123
  def state(self) -> TriageState:
124
+ """Property form β€” used internally."""
125
+ return self._state
126
+
127
+ def get_state(self) -> TriageState:
128
+ """Method form β€” satisfies OpenEnv spec's state() interface."""
129
  return self._state