Paulito Palmes, PhD commited on
Commit
f6d455c
·
1 Parent(s): 40035ea
Files changed (2) hide show
  1. app.py +5 -10
  2. snake_environment.py +10 -15
app.py CHANGED
@@ -22,16 +22,11 @@ Usage:
22
  """
23
 
24
  # Support both in-repo and standalone imports
25
- try:
26
- # In-repo imports (when running from OpenEnv repository)
27
- from core.env_server.http_server import create_app
28
- from ..models import SnakeAction, SnakeObservation
29
- from .snake_environment import SnakeEnvironment
30
- except ImportError:
31
- # Standalone imports (when environment is standalone with openenv-core from pip)
32
- from openenv_core.env_server.http_server import create_app
33
- from models import SnakeAction, SnakeObservation
34
- from server.snake_environment import SnakeEnvironment
35
 
36
  # Create the environment instance
37
  env = SnakeEnvironment()
 
22
  """
23
 
24
  # Support both in-repo and standalone imports
25
+ # In-repo imports (when running from OpenEnv repository)
26
+ # Standalone imports (when environment is standalone with openenv-core from pip)
27
+ from openenv_core.env_server.http_server import create_app
28
+ from envs.snake_env import SnakeAction, SnakeObservation
29
+ from snake_environment import SnakeEnvironment
 
 
 
 
 
30
 
31
  # Create the environment instance
32
  env = SnakeEnvironment()
snake_environment.py CHANGED
@@ -19,18 +19,13 @@ import marlenv.envs # Register marlenv environments with gym
19
  import numpy as np
20
 
21
  # Support both in-repo and standalone imports
22
- try:
23
- # In-repo imports (when running from OpenEnv repository)
24
- from core.env_server.interfaces import Environment
25
- from core.env_server.types import State
26
 
27
- from ..models import SnakeAction, SnakeObservation
28
- except ImportError:
29
- from models import SnakeAction, SnakeObservation
30
-
31
- # Standalone imports (when environment is standalone with openenv-core from pip)
32
- from openenv_core.env_server.interfaces import Environment
33
- from openenv_core.env_server.types import State
34
 
35
 
36
  class SingleAgentWrapper(gym.Wrapper):
@@ -44,15 +39,15 @@ class SingleAgentWrapper(gym.Wrapper):
44
  def __init__(self, env):
45
  super().__init__(env)
46
  # Unwrap observation and action spaces for single agent
47
- if hasattr(env.observation_space, '__getitem__'):
48
  self.observation_space = env.observation_space[0]
49
- if hasattr(env.action_space, '__getitem__'):
50
  self.action_space = env.action_space[0]
51
 
52
  def reset(self, **kwargs):
53
  obs = self.env.reset(**kwargs)
54
  # Remove first dimension if it's a multi-agent array (num_agents, H, W, C)
55
- if hasattr(obs, 'shape') and len(obs.shape) == 4 and obs.shape[0] == 1:
56
  return obs[0] # Return (H, W, C)
57
  # Return first agent's observation if it's a list
58
  if isinstance(obs, list):
@@ -65,7 +60,7 @@ class SingleAgentWrapper(gym.Wrapper):
65
 
66
  # Unwrap returns for single agent
67
  # Handle observation: remove first dimension if shape is (1, H, W, C)
68
- if hasattr(obs, 'shape') and len(obs.shape) == 4 and obs.shape[0] == 1:
69
  obs = obs[0] # Convert (1, H, W, C) -> (H, W, C)
70
  elif isinstance(obs, list):
71
  obs = obs[0]
 
19
  import numpy as np
20
 
21
  # Support both in-repo and standalone imports
22
+ # In-repo imports (when running from OpenEnv repository)
23
+ from core.env_server.interfaces import Environment
24
+ from core.env_server.types import State
25
+ from envs.snake_env import SnakeAction, SnakeObservation
26
 
27
+ # from openenv_core.env_server.interfaces import Environment
28
+ # from openenv_core.env_server.types import State
 
 
 
 
 
29
 
30
 
31
  class SingleAgentWrapper(gym.Wrapper):
 
39
  def __init__(self, env):
40
  super().__init__(env)
41
  # Unwrap observation and action spaces for single agent
42
+ if hasattr(env.observation_space, "__getitem__"):
43
  self.observation_space = env.observation_space[0]
44
+ if hasattr(env.action_space, "__getitem__"):
45
  self.action_space = env.action_space[0]
46
 
47
  def reset(self, **kwargs):
48
  obs = self.env.reset(**kwargs)
49
  # Remove first dimension if it's a multi-agent array (num_agents, H, W, C)
50
+ if hasattr(obs, "shape") and len(obs.shape) == 4 and obs.shape[0] == 1:
51
  return obs[0] # Return (H, W, C)
52
  # Return first agent's observation if it's a list
53
  if isinstance(obs, list):
 
60
 
61
  # Unwrap returns for single agent
62
  # Handle observation: remove first dimension if shape is (1, H, W, C)
63
+ if hasattr(obs, "shape") and len(obs.shape) == 4 and obs.shape[0] == 1:
64
  obs = obs[0] # Convert (1, H, W, C) -> (H, W, C)
65
  elif isinstance(obs, list):
66
  obs = obs[0]