File size: 11,107 Bytes
2badd2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from typing import List

import numpy as np

from engine.game.ai_compat import njit
from engine.game.fast_logic import batch_apply_action


@njit(cache=True)
def step_vectorized(

    actions: np.ndarray,

    batch_stage: np.ndarray,

    batch_energy_vec: np.ndarray,

    batch_energy_count: np.ndarray,

    batch_continuous_vec: np.ndarray,

    batch_continuous_ptr: np.ndarray,

    batch_tapped: np.ndarray,

    batch_live: np.ndarray,

    batch_opp_tapped: np.ndarray,

    batch_scores: np.ndarray,

    batch_flat_ctx: np.ndarray,

    batch_global_ctx: np.ndarray,

    batch_hand: np.ndarray,

    batch_deck: np.ndarray,

    # New: Bytecode Maps

    bytecode_map: np.ndarray,  # (GlobalOpMapSize, MaxBytecodeLen, 4)

    bytecode_index: np.ndarray,  # (NumCards, NumAbilities) -> Index in map

):
    """

    Step N game environments in parallel using JIT logic and Real Card Data.

    """
    # Sync individual scores to global_ctx before stepping
    for i in range(len(actions)):
        batch_global_ctx[i, 0] = batch_scores[i]

    batch_apply_action(
        actions,
        0,  # player_id
        batch_stage,
        batch_energy_vec,
        batch_energy_count,
        batch_continuous_vec,
        batch_continuous_ptr,
        batch_tapped,
        batch_scores,
        batch_live,
        batch_opp_tapped,
        batch_flat_ctx,
        batch_global_ctx,
        batch_hand,
        batch_deck,
        bytecode_map,
        bytecode_index,
    )


class VectorGameState:
    """

    Manages a batch of independent GameStates for high-throughput training.

    """

    def __init__(self, num_envs: int):
        self.num_envs = num_envs
        self.turn = 1

        # Batched state buffers
        self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32)
        self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32)
        self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
        self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32)
        self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32)
        self.batch_tapped = np.zeros((num_envs, 3), dtype=np.int32)
        self.batch_live = np.zeros((num_envs, 50), dtype=np.int32)
        self.batch_opp_tapped = np.zeros((num_envs, 3), dtype=np.int32)
        self.batch_scores = np.zeros(num_envs, dtype=np.int32)

        # Pre-allocated context buffers (Extreme speed optimization)
        self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32)
        self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32)
        self.batch_hand = np.zeros((num_envs, 50), dtype=np.int32)
        self.batch_deck = np.zeros((num_envs, 50), dtype=np.int32)

        # Pre-allocated observation buffer (SAVES ALLOCATION TIME)
        self.obs_buffer = np.zeros((num_envs, 320), dtype=np.float32)

        # Load Bytecode Map
        self._load_bytecode()
        self._load_verified_deck_pool()

    def _load_bytecode(self):
        import json

        try:
            with open("data/cards_numba.json", "r") as f:
                raw_map = json.load(f)

            # Convert to numpy array
            # Format: key "cardid_abidx" -> List[int]
            # storage:
            # 1. giant array of bytecodes (N, MaxLen, 4)
            # 2. lookup index (CardID, AbIdx) -> Index in giant array

            self.max_cards = 2000
            self.max_abilities = 4
            self.max_len = 64  # Max 64 instructions per ability

            # Count unique compiled entries
            unique_entries = len(raw_map)
            # (Index 0 is empty/nop)
            self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32)
            self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32)

            idx_counter = 1
            for key, bc_list in raw_map.items():
                cid, aid = map(int, key.split("_"))
                if cid < self.max_cards and aid < self.max_abilities:
                    # reshape list to (M, 4)
                    bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
                    length = min(bc_arr.shape[0], self.max_len)
                    self.bytecode_map[idx_counter, :length] = bc_arr[:length]
                    self.bytecode_index[cid, aid] = idx_counter
                    idx_counter += 1

            print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.")

        except FileNotFoundError:
            print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.")
            self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32)
            self.bytecode_index = np.zeros((1, 1), dtype=np.int32)

    def _load_verified_deck_pool(self):
        import json

        try:
            # Load Verified List
            with open("verified_card_pool.json", "r", encoding="utf-8") as f:
                verified_data = json.load(f)

            # Load DB to map CardNo -> CardID
            with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
                db_data = json.load(f)

            self.verified_card_ids = []

            # Map numbers to IDs
            card_no_map = {}
            for cid, cdata in db_data["member_db"].items():
                card_no_map[cdata["card_no"]] = int(cid)

            for v_no in verified_data.get("verified_abilities", []):
                if v_no in card_no_map:
                    self.verified_card_ids.append(card_no_map[v_no])

            # Fallback
            if not self.verified_card_ids:
                print(" [VectorEnv] Warning: No verified cards found. Using ID 1.")
                self.verified_card_ids = [1]
            else:
                print(f" [VectorEnv] Loaded {len(self.verified_card_ids)} verified cards for training.")

            self.verified_card_ids = np.array(self.verified_card_ids, dtype=np.int32)

        except Exception as e:
            print(f" [VectorEnv] Deck Load Error: {e}")
            self.verified_card_ids = np.array([1], dtype=np.int32)

    def reset(self, indices: List[int] = None):
        """Reset specified environments (or all if indices is None)."""
        if indices is None:
            indices = list(range(self.num_envs))

        # Optimization: Bulk operations for indices if supported,
        # but for now loop is fine (reset is rare compared to step)

        # Prepare a random deck selection to broadcast?
        # Actually random.choice is fast.

        for i in indices:
            self.batch_stage[i].fill(-1)
            self.batch_energy_vec[i].fill(0)
            self.batch_energy_count[i].fill(0)
            self.batch_continuous_vec[i].fill(0)
            self.batch_continuous_ptr[i] = 0
            self.batch_tapped[i].fill(0)
            self.batch_live[i].fill(0)
            self.batch_opp_tapped[i].fill(0)
            self.batch_scores[i] = 0

            # Reset contexts
            self.batch_flat_ctx[i].fill(0)
            self.batch_global_ctx[i].fill(0)

            # Initialize Deck with Verified Cards (Random 50)
            # Fast choice from verified pool
            if len(self.verified_card_ids) > 0:
                dk = np.random.choice(self.verified_card_ids, 50)
                self.batch_deck[i] = dk

            # Initialize Hand (Draw 5 from deck)
            # Simple simulation: Move top 5 deck to hand
            self.batch_hand[i, :5] = self.batch_deck[i, :5]
            # Shift deck? Or just pointer?
            # For this benchmark we assume infinite deck or simple pointer logic via opcodes.
            # But the 'hand' array needs to be populated for gameplay to start.

        self.turn = 1

    def step(self, actions: np.ndarray):
        """Apply a batch of actions across all environments."""
        step_vectorized(
            actions,
            self.batch_stage,
            self.batch_energy_vec,
            self.batch_energy_count,
            self.batch_continuous_vec,
            self.batch_continuous_ptr,
            self.batch_tapped,
            self.batch_live,
            self.batch_opp_tapped,
            self.batch_scores,
            self.batch_flat_ctx,
            self.batch_global_ctx,
            self.batch_hand,
            self.batch_deck,
            self.bytecode_map,
            self.bytecode_index,
        )
        # Simplified turn advancement
        # In real VectorEnv, this would be managed by the engine rules
        pass

    def get_observations(self):
        """Return a batched observation for RL models."""
        return encode_observations_vectorized(
            self.num_envs,
            self.batch_stage,
            self.batch_energy_count,
            self.batch_tapped,
            self.batch_scores,
            self.turn,
            self.obs_buffer,
        )


