Spaces:
Running
Running
File size: 4,629 Bytes
8991165 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import json
import os
import engine_rust
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env import VecEnv
class RustVectorEnv(VecEnv):
def __init__(self, num_envs, action_space=None, opp_mode=0, force_start_order=-1, mcts_sims=50):
# 1. Load DB
db_path = "data/cards_compiled.json"
if not os.path.exists(db_path):
raise FileNotFoundError(f"Card DB not found at {db_path}")
with open(db_path, "r", encoding="utf-8") as f:
json_str = f.read()
self.db = engine_rust.PyCardDatabase(json_str)
# 2. Create Vector State
self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims)
# 3. Setup Spaces
obs_dim = 350
self.observation_space = spaces.Box(low=0, high=1, shape=(obs_dim,), dtype=np.float32)
if action_space is None:
self.action_space = spaces.Discrete(2000)
else:
self.action_space = action_space
self.num_envs = num_envs
self.actions = None
# Pre-allocate buffers for Zero-Copy
self.obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32)
self.rewards_buffer = np.zeros(num_envs, dtype=np.float32)
self.dones_buffer = np.zeros(num_envs, dtype=bool)
# Term obs buffer needs to accommodate worst case (all done)
self.term_obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32)
self.mask_buffer = np.zeros((num_envs, 2000), dtype=bool)
# 4. Load Deck Config
self._load_decks()
# 5. Initialize (Warmup)
self.reset()
def _load_decks(self):
m_ids = []
l_ids = []
try:
with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
pool = json.load(f)
if self.db.has_member(1):
m_ids = [1] * 48
else:
ids = self.db.get_member_ids()
if ids:
m_ids = [ids[0]] * 48
l_ids = [100] * 12
except Exception as e:
print(f"Warning: Failed to load deck config: {e}")
m_ids = [1] * 48
l_ids = [100] * 12
self.p0_deck = m_ids
self.p1_deck = m_ids
self.p0_lives = l_ids
self.p1_lives = l_ids
def reset(self):
seed = np.random.randint(0, 1000000)
self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed)
return self.get_observations()
def step_async(self, actions):
self.actions = actions
def step_wait(self):
if self.actions is None:
return self.reset(), np.zeros(self.num_envs), np.zeros(self.num_envs, dtype=bool), [{}] * self.num_envs
# Ensure int32
actions = self.actions.astype(np.int32)
# Call Rust step with pre-allocated buffers
# Returns list of done indices
done_indices = self.game_state.step(
actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer
)
infos = [{} for _ in range(self.num_envs)]
# Populate infos for done envs
if done_indices:
for i, env_idx in enumerate(done_indices):
# Copy terminal obs from buffer to info dict
infos[env_idx]["terminal_observation"] = self.term_obs_buffer[i].copy()
# Return copies or views?
# VecEnv expects new arrays usually, or we must ensure they aren't mutated during agent update.
# SB3 PPO copies to rollout buffer, so views/buffers are fine IF they persist until next step.
# But we overwrite them next step. This is fine.
return self.obs_buffer.copy(), self.rewards_buffer.copy(), self.dones_buffer.copy(), infos
def close(self):
pass
def get_attr(self, attr_name, indices=None):
return [None] * self.num_envs
def set_attr(self, attr_name, value, indices=None):
pass
def env_method(self, method_name, *method_args, **method_kwargs):
return [None] * self.num_envs
def env_is_wrapped(self, wrapper_class, indices=None):
return [False] * self.num_envs
def get_observations(self):
self.game_state.get_observations(self.obs_buffer)
return self.obs_buffer.copy()
def action_masks(self):
self.game_state.get_action_masks(self.mask_buffer)
return self.mask_buffer.copy()
|