Spaces:
Running
Running
| from typing import List | |
| import numpy as np | |
| from engine.game.ai_compat import njit | |
| from engine.game.fast_logic import batch_apply_action | |
| def step_vectorized( | |
| actions: np.ndarray, | |
| batch_stage: np.ndarray, | |
| batch_energy_vec: np.ndarray, | |
| batch_energy_count: np.ndarray, | |
| batch_continuous_vec: np.ndarray, | |
| batch_continuous_ptr: np.ndarray, | |
| batch_tapped: np.ndarray, | |
| batch_live: np.ndarray, | |
| batch_opp_tapped: np.ndarray, | |
| batch_scores: np.ndarray, | |
| batch_flat_ctx: np.ndarray, | |
| batch_global_ctx: np.ndarray, | |
| batch_hand: np.ndarray, | |
| batch_deck: np.ndarray, | |
| # New: Bytecode Maps | |
| bytecode_map: np.ndarray, # (GlobalOpMapSize, MaxBytecodeLen, 4) | |
| bytecode_index: np.ndarray, # (NumCards, NumAbilities) -> Index in map | |
| ): | |
| """ | |
| Step N game environments in parallel using JIT logic and Real Card Data. | |
| """ | |
| # Sync individual scores to global_ctx before stepping | |
| for i in range(len(actions)): | |
| batch_global_ctx[i, 0] = batch_scores[i] | |
| batch_apply_action( | |
| actions, | |
| 0, # player_id | |
| batch_stage, | |
| batch_energy_vec, | |
| batch_energy_count, | |
| batch_continuous_vec, | |
| batch_continuous_ptr, | |
| batch_tapped, | |
| batch_scores, | |
| batch_live, | |
| batch_opp_tapped, | |
| batch_flat_ctx, | |
| batch_global_ctx, | |
| batch_hand, | |
| batch_deck, | |
| bytecode_map, | |
| bytecode_index, | |
| ) | |
| class VectorGameState: | |
| """ | |
| Manages a batch of independent GameStates for high-throughput training. | |
| """ | |
| def __init__(self, num_envs: int): | |
| self.num_envs = num_envs | |
| self.turn = 1 | |
| # Batched state buffers | |
| self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32) | |
| self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32) | |
| self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32) | |
| self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32) | |
| self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32) | |
| self.batch_tapped = np.zeros((num_envs, 3), dtype=np.int32) | |
| self.batch_live = np.zeros((num_envs, 50), dtype=np.int32) | |
| self.batch_opp_tapped = np.zeros((num_envs, 3), dtype=np.int32) | |
| self.batch_scores = np.zeros(num_envs, dtype=np.int32) | |
| # Pre-allocated context buffers (Extreme speed optimization) | |
| self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32) | |
| self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32) | |
| self.batch_hand = np.zeros((num_envs, 50), dtype=np.int32) | |
| self.batch_deck = np.zeros((num_envs, 50), dtype=np.int32) | |
| # Pre-allocated observation buffer (SAVES ALLOCATION TIME) | |
| self.obs_buffer = np.zeros((num_envs, 320), dtype=np.float32) | |
| # Load Bytecode Map | |
| self._load_bytecode() | |
| self._load_verified_deck_pool() | |
| def _load_bytecode(self): | |
| import json | |
| try: | |
| with open("data/cards_numba.json", "r") as f: | |
| raw_map = json.load(f) | |
| # Convert to numpy array | |
| # Format: key "cardid_abidx" -> List[int] | |
| # storage: | |
| # 1. giant array of bytecodes (N, MaxLen, 4) | |
| # 2. lookup index (CardID, AbIdx) -> Index in giant array | |
| self.max_cards = 2000 | |
| self.max_abilities = 4 | |
| self.max_len = 64 # Max 64 instructions per ability | |
| # Count unique compiled entries | |
| unique_entries = len(raw_map) | |
| # (Index 0 is empty/nop) | |
| self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32) | |
| self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32) | |
| idx_counter = 1 | |
| for key, bc_list in raw_map.items(): | |
| cid, aid = map(int, key.split("_")) | |
| if cid < self.max_cards and aid < self.max_abilities: | |
| # reshape list to (M, 4) | |
| bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4) | |
| length = min(bc_arr.shape[0], self.max_len) | |
| self.bytecode_map[idx_counter, :length] = bc_arr[:length] | |
| self.bytecode_index[cid, aid] = idx_counter | |
| idx_counter += 1 | |
| print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.") | |
| except FileNotFoundError: | |
| print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.") | |
| self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32) | |
| self.bytecode_index = np.zeros((1, 1), dtype=np.int32) | |
| def _load_verified_deck_pool(self): | |
| import json | |
| try: | |
| # Load Verified List | |
| with open("verified_card_pool.json", "r", encoding="utf-8") as f: | |
| verified_data = json.load(f) | |
| # Load DB to map CardNo -> CardID | |
| with open("data/cards_compiled.json", "r", encoding="utf-8") as f: | |
| db_data = json.load(f) | |
| self.verified_card_ids = [] | |
| # Map numbers to IDs | |
| card_no_map = {} | |
| for cid, cdata in db_data["member_db"].items(): | |
| card_no_map[cdata["card_no"]] = int(cid) | |
| for v_no in verified_data.get("verified_abilities", []): | |
| if v_no in card_no_map: | |
| self.verified_card_ids.append(card_no_map[v_no]) | |
| # Fallback | |
| if not self.verified_card_ids: | |
| print(" [VectorEnv] Warning: No verified cards found. Using ID 1.") | |
| self.verified_card_ids = [1] | |
| else: | |
| print(f" [VectorEnv] Loaded {len(self.verified_card_ids)} verified cards for training.") | |
| self.verified_card_ids = np.array(self.verified_card_ids, dtype=np.int32) | |
| except Exception as e: | |
| print(f" [VectorEnv] Deck Load Error: {e}") | |
| self.verified_card_ids = np.array([1], dtype=np.int32) | |
| def reset(self, indices: List[int] = None): | |
| """Reset specified environments (or all if indices is None).""" | |
| if indices is None: | |
| indices = list(range(self.num_envs)) | |
| # Optimization: Bulk operations for indices if supported, | |
| # but for now loop is fine (reset is rare compared to step) | |
| # Prepare a random deck selection to broadcast? | |
| # Actually random.choice is fast. | |
| for i in indices: | |
| self.batch_stage[i].fill(-1) | |
| self.batch_energy_vec[i].fill(0) | |
| self.batch_energy_count[i].fill(0) | |
| self.batch_continuous_vec[i].fill(0) | |
| self.batch_continuous_ptr[i] = 0 | |
| self.batch_tapped[i].fill(0) | |
| self.batch_live[i].fill(0) | |
| self.batch_opp_tapped[i].fill(0) | |
| self.batch_scores[i] = 0 | |
| # Reset contexts | |
| self.batch_flat_ctx[i].fill(0) | |
| self.batch_global_ctx[i].fill(0) | |
| # Initialize Deck with Verified Cards (Random 50) | |
| # Fast choice from verified pool | |
| if len(self.verified_card_ids) > 0: | |
| dk = np.random.choice(self.verified_card_ids, 50) | |
| self.batch_deck[i] = dk | |
| # Initialize Hand (Draw 5 from deck) | |
| # Simple simulation: Move top 5 deck to hand | |
| self.batch_hand[i, :5] = self.batch_deck[i, :5] | |
| # Shift deck? Or just pointer? | |
| # For this benchmark we assume infinite deck or simple pointer logic via opcodes. | |
| # But the 'hand' array needs to be populated for gameplay to start. | |
| self.turn = 1 | |
| def step(self, actions: np.ndarray): | |
| """Apply a batch of actions across all environments.""" | |
| step_vectorized( | |
| actions, | |
| self.batch_stage, | |
| self.batch_energy_vec, | |
| self.batch_energy_count, | |
| self.batch_continuous_vec, | |
| self.batch_continuous_ptr, | |
| self.batch_tapped, | |
| self.batch_live, | |
| self.batch_opp_tapped, | |
| self.batch_scores, | |
| self.batch_flat_ctx, | |
| self.batch_global_ctx, | |
| self.batch_hand, | |
| self.batch_deck, | |
| self.bytecode_map, | |
| self.bytecode_index, | |
| ) | |
| # Simplified turn advancement | |
| # In real VectorEnv, this would be managed by the engine rules | |
| pass | |
| def get_observations(self): | |
| """Return a batched observation for RL models.""" | |
| return encode_observations_vectorized( | |
| self.num_envs, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.turn, | |
| self.obs_buffer, | |
| ) | |
| def encode_observations_vectorized( | |
| num_envs: int, | |
| batch_stage: np.ndarray, # (N, 3) | |
| batch_energy_count: np.ndarray, # (N, 3) | |
| batch_tapped: np.ndarray, # (N, 3) | |
| batch_scores: np.ndarray, # (N,) | |
| turn_number: int, | |
| observations: np.ndarray, # (N, 320) | |
| ): | |
| # Reset buffer (extremely fast on pre-allocated) | |
| observations.fill(0.0) | |
| max_id_val = 2000.0 # Normalization constant | |
| for i in range(num_envs): | |
| # --- 1. METADATA [0:36] --- | |
| # Phase (Simplify: Always Main Phase=1 for now in vector env) | |
| # Phase 1=Start, 2=Draw, 3=Main... Main is index 3+2=5? | |
| # GameState logic: phase_val = int(phase) + 2. Main is 3. So 5. | |
| observations[i, 5] = 1.0 | |
| # Current Player [16:18] - Always Player 0 for this vector view | |
| observations[i, 16] = 1.0 | |
| # --- 2. HAND [36:168] --- | |
| # VectorEnv doesn't track hand yet. Leave 0.0. | |
| # --- 3. SELF STAGE [168:204] (3 slots * 12 features) --- | |
| for slot in range(3): | |
| cid = batch_stage[i, slot] | |
| base = 168 + slot * 12 | |
| if cid >= 0: | |
| observations[i, base] = 1.0 | |
| observations[i, base + 1] = cid / max_id_val | |
| observations[i, base + 2] = 1.0 if batch_tapped[i, slot] else 0.0 | |
| # Mock attributes (since we don't have full DB access inside JIT yet) | |
| # In real imp, we'd pass arrays like member_costs | |
| observations[i, base + 3] = 0.5 # Default power | |
| # Energy Count | |
| observations[i, base + 11] = min(batch_energy_count[i, slot] / 5.0, 1.0) | |
| # --- 4. OPPONENT STAGE [204:240] --- | |
| # Not tracked in partial vector env yet. | |
| # --- 5. LIVE ZONE [240:270] --- | |
| # Not tracked in partial vector env yet. | |
| # --- 6. SCORES [270:272] --- | |
| observations[i, 270] = min(batch_scores[i] / 5.0, 1.0) | |
| return observations | |