File size: 2,504 Bytes
443c22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Pyre Environment Client."""

from typing import Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

from .models import PyreAction, PyreObservation, PyreState


class PyreEnv(EnvClient[PyreAction, PyreObservation, PyreState]):
    """Client for the Pyre Environment.



    The environment is async by default; use .sync() for synchronous access:



        with PyreEnv(base_url="http://localhost:8000").sync() as env:

            result = env.reset()

            print(result.observation.narrative)

            result = env.step(PyreAction(action="move", direction="north"))

            print(f"Health: {result.observation.agent_health}")



    Or use async:



        async with PyreEnv(base_url="http://localhost:8000") as env:

            result = await env.reset()

    """

    def _step_payload(self, action: PyreAction) -> Dict:
        return action.model_dump(exclude_none=True)

    def _parse_result(self, payload: Dict) -> StepResult[PyreObservation]:
        obs_data = payload.get("observation", payload)
        obs = PyreObservation(
            narrative=obs_data.get("narrative", ""),
            agent_evacuated=obs_data.get("agent_evacuated", False),
            location_label=obs_data.get("location_label", ""),
            smoke_level=obs_data.get("smoke_level", "none"),
            fire_visible=obs_data.get("fire_visible", False),
            fire_direction=obs_data.get("fire_direction"),
            agent_health=obs_data.get("agent_health", 100.0),
            health_status=obs_data.get("health_status", "Good"),
            wind_dir=obs_data.get("wind_dir", "CALM"),
            visible_objects=obs_data.get("visible_objects", []),
            blocked_exit_ids=obs_data.get("blocked_exit_ids", []),
            audible_signals=obs_data.get("audible_signals", []),
            elapsed_steps=obs_data.get("elapsed_steps", 0),
            last_action_feedback=obs_data.get("last_action_feedback", ""),
            available_actions_hint=obs_data.get("available_actions_hint", []),
            done=payload.get("done", False),
            reward=payload.get("reward", 0.0),
            metadata=payload.get("metadata", {}),
        )
        return StepResult(
            observation=obs,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> PyreState:
        return PyreState(**payload)