File size: 2,469 Bytes
b05f799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os

import engine_rust
import numpy as np


class RustEnvLite:
    """

    A minimal, high-performance wrapper for the LovecaSim Rust engine.

    Bypasses Gymnasium/SB3 for direct, zero-copy training loops.

    """

    def __init__(self, num_envs, db_path="data/cards_compiled.json", opp_mode=0, mcts_sims=50):
        # 1. Load DB
        if not os.path.exists(db_path):
            raise FileNotFoundError(f"Card DB not found at {db_path}")

        with open(db_path, "r", encoding="utf-8") as f:
            json_str = f.read()
        self.db = engine_rust.PyCardDatabase(json_str)

        # 2. Params
        self.num_envs = num_envs
        self.obs_dim = 350
        self.action_dim = 2000

        # 3. Create Vector Engine
        self.game_state = engine_rust.PyVectorGameState(num_envs, self.db, opp_mode, mcts_sims)

        # 4. Pre-allocate Buffers (Zero-Copy)
        self.obs_buffer = np.zeros((num_envs, self.obs_dim), dtype=np.float32)
        self.rewards_buffer = np.zeros(num_envs, dtype=np.float32)
        self.dones_buffer = np.zeros(num_envs, dtype=bool)
        self.term_obs_buffer = np.zeros((num_envs, self.obs_dim), dtype=np.float32)
        self.mask_buffer = np.zeros((num_envs, self.action_dim), dtype=bool)

        # 5. Default Decks (Standard Play)
        # Using ID 1 (Member) and ID 100 (Live) as placeholders or from DB
        self.p0_deck = [1] * 48
        self.p1_deck = [1] * 48
        self.p0_lives = [100] * 12
        self.p1_lives = [100] * 12

    def reset(self, seed=None):
        if seed is None:
            seed = np.random.randint(0, 1000000)
        self.game_state.initialize(self.p0_deck, self.p1_deck, self.p0_lives, self.p1_lives, seed)
        self.game_state.get_observations(self.obs_buffer)
        return self.obs_buffer

    def step(self, actions):
        """

        Actions: np.ndarray (int32)

        Returns: obs (view), rewards (view), dones (view), done_indices

        """
        if actions.dtype != np.int32:
            actions = actions.astype(np.int32)

        done_indices = self.game_state.step(
            actions, self.obs_buffer, self.rewards_buffer, self.dones_buffer, self.term_obs_buffer
        )
        return self.obs_buffer, self.rewards_buffer, self.dones_buffer, done_indices

    def get_masks(self):
        self.game_state.get_action_masks(self.mask_buffer)
        return self.mask_buffer