# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the Pokemon Red OpenEnv Environment. This module defines the Action, Observation, and State types for the Pokemon Red Game Boy emulator environment using PyBoy. """ from __future__ import annotations from typing import Any, Dict, List, Optional from pydantic import Field from openenv.core.env_server.types import Action, Observation, State class PokemonRedAction(Action): """ Action for Pokemon Red environment. Represents a single Game Boy button press. Attributes: action: Discrete action index (0-6) mapping to: 0=Down, 1=Left, 2=Right, 3=Up, 4=A, 5=B, 6=Start """ action: int = Field( default=0, ge=0, le=6, description="Button index: 0=Down, 1=Left, 2=Right, 3=Up, 4=A, 5=B, 6=Start", ) class PokemonRedObservation(Observation): """ Multi-modal observation from Pokemon Red environment. Contains the game screen, player stats, and game state information suitable for both vision-based and structured RL agents. Attributes: screen_b64: Base64-encoded PNG of the current game frame. screen_shape: Shape of the screen array [height, width, channels]. health: Party HP fraction in range [0, 1]. level_sum: Sum of all party Pokemon levels. badges: 8-element list of badge flags (0 or 1). position: Current position as [x, y, map_id]. in_battle: Whether player is currently in a battle. seen_coords_count: Number of unique coordinates visited (exploration metric). legal_actions: List of valid action indices. """ screen_b64: str = Field(default="", description="Base64-encoded PNG frame") screen_shape: List[int] = Field( default_factory=lambda: [144, 160, 3], description="Frame dimensions [H,W,C]" ) health: float = Field(default=0.0, ge=0.0, le=1.0, description="Party HP fraction") level_sum: int = Field(default=0, ge=0, description="Sum of party Pokemon levels") badges: List[int] = Field( default_factory=lambda: [0] * 8, description="8-element badge flags" ) position: List[int] = Field( default_factory=lambda: [0, 0, 0], description="Player position [x, y, map_id]" ) in_battle: bool = Field(default=False, description="Whether in battle") seen_coords_count: int = Field( default=0, description="Unique coordinates visited (exploration)" ) legal_actions: List[int] = Field( default_factory=lambda: list(range(7)), description="Valid action indices" ) class PokemonRedState(State): """ State for Pokemon Red environment. Tracks episode progress and game state for the server. Attributes: total_reward: Cumulative reward for current episode. reset_count: Number of times environment has been reset. badges_obtained: Number of gym badges obtained (0-8). max_level_sum: Maximum level sum achieved this episode. events_triggered: Count of game event flags triggered. """ total_reward: float = Field(default=0.0, description="Cumulative episode reward") reset_count: int = Field(default=0, description="Number of resets") badges_obtained: int = Field(default=0, ge=0, le=8, description="Badges count") max_level_sum: int = Field(default=0, description="Max level sum achieved") events_triggered: int = Field(default=0, description="Event flags triggered")