samrat-rm commited on
Commit
572e42a
·
1 Parent(s): d08def9

feat: initial environment setup

Browse files
Files changed (1) hide show
  1. server/WhyDidItFail_environment.py +77 -70
server/WhyDidItFail_environment.py CHANGED
@@ -4,13 +4,10 @@
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- """
8
- Whydiditfail Environment Implementation.
9
-
10
- A simple test environment that echoes back messages sent to it.
11
- Perfect for testing HTTP server infrastructure.
12
- """
13
 
 
 
14
  from uuid import uuid4
15
 
16
  from openenv.core.env_server.interfaces import Environment
@@ -18,87 +15,97 @@ from openenv.core.env_server.types import State
18
 
19
  try:
20
  from ..models import WhyDidItFailAction, WhyDidItFailObservation
 
21
  except ImportError:
22
  from models import WhyDidItFailAction, WhyDidItFailObservation
 
23
 
24
 
25
- class WhydiditfailEnvironment(Environment):
26
- """
27
- A simple echo environment that echoes back messages.
28
-
29
- This environment is designed for testing the HTTP server infrastructure.
30
- It maintains minimal state and simply echoes back whatever message it receives.
31
-
32
- Example:
33
- >>> env = WhydiditfailEnvironment()
34
- >>> obs = env.reset()
35
- >>> print(obs.echoed_message) # "Whydiditfail environment ready!"
36
- >>>
37
- >>> obs = env.step(WhyDidItFailAction(message="Hello"))
38
- >>> print(obs.echoed_message) # "Hello"
39
- >>> print(obs.message_length) # 5
40
- """
41
 
42
- # Enable concurrent WebSocket sessions.
43
- # Set to True if your environment isolates state between instances.
44
- # When True, multiple WebSocket clients can connect simultaneously, each
45
- # getting their own environment instance (when using factory mode in app.py).
46
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
47
 
48
  def __init__(self):
49
- """Initialize the WhyDidItFail environment."""
50
  self._state = State(episode_id=str(uuid4()), step_count=0)
51
- self._reset_count = 0
 
52
 
53
- def reset(self) -> WhyDidItFailObservation:
54
- """
55
- Reset the environment.
56
-
57
- Returns:
58
- WhyDidItFailObservation with a ready message
59
- """
60
  self._state = State(episode_id=str(uuid4()), step_count=0)
61
- self._reset_count += 1
62
-
63
  return WhyDidItFailObservation(
64
- echoed_message="Whydiditfail environment ready!",
65
- message_length=0,
66
- done=False,
67
- reward=0.0,
 
 
68
  )
69
 
70
- def step(self, action: WhyDidItFailAction) -> WhyDidItFailObservation: # type: ignore[override]
71
- """
72
- Execute a step in the environment by echoing the message.
73
-
74
- Args:
75
- action: WhyDidItFailAction containing the message to echo
76
-
77
- Returns:
78
- WhyDidItFailObservation with the echoed message and its length
79
- """
80
  self._state.step_count += 1
81
 
82
- message = action.message
83
- length = len(message)
84
-
85
- # Simple reward: longer messages get higher rewards
86
- reward = length * 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  return WhyDidItFailObservation(
89
- echoed_message=message,
90
- message_length=length,
91
- done=False,
92
- reward=reward,
93
- metadata={"original_message": message, "step": self._state.step_count},
 
94
  )
95
 
96
- @property
97
- def state(self) -> State:
98
- """
99
- Get the current environment state.
100
-
101
- Returns:
102
- Current State with episode_id and step_count
103
- """
104
- return self._state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ """WhyDidItFail Environment Implementation."""
 
 
 
 
 
8
 
9
+ import random
10
+ from typing import Any, Optional
11
  from uuid import uuid4
12
 
13
  from openenv.core.env_server.interfaces import Environment
 
15
 
16
  try:
17
  from ..models import WhyDidItFailAction, WhyDidItFailObservation
18
+ from ..server.scenarios import SCENARIOS
19
  except ImportError:
20
  from models import WhyDidItFailAction, WhyDidItFailObservation
21
+ from server.scenarios import SCENARIOS
22
 
23
 
