Spaces:
Running
Running
| """ | |
| GPU-Native Vectorized Game Environment. | |
| This module provides VectorEnvGPU - a GPU-resident implementation using CuPy | |
| and Numba CUDA for maximum throughput. All game state arrays live in GPU VRAM, | |
| eliminating PCI-E transfer overhead during RL training. | |
| Usage: | |
| Set USE_GPU_ENV=1 to enable GPU environment in training. | |
| """ | |
| import json | |
| import os | |
| import time | |
| import numpy as np | |
| # CUDA detection | |
| HAS_CUDA = False | |
| try: | |
| import cupy as cp | |
| from numba import cuda | |
| if cuda.is_available(): | |
| HAS_CUDA = True | |
| from numba.cuda.random import create_xoroshiro128p_states | |
| except ImportError: | |
| pass | |
| # Mock objects for CPU fallback | |
| if not HAS_CUDA: | |
| class MockCP: | |
| int32 = np.int32 | |
| int8 = np.int8 | |
| float32 = np.float32 | |
| bool_ = np.bool_ | |
| def full(self, shape, val, dtype=None): | |
| return np.full(shape, val, dtype=dtype) | |
| def zeros(self, shape, dtype=None): | |
| return np.zeros(shape, dtype=dtype) | |
| def ones(self, shape, dtype=None): | |
| return np.ones(shape, dtype=dtype) | |
| def asnumpy(self, arr): | |
| return np.array(arr) | |
| def array(self, arr, dtype=None): | |
| return np.array(arr, dtype=dtype) | |
| def asarray(self, arr, dtype=None): | |
| return np.asarray(arr, dtype=dtype) | |
| def arange(self, n, dtype=None): | |
| return np.arange(n, dtype=dtype) | |
| def get_default_memory_pool(self): | |
| class MockPool: | |
| def used_bytes(self): | |
| return 0 | |
| return MockPool() | |
| cp = MockCP() | |
| class MockCudaMod: | |
| def to_device(self, arr): | |
| return arr | |
| def device_array(self, shape, dtype=None): | |
| return np.zeros(shape, dtype=dtype) | |
| def synchronize(self): | |
| pass | |
| def jit(self, *args, **kwargs): | |
| return lambda x: x | |
| def grid(self, x): | |
| return 0 | |
| cuda = MockCudaMod() | |
| def create_xoroshiro128p_states(n, seed): | |
| return None | |
| class VectorEnvGPU: | |
| """ | |
| GPU-Resident Vectorized Game Environment. | |
| All state arrays are CuPy arrays in GPU VRAM. | |
| Observations and actions are passed as GPU tensors with zero-copy. | |
| Args: | |
| num_envs: Number of parallel environments | |
| opp_mode: Opponent mode (0=Heuristic, 1=Random) | |
| force_start_order: -1=Random, 0=P1, 1=P2 | |
| """ | |
| def __init__(self, num_envs: int = 4096, opp_mode: int = 0, force_start_order: int = -1, seed: int = 42): | |
| self.num_envs = num_envs | |
| self.opp_mode = opp_mode # 0=Heuristic, 1=Random, 2=Solitaire | |
| self.force_start_order = force_start_order | |
| self.seed = seed | |
| print(f" [VectorEnvGPU] Initializing {num_envs} environments. CUDA: {HAS_CUDA}") | |
| # ========================================================= | |
| # AGENT STATE (GPU-Resident) | |
| # ========================================================= | |
| self.batch_stage = cp.full((num_envs, 3), -1, dtype=cp.int32) | |
| self.batch_energy_vec = cp.zeros((num_envs, 3, 32), dtype=cp.int32) | |
| self.batch_energy_count = cp.zeros((num_envs, 3), dtype=cp.int32) | |
| self.batch_continuous_vec = cp.zeros((num_envs, 32, 10), dtype=cp.int32) | |
| self.batch_continuous_ptr = cp.zeros(num_envs, dtype=cp.int32) | |
| self.batch_tapped = cp.zeros((num_envs, 16), dtype=cp.int32) | |
| self.batch_live = cp.zeros((num_envs, 50), dtype=cp.int32) | |
| self.batch_opp_tapped = cp.zeros((num_envs, 16), dtype=cp.int32) | |
| self.batch_scores = cp.zeros(num_envs, dtype=cp.int32) | |
| self.batch_flat_ctx = cp.zeros((num_envs, 64), dtype=cp.int32) | |
| self.batch_global_ctx = cp.zeros((num_envs, 128), dtype=cp.int32) | |
| self.batch_hand = cp.zeros((num_envs, 60), dtype=cp.int32) | |
| self.batch_deck = cp.zeros((num_envs, 60), dtype=cp.int32) | |
| self.batch_trash = cp.zeros((num_envs, 60), dtype=cp.int32) | |
| self.batch_opp_history = cp.zeros((num_envs, 6), dtype=cp.int32) | |
| # ========================================================= | |
| # OPPONENT STATE (GPU-Resident) | |
| # ========================================================= | |
| self.opp_stage = cp.full((num_envs, 3), -1, dtype=cp.int32) | |
| self.opp_energy_vec = cp.zeros((num_envs, 3, 32), dtype=cp.int32) | |
| self.opp_energy_count = cp.zeros((num_envs, 3), dtype=cp.int32) | |
| self.opp_tapped = cp.zeros((num_envs, 16), dtype=cp.int8) | |
| self.opp_live = cp.zeros((num_envs, 50), dtype=cp.int32) | |
| self.opp_scores = cp.zeros(num_envs, dtype=cp.int32) | |
| self.opp_global_ctx = cp.zeros((num_envs, 128), dtype=cp.int32) | |
| self.opp_hand = cp.zeros((num_envs, 60), dtype=cp.int32) | |
| self.opp_deck = cp.zeros((num_envs, 60), dtype=cp.int32) | |
| self.opp_trash = cp.zeros((num_envs, 60), dtype=cp.int32) | |
| # ========================================================= | |
| # TRACKING STATE | |
| # ========================================================= | |
| self.prev_scores = cp.zeros(num_envs, dtype=cp.int32) | |
| self.prev_opp_scores = cp.zeros(num_envs, dtype=cp.int32) | |
| self.prev_phases = cp.zeros(num_envs, dtype=cp.int32) | |
| self.episode_returns = cp.zeros(num_envs, dtype=cp.float32) | |
| self.episode_lengths = cp.zeros(num_envs, dtype=cp.int32) | |
| # ========================================================= | |
| # OBSERVATION MODE | |
| # ========================================================= | |
| self.obs_mode = os.getenv("OBS_MODE", "STANDARD") | |
| if self.obs_mode == "COMPRESSED": | |
| self.obs_dim = 512 | |
| elif self.obs_mode == "IMAX": | |
| self.obs_dim = 8192 | |
| elif self.obs_mode == "ATTENTION": | |
| self.obs_dim = 2240 | |
| else: | |
| self.obs_dim = 2304 | |
| print(f" [VectorEnvGPU] Observation Mode: {self.obs_mode} ({self.obs_dim}-dim)") | |
| self.batch_obs = cp.zeros((num_envs, self.obs_dim), dtype=cp.float32) | |
| self.terminal_obs_buffer = cp.zeros((num_envs, self.obs_dim), dtype=cp.float32) | |
| # Rewards and Dones | |
| self.rewards = cp.zeros(num_envs, dtype=cp.float32) | |
| self.dones = cp.zeros(num_envs, dtype=cp.bool_) | |
| self.term_scores_agent = cp.zeros(num_envs, dtype=cp.int32) | |
| self.term_scores_opp = cp.zeros(num_envs, dtype=cp.int32) | |
| # ========================================================= | |
| # GAME CONFIG | |
| # ========================================================= | |
| self.scenario_reward_scale = float(os.getenv("SCENARIO_REWARD_SCALE", "1.0")) | |
| if os.getenv("USE_SCENARIOS", "0") == "1" and self.scenario_reward_scale != 1.0: | |
| print(f" [VectorEnvGPU] Scenario Reward Scale: {self.scenario_reward_scale}") | |
| self.game_config = cp.zeros(10, dtype=cp.float32) | |
| self.game_config[0] = float(os.getenv("GAME_TURN_LIMIT", "100")) | |
| self.game_config[1] = float(os.getenv("GAME_STEP_LIMIT", "1000")) | |
| self.game_config[2] = float(os.getenv("GAME_REWARD_WIN", "100.0")) | |
| self.game_config[3] = float(os.getenv("GAME_REWARD_LOSE", "-100.0")) | |
| self.game_config[4] = float(os.getenv("GAME_REWARD_SCORE_SCALE", "50.0")) | |
| self.game_config[5] = float(os.getenv("GAME_REWARD_TURN_PENALTY", "-0.05")) | |
| # ========================================================= | |
| # GPU RNG | |
| # ========================================================= | |
| if HAS_CUDA: | |
| self.rng_states = create_xoroshiro128p_states(num_envs, seed=seed) | |
| else: | |
| self.rng_states = None | |
| # ========================================================= | |
| # KERNEL CONFIGURATION | |
| # ========================================================= | |
| self.threads_per_block = 128 | |
| self.blocks_per_grid = (num_envs + self.threads_per_block - 1) // self.threads_per_block | |
| # ========================================================= | |
| # LOAD DATA | |
| # ========================================================= | |
| self._load_bytecode() | |
| self._load_card_stats() | |
| self._load_deck_pool() | |
| # Memory stats | |
| if HAS_CUDA: | |
| mempool = cp.get_default_memory_pool() | |
| used_mb = mempool.used_bytes() / 1024 / 1024 | |
| print(f" [VectorEnvGPU] GPU VRAM used: {used_mb:.2f} MB") | |
| def _load_bytecode(self): | |
| """Load compiled bytecode to GPU.""" | |
| host_map = np.zeros((100, 128, 4), dtype=np.int32) | |
| host_idx = np.zeros((2000, 8), dtype=np.int32) | |
| try: | |
| with open("data/cards_numba.json", "r") as f: | |
| raw_map = json.load(f) | |
| max_cards = 2000 | |
| max_abilities = 8 | |
| max_len = 128 | |
| unique_entries = len(raw_map) | |
| host_map = np.zeros((unique_entries + 1, max_len, 4), dtype=np.int32) | |
| host_idx = np.full((max_cards, max_abilities), 0, dtype=np.int32) | |
| idx_counter = 1 | |
| for key, bc_list in raw_map.items(): | |
| cid, aid = map(int, key.split("_")) | |
| if cid < max_cards and aid < max_abilities: | |
| bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4) | |
| length = min(bc_arr.shape[0], max_len) | |
| host_map[idx_counter, :length] = bc_arr[:length] | |
| host_idx[cid, aid] = idx_counter | |
| idx_counter += 1 | |
| print(f" [VectorEnvGPU] Loaded {unique_entries} compiled abilities.") | |
| except FileNotFoundError: | |
| print(" [VectorEnvGPU] Warning: cards_numba.json not found.") | |
| except Exception as e: | |
| print(f" [VectorEnvGPU] Warning: Failed to load bytecode: {e}") | |
| self.bytecode_map = cp.asarray(host_map) | |
| self.bytecode_index = cp.asarray(host_idx) | |
| def _load_card_stats(self): | |
| """Load card statistics to GPU.""" | |
| host_stats = np.zeros((2000, 80), dtype=np.int32) | |
| try: | |
| with open("data/cards_compiled.json", "r", encoding="utf-8") as f: | |
| db = json.load(f) | |
| count = 0 | |
| if "member_db" in db: | |
| for cid_str, card in db["member_db"].items(): | |
| cid = int(cid_str) | |
| if cid < 2000: | |
| host_stats[cid, 0] = card.get("cost", 0) | |
| host_stats[cid, 1] = card.get("blades", 0) | |
| host_stats[cid, 2] = sum(card.get("hearts", [])) | |
| host_stats[cid, 10] = 1 # Type: Member | |
| # Hearts breakdown | |
| h_arr = card.get("hearts", []) | |
| for r_idx in range(min(len(h_arr), 7)): | |
| host_stats[cid, 12 + r_idx] = h_arr[r_idx] | |
| # Traits | |
| mask = 0 | |
| for g in card.get("groups", []): | |
| try: | |
| mask |= 1 << (int(g) % 20) | |
| except: | |
| pass | |
| host_stats[cid, 11] = mask | |
| count += 1 | |
| if "live_db" in db: | |
| for cid_str, card in db["live_db"].items(): | |
| cid = int(cid_str) | |
| if cid < 2000: | |
| host_stats[cid, 10] = 2 # Type: Live | |
| reqs = card.get("required_hearts", []) | |
| for r_idx in range(min(len(reqs), 7)): | |
| host_stats[cid, 12 + r_idx] = reqs[r_idx] | |
| host_stats[cid, 38] = card.get("score", 0) | |
| count += 1 | |
| print(f" [VectorEnvGPU] Loaded stats for {count} cards.") | |
| except Exception as e: | |
| print(f" [VectorEnvGPU] Warning: Failed to load card stats: {e}") | |
| self.card_stats = cp.asarray(host_stats) | |
| def _load_deck_pool(self): | |
| """Load verified card pool for deck generation.""" | |
| ability_member_ids = [] | |
| ability_live_ids = [] | |
| try: | |
| with open("data/verified_card_pool.json", "r", encoding="utf-8") as f: | |
| verified_data = json.load(f) | |
| with open("data/cards_compiled.json", "r", encoding="utf-8") as f: | |
| db_data = json.load(f) | |
| member_no_map = {} | |
| live_no_map = {} | |
| for cid, cdata in db_data.get("member_db", {}).items(): | |
| member_no_map[cdata["card_no"]] = int(cid) | |
| for cid, cdata in db_data.get("live_db", {}).items(): | |
| live_no_map[cdata["card_no"]] = int(cid) | |
| if isinstance(verified_data, list): | |
| for v_no in verified_data: | |
| if v_no in member_no_map: | |
| ability_member_ids.append(member_no_map[v_no]) | |
| elif v_no in live_no_map: | |
| ability_live_ids.append(live_no_map[v_no]) | |
| else: | |
| source_members = verified_data.get("verified_abilities", []) + verified_data.get("members", []) | |
| for v_no in source_members: | |
| if v_no in member_no_map: | |
| ability_member_ids.append(member_no_map[v_no]) | |
| source_lives = verified_data.get("verified_lives", []) + verified_data.get("lives", []) | |
| for v_no in source_lives: | |
| if v_no in live_no_map: | |
| ability_live_ids.append(live_no_map[v_no]) | |
| if not ability_member_ids: | |
| for v_no in verified_data.get("vanilla_members", []): | |
| if v_no in member_no_map: | |
| ability_member_ids.append(member_no_map[v_no]) | |
| if not ability_live_ids: | |
| for v_no in verified_data.get("vanilla_lives", []): | |
| if v_no in live_no_map: | |
| ability_live_ids.append(live_no_map[v_no]) | |
| if not ability_member_ids: | |
| ability_member_ids = [1] | |
| if not ability_live_ids: | |
| ability_live_ids = [999] | |
| print(f" [VectorEnvGPU] Deck Pool: {len(ability_member_ids)} members, {len(ability_live_ids)} lives") | |
| except Exception as e: | |
| print(f" [VectorEnvGPU] Deck Load Error: {e}") | |
| ability_member_ids = [1] | |
| ability_live_ids = [999] | |
| self.ability_member_ids = cp.array(ability_member_ids, dtype=cp.int32) | |
| self.ability_live_ids = cp.array(ability_live_ids, dtype=cp.int32) | |
| # ========================================================= | |
| # PYTORCH INTERFACE | |
| # ========================================================= | |
| def get_observations_tensor(self): | |
| """Return observations as PyTorch CUDA tensor (zero-copy).""" | |
| import torch | |
| return torch.as_tensor(self.batch_obs, device="cuda") | |
| def get_action_masks_tensor(self): | |
| """Return action masks as PyTorch CUDA tensor.""" | |
| import torch | |
| masks = self.get_action_masks() | |
| return torch.as_tensor(masks, device="cuda") | |
| def get_rewards_tensor(self): | |
| """Return rewards as PyTorch CUDA tensor.""" | |
| import torch | |
| return torch.as_tensor(self.rewards, device="cuda") | |
| def get_dones_tensor(self): | |
| """Return dones as PyTorch CUDA tensor.""" | |
| import torch | |
| return torch.as_tensor(self.dones, device="cuda") | |
| # ========================================================= | |
| # ENVIRONMENT INTERFACE | |
| # ========================================================= | |
| def reset(self, indices=None): | |
| """Reset environments.""" | |
| if not HAS_CUDA: | |
| # CPU fallback | |
| self.batch_stage.fill(-1) | |
| self.batch_scores.fill(0) | |
| self.batch_global_ctx.fill(0) | |
| self.batch_hand.fill(0) | |
| self.batch_deck.fill(0) | |
| return self.batch_obs | |
| from ai.cuda_kernels import encode_observations_attention_kernel, encode_observations_kernel, reset_kernel | |
| if indices is None: | |
| indices_gpu = cp.arange(self.num_envs, dtype=cp.int32) | |
| else: | |
| indices_gpu = cp.array(indices, dtype=cp.int32) | |
| blocks = (len(indices_gpu) + self.threads_per_block - 1) // self.threads_per_block | |
| reset_kernel[blocks, self.threads_per_block]( | |
| indices_gpu, | |
| self.batch_stage, | |
| self.batch_energy_vec, | |
| self.batch_energy_count, | |
| self.batch_continuous_vec, | |
| self.batch_continuous_ptr, | |
| self.batch_tapped, | |
| self.batch_live, | |
| self.batch_scores, | |
| self.batch_flat_ctx, | |
| self.batch_global_ctx, | |
| self.batch_hand, | |
| self.batch_deck, | |
| self.batch_trash, | |
| self.batch_opp_history, | |
| self.opp_stage, | |
| self.opp_energy_vec, | |
| self.opp_energy_count, | |
| self.opp_tapped, | |
| self.opp_live, | |
| self.opp_scores, | |
| self.opp_global_ctx, | |
| self.opp_hand, | |
| self.opp_deck, | |
| self.opp_trash, | |
| self.ability_member_ids, | |
| self.ability_live_ids, | |
| self.rng_states, | |
| self.force_start_order, | |
| self.batch_obs, | |
| self.card_stats, | |
| ) | |
| # Encode initial observations | |
| if self.obs_mode == "ATTENTION": | |
| encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| self.batch_opp_history, | |
| self.opp_global_ctx, | |
| 1, | |
| self.batch_obs, | |
| ) | |
| else: | |
| encode_observations_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| 1, | |
| self.batch_obs, | |
| ) | |
| # Reset tracking | |
| if indices is None: | |
| self.prev_scores.fill(0) | |
| self.prev_opp_scores.fill(0) | |
| self.episode_returns.fill(0) | |
| self.episode_lengths.fill(0) | |
| else: | |
| self.prev_scores[indices_gpu] = 0 | |
| self.prev_opp_scores[indices_gpu] = 0 | |
| self.episode_returns[indices_gpu] = 0 | |
| self.episode_lengths[indices_gpu] = 0 | |
| return self.batch_obs | |
| def step(self, actions): | |
| """ | |
| Step all environments. | |
| Args: | |
| actions: CuPy array or PyTorch tensor of actions | |
| Returns: | |
| obs, rewards, dones, infos | |
| """ | |
| if not HAS_CUDA: | |
| # Fallback | |
| return self.batch_obs, self.rewards, self.dones, [{}] * self.num_envs | |
| import torch | |
| from ai.cuda_kernels import ( | |
| encode_observations_attention_kernel, | |
| encode_observations_kernel, | |
| reset_kernel, | |
| step_kernel, | |
| ) | |
| # Convert to CuPy if needed | |
| if isinstance(actions, torch.Tensor): | |
| actions_gpu = cp.asarray(actions.cpu().numpy(), dtype=cp.int32) | |
| elif isinstance(actions, np.ndarray): | |
| actions_gpu = cp.asarray(actions, dtype=cp.int32) | |
| else: | |
| actions_gpu = actions | |
| # 1. Step kernel | |
| step_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| actions_gpu, | |
| self.batch_hand, | |
| self.batch_deck, | |
| self.batch_stage, | |
| self.batch_energy_vec, | |
| self.batch_energy_count, | |
| self.batch_continuous_vec, | |
| self.batch_continuous_ptr, | |
| self.batch_tapped, | |
| self.batch_live, | |
| self.batch_scores, | |
| self.batch_flat_ctx, | |
| self.batch_global_ctx, | |
| self.opp_hand, | |
| self.opp_deck, | |
| self.opp_stage, | |
| self.opp_energy_vec, | |
| self.opp_energy_count, | |
| self.opp_tapped, | |
| self.opp_live, | |
| self.opp_scores, | |
| self.opp_global_ctx, | |
| self.card_stats, | |
| self.bytecode_map, | |
| self.bytecode_index, | |
| self.batch_obs, | |
| self.rewards, | |
| self.dones, | |
| self.prev_scores, | |
| self.prev_opp_scores, | |
| self.prev_phases, | |
| self.terminal_obs_buffer, | |
| self.batch_trash, | |
| self.opp_trash, | |
| self.batch_opp_history, | |
| self.term_scores_agent, | |
| self.term_scores_opp, | |
| self.ability_member_ids, | |
| self.ability_live_ids, | |
| self.rng_states, | |
| self.game_config, | |
| self.opp_mode, | |
| self.force_start_order, | |
| ) | |
| # Apply Scenario Reward Scaling | |
| if self.scenario_reward_scale != 1.0 and os.getenv("USE_SCENARIOS", "0") == "1": | |
| self.rewards *= self.scenario_reward_scale | |
| # 2. Update Episodic Returns/Lengths (Vectorized GPU) | |
| self.episode_returns += self.rewards | |
| self.episode_lengths += 1 | |
| # 3. Handle Auto-Reset (High Performance) | |
| dones_cpu = cp.asnumpy(self.dones) | |
| # Pre-allocate infos list (reused or created) | |
| infos = [{} for _ in range(self.num_envs)] | |
| if np.any(dones_cpu): | |
| done_indices = np.where(dones_cpu)[0] | |
| done_indices_gpu = cp.array(done_indices, dtype=cp.int32) | |
| # A. Capture Terminal Observations (from UNRESET state) | |
| # Efficient Device-to-Device copy | |
| # NOTE: step_kernel leaves env in finished state, so batch_obs has terminal state. | |
| # We must encode it? | |
| # Actually, step_kernel calls encode at end? No, step_kernel does NOT encode obs in my implementation. | |
| # I removed the Python-side encode calls from previous impl? | |
| # Wait, step_kernel logic in my head vs file. | |
| # In ai/cuda_kernels.py, step_kernel does NOT call encode. | |
| # So batch_obs is STALE (from previous step)! | |
| # We MUST encode the terminal state first. | |
| # Encode CURRENT state (Terminal) for ALL envs? Or just done? | |
| # Usually we encode all envs at end of step. | |
| # BUT we need to reset done envs and encode AGAIN. | |
| # OPTIMIZATION: | |
| # 1. Encode ALL envs (Next state for running, Terminal for done). | |
| turn_num = 1 # Dummy, kernels use ctx | |
| if self.obs_mode == "ATTENTION": | |
| encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| self.batch_opp_history, | |
| self.opp_global_ctx, | |
| turn_num, | |
| self.batch_obs, | |
| ) | |
| else: | |
| encode_observations_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| turn_num, | |
| self.batch_obs, | |
| ) | |
| # 2. For Done Envs: Copy encoded terminal state to buffer | |
| # We can use fancy indexing copy on GPU | |
| self.terminal_obs_buffer[done_indices_gpu] = self.batch_obs[done_indices_gpu] | |
| # 3. Fetch Terminal Info Metrics (Bulk D2H) | |
| final_returns = cp.asnumpy(self.episode_returns[done_indices_gpu]) | |
| final_lengths = cp.asnumpy(self.episode_lengths[done_indices_gpu]) | |
| term_obs_cpu = cp.asnumpy(self.terminal_obs_buffer[done_indices_gpu]) | |
| term_scores_ag = cp.asnumpy(self.term_scores_agent[done_indices_gpu]) | |
| term_scores_op = cp.asnumpy(self.term_scores_opp[done_indices_gpu]) | |
| # 4. Populate Infos (CPU Loop over SMALL subset) | |
| for k, idx in enumerate(done_indices): | |
| infos[idx] = { | |
| "terminal_observation": term_obs_cpu[k], | |
| "episode": {"r": float(final_returns[k]), "l": int(final_lengths[k])}, | |
| "terminal_score_agent": int(term_scores_ag[k]), | |
| "terminal_score_opp": int(term_scores_op[k]), | |
| } | |
| # 5. Reset Done Envs | |
| # Reset accumulators | |
| self.episode_returns[done_indices_gpu] = 0 | |
| self.episode_lengths[done_indices_gpu] = 0 | |
| # Launch Reset Kernel | |
| blocks_reset = (len(done_indices) + self.threads_per_block - 1) // self.threads_per_block | |
| reset_kernel[blocks_reset, self.threads_per_block]( | |
| done_indices_gpu, | |
| self.batch_stage, | |
| self.batch_energy_vec, | |
| self.batch_energy_count, | |
| self.batch_continuous_vec, | |
| self.batch_continuous_ptr, | |
| self.batch_tapped, | |
| self.batch_live, | |
| self.batch_scores, | |
| self.batch_flat_ctx, | |
| self.batch_global_ctx, | |
| self.batch_hand, | |
| self.batch_deck, | |
| self.batch_trash, | |
| self.batch_opp_history, | |
| self.opp_stage, | |
| self.opp_energy_vec, | |
| self.opp_energy_count, | |
| self.opp_tapped, | |
| self.opp_live, | |
| self.opp_scores, | |
| self.opp_global_ctx, | |
| self.opp_hand, | |
| self.opp_deck, | |
| self.opp_trash, | |
| self.ability_member_ids, | |
| self.ability_live_ids, | |
| self.rng_states, | |
| self.force_start_order, | |
| self.batch_obs, | |
| self.card_stats, | |
| ) | |
| # 6. Re-Encode Reset Envs (to get initial state) | |
| # We assume reset_kernel updates state but NOT obs. | |
| # We need to re-run encode kernel ONLY for done indices? | |
| # Or run global encode again? Global is waste. | |
| # We need an encode kernel that takes indices. | |
| # The current kernel takes `num_envs` and assumes `0..N`. | |
| # We can reuse the global kernel if we are clever or modify it. | |
| # Modifying kernel to accept indices is best. | |
| # However, for now, to save complexity, we can re-run global encode. | |
| # It's redundant for non-done envs but correct. | |
| # Better: Reset modifies batch_obs directly? No, reset_kernel doesn't encode. | |
| # Let's re-run global encode. It's fast (GPU) compared to CPU loop. | |
| if self.obs_mode == "ATTENTION": | |
| encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| self.batch_opp_history, | |
| self.opp_global_ctx, | |
| turn_num, | |
| self.batch_obs, | |
| ) | |
| else: | |
| encode_observations_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| turn_num, | |
| self.batch_obs, | |
| ) | |
| else: | |
| # No resets needed. Just encode once to get next states. | |
| # Encode observations | |
| turn_num = 1 | |
| if self.obs_mode == "ATTENTION": | |
| encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| self.batch_opp_history, | |
| self.opp_global_ctx, | |
| turn_num, | |
| self.batch_obs, | |
| ) | |
| else: | |
| encode_observations_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_energy_count, | |
| self.batch_tapped, | |
| self.batch_scores, | |
| self.opp_scores, | |
| self.opp_stage, | |
| self.opp_tapped, | |
| self.card_stats, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| turn_num, | |
| self.batch_obs, | |
| ) | |
| return self.batch_obs, self.rewards, self.dones, infos | |
| def get_observations(self): | |
| """Return observation buffer (CuPy array).""" | |
| return self.batch_obs | |
| def get_action_masks(self): | |
| """Compute and return action masks (CuPy array).""" | |
| if not HAS_CUDA: | |
| return cp.ones((self.num_envs, 2000), dtype=cp.bool_) | |
| from ai.cuda_kernels import compute_action_masks_kernel | |
| masks = cp.zeros((self.num_envs, 2000), dtype=cp.bool_) | |
| compute_action_masks_kernel[self.blocks_per_grid, self.threads_per_block]( | |
| self.num_envs, | |
| self.batch_hand, | |
| self.batch_stage, | |
| self.batch_tapped, | |
| self.batch_global_ctx, | |
| self.batch_live, | |
| self.card_stats, | |
| masks, | |
| ) | |
| return masks | |
| # ============================================================================ | |
| # BENCHMARK | |
| # ============================================================================ | |
| def benchmark_gpu_env(num_envs=4096, steps=1000): | |
| """Benchmark GPU environment throughput.""" | |
| print("\n=== GPU Environment Benchmark ===") | |
| print(f"Environments: {num_envs}") | |
| print(f"Steps: {steps}") | |
| env = VectorEnvGPU(num_envs=num_envs) | |
| env.reset() | |
| # Warmup | |
| for _ in range(10): | |
| actions = cp.zeros(num_envs, dtype=cp.int32) | |
| env.step(actions) | |
| if HAS_CUDA: | |
| cuda.synchronize() | |
| # Benchmark | |
| start = time.time() | |
| for _ in range(steps): | |
| actions = cp.zeros(num_envs, dtype=cp.int32) # Pass action | |
| env.step(actions) | |
| if HAS_CUDA: | |
| cuda.synchronize() | |
| elapsed = time.time() - start | |
| total_steps = num_envs * steps | |
| sps = total_steps / elapsed | |
| print("\nResults:") | |
| print(f" Total Steps: {total_steps:,}") | |
| print(f" Time: {elapsed:.2f}s") | |
| print(f" Throughput: {sps:,.0f} steps/sec") | |
| return sps | |
| if __name__ == "__main__": | |
| # Quick test | |
| env = VectorEnvGPU(num_envs=128) | |
| obs = env.reset() | |
| print(f"Observation shape: {obs.shape}") | |
| actions = cp.zeros(128, dtype=cp.int32) | |
| obs, rewards, dones, infos = env.step(actions) | |
| print(f"Step completed. Rewards shape: {rewards.shape}") | |
| # Benchmark | |
| benchmark_gpu_env(num_envs=1024, steps=100) | |