Spaces:
Running
Running
File size: 4,470 Bytes
ac5cfba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Pokemon Red Environment Client.
This module provides the client for connecting to a Pokemon Red Environment server
via WebSocket for persistent sessions.
"""
from __future__ import annotations
from typing import Any, Dict, TYPE_CHECKING
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from .models import PokemonRedAction, PokemonRedObservation, PokemonRedState
if TYPE_CHECKING:
from openenv.core.containers.runtime import ContainerProvider
class PokemonRedEnv(EnvClient[PokemonRedAction, PokemonRedObservation, PokemonRedState]):
"""
Client for Pokemon Red Environment.
This client maintains a persistent WebSocket connection to the environment
server, enabling efficient multi-step interactions with lower latency.
Example:
>>> # Connect to a running server
>>> with PokemonRedEnv(base_url="http://localhost:8000") as client:
... result = client.reset()
... print(result.observation.screen_shape)
...
... result = client.step(PokemonRedAction(action=4)) # Press A
... print(result.reward, result.done)
Example with Docker:
>>> # Automatically start container and connect
>>> client = PokemonRedEnv.from_docker_image("pokemonred-env:latest")
>>> try:
... result = client.reset()
... result = client.step(PokemonRedAction(action=0)) # Press Down
... finally:
... client.close()
Example from HuggingFace Hub:
>>> # Connect to hosted environment
>>> client = PokemonRedEnv.from_hub("openenv/pokemonred")
>>> with client:
... result = client.reset()
... for _ in range(100):
... action = PokemonRedAction(action=random.randint(0, 6))
... result = client.step(action)
"""
def _step_payload(self, action: PokemonRedAction) -> Dict[str, Any]:
"""
Convert PokemonRedAction to JSON payload for step request.
Args:
action: PokemonRedAction instance.
Returns:
Dictionary representation suitable for JSON encoding.
"""
return {"action": action.action}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[PokemonRedObservation]:
"""
Parse server response into StepResult[PokemonRedObservation].
Args:
payload: JSON response from server.
Returns:
StepResult with PokemonRedObservation.
"""
obs_data = payload.get("observation", {})
observation = PokemonRedObservation(
screen_b64=obs_data.get("screen_b64", ""),
screen_shape=obs_data.get("screen_shape", [144, 160, 3]),
health=obs_data.get("health", 0.0),
level_sum=obs_data.get("level_sum", 0),
badges=obs_data.get("badges", [0] * 8),
position=obs_data.get("position", [0, 0, 0]),
in_battle=obs_data.get("in_battle", False),
seen_coords_count=obs_data.get("seen_coords_count", 0),
legal_actions=obs_data.get("legal_actions", list(range(7))),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> PokemonRedState:
"""
Parse server response into PokemonRedState object.
Args:
payload: JSON response from /state endpoint.
Returns:
PokemonRedState object with environment state information.
"""
return PokemonRedState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
total_reward=payload.get("total_reward", 0.0),
reset_count=payload.get("reset_count", 0),
badges_obtained=payload.get("badges_obtained", 0),
max_level_sum=payload.get("max_level_sum", 0),
events_triggered=payload.get("events_triggered", 0),
)
|