XcodeAddy commited on
Commit
af2ccc5
·
1 Parent(s): 18aa055

Keep reset and schema rewards inside unit interval

Browse files
Files changed (5) hide show
  1. environment.py +7 -6
  2. inference.py +8 -7
  3. models.py +4 -4
  4. tests/test_env.py +3 -3
  5. tests/test_inference.py +4 -4
environment.py CHANGED
@@ -39,6 +39,7 @@ TASK_SPECS = {
39
  },
40
  }
41
  DEFAULT_RESET_SEED = 42
 
42
 
43
 
44
  def validate_ticket_dataset(tickets: list[dict]) -> None:
@@ -77,9 +78,9 @@ class IncidentEnv:
77
  self.episode_id = ""
78
  self.step_count = 0
79
  self.max_steps = 1
80
- self.total_reward = 0.0
81
  self.done = False
82
- self.last_reward = 0.0
83
  self.last_action_summary = None
84
 
85
  def reset(
@@ -92,14 +93,14 @@ class IncidentEnv:
92
  self.current_ticket = self._select_ticket(normalized_task, ticket_id, seed)
93
  self.episode_id = str(uuid.uuid4())
94
  self.step_count = 0
95
- self.total_reward = 0.0
96
  self.done = False
97
- self.last_reward = 0.0
98
  self.last_action_summary = None
99
 
100
  return StepResult(
101
  observation=self._build_observation(),
102
- reward=IncidentReward(value=0.0, reason="Episode initialized."),
103
  done=False,
104
  info={
105
  "episode_id": self.episode_id,
@@ -138,7 +139,7 @@ class IncidentEnv:
138
 
139
  self.step_count += 1
140
  self.last_reward = reward_value
141
- self.total_reward += reward_value
142
  self.done = self.step_count >= self.max_steps
143
  self.last_action_summary = f"Submitted {selected_field}={agent_answer}"
144
 
 
39
  },
40
  }
41
  DEFAULT_RESET_SEED = 42
42
+ INITIAL_REWARD = 0.01
43
 
44
 
45
  def validate_ticket_dataset(tickets: list[dict]) -> None:
 
78
  self.episode_id = ""
79
  self.step_count = 0
80
  self.max_steps = 1
81
+ self.total_reward = INITIAL_REWARD
82
  self.done = False
83
+ self.last_reward = INITIAL_REWARD
84
  self.last_action_summary = None
85
 
86
  def reset(
 
93
  self.current_ticket = self._select_ticket(normalized_task, ticket_id, seed)
94
  self.episode_id = str(uuid.uuid4())
95
  self.step_count = 0
96
+ self.total_reward = INITIAL_REWARD
97
  self.done = False
98
+ self.last_reward = INITIAL_REWARD
99
  self.last_action_summary = None
100
 
101
  return StepResult(
102
  observation=self._build_observation(),
103
+ reward=IncidentReward(value=INITIAL_REWARD, reason="Episode initialized and awaiting first action."),
104
  done=False,
105
  info={
106
  "episode_id": self.episode_id,
 
139
 
140
  self.step_count += 1
141
  self.last_reward = reward_value
142
+ self.total_reward = reward_value
143
  self.done = self.step_count >= self.max_steps
144
  self.last_action_summary = f"Submitted {selected_field}={agent_answer}"
145
 
inference.py CHANGED
@@ -27,6 +27,7 @@ BENCHMARK = "incident-triage-env"
27
  MAX_TOKENS = 300
28
  TEMPERATURE = 0.0
29
  OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH") or "/tmp/outputs/baseline_scores.json")
 
30
 
31
  SYSTEM_PROMPT = """You are an expert SRE triaging production incidents.
32
  You will receive an incident alert, structured context, and the expected output field.
@@ -356,8 +357,8 @@ def get_action(model_client: Optional[OpenAI], observation: Dict[str, Any]) -> D
356
  def reward_value(step_data: Dict[str, Any]) -> float:
357
  reward = step_data.get("reward", {})
358
  if isinstance(reward, dict):
359
- return float(reward.get("value", 0.0))
360
- return float(reward or 0.0)
361
 
362
 
363
  def active_model_name(model_client: Optional[OpenAI]) -> str:
@@ -379,7 +380,7 @@ def run_episode(
379
  ) -> Dict[str, Any]:
380
  rewards: List[float] = []
381
  steps_taken = 0
382
- score = 0.0
383
  success = False
384
  episode_result: Dict[str, Any]
385
 
@@ -417,18 +418,18 @@ def run_episode(
417
  "agent_answer": step_data.get("info", {}).get("agent_answer"),
418
  }
419
  except Exception as exc:
420
- log_step(step=max(steps_taken, 1), action="error", reward=0.0, done=True, error=str(exc))
421
- score = 0.0
422
  success = False
423
  episode_result = {
424
  "incident_id": ticket["incident_id"],
425
  "task_type": ticket["task_type"],
426
- "score": 0.0,
427
  "success": False,
428
  "error": str(exc),
429
  }
430
  finally:
431
- log_end(success=success, steps=max(steps_taken, 1), score=score, rewards=rewards or [0.0])
432
 
433
  return episode_result
434
 
 
27
  MAX_TOKENS = 300
28
  TEMPERATURE = 0.0
29
  OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH") or "/tmp/outputs/baseline_scores.json")
30
+ MIN_EPISODE_SCORE = 0.01
31
 
32
  SYSTEM_PROMPT = """You are an expert SRE triaging production incidents.
33
  You will receive an incident alert, structured context, and the expected output field.
 
357
  def reward_value(step_data: Dict[str, Any]) -> float:
358
  reward = step_data.get("reward", {})
359
  if isinstance(reward, dict):
360
+ return float(reward.get("value", MIN_EPISODE_SCORE))
361
+ return float(reward or MIN_EPISODE_SCORE)
362
 
363
 
364
  def active_model_name(model_client: Optional[OpenAI]) -> str:
 
380
  ) -> Dict[str, Any]:
381
  rewards: List[float] = []
382
  steps_taken = 0
383
+ score = MIN_EPISODE_SCORE
384
  success = False
385
  episode_result: Dict[str, Any]
386
 
 
418
  "agent_answer": step_data.get("info", {}).get("agent_answer"),
419
  }
420
  except Exception as exc:
421
+ log_step(step=max(steps_taken, 1), action="error", reward=MIN_EPISODE_SCORE, done=True, error=str(exc))
422
+ score = MIN_EPISODE_SCORE
423
  success = False
424
  episode_result = {
425
  "incident_id": ticket["incident_id"],
426
  "task_type": ticket["task_type"],
427
+ "score": MIN_EPISODE_SCORE,
428
  "success": False,
429
  "error": str(exc),
430
  }
431
  finally:
432
+ log_end(success=success, steps=max(steps_taken, 1), score=score, rewards=rewards or [MIN_EPISODE_SCORE])
433
 
434
  return episode_result
435
 
models.py CHANGED
@@ -47,7 +47,7 @@ class IncidentObservation(BaseModel):
47
  step_count: int = 0
48
  max_steps: int = 1
49
  last_action_summary: Optional[str] = None
50
- last_reward: float = 0.0
51
  episode_status: str = "awaiting_action"
52
 
53
 
@@ -82,7 +82,7 @@ class IncidentAction(BaseModel):
82
 
83
 
84
  class IncidentReward(BaseModel):
85
- value: float = Field(..., ge=0.0, le=1.0)
86
  reason: str
87
 
88
 
@@ -98,13 +98,13 @@ class IncidentState(BaseModel):
98
  session_id: Optional[str] = None
99
  step_count: int
100
  max_steps: int
101
- total_reward: float = 0.0
102
  done: bool
103
  incident_id: str
104
  task_type: TaskType
105
  difficulty: str
106
  status: str
107
- last_reward: float = 0.0
108
 
109
 
110
  class ResetRequest(BaseModel):
 
47
  step_count: int = 0
48
  max_steps: int = 1
49
  last_action_summary: Optional[str] = None
50
+ last_reward: float = Field(default=0.01, gt=0.0, lt=1.0)
51
  episode_status: str = "awaiting_action"
52
 
53
 
 
82
 
83
 
84
  class IncidentReward(BaseModel):
85
+ value: float = Field(..., gt=0.0, lt=1.0)
86
  reason: str
87
 
88
 
 
98
  session_id: Optional[str] = None
99
  step_count: int
100
  max_steps: int
101
+ total_reward: float = Field(default=0.01, gt=0.0, lt=1.0)
102
  done: bool
103
  incident_id: str
104
  task_type: TaskType
105
  difficulty: str
106
  status: str
107
+ last_reward: float = Field(default=0.01, gt=0.0, lt=1.0)
108
 
109
 
110
  class ResetRequest(BaseModel):
tests/test_env.py CHANGED
@@ -82,7 +82,7 @@ class IncidentEnvApiTests(unittest.TestCase):
82
  body = response.json()
83
  self.assertEqual(body["observation"]["incident_id"], "INC-014")
84
  self.assertEqual(body["observation"]["task_type"], "task3")
85
- self.assertEqual(body["reward"]["value"], 0.0)
86
  self.assertFalse(body["done"])
87
  self.assertIn("session_id", body["info"])
88
  self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
@@ -201,13 +201,13 @@ class IncidentEnvApiTests(unittest.TestCase):
201
  session_id="done-session",
202
  step_count=1,
203
  max_steps=1,
204
- total_reward=1.0,
205
  done=True,
206
  incident_id="INC-001",
207
  task_type=TaskType.TASK1,
208
  difficulty="easy",
209
  status="completed",
210
- last_reward=1.0,
211
  )
212
 
213
  with TestClient(app) as client:
 
82
  body = response.json()
83
  self.assertEqual(body["observation"]["incident_id"], "INC-014")
84
  self.assertEqual(body["observation"]["task_type"], "task3")
85
+ self.assertEqual(body["reward"]["value"], 0.01)
86
  self.assertFalse(body["done"])
87
  self.assertIn("session_id", body["info"])
88
  self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
 
201
  session_id="done-session",
202
  step_count=1,
203
  max_steps=1,
204
+ total_reward=0.99,
205
  done=True,
206
  incident_id="INC-001",
207
  task_type=TaskType.TASK1,
208
  difficulty="easy",
209
  status="completed",
210
+ last_reward=0.99,
211
  )
