Spaces:
Running
Running
| """ | |
| CUDA Kernels for GPU-Accelerated VectorEnv. | |
| This module contains CUDA kernel implementations for: | |
| - Environment reset | |
| - Game step (integrated with opponent, phases, scoring) | |
| - Observation encoding | |
| - Action mask computation | |
| All kernels are designed for the VectorGameStateGPU class. | |
| """ | |
| import numpy as np | |
| try: | |
| from numba import cuda | |
| from numba.cuda.random import xoroshiro128p_normal_float32, xoroshiro128p_uniform_float32 | |
| HAS_CUDA = True | |
| except ImportError: | |
| HAS_CUDA = False | |
| # Mock for type checking | |
| class MockCuda: | |
| def jit(self, *args, **kwargs): | |
| def decorator(f): | |
| return f | |
| return decorator | |
| def grid(self, x): | |
| return 0 | |
| cuda = MockCuda() | |
| def xoroshiro128p_uniform_float32(rng, i): | |
| return 0.5 | |
| # ============================================================================ | |
| # CONSTANTS (Must match fast_logic.py) | |
| # ============================================================================ | |
| SC = 0 | |
| OS = 1 | |
| TR = 2 | |
| HD = 3 | |
| DI = 4 | |
| EN = 5 | |
| DK = 6 | |
| OT = 7 | |
| PH = 8 | |
| OD = 9 | |
| # Opcodes | |
| O_DRAW = 10 | |
| O_BLADES = 11 | |
| O_HEARTS = 12 | |
| O_RECOV_L = 13 | |
| O_BOOST = 14 | |
| O_RECOV_M = 15 | |
| O_BUFF = 16 | |
| O_CHARGE = 17 | |
| O_TAP_O = 18 | |
| O_CHOOSE = 19 | |
| O_ADD_H = 20 | |
| O_RETURN = 999 | |
| O_JUMP = 100 | |
| O_JUMP_F = 101 | |
| # Conditions | |
| C_TR1 = 200 | |
| C_CLR = 202 | |
| C_STG = 203 | |
| C_HND = 204 | |
| C_CTR = 206 | |
| C_LLD = 207 | |
| C_GRP = 208 | |
| C_OPH = 210 | |
| C_ENR = 213 | |
| C_CMP = 220 | |
| # Unique ID (UID) System | |
| BASE_ID_MASK = 0xFFFFF | |
| def get_base_id_device(uid: int) -> int: | |
| """Extract the base card definition ID (0-1999) from a UID.""" | |
| return uid & BASE_ID_MASK | |
| # ============================================================================ | |
| # DEVICE FUNCTIONS (Callable from kernels) | |
| # ============================================================================ | |
| def check_deck_refresh_device(p_deck, p_trash, p_global_ctx, DK_idx, TR_idx): | |
| """Shuffle trash back into deck if deck is empty.""" | |
| if p_global_ctx[DK_idx] <= 0: | |
| # Count trash | |
| tr_count = 0 | |
| for t in range(60): | |
| if p_trash[t] > 0: | |
| tr_count += 1 | |
| if tr_count > 0: | |
| # Move trash to deck | |
| d_ptr = 0 | |
| for t in range(60): | |
| if p_trash[t] > 0: | |
| p_deck[d_ptr] = p_trash[t] | |
| p_trash[t] = 0 | |
| d_ptr += 1 | |
| p_global_ctx[DK_idx] = d_ptr | |
| p_global_ctx[TR_idx] = 0 | |
| def move_to_trash_device(card_id, p_trash, p_global_ctx, TR_idx): | |
| """Move a card to trash zone.""" | |
| for t in range(60): | |
| if p_trash[t] == 0: | |
| p_trash[t] = card_id | |
| p_global_ctx[TR_idx] += 1 | |
| break | |
| def draw_cards_device(count, p_hand, p_deck, p_trash, p_global_ctx): | |
| """Draw cards from deck to hand.""" | |
| for _ in range(count): | |
| check_deck_refresh_device(p_deck, p_trash, p_global_ctx, DK, TR) | |
| if p_global_ctx[DK] <= 0: | |
| break | |
| # Find top card in deck | |
| top_card = 0 | |
| d_idx_found = -1 | |
| for d in range(60): | |
| if p_deck[d] > 0: | |
| top_card = p_deck[d] | |
| d_idx_found = d | |
| break | |
| if top_card > 0: | |
| # Find empty hand slot | |
| for h in range(60): | |
| if p_hand[h] == 0: | |
| p_hand[h] = top_card | |
| p_deck[d_idx_found] = 0 | |
| p_global_ctx[DK] -= 1 | |
| p_global_ctx[HD] += 1 | |
| break | |
| def resolve_bytecode_device( | |
| bytecode, | |
| flat_ctx, | |
| global_ctx, | |
| player_id, | |
| p_hand, | |
| p_deck, | |
| p_stage, | |
| p_energy_vec, | |
| p_energy_count, | |
| p_cont_vec, | |
| p_cont_ptr, | |
| p_tapped, | |
| p_live, | |
| opp_tapped, | |
| p_trash, | |
| bytecode_map, | |
| bytecode_index, | |
| ): | |
| """ | |
| GPU Device function for resolving bytecode. | |
| Returns (new_cont_ptr, status, bonus). | |
| """ | |
| ip = 0 | |
| cptr = p_cont_ptr | |
| bonus = 0 | |
| cond = True | |
| blen = bytecode.shape[0] | |
| safety_counter = 0 | |
| while ip < blen and safety_counter < 500: | |
| safety_counter += 1 | |
| op = bytecode[ip, 0] | |
| v = bytecode[ip, 1] | |
| a = bytecode[ip, 2] | |
| s = bytecode[ip, 3] | |
| if op == 0: | |
| ip += 1 | |
| continue | |
| if op == O_RETURN: | |
| break | |
| # Jumps | |
| if op == O_JUMP: | |
| new_ip = ip + v | |
| if 0 <= new_ip < blen: | |
| ip = new_ip | |
| else: | |
| break | |
| continue | |
| if op == O_JUMP_F: | |
| if not cond: | |
| new_ip = ip + v | |
| if 0 <= new_ip < blen: | |
| ip = new_ip | |
| else: | |
| break | |
| continue | |
| ip += 1 | |
| continue | |
| # Conditions (op >= 200) | |
| if op >= 200: | |
| if op == C_TR1: | |
| cond = global_ctx[TR] == 1 | |
| elif op == C_STG: | |
| ct = 0 | |
| for j in range(3): | |
| if p_stage[j] != -1: | |
| ct += 1 | |
| cond = ct >= v | |
| elif op == C_HND: | |
| cond = global_ctx[HD] >= v | |
| elif op == C_LLD: | |
| cond = global_ctx[SC] > global_ctx[OS] | |
| elif op == C_ENR: | |
| cond = global_ctx[EN] >= v | |
| elif op == C_CMP: | |
| if v > 0: | |
| cond = global_ctx[SC] >= v | |
| else: | |
| cond = global_ctx[SC] > global_ctx[OS] | |
| elif op == C_OPH: | |
| cond = global_ctx[OT] >= v if v > 0 else global_ctx[OT] > 0 | |
| else: | |
| cond = True | |
| ip += 1 | |
| else: | |
| # Effects | |
| if cond: | |
| if op == O_DRAW: | |
| draw_cards_device(v, p_hand, p_deck, p_trash, global_ctx) | |
| elif op == O_CHARGE: | |
| # Move cards from deck to energy (simplified) | |
| amt = min(v, global_ctx[DK]) | |
| for _ in range(amt): | |
| for d in range(60): | |
| if p_deck[d] > 0: | |
| p_deck[d] = 0 | |
| global_ctx[DK] -= 1 | |
| global_ctx[EN] += 1 | |
| break | |
| elif op == O_HEARTS: | |
| # Add hearts (points) | |
| bonus += v | |
| # Register continuous effect | |
| if cptr < 32: | |
| p_cont_vec[cptr, 0] = 2 | |
| p_cont_vec[cptr, 1] = v | |
| p_cont_vec[cptr, 5] = a | |
| p_cont_vec[cptr, 9] = 1 | |
| cptr += 1 | |
| elif op == O_BLADES: | |
| if cptr < 32: | |
| p_cont_vec[cptr, 0] = 1 | |
| p_cont_vec[cptr, 1] = v | |
| p_cont_vec[cptr, 2] = 4 | |
| p_cont_vec[cptr, 3] = s | |
| p_cont_vec[cptr, 9] = 1 | |
| cptr += 1 | |
| elif op == O_RECOV_M: | |
| if 0 <= s < 3: | |
| p_tapped[s] = 0 | |
| elif op == O_RECOV_L: | |
| if 0 <= s < p_live.shape[0]: | |
| p_live[s] = 0 | |
| elif op == O_TAP_O: | |
| if 0 <= s < 3: | |
| opp_tapped[s] = 1 | |
| elif op == O_BUFF: | |
| if cptr < 32: | |
| p_cont_vec[cptr, 0] = 8 | |
| p_cont_vec[cptr, 1] = v | |
| p_cont_vec[cptr, 2] = s | |
| p_cont_vec[cptr, 9] = 1 | |
| cptr += 1 | |
| elif op == O_BOOST: | |
| bonus += v | |
| ip += 1 | |
| return cptr, 0, bonus | |
| def step_player_device( | |
| act_id, | |
| player_id, | |
| rng_state, | |
| i, | |
| p_hand, | |
| p_deck, | |
| p_stage, | |
| p_energy_vec, | |
| p_energy_count, | |
| p_tapped, | |
| p_live, | |
| p_scores, | |
| p_global_ctx, | |
| p_trash, | |
| p_continuous_vec, | |
| p_continuous_ptr, | |
| opp_tapped, | |
| card_stats, | |
| bytecode_map, | |
| bytecode_index, | |
| ): | |
| """ | |
| Device function for single player step. | |
| Returns bonus score from this action. | |
| """ | |
| bonus = 0 | |
| if act_id == 0: | |
| # Pass -> Next Phase | |
| ph = p_global_ctx[PH] | |
| if ph == -1: | |
| p_global_ctx[PH] = 0 | |
| elif ph == 0: | |
| p_global_ctx[PH] = 4 # Skip to Main | |
| elif ph == 4: | |
| p_global_ctx[PH] = 8 # Performance | |
| return 0 | |
| # Member Play (1-180) | |
| if 1 <= act_id <= 180: | |
| adj = act_id - 1 | |
| hand_idx = adj // 3 | |
| slot = adj % 3 | |
| if hand_idx < 60: | |
| card_id = p_hand[hand_idx] | |
| if card_id >= 0: | |
| bid = get_base_id_device(card_id) | |
| if bid < card_stats.shape[0]: | |
| # Cost calculation | |
| cost = card_stats[bid, 0] | |
| effective_cost = cost | |
| prev_cid = p_stage[slot] | |
| if prev_cid >= 0: | |
| prev_bid = get_base_id_device(prev_cid) | |
| if prev_bid < card_stats.shape[0]: | |
| prev_cost = card_stats[prev_bid, 0] | |
| effective_cost = max(0, cost - prev_cost) | |
| # Pay cost by tapping energy | |
| ec = min(p_global_ctx[EN], 12) | |
| paid = 0 | |
| if effective_cost > 0: | |
| for e_idx in range(ec): | |
| if 3 + e_idx < 16: | |
| if p_tapped[3 + e_idx] == 0: | |
| p_tapped[3 + e_idx] = 1 | |
| paid += 1 | |
| if paid >= effective_cost: | |
| break | |
| # Move to stage | |
| p_stage[slot] = card_id | |
| p_hand[hand_idx] = 0 | |
| p_global_ctx[HD] -= 1 | |
| p_global_ctx[51 + slot] = 1 # Mark played | |
| # Resolve auto-ability | |
| bid = get_base_id_device(card_id) | |
| if bid < bytecode_index.shape[0]: | |
| map_idx = bytecode_index[bid, 0] | |
| if map_idx >= 0: | |
| flat_ctx = cuda.local.array(64, dtype=np.int32) | |
| for j in range(64): | |
| flat_ctx[j] = 0 | |
| new_ptr, _, ab_bonus = resolve_bytecode_device( | |
| bytecode_map[map_idx], | |
| flat_ctx, | |
| p_global_ctx, | |
| player_id, | |
| p_hand, | |
| p_deck, | |
| p_stage, | |
| p_energy_vec, | |
| p_energy_count, | |
| p_continuous_vec, | |
| p_continuous_ptr[0], | |
| p_tapped, | |
| p_live, | |
| opp_tapped, | |
| p_trash, | |
| bytecode_map, | |
| bytecode_index, | |
| ) | |
| p_continuous_ptr[0] = new_ptr | |
| bonus += ab_bonus | |
| # Activate Ability (200-202) | |
| elif 200 <= act_id <= 202: | |
| slot = act_id - 200 | |
| card_id = p_stage[slot] | |
| if card_id >= 0 and p_tapped[slot] == 0: | |
| bid = get_base_id_device(card_id) | |
| if bid < bytecode_index.shape[0]: | |
| map_idx = bytecode_index[bid, 0] | |
| if map_idx >= 0: | |
| flat_ctx = cuda.local.array(64, dtype=np.int32) | |
| for j in range(64): | |
| flat_ctx[j] = 0 | |
| new_ptr, _, ab_bonus = resolve_bytecode_device( | |
| bytecode_map[map_idx], | |
| flat_ctx, | |
| p_global_ctx, | |
| player_id, | |
| p_hand, | |
| p_deck, | |
| p_stage, | |
| p_energy_vec, | |
| p_energy_count, | |
| p_continuous_vec, | |
| p_continuous_ptr[0], | |
| p_tapped, | |
| p_live, | |
| opp_tapped, | |
| p_trash, | |
| bytecode_map, | |
| bytecode_index, | |
| ) | |
| p_continuous_ptr[0] = new_ptr | |
| bonus += ab_bonus | |
| p_tapped[slot] = 1 | |
| # Set Live Card (400-459) | |
| elif 400 <= act_id <= 459: | |
| hand_idx = act_id - 400 | |
| if hand_idx < 60: | |
| card_id = p_hand[hand_idx] | |
| if card_id > 0: | |
| # Find empty live zone slot | |
| for j in range(50): | |
| if p_live[j] == 0: | |
| p_live[j] = card_id | |
| p_hand[hand_idx] = 0 | |
| p_global_ctx[HD] -= 1 | |
| break | |
| return bonus | |
| def resolve_live_device( | |
| live_id, p_stage, p_live, p_scores, p_global_ctx, p_deck, p_hand, p_trash, card_stats, p_cont_vec, p_cont_ptr | |
| ): | |
| """ | |
| Device function to resolve a live card. | |
| Returns the score value if successful, 0 otherwise. | |
| """ | |
| bid = get_base_id_device(live_id) | |
| if live_id < 0 or bid >= card_stats.shape[0]: | |
| return 0 | |
| # Get required hearts from card_stats (indices 12-18) | |
| required = cuda.local.array(7, dtype=np.int32) | |
| for c in range(7): | |
| required[c] = card_stats[bid, 12 + c] | |
| total_required = 0 | |
| for c in range(7): | |
| total_required += required[c] | |
| if total_required <= 0: | |
| # No requirements - auto-succeed | |
| return card_stats[bid, 38] # Score value | |
| # Calculate provided hearts from stage members | |
| provided = cuda.local.array(7, dtype=np.int32) | |
| for c in range(7): | |
| provided[c] = 0 | |
| for slot in range(3): | |
| cid = p_stage[slot] | |
| if cid > 0: | |
| s_bid = get_base_id_device(cid) | |
| if s_bid < card_stats.shape[0]: | |
| for c in range(7): | |
| provided[c] += card_stats[s_bid, 12 + c] | |
| # Check if requirements met | |
| for c in range(6): # Colors (not All) | |
| if required[c] > provided[c]: | |
| return 0 # Failed | |
| # All requirements met | |
| return card_stats[bid, 38] | |
| def run_opponent_turn_device( | |
| rng_state, | |
| i, | |
| opp_hand, | |
| opp_deck, | |
| opp_stage, | |
| opp_energy_vec, | |
| opp_energy_count, | |
| opp_tapped, | |
| opp_live, | |
| opp_scores, | |
| opp_global_ctx, | |
| opp_trash, | |
| p_tapped, | |
| opp_history, | |
| card_stats, | |
| bytecode_map, | |
| bytecode_index, | |
| ): | |
| """ | |
| Simple heuristic opponent turn. | |
| Plays members if possible, activates abilities, sets lives. | |
| """ | |
| # Play up to 3 members in empty slots | |
| for slot in range(3): | |
| if opp_stage[slot] == -1: | |
| # Find playable member in hand | |
| for h in range(60): | |
| cid = opp_hand[h] | |
| if cid >= 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| ctype = card_stats[bid, 10] | |
| if ctype == 1: # Member | |
| cost = card_stats[bid, 0] | |
| if cost <= opp_global_ctx[EN]: | |
| # Play it | |
| opp_stage[slot] = cid | |
| opp_hand[h] = 0 | |
| opp_global_ctx[HD] -= 1 | |
| # Update History | |
| for k in range(5, 0, -1): | |
| opp_history[i, k] = opp_history[i, k - 1] | |
| opp_history[i, 0] = cid | |
| break | |
| # Set a live card if possible | |
| for h in range(60): | |
| cid = opp_hand[h] | |
| if cid >= 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| ctype = card_stats[bid, 10] | |
| if ctype == 2: # Live | |
| for lz in range(50): | |
| if opp_live[lz] == 0: | |
| opp_live[lz] = cid | |
| opp_hand[h] = 0 | |
| opp_global_ctx[HD] -= 1 | |
| # Update History | |
| for k in range(5, 0, -1): | |
| opp_history[i, k] = opp_history[i, k - 1] | |
| opp_history[i, 0] = cid | |
| break | |
| break | |
| # ============================================================================ | |
| # MAIN KERNELS | |
| # ============================================================================ | |
| def reset_kernel( | |
| indices, | |
| batch_stage, | |
| batch_energy_vec, | |
| batch_energy_count, | |
| batch_continuous_vec, | |
| batch_continuous_ptr, | |
| batch_tapped, | |
| batch_live, | |
| batch_scores, | |
| batch_flat_ctx, | |
| batch_global_ctx, | |
| batch_hand, | |
| batch_deck, | |
| batch_trash, | |
| batch_opp_history, | |
| opp_stage, | |
| opp_energy_vec, | |
| opp_energy_count, | |
| opp_tapped, | |
| opp_live, | |
| opp_scores, | |
| opp_global_ctx, | |
| opp_hand, | |
| opp_deck, | |
| opp_trash, | |
| ability_member_ids, | |
| ability_live_ids, | |
| rng_states, | |
| force_start_order, | |
| obs_buffer, | |
| card_stats, | |
| ): | |
| """ | |
| CUDA Kernel to reset environments. | |
| """ | |
| tid = cuda.grid(1) | |
| if tid >= indices.shape[0]: | |
| return | |
| i = indices[tid] | |
| # Clear agent state | |
| for j in range(3): | |
| batch_stage[i, j] = -1 | |
| for j in range(3): | |
| for k in range(32): | |
| batch_energy_vec[i, j, k] = 0 | |
| batch_energy_count[i, j] = 0 | |
| for j in range(32): | |
| for k in range(10): | |
| batch_continuous_vec[i, j, k] = 0 | |
| batch_continuous_ptr[i] = 0 | |
| for j in range(16): | |
| batch_tapped[i, j] = 0 | |
| for j in range(50): | |
| batch_live[i, j] = 0 | |
| batch_scores[i] = 0 | |
| for j in range(64): | |
| batch_flat_ctx[i, j] = 0 | |
| for j in range(128): | |
| batch_global_ctx[i, j] = 0 | |
| for j in range(60): | |
| batch_trash[i, j] = 0 | |
| for j in range(6): | |
| batch_opp_history[i, j] = 0 | |
| # Clear opponent state | |
| for j in range(3): | |
| opp_stage[i, j] = -1 | |
| for j in range(3): | |
| for k in range(32): | |
| opp_energy_vec[i, j, k] = 0 | |
| opp_energy_count[i, j] = 0 | |
| for j in range(16): | |
| opp_tapped[i, j] = 0 | |
| for j in range(50): | |
| opp_live[i, j] = 0 | |
| opp_scores[i] = 0 | |
| for j in range(128): | |
| opp_global_ctx[i, j] = 0 | |
| for j in range(60): | |
| opp_trash[i, j] = 0 | |
| # Generate deck | |
| n_members = ability_member_ids.shape[0] | |
| n_lives = ability_live_ids.shape[0] | |
| # Members (0-47) | |
| for k in range(48): | |
| if n_members == 48: | |
| batch_deck[i, k] = ability_member_ids[k] | |
| opp_deck[i, k] = ability_member_ids[k] | |
| else: | |
| # Random pick using RNG | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| idx = int(r * n_members) % n_members | |
| batch_deck[i, k] = ability_member_ids[idx] | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| idx = int(r * n_members) % n_members | |
| opp_deck[i, k] = ability_member_ids[idx] | |
| # Lives (48-59) | |
| for k in range(12): | |
| if n_lives == 12: | |
| batch_deck[i, 48 + k] = ability_live_ids[k] | |
| opp_deck[i, 48 + k] = ability_live_ids[k] | |
| else: | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| idx = int(r * n_lives) % n_lives | |
| batch_deck[i, 48 + k] = ability_live_ids[idx] | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| idx = int(r * n_lives) % n_lives | |
| opp_deck[i, 48 + k] = ability_live_ids[idx] | |
| # Shuffle decks (Fisher-Yates) | |
| for k in range(59, 0, -1): | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| j = int(r * (k + 1)) % (k + 1) | |
| tmp = batch_deck[i, k] | |
| batch_deck[i, k] = batch_deck[i, j] | |
| batch_deck[i, j] = tmp | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| j = int(r * (k + 1)) % (k + 1) | |
| tmp = opp_deck[i, k] | |
| opp_deck[i, k] = opp_deck[i, j] | |
| opp_deck[i, j] = tmp | |
| # Place 2 cards in Live Zone | |
| batch_live[i, 0] = batch_deck[i, 0] | |
| batch_live[i, 1] = batch_deck[i, 1] | |
| batch_deck[i, 0] = 0 | |
| batch_deck[i, 1] = 0 | |
| opp_live[i, 0] = opp_deck[i, 0] | |
| opp_live[i, 1] = opp_deck[i, 1] | |
| opp_deck[i, 0] = 0 | |
| opp_deck[i, 1] = 0 | |
| # Draw hand (6 cards) | |
| for j in range(60): | |
| batch_hand[i, j] = 0 | |
| opp_hand[i, j] = 0 | |
| drawn = 0 | |
| for k in range(2, 60): | |
| if batch_deck[i, k] > 0 and drawn < 6: | |
| batch_hand[i, drawn] = batch_deck[i, k] | |
| batch_deck[i, k] = 0 | |
| drawn += 1 | |
| drawn_o = 0 | |
| for k in range(2, 60): | |
| if opp_deck[i, k] > 0 and drawn_o < 6: | |
| opp_hand[i, drawn_o] = opp_deck[i, k] | |
| opp_deck[i, k] = 0 | |
| drawn_o += 1 | |
| # Set initial global context | |
| batch_global_ctx[i, HD] = 6 | |
| batch_global_ctx[i, DK] = 52 | |
| batch_global_ctx[i, EN] = 3 | |
| batch_global_ctx[i, PH] = 4 # Start in Main phase (simplified) | |
| batch_global_ctx[i, 54] = 1 # Turn 1 | |
| opp_global_ctx[i, HD] = 6 | |
| opp_global_ctx[i, DK] = 52 | |
| opp_global_ctx[i, EN] = 3 | |
| opp_global_ctx[i, PH] = 4 | |
| opp_global_ctx[i, 54] = 1 | |
| # Start order | |
| if force_start_order == -1: | |
| r = xoroshiro128p_uniform_float32(rng_states, i) | |
| is_second = 1 if r > 0.5 else 0 | |
| else: | |
| is_second = force_start_order | |
| batch_global_ctx[i, 10] = is_second | |
| def step_kernel( | |
| num_envs, | |
| actions, | |
| batch_hand, | |
| batch_deck, | |
| batch_stage, | |
| batch_energy_vec, | |
| batch_energy_count, | |
| batch_continuous_vec, | |
| batch_continuous_ptr, | |
| batch_tapped, | |
| batch_live, | |
| batch_scores, | |
| batch_flat_ctx, | |
| batch_global_ctx, | |
| opp_hand, | |
| opp_deck, | |
| opp_stage, | |
| opp_energy_vec, | |
| opp_energy_count, | |
| opp_tapped, | |
| opp_live, | |
| opp_scores, | |
| opp_global_ctx, | |
| card_stats, | |
| bytecode_map, | |
| bytecode_index, | |
| obs_buffer, | |
| rewards, | |
| dones, | |
| prev_scores, | |
| prev_opp_scores, | |
| prev_phases, | |
| terminal_obs_buffer, | |
| batch_trash, | |
| opp_trash, | |
| batch_opp_history, | |
| term_scores_agent, | |
| term_scores_opp, | |
| ability_member_ids, | |
| ability_live_ids, | |
| rng_states, | |
| game_config, | |
| opp_mode, | |
| force_start_order, | |
| ): | |
| """ | |
| Main integrated step kernel. | |
| Processes one environment per thread. | |
| """ | |
| i = cuda.grid(1) | |
| if i >= num_envs: | |
| return | |
| # Config | |
| CFG_TURN_LIMIT = int(game_config[0]) | |
| CFG_STEP_LIMIT = int(game_config[1]) | |
| CFG_REWARD_WIN = game_config[2] | |
| CFG_REWARD_LOSE = game_config[3] | |
| CFG_REWARD_SCALE = game_config[4] | |
| CFG_REWARD_TURN_PENALTY = game_config[5] | |
| act_id = actions[i] | |
| ph = int(batch_global_ctx[i, PH]) | |
| # Sync score to context | |
| batch_global_ctx[i, SC] = batch_scores[i] | |
| # Increment step counter | |
| batch_global_ctx[i, 58] += 1 | |
| # Get continuous pointer slice | |
| cont_ptr_arr = batch_continuous_ptr[i : i + 1] | |
| score_arr = batch_scores[i : i + 1] | |
| # Execute action | |
| bonus = step_player_device( | |
| act_id, | |
| 0, | |
| rng_states, | |
| i, | |
| batch_hand[i], | |
| batch_deck[i], | |
| batch_stage[i], | |
| batch_energy_vec[i], | |
| batch_energy_count[i], | |
| batch_tapped[i], | |
| batch_live[i], | |
| score_arr, | |
| batch_global_ctx[i], | |
| batch_trash[i], | |
| batch_continuous_vec[i], | |
| cont_ptr_arr, | |
| opp_tapped[i], | |
| card_stats, | |
| bytecode_map, | |
| bytecode_index, | |
| ) | |
| batch_scores[i] += bonus | |
| # Handle turn end (Pass in Main Phase) | |
| if act_id == 0 and ph == 4: | |
| # Run opponent turn | |
| run_opponent_turn_device( | |
| rng_states, | |
| i, | |
| opp_hand[i], | |
| opp_deck[i], | |
| opp_stage[i], | |
| opp_energy_vec[i], | |
| opp_energy_count[i], | |
| opp_tapped[i], | |
| opp_live[i], | |
| opp_scores[i : i + 1], | |
| opp_global_ctx[i], | |
| opp_trash[i], | |
| batch_tapped[i], | |
| batch_opp_history, | |
| card_stats, | |
| bytecode_map, | |
| bytecode_index, | |
| ) | |
| # Resolve lives for both players | |
| agent_live_score = 0 | |
| opp_live_score = 0 | |
| for z in range(10): | |
| lid = batch_live[i, z] | |
| if lid > 0: | |
| s = resolve_live_device( | |
| lid, | |
| batch_stage[i], | |
| batch_live[i], | |
| batch_scores[i : i + 1], | |
| batch_global_ctx[i], | |
| batch_deck[i], | |
| batch_hand[i], | |
| batch_trash[i], | |
| card_stats, | |
| batch_continuous_vec[i], | |
| cont_ptr_arr, | |
| ) | |
| agent_live_score += s | |
| # Clear used live | |
| if s > 0: | |
| move_to_trash_device(lid, batch_trash[i], batch_global_ctx[i], TR) | |
| batch_live[i, z] = 0 | |
| for z in range(10): | |
| lid = opp_live[i, z] | |
| if lid > 0: | |
| s = resolve_live_device( | |
| lid, | |
| opp_stage[i], | |
| opp_live[i], | |
| opp_scores[i : i + 1], | |
| opp_global_ctx[i], | |
| opp_deck[i], | |
| opp_hand[i], | |
| opp_trash[i], | |
| card_stats, | |
| batch_continuous_vec[i], | |
| cont_ptr_arr, | |
| ) | |
| opp_live_score += s | |
| if s > 0: | |
| move_to_trash_device(lid, opp_trash[i], opp_global_ctx[i], TR) | |
| opp_live[i, z] = 0 | |
| # Scoring comparison | |
| if agent_live_score > 0 and opp_live_score == 0: | |
| batch_scores[i] += 1 | |
| elif agent_live_score == 0 and opp_live_score > 0: | |
| opp_scores[i] += 1 | |
| elif agent_live_score > 0 and opp_live_score > 0: | |
| if agent_live_score > opp_live_score: | |
| batch_scores[i] += 1 | |
| elif opp_live_score > agent_live_score: | |
| opp_scores[i] += 1 | |
| else: | |
| # Tie - both score | |
| batch_scores[i] += 1 | |
| opp_scores[i] += 1 | |
| # Next turn setup | |
| batch_global_ctx[i, 54] += 1 | |
| opp_global_ctx[i, 54] += 1 | |
| # Untap and energy | |
| for j in range(16): | |
| batch_tapped[i, j] = 0 | |
| if j < opp_tapped.shape[1]: | |
| opp_tapped[i, j] = 0 | |
| batch_global_ctx[i, EN] = min(batch_global_ctx[i, EN] + 1, 12) | |
| opp_global_ctx[i, EN] = min(opp_global_ctx[i, EN] + 1, 12) | |
| # Draw card | |
| draw_cards_device(1, batch_hand[i], batch_deck[i], batch_trash[i], batch_global_ctx[i]) | |
| draw_cards_device(1, opp_hand[i], opp_deck[i], opp_trash[i], opp_global_ctx[i]) | |
| # Calculate rewards | |
| current_score = batch_scores[i] | |
| score_diff = float(current_score) - float(prev_scores[i]) | |
| opp_score_diff = float(opp_scores[i]) - float(prev_opp_scores[i]) | |
| r = (score_diff * CFG_REWARD_SCALE) - (opp_score_diff * CFG_REWARD_SCALE) | |
| r += CFG_REWARD_TURN_PENALTY | |
| win = current_score >= 3 | |
| lose = opp_scores[i] >= 3 | |
| if win: | |
| r += CFG_REWARD_WIN | |
| if lose: | |
| r += CFG_REWARD_LOSE | |
| rewards[i] = r | |
| # Sync Opp Stats to Agent Context (for Attention features) | |
| batch_global_ctx[i, 4] = opp_global_ctx[i, 3] # HD | |
| batch_global_ctx[i, 9] = opp_global_ctx[i, 6] # DK | |
| batch_global_ctx[i, 7] = opp_global_ctx[i, 2] # TR | |
| # Check done | |
| is_done = win or lose or batch_global_ctx[i, 54] >= CFG_TURN_LIMIT or batch_global_ctx[i, 58] >= CFG_STEP_LIMIT | |
| dones[i] = is_done | |
| if is_done: | |
| term_scores_agent[i] = batch_scores[i] | |
| term_scores_opp[i] = opp_scores[i] | |
| # Note: Auto-reset should be called separately | |
| # Update prev scores | |
| prev_scores[i] = batch_scores[i] | |
| prev_opp_scores[i] = opp_scores[i] | |
| def compute_action_masks_kernel( | |
| num_envs, | |
| batch_hand, | |
| batch_stage, | |
| batch_tapped, | |
| batch_global_ctx, | |
| batch_live, | |
| card_stats, | |
| masks, # Output: (N, 2000) | |
| ): | |
| """ | |
| Compute legal action masks on GPU. | |
| """ | |
| i = cuda.grid(1) | |
| if i >= num_envs: | |
| return | |
| # Reset all to False | |
| for a in range(2000): | |
| masks[i, a] = False | |
| ph = batch_global_ctx[i, PH] | |
| # Action 0: Pass is always legal in Main Phase | |
| if ph == 4: | |
| masks[i, 0] = True | |
| # Member Play (1-180): HandIdx * 3 + Slot + 1 | |
| for h_idx in range(60): | |
| cid = batch_hand[i, h_idx] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| ctype = card_stats[bid, 10] | |
| cost = card_stats[bid, 0] | |
| if ctype == 1: # Member | |
| for slot in range(3): | |
| # Check if slot empty or can upgrade | |
| old_cid = batch_stage[i, slot] | |
| effective_cost = cost | |
| if old_cid >= 0: | |
| old_bid = get_base_id_device(old_cid) | |
| if old_bid < card_stats.shape[0]: | |
| effective_cost = max(0, cost - card_stats[old_bid, 0]) | |
| # Check energy | |
| available_energy = 0 | |
| for e in range(12): | |
| if batch_tapped[i, 3 + e] == 0: | |
| available_energy += 1 | |
| if available_energy >= effective_cost: | |
| action_id = h_idx * 3 + slot + 1 | |
| if action_id < 181: | |
| masks[i, action_id] = True | |
| # Activate Ability (200-202) | |
| for slot in range(3): | |
| cid = batch_stage[i, slot] | |
| if cid > 0 and batch_tapped[i, slot] == 0: | |
| masks[i, 200 + slot] = True | |
| # Set Live (400-459) | |
| for h_idx in range(60): | |
| cid = batch_hand[i, h_idx] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| ctype = card_stats[bid, 10] | |
| if ctype == 2: # Live | |
| # Check if there's an empty live zone slot | |
| for lz_idx in range(50): | |
| if batch_live[i, lz_idx] == 0: | |
| if h_idx < 60: # This check is redundant due to outer loop | |
| masks[i, 400 + h_idx] = True | |
| break # Only need one empty slot to make it legal | |
| def encode_observations_kernel( | |
| num_envs, | |
| batch_hand, | |
| batch_stage, | |
| batch_energy_count, | |
| batch_tapped, | |
| batch_scores, | |
| opp_scores, | |
| opp_stage, | |
| opp_tapped, | |
| card_stats, | |
| batch_global_ctx, | |
| batch_live, | |
| turn_number, | |
| obs_buffer, | |
| ): | |
| """ | |
| Encode observations on GPU (STANDARD mode). | |
| """ | |
| i = cuda.grid(1) | |
| if i >= num_envs: | |
| return | |
| obs_dim = obs_buffer.shape[1] | |
| # Clear observation | |
| for j in range(obs_dim): | |
| obs_buffer[i, j] = 0.0 | |
| # Metadata | |
| obs_buffer[i, 0] = float(batch_scores[i]) / 3.0 | |
| obs_buffer[i, 1] = float(opp_scores[i]) / 3.0 | |
| obs_buffer[i, 2] = float(batch_global_ctx[i, EN]) / 12.0 | |
| obs_buffer[i, 3] = float(batch_global_ctx[i, HD]) / 60.0 | |
| obs_buffer[i, 4] = float(batch_global_ctx[i, DK]) / 60.0 | |
| obs_buffer[i, 5] = float(batch_global_ctx[i, 54]) / 100.0 # Turn | |
| offset = 10 | |
| # Stage (3 slots x 20 features) | |
| for slot in range(3): | |
| cid = batch_stage[i, slot] | |
| base = offset + slot * 20 | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| obs_buffer[i, base] = 1.0 # Presence | |
| obs_buffer[i, base + 1] = float(cid) / 2000.0 | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 # Cost | |
| obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 # Blades | |
| obs_buffer[i, base + 4] = float(card_stats[bid, 2]) / 10.0 # Hearts | |
| obs_buffer[i, base + 5] = 1.0 if batch_tapped[i, slot] > 0 else 0.0 | |
| offset += 60 | |
| # Opponent Stage | |
| for slot in range(3): | |
| cid = opp_stage[i, slot] | |
| base = offset + slot * 20 | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| obs_buffer[i, base] = 1.0 | |
| obs_buffer[i, base + 1] = float(cid) / 2000.0 | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 | |
| obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 | |
| obs_buffer[i, base + 4] = float(card_stats[bid, 2]) / 10.0 | |
| offset += 60 | |
| # Hand (up to 20 cards shown) | |
| h_count = 0 | |
| for h_idx in range(60): | |
| cid = batch_hand[i, h_idx] | |
| if cid > 0 and h_count < 20: | |
| base = offset + h_count * 20 | |
| if base + 10 < obs_dim: | |
| obs_buffer[i, base] = 1.0 | |
| obs_buffer[i, base + 1] = float(cid) / 2000.0 | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 | |
| obs_buffer[i, base + 3] = float(card_stats[bid, 10]) # Type | |
| h_count += 1 | |
| offset += 400 | |
| # Live zone (up to 10 cards) | |
| l_count = 0 | |
| for l_idx in range(50): | |
| cid = batch_live[i, l_idx] | |
| if cid > 0 and l_count < 10: | |
| base = offset + l_count * 10 | |
| if base + 5 < obs_dim: | |
| obs_buffer[i, base] = 1.0 | |
| obs_buffer[i, base + 1] = float(cid) / 2000.0 | |
| l_count += 1 | |
| def encode_observations_attention_kernel( | |
| num_envs, | |
| batch_hand, | |
| batch_stage, | |
| batch_energy_count, | |
| batch_tapped, | |
| batch_scores, | |
| opp_scores, | |
| opp_stage, | |
| opp_tapped, | |
| card_stats, | |
| batch_global_ctx, | |
| batch_live, | |
| batch_opp_history, | |
| opp_global_ctx, # Added | |
| turn_number, | |
| obs_buffer, | |
| ): | |
| """ | |
| Encode observations for Attention Architecture (2240-dim). | |
| """ | |
| i = cuda.grid(1) | |
| if i >= num_envs: | |
| return | |
| # Constants | |
| FEAT = 64 | |
| MAX_HAND = 15 # +1 overflow | |
| # Offsets | |
| HAND_START = 0 | |
| HAND_OVER_START = HAND_START + (MAX_HAND * FEAT) # 960 | |
| STAGE_START = HAND_OVER_START + FEAT # 1024 | |
| LIVE_START = STAGE_START + (3 * FEAT) # 1216 | |
| LIVE_SUCC_START = LIVE_START + (3 * FEAT) # 1408 | |
| OPP_STAGE_START = LIVE_SUCC_START + (3 * FEAT) # 1600 | |
| OPP_HIST_START = OPP_STAGE_START + (3 * FEAT) # 1792 | |
| GLOBAL_START = OPP_HIST_START + (6 * FEAT) # 2176 | |
| # Clear buffer | |
| for k in range(2240): | |
| obs_buffer[i, k] = 0.0 | |
| # --- A. HAND (16 slots) --- | |
| hand_count = 0 | |
| for j in range(60): | |
| cid = batch_hand[i, j] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| if hand_count < 16: | |
| base = HAND_START + hand_count * FEAT | |
| obs_buffer[i, base + 0] = 1.0 # Presence | |
| obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 # Type | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 # Cost | |
| obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 # Blades | |
| obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New) | |
| obs_buffer[i, base + 6] = 0.2 # Location: Hand | |
| # Hearts (8-14) | |
| for k in range(7): | |
| if 12 + k < card_stats.shape[1]: | |
| obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0 | |
| # Group (22-28) | |
| raw_group = card_stats[bid, 11] | |
| obs_buffer[i, base + 22 + (raw_group % 7)] = 1.0 | |
| # Context | |
| obs_buffer[i, base + 58] = float(hand_count) / 10.0 | |
| obs_buffer[i, base + 59] = 1.0 # Mine | |
| hand_count += 1 | |
| # --- B. MY STAGE (3 slots) --- | |
| for slot in range(3): | |
| cid = batch_stage[i, slot] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| base = STAGE_START + slot * FEAT | |
| obs_buffer[i, base + 0] = 1.0 | |
| obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 | |
| obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 | |
| obs_buffer[i, base + 4] = 1.0 if batch_tapped[i, slot] > 0 else 0.0 | |
| obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New) | |
| obs_buffer[i, base + 6] = 0.4 # Location: Stage | |
| for k in range(7): | |
| if 12 + k < card_stats.shape[1]: | |
| obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0 | |
| raw_group = card_stats[bid, 11] | |
| obs_buffer[i, base + 22 + (raw_group % 7)] = 1.0 | |
| obs_buffer[i, base + 58] = float(slot) / 10.0 | |
| obs_buffer[i, base + 59] = 1.0 | |
| # --- C. LIVE ZONE (6 slots) --- | |
| live_count = 0 | |
| for j in range(50): | |
| cid = batch_live[i, j] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0] and live_count < 6: | |
| base = LIVE_START + live_count * FEAT | |
| obs_buffer[i, base + 0] = 1.0 | |
| obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 | |
| obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New) | |
| obs_buffer[i, base + 6] = 0.6 # Location: Live | |
| for k in range(7): | |
| if 12 + k < card_stats.shape[1]: | |
| obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0 | |
| obs_buffer[i, base + 58] = float(live_count) / 10.0 | |
| obs_buffer[i, base + 59] = 1.0 | |
| live_count += 1 | |
| # --- D. OPP STAGE (3 slots) --- | |
| for slot in range(3): | |
| cid = opp_stage[i, slot] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| base = OPP_STAGE_START + slot * FEAT | |
| obs_buffer[i, base + 0] = 1.0 | |
| obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 | |
| obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 | |
| obs_buffer[i, base + 4] = 1.0 if opp_tapped[i, slot] > 0 else 0.0 | |
| obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New) | |
| obs_buffer[i, base + 6] = 0.8 # Location: Opp Stage | |
| for k in range(7): | |
| if 12 + k < card_stats.shape[1]: | |
| obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0 | |
| obs_buffer[i, base + 58] = float(slot) / 10.0 | |
| obs_buffer[i, base + 59] = -1.0 | |
| # --- E. OPP HISTORY (6 slots) --- | |
| for h in range(6): | |
| cid = batch_opp_history[i, h] | |
| if cid > 0: | |
| bid = get_base_id_device(cid) | |
| if bid < card_stats.shape[0]: | |
| base = OPP_HIST_START + h * FEAT | |
| obs_buffer[i, base + 0] = 1.0 | |
| obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 | |
| obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 | |
| obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New) | |
| obs_buffer[i, base + 6] = 1.0 # Location: History | |
| obs_buffer[i, base + 58] = float(h) / 10.0 | |
| obs_buffer[i, base + 59] = -1.0 | |
| # --- F. GLOBAL SCALARS --- | |
| obs_buffer[i, GLOBAL_START + 0] = float(batch_scores[i]) / 10.0 | |
| obs_buffer[i, GLOBAL_START + 1] = float(opp_scores[i]) / 10.0 | |
| obs_buffer[i, GLOBAL_START + 2] = float(batch_global_ctx[i, 54]) / 20.0 # Turn from Context | |
| obs_buffer[i, GLOBAL_START + 3] = float(batch_global_ctx[i, 8]) / 10.0 | |
| obs_buffer[i, GLOBAL_START + 4] = float(batch_global_ctx[i, 5]) / 10.0 | |
| obs_buffer[i, GLOBAL_START + 5] = float(batch_global_ctx[i, 6]) / 40.0 | |
| obs_buffer[i, GLOBAL_START + 6] = float(hand_count) / 15.0 | |
| # Opponent Resources (New) | |
| obs_buffer[i, GLOBAL_START + 7] = float(opp_global_ctx[i, 5]) / 10.0 # Opp Energy | |
| obs_buffer[i, GLOBAL_START + 8] = float(batch_global_ctx[i, 4]) / 10.0 # Opp Hand (from ctx[4]) | |
| obs_buffer[i, GLOBAL_START + 9] = float(batch_global_ctx[i, 9]) / 40.0 # Opp Deck (from ctx[9]) | |
| obs_buffer[i, GLOBAL_START + 10] = float(batch_global_ctx[i, 7]) / 10.0 # Opp Trash (from ctx[7]) | |