@njit(cache=True)
def encode_observations_vectorized(

    num_envs: int,

    batch_stage: np.ndarray,  # (N, 3)

    batch_energy_count: np.ndarray,  # (N, 3)

    batch_tapped: np.ndarray,  # (N, 3)

    batch_scores: np.ndarray,  # (N,)

    turn_number: int,

    observations: np.ndarray,  # (N, 320)

):
    # Reset buffer (extremely fast on pre-allocated)
    observations.fill(0.0)
    max_id_val = 2000.0  # Normalization constant

    for i in range(num_envs):
        # --- 1. METADATA [0:36] ---
        # Phase (Simplify: Always Main Phase=1 for now in vector env)
        # Phase 1=Start, 2=Draw, 3=Main... Main is index 3+2=5?
        # GameState logic: phase_val = int(phase) + 2. Main is 3. So 5.
        observations[i, 5] = 1.0

        # Current Player [16:18] - Always Player 0 for this vector view
        observations[i, 16] = 1.0

        # --- 2. HAND [36:168] ---
        # VectorEnv doesn't track hand yet. Leave 0.0.

        # --- 3. SELF STAGE [168:204] (3 slots * 12 features) ---
        for slot in range(3):
            cid = batch_stage[i, slot]
            base = 168 + slot * 12
            if cid >= 0:
                observations[i, base] = 1.0
                observations[i, base + 1] = cid / max_id_val
                observations[i, base + 2] = 1.0 if batch_tapped[i, slot] else 0.0

                # Mock attributes (since we don't have full DB access inside JIT yet)
                # In real imp, we'd pass arrays like member_costs
                observations[i, base + 3] = 0.5  # Default power

                # Energy Count
                observations[i, base + 11] = min(batch_energy_count[i, slot] / 5.0, 1.0)

        # --- 4. OPPONENT STAGE [204:240] ---
        # Not tracked in partial vector env yet.

        # --- 5. LIVE ZONE [240:270] ---
        # Not tracked in partial vector env yet.

        # --- 6. SCORES [270:272] ---
        observations[i, 270] = min(batch_scores[i] / 5.0, 1.0)

    return observations