Spaces:
Running
Running
Upload ai/environments/rust_env_lite.py with huggingface_hub
Browse files
ai/environments/rust_env_lite.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import engine_rust
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RustEnvLite:
|
| 8 |
+
"""
|
| 9 |
+
A minimal, high-performance wrapper for the LovecaSim Rust engine.
|
| 10 |
+
Bypasses Gymnasium/SB3 for direct, zero-copy training loops.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, num_envs, db_path="data/cards_compiled.json", opp_mode=0, mcts_sims=50):
|
| 14 |
+
# 1. Load DB
|
| 15 |
+
if not os.path.exists(db_path):
|
| 16 |
+
raise FileNotFoundError(f"Card DB not found at {db_path}")
|
| 17 |
+
|
| 18 |
+
with open(db_path, "r", encoding="utf-8") as f:
|
| 19 |
+
json_str = f.read()
|
| 20 |
+
self.db = engine_rust.PyCardDatabase(json_str)
|
| 21 |
+
|
| 22 |
+
# 2. Params
|
| 23 |
+
self.num_envs = num_envs
|
| 24 |
+
self.obs_dim = 350
|
| 25 |
+
self.action_dim = 2000
|
| 26 |
+
|
| 27 |
+
# 3. Create Vector Engine
|
| 28 |
+
self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims)
|
| 29 |
+
|
| 30 |
+
# 4. Pre-allocate Buffers (Zero-Copy)
|
| 31 |
+
self.obs_buffer = np.zeros((num_envs, self.obs_dim), dtype=np.float32)
|
| 32 |
+
self.rewards_buffer = np.zeros(num_envs, dtype=np.float32)
|
| 33 |
+
self.dones_buffer = np.zeros(num_envs, dtype=bool)
|
| 34 |
+
self.term_obs_buffer = np.zeros((num_envs, self.obs_dim), dtype=np.float32)
|
| 35 |
+
self.mask_buffer = np.zeros((num_envs, self.action_dim), dtype=bool)
|
| 36 |
+
|
| 37 |
+
# 5. Default Decks (Standard Play)
|
| 38 |
+
# Using ID 1 (Member) and ID 100 (Live) as placeholders or from DB
|
| 39 |
+
self.p0_deck = [1] * 48
|
| 40 |
+
self.p1_deck = [1] * 48
|
| 41 |
+
self.p0_lives = [100] * 12
|
| 42 |
+
self.p1_lives = [100] * 12
|
| 43 |
+
|
| 44 |
+
def reset(self, seed=None):
|
| 45 |
+
if seed is None:
|
| 46 |
+
seed = np.random.randint(0, 1000000)
|
| 47 |
+
self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed)
|
| 48 |
+
self.game_state.get_observations(self.obs_buffer)
|
| 49 |
+
return self.obs_buffer
|
| 50 |
+
|
| 51 |
+
def step(self, actions):
|
| 52 |
+
"""
|
| 53 |
+
Actions: np.ndarray (int32)
|
| 54 |
+
Returns: obs (view), rewards (view), dones (view), done_indices
|
| 55 |
+
"""
|
| 56 |
+
if actions.dtype != np.int32:
|
| 57 |
+
actions = actions.astype(np.int32)
|
| 58 |
+
|
| 59 |
+
done_indices = self.game_state.step(
|
| 60 |
+
actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer
|
| 61 |
+
)
|
| 62 |
+
return self.obs_buffer, self.rewards_buffer, self.dones_buffer, done_indices
|
| 63 |
+
|
| 64 |
+
def get_masks(self):
|
| 65 |
+
self.game_state.get_action_masks(self.mask_buffer)
|
| 66 |
+
return self.mask_buffer
|