File size: 9,820 Bytes
572e42a
 
b37875f
 
 
 
 
ff8ce5f
04666da
e216a2f
 
b37875f
572e42a
b37875f
 
 
 
e216a2f
a818334
1288c52
b37875f
3613ecf
ff8ce5f
 
 
 
 
 
 
 
 
 
3613ecf
572e42a
a0518e7
a818334
e216a2f
 
1288c52
e216a2f
 
 
 
 
 
1288c52
 
 
 
87037e2
e216a2f
 
 
 
 
 
 
 
d3b224f
e216a2f
 
b37875f
 
572e42a
 
 
e216a2f
b37875f
1288c52
d3b224f
1288c52
 
 
 
 
 
d3b224f
1288c52
 
 
d3b224f
1288c52
 
e216a2f
b37875f
572e42a
e216a2f
a818334
 
e216a2f
 
 
 
 
 
 
 
 
572e42a
 
e216a2f
a818334
 
e216a2f
 
 
 
 
 
 
 
 
572e42a
 
e216a2f
a818334
 
e216a2f
 
 
 
 
 
 
 
 
572e42a
 
e216a2f
572e42a
 
e216a2f
 
572e42a
e216a2f
 
 
572e42a
 
 
e216a2f
 
 
 
 
d3b224f
e216a2f
d3b224f
e216a2f
b37875f
a818334
d3b224f
 
a818334
e216a2f
a818334
 
d3b224f
 
 
 
a818334
 
d3b224f
a818334
e216a2f
a818334
 
 
 
d3b224f
e216a2f
 
 
a818334
e216a2f
a818334
d29cfdb
a818334
d29cfdb
 
 
a818334
 
e216a2f
 
 
 
a818334
e216a2f
a818334
e216a2f
 
 
 
 
572e42a
a818334
e216a2f
b37875f
 
e216a2f
 
 
 
572e42a
e216a2f
572e42a
e216a2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import random
from typing import Any, Optional
from uuid import uuid4

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from models import WhyDidItFailAction, WhyDidItFailObservation, WhyDidItFailState
from server.scenarios import SCENARIOS
from server.graders import grade


