import os from typing import List import numpy as np from numba import njit, prange import ai.research.integrated_step_numba as isn from engine.game.fast_logic import ( batch_apply_action, resolve_bytecode, ) @njit(cache=True) def step_vectorized( actions: np.ndarray, batch_stage: np.ndarray, batch_energy_vec: np.ndarray, batch_energy_count: np.ndarray, batch_continuous_vec: np.ndarray, batch_continuous_ptr: np.ndarray, batch_tapped: np.ndarray, batch_live: np.ndarray, batch_opp_tapped: np.ndarray, batch_scores: np.ndarray, batch_flat_ctx: np.ndarray, batch_global_ctx: np.ndarray, batch_hand: np.ndarray, batch_deck: np.ndarray, # New: Bytecode Maps bytecode_map: np.ndarray, # (GlobalOpMapSize, MaxBytecodeLen, 4) bytecode_index: np.ndarray, # (NumCards, NumAbilities) -> Index in map card_stats: np.ndarray, batch_trash: np.ndarray, # Added ): """ Step N game environments in parallel using JIT logic and Real Card Data. """ # Score sync now handled internally by batch_apply_action batch_apply_action( actions, 0, # player_id batch_stage, batch_energy_vec, batch_energy_count, batch_continuous_vec, batch_continuous_ptr, batch_tapped, batch_scores, batch_live, batch_opp_tapped, batch_flat_ctx, batch_global_ctx, batch_hand, batch_deck, batch_trash, # Added bytecode_map, bytecode_index, card_stats, ) rewards = np.zeros(actions.shape[0], dtype=np.float32) dones = np.zeros(actions.shape[0], dtype=np.bool_) return rewards, dones class VectorGameState: """ Manages a batch of independent GameStates for high-throughput training. """ def __init__(self, num_envs: int, opp_mode: int = 0, force_start_order: int = -1): self.num_envs = num_envs # opp_mode: 0=Heuristic, 1=Random, 2=Solitaire (Pass Only) self.opp_mode = opp_mode self.force_start_order = force_start_order # -1=Random, 0=P1, 1=P2 self.turn = 1 # Batched state buffers - Player 0 (Agent) self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32) self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32) self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32) self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32) self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32) self.batch_tapped = np.zeros((num_envs, 16), dtype=np.int32) # Slots 0-2, Energy 3-15 self.batch_live = np.zeros((num_envs, 50), dtype=np.int32) self.batch_opp_tapped = np.zeros((num_envs, 16), dtype=np.int32) self.batch_scores = np.zeros(num_envs, dtype=np.int32) # Batched state buffers - Opponent State (Player 1) self.opp_stage = np.full((num_envs, 3), -1, dtype=np.int32) self.opp_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32) # Match Agent Shape self.opp_energy_count = np.zeros((num_envs, 3), dtype=np.int32) self.opp_tapped = np.zeros((num_envs, 16), dtype=np.int8) self.opp_live = np.zeros((num_envs, 50), dtype=np.int32) # Added Opp Live self.opp_scores = np.zeros(num_envs, dtype=np.int32) # New State Tracking for Integrated Step self.prev_scores = np.zeros(num_envs, dtype=np.int32) self.prev_opp_scores = np.zeros(num_envs, dtype=np.int32) self.prev_phases = np.zeros(num_envs, dtype=np.int32) self.episode_returns = np.zeros(num_envs, dtype=np.float32) self.episode_lengths = np.zeros(num_envs, dtype=np.int32) # Opponent Finite Deck Buffers self.opp_hand = np.zeros((num_envs, 60), dtype=np.int32) self.opp_deck = np.zeros((num_envs, 60), dtype=np.int32) # Load Numba functions import os if os.getenv("USE_SCENARIOS", "0") == "1": self._load_scenarios() # Scenario Reward Scaling self.scenario_reward_scale = float(os.getenv("SCENARIO_REWARD_SCALE", "1.0")) if os.getenv("USE_SCENARIOS", "0") == "1" and self.scenario_reward_scale != 1.0: print(f" [VectorEnv] Scenario Reward Scale: {self.scenario_reward_scale}") # New: Opponent History Buffer (Top 20 cards e.g.) self.batch_opp_history = np.zeros((num_envs, 50), dtype=np.int32) # Pre-allocated context buffers (Extreme speed optimization) self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32) self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32) self.opp_global_ctx = np.zeros((num_envs, 128), dtype=np.int32) # Persistent Opponent Context self.batch_hand = np.zeros((num_envs, 60), dtype=np.int32) self.batch_deck = np.zeros((num_envs, 60), dtype=np.int32) self.batch_trash = np.zeros((num_envs, 60), dtype=np.int32) # Added Trash self.opp_trash = np.zeros((num_envs, 60), dtype=np.int32) # Added Opp Trash # Observation Buffer # 20480 floats per env to handle Full Hand (60 cards) + Opponent + Stats # Increased for "Real Vision" upgrade # Observation Buffer # Mode Selection import os self.obs_mode = os.getenv("OBS_MODE", "STANDARD") if self.obs_mode == "COMPRESSED": self.obs_dim = 512 self.action_space_dim = 2000 print(" [VectorEnv] Observation Mode: COMPRESSED (512-dim)") elif self.obs_mode == "IMAX": self.obs_dim = 8192 self.action_space_dim = 2000 print(" [VectorEnv] Observation Mode: IMAX (8192-dim)") elif self.obs_mode == "ATTENTION": self.obs_dim = 2240 self.action_space_dim = 512 print(" [VectorEnv] Observation Mode: ATTENTION (2240-dim)") else: self.obs_dim = 2304 self.action_space_dim = 2000 print(" [VectorEnv] Observation Mode: STANDARD (2304-dim)") self.obs_buffer = np.zeros((self.num_envs, self.obs_dim), dtype=np.float32) # Terminal Obs Buffer for Auto-Reset self.terminal_obs_buffer = np.zeros((self.num_envs, self.obs_dim), dtype=np.float32) # Global Turn Counter (Pointer for Numba) self.turn_number_ptr = np.zeros(1, dtype=np.int32) self.turn_number_ptr[0] = 1 # Game Config (Turn Limits & Rewards) # 0: Turn Limit, 1: Step Limit, 2: Win Reward, 3: Lose Reward, 4: Score Scale, 5: Turn Penalty self.game_config = np.zeros(10, dtype=np.float32) self.game_config[0] = float(os.getenv("GAME_TURN_LIMIT", "100")) self.game_config[1] = float(os.getenv("GAME_STEP_LIMIT", "1000")) self.game_config[2] = float(os.getenv("GAME_REWARD_WIN", "100.0")) self.game_config[3] = float(os.getenv("GAME_REWARD_LOSE", "-100.0")) self.game_config[4] = float(os.getenv("GAME_REWARD_SCORE_SCALE", "50.0")) self.game_config[5] = float(os.getenv("GAME_REWARD_TURN_PENALTY", "-0.05")) print( f" [VectorEnv] Game Config: Turns={int(self.game_config[0])}, Steps={int(self.game_config[1])}, Win={self.game_config[2]}, Lose={self.game_config[3]}" ) # Load Bytecode Map self._load_bytecode() # Check for Fixed Deck Override fixed_deck_path = os.getenv("USE_FIXED_DECK") if fixed_deck_path: self._load_fixed_deck_pool(fixed_deck_path) else: self._load_verified_deck_pool() def _load_bytecode(self): import json try: with open("data/cards_numba.json", "r") as f: raw_map = json.load(f) # Convert to numpy array # Format: key "cardid_abidx" -> List[int] # storage: # 1. giant array of bytecodes (N, MaxLen, 4) # 2. lookup index (CardID, AbIdx) -> Index in giant array self.max_cards = 2000 self.max_abilities = 8 self.max_len = 128 # Max 128 instructions per ability for future expansion # Count unique compiled entries unique_entries = len(raw_map) # (Index 0 is empty/nop) self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32) self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32) idx_counter = 1 for key, bc_list in raw_map.items(): cid, aid = map(int, key.split("_")) if cid < self.max_cards and aid < self.max_abilities: # reshape list to (M, 4) bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4) length = min(bc_arr.shape[0], self.max_len) self.bytecode_map[idx_counter, :length] = bc_arr[:length] self.bytecode_index[cid, aid] = idx_counter idx_counter += 1 print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.") # --- IMAX PRO VISION (Stride 80) --- # Fixed Geography: No maps, no shifting. Dedicated space per ability. # 0-19: Stats (Cost, Hearts, Traits, Live Reqs) # 20-35: Ability 1 (Trig, Cond, Opts, 3 Effs) # 36-47: Ability 2 (Trig, Cond, 3 Effs) # 48-59: Ability 3 (Trig, Cond, 3 Effs) # 60-71: Ability 4 (Trig, Cond, 3 Effs) # 79: Location Signal (Runtime Only) self.card_stats = np.zeros((self.max_cards, 80), dtype=np.int32) try: import json import re with open("data/cards_compiled.json", "r", encoding="utf-8") as f: db = json.load(f) # We need to map Card ID (int) -> Stats # cards_compiled.json is keyed by string integer "0", "1"... count = 0 # Build character name to ID mapping for Baton Pass name_to_id = {} # First pass: collect all character names and their IDs if "member_db" in db: for cid_str, card in db["member_db"].items(): cid = int(cid_str) if cid < self.max_cards: # Store character name to ID mapping name = card.get("name", "") if name: name_to_id[name] = cid # Load Members if "member_db" in db: for cid_str, card in db["member_db"].items(): cid = int(cid_str) if cid < self.max_cards: # 0. Card Type (1=Member) self.card_stats[cid, 10] = 1 # 1. Cost self.card_stats[cid, 0] = card.get("cost", 0) # 2. Blades self.card_stats[cid, 1] = card.get("blades", 0) # 3. Hearts (Sum of array elements > 0?) # Actually just count non-zero hearts in array? Or sum of values? # Usually 'hearts' is [points, points...]. Let's sum points. h_arr = card.get("hearts", []) self.card_stats[cid, 2] = sum(h_arr) # 4. Store detailed hearts for Members too (indices 12-18) # [Pn, Rd, Yl, Gr, Bl, Pu, All] for r_idx in range(min(len(h_arr), 7)): self.card_stats[cid, 12 + r_idx] = h_arr[r_idx] # Store Character ID in index 19 for Baton Pass condition name = card.get("name", "") if name in name_to_id: self.card_stats[cid, 19] = name_to_id[name] # Infer Primary Color (for visualization/traits) col = 0 for cidx, val in enumerate(h_arr): if val > 0: col = cidx + 1 # 1-based color break self.card_stats[cid, 3] = col # 5. Volume/Draw Icons self.card_stats[cid, 4] = card.get("volume_icons", 0) self.card_stats[cid, 5] = card.get("draw_icons", 0) # 6. Blade Hearts (flipped as yell) bh = card.get("blade_hearts", []) for b_idx in range(min(len(bh), 7)): self.card_stats[cid, 40 + b_idx] = bh[b_idx] # Live Card Stats if "required_hearts" in card: # Pack Required Hearts into 12-18 (Pink..Purple, All) reqs = card.get("required_hearts", []) for r_idx in range(min(len(reqs), 7)): self.card_stats[cid, 12 + r_idx] = reqs[r_idx] # --- FIXED GEOGRAPHY ABILITY PACKING --- ab_list = card.get("abilities", []) # Helper to pack an ability into a fixed block def pack_ability_block(ab, base_idx, has_opts=False): if not ab: return # Trigger (Base + 0) self.card_stats[cid, base_idx] = ab.get("trigger", 0) # Condition (Base + 1, 2) conds = ab.get("conditions", []) if conds: self.card_stats[cid, base_idx + 1] = conds[0].get("type", 0) self.card_stats[cid, base_idx + 2] = conds[0].get("params", {}).get("value", 0) # Effects effs = ab.get("effects", []) eff_start = base_idx + 3 if has_opts: # Ability 1 has extra space for Options eff_start = base_idx + 9 # Skip 6 slots for options # Pack Options (from first effect) if effs: m_opts = effs[0].get("modal_options", []) if len(m_opts) > 0 and len(m_opts[0]) > 0: o = m_opts[0][0] # Opt 1 self.card_stats[cid, base_idx + 3] = o.get("effect_type", 0) self.card_stats[cid, base_idx + 4] = o.get("value", 0) self.card_stats[cid, base_idx + 5] = o.get("target", 0) if len(m_opts) > 1 and len(m_opts[1]) > 0: o = m_opts[1][0] # Opt 2 self.card_stats[cid, base_idx + 6] = o.get("effect_type", 0) self.card_stats[cid, base_idx + 7] = o.get("value", 0) self.card_stats[cid, base_idx + 8] = o.get("target", 0) # Pack up to 3 Effects for e_i in range(min(len(effs), 3)): e = effs[e_i] off = eff_start + (e_i * 3) self.card_stats[cid, off] = e.get("effect_type", 0) self.card_stats[cid, off + 1] = e.get("value", 0) self.card_stats[cid, off + 2] = e.get("target", 0) # Block 1: Ability 1 (Indices 20-35) [Has Options] if len(ab_list) > 0: pack_ability_block(ab_list[0], 20, has_opts=True) # Block 2: Ability 2 (Indices 36-47) if len(ab_list) > 1: pack_ability_block(ab_list[1], 36) # Block 3: Ability 3 (Indices 48-59) if len(ab_list) > 2: pack_ability_block(ab_list[2], 48) # Block 4: Ability 4 (Indices 60-71) if len(ab_list) > 3: pack_ability_block(ab_list[3], 60) # 7. Type self.card_stats[cid, 10] = 1 # 8. Traits Bitmask (Groups & Units) -> Stores in Index 11 # Bits 0-4: Groups (Max 5) # Bits 5-20: Units (Max 16) mask = 0 groups = card.get("groups", []) for g in groups: try: mask |= 1 << (int(g) % 20) except: pass units = card.get("units", []) for u in units: try: mask |= 1 << ((int(u) % 20) + 5) except: pass self.card_stats[cid, 11] = mask count += 1 # Load Lives if "live_db" in db: for cid_str, card in db["live_db"].items(): cid = int(cid_str) if cid < self.max_cards: # Type: Live=2 self.card_stats[cid, 10] = 2 # Required Hearts reqs = card.get("required_hearts", []) for r_idx in range(min(len(reqs), 7)): self.card_stats[cid, 12 + r_idx] = reqs[r_idx] # Score self.card_stats[cid, 38] = card.get("score", 0) # Store Character ID in index 19 for Baton Pass condition name = card.get("name", "") if name in name_to_id: self.card_stats[cid, 19] = name_to_id[name] count += 1 print(f" [VectorEnv] Loaded detailed stats/abilities for {count} cards.") # --- RUNTIME PATCHING FOR BATON PASS CARDS --- # Scan all cards for "バトンタッチして" condition and inject C_BATON opcode print(" [VectorEnv] Starting runtime patching for Baton Pass cards...") # Load the original bytecode map to scan for cards that need patching with open("data/cards_numba.json", "r") as f: raw_map = json.load(f) # Regex pattern to detect Baton Pass condition baton_pattern = re.compile(r"「(.+?)」からバトンタッチして") patched_count = 0 idx_counter = 1 # Start from 1 since 0 is reserved for empty # First pass: count how many patched bytecodes we'll need baton_cards = [] for cid_str, card in {**db.get("member_db", {}), **db.get("live_db", {})}.items(): cid = int(cid_str) if cid >= self.max_cards: continue # Check if this card has abilities with Baton Pass condition ab_list = card.get("abilities", []) for ab_idx, ability in enumerate(ab_list): raw_text = ability.get("raw_text", "") # Check if the raw text contains the Baton Pass pattern match = baton_pattern.search(raw_text) if match: target_name = match.group(1) # Get the target character ID target_cid = name_to_id.get(target_name, -1) if target_cid != -1: original_key = f"{cid}_{ab_idx}" if original_key in raw_map: baton_cards.append((cid, ab_idx, target_cid, raw_map[original_key], target_name)) # Second pass: expand bytecode_map if needed and apply patches for cid, ab_idx, target_cid, original_bytecode, target_name in baton_cards: # Get the card object again to access the name card = {} if str(cid) in db.get("member_db", {}): card = db["member_db"][str(cid)] elif str(cid) in db.get("live_db", {}): card = db["live_db"][str(cid)] # This card has a Baton Pass condition that needs to be patched print( f" [VectorEnv] Patching Baton Pass for card {cid} ('{card.get('name', '')}') targeting '{target_name}' (ID: {target_cid})" ) # Create new bytecode sequence with C_BATON condition prepended # Format: [C_BATON, Target_Char_ID, 0, 0] + original_bytecode # Prepend CHECK_BATON (231) opcode new_bytecode = [231, target_cid, 0, 0] + original_bytecode # original_bytecode is already a list # Find a free slot in the bytecode map for the patched version if idx_counter < self.bytecode_map.shape[0]: # Reshape the new bytecode to fit the map dimensions bc_arr = np.array(new_bytecode, dtype=np.int32).reshape(-1, 4) length = min(bc_arr.shape[0], self.max_len) self.bytecode_map[idx_counter, :length] = bc_arr[:length] # Update the bytecode index to point to the new patched version self.bytecode_index[cid, ab_idx] = idx_counter patched_count += 1 print( f" [VectorEnv] Successfully patched ability {ab_idx} for card {cid}, new bytecode index: {idx_counter}" ) idx_counter += 1 else: print(f" [VectorEnv] Error: No more space in bytecode map for card {cid}") print(f" [VectorEnv] Runtime patching completed. {patched_count} cards patched.") except Exception as e: print(f" [VectorEnv] Warning: Failed to load compiled stats: {e}") except FileNotFoundError: print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.") self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32) self.bytecode_index = np.zeros((1, 1), dtype=np.int32) def _load_verified_deck_pool(self): import json try: # Load Verified List with open("data/verified_card_pool.json", "r", encoding="utf-8") as f: verified_data = json.load(f) # Load DB to map CardNo -> CardID with open("data/cards_compiled.json", "r", encoding="utf-8") as f: db_data = json.load(f) self.ability_member_ids = [] self.ability_live_ids = [] self.vanilla_member_ids = [] self.vanilla_live_ids = [] # Map numbers to IDs and types member_no_map = {} live_no_map = {} for cid, cdata in db_data.get("member_db", {}).items(): member_no_map[cdata["card_no"]] = int(cid) for cid, cdata in db_data.get("live_db", {}).items(): live_no_map[cdata["card_no"]] = int(cid) # Check for list compatibility mode if isinstance(verified_data, list): print(" [VectorEnv] Loading Verified Pool from List (Compatibility Mode)") for v_no in verified_data: if v_no in member_no_map: self.ability_member_ids.append(member_no_map[v_no]) elif v_no in live_no_map: self.ability_live_ids.append(live_no_map[v_no]) else: # 1. Primary Pool: Abilities (Categorized) # Support both old keys (verified_abilities) and new keys (members) source_members = verified_data.get("verified_abilities", []) + verified_data.get("members", []) for v_no in source_members: if v_no in member_no_map: self.ability_member_ids.append(member_no_map[v_no]) source_lives = verified_data.get("verified_lives", []) + verified_data.get("lives", []) for v_no in source_lives: if v_no in live_no_map: self.ability_live_ids.append(live_no_map[v_no]) # 2. Secondary Pool: Vanilla for v_no in verified_data.get("vanilla_members", []): if v_no in member_no_map: self.vanilla_member_ids.append(member_no_map[v_no]) for v_no in verified_data.get("vanilla_lives", []): if v_no in live_no_map: self.vanilla_live_ids.append(live_no_map[v_no]) # Fallback/Warnings if not self.ability_member_ids: if self.vanilla_member_ids: print(" [VectorEnv] Warning: No ability members. using vanilla members.") self.ability_member_ids = self.vanilla_member_ids else: print(" [VectorEnv] Warning: No members found. Using ID 1.") self.ability_member_ids = [1] if not self.ability_live_ids: if self.vanilla_live_ids: print(" [VectorEnv] Warning: No ability lives. Using vanilla lives.") self.ability_live_ids = self.vanilla_live_ids else: print(" [VectorEnv] Warning: No lives found. Using ID 999 (Dummy).") self.ability_live_ids = [999] print( f" [VectorEnv] Pools: {len(self.ability_member_ids)} Ability Members, {len(self.ability_live_ids)} Ability Lives." ) print( f" [VectorEnv] Fallbacks: {len(self.vanilla_member_ids)} Vanilla Members, {len(self.vanilla_live_ids)} Vanilla Lives." ) self.ability_member_ids = np.array(self.ability_member_ids, dtype=np.int32) self.ability_live_ids = np.array(self.ability_live_ids, dtype=np.int32) self.vanilla_member_ids = np.array(self.vanilla_member_ids, dtype=np.int32) self.vanilla_live_ids = np.array(self.vanilla_live_ids, dtype=np.int32) except Exception as e: print(f" [VectorEnv] Deck Load Error: {e}") self.ability_member_ids = np.array([1], dtype=np.int32) self.ability_live_ids = np.array([999], dtype=np.int32) self.vanilla_member_ids = np.array([], dtype=np.int32) self.vanilla_live_ids = np.array([], dtype=np.int32) def _load_fixed_deck_pool(self, deck_path: str): import json import re print(f" [VectorEnv] Loading FIXED DECK from: {deck_path}") try: # 1. Load DB to map CardNo -> CardID with open("data/cards_compiled.json", "r", encoding="utf-8") as f: db_data = json.load(f) member_no_map = {} live_no_map = {} for cid, cdata in db_data.get("member_db", {}).items(): member_no_map[cdata["card_no"]] = int(cid) for cid, cdata in db_data.get("live_db", {}).items(): live_no_map[cdata["card_no"]] = int(cid) # 2. Parse Markdown with open(deck_path, "r", encoding="utf-8") as f: lines = f.readlines() members = [] lives = [] for line in lines: # Look for "4x [PL!-...]" - flexible for markdown bolding like **4x** match = re.search(r"(\d+)x.*?\[(PL!-[^\]]+)\]", line) if match: count = int(match.group(1)) card_no = match.group(2) if card_no in member_no_map: for _ in range(count): members.append(member_no_map[card_no]) elif card_no in live_no_map: for _ in range(count): lives.append(live_no_map[card_no]) # 3. Finalize if len(members) != 48: print(f" [VectorEnv] Warning: Fixed deck members count is {len(members)}, expected 48.") if len(lives) != 12: print(f" [VectorEnv] Warning: Fixed deck lives count is {len(lives)}, expected 12.") self.ability_member_ids = np.array(members, dtype=np.int32) self.ability_live_ids = np.array(lives, dtype=np.int32) self.vanilla_member_ids = np.array([], dtype=np.int32) self.vanilla_live_ids = np.array([], dtype=np.int32) print( f" [VectorEnv] Fixed Deck Loaded: {len(self.ability_member_ids)} members, {len(self.ability_live_ids)} lives." ) except Exception as e: print(f" [VectorEnv] Fixed Deck Load Error: {e}") self._load_verified_deck_pool() def _load_scenarios(self, path="data/scenarios.npz"): try: import numpy as np data = np.load(path) self.scenarios = {k: data[k] for k in data.files} self.num_scenarios = len(self.scenarios["batch_hand"]) print(f" [VectorEnv] Loaded {self.num_scenarios} scenarios from {path}") except Exception as e: print(f" [VectorEnv] Failed to load scenarios: {e}") self.scenarios = None def reset(self, indices: List[int] = None): """Reset specified environments (or all if indices is None).""" if indices is None: # Full Reset # Optimization: If resetting all, just loop all in Numba # We can use a special function or pass all indices indices_arr = np.arange(self.num_envs, dtype=np.int32) else: indices_arr = np.array(indices, dtype=np.int32) # Use new reset_single logic via loop or parallel # We can reuse integrated_step_numba's reset logic helper # But we need a standalone reset kernel isn.reset_kernel_numba( indices_arr, self.batch_stage, self.batch_energy_vec, self.batch_energy_count, self.batch_continuous_vec, self.batch_continuous_ptr, self.batch_tapped, self.batch_live, self.batch_scores, self.batch_flat_ctx, self.batch_global_ctx, self.batch_hand, self.batch_deck, self.opp_stage, self.opp_energy_vec, self.opp_energy_count, self.opp_tapped, self.opp_live, self.opp_scores, self.opp_global_ctx, self.opp_hand, self.opp_deck, self.batch_trash, self.opp_trash, self.batch_opp_history, self.ability_member_ids, self.ability_live_ids, int(self.force_start_order), ) # Scenario Overwrite if getattr(self, "scenarios", None) is not None and os.getenv("USE_SCENARIOS", "0") == "1": try: # Select random scenarios num_reset = self.num_envs if indices is None else len(indices_arr) reset_indices = np.arange(self.num_envs) if indices is None else indices_arr scen_indices = np.random.randint(0, self.num_scenarios, size=num_reset) def load_field(name, target): if name in self.scenarios: data = self.scenarios[name][scen_indices] if target.ndim == 1 and data.ndim == 2 and data.shape[1] == 1: data = data.ravel() target[reset_indices] = data load_field("batch_hand", self.batch_hand) load_field("batch_deck", self.batch_deck) load_field("batch_stage", self.batch_stage) load_field("batch_energy_vec", self.batch_energy_vec) load_field("batch_energy_count", self.batch_energy_count) load_field("batch_continuous_vec", self.batch_continuous_vec) load_field("batch_continuous_ptr", self.batch_continuous_ptr) load_field("batch_tapped", self.batch_tapped) load_field("batch_live", self.batch_live) load_field("batch_scores", self.batch_scores) load_field("batch_flat_ctx", self.batch_flat_ctx) load_field("batch_global_ctx", self.batch_global_ctx) load_field("opp_hand", self.opp_hand) load_field("opp_deck", self.opp_deck) load_field("opp_stage", self.opp_stage) load_field("opp_energy_vec", self.opp_energy_vec) load_field("opp_energy_count", self.opp_energy_count) load_field("opp_tapped", self.opp_tapped) load_field("opp_live", self.opp_live) load_field("opp_scores", self.opp_scores) load_field("opp_global_ctx", self.opp_global_ctx) except Exception as e: print(f" [VectorEnv] Error loading scenario data: {e}") # Reset local trackers if indices is None: self.turn = 1 self.prev_scores.fill(0) self.prev_opp_scores.fill(0) self.prev_phases.fill(0) self.episode_returns.fill(0) self.episode_lengths.fill(0) else: for idx in indices: self.prev_scores[idx] = 0 self.prev_opp_scores[idx] = 0 self.prev_phases[idx] = 0 self.episode_returns[idx] = 0 self.episode_lengths[idx] = 0 # Return observations return self.get_observations() def step(self, actions: np.ndarray): """Apply a batch of actions across all environments using Optimized Integrated Step.""" # Ensure actions are int32 if actions.dtype != np.int32: actions = actions.astype(np.int32) return self.integrated_step(actions) def integrated_step(self, actions: np.ndarray): """ Executes the optimized Numba Integrated Step. Returns: obs, rewards, dones, infos (list of dicts) """ term_scores_agent = np.zeros(self.num_envs, dtype=np.int32) term_scores_opp = np.zeros(self.num_envs, dtype=np.int32) rewards, dones = isn.integrated_step_numba( self.num_envs, actions, self.batch_hand, self.batch_deck, self.batch_stage, self.batch_energy_vec, self.batch_energy_count, self.batch_continuous_vec, self.batch_continuous_ptr, self.batch_tapped, self.batch_live, self.batch_scores, self.batch_flat_ctx, self.batch_global_ctx, self.opp_hand, self.opp_deck, self.opp_stage, self.opp_energy_vec, self.opp_energy_count, self.opp_tapped, self.opp_live, # Added self.opp_scores, self.opp_global_ctx, self.card_stats, self.bytecode_map, self.bytecode_index, self.batch_opp_history, self.obs_buffer, self.prev_scores, self.prev_opp_scores, self.prev_phases, self.ability_member_ids, self.ability_live_ids, self.turn_number_ptr, self.terminal_obs_buffer, self.batch_trash, self.opp_trash, term_scores_agent, term_scores_opp, 0 if self.obs_mode == "IMAX" else (1 if self.obs_mode == "STANDARD" else (3 if self.obs_mode == "ATTENTION" else 2)), self.game_config, # New Config int(self.opp_mode), int(self.force_start_order), ) # Apply Scenario Reward Scaling if self.scenario_reward_scale != 1.0 and os.getenv("USE_SCENARIOS", "0") == "1": rewards *= self.scenario_reward_scale # Construct Infos (minimal python overhead) infos = [] for i in range(self.num_envs): if dones[i]: infos.append( { "terminal_observation": self.terminal_obs_buffer[i].copy(), "episode": {"r": float(rewards[i]), "l": 10}, "terminal_score_agent": int(term_scores_agent[i]), "terminal_score_opp": int(term_scores_opp[i]), } ) else: # Accumulate rewards for ongoing episodes # NOTE: rewards[i] is the delta reward for this specific integrated step. self.episode_returns[i] += rewards[i] self.episode_lengths[i] += 1 infos.append({}) # After loop, update terminal infos for done envs with the SUMMED returns for i in range(self.num_envs): if dones[i]: # Add terminal reward to the return final_return = self.episode_returns[i] + rewards[i] final_length = self.episode_lengths[i] + 1 infos[i]["episode"] = {"r": float(final_return), "l": int(final_length)} # Reset accumulators for the next episode in this slot self.episode_returns[i] = 0 self.episode_lengths[i] = 0 return self.obs_buffer, rewards, dones, infos def get_action_masks(self): """Return legal action masks.""" if self.obs_mode == "ATTENTION": return compute_action_masks_attention( self.num_envs, self.batch_hand, self.batch_stage, self.batch_tapped, self.batch_global_ctx, self.batch_live, self.card_stats, ) else: return compute_action_masks( self.num_envs, self.batch_hand, self.batch_stage, self.batch_tapped, self.batch_global_ctx, self.batch_live, self.card_stats, ) def get_observations(self): """Return a batched observation for RL models.""" if self.obs_mode == "COMPRESSED": return isn.encode_observations_compressed( self.num_envs, self.batch_hand, self.batch_stage, self.batch_energy_count, self.batch_tapped, self.batch_scores, self.opp_scores, self.opp_stage, self.opp_tapped, self.card_stats, self.batch_global_ctx, self.batch_live, self.batch_opp_history, self.turn, self.obs_buffer, ) elif self.obs_mode == "IMAX": return isn.encode_observations_imax( self.num_envs, self.batch_hand, self.batch_stage, self.batch_energy_count, self.batch_tapped, self.batch_scores, self.opp_scores, self.opp_stage, self.opp_tapped, self.card_stats, self.batch_global_ctx, self.batch_live, self.batch_opp_history, self.turn, self.obs_buffer, ) elif self.obs_mode == "ATTENTION": return isn.encode_observations_attention( self.num_envs, self.batch_hand, self.batch_stage, self.batch_energy_count, self.batch_tapped, self.batch_scores, self.opp_scores, self.opp_stage, self.opp_tapped, self.card_stats, self.batch_global_ctx, self.batch_live, self.batch_opp_history, self.opp_global_ctx, self.turn, self.obs_buffer, ) else: return isn.encode_observations_standard( self.num_envs, self.batch_hand, self.batch_stage, self.batch_energy_count, self.batch_tapped, self.batch_scores, self.opp_scores, self.opp_stage, self.opp_tapped, self.card_stats, self.batch_global_ctx, self.batch_live, self.batch_opp_history, self.turn, self.obs_buffer, ) @njit(cache=True) def step_opponent_vectorized( opp_hand: np.ndarray, # (N, 60) opp_deck: np.ndarray, # (N, 60) opp_stage: np.ndarray, opp_energy_vec: np.ndarray, opp_energy_count: np.ndarray, opp_tapped: np.ndarray, opp_scores: np.ndarray, agent_tapped: np.ndarray, opp_global_ctx: np.ndarray, # (N, 128) bytecode_map: np.ndarray, bytecode_index: np.ndarray, ): """ Very simplified opponent step. Reuses agent bytecode but targets opponent buffers. """ num_envs = len(opp_hand) # Dummy buffers for context (reused per env) f_ctx = np.zeros(64, dtype=np.int32) # We use the passed Hand/Deck buffers directly! live = np.zeros(50, dtype=np.int32) # Dummy live zone for opponent # Reusable dummies to avoid allocation in loop dummy_cont_vec = np.zeros((32, 10), dtype=np.int32) dummy_ptr = np.zeros(1, dtype=np.int32) # Ref Array dummy_bonus = np.zeros(1, dtype=np.int32) # Ref Array for i in range(num_envs): # RESET local context per environment f_ctx.fill(0) # 1. Select Random Legal Action from Hand # Scan hand for valid bytecodes # Use fixed array for Numba compatibility (no lists) candidates = np.zeros(60, dtype=np.int32) c_ptr = 0 for j in range(60): # Hand size cid = opp_hand[i, j] if cid > 0: candidates[c_ptr] = j # Store Index in Hand c_ptr += 1 if c_ptr == 0: continue # Pick one random index idx_choice = np.random.randint(0, c_ptr) hand_idx = candidates[idx_choice] act_id = opp_hand[i, hand_idx] # 2. Execute if act_id > 0 and act_id < bytecode_index.shape[0]: map_idx = bytecode_index[act_id, 0] if map_idx > 0: code_seq = bytecode_map[map_idx] opp_global_ctx[i, 0] = opp_scores[i] opp_global_ctx[i, 3] -= 1 # Decrement Hand Count (HD) after playing # Reset dummies dummy_ptr[0] = 0 dummy_bonus[0] = 0 # Pass Row Slices of Hand/Deck # Careful: slicing in loop might allocate. Pass full array + index? # resolve_bytecode expects 1D array. # We can't pass a slice 'opp_hand[i]' effectively if function modifies it in place? # Actually resolve_bytecode modifies it. # Numba slices are views, should work. resolve_bytecode( code_seq, f_ctx, opp_global_ctx[i], 1, opp_hand[i], opp_deck[i], opp_stage[i], opp_energy_vec[i], opp_energy_count[i], dummy_cont_vec, dummy_ptr, opp_tapped[i], live, agent_tapped[i], bytecode_map, bytecode_index, dummy_bonus, ) # Neutralized: opp_scores[i] = opp_global_ctx[i, 0] # SC = 0; OS = 1; TR = 2; HD = 3; DI = 4; EN = 5; DK = 6; OT = 7 # Resolve bytecode puts score in SC (index 0) for the current player? # Let's check fast_logic.py: it uses global_ctx[SC]. # So opp_scores[i] = opp_global_ctx[i, 0] is correct if they are the "current player" in that call. # 3. Post-Play Cleanup (Draw to refill?) # If card played, act_id removed from hand by resolve_bytecode (Opcode 11/12/13 usually). # To simulate "Draw", we check if hand size < 5. # Count current hand cnt = 0 for j in range(60): if opp_hand[i, j] > 0: cnt += 1 if cnt < 5: # Draw top card from Deck # Find first card in Deck top_card = 0 deck_idx = -1 for j in range(60): if opp_deck[i, j] > 0: top_card = opp_deck[i, j] deck_idx = j break if top_card > 0: # Move to Hand (First empty slot) for j in range(60): if opp_hand[i, j] == 0: opp_hand[i, j] = top_card opp_deck[i, deck_idx] = 0 # Remove from deck opp_global_ctx[i, 3] += 1 # Increment Hand Count (HD) opp_global_ctx[i, 6] -= 1 # Decrement Deck Count (DK) break @njit(cache=True) def resolve_auto_phases( num_envs: int, batch_hand: np.ndarray, batch_deck: np.ndarray, batch_global_ctx: np.ndarray, batch_tapped: np.ndarray, single_step: bool = False, ): """ Automatically advances the game through non-interactive phases (0, 1, 2) until it reaches the Main Phase (3) or the game is over. Includes Turn Start Draw (Phase 2). """ for i in range(num_envs): # We loop to handle multiple phase jumps if needed # SAFETY: Limit iterations max_iters = 1 if single_step else 10 for _ in range(max_iters): ph = int(batch_global_ctx[i, 8]) # 0 (MULLIGAN) or 8 (LIVE_RESULT) -> 1 (ACTIVE) if ph == 0 or ph == 8: # Turn Start: Reset Slot Played Flags (Indices 51-53) batch_global_ctx[i, 51:54] = 0 # Reset Tapped Status (Members 0-2, Energy 3-15) batch_tapped[i, 0:16] = 0 # Increment Energy Count (Index 5) (Up to 12) cur_ec = batch_global_ctx[i, 5] if cur_ec == 0: batch_global_ctx[i, 5] = 3 elif cur_ec < 12: batch_global_ctx[i, 5] = cur_ec + 1 # Increment Turn Counter (Index 54) batch_global_ctx[i, 54] += 1 batch_global_ctx[i, 8] = 1 continue # ACTIVE (1) -> ENERGY (2) if ph == 1: batch_global_ctx[i, 8] = 2 continue # ENERGY (2) -> DRAW (3) if ph == 2: batch_global_ctx[i, 8] = 3 continue # DRAW (3) -> MAIN (4) if ph == 3: # DRAW 1 CARD top_card = 0 deck_idx = -1 for d_idx in range(60): if batch_deck[i, d_idx] > 0: top_card = batch_deck[i, d_idx] deck_idx = d_idx break # REPLENISH DECK IF EMPTY (Infinite play for benchmarks) if top_card == 0: batch_global_ctx[i, 8] = 4 continue if top_card > 0: for h_idx in range(60): if batch_hand[i, h_idx] == 0: batch_hand[i, h_idx] = top_card batch_deck[i, deck_idx] = 0 batch_global_ctx[i, 3] = 0 for k in range(60): if batch_hand[i, k] > 0: batch_global_ctx[i, 3] += 1 batch_global_ctx[i, 6] -= 1 break batch_global_ctx[i, 8] = 4 continue # If ph == 4 (Main), we stop and let the agent act. if ph == 4: break # If ph is not handled, break to avoid infinite loop break @njit(parallel=True, cache=True) def compute_action_masks_attention( num_envs: int, batch_hand: np.ndarray, batch_stage: np.ndarray, batch_tapped: np.ndarray, batch_global_ctx: np.ndarray, batch_live: np.ndarray, card_stats: np.ndarray, ): """ Compute legal action masks for ATTENTION mode (512 actions). Mapping: - 0: Pass - 1-45: Play Member (15 hand idx * 3 slots) - 46-60: Set Live (15 hand idx) - 61-63: Activate Ability (3 slots) - 64-69: Mulligan Select (6 cards) - 100-299: Choice Actions (Not fully implemented yet) """ masks = np.zeros((num_envs, 512), dtype=np.bool_) masks[:, 0] = True # Pass always legal for i in prange(num_envs): phase = batch_global_ctx[i, 8] # --- Mulligan (Phase Includes -1, 0) --- if phase <= 0: # Allow pass (0) to finish masks[i, 0] = True # Allow select mulligan (64-69) for first 6 cards # ONE-WAY: If already selected (flag=1), mask it. for h_idx in range(6): if batch_hand[i, h_idx] > 0: if batch_global_ctx[i, 120 + h_idx] == 0: masks[i, 64 + h_idx] = True continue # --- Main Phase (4) --- if phase == 4: ec = batch_global_ctx[i, 5] tapped_count = 0 for e_idx in range(min(ec, 12)): if batch_tapped[i, 3 + e_idx] > 0: tapped_count += 1 available_energy = ec - tapped_count # 1. Play Actions (1-45) & Set Live (46-60) # Hand limit for this mode is 15 primary indices for h_idx in range(15): cid = batch_hand[i, h_idx] if cid <= 0 or cid >= card_stats.shape[0]: continue is_member = card_stats[cid, 10] == 1 is_live = card_stats[cid, 10] == 2 if is_member: # Play to Slot 0-2 (Actions 1-45) # Base = 1 + h_idx * 3 cost = card_stats[cid, 0] for slot in range(3): # One play per slot per turn check if batch_global_ctx[i, 51 + slot] > 0: continue # Effective Cost (Baton Touch) effective_cost = cost prev_cid = batch_stage[i, slot] if prev_cid > 0 and prev_cid < card_stats.shape[0]: effective_cost = max(0, cost - card_stats[prev_cid, 0]) if effective_cost <= available_energy: masks[i, 1 + h_idx * 3 + slot] = True # Set Live (Actions 46-60) # Rule 8.3 & 8.2.2: ANY card can be set. # Limit 3 cards in zone live_count = 0 for lx in range(6): # Check full 6 capacity (3 pending + 3 success) if batch_live[i, lx] > 0: live_count += 1 if live_count < 3: masks[i, 46 + h_idx] = True # 2. Activate Abilities (61-63) for slot in range(3): cid = batch_stage[i, slot] if cid > 0 and not batch_tapped[i, slot]: masks[i, 61 + slot] = True # --- Choice Handling (Phase 7+) --- if phase >= 7 or phase == 4: # Allow hand selection (100-159) for h_idx in range(60): if batch_hand[i, h_idx] > 0: masks[i, 100 + h_idx] = True # Allow energy selection (160-171) ec_val = batch_global_ctx[i, 5] for e_idx in range(min(ec_val, 12)): masks[i, 160 + e_idx] = True return masks @njit(parallel=True, cache=True) def compute_action_masks( num_envs: int, batch_hand: np.ndarray, batch_stage: np.ndarray, batch_tapped: np.ndarray, batch_global_ctx: np.ndarray, batch_live: np.ndarray, card_stats: np.ndarray, ): """ Compute legal action masks using Python-compatible action IDs: - 0: Pass (always legal in Main Phase) - 1-180: Play Member from Hand (HandIdx * 3 + Slot + 1) - 200-202: Activate Ability (Slot) - 400-459: Set Live Card (HandIdx) """ masks = np.zeros((num_envs, 2000), dtype=np.bool_) # Action 0 (Pass) is always legal masks[:, 0] = True for i in prange(num_envs): phase = batch_global_ctx[i, 8] # Mulligan Phases (-1, 0) # Mulligan Phases (-1, 0) if phase == -1 or phase == 0: masks[i, 0] = True # Pass to finalize # Only allow selection if the card exists AND isn't already selected (One-way) for h_idx in range(6): # Only first 6 cards are mull-able (Parity) if batch_hand[i, h_idx] > 0: selected = batch_global_ctx[i, 120 + h_idx] if selected == 0: masks[i, 300 + h_idx] = True continue # Only compute member/ability actions in Main Phase (4) if phase == 4: # Calculate available untapped energy ec = batch_global_ctx[i, 5] # EC at index 5 tapped_count = 0 for e_idx in range(min(ec, 12)): if batch_tapped[i, 3 + e_idx] > 0: tapped_count += 1 available_energy = ec - tapped_count # --- Member Play Actions (1-180) --- # Action ID = HandIdx * 3 + Slot + 1 for h_idx in range(60): cid = batch_hand[i, h_idx] # CRITICAL SAFETY: card_stats shape check if cid <= 0 or cid >= card_stats.shape[0]: continue # Check if this is a Member card (Type 1) if card_stats[cid, 10] != 1: # Check if this is a Live card (Type 2) for play actions 400-459 if card_stats[cid, 10] == 2: # Action ID = 400 + h_idx action_id = 400 + h_idx # --- RULE ACCURACY: Live cards can be set without checking hearts --- # Requirements are checked during Performance phase (Rule 8.3) # We allow setting if hand size limit not reached (max 3 in zone) count_in_zone = 0 for j in range(50): if batch_live[i, j] > 0: count_in_zone += 1 if count_in_zone < 3: masks[i, action_id] = True continue # Member cost in card_stats[cid, 0] cost = card_stats[cid, 0] for slot in range(3): action_id = h_idx * 3 + slot + 1 # Rule: One play per slot per turn (Indices 51-53) if batch_global_ctx[i, 51 + slot] > 0: continue # Calculate effective cost (Baton Touch reduction) effective_cost = cost prev_cid = batch_stage[i, slot] # SAFETY: Check cid range to avoid out-of-bounds card_stats access if prev_cid >= 0 and prev_cid < card_stats.shape[0]: prev_cost = card_stats[prev_cid, 0] effective_cost = cost - prev_cost if effective_cost < 0: effective_cost = 0 if effective_cost <= available_energy: masks[i, action_id] = True # --- Activate Ability Actions (200-202) --- for slot in range(3): cid = batch_stage[i, slot] if cid > 0 and not batch_tapped[i, slot]: # Check if card has an activated ability # For now, assume all untapped members can activate masks[i, 200 + slot] = True # --- Mandatory Choice Handling (Phase 7, 8 & Fallback) --- if phase >= 7 or phase == 4: # Allow hand selection/discard actions (500-559) if hand has cards # This prevents Zero Legal Moves when a choice is pending. for h_idx in range(60): if batch_hand[i, h_idx] > 0: masks[i, 500 + h_idx] = True # Allow energy selection actions (600-611) if energy exists energy_count = batch_global_ctx[i, 5] for e_idx in range(min(energy_count, 12)): masks[i, 600 + e_idx] = True return masks # Export for legacy/external compatibility encode_observations_vectorized = isn.encode_observations_standard