LovecaSim / ai /environments /vec_env_rust.py
trioskosmos's picture
Upload ai/environments/vec_env_rust.py with huggingface_hub
8991165 verified
import json
import os
import engine_rust
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env import VecEnv
class RustVectorEnv(VecEnv):
def __init__(self, num_envs, action_space=None, opp_mode=0, force_start_order=-1, mcts_sims=50):
# 1. Load DB
db_path = "data/cards_compiled.json"
if not os.path.exists(db_path):
raise FileNotFoundError(f"Card DB not found at {db_path}")
with open(db_path, "r", encoding="utf-8") as f:
json_str = f.read()
self.db = engine_rust.PyCardDatabase(json_str)
# 2. Create Vector State
self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims)
# 3. Setup Spaces
obs_dim = 350
self.observation_space = spaces.Box(low=0, high=1, shape=(obs_dim,), dtype=np.float32)
if action_space is None:
self.action_space = spaces.Discrete(2000)
else:
self.action_space = action_space
self.num_envs = num_envs
self.actions = None
# Pre-allocate buffers for Zero-Copy
self.obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32)
self.rewards_buffer = np.zeros(num_envs, dtype=np.float32)
self.dones_buffer = np.zeros(num_envs, dtype=bool)
# Term obs buffer needs to accommodate worst case (all done)
self.term_obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32)
self.mask_buffer = np.zeros((num_envs, 2000), dtype=bool)
# 4. Load Deck Config
self._load_decks()
# 5. Initialize (Warmup)
self.reset()
def _load_decks(self):
m_ids = []
l_ids = []
try:
with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
pool = json.load(f)
if self.db.has_member(1):
m_ids = [1] * 48
else:
ids = self.db.get_member_ids()
if ids:
m_ids = [ids[0]] * 48
l_ids = [100] * 12
except Exception as e:
print(f"Warning: Failed to load deck config: {e}")
m_ids = [1] * 48
l_ids = [100] * 12
self.p0_deck = m_ids
self.p1_deck = m_ids
self.p0_lives = l_ids
self.p1_lives = l_ids
def reset(self):
seed = np.random.randint(0, 1000000)
self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed)
return self.get_observations()
def step_async(self, actions):
self.actions = actions
def step_wait(self):
if self.actions is None:
return self.reset(), np.zeros(self.num_envs), np.zeros(self.num_envs, dtype=bool), [{}] * self.num_envs
# Ensure int32
actions = self.actions.astype(np.int32)
# Call Rust step with pre-allocated buffers
# Returns list of done indices
done_indices = self.game_state.step(
actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer
)
infos = [{} for _ in range(self.num_envs)]
# Populate infos for done envs
if done_indices:
for i, env_idx in enumerate(done_indices):
# Copy terminal obs from buffer to info dict
infos[env_idx]["terminal_observation"] = self.term_obs_buffer[i].copy()
# Return copies or views?
# VecEnv expects new arrays usually, or we must ensure they aren't mutated during agent update.
# SB3 PPO copies to rollout buffer, so views/buffers are fine IF they persist until next step.
# But we overwrite them next step. This is fine.
return self.obs_buffer.copy(), self.rewards_buffer.copy(), self.dones_buffer.copy(), infos
def close(self):
pass
def get_attr(self, attr_name, indices=None):
return [None] * self.num_envs
def set_attr(self, attr_name, value, indices=None):
pass
def env_method(self, method_name, *method_args, **method_kwargs):
return [None] * self.num_envs
def env_is_wrapped(self, wrapper_class, indices=None):
return [False] * self.num_envs
def get_observations(self):
self.game_state.get_observations(self.obs_buffer)
return self.obs_buffer.copy()
def action_masks(self):
self.game_state.get_action_masks(self.mask_buffer)
return self.mask_buffer.copy()