Spaces:
Sleeping
Sleeping
feat: initial environment setup
Browse files
server/WhyDidItFail_environment.py
CHANGED
|
@@ -4,13 +4,10 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
"""
|
| 8 |
-
Whydiditfail Environment Implementation.
|
| 9 |
-
|
| 10 |
-
A simple test environment that echoes back messages sent to it.
|
| 11 |
-
Perfect for testing HTTP server infrastructure.
|
| 12 |
-
"""
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from uuid import uuid4
|
| 15 |
|
| 16 |
from openenv.core.env_server.interfaces import Environment
|
|
@@ -18,87 +15,97 @@ from openenv.core.env_server.types import State
|
|
| 18 |
|
| 19 |
try:
|
| 20 |
from ..models import WhyDidItFailAction, WhyDidItFailObservation
|
|
|
|
| 21 |
except ImportError:
|
| 22 |
from models import WhyDidItFailAction, WhyDidItFailObservation
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
-
class
|
| 26 |
-
"""
|
| 27 |
-
A simple echo environment that echoes back messages.
|
| 28 |
-
|
| 29 |
-
This environment is designed for testing the HTTP server infrastructure.
|
| 30 |
-
It maintains minimal state and simply echoes back whatever message it receives.
|
| 31 |
-
|
| 32 |
-
Example:
|
| 33 |
-
>>> env = WhydiditfailEnvironment()
|
| 34 |
-
>>> obs = env.reset()
|
| 35 |
-
>>> print(obs.echoed_message) # "Whydiditfail environment ready!"
|
| 36 |
-
>>>
|
| 37 |
-
>>> obs = env.step(WhyDidItFailAction(message="Hello"))
|
| 38 |
-
>>> print(obs.echoed_message) # "Hello"
|
| 39 |
-
>>> print(obs.message_length) # 5
|
| 40 |
-
"""
|
| 41 |
|
| 42 |
-
# Enable concurrent WebSocket sessions.
|
| 43 |
-
# Set to True if your environment isolates state between instances.
|
| 44 |
-
# When True, multiple WebSocket clients can connect simultaneously, each
|
| 45 |
-
# getting their own environment instance (when using factory mode in app.py).
|
| 46 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 47 |
|
| 48 |
def __init__(self):
|
| 49 |
-
"""Initialize the WhyDidItFail environment."""
|
| 50 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 51 |
-
self.
|
|
|
|
| 52 |
|
| 53 |
-
def reset(self) -> WhyDidItFailObservation:
|
| 54 |
-
"""
|
| 55 |
-
Reset the environment.
|
| 56 |
-
|
| 57 |
-
Returns:
|
| 58 |
-
WhyDidItFailObservation with a ready message
|
| 59 |
-
"""
|
| 60 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 61 |
-
self.
|
| 62 |
-
|
| 63 |
return WhyDidItFailObservation(
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
|
| 70 |
-
def step(self, action: WhyDidItFailAction) -> WhyDidItFailObservation:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
Args:
|
| 75 |
-
action: WhyDidItFailAction containing the message to echo
|
| 76 |
-
|
| 77 |
-
Returns:
|
| 78 |
-
WhyDidItFailObservation with the echoed message and its length
|
| 79 |
-
"""
|
| 80 |
self._state.step_count += 1
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
return WhyDidItFailObservation(
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 94 |
)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
""
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
"""WhyDidItFail Environment Implementation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
import random
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
from uuid import uuid4
|
| 12 |
|
| 13 |
from openenv.core.env_server.interfaces import Environment
|
|
|
|
| 15 |
|
| 16 |
try:
|
| 17 |
from ..models import WhyDidItFailAction, WhyDidItFailObservation
|
| 18 |
+
from ..server.scenarios import SCENARIOS
|
| 19 |
except ImportError:
|
| 20 |
from models import WhyDidItFailAction, WhyDidItFailObservation
|
| 21 |
+
from server.scenarios import SCENARIOS
|
| 22 |
|
| 23 |
|
| 24 |
+
class WhyDidItFailEnvironment(Environment):
|
| 25 |
+
"""Diagnostic environment where the agent investigates a failed training run."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 28 |
|
| 29 |
def __init__(self):
|
|
|
|
| 30 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 31 |
+
self.scenario = None
|
| 32 |
+
self.inspected = set() # tracks what the agent has already looked at
|
| 33 |
|
| 34 |
+
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 36 |
+
self.scenario = random.choice(list(SCENARIOS.values()))
|
| 37 |
+
self.inspected = set()
|
| 38 |
return WhyDidItFailObservation(
|
| 39 |
+
task_description="A training run has failed. Diagnose the problem.",
|
| 40 |
+
visible_data={"hint": "Use inspect_logs or inspect_config to begin."},
|
| 41 |
+
available_actions=["inspect_logs","inspect_config",
|
| 42 |
+
"inspect_gradients","submit_diagnosis"],
|
| 43 |
+
steps_taken=0, reward=0.0, done=False,
|
| 44 |
+
feedback="Investigation started."
|
| 45 |
)
|
| 46 |
|
| 47 |
+
def step(self, action: WhyDidItFailAction, timeout_s: Optional[float] = None, **kwargs: Any) -> WhyDidItFailObservation:
|
| 48 |
+
if self.scenario is None:
|
| 49 |
+
raise RuntimeError("Environment must be reset before calling step.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
self._state.step_count += 1
|
| 51 |
|
| 52 |
+
if action.action_type == "inspect_logs":
|
| 53 |
+
self.inspected.add("logs")
|
| 54 |
+
visible = {"training_logs": self.scenario["logs"]}
|
| 55 |
+
feedback = "You examined the training logs."
|
| 56 |
+
|
| 57 |
+
elif action.action_type == "inspect_config":
|
| 58 |
+
self.inspected.add("config")
|
| 59 |
+
visible = {"config": self.scenario["config"]}
|
| 60 |
+
feedback = "You examined the hyperparameter config."
|
| 61 |
+
|
| 62 |
+
elif action.action_type == "inspect_gradients":
|
| 63 |
+
self.inspected.add("gradients")
|
| 64 |
+
visible = {"gradient_norms": self.scenario["gradient_norms"]}
|
| 65 |
+
feedback = "You examined gradient statistics."
|
| 66 |
+
|
| 67 |
+
elif action.action_type == "submit_diagnosis":
|
| 68 |
+
reward, feedback, done = self.grade(action)
|
| 69 |
+
return WhyDidItFailObservation(
|
| 70 |
+
task_description="Diagnosis submitted.",
|
| 71 |
+
visible_data={}, available_actions=[],
|
| 72 |
+
steps_taken=self._state.step_count,
|
| 73 |
+
reward=reward, done=True, feedback=feedback
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
visible = {}
|
| 78 |
+
feedback = f"Unknown action '{action.action_type}'."
|
| 79 |
|
| 80 |
return WhyDidItFailObservation(
|
| 81 |
+
task_description="Continue your investigation.",
|
| 82 |
+
visible_data=visible,
|
| 83 |
+
available_actions=["inspect_logs","inspect_config",
|
| 84 |
+
"inspect_gradients","submit_diagnosis"],
|
| 85 |
+
steps_taken=self._state.step_count,
|
| 86 |
+
reward=0.0, done=False, feedback=feedback
|
| 87 |
)
|
| 88 |
|
| 89 |
+
def grade(self, action: WhyDidItFailAction) -> tuple[float, str, bool]:
|
| 90 |
+
"""Score a submit_diagnosis action against the current scenario."""
|
| 91 |
+
if self.scenario is None:
|
| 92 |
+
raise RuntimeError("Environment must be reset before calling grade.")
|
| 93 |
+
diagnosis = (action.diagnosis or "").strip().lower()
|
| 94 |
+
correct_diagnosis = self.scenario["correct_diagnosis"].strip().lower()
|
| 95 |
+
correct_fix = (self.scenario.get("correct_fix") or "").strip().lower()
|
| 96 |
+
suggested_fix = (action.suggested_fix or "").strip().lower()
|
| 97 |
+
|
| 98 |
+
diagnosis_correct = diagnosis == correct_diagnosis
|
| 99 |
+
fix_correct = suggested_fix == correct_fix if correct_fix else True
|
| 100 |
+
|
| 101 |
+
if diagnosis_correct and fix_correct:
|
| 102 |
+
reward = 1.0
|
| 103 |
+
feedback = "Correct diagnosis and fix!"
|
| 104 |
+
elif diagnosis_correct:
|
| 105 |
+
reward = 0.5
|
| 106 |
+
feedback = f"Correct diagnosis, but the suggested fix was wrong. Expected: '{self.scenario.get('correct_fix')}'."
|
| 107 |
+
else:
|
| 108 |
+
reward = 0.0
|
| 109 |
+
feedback = f"Incorrect diagnosis. The actual failure mode was '{self.scenario['correct_diagnosis']}'."
|
| 110 |
+
|
| 111 |
+
return reward, feedback, True
|