trioskosmos commited on
Commit
250ac83
·
verified ·
1 Parent(s): dc702ee

Upload ai/environments/vector_env_gpu.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/environments/vector_env_gpu.py +891 -0
ai/environments/vector_env_gpu.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU-Native Vectorized Game Environment.
3
+
4
+ This module provides VectorEnvGPU - a GPU-resident implementation using CuPy
5
+ and Numba CUDA for maximum throughput. All game state arrays live in GPU VRAM,
6
+ eliminating PCI-E transfer overhead during RL training.
7
+
8
+ Usage:
9
+ Set USE_GPU_ENV=1 to enable GPU environment in training.
10
+ """
11
+
12
+ import json
13
+ import os
14
+ import time
15
+
16
+ import numpy as np
17
+
18
+ # CUDA detection
19
+ HAS_CUDA = False
20
+ try:
21
+ import cupy as cp
22
+ from numba import cuda
23
+
24
+ if cuda.is_available():
25
+ HAS_CUDA = True
26
+ from numba.cuda.random import create_xoroshiro128p_states
27
+ except ImportError:
28
+ pass
29
+
30
+ # Mock objects for CPU fallback
31
+ if not HAS_CUDA:
32
+
33
+ class MockCP:
34
+ int32 = np.int32
35
+ int8 = np.int8
36
+ float32 = np.float32
37
+ bool_ = np.bool_
38
+
39
+ def full(self, shape, val, dtype=None):
40
+ return np.full(shape, val, dtype=dtype)
41
+
42
+ def zeros(self, shape, dtype=None):
43
+ return np.zeros(shape, dtype=dtype)
44
+
45
+ def ones(self, shape, dtype=None):
46
+ return np.ones(shape, dtype=dtype)
47
+
48
+ def asnumpy(self, arr):
49
+ return np.array(arr)
50
+
51
+ def array(self, arr, dtype=None):
52
+ return np.array(arr, dtype=dtype)
53
+
54
+ def asarray(self, arr, dtype=None):
55
+ return np.asarray(arr, dtype=dtype)
56
+
57
+ def arange(self, n, dtype=None):
58
+ return np.arange(n, dtype=dtype)
59
+
60
+ def get_default_memory_pool(self):
61
+ class MockPool:
62
+ def used_bytes(self):
63
+ return 0
64
+
65
+ return MockPool()
66
+
67
+ cp = MockCP()
68
+
69
+ class MockCudaMod:
70
+ def to_device(self, arr):
71
+ return arr
72
+
73
+ def device_array(self, shape, dtype=None):
74
+ return np.zeros(shape, dtype=dtype)
75
+
76
+ def synchronize(self):
77
+ pass
78
+
79
+ def jit(self, *args, **kwargs):
80
+ return lambda x: x
81
+
82
+ def grid(self, x):
83
+ return 0
84
+
85
+ cuda = MockCudaMod()
86
+
87
+ def create_xoroshiro128p_states(n, seed):
88
+ return None
89
+
90
+
91
+ class VectorEnvGPU:
92
+ """
93
+ GPU-Resident Vectorized Game Environment.
94
+
95
+ All state arrays are CuPy arrays in GPU VRAM.
96
+ Observations and actions are passed as GPU tensors with zero-copy.
97
+
98
+ Args:
99
+ num_envs: Number of parallel environments
100
+ opp_mode: Opponent mode (0=Heuristic, 1=Random)
101
+ force_start_order: -1=Random, 0=P1, 1=P2
102
+ """
103
+
104
+ def __init__(self, num_envs: int = 4096, opp_mode: int = 0, force_start_order: int = -1, seed: int = 42):
105
+ self.num_envs = num_envs
106
+ self.opp_mode = opp_mode # 0=Heuristic, 1=Random, 2=Solitaire
107
+ self.force_start_order = force_start_order
108
+ self.seed = seed
109
+
110
+ print(f" [VectorEnvGPU] Initializing {num_envs} environments. CUDA: {HAS_CUDA}")
111
+
112
+ # =========================================================
113
+ # AGENT STATE (GPU-Resident)
114
+ # =========================================================
115
+ self.batch_stage = cp.full((num_envs, 3), -1, dtype=cp.int32)
116
+ self.batch_energy_vec = cp.zeros((num_envs, 3, 32), dtype=cp.int32)
117
+ self.batch_energy_count = cp.zeros((num_envs, 3), dtype=cp.int32)
118
+ self.batch_continuous_vec = cp.zeros((num_envs, 32, 10), dtype=cp.int32)
119
+ self.batch_continuous_ptr = cp.zeros(num_envs, dtype=cp.int32)
120
+ self.batch_tapped = cp.zeros((num_envs, 16), dtype=cp.int32)
121
+ self.batch_live = cp.zeros((num_envs, 50), dtype=cp.int32)
122
+ self.batch_opp_tapped = cp.zeros((num_envs, 16), dtype=cp.int32)
123
+ self.batch_scores = cp.zeros(num_envs, dtype=cp.int32)
124
+
125
+ self.batch_flat_ctx = cp.zeros((num_envs, 64), dtype=cp.int32)
126
+ self.batch_global_ctx = cp.zeros((num_envs, 128), dtype=cp.int32)
127
+
128
+ self.batch_hand = cp.zeros((num_envs, 60), dtype=cp.int32)
129
+ self.batch_deck = cp.zeros((num_envs, 60), dtype=cp.int32)
130
+ self.batch_trash = cp.zeros((num_envs, 60), dtype=cp.int32)
131
+ self.batch_opp_history = cp.zeros((num_envs, 6), dtype=cp.int32)
132
+
133
+ # =========================================================
134
+ # OPPONENT STATE (GPU-Resident)
135
+ # =========================================================
136
+ self.opp_stage = cp.full((num_envs, 3), -1, dtype=cp.int32)
137
+ self.opp_energy_vec = cp.zeros((num_envs, 3, 32), dtype=cp.int32)
138
+ self.opp_energy_count = cp.zeros((num_envs, 3), dtype=cp.int32)
139
+ self.opp_tapped = cp.zeros((num_envs, 16), dtype=cp.int8)
140
+ self.opp_live = cp.zeros((num_envs, 50), dtype=cp.int32)
141
+ self.opp_scores = cp.zeros(num_envs, dtype=cp.int32)
142
+ self.opp_global_ctx = cp.zeros((num_envs, 128), dtype=cp.int32)
143
+ self.opp_hand = cp.zeros((num_envs, 60), dtype=cp.int32)
144
+ self.opp_deck = cp.zeros((num_envs, 60), dtype=cp.int32)
145
+ self.opp_trash = cp.zeros((num_envs, 60), dtype=cp.int32)
146
+
147
+ # =========================================================
148
+ # TRACKING STATE
149
+ # =========================================================
150
+ self.prev_scores = cp.zeros(num_envs, dtype=cp.int32)
151
+ self.prev_opp_scores = cp.zeros(num_envs, dtype=cp.int32)
152
+ self.prev_phases = cp.zeros(num_envs, dtype=cp.int32)
153
+ self.episode_returns = cp.zeros(num_envs, dtype=cp.float32)
154
+ self.episode_lengths = cp.zeros(num_envs, dtype=cp.int32)
155
+
156
+ # =========================================================
157
+ # OBSERVATION MODE
158
+ # =========================================================
159
+ self.obs_mode = os.getenv("OBS_MODE", "STANDARD")
160
+ if self.obs_mode == "COMPRESSED":
161
+ self.obs_dim = 512
162
+ elif self.obs_mode == "IMAX":
163
+ self.obs_dim = 8192
164
+ elif self.obs_mode == "ATTENTION":
165
+ self.obs_dim = 2240
166
+ else:
167
+ self.obs_dim = 2304
168
+ print(f" [VectorEnvGPU] Observation Mode: {self.obs_mode} ({self.obs_dim}-dim)")
169
+
170
+ self.batch_obs = cp.zeros((num_envs, self.obs_dim), dtype=cp.float32)
171
+ self.terminal_obs_buffer = cp.zeros((num_envs, self.obs_dim), dtype=cp.float32)
172
+
173
+ # Rewards and Dones
174
+ self.rewards = cp.zeros(num_envs, dtype=cp.float32)
175
+ self.dones = cp.zeros(num_envs, dtype=cp.bool_)
176
+ self.term_scores_agent = cp.zeros(num_envs, dtype=cp.int32)
177
+ self.term_scores_opp = cp.zeros(num_envs, dtype=cp.int32)
178
+
179
+ # =========================================================
180
+ # GAME CONFIG
181
+ # =========================================================
182
+ self.scenario_reward_scale = float(os.getenv("SCENARIO_REWARD_SCALE", "1.0"))
183
+ if os.getenv("USE_SCENARIOS", "0") == "1" and self.scenario_reward_scale != 1.0:
184
+ print(f" [VectorEnvGPU] Scenario Reward Scale: {self.scenario_reward_scale}")
185
+
186
+ self.game_config = cp.zeros(10, dtype=cp.float32)
187
+ self.game_config[0] = float(os.getenv("GAME_TURN_LIMIT", "100"))
188
+ self.game_config[1] = float(os.getenv("GAME_STEP_LIMIT", "1000"))
189
+ self.game_config[2] = float(os.getenv("GAME_REWARD_WIN", "100.0"))
190
+ self.game_config[3] = float(os.getenv("GAME_REWARD_LOSE", "-100.0"))
191
+ self.game_config[4] = float(os.getenv("GAME_REWARD_SCORE_SCALE", "50.0"))
192
+ self.game_config[5] = float(os.getenv("GAME_REWARD_TURN_PENALTY", "-0.05"))
193
+
194
+ # =========================================================
195
+ # GPU RNG
196
+ # =========================================================
197
+ if HAS_CUDA:
198
+ self.rng_states = create_xoroshiro128p_states(num_envs, seed=seed)
199
+ else:
200
+ self.rng_states = None
201
+
202
+ # =========================================================
203
+ # KERNEL CONFIGURATION
204
+ # =========================================================
205
+ self.threads_per_block = 128
206
+ self.blocks_per_grid = (num_envs + self.threads_per_block - 1) // self.threads_per_block
207
+
208
+ # =========================================================
209
+ # LOAD DATA
210
+ # =========================================================
211
+ self._load_bytecode()
212
+ self._load_card_stats()
213
+ self._load_deck_pool()
214
+
215
+ # Memory stats
216
+ if HAS_CUDA:
217
+ mempool = cp.get_default_memory_pool()
218
+ used_mb = mempool.used_bytes() / 1024 / 1024
219
+ print(f" [VectorEnvGPU] GPU VRAM used: {used_mb:.2f} MB")
220
+
221
+ def _load_bytecode(self):
222
+ """Load compiled bytecode to GPU."""
223
+ host_map = np.zeros((100, 128, 4), dtype=np.int32)
224
+ host_idx = np.zeros((2000, 8), dtype=np.int32)
225
+
226
+ try:
227
+ with open("data/cards_numba.json", "r") as f:
228
+ raw_map = json.load(f)
229
+
230
+ max_cards = 2000
231
+ max_abilities = 8
232
+ max_len = 128
233
+
234
+ unique_entries = len(raw_map)
235
+ host_map = np.zeros((unique_entries + 1, max_len, 4), dtype=np.int32)
236
+ host_idx = np.full((max_cards, max_abilities), 0, dtype=np.int32)
237
+
238
+ idx_counter = 1
239
+ for key, bc_list in raw_map.items():
240
+ cid, aid = map(int, key.split("_"))
241
+ if cid < max_cards and aid < max_abilities:
242
+ bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
243
+ length = min(bc_arr.shape[0], max_len)
244
+ host_map[idx_counter, :length] = bc_arr[:length]
245
+ host_idx[cid, aid] = idx_counter
246
+ idx_counter += 1
247
+
248
+ print(f" [VectorEnvGPU] Loaded {unique_entries} compiled abilities.")
249
+ except FileNotFoundError:
250
+ print(" [VectorEnvGPU] Warning: cards_numba.json not found.")
251
+ except Exception as e:
252
+ print(f" [VectorEnvGPU] Warning: Failed to load bytecode: {e}")
253
+
254
+ self.bytecode_map = cp.asarray(host_map)
255
+ self.bytecode_index = cp.asarray(host_idx)
256
+
257
+ def _load_card_stats(self):
258
+ """Load card statistics to GPU."""
259
+ host_stats = np.zeros((2000, 80), dtype=np.int32)
260
+
261
+ try:
262
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
263
+ db = json.load(f)
264
+
265
+ count = 0
266
+ if "member_db" in db:
267
+ for cid_str, card in db["member_db"].items():
268
+ cid = int(cid_str)
269
+ if cid < 2000:
270
+ host_stats[cid, 0] = card.get("cost", 0)
271
+ host_stats[cid, 1] = card.get("blades", 0)
272
+ host_stats[cid, 2] = sum(card.get("hearts", []))
273
+ host_stats[cid, 10] = 1 # Type: Member
274
+
275
+ # Hearts breakdown
276
+ h_arr = card.get("hearts", [])
277
+ for r_idx in range(min(len(h_arr), 7)):
278
+ host_stats[cid, 12 + r_idx] = h_arr[r_idx]
279
+
280
+ # Traits
281
+ mask = 0
282
+ for g in card.get("groups", []):
283
+ try:
284
+ mask |= 1 << (int(g) % 20)
285
+ except:
286
+ pass
287
+ host_stats[cid, 11] = mask
288
+ count += 1
289
+
290
+ if "live_db" in db:
291
+ for cid_str, card in db["live_db"].items():
292
+ cid = int(cid_str)
293
+ if cid < 2000:
294
+ host_stats[cid, 10] = 2 # Type: Live
295
+ reqs = card.get("required_hearts", [])
296
+ for r_idx in range(min(len(reqs), 7)):
297
+ host_stats[cid, 12 + r_idx] = reqs[r_idx]
298
+ host_stats[cid, 38] = card.get("score", 0)
299
+ count += 1
300
+
301
+ print(f" [VectorEnvGPU] Loaded stats for {count} cards.")
302
+ except Exception as e:
303
+ print(f" [VectorEnvGPU] Warning: Failed to load card stats: {e}")
304
+
305
+ self.card_stats = cp.asarray(host_stats)
306
+
307
+ def _load_deck_pool(self):
308
+ """Load verified card pool for deck generation."""
309
+ ability_member_ids = []
310
+ ability_live_ids = []
311
+
312
+ try:
313
+ with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
314
+ verified_data = json.load(f)
315
+
316
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
317
+ db_data = json.load(f)
318
+
319
+ member_no_map = {}
320
+ live_no_map = {}
321
+ for cid, cdata in db_data.get("member_db", {}).items():
322
+ member_no_map[cdata["card_no"]] = int(cid)
323
+ for cid, cdata in db_data.get("live_db", {}).items():
324
+ live_no_map[cdata["card_no"]] = int(cid)
325
+
326
+ if isinstance(verified_data, list):
327
+ for v_no in verified_data:
328
+ if v_no in member_no_map:
329
+ ability_member_ids.append(member_no_map[v_no])
330
+ elif v_no in live_no_map:
331
+ ability_live_ids.append(live_no_map[v_no])
332
+ else:
333
+ source_members = verified_data.get("verified_abilities", []) + verified_data.get("members", [])
334
+ for v_no in source_members:
335
+ if v_no in member_no_map:
336
+ ability_member_ids.append(member_no_map[v_no])
337
+
338
+ source_lives = verified_data.get("verified_lives", []) + verified_data.get("lives", [])
339
+ for v_no in source_lives:
340
+ if v_no in live_no_map:
341
+ ability_live_ids.append(live_no_map[v_no])
342
+
343
+ if not ability_member_ids:
344
+ for v_no in verified_data.get("vanilla_members", []):
345
+ if v_no in member_no_map:
346
+ ability_member_ids.append(member_no_map[v_no])
347
+ if not ability_live_ids:
348
+ for v_no in verified_data.get("vanilla_lives", []):
349
+ if v_no in live_no_map:
350
+ ability_live_ids.append(live_no_map[v_no])
351
+
352
+ if not ability_member_ids:
353
+ ability_member_ids = [1]
354
+ if not ability_live_ids:
355
+ ability_live_ids = [999]
356
+
357
+ print(f" [VectorEnvGPU] Deck Pool: {len(ability_member_ids)} members, {len(ability_live_ids)} lives")
358
+ except Exception as e:
359
+ print(f" [VectorEnvGPU] Deck Load Error: {e}")
360
+ ability_member_ids = [1]
361
+ ability_live_ids = [999]
362
+
363
+ self.ability_member_ids = cp.array(ability_member_ids, dtype=cp.int32)
364
+ self.ability_live_ids = cp.array(ability_live_ids, dtype=cp.int32)
365
+
366
+ # =========================================================
367
+ # PYTORCH INTERFACE
368
+ # =========================================================
369
+
370
+ def get_observations_tensor(self):
371
+ """Return observations as PyTorch CUDA tensor (zero-copy)."""
372
+ import torch
373
+
374
+ return torch.as_tensor(self.batch_obs, device="cuda")
375
+
376
+ def get_action_masks_tensor(self):
377
+ """Return action masks as PyTorch CUDA tensor."""
378
+ import torch
379
+
380
+ masks = self.get_action_masks()
381
+ return torch.as_tensor(masks, device="cuda")
382
+
383
+ def get_rewards_tensor(self):
384
+ """Return rewards as PyTorch CUDA tensor."""
385
+ import torch
386
+
387
+ return torch.as_tensor(self.rewards, device="cuda")
388
+
389
+ def get_dones_tensor(self):
390
+ """Return dones as PyTorch CUDA tensor."""
391
+ import torch
392
+
393
+ return torch.as_tensor(self.dones, device="cuda")
394
+
395
+ # =========================================================
396
+ # ENVIRONMENT INTERFACE
397
+ # =========================================================
398
+
399
+ def reset(self, indices=None):
400
+ """Reset environments."""
401
+ if not HAS_CUDA:
402
+ # CPU fallback
403
+ self.batch_stage.fill(-1)
404
+ self.batch_scores.fill(0)
405
+ self.batch_global_ctx.fill(0)
406
+ self.batch_hand.fill(0)
407
+ self.batch_deck.fill(0)
408
+ return self.batch_obs
409
+
410
+ from ai.cuda_kernels import encode_observations_attention_kernel, encode_observations_kernel, reset_kernel
411
+
412
+ if indices is None:
413
+ indices_gpu = cp.arange(self.num_envs, dtype=cp.int32)
414
+ else:
415
+ indices_gpu = cp.array(indices, dtype=cp.int32)
416
+
417
+ blocks = (len(indices_gpu) + self.threads_per_block - 1) // self.threads_per_block
418
+
419
+ reset_kernel[blocks, self.threads_per_block](
420
+ indices_gpu,
421
+ self.batch_stage,
422
+ self.batch_energy_vec,
423
+ self.batch_energy_count,
424
+ self.batch_continuous_vec,
425
+ self.batch_continuous_ptr,
426
+ self.batch_tapped,
427
+ self.batch_live,
428
+ self.batch_scores,
429
+ self.batch_flat_ctx,
430
+ self.batch_global_ctx,
431
+ self.batch_hand,
432
+ self.batch_deck,
433
+ self.batch_trash,
434
+ self.batch_opp_history,
435
+ self.opp_stage,
436
+ self.opp_energy_vec,
437
+ self.opp_energy_count,
438
+ self.opp_tapped,
439
+ self.opp_live,
440
+ self.opp_scores,
441
+ self.opp_global_ctx,
442
+ self.opp_hand,
443
+ self.opp_deck,
444
+ self.opp_trash,
445
+ self.ability_member_ids,
446
+ self.ability_live_ids,
447
+ self.rng_states,
448
+ self.force_start_order,
449
+ self.batch_obs,
450
+ self.card_stats,
451
+ )
452
+
453
+ # Encode initial observations
454
+ if self.obs_mode == "ATTENTION":
455
+ encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
456
+ self.num_envs,
457
+ self.batch_hand,
458
+ self.batch_stage,
459
+ self.batch_energy_count,
460
+ self.batch_tapped,
461
+ self.batch_scores,
462
+ self.opp_scores,
463
+ self.opp_stage,
464
+ self.opp_tapped,
465
+ self.card_stats,
466
+ self.batch_global_ctx,
467
+ self.batch_live,
468
+ self.batch_opp_history,
469
+ self.opp_global_ctx,
470
+ 1,
471
+ self.batch_obs,
472
+ )
473
+ else:
474
+ encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
475
+ self.num_envs,
476
+ self.batch_hand,
477
+ self.batch_stage,
478
+ self.batch_energy_count,
479
+ self.batch_tapped,
480
+ self.batch_scores,
481
+ self.opp_scores,
482
+ self.opp_stage,
483
+ self.opp_tapped,
484
+ self.card_stats,
485
+ self.batch_global_ctx,
486
+ self.batch_live,
487
+ 1,
488
+ self.batch_obs,
489
+ )
490
+
491
+ # Reset tracking
492
+ if indices is None:
493
+ self.prev_scores.fill(0)
494
+ self.prev_opp_scores.fill(0)
495
+ self.episode_returns.fill(0)
496
+ self.episode_lengths.fill(0)
497
+ else:
498
+ self.prev_scores[indices_gpu] = 0
499
+ self.prev_opp_scores[indices_gpu] = 0
500
+ self.episode_returns[indices_gpu] = 0
501
+ self.episode_lengths[indices_gpu] = 0
502
+
503
+ return self.batch_obs
504
+
505
+ def step(self, actions):
506
+ """
507
+ Step all environments.
508
+
509
+ Args:
510
+ actions: CuPy array or PyTorch tensor of actions
511
+
512
+ Returns:
513
+ obs, rewards, dones, infos
514
+ """
515
+ if not HAS_CUDA:
516
+ # Fallback
517
+ return self.batch_obs, self.rewards, self.dones, [{}] * self.num_envs
518
+
519
+ import torch
520
+ from ai.cuda_kernels import (
521
+ encode_observations_attention_kernel,
522
+ encode_observations_kernel,
523
+ reset_kernel,
524
+ step_kernel,
525
+ )
526
+
527
+ # Convert to CuPy if needed
528
+ if isinstance(actions, torch.Tensor):
529
+ actions_gpu = cp.asarray(actions.cpu().numpy(), dtype=cp.int32)
530
+ elif isinstance(actions, np.ndarray):
531
+ actions_gpu = cp.asarray(actions, dtype=cp.int32)
532
+ else:
533
+ actions_gpu = actions
534
+
535
+ # 1. Step kernel
536
+ step_kernel[self.blocks_per_grid, self.threads_per_block](
537
+ self.num_envs,
538
+ actions_gpu,
539
+ self.batch_hand,
540
+ self.batch_deck,
541
+ self.batch_stage,
542
+ self.batch_energy_vec,
543
+ self.batch_energy_count,
544
+ self.batch_continuous_vec,
545
+ self.batch_continuous_ptr,
546
+ self.batch_tapped,
547
+ self.batch_live,
548
+ self.batch_scores,
549
+ self.batch_flat_ctx,
550
+ self.batch_global_ctx,
551
+ self.opp_hand,
552
+ self.opp_deck,
553
+ self.opp_stage,
554
+ self.opp_energy_vec,
555
+ self.opp_energy_count,
556
+ self.opp_tapped,
557
+ self.opp_live,
558
+ self.opp_scores,
559
+ self.opp_global_ctx,
560
+ self.card_stats,
561
+ self.bytecode_map,
562
+ self.bytecode_index,
563
+ self.batch_obs,
564
+ self.rewards,
565
+ self.dones,
566
+ self.prev_scores,
567
+ self.prev_opp_scores,
568
+ self.prev_phases,
569
+ self.terminal_obs_buffer,
570
+ self.batch_trash,
571
+ self.opp_trash,
572
+ self.batch_opp_history,
573
+ self.term_scores_agent,
574
+ self.term_scores_opp,
575
+ self.ability_member_ids,
576
+ self.ability_live_ids,
577
+ self.rng_states,
578
+ self.game_config,
579
+ self.opp_mode,
580
+ self.force_start_order,
581
+ )
582
+
583
+ # Apply Scenario Reward Scaling
584
+ if self.scenario_reward_scale != 1.0 and os.getenv("USE_SCENARIOS", "0") == "1":
585
+ self.rewards *= self.scenario_reward_scale
586
+
587
+ # 2. Update Episodic Returns/Lengths (Vectorized GPU)
588
+ self.episode_returns += self.rewards
589
+ self.episode_lengths += 1
590
+
591
+ # 3. Handle Auto-Reset (High Performance)
592
+ dones_cpu = cp.asnumpy(self.dones)
593
+
594
+ # Pre-allocate infos list (reused or created)
595
+ infos = [{} for _ in range(self.num_envs)]
596
+
597
+ if np.any(dones_cpu):
598
+ done_indices = np.where(dones_cpu)[0]
599
+ done_indices_gpu = cp.array(done_indices, dtype=cp.int32)
600
+
601
+ # A. Capture Terminal Observations (from UNRESET state)
602
+ # Efficient Device-to-Device copy
603
+ # NOTE: step_kernel leaves env in finished state, so batch_obs has terminal state.
604
+ # We must encode it?
605
+ # Actually, step_kernel calls encode at end? No, step_kernel does NOT encode obs in my implementation.
606
+ # I removed the Python-side encode calls from previous impl?
607
+ # Wait, step_kernel logic in my head vs file.
608
+ # In ai/cuda_kernels.py, step_kernel does NOT call encode.
609
+ # So batch_obs is STALE (from previous step)!
610
+ # We MUST encode the terminal state first.
611
+
612
+ # Encode CURRENT state (Terminal) for ALL envs? Or just done?
613
+ # Usually we encode all envs at end of step.
614
+ # BUT we need to reset done envs and encode AGAIN.
615
+
616
+ # OPTIMIZATION:
617
+ # 1. Encode ALL envs (Next state for running, Terminal for done).
618
+ turn_num = 1 # Dummy, kernels use ctx
619
+ if self.obs_mode == "ATTENTION":
620
+ encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
621
+ self.num_envs,
622
+ self.batch_hand,
623
+ self.batch_stage,
624
+ self.batch_energy_count,
625
+ self.batch_tapped,
626
+ self.batch_scores,
627
+ self.opp_scores,
628
+ self.opp_stage,
629
+ self.opp_tapped,
630
+ self.card_stats,
631
+ self.batch_global_ctx,
632
+ self.batch_live,
633
+ self.batch_opp_history,
634
+ self.opp_global_ctx,
635
+ turn_num,
636
+ self.batch_obs,
637
+ )
638
+ else:
639
+ encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
640
+ self.num_envs,
641
+ self.batch_hand,
642
+ self.batch_stage,
643
+ self.batch_energy_count,
644
+ self.batch_tapped,
645
+ self.batch_scores,
646
+ self.opp_scores,
647
+ self.opp_stage,
648
+ self.opp_tapped,
649
+ self.card_stats,
650
+ self.batch_global_ctx,
651
+ self.batch_live,
652
+ turn_num,
653
+ self.batch_obs,
654
+ )
655
+
656
+ # 2. For Done Envs: Copy encoded terminal state to buffer
657
+ # We can use fancy indexing copy on GPU
658
+ self.terminal_obs_buffer[done_indices_gpu] = self.batch_obs[done_indices_gpu]
659
+
660
+ # 3. Fetch Terminal Info Metrics (Bulk D2H)
661
+ final_returns = cp.asnumpy(self.episode_returns[done_indices_gpu])
662
+ final_lengths = cp.asnumpy(self.episode_lengths[done_indices_gpu])
663
+ term_obs_cpu = cp.asnumpy(self.terminal_obs_buffer[done_indices_gpu])
664
+ term_scores_ag = cp.asnumpy(self.term_scores_agent[done_indices_gpu])
665
+ term_scores_op = cp.asnumpy(self.term_scores_opp[done_indices_gpu])
666
+
667
+ # 4. Populate Infos (CPU Loop over SMALL subset)
668
+ for k, idx in enumerate(done_indices):
669
+ infos[idx] = {
670
+ "terminal_observation": term_obs_cpu[k],
671
+ "episode": {"r": float(final_returns[k]), "l": int(final_lengths[k])},
672
+ "terminal_score_agent": int(term_scores_ag[k]),
673
+ "terminal_score_opp": int(term_scores_op[k]),
674
+ }
675
+
676
+ # 5. Reset Done Envs
677
+ # Reset accumulators
678
+ self.episode_returns[done_indices_gpu] = 0
679
+ self.episode_lengths[done_indices_gpu] = 0
680
+
681
+ # Launch Reset Kernel
682
+ blocks_reset = (len(done_indices) + self.threads_per_block - 1) // self.threads_per_block
683
+ reset_kernel[blocks_reset, self.threads_per_block](
684
+ done_indices_gpu,
685
+ self.batch_stage,
686
+ self.batch_energy_vec,
687
+ self.batch_energy_count,
688
+ self.batch_continuous_vec,
689
+ self.batch_continuous_ptr,
690
+ self.batch_tapped,
691
+ self.batch_live,
692
+ self.batch_scores,
693
+ self.batch_flat_ctx,
694
+ self.batch_global_ctx,
695
+ self.batch_hand,
696
+ self.batch_deck,
697
+ self.batch_trash,
698
+ self.batch_opp_history,
699
+ self.opp_stage,
700
+ self.opp_energy_vec,
701
+ self.opp_energy_count,
702
+ self.opp_tapped,
703
+ self.opp_live,
704
+ self.opp_scores,
705
+ self.opp_global_ctx,
706
+ self.opp_hand,
707
+ self.opp_deck,
708
+ self.opp_trash,
709
+ self.ability_member_ids,
710
+ self.ability_live_ids,
711
+ self.rng_states,
712
+ self.force_start_order,
713
+ self.batch_obs,
714
+ self.card_stats,
715
+ )
716
+
717
+ # 6. Re-Encode Reset Envs (to get initial state)
718
+ # We assume reset_kernel updates state but NOT obs.
719
+ # We need to re-run encode kernel ONLY for done indices?
720
+ # Or run global encode again? Global is waste.
721
+ # We need an encode kernel that takes indices.
722
+ # The current kernel takes `num_envs` and assumes `0..N`.
723
+ # We can reuse the global kernel if we are clever or modify it.
724
+ # Modifying kernel to accept indices is best.
725
+ # However, for now, to save complexity, we can re-run global encode.
726
+ # It's redundant for non-done envs but correct.
727
+ # Better: Reset modifies batch_obs directly? No, reset_kernel doesn't encode.
728
+
729
+ # Let's re-run global encode. It's fast (GPU) compared to CPU loop.
730
+ if self.obs_mode == "ATTENTION":
731
+ encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
732
+ self.num_envs,
733
+ self.batch_hand,
734
+ self.batch_stage,
735
+ self.batch_energy_count,
736
+ self.batch_tapped,
737
+ self.batch_scores,
738
+ self.opp_scores,
739
+ self.opp_stage,
740
+ self.opp_tapped,
741
+ self.card_stats,
742
+ self.batch_global_ctx,
743
+ self.batch_live,
744
+ self.batch_opp_history,
745
+ self.opp_global_ctx,
746
+ turn_num,
747
+ self.batch_obs,
748
+ )
749
+ else:
750
+ encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
751
+ self.num_envs,
752
+ self.batch_hand,
753
+ self.batch_stage,
754
+ self.batch_energy_count,
755
+ self.batch_tapped,
756
+ self.batch_scores,
757
+ self.opp_scores,
758
+ self.opp_stage,
759
+ self.opp_tapped,
760
+ self.card_stats,
761
+ self.batch_global_ctx,
762
+ self.batch_live,
763
+ turn_num,
764
+ self.batch_obs,
765
+ )
766
+
767
+ else:
768
+ # No resets needed. Just encode once to get next states.
769
+ # Encode observations
770
+ turn_num = 1
771
+ if self.obs_mode == "ATTENTION":
772
+ encode_observations_attention_kernel[self.blocks_per_grid, self.threads_per_block](
773
+ self.num_envs,
774
+ self.batch_hand,
775
+ self.batch_stage,
776
+ self.batch_energy_count,
777
+ self.batch_tapped,
778
+ self.batch_scores,
779
+ self.opp_scores,
780
+ self.opp_stage,
781
+ self.opp_tapped,
782
+ self.card_stats,
783
+ self.batch_global_ctx,
784
+ self.batch_live,
785
+ self.batch_opp_history,
786
+ self.opp_global_ctx,
787
+ turn_num,
788
+ self.batch_obs,
789
+ )
790
+ else:
791
+ encode_observations_kernel[self.blocks_per_grid, self.threads_per_block](
792
+ self.num_envs,
793
+ self.batch_hand,
794
+ self.batch_stage,
795
+ self.batch_energy_count,
796
+ self.batch_tapped,
797
+ self.batch_scores,
798
+ self.opp_scores,
799
+ self.opp_stage,
800
+ self.opp_tapped,
801
+ self.card_stats,
802
+ self.batch_global_ctx,
803
+ self.batch_live,
804
+ turn_num,
805
+ self.batch_obs,
806
+ )
807
+
808
+ return self.batch_obs, self.rewards, self.dones, infos
809
+
810
+ def get_observations(self):
811
+ """Return observation buffer (CuPy array)."""
812
+ return self.batch_obs
813
+
814
+ def get_action_masks(self):
815
+ """Compute and return action masks (CuPy array)."""
816
+ if not HAS_CUDA:
817
+ return cp.ones((self.num_envs, 2000), dtype=cp.bool_)
818
+
819
+ from ai.cuda_kernels import compute_action_masks_kernel
820
+
821
+ masks = cp.zeros((self.num_envs, 2000), dtype=cp.bool_)
822
+
823
+ compute_action_masks_kernel[self.blocks_per_grid, self.threads_per_block](
824
+ self.num_envs,
825
+ self.batch_hand,
826
+ self.batch_stage,
827
+ self.batch_tapped,
828
+ self.batch_global_ctx,
829
+ self.batch_live,
830
+ self.card_stats,
831
+ masks,
832
+ )
833
+
834
+ return masks
835
+
836
+
837
+ # ============================================================================
838
+ # BENCHMARK
839
+ # ============================================================================
840
+
841
+
842
+ def benchmark_gpu_env(num_envs=4096, steps=1000):
843
+ """Benchmark GPU environment throughput."""
844
+ print("\n=== GPU Environment Benchmark ===")
845
+ print(f"Environments: {num_envs}")
846
+ print(f"Steps: {steps}")
847
+
848
+ env = VectorEnvGPU(num_envs=num_envs)
849
+ env.reset()
850
+
851
+ # Warmup
852
+ for _ in range(10):
853
+ actions = cp.zeros(num_envs, dtype=cp.int32)
854
+ env.step(actions)
855
+
856
+ if HAS_CUDA:
857
+ cuda.synchronize()
858
+
859
+ # Benchmark
860
+ start = time.time()
861
+ for _ in range(steps):
862
+ actions = cp.zeros(num_envs, dtype=cp.int32) # Pass action
863
+ env.step(actions)
864
+
865
+ if HAS_CUDA:
866
+ cuda.synchronize()
867
+
868
+ elapsed = time.time() - start
869
+ total_steps = num_envs * steps
870
+ sps = total_steps / elapsed
871
+
872
+ print("\nResults:")
873
+ print(f" Total Steps: {total_steps:,}")
874
+ print(f" Time: {elapsed:.2f}s")
875
+ print(f" Throughput: {sps:,.0f} steps/sec")
876
+
877
+ return sps
878
+
879
+
880
+ if __name__ == "__main__":
881
+ # Quick test
882
+ env = VectorEnvGPU(num_envs=128)
883
+ obs = env.reset()
884
+ print(f"Observation shape: {obs.shape}")
885
+
886
+ actions = cp.zeros(128, dtype=cp.int32)
887
+ obs, rewards, dones, infos = env.step(actions)
888
+ print(f"Step completed. Rewards shape: {rewards.shape}")
889
+
890
+ # Benchmark
891
+ benchmark_gpu_env(num_envs=1024, steps=100)