class WhyDidItFailEnvironment(Environment):
    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self.scenario: dict | None = None
        self.inspection_order: list[str] = []  # first-visit order; doubles as membership check
        self.max_steps: int = 0

    @property
    def state(self) -> WhyDidItFailState:
        return WhyDidItFailState(
            episode_id=self._state.episode_id,
            step_count=self._state.step_count,
            scenario_key=self.scenario.get("failure_mode") if self.scenario else None,
            difficulty=self.scenario.get("difficulty") if self.scenario else None,
            inspection_order=list(self.inspection_order),
            required_sources=list(self.scenario.get("required_sources", [])) if self.scenario else [],
            max_steps=self.max_steps,
        )

    def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any) -> WhyDidItFailObservation:
        self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
        self.inspection_order = []

        scenario_key = kwargs.get("scenario_key")

        if scenario_key and scenario_key in SCENARIOS:
            self.scenario = SCENARIOS[scenario_key]
        else:
            if seed is not None:
                random.seed(seed)
            self.scenario = random.choice(list(SCENARIOS.values()))

        required_sources = self.scenario.get("required_sources", ["logs"])
        self.max_steps = len(required_sources) * 3 + 2

        return WhyDidItFailObservation(
            task_description=(
                "A training run has failed. Diagnose the root cause.\n"
                f"Difficulty: {self.scenario['difficulty']}. "
                "Available actions: inspect_logs, inspect_config, inspect_gradients, submit_diagnosis."
            ),
            visible_data={"hint": "Start by inspecting the training logs."},
            available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
            steps_taken=0,
            reward=0.10,
            done=False,
            feedback="Investigation started.",
        )

    def step(self, action: WhyDidItFailAction, timeout_s: Optional[float] = None, **kwargs: Any) -> WhyDidItFailObservation:
        if self.scenario is None:
            raise RuntimeError("Environment must be reset before calling step.")

        self._state.step_count += 1

        # Hard step limit — terminate immediately, grade() will return 0.10.
        if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
            return WhyDidItFailObservation(
                task_description="Step limit reached. Episode terminated.",
                visible_data={},
                available_actions=[],
                steps_taken=self._state.step_count,
                reward=0.10,
                done=True,
                feedback=(
                    f"Step limit ({self.max_steps}) reached without a diagnosis. "
                    f"Score: 0.10. Actual failure: '{self.scenario['correct_diagnosis']}'."
                ),
            )
        required: list[str] = self.scenario.get("required_sources", ["logs"])

        if action.action_type == "inspect_logs":
            step_reward = self._inspect_reward("logs", required)
            if "logs" not in self.inspection_order:
                self.inspection_order.append("logs")
            return WhyDidItFailObservation(
                task_description="Continue your investigation.",
                visible_data={"training_logs": self.scenario["logs"]},
                available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
                steps_taken=self._state.step_count,
                reward=step_reward,
                done=False,
                feedback=self._inspect_feedback("logs", required, step_reward),
            )

        elif action.action_type == "inspect_config":
            step_reward = self._inspect_reward("config", required)
            if "config" not in self.inspection_order:
                self.inspection_order.append("config")
            return WhyDidItFailObservation(
                task_description="Continue your investigation.",
                visible_data={"config": self.scenario["config"]},
                available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
                steps_taken=self._state.step_count,
                reward=step_reward,
                done=False,
                feedback=self._inspect_feedback("config", required, step_reward),
            )

        elif action.action_type == "inspect_gradients":
            step_reward = self._inspect_reward("gradients", required)
            if "gradients" not in self.inspection_order:
                self.inspection_order.append("gradients")
            return WhyDidItFailObservation(
                task_description="Continue your investigation.",
                visible_data={"gradient_norms": self.scenario["gradient_norms"]},
                available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
                steps_taken=self._state.step_count,
                reward=step_reward,
                done=False,
                feedback=self._inspect_feedback("gradients", required, step_reward),
            )

        elif action.action_type == "submit_diagnosis":
            final_reward, feedback = self._grade(action)
            return WhyDidItFailObservation(
                task_description="Diagnosis submitted.",
                visible_data={},
                available_actions=[],
                steps_taken=self._state.step_count,
                reward=final_reward,
                done=True,
                feedback=feedback,
            )

        else:
            return WhyDidItFailObservation(
                task_description="Continue your investigation.",
                visible_data={},
                available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
                steps_taken=self._state.step_count,
                reward=0.10,
                done=False,
                feedback=f"Unknown action '{action.action_type}'. Minimum reward.",
            )

    # Rewards decay as more required sources are discovered — first clue is worth most.
    # All values are in [0.10, 0.90] — no negative rewards.
    _REQUIRED_STEP_REWARDS = [0.50, 0.30, 0.15]

    def _inspect_reward(self, source: str, required: list[str]) -> float:
        """Return step reward for an inspect action.

        Required sources:   progressive — 0.50 / 0.30 / 0.15 for 1st/2nd/3rd discovery.
        Irrelevant sources: 0.10 (minimum; mild penalty via contrast with required rewards).
        Re-inspection:      0.10 (minimum; waste with no new information).
        All values are strictly in [0.10, 0.90].
        """
        if source in self.inspection_order:
            return 0.10   # redundant inspection — minimum reward

        if source in required:
            n_found = sum(1 for s in self.inspection_order if s in required)
            idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1)
            return self._REQUIRED_STEP_REWARDS[idx]

        return 0.10       # irrelevant source — minimum reward

    def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
        label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
        if source in self.inspection_order:
            return f"You already examined the {label}. No new information gained."
        if source in required:
            remaining_sources = [s for s in required if s not in self.inspection_order and s != source]
            msg = f"You examined the {label}. Relevant clue found (+{reward:.2f})."
            if remaining_sources:
                next_source = f"inspect_{remaining_sources[0]}"
                msg += f" {len(remaining_sources)} required source(s) still unexamined. Next required action: {next_source}."
            return msg
        return f"You examined the {label}. This source is not required for this failure mode."

    def _grade(self, action: WhyDidItFailAction) -> tuple[float, str]:
        """Delegate to the unified grade() function and return (reward, feedback)."""
        assert self.scenario is not None
        diagnosis     = (action.diagnosis or "").strip().lower()
        suggested_fix = (action.suggested_fix or "").strip().lower() or None
        difficulty    = self.scenario["difficulty"]

        reward = grade(
            diagnosis=diagnosis,
            suggested_fix=suggested_fix,
            scenario=self.scenario,
            steps_taken=self._state.step_count,
            inspection_order=self.inspection_order,
            difficulty=difficulty,
        )

        if reward >= 0.80:
            feedback = f"Excellent diagnosis! Score: {reward:.2f}"
        elif reward >= 0.50:
            feedback = f"Partially correct. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'."
        else:
            feedback = f"Incorrect diagnosis. Score: {reward:.2f}. Actual failure: '{self.scenario['correct_diagnosis']}'."

        return reward, feedback