File size: 2,330 Bytes
20bb95e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94eed90
 
 
 
20bb95e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for the Snake Environment.

The Snake environment is a multi-agent reinforcement learning environment
based on marlenv's Snake-v1. Multiple snakes battle on a fixed size grid map.
"""

from dataclasses import dataclass
from typing import Any, Dict, List, Optional

# Support both in-repo and standalone imports
try:
    # In-repo imports (when running from OpenEnv repository)
    from core.env_server.types import Action, Observation
except ImportError:
    # Standalone imports (when environment is standalone with openenv-core from pip)
    from openenv_core.env_server.types import Action, Observation


@dataclass(kw_only=True)
class SnakeAction(Action):
    """
    Action for the Snake environment.

    For single snake (observer='snake'):
        action: int in [0, 1, 2]
            0 = noop (continue in same direction)
            1 = turn left (90 degrees)
            2 = turn right (90 degrees)

    For single snake (observer='human'):
        action: int in [0, 1, 2, 3, 4]
            0 = noop
            1 = left
            2 = right
            3 = down
            4 = up
    """

    action: int

    def __post_init__(self):
        """Ensure action is converted to int (handles string inputs from web interface)."""
        self.action = int(self.action)


@dataclass(kw_only=True)
class SnakeObservation(Observation):
    """
    Observation from the Snake environment.

    Attributes:
        grid: The current game grid as a nested list (height x width)
        observation: The encoded observation for the snake (can be full grid or vision range)
        episode_score: Total score accumulated in this episode
        episode_steps: Number of steps taken in this episode
        episode_fruits: Number of fruits eaten in this episode
        episode_kills: Number of kills in this episode
        alive: Whether the snake is still alive
    """

    grid: List[List[int]]
    observation: List[List[List[float]]]  # H x W x C observation
    episode_score: float = 0.0
    episode_steps: int = 0
    episode_fruits: int = 0
    episode_kills: int = 0
    alive: bool = True