24
+ class WhyDidItFailEnvironment(Environment):
25
+ """Diagnostic environment where the agent investigates a failed training run."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
27
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
28
 
29
  def __init__(self):
 
30
  self._state = State(episode_id=str(uuid4()), step_count=0)
31
+ self.scenario = None
32
+ self.inspected = set() # tracks what the agent has already looked at
33
 
34
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
 
 
 
 
 
 
35
  self._state = State(episode_id=str(uuid4()), step_count=0)
36
+ self.scenario = random.choice(list(SCENARIOS.values()))
37
+ self.inspected = set()
38
  return WhyDidItFailObservation(
39
+ task_description="A training run has failed. Diagnose the problem.",
40
+ visible_data={"hint": "Use inspect_logs or inspect_config to begin."},
41
+ available_actions=["inspect_logs","inspect_config",
42
+ "inspect_gradients","submit_diagnosis"],
43
+ steps_taken=0, reward=0.0, done=False,
44
+ feedback="Investigation started."
45
  )
46
 
47
+ def step(self, action: WhyDidItFailAction, timeout_s: Optional[float] = None, **kwargs: Any) -> WhyDidItFailObservation:
48
+ if self.scenario is None:
49
+ raise RuntimeError("Environment must be reset before calling step.")
 
 
 
 
 
 
 
50
  self._state.step_count += 1
51
 
52
+ if action.action_type == "inspect_logs":
53
+ self.inspected.add("logs")
54
+ visible = {"training_logs": self.scenario["logs"]}
55
+ feedback = "You examined the training logs."
56
+
57
+ elif action.action_type == "inspect_config":
58
+ self.inspected.add("config")
59
+ visible = {"config": self.scenario["config"]}
60
+ feedback = "You examined the hyperparameter config."
61
+
62
+ elif action.action_type == "inspect_gradients":
63
+ self.inspected.add("gradients")
64
+ visible = {"gradient_norms": self.scenario["gradient_norms"]}
65
+ feedback = "You examined gradient statistics."
66
+
67
+ elif action.action_type == "submit_diagnosis":
68
+ reward, feedback, done = self.grade(action)
69
+ return WhyDidItFailObservation(
70
+ task_description="Diagnosis submitted.",
71
+ visible_data={}, available_actions=[],
72
+ steps_taken=self._state.step_count,
73
+ reward=reward, done=True, feedback=feedback
74
+ )
75
+
76
+ else:
77
+ visible = {}
78
+ feedback = f"Unknown action '{action.action_type}'."
79
 
80
  return WhyDidItFailObservation(
81
+ task_description="Continue your investigation.",
82
+ visible_data=visible,
83
+ available_actions=["inspect_logs","inspect_config",
84
+ "inspect_gradients","submit_diagnosis"],
85
+ steps_taken=self._state.step_count,
86
+ reward=0.0, done=False, feedback=feedback
87
  )
88
 
89
+ def grade(self, action: WhyDidItFailAction) -> tuple[float, str, bool]:
90
+ """Score a submit_diagnosis action against the current scenario."""
91
+ if self.scenario is None:
92
+ raise RuntimeError("Environment must be reset before calling grade.")
93
+ diagnosis = (action.diagnosis or "").strip().lower()
94
+ correct_diagnosis = self.scenario["correct_diagnosis"].strip().lower()
95
+ correct_fix = (self.scenario.get("correct_fix") or "").strip().lower()
96
+ suggested_fix = (action.suggested_fix or "").strip().lower()
97
+
98
+ diagnosis_correct = diagnosis == correct_diagnosis
99
+ fix_correct = suggested_fix == correct_fix if correct_fix else True
100
+
101
+ if diagnosis_correct and fix_correct:
102
+ reward = 1.0
103
+ feedback = "Correct diagnosis and fix!"
104
+ elif diagnosis_correct:
105
+ reward = 0.5
106
+ feedback = f"Correct diagnosis, but the suggested fix was wrong. Expected: '{self.scenario.get('correct_fix')}'."
107
+ else:
108
+ reward = 0.0
109
+ feedback = f"Incorrect diagnosis. The actual failure mode was '{self.scenario['correct_diagnosis']}'."
110
+
111
+ return reward, feedback, True