trioskosmos commited on
Commit
d980970
·
verified ·
1 Parent(s): b05f799

Upload ai/environments/vector_env.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/environments/vector_env.py +1418 -0
ai/environments/vector_env.py ADDED
@@ -0,0 +1,1418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from numba import njit, prange
6
+
7
+ import ai.research.integrated_step_numba as isn
8
+ from engine.game.fast_logic import (
9
+ batch_apply_action,
10
+ resolve_bytecode,
11
+ )
12
+
13
+
14
+ @njit(cache=True)
15
+ def step_vectorized(
16
+ actions: np.ndarray,
17
+ batch_stage: np.ndarray,
18
+ batch_energy_vec: np.ndarray,
19
+ batch_energy_count: np.ndarray,
20
+ batch_continuous_vec: np.ndarray,
21
+ batch_continuous_ptr: np.ndarray,
22
+ batch_tapped: np.ndarray,
23
+ batch_live: np.ndarray,
24
+ batch_opp_tapped: np.ndarray,
25
+ batch_scores: np.ndarray,
26
+ batch_flat_ctx: np.ndarray,
27
+ batch_global_ctx: np.ndarray,
28
+ batch_hand: np.ndarray,
29
+ batch_deck: np.ndarray,
30
+ # New: Bytecode Maps
31
+ bytecode_map: np.ndarray, # (GlobalOpMapSize, MaxBytecodeLen, 4)
32
+ bytecode_index: np.ndarray, # (NumCards, NumAbilities) -> Index in map
33
+ card_stats: np.ndarray,
34
+ batch_trash: np.ndarray, # Added
35
+ ):
36
+ """
37
+ Step N game environments in parallel using JIT logic and Real Card Data.
38
+ """
39
+ # Score sync now handled internally by batch_apply_action
40
+
41
+ batch_apply_action(
42
+ actions,
43
+ 0, # player_id
44
+ batch_stage,
45
+ batch_energy_vec,
46
+ batch_energy_count,
47
+ batch_continuous_vec,
48
+ batch_continuous_ptr,
49
+ batch_tapped,
50
+ batch_scores,
51
+ batch_live,
52
+ batch_opp_tapped,
53
+ batch_flat_ctx,
54
+ batch_global_ctx,
55
+ batch_hand,
56
+ batch_deck,
57
+ batch_trash, # Added
58
+ bytecode_map,
59
+ bytecode_index,
60
+ card_stats,
61
+ )
62
+
63
+ rewards = np.zeros(actions.shape[0], dtype=np.float32)
64
+ dones = np.zeros(actions.shape[0], dtype=np.bool_)
65
+ return rewards, dones
66
+
67
+
68
+ class VectorGameState:
69
+ """
70
+ Manages a batch of independent GameStates for high-throughput training.
71
+ """
72
+
73
+ def __init__(self, num_envs: int, opp_mode: int = 0, force_start_order: int = -1):
74
+ self.num_envs = num_envs
75
+ # opp_mode: 0=Heuristic, 1=Random, 2=Solitaire (Pass Only)
76
+ self.opp_mode = opp_mode
77
+ self.force_start_order = force_start_order # -1=Random, 0=P1, 1=P2
78
+ self.turn = 1
79
+
80
+ # Batched state buffers - Player 0 (Agent)
81
+ self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32)
82
+ self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32)
83
+ self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
84
+ self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32)
85
+ self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32)
86
+ self.batch_tapped = np.zeros((num_envs, 16), dtype=np.int32) # Slots 0-2, Energy 3-15
87
+ self.batch_live = np.zeros((num_envs, 50), dtype=np.int32)
88
+ self.batch_opp_tapped = np.zeros((num_envs, 16), dtype=np.int32)
89
+ self.batch_scores = np.zeros(num_envs, dtype=np.int32)
90
+
91
+ # Batched state buffers - Opponent State (Player 1)
92
+ self.opp_stage = np.full((num_envs, 3), -1, dtype=np.int32)
93
+ self.opp_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32) # Match Agent Shape
94
+ self.opp_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
95
+ self.opp_tapped = np.zeros((num_envs, 16), dtype=np.int8)
96
+ self.opp_live = np.zeros((num_envs, 50), dtype=np.int32) # Added Opp Live
97
+ self.opp_scores = np.zeros(num_envs, dtype=np.int32)
98
+
99
+ # New State Tracking for Integrated Step
100
+ self.prev_scores = np.zeros(num_envs, dtype=np.int32)
101
+ self.prev_opp_scores = np.zeros(num_envs, dtype=np.int32)
102
+ self.prev_phases = np.zeros(num_envs, dtype=np.int32)
103
+ self.episode_returns = np.zeros(num_envs, dtype=np.float32)
104
+ self.episode_lengths = np.zeros(num_envs, dtype=np.int32)
105
+
106
+ # Opponent Finite Deck Buffers
107
+ self.opp_hand = np.zeros((num_envs, 60), dtype=np.int32)
108
+ self.opp_deck = np.zeros((num_envs, 60), dtype=np.int32)
109
+
110
+ # Load Numba functions
111
+ import os
112
+
113
+ if os.getenv("USE_SCENARIOS", "0") == "1":
114
+ self._load_scenarios()
115
+
116
+ # Scenario Reward Scaling
117
+ self.scenario_reward_scale = float(os.getenv("SCENARIO_REWARD_SCALE", "1.0"))
118
+ if os.getenv("USE_SCENARIOS", "0") == "1" and self.scenario_reward_scale != 1.0:
119
+ print(f" [VectorEnv] Scenario Reward Scale: {self.scenario_reward_scale}")
120
+
121
+ # New: Opponent History Buffer (Top 20 cards e.g.)
122
+ self.batch_opp_history = np.zeros((num_envs, 50), dtype=np.int32)
123
+
124
+ # Pre-allocated context buffers (Extreme speed optimization)
125
+ self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32)
126
+ self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32)
127
+ self.opp_global_ctx = np.zeros((num_envs, 128), dtype=np.int32) # Persistent Opponent Context
128
+ self.batch_hand = np.zeros((num_envs, 60), dtype=np.int32)
129
+ self.batch_deck = np.zeros((num_envs, 60), dtype=np.int32)
130
+ self.batch_trash = np.zeros((num_envs, 60), dtype=np.int32) # Added Trash
131
+ self.opp_trash = np.zeros((num_envs, 60), dtype=np.int32) # Added Opp Trash
132
+ # Observation Buffer
133
+ # 20480 floats per env to handle Full Hand (60 cards) + Opponent + Stats
134
+ # Increased for "Real Vision" upgrade
135
+ # Observation Buffer
136
+ # Mode Selection
137
+ import os
138
+
139
+ self.obs_mode = os.getenv("OBS_MODE", "STANDARD")
140
+ if self.obs_mode == "COMPRESSED":
141
+ self.obs_dim = 512
142
+ self.action_space_dim = 2000
143
+ print(" [VectorEnv] Observation Mode: COMPRESSED (512-dim)")
144
+ elif self.obs_mode == "IMAX":
145
+ self.obs_dim = 8192
146
+ self.action_space_dim = 2000
147
+ print(" [VectorEnv] Observation Mode: IMAX (8192-dim)")
148
+ elif self.obs_mode == "ATTENTION":
149
+ self.obs_dim = 2240
150
+ self.action_space_dim = 512
151
+ print(" [VectorEnv] Observation Mode: ATTENTION (2240-dim)")
152
+ else:
153
+ self.obs_dim = 2304
154
+ self.action_space_dim = 2000
155
+ print(" [VectorEnv] Observation Mode: STANDARD (2304-dim)")
156
+
157
+ self.obs_buffer = np.zeros((self.num_envs, self.obs_dim), dtype=np.float32)
158
+ # Terminal Obs Buffer for Auto-Reset
159
+ self.terminal_obs_buffer = np.zeros((self.num_envs, self.obs_dim), dtype=np.float32)
160
+
161
+ # Global Turn Counter (Pointer for Numba)
162
+ self.turn_number_ptr = np.zeros(1, dtype=np.int32)
163
+ self.turn_number_ptr[0] = 1
164
+
165
+ # Game Config (Turn Limits & Rewards)
166
+ # 0: Turn Limit, 1: Step Limit, 2: Win Reward, 3: Lose Reward, 4: Score Scale, 5: Turn Penalty
167
+ self.game_config = np.zeros(10, dtype=np.float32)
168
+ self.game_config[0] = float(os.getenv("GAME_TURN_LIMIT", "100"))
169
+ self.game_config[1] = float(os.getenv("GAME_STEP_LIMIT", "1000"))
170
+ self.game_config[2] = float(os.getenv("GAME_REWARD_WIN", "100.0"))
171
+ self.game_config[3] = float(os.getenv("GAME_REWARD_LOSE", "-100.0"))
172
+ self.game_config[4] = float(os.getenv("GAME_REWARD_SCORE_SCALE", "50.0"))
173
+ self.game_config[5] = float(os.getenv("GAME_REWARD_TURN_PENALTY", "-0.05"))
174
+ print(
175
+ f" [VectorEnv] Game Config: Turns={int(self.game_config[0])}, Steps={int(self.game_config[1])}, Win={self.game_config[2]}, Lose={self.game_config[3]}"
176
+ )
177
+
178
+ # Load Bytecode Map
179
+ self._load_bytecode()
180
+
181
+ # Check for Fixed Deck Override
182
+ fixed_deck_path = os.getenv("USE_FIXED_DECK")
183
+ if fixed_deck_path:
184
+ self._load_fixed_deck_pool(fixed_deck_path)
185
+ else:
186
+ self._load_verified_deck_pool()
187
+
188
+ def _load_bytecode(self):
189
+ import json
190
+
191
+ try:
192
+ with open("data/cards_numba.json", "r") as f:
193
+ raw_map = json.load(f)
194
+
195
+ # Convert to numpy array
196
+ # Format: key "cardid_abidx" -> List[int]
197
+ # storage:
198
+ # 1. giant array of bytecodes (N, MaxLen, 4)
199
+ # 2. lookup index (CardID, AbIdx) -> Index in giant array
200
+
201
+ self.max_cards = 2000
202
+ self.max_abilities = 8
203
+ self.max_len = 128 # Max 128 instructions per ability for future expansion
204
+
205
+ # Count unique compiled entries
206
+ unique_entries = len(raw_map)
207
+ # (Index 0 is empty/nop)
208
+ self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32)
209
+ self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32)
210
+
211
+ idx_counter = 1
212
+ for key, bc_list in raw_map.items():
213
+ cid, aid = map(int, key.split("_"))
214
+ if cid < self.max_cards and aid < self.max_abilities:
215
+ # reshape list to (M, 4)
216
+ bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
217
+ length = min(bc_arr.shape[0], self.max_len)
218
+ self.bytecode_map[idx_counter, :length] = bc_arr[:length]
219
+ self.bytecode_index[cid, aid] = idx_counter
220
+ idx_counter += 1
221
+
222
+ print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.")
223
+
224
+ # --- IMAX PRO VISION (Stride 80) ---
225
+ # Fixed Geography: No maps, no shifting. Dedicated space per ability.
226
+ # 0-19: Stats (Cost, Hearts, Traits, Live Reqs)
227
+ # 20-35: Ability 1 (Trig, Cond, Opts, 3 Effs)
228
+ # 36-47: Ability 2 (Trig, Cond, 3 Effs)
229
+ # 48-59: Ability 3 (Trig, Cond, 3 Effs)
230
+ # 60-71: Ability 4 (Trig, Cond, 3 Effs)
231
+ # 79: Location Signal (Runtime Only)
232
+ self.card_stats = np.zeros((self.max_cards, 80), dtype=np.int32)
233
+
234
+ try:
235
+ import json
236
+ import re
237
+
238
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
239
+ db = json.load(f)
240
+
241
+ # We need to map Card ID (int) -> Stats
242
+ # cards_compiled.json is keyed by string integer "0", "1"...
243
+
244
+ count = 0
245
+
246
+ # Build character name to ID mapping for Baton Pass
247
+ name_to_id = {}
248
+
249
+ # First pass: collect all character names and their IDs
250
+ if "member_db" in db:
251
+ for cid_str, card in db["member_db"].items():
252
+ cid = int(cid_str)
253
+ if cid < self.max_cards:
254
+ # Store character name to ID mapping
255
+ name = card.get("name", "")
256
+ if name:
257
+ name_to_id[name] = cid
258
+
259
+ # Load Members
260
+ if "member_db" in db:
261
+ for cid_str, card in db["member_db"].items():
262
+ cid = int(cid_str)
263
+ if cid < self.max_cards:
264
+ # 0. Card Type (1=Member)
265
+ self.card_stats[cid, 10] = 1
266
+ # 1. Cost
267
+ self.card_stats[cid, 0] = card.get("cost", 0)
268
+ # 2. Blades
269
+ self.card_stats[cid, 1] = card.get("blades", 0)
270
+ # 3. Hearts (Sum of array elements > 0?)
271
+ # Actually just count non-zero hearts in array? Or sum of values?
272
+ # Usually 'hearts' is [points, points...]. Let's sum points.
273
+ h_arr = card.get("hearts", [])
274
+ self.card_stats[cid, 2] = sum(h_arr)
275
+
276
+ # 4. Store detailed hearts for Members too (indices 12-18)
277
+ # [Pn, Rd, Yl, Gr, Bl, Pu, All]
278
+ for r_idx in range(min(len(h_arr), 7)):
279
+ self.card_stats[cid, 12 + r_idx] = h_arr[r_idx]
280
+
281
+ # Store Character ID in index 19 for Baton Pass condition
282
+ name = card.get("name", "")
283
+ if name in name_to_id:
284
+ self.card_stats[cid, 19] = name_to_id[name]
285
+
286
+ # Infer Primary Color (for visualization/traits)
287
+ col = 0
288
+ for cidx, val in enumerate(h_arr):
289
+ if val > 0:
290
+ col = cidx + 1 # 1-based color
291
+ break
292
+ self.card_stats[cid, 3] = col
293
+
294
+ # 5. Volume/Draw Icons
295
+ self.card_stats[cid, 4] = card.get("volume_icons", 0)
296
+ self.card_stats[cid, 5] = card.get("draw_icons", 0)
297
+
298
+ # 6. Blade Hearts (flipped as yell)
299
+ bh = card.get("blade_hearts", [])
300
+ for b_idx in range(min(len(bh), 7)):
301
+ self.card_stats[cid, 40 + b_idx] = bh[b_idx]
302
+
303
+ # Live Card Stats
304
+ if "required_hearts" in card:
305
+ # Pack Required Hearts into 12-18 (Pink..Purple, All)
306
+ reqs = card.get("required_hearts", [])
307
+ for r_idx in range(min(len(reqs), 7)):
308
+ self.card_stats[cid, 12 + r_idx] = reqs[r_idx]
309
+
310
+ # --- FIXED GEOGRAPHY ABILITY PACKING ---
311
+ ab_list = card.get("abilities", [])
312
+
313
+ # Helper to pack an ability into a fixed block
314
+ def pack_ability_block(ab, base_idx, has_opts=False):
315
+ if not ab:
316
+ return
317
+
318
+ # Trigger (Base + 0)
319
+ self.card_stats[cid, base_idx] = ab.get("trigger", 0)
320
+
321
+ # Condition (Base + 1, 2)
322
+ conds = ab.get("conditions", [])
323
+ if conds:
324
+ self.card_stats[cid, base_idx + 1] = conds[0].get("type", 0)
325
+ self.card_stats[cid, base_idx + 2] = conds[0].get("params", {}).get("value", 0)
326
+
327
+ # Effects
328
+ effs = ab.get("effects", [])
329
+ eff_start = base_idx + 3
330
+ if has_opts: # Ability 1 has extra space for Options
331
+ eff_start = base_idx + 9 # Skip 6 slots for options
332
+
333
+ # Pack Options (from first effect)
334
+ if effs:
335
+ m_opts = effs[0].get("modal_options", [])
336
+ if len(m_opts) > 0 and len(m_opts[0]) > 0:
337
+ o = m_opts[0][0] # Opt 1
338
+ self.card_stats[cid, base_idx + 3] = o.get("effect_type", 0)
339
+ self.card_stats[cid, base_idx + 4] = o.get("value", 0)
340
+ self.card_stats[cid, base_idx + 5] = o.get("target", 0)
341
+ if len(m_opts) > 1 and len(m_opts[1]) > 0:
342
+ o = m_opts[1][0] # Opt 2
343
+ self.card_stats[cid, base_idx + 6] = o.get("effect_type", 0)
344
+ self.card_stats[cid, base_idx + 7] = o.get("value", 0)
345
+ self.card_stats[cid, base_idx + 8] = o.get("target", 0)
346
+
347
+ # Pack up to 3 Effects
348
+ for e_i in range(min(len(effs), 3)):
349
+ e = effs[e_i]
350
+ off = eff_start + (e_i * 3)
351
+ self.card_stats[cid, off] = e.get("effect_type", 0)
352
+ self.card_stats[cid, off + 1] = e.get("value", 0)
353
+ self.card_stats[cid, off + 2] = e.get("target", 0)
354
+
355
+ # Block 1: Ability 1 (Indices 20-35) [Has Options]
356
+ if len(ab_list) > 0:
357
+ pack_ability_block(ab_list[0], 20, has_opts=True)
358
+
359
+ # Block 2: Ability 2 (Indices 36-47)
360
+ if len(ab_list) > 1:
361
+ pack_ability_block(ab_list[1], 36)
362
+
363
+ # Block 3: Ability 3 (Indices 48-59)
364
+ if len(ab_list) > 2:
365
+ pack_ability_block(ab_list[2], 48)
366
+
367
+ # Block 4: Ability 4 (Indices 60-71)
368
+ if len(ab_list) > 3:
369
+ pack_ability_block(ab_list[3], 60)
370
+
371
+ # 7. Type
372
+ self.card_stats[cid, 10] = 1
373
+
374
+ # 8. Traits Bitmask (Groups & Units) -> Stores in Index 11
375
+ # Bits 0-4: Groups (Max 5)
376
+ # Bits 5-20: Units (Max 16)
377
+ mask = 0
378
+ groups = card.get("groups", [])
379
+ for g in groups:
380
+ try:
381
+ mask |= 1 << (int(g) % 20)
382
+ except:
383
+ pass
384
+
385
+ units = card.get("units", [])
386
+ for u in units:
387
+ try:
388
+ mask |= 1 << ((int(u) % 20) + 5)
389
+ except:
390
+ pass
391
+
392
+ self.card_stats[cid, 11] = mask
393
+
394
+ count += 1
395
+
396
+ # Load Lives
397
+ if "live_db" in db:
398
+ for cid_str, card in db["live_db"].items():
399
+ cid = int(cid_str)
400
+ if cid < self.max_cards:
401
+ # Type: Live=2
402
+ self.card_stats[cid, 10] = 2
403
+
404
+ # Required Hearts
405
+ reqs = card.get("required_hearts", [])
406
+ for r_idx in range(min(len(reqs), 7)):
407
+ self.card_stats[cid, 12 + r_idx] = reqs[r_idx]
408
+
409
+ # Score
410
+ self.card_stats[cid, 38] = card.get("score", 0)
411
+
412
+ # Store Character ID in index 19 for Baton Pass condition
413
+ name = card.get("name", "")
414
+ if name in name_to_id:
415
+ self.card_stats[cid, 19] = name_to_id[name]
416
+
417
+ count += 1
418
+
419
+ print(f" [VectorEnv] Loaded detailed stats/abilities for {count} cards.")
420
+
421
+ # --- RUNTIME PATCHING FOR BATON PASS CARDS ---
422
+ # Scan all cards for "バトンタッチして" condition and inject C_BATON opcode
423
+ print(" [VectorEnv] Starting runtime patching for Baton Pass cards...")
424
+
425
+ # Load the original bytecode map to scan for cards that need patching
426
+ with open("data/cards_numba.json", "r") as f:
427
+ raw_map = json.load(f)
428
+
429
+ # Regex pattern to detect Baton Pass condition
430
+ baton_pattern = re.compile(r"「(.+?)」からバトンタッチして")
431
+
432
+ patched_count = 0
433
+ idx_counter = 1 # Start from 1 since 0 is reserved for empty
434
+
435
+ # First pass: count how many patched bytecodes we'll need
436
+ baton_cards = []
437
+ for cid_str, card in {**db.get("member_db", {}), **db.get("live_db", {})}.items():
438
+ cid = int(cid_str)
439
+
440
+ if cid >= self.max_cards:
441
+ continue
442
+
443
+ # Check if this card has abilities with Baton Pass condition
444
+ ab_list = card.get("abilities", [])
445
+
446
+ for ab_idx, ability in enumerate(ab_list):
447
+ raw_text = ability.get("raw_text", "")
448
+
449
+ # Check if the raw text contains the Baton Pass pattern
450
+ match = baton_pattern.search(raw_text)
451
+ if match:
452
+ target_name = match.group(1)
453
+
454
+ # Get the target character ID
455
+ target_cid = name_to_id.get(target_name, -1)
456
+
457
+ if target_cid != -1:
458
+ original_key = f"{cid}_{ab_idx}"
459
+ if original_key in raw_map:
460
+ baton_cards.append((cid, ab_idx, target_cid, raw_map[original_key], target_name))
461
+
462
+ # Second pass: expand bytecode_map if needed and apply patches
463
+ for cid, ab_idx, target_cid, original_bytecode, target_name in baton_cards:
464
+ # Get the card object again to access the name
465
+ card = {}
466
+ if str(cid) in db.get("member_db", {}):
467
+ card = db["member_db"][str(cid)]
468
+ elif str(cid) in db.get("live_db", {}):
469
+ card = db["live_db"][str(cid)]
470
+
471
+ # This card has a Baton Pass condition that needs to be patched
472
+ print(
473
+ f" [VectorEnv] Patching Baton Pass for card {cid} ('{card.get('name', '')}') targeting '{target_name}' (ID: {target_cid})"
474
+ )
475
+
476
+ # Create new bytecode sequence with C_BATON condition prepended
477
+ # Format: [C_BATON, Target_Char_ID, 0, 0] + original_bytecode
478
+ # Prepend CHECK_BATON (231) opcode
479
+ new_bytecode = [231, target_cid, 0, 0] + original_bytecode # original_bytecode is already a list
480
+
481
+ # Find a free slot in the bytecode map for the patched version
482
+ if idx_counter < self.bytecode_map.shape[0]:
483
+ # Reshape the new bytecode to fit the map dimensions
484
+ bc_arr = np.array(new_bytecode, dtype=np.int32).reshape(-1, 4)
485
+ length = min(bc_arr.shape[0], self.max_len)
486
+ self.bytecode_map[idx_counter, :length] = bc_arr[:length]
487
+
488
+ # Update the bytecode index to point to the new patched version
489
+ self.bytecode_index[cid, ab_idx] = idx_counter
490
+
491
+ patched_count += 1
492
+ print(
493
+ f" [VectorEnv] Successfully patched ability {ab_idx} for card {cid}, new bytecode index: {idx_counter}"
494
+ )
495
+ idx_counter += 1
496
+ else:
497
+ print(f" [VectorEnv] Error: No more space in bytecode map for card {cid}")
498
+
499
+ print(f" [VectorEnv] Runtime patching completed. {patched_count} cards patched.")
500
+
501
+ except Exception as e:
502
+ print(f" [VectorEnv] Warning: Failed to load compiled stats: {e}")
503
+
504
+ except FileNotFoundError:
505
+ print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.")
506
+ self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32)
507
+ self.bytecode_index = np.zeros((1, 1), dtype=np.int32)
508
+
509
+ def _load_verified_deck_pool(self):
510
+ import json
511
+
512
+ try:
513
+ # Load Verified List
514
+ with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
515
+ verified_data = json.load(f)
516
+
517
+ # Load DB to map CardNo -> CardID
518
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
519
+ db_data = json.load(f)
520
+
521
+ self.ability_member_ids = []
522
+ self.ability_live_ids = []
523
+ self.vanilla_member_ids = []
524
+ self.vanilla_live_ids = []
525
+
526
+ # Map numbers to IDs and types
527
+ member_no_map = {}
528
+ live_no_map = {}
529
+ for cid, cdata in db_data.get("member_db", {}).items():
530
+ member_no_map[cdata["card_no"]] = int(cid)
531
+ for cid, cdata in db_data.get("live_db", {}).items():
532
+ live_no_map[cdata["card_no"]] = int(cid)
533
+
534
+ # Check for list compatibility mode
535
+ if isinstance(verified_data, list):
536
+ print(" [VectorEnv] Loading Verified Pool from List (Compatibility Mode)")
537
+ for v_no in verified_data:
538
+ if v_no in member_no_map:
539
+ self.ability_member_ids.append(member_no_map[v_no])
540
+ elif v_no in live_no_map:
541
+ self.ability_live_ids.append(live_no_map[v_no])
542
+ else:
543
+ # 1. Primary Pool: Abilities (Categorized)
544
+ # Support both old keys (verified_abilities) and new keys (members)
545
+ source_members = verified_data.get("verified_abilities", []) + verified_data.get("members", [])
546
+ for v_no in source_members:
547
+ if v_no in member_no_map:
548
+ self.ability_member_ids.append(member_no_map[v_no])
549
+
550
+ source_lives = verified_data.get("verified_lives", []) + verified_data.get("lives", [])
551
+ for v_no in source_lives:
552
+ if v_no in live_no_map:
553
+ self.ability_live_ids.append(live_no_map[v_no])
554
+
555
+ # 2. Secondary Pool: Vanilla
556
+ for v_no in verified_data.get("vanilla_members", []):
557
+ if v_no in member_no_map:
558
+ self.vanilla_member_ids.append(member_no_map[v_no])
559
+ for v_no in verified_data.get("vanilla_lives", []):
560
+ if v_no in live_no_map:
561
+ self.vanilla_live_ids.append(live_no_map[v_no])
562
+
563
+ # Fallback/Warnings
564
+ if not self.ability_member_ids:
565
+ if self.vanilla_member_ids:
566
+ print(" [VectorEnv] Warning: No ability members. using vanilla members.")
567
+ self.ability_member_ids = self.vanilla_member_ids
568
+ else:
569
+ print(" [VectorEnv] Warning: No members found. Using ID 1.")
570
+ self.ability_member_ids = [1]
571
+
572
+ if not self.ability_live_ids:
573
+ if self.vanilla_live_ids:
574
+ print(" [VectorEnv] Warning: No ability lives. Using vanilla lives.")
575
+ self.ability_live_ids = self.vanilla_live_ids
576
+ else:
577
+ print(" [VectorEnv] Warning: No lives found. Using ID 999 (Dummy).")
578
+ self.ability_live_ids = [999]
579
+
580
+ print(
581
+ f" [VectorEnv] Pools: {len(self.ability_member_ids)} Ability Members, {len(self.ability_live_ids)} Ability Lives."
582
+ )
583
+ print(
584
+ f" [VectorEnv] Fallbacks: {len(self.vanilla_member_ids)} Vanilla Members, {len(self.vanilla_live_ids)} Vanilla Lives."
585
+ )
586
+
587
+ self.ability_member_ids = np.array(self.ability_member_ids, dtype=np.int32)
588
+ self.ability_live_ids = np.array(self.ability_live_ids, dtype=np.int32)
589
+ self.vanilla_member_ids = np.array(self.vanilla_member_ids, dtype=np.int32)
590
+ self.vanilla_live_ids = np.array(self.vanilla_live_ids, dtype=np.int32)
591
+
592
+ except Exception as e:
593
+ print(f" [VectorEnv] Deck Load Error: {e}")
594
+ self.ability_member_ids = np.array([1], dtype=np.int32)
595
+ self.ability_live_ids = np.array([999], dtype=np.int32)
596
+ self.vanilla_member_ids = np.array([], dtype=np.int32)
597
+ self.vanilla_live_ids = np.array([], dtype=np.int32)
598
+
599
+ def _load_fixed_deck_pool(self, deck_path: str):
600
+ import json
601
+ import re
602
+
603
+ print(f" [VectorEnv] Loading FIXED DECK from: {deck_path}")
604
+ try:
605
+ # 1. Load DB to map CardNo -> CardID
606
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
607
+ db_data = json.load(f)
608
+
609
+ member_no_map = {}
610
+ live_no_map = {}
611
+ for cid, cdata in db_data.get("member_db", {}).items():
612
+ member_no_map[cdata["card_no"]] = int(cid)
613
+ for cid, cdata in db_data.get("live_db", {}).items():
614
+ live_no_map[cdata["card_no"]] = int(cid)
615
+
616
+ # 2. Parse Markdown
617
+ with open(deck_path, "r", encoding="utf-8") as f:
618
+ lines = f.readlines()
619
+
620
+ members = []
621
+ lives = []
622
+
623
+ for line in lines:
624
+ # Look for "4x [PL!-...]" - flexible for markdown bolding like **4x**
625
+ match = re.search(r"(\d+)x.*?\[(PL!-[^\]]+)\]", line)
626
+ if match:
627
+ count = int(match.group(1))
628
+ card_no = match.group(2)
629
+ if card_no in member_no_map:
630
+ for _ in range(count):
631
+ members.append(member_no_map[card_no])
632
+ elif card_no in live_no_map:
633
+ for _ in range(count):
634
+ lives.append(live_no_map[card_no])
635
+
636
+ # 3. Finalize
637
+ if len(members) != 48:
638
+ print(f" [VectorEnv] Warning: Fixed deck members count is {len(members)}, expected 48.")
639
+ if len(lives) != 12:
640
+ print(f" [VectorEnv] Warning: Fixed deck lives count is {len(lives)}, expected 12.")
641
+
642
+ self.ability_member_ids = np.array(members, dtype=np.int32)
643
+ self.ability_live_ids = np.array(lives, dtype=np.int32)
644
+ self.vanilla_member_ids = np.array([], dtype=np.int32)
645
+ self.vanilla_live_ids = np.array([], dtype=np.int32)
646
+
647
+ print(
648
+ f" [VectorEnv] Fixed Deck Loaded: {len(self.ability_member_ids)} members, {len(self.ability_live_ids)} lives."
649
+ )
650
+
651
+ except Exception as e:
652
+ print(f" [VectorEnv] Fixed Deck Load Error: {e}")
653
+ self._load_verified_deck_pool()
654
+
655
+ def _load_scenarios(self, path="data/scenarios.npz"):
656
+ try:
657
+ import numpy as np
658
+
659
+ data = np.load(path)
660
+ self.scenarios = {k: data[k] for k in data.files}
661
+ self.num_scenarios = len(self.scenarios["batch_hand"])
662
+ print(f" [VectorEnv] Loaded {self.num_scenarios} scenarios from {path}")
663
+ except Exception as e:
664
+ print(f" [VectorEnv] Failed to load scenarios: {e}")
665
+ self.scenarios = None
666
+
667
+ def reset(self, indices: List[int] = None):
668
+ """Reset specified environments (or all if indices is None)."""
669
+ if indices is None:
670
+ # Full Reset
671
+ # Optimization: If resetting all, just loop all in Numba
672
+ # We can use a special function or pass all indices
673
+ indices_arr = np.arange(self.num_envs, dtype=np.int32)
674
+ else:
675
+ indices_arr = np.array(indices, dtype=np.int32)
676
+
677
+ # Use new reset_single logic via loop or parallel
678
+ # We can reuse integrated_step_numba's reset logic helper
679
+ # But we need a standalone reset kernel
680
+ isn.reset_kernel_numba(
681
+ indices_arr,
682
+ self.batch_stage,
683
+ self.batch_energy_vec,
684
+ self.batch_energy_count,
685
+ self.batch_continuous_vec,
686
+ self.batch_continuous_ptr,
687
+ self.batch_tapped,
688
+ self.batch_live,
689
+ self.batch_scores,
690
+ self.batch_flat_ctx,
691
+ self.batch_global_ctx,
692
+ self.batch_hand,
693
+ self.batch_deck,
694
+ self.opp_stage,
695
+ self.opp_energy_vec,
696
+ self.opp_energy_count,
697
+ self.opp_tapped,
698
+ self.opp_live,
699
+ self.opp_scores,
700
+ self.opp_global_ctx,
701
+ self.opp_hand,
702
+ self.opp_deck,
703
+ self.batch_trash,
704
+ self.opp_trash,
705
+ self.batch_opp_history,
706
+ self.ability_member_ids,
707
+ self.ability_live_ids,
708
+ int(self.force_start_order),
709
+ )
710
+
711
+ # Scenario Overwrite
712
+ if getattr(self, "scenarios", None) is not None and os.getenv("USE_SCENARIOS", "0") == "1":
713
+ try:
714
+ # Select random scenarios
715
+ num_reset = self.num_envs if indices is None else len(indices_arr)
716
+ reset_indices = np.arange(self.num_envs) if indices is None else indices_arr
717
+ scen_indices = np.random.randint(0, self.num_scenarios, size=num_reset)
718
+
719
+ def load_field(name, target):
720
+ if name in self.scenarios:
721
+ data = self.scenarios[name][scen_indices]
722
+ if target.ndim == 1 and data.ndim == 2 and data.shape[1] == 1:
723
+ data = data.ravel()
724
+ target[reset_indices] = data
725
+
726
+ load_field("batch_hand", self.batch_hand)
727
+ load_field("batch_deck", self.batch_deck)
728
+ load_field("batch_stage", self.batch_stage)
729
+ load_field("batch_energy_vec", self.batch_energy_vec)
730
+ load_field("batch_energy_count", self.batch_energy_count)
731
+ load_field("batch_continuous_vec", self.batch_continuous_vec)
732
+ load_field("batch_continuous_ptr", self.batch_continuous_ptr)
733
+ load_field("batch_tapped", self.batch_tapped)
734
+ load_field("batch_live", self.batch_live)
735
+ load_field("batch_scores", self.batch_scores)
736
+ load_field("batch_flat_ctx", self.batch_flat_ctx)
737
+ load_field("batch_global_ctx", self.batch_global_ctx)
738
+
739
+ load_field("opp_hand", self.opp_hand)
740
+ load_field("opp_deck", self.opp_deck)
741
+ load_field("opp_stage", self.opp_stage)
742
+ load_field("opp_energy_vec", self.opp_energy_vec)
743
+ load_field("opp_energy_count", self.opp_energy_count)
744
+ load_field("opp_tapped", self.opp_tapped)
745
+ load_field("opp_live", self.opp_live)
746
+ load_field("opp_scores", self.opp_scores)
747
+ load_field("opp_global_ctx", self.opp_global_ctx)
748
+
749
+ except Exception as e:
750
+ print(f" [VectorEnv] Error loading scenario data: {e}")
751
+
752
+ # Reset local trackers
753
+ if indices is None:
754
+ self.turn = 1
755
+ self.prev_scores.fill(0)
756
+ self.prev_opp_scores.fill(0)
757
+ self.prev_phases.fill(0)
758
+ self.episode_returns.fill(0)
759
+ self.episode_lengths.fill(0)
760
+ else:
761
+ for idx in indices:
762
+ self.prev_scores[idx] = 0
763
+ self.prev_opp_scores[idx] = 0
764
+ self.prev_phases[idx] = 0
765
+ self.episode_returns[idx] = 0
766
+ self.episode_lengths[idx] = 0
767
+
768
+ # Return observations
769
+ return self.get_observations()
770
+
771
+ def step(self, actions: np.ndarray):
772
+ """Apply a batch of actions across all environments using Optimized Integrated Step."""
773
+ # Ensure actions are int32
774
+ if actions.dtype != np.int32:
775
+ actions = actions.astype(np.int32)
776
+
777
+ return self.integrated_step(actions)
778
+
779
+ def integrated_step(self, actions: np.ndarray):
780
+ """
781
+ Executes the optimized Numba Integrated Step.
782
+ Returns: obs, rewards, dones, infos (list of dicts)
783
+ """
784
+ term_scores_agent = np.zeros(self.num_envs, dtype=np.int32)
785
+ term_scores_opp = np.zeros(self.num_envs, dtype=np.int32)
786
+
787
+ rewards, dones = isn.integrated_step_numba(
788
+ self.num_envs,
789
+ actions,
790
+ self.batch_hand,
791
+ self.batch_deck,
792
+ self.batch_stage,
793
+ self.batch_energy_vec,
794
+ self.batch_energy_count,
795
+ self.batch_continuous_vec,
796
+ self.batch_continuous_ptr,
797
+ self.batch_tapped,
798
+ self.batch_live,
799
+ self.batch_scores,
800
+ self.batch_flat_ctx,
801
+ self.batch_global_ctx,
802
+ self.opp_hand,
803
+ self.opp_deck,
804
+ self.opp_stage,
805
+ self.opp_energy_vec,
806
+ self.opp_energy_count,
807
+ self.opp_tapped,
808
+ self.opp_live, # Added
809
+ self.opp_scores,
810
+ self.opp_global_ctx,
811
+ self.card_stats,
812
+ self.bytecode_map,
813
+ self.bytecode_index,
814
+ self.batch_opp_history,
815
+ self.obs_buffer,
816
+ self.prev_scores,
817
+ self.prev_opp_scores,
818
+ self.prev_phases,
819
+ self.ability_member_ids,
820
+ self.ability_live_ids,
821
+ self.turn_number_ptr,
822
+ self.terminal_obs_buffer,
823
+ self.batch_trash,
824
+ self.opp_trash,
825
+ term_scores_agent,
826
+ term_scores_opp,
827
+ 0
828
+ if self.obs_mode == "IMAX"
829
+ else (1 if self.obs_mode == "STANDARD" else (3 if self.obs_mode == "ATTENTION" else 2)),
830
+ self.game_config, # New Config
831
+ int(self.opp_mode),
832
+ int(self.force_start_order),
833
+ )
834
+
835
+ # Apply Scenario Reward Scaling
836
+ if self.scenario_reward_scale != 1.0 and os.getenv("USE_SCENARIOS", "0") == "1":
837
+ rewards *= self.scenario_reward_scale
838
+
839
+ # Construct Infos (minimal python overhead)
840
+ infos = []
841
+ for i in range(self.num_envs):
842
+ if dones[i]:
843
+ infos.append(
844
+ {
845
+ "terminal_observation": self.terminal_obs_buffer[i].copy(),
846
+ "episode": {"r": float(rewards[i]), "l": 10},
847
+ "terminal_score_agent": int(term_scores_agent[i]),
848
+ "terminal_score_opp": int(term_scores_opp[i]),
849
+ }
850
+ )
851
+ else:
852
+ # Accumulate rewards for ongoing episodes
853
+ # NOTE: rewards[i] is the delta reward for this specific integrated step.
854
+ self.episode_returns[i] += rewards[i]
855
+ self.episode_lengths[i] += 1
856
+ infos.append({})
857
+
858
+ # After loop, update terminal infos for done envs with the SUMMED returns
859
+ for i in range(self.num_envs):
860
+ if dones[i]:
861
+ # Add terminal reward to the return
862
+ final_return = self.episode_returns[i] + rewards[i]
863
+ final_length = self.episode_lengths[i] + 1
864
+ infos[i]["episode"] = {"r": float(final_return), "l": int(final_length)}
865
+ # Reset accumulators for the next episode in this slot
866
+ self.episode_returns[i] = 0
867
+ self.episode_lengths[i] = 0
868
+
869
+ return self.obs_buffer, rewards, dones, infos
870
+
871
+ def get_action_masks(self):
872
+ """Return legal action masks."""
873
+ if self.obs_mode == "ATTENTION":
874
+ return compute_action_masks_attention(
875
+ self.num_envs,
876
+ self.batch_hand,
877
+ self.batch_stage,
878
+ self.batch_tapped,
879
+ self.batch_global_ctx,
880
+ self.batch_live,
881
+ self.card_stats,
882
+ )
883
+ else:
884
+ return compute_action_masks(
885
+ self.num_envs,
886
+ self.batch_hand,
887
+ self.batch_stage,
888
+ self.batch_tapped,
889
+ self.batch_global_ctx,
890
+ self.batch_live,
891
+ self.card_stats,
892
+ )
893
+
894
+ def get_observations(self):
895
+ """Return a batched observation for RL models."""
896
+ if self.obs_mode == "COMPRESSED":
897
+ return isn.encode_observations_compressed(
898
+ self.num_envs,
899
+ self.batch_hand,
900
+ self.batch_stage,
901
+ self.batch_energy_count,
902
+ self.batch_tapped,
903
+ self.batch_scores,
904
+ self.opp_scores,
905
+ self.opp_stage,
906
+ self.opp_tapped,
907
+ self.card_stats,
908
+ self.batch_global_ctx,
909
+ self.batch_live,
910
+ self.batch_opp_history,
911
+ self.turn,
912
+ self.obs_buffer,
913
+ )
914
+ elif self.obs_mode == "IMAX":
915
+ return isn.encode_observations_imax(
916
+ self.num_envs,
917
+ self.batch_hand,
918
+ self.batch_stage,
919
+ self.batch_energy_count,
920
+ self.batch_tapped,
921
+ self.batch_scores,
922
+ self.opp_scores,
923
+ self.opp_stage,
924
+ self.opp_tapped,
925
+ self.card_stats,
926
+ self.batch_global_ctx,
927
+ self.batch_live,
928
+ self.batch_opp_history,
929
+ self.turn,
930
+ self.obs_buffer,
931
+ )
932
+ elif self.obs_mode == "ATTENTION":
933
+ return isn.encode_observations_attention(
934
+ self.num_envs,
935
+ self.batch_hand,
936
+ self.batch_stage,
937
+ self.batch_energy_count,
938
+ self.batch_tapped,
939
+ self.batch_scores,
940
+ self.opp_scores,
941
+ self.opp_stage,
942
+ self.opp_tapped,
943
+ self.card_stats,
944
+ self.batch_global_ctx,
945
+ self.batch_live,
946
+ self.batch_opp_history,
947
+ self.opp_global_ctx,
948
+ self.turn,
949
+ self.obs_buffer,
950
+ )
951
+ else:
952
+ return isn.encode_observations_standard(
953
+ self.num_envs,
954
+ self.batch_hand,
955
+ self.batch_stage,
956
+ self.batch_energy_count,
957
+ self.batch_tapped,
958
+ self.batch_scores,
959
+ self.opp_scores,
960
+ self.opp_stage,
961
+ self.opp_tapped,
962
+ self.card_stats,
963
+ self.batch_global_ctx,
964
+ self.batch_live,
965
+ self.batch_opp_history,
966
+ self.turn,
967
+ self.obs_buffer,
968
+ )
969
+
970
+
971
+ @njit(cache=True)
972
+ def step_opponent_vectorized(
973
+ opp_hand: np.ndarray, # (N, 60)
974
+ opp_deck: np.ndarray, # (N, 60)
975
+ opp_stage: np.ndarray,
976
+ opp_energy_vec: np.ndarray,
977
+ opp_energy_count: np.ndarray,
978
+ opp_tapped: np.ndarray,
979
+ opp_scores: np.ndarray,
980
+ agent_tapped: np.ndarray,
981
+ opp_global_ctx: np.ndarray, # (N, 128)
982
+ bytecode_map: np.ndarray,
983
+ bytecode_index: np.ndarray,
984
+ ):
985
+ """
986
+ Very simplified opponent step. Reuses agent bytecode but targets opponent buffers.
987
+ """
988
+ num_envs = len(opp_hand)
989
+ # Dummy buffers for context (reused per env)
990
+ f_ctx = np.zeros(64, dtype=np.int32)
991
+
992
+ # We use the passed Hand/Deck buffers directly!
993
+ live = np.zeros(50, dtype=np.int32) # Dummy live zone for opponent
994
+
995
+ # Reusable dummies to avoid allocation in loop
996
+ dummy_cont_vec = np.zeros((32, 10), dtype=np.int32)
997
+ dummy_ptr = np.zeros(1, dtype=np.int32) # Ref Array
998
+ dummy_bonus = np.zeros(1, dtype=np.int32) # Ref Array
999
+
1000
+ for i in range(num_envs):
1001
+ # RESET local context per environment
1002
+ f_ctx.fill(0)
1003
+
1004
+ # 1. Select Random Legal Action from Hand
1005
+ # Scan hand for valid bytecodes
1006
+ # Use fixed array for Numba compatibility (no lists)
1007
+ candidates = np.zeros(60, dtype=np.int32)
1008
+ c_ptr = 0
1009
+
1010
+ for j in range(60): # Hand size
1011
+ cid = opp_hand[i, j]
1012
+ if cid > 0:
1013
+ candidates[c_ptr] = j # Store Index in Hand
1014
+ c_ptr += 1
1015
+
1016
+ if c_ptr == 0:
1017
+ continue
1018
+
1019
+ # Pick one random index
1020
+ idx_choice = np.random.randint(0, c_ptr)
1021
+ hand_idx = candidates[idx_choice]
1022
+ act_id = opp_hand[i, hand_idx]
1023
+
1024
+ # 2. Execute
1025
+ if act_id > 0 and act_id < bytecode_index.shape[0]:
1026
+ map_idx = bytecode_index[act_id, 0]
1027
+ if map_idx > 0:
1028
+ code_seq = bytecode_map[map_idx]
1029
+ opp_global_ctx[i, 0] = opp_scores[i]
1030
+ opp_global_ctx[i, 3] -= 1 # Decrement Hand Count (HD) after playing
1031
+
1032
+ # Reset dummies
1033
+ dummy_ptr[0] = 0
1034
+ dummy_bonus[0] = 0
1035
+
1036
+ # Pass Row Slices of Hand/Deck
1037
+ # Careful: slicing in loop might allocate. Pass full array + index?
1038
+ # resolve_bytecode expects 1D array.
1039
+ # We can't pass a slice 'opp_hand[i]' effectively if function modifies it in place?
1040
+ # Actually resolve_bytecode modifies it.
1041
+ # Numba slices are views, should work.
1042
+
1043
+ resolve_bytecode(
1044
+ code_seq,
1045
+ f_ctx,
1046
+ opp_global_ctx[i],
1047
+ 1,
1048
+ opp_hand[i],
1049
+ opp_deck[i],
1050
+ opp_stage[i],
1051
+ opp_energy_vec[i],
1052
+ opp_energy_count[i],
1053
+ dummy_cont_vec,
1054
+ dummy_ptr,
1055
+ opp_tapped[i],
1056
+ live,
1057
+ agent_tapped[i],
1058
+ bytecode_map,
1059
+ bytecode_index,
1060
+ dummy_bonus,
1061
+ )
1062
+ # Neutralized: opp_scores[i] = opp_global_ctx[i, 0]
1063
+ # SC = 0; OS = 1; TR = 2; HD = 3; DI = 4; EN = 5; DK = 6; OT = 7
1064
+ # Resolve bytecode puts score in SC (index 0) for the current player?
1065
+ # Let's check fast_logic.py: it uses global_ctx[SC].
1066
+ # So opp_scores[i] = opp_global_ctx[i, 0] is correct if they are the "current player" in that call.
1067
+
1068
+ # 3. Post-Play Cleanup (Draw to refill?)
1069
+ # If card played, act_id removed from hand by resolve_bytecode (Opcode 11/12/13 usually).
1070
+ # To simulate "Draw", we check if hand size < 5.
1071
+ # Count current hand
1072
+ cnt = 0
1073
+ for j in range(60):
1074
+ if opp_hand[i, j] > 0:
1075
+ cnt += 1
1076
+
1077
+ if cnt < 5:
1078
+ # Draw top card from Deck
1079
+ # Find first card in Deck
1080
+ top_card = 0
1081
+ deck_idx = -1
1082
+ for j in range(60):
1083
+ if opp_deck[i, j] > 0:
1084
+ top_card = opp_deck[i, j]
1085
+ deck_idx = j
1086
+ break
1087
+
1088
+ if top_card > 0:
1089
+ # Move to Hand (First empty slot)
1090
+ for j in range(60):
1091
+ if opp_hand[i, j] == 0:
1092
+ opp_hand[i, j] = top_card
1093
+ opp_deck[i, deck_idx] = 0 # Remove from deck
1094
+ opp_global_ctx[i, 3] += 1 # Increment Hand Count (HD)
1095
+ opp_global_ctx[i, 6] -= 1 # Decrement Deck Count (DK)
1096
+ break
1097
+
1098
+
1099
+ @njit(cache=True)
1100
+ def resolve_auto_phases(
1101
+ num_envs: int,
1102
+ batch_hand: np.ndarray,
1103
+ batch_deck: np.ndarray,
1104
+ batch_global_ctx: np.ndarray,
1105
+ batch_tapped: np.ndarray,
1106
+ single_step: bool = False,
1107
+ ):
1108
+ """
1109
+ Automatically advances the game through non-interactive phases (0, 1, 2)
1110
+ until it reaches the Main Phase (3) or the game is over.
1111
+ Includes Turn Start Draw (Phase 2).
1112
+ """
1113
+ for i in range(num_envs):
1114
+ # We loop to handle multiple phase jumps if needed
1115
+ # SAFETY: Limit iterations
1116
+ max_iters = 1 if single_step else 10
1117
+ for _ in range(max_iters):
1118
+ ph = int(batch_global_ctx[i, 8])
1119
+
1120
+ # 0 (MULLIGAN) or 8 (LIVE_RESULT) -> 1 (ACTIVE)
1121
+ if ph == 0 or ph == 8:
1122
+ # Turn Start: Reset Slot Played Flags (Indices 51-53)
1123
+ batch_global_ctx[i, 51:54] = 0
1124
+
1125
+ # Reset Tapped Status (Members 0-2, Energy 3-15)
1126
+ batch_tapped[i, 0:16] = 0
1127
+
1128
+ # Increment Energy Count (Index 5) (Up to 12)
1129
+ cur_ec = batch_global_ctx[i, 5]
1130
+ if cur_ec == 0:
1131
+ batch_global_ctx[i, 5] = 3
1132
+ elif cur_ec < 12:
1133
+ batch_global_ctx[i, 5] = cur_ec + 1
1134
+
1135
+ # Increment Turn Counter (Index 54)
1136
+ batch_global_ctx[i, 54] += 1
1137
+
1138
+ batch_global_ctx[i, 8] = 1
1139
+ continue
1140
+
1141
+ # ACTIVE (1) -> ENERGY (2)
1142
+ if ph == 1:
1143
+ batch_global_ctx[i, 8] = 2
1144
+ continue
1145
+
1146
+ # ENERGY (2) -> DRAW (3)
1147
+ if ph == 2:
1148
+ batch_global_ctx[i, 8] = 3
1149
+ continue
1150
+
1151
+ # DRAW (3) -> MAIN (4)
1152
+ if ph == 3:
1153
+ # DRAW 1 CARD
1154
+ top_card = 0
1155
+ deck_idx = -1
1156
+ for d_idx in range(60):
1157
+ if batch_deck[i, d_idx] > 0:
1158
+ top_card = batch_deck[i, d_idx]
1159
+ deck_idx = d_idx
1160
+ break
1161
+
1162
+ # REPLENISH DECK IF EMPTY (Infinite play for benchmarks)
1163
+ if top_card == 0:
1164
+ batch_global_ctx[i, 8] = 4
1165
+ continue
1166
+
1167
+ if top_card > 0:
1168
+ for h_idx in range(60):
1169
+ if batch_hand[i, h_idx] == 0:
1170
+ batch_hand[i, h_idx] = top_card
1171
+ batch_deck[i, deck_idx] = 0
1172
+ batch_global_ctx[i, 3] = 0
1173
+ for k in range(60):
1174
+ if batch_hand[i, k] > 0:
1175
+ batch_global_ctx[i, 3] += 1
1176
+ batch_global_ctx[i, 6] -= 1
1177
+ break
1178
+
1179
+ batch_global_ctx[i, 8] = 4
1180
+ continue
1181
+
1182
+ # If ph == 4 (Main), we stop and let the agent act.
1183
+ if ph == 4:
1184
+ break
1185
+
1186
+ # If ph is not handled, break to avoid infinite loop
1187
+ break
1188
+
1189
+
1190
+ @njit(parallel=True, cache=True)
1191
+ def compute_action_masks_attention(
1192
+ num_envs: int,
1193
+ batch_hand: np.ndarray,
1194
+ batch_stage: np.ndarray,
1195
+ batch_tapped: np.ndarray,
1196
+ batch_global_ctx: np.ndarray,
1197
+ batch_live: np.ndarray,
1198
+ card_stats: np.ndarray,
1199
+ ):
1200
+ """
1201
+ Compute legal action masks for ATTENTION mode (512 actions).
1202
+ Mapping:
1203
+ - 0: Pass
1204
+ - 1-45: Play Member (15 hand idx * 3 slots)
1205
+ - 46-60: Set Live (15 hand idx)
1206
+ - 61-63: Activate Ability (3 slots)
1207
+ - 64-69: Mulligan Select (6 cards)
1208
+ - 100-299: Choice Actions (Not fully implemented yet)
1209
+ """
1210
+ masks = np.zeros((num_envs, 512), dtype=np.bool_)
1211
+ masks[:, 0] = True # Pass always legal
1212
+
1213
+ for i in prange(num_envs):
1214
+ phase = batch_global_ctx[i, 8]
1215
+
1216
+ # --- Mulligan (Phase Includes -1, 0) ---
1217
+ if phase <= 0:
1218
+ # Allow pass (0) to finish
1219
+ masks[i, 0] = True
1220
+ # Allow select mulligan (64-69) for first 6 cards
1221
+ # ONE-WAY: If already selected (flag=1), mask it.
1222
+ for h_idx in range(6):
1223
+ if batch_hand[i, h_idx] > 0:
1224
+ if batch_global_ctx[i, 120 + h_idx] == 0:
1225
+ masks[i, 64 + h_idx] = True
1226
+ continue
1227
+
1228
+ # --- Main Phase (4) ---
1229
+ if phase == 4:
1230
+ ec = batch_global_ctx[i, 5]
1231
+ tapped_count = 0
1232
+ for e_idx in range(min(ec, 12)):
1233
+ if batch_tapped[i, 3 + e_idx] > 0:
1234
+ tapped_count += 1
1235
+ available_energy = ec - tapped_count
1236
+
1237
+ # 1. Play Actions (1-45) & Set Live (46-60)
1238
+ # Hand limit for this mode is 15 primary indices
1239
+ for h_idx in range(15):
1240
+ cid = batch_hand[i, h_idx]
1241
+ if cid <= 0 or cid >= card_stats.shape[0]:
1242
+ continue
1243
+
1244
+ is_member = card_stats[cid, 10] == 1
1245
+ is_live = card_stats[cid, 10] == 2
1246
+
1247
+ if is_member:
1248
+ # Play to Slot 0-2 (Actions 1-45)
1249
+ # Base = 1 + h_idx * 3
1250
+ cost = card_stats[cid, 0]
1251
+ for slot in range(3):
1252
+ # One play per slot per turn check
1253
+ if batch_global_ctx[i, 51 + slot] > 0:
1254
+ continue
1255
+
1256
+ # Effective Cost (Baton Touch)
1257
+ effective_cost = cost
1258
+ prev_cid = batch_stage[i, slot]
1259
+ if prev_cid > 0 and prev_cid < card_stats.shape[0]:
1260
+ effective_cost = max(0, cost - card_stats[prev_cid, 0])
1261
+
1262
+ if effective_cost <= available_energy:
1263
+ masks[i, 1 + h_idx * 3 + slot] = True
1264
+
1265
+ # Set Live (Actions 46-60)
1266
+ # Rule 8.3 & 8.2.2: ANY card can be set.
1267
+ # Limit 3 cards in zone
1268
+ live_count = 0
1269
+ for lx in range(6): # Check full 6 capacity (3 pending + 3 success)
1270
+ if batch_live[i, lx] > 0:
1271
+ live_count += 1
1272
+
1273
+ if live_count < 3:
1274
+ masks[i, 46 + h_idx] = True
1275
+
1276
+ # 2. Activate Abilities (61-63)
1277
+ for slot in range(3):
1278
+ cid = batch_stage[i, slot]
1279
+ if cid > 0 and not batch_tapped[i, slot]:
1280
+ masks[i, 61 + slot] = True
1281
+
1282
+ # --- Choice Handling (Phase 7+) ---
1283
+ if phase >= 7 or phase == 4:
1284
+ # Allow hand selection (100-159)
1285
+ for h_idx in range(60):
1286
+ if batch_hand[i, h_idx] > 0:
1287
+ masks[i, 100 + h_idx] = True
1288
+
1289
+ # Allow energy selection (160-171)
1290
+ ec_val = batch_global_ctx[i, 5]
1291
+ for e_idx in range(min(ec_val, 12)):
1292
+ masks[i, 160 + e_idx] = True
1293
+
1294
+ return masks
1295
+
1296
+
1297
+ @njit(parallel=True, cache=True)
1298
+ def compute_action_masks(
1299
+ num_envs: int,
1300
+ batch_hand: np.ndarray,
1301
+ batch_stage: np.ndarray,
1302
+ batch_tapped: np.ndarray,
1303
+ batch_global_ctx: np.ndarray,
1304
+ batch_live: np.ndarray,
1305
+ card_stats: np.ndarray,
1306
+ ):
1307
+ """
1308
+ Compute legal action masks using Python-compatible action IDs:
1309
+ - 0: Pass (always legal in Main Phase)
1310
+ - 1-180: Play Member from Hand (HandIdx * 3 + Slot + 1)
1311
+ - 200-202: Activate Ability (Slot)
1312
+ - 400-459: Set Live Card (HandIdx)
1313
+ """
1314
+ masks = np.zeros((num_envs, 2000), dtype=np.bool_)
1315
+
1316
+ # Action 0 (Pass) is always legal
1317
+ masks[:, 0] = True
1318
+
1319
+ for i in prange(num_envs):
1320
+ phase = batch_global_ctx[i, 8]
1321
+ # Mulligan Phases (-1, 0)
1322
+ # Mulligan Phases (-1, 0)
1323
+ if phase == -1 or phase == 0:
1324
+ masks[i, 0] = True # Pass to finalize
1325
+ # Only allow selection if the card exists AND isn't already selected (One-way)
1326
+ for h_idx in range(6): # Only first 6 cards are mull-able (Parity)
1327
+ if batch_hand[i, h_idx] > 0:
1328
+ selected = batch_global_ctx[i, 120 + h_idx]
1329
+ if selected == 0:
1330
+ masks[i, 300 + h_idx] = True
1331
+ continue
1332
+
1333
+ # Only compute member/ability actions in Main Phase (4)
1334
+ if phase == 4:
1335
+ # Calculate available untapped energy
1336
+ ec = batch_global_ctx[i, 5] # EC at index 5
1337
+ tapped_count = 0
1338
+ for e_idx in range(min(ec, 12)):
1339
+ if batch_tapped[i, 3 + e_idx] > 0:
1340
+ tapped_count += 1
1341
+ available_energy = ec - tapped_count
1342
+
1343
+ # --- Member Play Actions (1-180) ---
1344
+ # Action ID = HandIdx * 3 + Slot + 1
1345
+ for h_idx in range(60):
1346
+ cid = batch_hand[i, h_idx]
1347
+ # CRITICAL SAFETY: card_stats shape check
1348
+ if cid <= 0 or cid >= card_stats.shape[0]:
1349
+ continue
1350
+
1351
+ # Check if this is a Member card (Type 1)
1352
+ if card_stats[cid, 10] != 1:
1353
+ # Check if this is a Live card (Type 2) for play actions 400-459
1354
+ if card_stats[cid, 10] == 2:
1355
+ # Action ID = 400 + h_idx
1356
+ action_id = 400 + h_idx
1357
+
1358
+ # --- RULE ACCURACY: Live cards can be set without checking hearts ---
1359
+ # Requirements are checked during Performance phase (Rule 8.3)
1360
+ # We allow setting if hand size limit not reached (max 3 in zone)
1361
+ count_in_zone = 0
1362
+ for j in range(50):
1363
+ if batch_live[i, j] > 0:
1364
+ count_in_zone += 1
1365
+
1366
+ if count_in_zone < 3:
1367
+ masks[i, action_id] = True
1368
+ continue
1369
+
1370
+ # Member cost in card_stats[cid, 0]
1371
+ cost = card_stats[cid, 0]
1372
+
1373
+ for slot in range(3):
1374
+ action_id = h_idx * 3 + slot + 1
1375
+
1376
+ # Rule: One play per slot per turn (Indices 51-53)
1377
+ if batch_global_ctx[i, 51 + slot] > 0:
1378
+ continue
1379
+
1380
+ # Calculate effective cost (Baton Touch reduction)
1381
+ effective_cost = cost
1382
+ prev_cid = batch_stage[i, slot]
1383
+ # SAFETY: Check cid range to avoid out-of-bounds card_stats access
1384
+ if prev_cid >= 0 and prev_cid < card_stats.shape[0]:
1385
+ prev_cost = card_stats[prev_cid, 0]
1386
+ effective_cost = cost - prev_cost
1387
+ if effective_cost < 0:
1388
+ effective_cost = 0
1389
+
1390
+ if effective_cost <= available_energy:
1391
+ masks[i, action_id] = True
1392
+
1393
+ # --- Activate Ability Actions (200-202) ---
1394
+ for slot in range(3):
1395
+ cid = batch_stage[i, slot]
1396
+ if cid > 0 and not batch_tapped[i, slot]:
1397
+ # Check if card has an activated ability
1398
+ # For now, assume all untapped members can activate
1399
+ masks[i, 200 + slot] = True
1400
+
1401
+ # --- Mandatory Choice Handling (Phase 7, 8 & Fallback) ---
1402
+ if phase >= 7 or phase == 4:
1403
+ # Allow hand selection/discard actions (500-559) if hand has cards
1404
+ # This prevents Zero Legal Moves when a choice is pending.
1405
+ for h_idx in range(60):
1406
+ if batch_hand[i, h_idx] > 0:
1407
+ masks[i, 500 + h_idx] = True
1408
+
1409
+ # Allow energy selection actions (600-611) if energy exists
1410
+ energy_count = batch_global_ctx[i, 5]
1411
+ for e_idx in range(min(energy_count, 12)):
1412
+ masks[i, 600 + e_idx] = True
1413
+
1414
+ return masks
1415
+
1416
+
1417
+ # Export for legacy/external compatibility
1418
+ encode_observations_vectorized = isn.encode_observations_standard