trioskosmos commited on
Commit
b05f799
·
verified ·
1 Parent(s): 216335a

Upload ai/environments/rust_env_lite.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/environments/rust_env_lite.py +66 -0
ai/environments/rust_env_lite.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import engine_rust
4
+ import numpy as np
5
+
6
+
7
+ class RustEnvLite:
8
+ """
9
+ A minimal, high-performance wrapper for the LovecaSim Rust engine.
10
+ Bypasses Gymnasium/SB3 for direct, zero-copy training loops.
11
+ """
12
+
13
+ def __init__(self, num_envs, db_path="data/cards_compiled.json", opp_mode=0, mcts_sims=50):
14
+ # 1. Load DB
15
+ if not os.path.exists(db_path):
16
+ raise FileNotFoundError(f"Card DB not found at {db_path}")
17
+
18
+ with open(db_path, "r", encoding="utf-8") as f:
19
+ json_str = f.read()
20
+ self.db = engine_rust.PyCardDatabase(json_str)
21
+
22
+ # 2. Params
23
+ self.num_envs = num_envs
24
+ self.obs_dim = 350
25
+ self.action_dim = 2000
26
+
27
+ # 3. Create Vector Engine
28
+ self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims)
29
+
30
+ # 4. Pre-allocate Buffers (Zero-Copy)
31
+ self.obs_buffer = np.zeros((num_envs, self.obs_dim), dtype=np.float32)
32
+ self.rewards_buffer = np.zeros(num_envs, dtype=np.float32)
33
+ self.dones_buffer = np.zeros(num_envs, dtype=bool)
34
+ self.term_obs_buffer = np.zeros((num_envs, self.obs_dim), dtype=np.float32)
35
+ self.mask_buffer = np.zeros((num_envs, self.action_dim), dtype=bool)
36
+
37
+ # 5. Default Decks (Standard Play)
38
+ # Using ID 1 (Member) and ID 100 (Live) as placeholders or from DB
39
+ self.p0_deck = [1] * 48
40
+ self.p1_deck = [1] * 48
41
+ self.p0_lives = [100] * 12
42
+ self.p1_lives = [100] * 12
43
+
44
+ def reset(self, seed=None):
45
+ if seed is None:
46
+ seed = np.random.randint(0, 1000000)
47
+ self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed)
48
+ self.game_state.get_observations(self.obs_buffer)
49
+ return self.obs_buffer
50
+
51
+ def step(self, actions):
52
+ """
53
+ Actions: np.ndarray (int32)
54
+ Returns: obs (view), rewards (view), dones (view), done_indices
55
+ """
56
+ if actions.dtype != np.int32:
57
+ actions = actions.astype(np.int32)
58
+
59
+ done_indices = self.game_state.step(
60
+ actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer
61
+ )
62
+ return self.obs_buffer, self.rewards_buffer, self.dones_buffer, done_indices
63
+
64
+ def get_masks(self):
65
+ self.game_state.get_action_masks(self.mask_buffer)
66
+ return self.mask_buffer