File size: 2,941 Bytes
b37875f
 
 
 
 
 
90c8812
b37875f
 
 
 
 
 
ff8ce5f
b37875f
 
ff8ce5f
b37875f
90c8812
b37875f
90c8812
 
b37875f
 
90c8812
 
 
b37875f
90c8812
 
 
 
 
 
 
 
 
 
b37875f
 
87037e2
90c8812
 
b37875f
87037e2
90c8812
b37875f
87037e2
90c8812
 
 
 
d3b224f
90c8812
 
b37875f
 
 
 
 
 
 
ff8ce5f
 
 
b37875f
 
ff8ce5f
 
 
 
 
90c8812
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""WhyDidItFail Environment Client."""

from typing import Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

from models import WhyDidItFailAction, WhyDidItFailObservation, WhyDidItFailState


class WhyDidItFailEnv(EnvClient[WhyDidItFailAction, WhyDidItFailObservation, WhyDidItFailState]):
    """
    Client for the WhyDidItFail Environment.

    Maintains a persistent WebSocket connection to the environment server.
    Each instance has its own dedicated session.

    Example:
        >>> with WhyDidItFailEnv(base_url="http://localhost:8000") as env:
        ...     result = env.reset()
        ...     print(result.observation.task_description)
        ...
        ...     action = WhyDidItFailAction(action_type="inspect_logs")
        ...     result = env.step(action)
        ...     print(result.observation.visible_data)
        ...
        ...     action = WhyDidItFailAction(
        ...         action_type="submit_diagnosis",
        ...         diagnosis="exploding gradients"
        ...     )
        ...     result = env.step(action)
        ...     print(result.observation.feedback, result.reward)
    """

    def _step_payload(self, action: WhyDidItFailAction) -> Dict:
        """Convert WhyDidItFailAction to JSON payload."""
        return action.model_dump(exclude_none=True)

    def _parse_result(self, payload: Dict) -> StepResult[WhyDidItFailObservation]:
        """Parse server response into StepResult[WhyDidItFailObservation]."""
        obs_data = payload.get("observation", {})
        observation = WhyDidItFailObservation(
            task_description=obs_data.get("task_description", ""),
            visible_data=obs_data.get("visible_data", {}),
            available_actions=obs_data.get("available_actions", []),
            steps_taken=obs_data.get("steps_taken", 0),
            reward=obs_data.get("reward", 0.10),
            done=obs_data.get("done", False),
            feedback=obs_data.get("feedback", ""),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> WhyDidItFailState:
        """Parse server response into WhyDidItFailState."""
        return WhyDidItFailState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            scenario_key=payload.get("scenario_key"),
            difficulty=payload.get("difficulty"),
            inspection_order=payload.get("inspection_order", []),
            required_sources=payload.get("required_sources", []),
            max_steps=payload.get("max_steps", 0),
        )