import json import os import engine_rust import numpy as np from gymnasium import spaces from stable_baselines3.common.vec_env import VecEnv class RustVectorEnv(VecEnv): def __init__(self, num_envs, action_space=None, opp_mode=0, force_start_order=-1, mcts_sims=50): # 1. Load DB db_path = "data/cards_compiled.json" if not os.path.exists(db_path): raise FileNotFoundError(f"Card DB not found at {db_path}") with open(db_path, "r", encoding="utf-8") as f: json_str = f.read() self.db = engine_rust.PyCardDatabase(json_str) # 2. Create Vector State self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims) # 3. Setup Spaces obs_dim = 350 self.observation_space = spaces.Box(low=0, high=1, shape=(obs_dim,), dtype=np.float32) if action_space is None: self.action_space = spaces.Discrete(2000) else: self.action_space = action_space self.num_envs = num_envs self.actions = None # Pre-allocate buffers for Zero-Copy self.obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32) self.rewards_buffer = np.zeros(num_envs, dtype=np.float32) self.dones_buffer = np.zeros(num_envs, dtype=bool) # Term obs buffer needs to accommodate worst case (all done) self.term_obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32) self.mask_buffer = np.zeros((num_envs, 2000), dtype=bool) # 4. Load Deck Config self._load_decks() # 5. Initialize (Warmup) self.reset() def _load_decks(self): m_ids = [] l_ids = [] try: with open("data/verified_card_pool.json", "r", encoding="utf-8") as f: pool = json.load(f) if self.db.has_member(1): m_ids = [1] * 48 else: ids = self.db.get_member_ids() if ids: m_ids = [ids[0]] * 48 l_ids = [100] * 12 except Exception as e: print(f"Warning: Failed to load deck config: {e}") m_ids = [1] * 48 l_ids = [100] * 12 self.p0_deck = m_ids self.p1_deck = m_ids self.p0_lives = l_ids self.p1_lives = l_ids def reset(self): seed = np.random.randint(0, 1000000) self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed) return self.get_observations() def step_async(self, actions): self.actions = actions def step_wait(self): if self.actions is None: return self.reset(), np.zeros(self.num_envs), np.zeros(self.num_envs, dtype=bool), [{}] * self.num_envs # Ensure int32 actions = self.actions.astype(np.int32) # Call Rust step with pre-allocated buffers # Returns list of done indices done_indices = self.game_state.step( actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer ) infos = [{} for _ in range(self.num_envs)] # Populate infos for done envs if done_indices: for i, env_idx in enumerate(done_indices): # Copy terminal obs from buffer to info dict infos[env_idx]["terminal_observation"] = self.term_obs_buffer[i].copy() # Return copies or views? # VecEnv expects new arrays usually, or we must ensure they aren't mutated during agent update. # SB3 PPO copies to rollout buffer, so views/buffers are fine IF they persist until next step. # But we overwrite them next step. This is fine. return self.obs_buffer.copy(), self.rewards_buffer.copy(), self.dones_buffer.copy(), infos def close(self): pass def get_attr(self, attr_name, indices=None): return [None] * self.num_envs def set_attr(self, attr_name, value, indices=None): pass def env_method(self, method_name, *method_args, **method_kwargs): return [None] * self.num_envs def env_is_wrapped(self, wrapper_class, indices=None): return [False] * self.num_envs def get_observations(self): self.game_state.get_observations(self.obs_buffer) return self.obs_buffer.copy() def action_masks(self): self.game_state.get_action_masks(self.mask_buffer) return self.mask_buffer.copy()