File size: 2,292 Bytes
c80e50a
12d85aa
c80e50a
12d85aa
 
 
 
 
c80e50a
12d85aa
 
c80e50a
12d85aa
c80e50a
12d85aa
 
 
c80e50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d85aa
 
c80e50a
 
12d85aa
 
 
 
 
 
 
 
104c835
 
12d85aa
 
 
c80e50a
 
 
 
 
 
12d85aa
 
c80e50a
 
12d85aa
 
c80e50a
12d85aa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""PRobe Environment Client."""

from __future__ import annotations

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

from .models import ProbeAction, ProbeObservation


class ProbeEnv(EnvClient[ProbeAction, ProbeObservation, State]):
    """
    Client for the PRobe environment.

    Maintains a persistent WebSocket connection to the server.

    Example::

        with ProbeEnv(base_url="http://localhost:8000") as env:
            result = env.reset()
            print(result.observation.task_description)

            action = ProbeAction(
                action_type="add_comment",
                line_number=4,
                comment="Off-by-one: range(len+1) causes IndexError",
                severity="error",
                category="bug",
            )
            result = env.step(action)
            print(result.reward)
    """

    def _step_payload(self, action: ProbeAction) -> dict:
        payload: dict = {"action_type": action.action_type.value}
        if action.line_number is not None:
            payload["line_number"] = action.line_number
        if action.comment is not None:
            payload["comment"] = action.comment
        if action.severity is not None:
            payload["severity"] = action.severity.value
        if action.category is not None:
            payload["category"] = action.category.value
        if action.classification is not None:
            payload["classification"] = action.classification.value
        return payload

    def _parse_result(
        self, payload: dict
    ) -> StepResult[ProbeObservation]:
        obs_data: dict = payload.get("observation", {})
        # Use model_validate so new fields added to ProbeObservation
        # are picked up automatically without changing this method.
        observation = ProbeObservation.model_validate(obs_data)
        return StepResult(
            observation=observation,
            reward=float(payload.get("reward") or 0.0),
            done=bool(payload.get("done", False)),
        )

    def _parse_state(self, payload: dict) -> State:
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )