LovecaSim / ai /environments /vector_env_gpu.py
trioskosmos's picture
Upload ai/environments/vector_env_gpu.py with huggingface_hub
250ac83 verified
"""
GPU-Native Vectorized Game Environment.
This module provides VectorEnvGPU - a GPU-resident implementation using CuPy
and Numba CUDA for maximum throughput. All game state arrays live in GPU VRAM,
eliminating PCI-E transfer overhead during RL training.
Usage:
Set USE_GPU_ENV=1 to enable GPU environment in training.
"""
import json
import os
import time
import numpy as np
# CUDA detection
HAS_CUDA = False
try:
import cupy as cp
from numba import cuda
if cuda.is_available():
HAS_CUDA = True
from numba.cuda.random import create_xoroshiro128p_states
except ImportError:
pass
# Mock objects for CPU fallback
if not HAS_CUDA:
class MockCP:
int32 = np.int32
int8 = np.int8
float32 = np.float32
bool_ = np.bool_
def full(self, shape, val, dtype=None):
return np.full(shape, val, dtype=dtype)
def zeros(self, shape, dtype=None):
return np.zeros(shape, dtype=dtype)
def ones(self, shape, dtype=None):
return np.ones(shape, dtype=dtype)
def asnumpy(self, arr):
return np.array(arr)
def array(self, arr, dtype=None):
return np.array(arr, dtype=dtype)
def asarray(self, arr, dtype=None):
return np.asarray(arr, dtype=dtype)
def arange(self, n, dtype=None):
return np.arange(n, dtype=dtype)
def get_default_memory_pool(self):
class MockPool:
def used_bytes(self):
return 0
return MockPool()
cp = MockCP()
class MockCudaMod:
def to_device(self, arr):
return arr
def device_array(self, shape, dtype=None):
return np.zeros(shape, dtype=dtype)
def synchronize(self):
pass
def jit(self, *args, **kwargs):
return lambda x: x
def grid(self, x):
return 0
cuda = MockCudaMod()
def create_xoroshiro128p_states(n, seed):
return None
class VectorEnvGPU:
"""
GPU-Resident Vectorized Game Environment.
All state arrays are CuPy arrays in GPU VRAM.
Observations and actions are passed as GPU tensors with zero-copy.
Args:
num_envs: Number of parallel environments
opp_mode: Opponent mode (0=Heuristic, 1=Random)
force_start_order: -1=Random, 0=P1, 1=P2
"""
def __init__(self, num_envs: int = 4096, opp_mode: int = 0, force_start_order: int = -1, seed: int = 42):
self.num_envs = num_envs
self.opp_mode = opp_mode # 0=Heuristic, 1=Random, 2=Solitaire
self.force_start_order = force_start_order
self.seed = seed
print(f" [VectorEnvGPU] Initializing {num_envs} environments. CUDA: {HAS_CUDA}")
# =========================================================
# AGENT STATE (GPU-Resident)
# =========================================================
self.batch_stage = cp.full((num_envs, 3), -1, dtype=cp.int32)
self.batch_energy_vec = cp.zeros((num_envs, 3, 32), dtype=cp.int32)
self.batch_energy_count = cp.zeros((num_envs, 3), dtype=cp.int32)
self.batch_continuous_vec = cp.zeros((num_envs, 32, 10), dtype=cp.int32)
self.batch_continuous_ptr = cp.zeros(num_envs, dtype=cp.int32)
self.batch_tapped = cp.zeros((num_envs, 16), dtype=cp.int32)
self.batch_live = cp.zeros((num_envs, 50), dtype=cp.int32)
self.batch_opp_tapped = cp.zeros((num_envs, 16), dtype=cp.int32)
self.batch_scores = cp.zeros(num_envs, dtype=cp.int32)
self.batch_flat_ctx = cp.zeros((num_envs, 64), dtype=cp.int32)
self.batch_global_ctx = cp.zeros((num_envs, 128), dtype=cp.int32)
self.batch_hand = cp.zeros((num_envs, 60), dtype=cp.int32)
self.batch_deck = cp.zeros((num_envs, 60), dtype=cp.int32)
self.batch_trash = cp.zeros((num_envs, 60), dtype=cp.int32)
self.batch_opp_history = cp.zeros((num_envs, 6), dtype=cp.int32)
# =========================================================
# OPPONENT STATE (GPU-Resident)
# =========================================================
self.opp_stage = cp.full((num_envs, 3), -1, dtype=cp.int32)
self.opp_energy_vec = cp.zeros((num_envs, 3, 32), dtype=cp.int32)
self.opp_energy_count = cp.zeros((num_envs, 3), dtype=cp.int32)
self.opp_tapped = cp.zeros((num_envs, 16), dtype=cp.int8)
self.opp_live = cp.zeros((num_envs, 50), dtype=cp.int32)
self.opp_scores = cp.zeros(num_envs, dtype=cp.int32)
self.opp_global_ctx = cp.zeros((num_envs, 128), dtype=cp.int32)
self.opp_hand = cp.zeros((num_envs, 60), dtype=cp.int32)
self.opp_deck = cp.zeros((num_envs, 60), dtype=cp.int32)
self.opp_trash = cp.zeros((num_envs, 60), dtype=cp.int32)
# =========================================================
# TRACKING STATE
# =========================================================
self.prev_scores = cp.zeros(num_envs, dtype=cp.int32)
self.prev_opp_scores = cp.zeros(num_envs, dtype=cp.int32)
self.prev_phases = cp.zeros(num_envs, dtype=cp.int32)
self.episode_returns = cp.zeros(num_envs, dtype=cp.float32)
self.episode_lengths = cp.zeros(num_envs, dtype=cp.int32)
# =========================================================
# OBSERVATION MODE
# =========================================================
self.obs_mode = os.getenv("OBS_MODE", "STANDARD")
if self.obs_mode == "COMPRESSED":
self.obs_dim = 512
elif self.obs_mode == "IMAX":
self.obs_dim = 8192
elif self.obs_mode == "ATTENTION":
self.obs_dim = 2240
else:
self.obs_dim = 2304
print(f" [VectorEnvGPU] Observation Mode: {self.obs_mode} ({self.obs_dim}-dim)")
self.batch_obs = cp.zeros((num_envs, self.obs_dim), dtype=cp.float32)
self.terminal_obs_buffer = cp.zeros((num_envs, self.obs_dim), dtype=cp.float32)
# Rewards and Dones
self.rewards = cp.zeros(num_envs, dtype=cp.float32)
self.dones = cp.zeros(num_envs, dtype=cp.bool_)
self.term_scores_agent = cp.zeros(num_envs, dtype=cp.int32)
self.term_scores_opp = cp.zeros(num_envs, dtype=cp.int32)
# =========================================================
# GAME CONFIG
# =========================================================
self.scenario_reward_scale = float(os.getenv("SCENARIO_REWARD_SCALE", "1.0"))
if os.getenv("USE_SCENARIOS", "0") == "1" and self.scenario_reward_scale != 1.0:
print(f" [VectorEnvGPU] Scenario Reward Scale: {self.scenario_reward_scale}")
self.game_config = cp.zeros(10, dtype=cp.float32)
self.game_config[0] = float(os.getenv("GAME_TURN_LIMIT", "100"))
self.game_config[1] = float(os.getenv("GAME_STEP_LIMIT", "1000"))
self.game_config[2] = float(os.getenv("GAME_REWARD_WIN", "100.0"))
self.game_config[3] = float(os.getenv("GAME_REWARD_LOSE", "-100.0"))
self.game_config[4] = float(os.getenv("GAME_REWARD_SCORE_SCALE", "50.0"))
self.game_config[5] = float(os.getenv("GAME_REWARD_TURN_PENALTY", "-0.05"))
# =========================================================
# GPU RNG
# =========================================================
if HAS_CUDA:
self.rng_states = create_xoroshiro128p_states(num_envs, seed=seed)
else:
self.rng_states = None
# =========================================================
# KERNEL CONFIGURATION
# =========================================================
self.threads_per_block = 128
self.blocks_per_grid = (num_envs + self.threads_per_block - 1) // self.threads_per_block
# =========================================================
# LOAD DATA
# =========================================================
self._load_bytecode()
self._load_card_stats()
self._load_deck_pool()
# Memory stats
if HAS_CUDA:
mempool = cp.get_default_memory_pool()
used_mb = mempool.used_bytes() / 1024 / 1024
print(f" [VectorEnvGPU] GPU VRAM used: {used_mb:.2f} MB")
def _load_bytecode(self):
"""Load compiled bytecode to GPU."""
host_map = np.zeros((100, 128, 4), dtype=np.int32)
host_idx = np.zeros((2000, 8), dtype=np.int32)
try:
with open("data/cards_numba.json", "r") as f:
raw_map = json.load(f)
max_cards = 2000
max_abilities = 8
max_len = 128
unique_entries = len(raw_map)
host_map = np.zeros((unique_entries + 1, max_len, 4), dtype=np.int32)
host_idx = np.full((max_cards, max_abilities), 0, dtype=np.int32)
idx_counter = 1
for key, bc_list in raw_map.items():
cid, aid = map(int, key.split("_"))
if cid < max_cards and aid < max_abilities:
bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
length = min(bc_arr.shape[0], max_len)
host_map[idx_counter, :length] = bc_arr[:length]
host_idx[cid, aid] = idx_counter
idx_counter += 1
print(f" [VectorEnvGPU] Loaded {unique_entries} compiled abilities.")
except FileNotFoundError:
print(" [VectorEnvGPU] Warning: cards_numba.json not found.")
except Exception as e:
print(f" [VectorEnvGPU] Warning: Failed to load bytecode: {e}")
self.bytecode_map = cp.asarray(host_map)
self.bytecode_index = cp.asarray(host_idx)
def _load_card_stats(self):
"""Load card statistics to GPU."""
host_stats = np.zeros((2000, 80), dtype=np.int32)
try:
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db = json.load(f)
count = 0
if "member_db" in db:
for cid_str, card in db["member_db"].items():
cid = int(cid_str)
if cid < 2000:
host_stats[cid, 0] = card.get("cost", 0)
host_stats[cid, 1] = card.get("blades", 0)
host_stats[cid, 2] = sum(card.get("hearts", []))
host_stats[cid, 10] = 1 # Type: Member
# Hearts breakdown
h_arr = card.get("hearts", [])
for r_idx in range(min(len(h_arr), 7)):
host_stats[cid, 12 + r_idx] = h_arr[r_idx]
# Traits
mask = 0
for g in card.get("groups", []):
try:
mask |= 1 << (int(g) % 20)
except:
pass
host_stats[cid, 11] = mask
count += 1
if "live_db" in db:
for cid_str, card in db["live_db"].items():
cid = int(cid_str)
if cid < 2000:
host_stats[cid, 10] = 2 # Type: Live
reqs = card.get("required_hearts", [])
for r_idx in range(min(len(reqs), 7)):
host_stats[cid, 12 + r_idx] = reqs[r_idx]
host_stats[cid, 38] = card.get("score", 0)
count += 1
print(f" [VectorEnvGPU] Loaded stats for {count} cards.")
except Exception as e:
print(f" [VectorEnvGPU] Warning: Failed to load card stats: {e}")
self.card_stats = cp.asarray(host_stats)
def _load_deck_pool(self):
"""Load verified card pool for deck generation."""
ability_member_ids = []
ability_live_ids = []
try:
with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
verified_data = json.load(f)
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db_data = json.load(f)
member_no_map = {}
live_no_map = {}
for cid, cdata in db_data.get("member_db", {}).items():
member_no_map[cdata["card_no"]] = int(cid)
for cid, cdata in db_data.get("live_db", {}).items():
live_no_map[cdata["card_no"]] = int(cid)
if isinstance(verified_data, list):
for v_no in verified_data:
if v_no in member_no_map:
ability_member_ids.append(member_no_map[v_no])
elif v_no in live_no_map:
ability_live_ids.append(live_no_map[v_no])
else:
source_members = verified_data.get("verified_abilities", []) + verified_data.get("members", [])
for v_no in source_members:
if v_no in member_no_map:
ability_member_ids.append(member_no_map[v_no])
source_lives = verified_data.get("verified_lives", []) + verified_data.get("lives", [])
for v_no in source_lives:
if v_no in live_no_map:
ability_live_ids.append(live_no_map[v_no])
if not ability_member_ids:
for v_no in verified_data.get("vanilla_members", []):
if v_no in member_no_map:
ability_member_ids.append(member_no_map[v_no])
if not ability_live_ids:
for v_no in verified_data.get("vanilla_lives", []):
if v_no in live_no_map:
ability_live_ids.append(live_no_map[v_no])
if not ability_member_ids:
ability_member_ids = [1]
if not ability_live_ids:
ability_live_ids = [999]
print(f" [VectorEnvGPU] Deck Pool: {len(ability_member_ids)} members, {len(ability_live_ids)} lives")
except Exception as e:
print(f" [VectorEnvGPU] Deck Load Error: {e}")
ability_member_ids = [1]
ability_live_ids = [999]
self.ability_member_ids = cp.array(ability_member_ids, dtype=cp.int32)
self.ability_live_ids = cp.array(ability_live_ids, dtype=cp.int32)
# =========================================================
# PYTORCH INTERFACE
# =========================================================
def get_observations_tensor(self):
"""Return observations as PyTorch CUDA tensor (zero-copy)."""
import torch
return torch.as_tensor(self.batch_obs, device="cuda")
def get_action_masks_tensor(self):
"""Return action masks as PyTorch CUDA tensor."""
import torch
masks = self.get_action_masks()
return torch.as_tensor(masks, device="cuda")
def get_rewards_tensor(self):
"""Return rewards as PyTorch CUDA tensor."""
import torch
return torch.as_tensor(self.rewards, device="cuda")
def get_dones_tensor(self):
"""Return dones as PyTorch CUDA tensor."""
import torch
return torch.as_tensor(self.dones, device="cuda")
# =========================================================
# ENVIRONMENT INTERFACE
# =========================================================
def reset(self, indices=None):
"""Reset environments."""
if not HAS_CUDA:
# CPU fallback
self.batch_stage.fill(-1)
self.batch_scores.fill(0)
self.batch_global_ctx.fill(0)
self.batch_hand.fill(0)
self.batch_deck.fill(0)
return self.batch_obs
from ai.cuda_kernels import encode_observations_attention_kernel, encode_observations_kernel, reset_kernel
if indices is None:
indices_gpu = cp.arange(self.num_envs, dtype=cp.int32)
else:
indices_gpu = cp.array(indices, dtype=cp.int32)
blocks = (len(indices_gpu) + self.threads_per_block - 1) // self.threads_per_block
reset_kernel[blocks, self.threads_per_block](
indices_gpu,
self.batch_stage,
self.batch_energy_vec,
self.batch_energy_count,
self.batch_continuous_vec,
self.batch_continuous_ptr,
self.batch_tapped,
self.batch_live,
self.batch_scores,
self.batch_flat_ctx,
self.batch_global_ctx,
self.batch_hand,
self.batch_deck,
self.batch_trash,
self.batch_opp_history,
self.opp_stage,
self.opp_energy_vec,
self.opp_energy_count,
self.opp_tapped,
self.opp_live,
self.opp_scores,
self.opp_global_ctx,
self.opp_hand,
self.opp_deck,
self.opp_trash,
self.ability_member_ids,
self.ability_live_ids,
self.rng_states,
self.force_start_order,
self.batch_obs,
self.card_stats,
)
# Encode initial observations
if self.obs_mode == "ATTENTION":
encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.opp_global_ctx,
1,
self.batch_obs,
)
else:
encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
1,
self.batch_obs,
)
# Reset tracking
if indices is None:
self.prev_scores.fill(0)
self.prev_opp_scores.fill(0)
self.episode_returns.fill(0)
self.episode_lengths.fill(0)
else:
self.prev_scores[indices_gpu] = 0
self.prev_opp_scores[indices_gpu] = 0
self.episode_returns[indices_gpu] = 0
self.episode_lengths[indices_gpu] = 0
return self.batch_obs
def step(self, actions):
"""
Step all environments.
Args:
actions: CuPy array or PyTorch tensor of actions
Returns:
obs, rewards, dones, infos
"""
if not HAS_CUDA:
# Fallback
return self.batch_obs, self.rewards, self.dones, [{}] * self.num_envs
import torch
from ai.cuda_kernels import (
encode_observations_attention_kernel,
encode_observations_kernel,
reset_kernel,
step_kernel,
)
# Convert to CuPy if needed
if isinstance(actions, torch.Tensor):
actions_gpu = cp.asarray(actions.cpu().numpy(), dtype=cp.int32)
elif isinstance(actions, np.ndarray):
actions_gpu = cp.asarray(actions, dtype=cp.int32)
else:
actions_gpu = actions
# 1. Step kernel
step_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
actions_gpu,
self.batch_hand,
self.batch_deck,
self.batch_stage,
self.batch_energy_vec,
self.batch_energy_count,
self.batch_continuous_vec,
self.batch_continuous_ptr,
self.batch_tapped,
self.batch_live,
self.batch_scores,
self.batch_flat_ctx,
self.batch_global_ctx,
self.opp_hand,
self.opp_deck,
self.opp_stage,
self.opp_energy_vec,
self.opp_energy_count,
self.opp_tapped,
self.opp_live,
self.opp_scores,
self.opp_global_ctx,
self.card_stats,
self.bytecode_map,
self.bytecode_index,
self.batch_obs,
self.rewards,
self.dones,
self.prev_scores,
self.prev_opp_scores,
self.prev_phases,
self.terminal_obs_buffer,
self.batch_trash,
self.opp_trash,
self.batch_opp_history,
self.term_scores_agent,
self.term_scores_opp,
self.ability_member_ids,
self.ability_live_ids,
self.rng_states,
self.game_config,
self.opp_mode,
self.force_start_order,
)
# Apply Scenario Reward Scaling
if self.scenario_reward_scale != 1.0 and os.getenv("USE_SCENARIOS", "0") == "1":
self.rewards *= self.scenario_reward_scale
# 2. Update Episodic Returns/Lengths (Vectorized GPU)
self.episode_returns += self.rewards
self.episode_lengths += 1
# 3. Handle Auto-Reset (High Performance)
dones_cpu = cp.asnumpy(self.dones)
# Pre-allocate infos list (reused or created)
infos = [{} for _ in range(self.num_envs)]
if np.any(dones_cpu):
done_indices = np.where(dones_cpu)[0]
done_indices_gpu = cp.array(done_indices, dtype=cp.int32)
# A. Capture Terminal Observations (from UNRESET state)
# Efficient Device-to-Device copy
# NOTE: step_kernel leaves env in finished state, so batch_obs has terminal state.
# We must encode it?
# Actually, step_kernel calls encode at end? No, step_kernel does NOT encode obs in my implementation.
# I removed the Python-side encode calls from previous impl?
# Wait, step_kernel logic in my head vs file.
# In ai/cuda_kernels.py, step_kernel does NOT call encode.
# So batch_obs is STALE (from previous step)!
# We MUST encode the terminal state first.
# Encode CURRENT state (Terminal) for ALL envs? Or just done?
# Usually we encode all envs at end of step.
# BUT we need to reset done envs and encode AGAIN.
# OPTIMIZATION:
# 1. Encode ALL envs (Next state for running, Terminal for done).
turn_num = 1 # Dummy, kernels use ctx
if self.obs_mode == "ATTENTION":
encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.opp_global_ctx,
turn_num,
self.batch_obs,
)
else:
encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
turn_num,
self.batch_obs,
)
# 2. For Done Envs: Copy encoded terminal state to buffer
# We can use fancy indexing copy on GPU
self.terminal_obs_buffer[done_indices_gpu] = self.batch_obs[done_indices_gpu]
# 3. Fetch Terminal Info Metrics (Bulk D2H)
final_returns = cp.asnumpy(self.episode_returns[done_indices_gpu])
final_lengths = cp.asnumpy(self.episode_lengths[done_indices_gpu])
term_obs_cpu = cp.asnumpy(self.terminal_obs_buffer[done_indices_gpu])
term_scores_ag = cp.asnumpy(self.term_scores_agent[done_indices_gpu])
term_scores_op = cp.asnumpy(self.term_scores_opp[done_indices_gpu])
# 4. Populate Infos (CPU Loop over SMALL subset)
for k, idx in enumerate(done_indices):
infos[idx] = {
"terminal_observation": term_obs_cpu[k],
"episode": {"r": float(final_returns[k]), "l": int(final_lengths[k])},
"terminal_score_agent": int(term_scores_ag[k]),
"terminal_score_opp": int(term_scores_op[k]),
}
# 5. Reset Done Envs
# Reset accumulators
self.episode_returns[done_indices_gpu] = 0
self.episode_lengths[done_indices_gpu] = 0
# Launch Reset Kernel
blocks_reset = (len(done_indices) + self.threads_per_block - 1) // self.threads_per_block
reset_kernel[blocks_reset, self.threads_per_block](
done_indices_gpu,
self.batch_stage,
self.batch_energy_vec,
self.batch_energy_count,
self.batch_continuous_vec,
self.batch_continuous_ptr,
self.batch_tapped,
self.batch_live,
self.batch_scores,
self.batch_flat_ctx,
self.batch_global_ctx,
self.batch_hand,
self.batch_deck,
self.batch_trash,
self.batch_opp_history,
self.opp_stage,
self.opp_energy_vec,
self.opp_energy_count,
self.opp_tapped,
self.opp_live,
self.opp_scores,
self.opp_global_ctx,
self.opp_hand,
self.opp_deck,
self.opp_trash,
self.ability_member_ids,
self.ability_live_ids,
self.rng_states,
self.force_start_order,
self.batch_obs,
self.card_stats,
)
# 6. Re-Encode Reset Envs (to get initial state)
# We assume reset_kernel updates state but NOT obs.
# We need to re-run encode kernel ONLY for done indices?
# Or run global encode again? Global is waste.
# We need an encode kernel that takes indices.
# The current kernel takes `num_envs` and assumes `0..N`.
# We can reuse the global kernel if we are clever or modify it.
# Modifying kernel to accept indices is best.
# However, for now, to save complexity, we can re-run global encode.
# It's redundant for non-done envs but correct.
# Better: Reset modifies batch_obs directly? No, reset_kernel doesn't encode.
# Let's re-run global encode. It's fast (GPU) compared to CPU loop.
if self.obs_mode == "ATTENTION":
encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.opp_global_ctx,
turn_num,
self.batch_obs,
)
else:
encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
turn_num,
self.batch_obs,
)
else:
# No resets needed. Just encode once to get next states.
# Encode observations
turn_num = 1
if self.obs_mode == "ATTENTION":
encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.opp_global_ctx,
turn_num,
self.batch_obs,
)
else:
encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
turn_num,
self.batch_obs,
)
return self.batch_obs, self.rewards, self.dones, infos
def get_observations(self):
"""Return observation buffer (CuPy array)."""
return self.batch_obs
def get_action_masks(self):
"""Compute and return action masks (CuPy array)."""
if not HAS_CUDA:
return cp.ones((self.num_envs, 2000), dtype=cp.bool_)
from ai.cuda_kernels import compute_action_masks_kernel
masks = cp.zeros((self.num_envs, 2000), dtype=cp.bool_)
compute_action_masks_kernel[self.blocks_per_grid, self.threads_per_block](
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_tapped,
self.batch_global_ctx,
self.batch_live,
self.card_stats,
masks,
)
return masks
# ============================================================================
# BENCHMARK
# ============================================================================
def benchmark_gpu_env(num_envs=4096, steps=1000):
"""Benchmark GPU environment throughput."""
print("\n=== GPU Environment Benchmark ===")
print(f"Environments: {num_envs}")
print(f"Steps: {steps}")
env = VectorEnvGPU(num_envs=num_envs)
env.reset()
# Warmup
for _ in range(10):
actions = cp.zeros(num_envs, dtype=cp.int32)
env.step(actions)
if HAS_CUDA:
cuda.synchronize()
# Benchmark
start = time.time()
for _ in range(steps):
actions = cp.zeros(num_envs, dtype=cp.int32) # Pass action
env.step(actions)
if HAS_CUDA:
cuda.synchronize()
elapsed = time.time() - start
total_steps = num_envs * steps
sps = total_steps / elapsed
print("\nResults:")
print(f" Total Steps: {total_steps:,}")
print(f" Time: {elapsed:.2f}s")
print(f" Throughput: {sps:,.0f} steps/sec")
return sps
if __name__ == "__main__":
# Quick test
env = VectorEnvGPU(num_envs=128)
obs = env.reset()
print(f"Observation shape: {obs.shape}")
actions = cp.zeros(128, dtype=cp.int32)
obs, rewards, dones, infos = env.step(actions)
print(f"Step completed. Rewards shape: {rewards.shape}")
# Benchmark
benchmark_gpu_env(num_envs=1024, steps=100)