Spaces:
Running
Running
Keep reset and schema rewards inside unit interval
Browse files- environment.py +7 -6
- inference.py +8 -7
- models.py +4 -4
- tests/test_env.py +3 -3
- tests/test_inference.py +4 -4
environment.py
CHANGED
|
@@ -39,6 +39,7 @@ TASK_SPECS = {
|
|
| 39 |
},
|
| 40 |
}
|
| 41 |
DEFAULT_RESET_SEED = 42
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def validate_ticket_dataset(tickets: list[dict]) -> None:
|
|
@@ -77,9 +78,9 @@ class IncidentEnv:
|
|
| 77 |
self.episode_id = ""
|
| 78 |
self.step_count = 0
|
| 79 |
self.max_steps = 1
|
| 80 |
-
self.total_reward =
|
| 81 |
self.done = False
|
| 82 |
-
self.last_reward =
|
| 83 |
self.last_action_summary = None
|
| 84 |
|
| 85 |
def reset(
|
|
@@ -92,14 +93,14 @@ class IncidentEnv:
|
|
| 92 |
self.current_ticket = self._select_ticket(normalized_task, ticket_id, seed)
|
| 93 |
self.episode_id = str(uuid.uuid4())
|
| 94 |
self.step_count = 0
|
| 95 |
-
self.total_reward =
|
| 96 |
self.done = False
|
| 97 |
-
self.last_reward =
|
| 98 |
self.last_action_summary = None
|
| 99 |
|
| 100 |
return StepResult(
|
| 101 |
observation=self._build_observation(),
|
| 102 |
-
reward=IncidentReward(value=
|
| 103 |
done=False,
|
| 104 |
info={
|
| 105 |
"episode_id": self.episode_id,
|
|
@@ -138,7 +139,7 @@ class IncidentEnv:
|
|
| 138 |
|
| 139 |
self.step_count += 1
|
| 140 |
self.last_reward = reward_value
|
| 141 |
-
self.total_reward
|
| 142 |
self.done = self.step_count >= self.max_steps
|
| 143 |
self.last_action_summary = f"Submitted {selected_field}={agent_answer}"
|
| 144 |
|
|
|
|
| 39 |
},
|
| 40 |
}
|
| 41 |
DEFAULT_RESET_SEED = 42
|
| 42 |
+
INITIAL_REWARD = 0.01
|
| 43 |
|
| 44 |
|
| 45 |
def validate_ticket_dataset(tickets: list[dict]) -> None:
|
|
|
|
| 78 |
self.episode_id = ""
|
| 79 |
self.step_count = 0
|
| 80 |
self.max_steps = 1
|
| 81 |
+
self.total_reward = INITIAL_REWARD
|
| 82 |
self.done = False
|
| 83 |
+
self.last_reward = INITIAL_REWARD
|
| 84 |
self.last_action_summary = None
|
| 85 |
|
| 86 |
def reset(
|
|
|
|
| 93 |
self.current_ticket = self._select_ticket(normalized_task, ticket_id, seed)
|
| 94 |
self.episode_id = str(uuid.uuid4())
|
| 95 |
self.step_count = 0
|
| 96 |
+
self.total_reward = INITIAL_REWARD
|
| 97 |
self.done = False
|
| 98 |
+
self.last_reward = INITIAL_REWARD
|
| 99 |
self.last_action_summary = None
|
| 100 |
|
| 101 |
return StepResult(
|
| 102 |
observation=self._build_observation(),
|
| 103 |
+
reward=IncidentReward(value=INITIAL_REWARD, reason="Episode initialized and awaiting first action."),
|
| 104 |
done=False,
|
| 105 |
info={
|
| 106 |
"episode_id": self.episode_id,
|
|
|
|
| 139 |
|
| 140 |
self.step_count += 1
|
| 141 |
self.last_reward = reward_value
|
| 142 |
+
self.total_reward = reward_value
|
| 143 |
self.done = self.step_count >= self.max_steps
|
| 144 |
self.last_action_summary = f"Submitted {selected_field}={agent_answer}"
|
| 145 |
|
inference.py
CHANGED
|
@@ -27,6 +27,7 @@ BENCHMARK = "incident-triage-env"
|
|
| 27 |
MAX_TOKENS = 300
|
| 28 |
TEMPERATURE = 0.0
|
| 29 |
OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH") or "/tmp/outputs/baseline_scores.json")
|
|
|
|
| 30 |
|
| 31 |
SYSTEM_PROMPT = """You are an expert SRE triaging production incidents.
|
| 32 |
You will receive an incident alert, structured context, and the expected output field.
|
|
@@ -356,8 +357,8 @@ def get_action(model_client: Optional[OpenAI], observation: Dict[str, Any]) -> D
|
|
| 356 |
def reward_value(step_data: Dict[str, Any]) -> float:
|
| 357 |
reward = step_data.get("reward", {})
|
| 358 |
if isinstance(reward, dict):
|
| 359 |
-
return float(reward.get("value",
|
| 360 |
-
return float(reward or
|
| 361 |
|
| 362 |
|
| 363 |
def active_model_name(model_client: Optional[OpenAI]) -> str:
|
|
@@ -379,7 +380,7 @@ def run_episode(
|
|
| 379 |
) -> Dict[str, Any]:
|
| 380 |
rewards: List[float] = []
|
| 381 |
steps_taken = 0
|
| 382 |
-
score =
|
| 383 |
success = False
|
| 384 |
episode_result: Dict[str, Any]
|
| 385 |
|
|
@@ -417,18 +418,18 @@ def run_episode(
|
|
| 417 |
"agent_answer": step_data.get("info", {}).get("agent_answer"),
|
| 418 |
}
|
| 419 |
except Exception as exc:
|
| 420 |
-
log_step(step=max(steps_taken, 1), action="error", reward=
|
| 421 |
-
score =
|
| 422 |
success = False
|
| 423 |
episode_result = {
|
| 424 |
"incident_id": ticket["incident_id"],
|
| 425 |
"task_type": ticket["task_type"],
|
| 426 |
-
"score":
|
| 427 |
"success": False,
|
| 428 |
"error": str(exc),
|
| 429 |
}
|
| 430 |
finally:
|
| 431 |
-
log_end(success=success, steps=max(steps_taken, 1), score=score, rewards=rewards or [
|
| 432 |
|
| 433 |
return episode_result
|
| 434 |
|
|
|
|
| 27 |
MAX_TOKENS = 300
|
| 28 |
TEMPERATURE = 0.0
|
| 29 |
OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH") or "/tmp/outputs/baseline_scores.json")
|
| 30 |
+
MIN_EPISODE_SCORE = 0.01
|
| 31 |
|
| 32 |
SYSTEM_PROMPT = """You are an expert SRE triaging production incidents.
|
| 33 |
You will receive an incident alert, structured context, and the expected output field.
|
|
|
|
| 357 |
def reward_value(step_data: Dict[str, Any]) -> float:
|
| 358 |
reward = step_data.get("reward", {})
|
| 359 |
if isinstance(reward, dict):
|
| 360 |
+
return float(reward.get("value", MIN_EPISODE_SCORE))
|
| 361 |
+
return float(reward or MIN_EPISODE_SCORE)
|
| 362 |
|
| 363 |
|
| 364 |
def active_model_name(model_client: Optional[OpenAI]) -> str:
|
|
|
|
| 380 |
) -> Dict[str, Any]:
|
| 381 |
rewards: List[float] = []
|
| 382 |
steps_taken = 0
|
| 383 |
+
score = MIN_EPISODE_SCORE
|
| 384 |
success = False
|
| 385 |
episode_result: Dict[str, Any]
|
| 386 |
|
|
|
|
| 418 |
"agent_answer": step_data.get("info", {}).get("agent_answer"),
|
| 419 |
}
|
| 420 |
except Exception as exc:
|
| 421 |
+
log_step(step=max(steps_taken, 1), action="error", reward=MIN_EPISODE_SCORE, done=True, error=str(exc))
|
| 422 |
+
score = MIN_EPISODE_SCORE
|
| 423 |
success = False
|
| 424 |
episode_result = {
|
| 425 |
"incident_id": ticket["incident_id"],
|
| 426 |
"task_type": ticket["task_type"],
|
| 427 |
+
"score": MIN_EPISODE_SCORE,
|
| 428 |
"success": False,
|
| 429 |
"error": str(exc),
|
| 430 |
}
|
| 431 |
finally:
|
| 432 |
+
log_end(success=success, steps=max(steps_taken, 1), score=score, rewards=rewards or [MIN_EPISODE_SCORE])
|
| 433 |
|
| 434 |
return episode_result
|
| 435 |
|
models.py
CHANGED
|
@@ -47,7 +47,7 @@ class IncidentObservation(BaseModel):
|
|
| 47 |
step_count: int = 0
|
| 48 |
max_steps: int = 1
|
| 49 |
last_action_summary: Optional[str] = None
|
| 50 |
-
last_reward: float = 0.0
|
| 51 |
episode_status: str = "awaiting_action"
|
| 52 |
|
| 53 |
|
|
@@ -82,7 +82,7 @@ class IncidentAction(BaseModel):
|
|
| 82 |
|
| 83 |
|
| 84 |
class IncidentReward(BaseModel):
|
| 85 |
-
value: float = Field(...,
|
| 86 |
reason: str
|
| 87 |
|
| 88 |
|
|
@@ -98,13 +98,13 @@ class IncidentState(BaseModel):
|
|
| 98 |
session_id: Optional[str] = None
|
| 99 |
step_count: int
|
| 100 |
max_steps: int
|
| 101 |
-
total_reward: float = 0.0
|
| 102 |
done: bool
|
| 103 |
incident_id: str
|
| 104 |
task_type: TaskType
|
| 105 |
difficulty: str
|
| 106 |
status: str
|
| 107 |
-
last_reward: float = 0.0
|
| 108 |
|
| 109 |
|
| 110 |
class ResetRequest(BaseModel):
|
|
|
|
| 47 |
step_count: int = 0
|
| 48 |
max_steps: int = 1
|
| 49 |
last_action_summary: Optional[str] = None
|
| 50 |
+
last_reward: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 51 |
episode_status: str = "awaiting_action"
|
| 52 |
|
| 53 |
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
class IncidentReward(BaseModel):
|
| 85 |
+
value: float = Field(..., gt=0.0, lt=1.0)
|
| 86 |
reason: str
|
| 87 |
|
| 88 |
|
|
|
|
| 98 |
session_id: Optional[str] = None
|
| 99 |
step_count: int
|
| 100 |
max_steps: int
|
| 101 |
+
total_reward: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 102 |
done: bool
|
| 103 |
incident_id: str
|
| 104 |
task_type: TaskType
|
| 105 |
difficulty: str
|
| 106 |
status: str
|
| 107 |
+
last_reward: float = Field(default=0.01, gt=0.0, lt=1.0)
|
| 108 |
|
| 109 |
|
| 110 |
class ResetRequest(BaseModel):
|
tests/test_env.py
CHANGED
|
@@ -82,7 +82,7 @@ class IncidentEnvApiTests(unittest.TestCase):
|
|
| 82 |
body = response.json()
|
| 83 |
self.assertEqual(body["observation"]["incident_id"], "INC-014")
|
| 84 |
self.assertEqual(body["observation"]["task_type"], "task3")
|
| 85 |
-
self.assertEqual(body["reward"]["value"], 0.
|
| 86 |
self.assertFalse(body["done"])
|
| 87 |
self.assertIn("session_id", body["info"])
|
| 88 |
self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
|
|
@@ -201,13 +201,13 @@ class IncidentEnvApiTests(unittest.TestCase):
|
|
| 201 |
session_id="done-session",
|
| 202 |
step_count=1,
|
| 203 |
max_steps=1,
|
| 204 |
-
total_reward=
|
| 205 |
done=True,
|
| 206 |
incident_id="INC-001",
|
| 207 |
task_type=TaskType.TASK1,
|
| 208 |
difficulty="easy",
|
| 209 |
status="completed",
|
| 210 |
-
last_reward=
|
| 211 |
)
|
| 212 |
|
| 213 |
with TestClient(app) as client:
|
|
|
|
| 82 |
body = response.json()
|
| 83 |
self.assertEqual(body["observation"]["incident_id"], "INC-014")
|
| 84 |
self.assertEqual(body["observation"]["task_type"], "task3")
|
| 85 |
+
self.assertEqual(body["reward"]["value"], 0.01)
|
| 86 |
self.assertFalse(body["done"])
|
| 87 |
self.assertIn("session_id", body["info"])
|
| 88 |
self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
|
|
|
|
| 201 |
session_id="done-session",
|
| 202 |
step_count=1,
|
| 203 |
max_steps=1,
|
| 204 |
+
total_reward=0.99,
|
| 205 |
done=True,
|
| 206 |
incident_id="INC-001",
|
| 207 |
task_type=TaskType.TASK1,
|
| 208 |
difficulty="easy",
|
| 209 |
status="completed",
|
| 210 |
+
last_reward=0.99,
|
| 211 |
)
|
| 212 |
|
| 213 |
with TestClient(app) as client:
|
tests/test_inference.py
CHANGED
|
@@ -19,7 +19,7 @@ class InferenceOutputTests(unittest.TestCase):
|
|
| 19 |
|
| 20 |
def test_write_results_writes_summary_to_configured_path(self) -> None:
|
| 21 |
results = [
|
| 22 |
-
{"incident_id": "INC-001", "task_type": "task1", "score":
|
| 23 |
{"incident_id": "INC-002", "task_type": "task2", "score": 0.5, "success": False},
|
| 24 |
]
|
| 25 |
|
|
@@ -30,13 +30,13 @@ class InferenceOutputTests(unittest.TestCase):
|
|
| 30 |
self.assertTrue(output_path.exists())
|
| 31 |
payload = json.loads(output_path.read_text())
|
| 32 |
self.assertEqual(payload["episodes"], 2)
|
| 33 |
-
self.assertAlmostEqual(payload["average_score"], 0.
|
| 34 |
-
self.assertEqual(payload["by_task"]["task1"]["average_score"],
|
| 35 |
self.assertEqual(payload["by_task"]["task2"]["average_score"], 0.5)
|
| 36 |
|
| 37 |
def test_write_results_tolerates_unwritable_path(self) -> None:
|
| 38 |
results = [
|
| 39 |
-
{"incident_id": "INC-001", "task_type": "task1", "score":
|
| 40 |
]
|
| 41 |
|
| 42 |
with tempfile.TemporaryDirectory() as temp_dir:
|
|
|
|
| 19 |
|
| 20 |
def test_write_results_writes_summary_to_configured_path(self) -> None:
|
| 21 |
results = [
|
| 22 |
+
{"incident_id": "INC-001", "task_type": "task1", "score": 0.99, "success": True},
|
| 23 |
{"incident_id": "INC-002", "task_type": "task2", "score": 0.5, "success": False},
|
| 24 |
]
|
| 25 |
|
|
|
|
| 30 |
self.assertTrue(output_path.exists())
|
| 31 |
payload = json.loads(output_path.read_text())
|
| 32 |
self.assertEqual(payload["episodes"], 2)
|
| 33 |
+
self.assertAlmostEqual(payload["average_score"], 0.745)
|
| 34 |
+
self.assertEqual(payload["by_task"]["task1"]["average_score"], 0.99)
|
| 35 |
self.assertEqual(payload["by_task"]["task2"]["average_score"], 0.5)
|
| 36 |
|
| 37 |
def test_write_results_tolerates_unwritable_path(self) -> None:
|
| 38 |
results = [
|
| 39 |
+
{"incident_id": "INC-001", "task_type": "task1", "score": 0.99, "success": True},
|
| 40 |
]
|
| 41 |
|
| 42 |
with tempfile.TemporaryDirectory() as temp_dir:
|