trioskosmos commited on
Commit
c1209b4
·
verified ·
1 Parent(s): 2badd2f

Upload ai/environments/vec_env_adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/environments/vec_env_adapter.py +191 -0
ai/environments/vec_env_adapter.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ from gymnasium import spaces
5
+ from stable_baselines3.common.vec_env import VecEnv
6
+
7
+ # RUST Engine Toggle
8
+ USE_RUST_ENGINE = os.getenv("USE_RUST_ENGINE", "0") == "1"
9
+
10
+ if USE_RUST_ENGINE:
11
+ print(" [VecEnvAdapter] RUST Engine ENABLED (USE_RUST_ENGINE=1)")
12
+ from ai.vec_env_rust import RustVectorEnv
13
+
14
+ # Wrapper to inject MCTS_SIMS from env
15
+ class VectorEnvAdapter(RustVectorEnv):
16
+ def __init__(self, num_envs, action_space=None, opp_mode=0, force_start_order=-1):
17
+ mcts_sims = int(os.getenv("MCTS_SIMS", "50"))
18
+ super().__init__(num_envs, action_space, opp_mode, force_start_order, mcts_sims)
19
+ else:
20
+ # GPU Environment Toggle
21
+ USE_GPU_ENV = os.getenv("USE_GPU_ENV", "0") == "1" or os.getenv("GPU_ENV", "0") == "1"
22
+
23
+ if USE_GPU_ENV:
24
+ try:
25
+ from ai.vector_env_gpu import HAS_CUDA, VectorEnvGPU
26
+
27
+ if HAS_CUDA:
28
+ print(" [VecEnvAdapter] GPU Environment ENABLED (USE_GPU_ENV=1)")
29
+ else:
30
+ print(" [VecEnvAdapter] Warning: USE_GPU_ENV=1 but CUDA not available. Falling back to CPU.")
31
+ USE_GPU_ENV = False
32
+ except ImportError as e:
33
+ print(f" [VecEnvAdapter] Warning: Failed to import GPU env: {e}. Falling back to CPU.")
34
+ USE_GPU_ENV = False
35
+
36
+ if not USE_GPU_ENV:
37
+ from ai.environments.vector_env import VectorGameState
38
+
39
+ class VectorEnvAdapter(VecEnv):
40
+ """
41
+ Wraps the Numba-accelerated VectorGameState to be compatible with Stable-Baselines3.
42
+
43
+ When USE_GPU_ENV=1 is set, uses VectorEnvGPU for GPU-resident environments
44
+ with zero-copy observation transfer to PyTorch.
45
+ """
46
+
47
+ metadata = {"render_modes": ["rgb_array"]}
48
+
49
+ def __init__(self, num_envs, action_space=None, opp_mode=0, force_start_order=-1):
50
+ self.num_envs = num_envs
51
+ self.use_gpu = USE_GPU_ENV
52
+
53
+ # For Legacy Adapter: Read MCTS_SIMS env var or default
54
+ mcts_sims = int(os.getenv("MCTS_SIMS", "50"))
55
+
56
+ if self.use_gpu:
57
+ # GPU Env doesn't support MCTS yet, pass legacy args
58
+ self.game_state = VectorEnvGPU(num_envs, opp_mode=opp_mode, force_start_order=force_start_order)
59
+ else:
60
+ self.game_state = VectorGameState(num_envs, opp_mode=opp_mode, force_start_order=force_start_order)
61
+
62
+ # Use Dynamic Dimension from Engine (IMAX 8k, Standard 2k, or Compressed 512)
63
+ obs_dim = self.game_state.obs_dim
64
+ self.observation_space = spaces.Box(low=0, high=1, shape=(obs_dim,), dtype=np.float32)
65
+ if action_space is None:
66
+ # Check if game_state has defined action_space_dim (default 2000)
67
+ if hasattr(self.game_state, "action_space_dim"):
68
+ action_dim = self.game_state.action_space_dim
69
+ else:
70
+ # Fallback: The Engine always produces 2000-dim masks (Action IDs 0-1999)
71
+ action_dim = 2000
72
+
73
+ action_space = spaces.Discrete(action_dim)
74
+
75
+ # Manually initialize VecEnv fields to bypass render_modes crash
76
+ self.action_space = action_space
77
+ self.actions = None
78
+ self.render_mode = None
79
+
80
+ # Track previous scores for delta-based rewards
81
+ self.prev_scores = np.zeros(num_envs, dtype=np.int32)
82
+ self.prev_turns = np.zeros(num_envs, dtype=np.int32)
83
+ # Pre-allocate empty infos list (reused when no envs done)
84
+ self._empty_infos = [{} for _ in range(num_envs)]
85
+
86
+ def reset(self):
87
+ """
88
+ Reset all environments.
89
+ """
90
+ self.game_state.reset()
91
+ self.prev_scores.fill(0) # Reset score tracking
92
+ self.prev_turns.fill(0) # Reset turn tracking
93
+
94
+ obs = self.game_state.get_observations()
95
+ # Convert CuPy to NumPy if using GPU (SB3 expects numpy)
96
+ if self.use_gpu:
97
+ try:
98
+ import cupy as cp
99
+
100
+ if isinstance(obs, cp.ndarray):
101
+ obs = cp.asnumpy(obs)
102
+ except:
103
+ pass
104
+ return obs
105
+
106
+ def step_async(self, actions):
107
+ """
108
+ Tell the generic VecEnv wrapper to hold these actions.
109
+ """
110
+ self.actions = actions
111
+
112
+ def step_wait(self):
113
+ """
114
+ Execute the actions on the Numba engine.
115
+ """
116
+ # Ensure actions are int32 for Numba (avoid copy if already correct type)
117
+ if self.actions.dtype != np.int32:
118
+ actions_int32 = self.actions.astype(np.int32)
119
+ else:
120
+ actions_int32 = self.actions
121
+
122
+ # Step the engine
123
+ obs, rewards, dones, infos = self.game_state.step(actions_int32)
124
+
125
+ # Convert CuPy arrays to NumPy if using GPU (SB3 expects numpy)
126
+ if self.use_gpu:
127
+ try:
128
+ import cupy as cp
129
+
130
+ if isinstance(obs, cp.ndarray):
131
+ obs = cp.asnumpy(obs)
132
+ if isinstance(rewards, cp.ndarray):
133
+ rewards = cp.asnumpy(rewards)
134
+ if isinstance(dones, cp.ndarray):
135
+ dones = cp.asnumpy(dones)
136
+ except:
137
+ pass
138
+
139
+ return obs, rewards, dones, infos
140
+
141
+ def close(self):
142
+ pass
143
+
144
+ def get_attr(self, attr_name, indices=None):
145
+ """
146
+ Return attribute from vectorized environments.
147
+ """
148
+ if attr_name == "action_masks":
149
+ # Return function reference or result? SB3 usually looks for method
150
+ pass
151
+ return [None] * self.num_envs
152
+
153
+ def set_attr(self, attr_name, value, indices=None):
154
+ pass
155
+
156
+ def env_method(self, method_name, *method_args, **method_kwargs):
157
+ """
158
+ Call instance methods of vectorized environments.
159
+ """
160
+ if method_name == "action_masks":
161
+ # Return list of masks for all envs
162
+ masks = self.game_state.get_action_masks()
163
+ if self.use_gpu:
164
+ try:
165
+ import cupy as cp
166
+
167
+ if isinstance(masks, cp.ndarray):
168
+ masks = cp.asnumpy(masks)
169
+ except:
170
+ pass
171
+ return [masks[i] for i in range(self.num_envs)]
172
+
173
+ return [None] * self.num_envs
174
+
175
+ def env_is_wrapped(self, wrapper_class, indices=None):
176
+ return [False] * self.num_envs
177
+
178
+ def action_masks(self):
179
+ """
180
+ Required for MaskablePPO. Returns (num_envs, action_space.n) boolean array.
181
+ """
182
+ masks = self.game_state.get_action_masks()
183
+ if self.use_gpu:
184
+ try:
185
+ import cupy as cp
186
+
187
+ if isinstance(masks, cp.ndarray):
188
+ masks = cp.asnumpy(masks)
189
+ except:
190
+ pass
191
+ return masks