212
 
213
  with TestClient(app) as client:
tests/test_inference.py CHANGED
@@ -19,7 +19,7 @@ class InferenceOutputTests(unittest.TestCase):
19
 
20
  def test_write_results_writes_summary_to_configured_path(self) -> None:
21
  results = [
22
- {"incident_id": "INC-001", "task_type": "task1", "score": 1.0, "success": True},
23
  {"incident_id": "INC-002", "task_type": "task2", "score": 0.5, "success": False},
24
  ]
25
 
@@ -30,13 +30,13 @@ class InferenceOutputTests(unittest.TestCase):
30
  self.assertTrue(output_path.exists())
31
  payload = json.loads(output_path.read_text())
32
  self.assertEqual(payload["episodes"], 2)
33
- self.assertAlmostEqual(payload["average_score"], 0.75)
34
- self.assertEqual(payload["by_task"]["task1"]["average_score"], 1.0)
35
  self.assertEqual(payload["by_task"]["task2"]["average_score"], 0.5)
36
 
37
  def test_write_results_tolerates_unwritable_path(self) -> None:
38
  results = [
39
- {"incident_id": "INC-001", "task_type": "task1", "score": 1.0, "success": True},
40
  ]
41
 
42
  with tempfile.TemporaryDirectory() as temp_dir:
 
19
 
20
  def test_write_results_writes_summary_to_configured_path(self) -> None:
21
  results = [
22
+ {"incident_id": "INC-001", "task_type": "task1", "score": 0.99, "success": True},
23
  {"incident_id": "INC-002", "task_type": "task2", "score": 0.5, "success": False},
24
  ]
25
 
 
30
  self.assertTrue(output_path.exists())
31
  payload = json.loads(output_path.read_text())
32
  self.assertEqual(payload["episodes"], 2)
33
+ self.assertAlmostEqual(payload["average_score"], 0.745)
34
+ self.assertEqual(payload["by_task"]["task1"]["average_score"], 0.99)
35
  self.assertEqual(payload["by_task"]["task2"]["average_score"], 0.5)
36
 
37
  def test_write_results_tolerates_unwritable_path(self) -> None:
38
  results = [
39
+ {"incident_id": "INC-001", "task_type": "task1", "score": 0.99, "success": True},
40
  ]
41
 
42
  with tempfile.TemporaryDirectory() as temp_dir: