trioskosmos commited on
Commit
8991165
·
verified ·
1 Parent(s): 8c0b3c9

Upload ai/environments/vec_env_rust.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/environments/vec_env_rust.py +131 -0
ai/environments/vec_env_rust.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import engine_rust
5
+ import numpy as np
6
+ from gymnasium import spaces
7
+ from stable_baselines3.common.vec_env import VecEnv
8
+
9
+
10
+ class RustVectorEnv(VecEnv):
11
+ def __init__(self, num_envs, action_space=None, opp_mode=0, force_start_order=-1, mcts_sims=50):
12
+ # 1. Load DB
13
+ db_path = "data/cards_compiled.json"
14
+ if not os.path.exists(db_path):
15
+ raise FileNotFoundError(f"Card DB not found at {db_path}")
16
+
17
+ with open(db_path, "r", encoding="utf-8") as f:
18
+ json_str = f.read()
19
+ self.db = engine_rust.PyCardDatabase(json_str)
20
+
21
+ # 2. Create Vector State
22
+ self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims)
23
+
24
+ # 3. Setup Spaces
25
+ obs_dim = 350
26
+ self.observation_space = spaces.Box(low=0, high=1, shape=(obs_dim,), dtype=np.float32)
27
+
28
+ if action_space is None:
29
+ self.action_space = spaces.Discrete(2000)
30
+ else:
31
+ self.action_space = action_space
32
+
33
+ self.num_envs = num_envs
34
+ self.actions = None
35
+
36
+ # Pre-allocate buffers for Zero-Copy
37
+ self.obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32)
38
+ self.rewards_buffer = np.zeros(num_envs, dtype=np.float32)
39
+ self.dones_buffer = np.zeros(num_envs, dtype=bool)
40
+ # Term obs buffer needs to accommodate worst case (all done)
41
+ self.term_obs_buffer = np.zeros((num_envs, obs_dim), dtype=np.float32)
42
+ self.mask_buffer = np.zeros((num_envs, 2000), dtype=bool)
43
+
44
+ # 4. Load Deck Config
45
+ self._load_decks()
46
+
47
+ # 5. Initialize (Warmup)
48
+ self.reset()
49
+
50
+ def _load_decks(self):
51
+ m_ids = []
52
+ l_ids = []
53
+ try:
54
+ with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
55
+ pool = json.load(f)
56
+
57
+ if self.db.has_member(1):
58
+ m_ids = [1] * 48
59
+ else:
60
+ ids = self.db.get_member_ids()
61
+ if ids:
62
+ m_ids = [ids[0]] * 48
63
+ l_ids = [100] * 12
64
+
65
+ except Exception as e:
66
+ print(f"Warning: Failed to load deck config: {e}")
67
+ m_ids = [1] * 48
68
+ l_ids = [100] * 12
69
+
70
+ self.p0_deck = m_ids
71
+ self.p1_deck = m_ids
72
+ self.p0_lives = l_ids
73
+ self.p1_lives = l_ids
74
+
75
+ def reset(self):
76
+ seed = np.random.randint(0, 1000000)
77
+ self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed)
78
+ return self.get_observations()
79
+
80
+ def step_async(self, actions):
81
+ self.actions = actions
82
+
83
+ def step_wait(self):
84
+ if self.actions is None:
85
+ return self.reset(), np.zeros(self.num_envs), np.zeros(self.num_envs, dtype=bool), [{}] * self.num_envs
86
+
87
+ # Ensure int32
88
+ actions = self.actions.astype(np.int32)
89
+
90
+ # Call Rust step with pre-allocated buffers
91
+ # Returns list of done indices
92
+ done_indices = self.game_state.step(
93
+ actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer
94
+ )
95
+
96
+ infos = [{} for _ in range(self.num_envs)]
97
+
98
+ # Populate infos for done envs
99
+ if done_indices:
100
+ for i, env_idx in enumerate(done_indices):
101
+ # Copy terminal obs from buffer to info dict
102
+ infos[env_idx]["terminal_observation"] = self.term_obs_buffer[i].copy()
103
+
104
+ # Return copies or views?
105
+ # VecEnv expects new arrays usually, or we must ensure they aren't mutated during agent update.
106
+ # SB3 PPO copies to rollout buffer, so views/buffers are fine IF they persist until next step.
107
+ # But we overwrite them next step. This is fine.
108
+ return self.obs_buffer.copy(), self.rewards_buffer.copy(), self.dones_buffer.copy(), infos
109
+
110
+ def close(self):
111
+ pass
112
+
113
+ def get_attr(self, attr_name, indices=None):
114
+ return [None] * self.num_envs
115
+
116
+ def set_attr(self, attr_name, value, indices=None):
117
+ pass
118
+
119
+ def env_method(self, method_name, *method_args, **method_kwargs):
120
+ return [None] * self.num_envs
121
+
122
+ def env_is_wrapped(self, wrapper_class, indices=None):
123
+ return [False] * self.num_envs
124
+
125
+ def get_observations(self):
126
+ self.game_state.get_observations(self.obs_buffer)
127
+ return self.obs_buffer.copy()
128
+
129
+ def action_masks(self):
130
+ self.game_state.get_action_masks(self.mask_buffer)
131
+ return self.mask_buffer.copy()