LovecaSim / ai /research /cuda_kernels.py
trioskosmos's picture
Upload ai/research/cuda_kernels.py with huggingface_hub
191b8ab verified
"""
CUDA Kernels for GPU-Accelerated VectorEnv.
This module contains CUDA kernel implementations for:
- Environment reset
- Game step (integrated with opponent, phases, scoring)
- Observation encoding
- Action mask computation
All kernels are designed for the VectorGameStateGPU class.
"""
import numpy as np
try:
from numba import cuda
from numba.cuda.random import xoroshiro128p_normal_float32, xoroshiro128p_uniform_float32
HAS_CUDA = True
except ImportError:
HAS_CUDA = False
# Mock for type checking
class MockCuda:
def jit(self, *args, **kwargs):
def decorator(f):
return f
return decorator
def grid(self, x):
return 0
cuda = MockCuda()
def xoroshiro128p_uniform_float32(rng, i):
return 0.5
# ============================================================================
# CONSTANTS (Must match fast_logic.py)
# ============================================================================
SC = 0
OS = 1
TR = 2
HD = 3
DI = 4
EN = 5
DK = 6
OT = 7
PH = 8
OD = 9
# Opcodes
O_DRAW = 10
O_BLADES = 11
O_HEARTS = 12
O_RECOV_L = 13
O_BOOST = 14
O_RECOV_M = 15
O_BUFF = 16
O_CHARGE = 17
O_TAP_O = 18
O_CHOOSE = 19
O_ADD_H = 20
O_RETURN = 999
O_JUMP = 100
O_JUMP_F = 101
# Conditions
C_TR1 = 200
C_CLR = 202
C_STG = 203
C_HND = 204
C_CTR = 206
C_LLD = 207
C_GRP = 208
C_OPH = 210
C_ENR = 213
C_CMP = 220
# Unique ID (UID) System
BASE_ID_MASK = 0xFFFFF
@cuda.jit(device=True)
def get_base_id_device(uid: int) -> int:
"""Extract the base card definition ID (0-1999) from a UID."""
return uid & BASE_ID_MASK
# ============================================================================
# DEVICE FUNCTIONS (Callable from kernels)
# ============================================================================
@cuda.jit(device=True)
def check_deck_refresh_device(p_deck, p_trash, p_global_ctx, DK_idx, TR_idx):
"""Shuffle trash back into deck if deck is empty."""
if p_global_ctx[DK_idx] <= 0:
# Count trash
tr_count = 0
for t in range(60):
if p_trash[t] > 0:
tr_count += 1
if tr_count > 0:
# Move trash to deck
d_ptr = 0
for t in range(60):
if p_trash[t] > 0:
p_deck[d_ptr] = p_trash[t]
p_trash[t] = 0
d_ptr += 1
p_global_ctx[DK_idx] = d_ptr
p_global_ctx[TR_idx] = 0
@cuda.jit(device=True)
def move_to_trash_device(card_id, p_trash, p_global_ctx, TR_idx):
"""Move a card to trash zone."""
for t in range(60):
if p_trash[t] == 0:
p_trash[t] = card_id
p_global_ctx[TR_idx] += 1
break
@cuda.jit(device=True)
def draw_cards_device(count, p_hand, p_deck, p_trash, p_global_ctx):
"""Draw cards from deck to hand."""
for _ in range(count):
check_deck_refresh_device(p_deck, p_trash, p_global_ctx, DK, TR)
if p_global_ctx[DK] <= 0:
break
# Find top card in deck
top_card = 0
d_idx_found = -1
for d in range(60):
if p_deck[d] > 0:
top_card = p_deck[d]
d_idx_found = d
break
if top_card > 0:
# Find empty hand slot
for h in range(60):
if p_hand[h] == 0:
p_hand[h] = top_card
p_deck[d_idx_found] = 0
p_global_ctx[DK] -= 1
p_global_ctx[HD] += 1
break
@cuda.jit(device=True)
def resolve_bytecode_device(
bytecode,
flat_ctx,
global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_cont_vec,
p_cont_ptr,
p_tapped,
p_live,
opp_tapped,
p_trash,
bytecode_map,
bytecode_index,
):
"""
GPU Device function for resolving bytecode.
Returns (new_cont_ptr, status, bonus).
"""
ip = 0
cptr = p_cont_ptr
bonus = 0
cond = True
blen = bytecode.shape[0]
safety_counter = 0
while ip < blen and safety_counter < 500:
safety_counter += 1
op = bytecode[ip, 0]
v = bytecode[ip, 1]
a = bytecode[ip, 2]
s = bytecode[ip, 3]
if op == 0:
ip += 1
continue
if op == O_RETURN:
break
# Jumps
if op == O_JUMP:
new_ip = ip + v
if 0 <= new_ip < blen:
ip = new_ip
else:
break
continue
if op == O_JUMP_F:
if not cond:
new_ip = ip + v
if 0 <= new_ip < blen:
ip = new_ip
else:
break
continue
ip += 1
continue
# Conditions (op >= 200)
if op >= 200:
if op == C_TR1:
cond = global_ctx[TR] == 1
elif op == C_STG:
ct = 0
for j in range(3):
if p_stage[j] != -1:
ct += 1
cond = ct >= v
elif op == C_HND:
cond = global_ctx[HD] >= v
elif op == C_LLD:
cond = global_ctx[SC] > global_ctx[OS]
elif op == C_ENR:
cond = global_ctx[EN] >= v
elif op == C_CMP:
if v > 0:
cond = global_ctx[SC] >= v
else:
cond = global_ctx[SC] > global_ctx[OS]
elif op == C_OPH:
cond = global_ctx[OT] >= v if v > 0 else global_ctx[OT] > 0
else:
cond = True
ip += 1
else:
# Effects
if cond:
if op == O_DRAW:
draw_cards_device(v, p_hand, p_deck, p_trash, global_ctx)
elif op == O_CHARGE:
# Move cards from deck to energy (simplified)
amt = min(v, global_ctx[DK])
for _ in range(amt):
for d in range(60):
if p_deck[d] > 0:
p_deck[d] = 0
global_ctx[DK] -= 1
global_ctx[EN] += 1
break
elif op == O_HEARTS:
# Add hearts (points)
bonus += v
# Register continuous effect
if cptr < 32:
p_cont_vec[cptr, 0] = 2
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 5] = a
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_BLADES:
if cptr < 32:
p_cont_vec[cptr, 0] = 1
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = 4
p_cont_vec[cptr, 3] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_RECOV_M:
if 0 <= s < 3:
p_tapped[s] = 0
elif op == O_RECOV_L:
if 0 <= s < p_live.shape[0]:
p_live[s] = 0
elif op == O_TAP_O:
if 0 <= s < 3:
opp_tapped[s] = 1
elif op == O_BUFF:
if cptr < 32:
p_cont_vec[cptr, 0] = 8
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_BOOST:
bonus += v
ip += 1
return cptr, 0, bonus
@cuda.jit(device=True)
def step_player_device(
act_id,
player_id,
rng_state,
i,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_tapped,
p_live,
p_scores,
p_global_ctx,
p_trash,
p_continuous_vec,
p_continuous_ptr,
opp_tapped,
card_stats,
bytecode_map,
bytecode_index,
):
"""
Device function for single player step.
Returns bonus score from this action.
"""
bonus = 0
if act_id == 0:
# Pass -> Next Phase
ph = p_global_ctx[PH]
if ph == -1:
p_global_ctx[PH] = 0
elif ph == 0:
p_global_ctx[PH] = 4 # Skip to Main
elif ph == 4:
p_global_ctx[PH] = 8 # Performance
return 0
# Member Play (1-180)
if 1 <= act_id <= 180:
adj = act_id - 1
hand_idx = adj // 3
slot = adj % 3
if hand_idx < 60:
card_id = p_hand[hand_idx]
if card_id >= 0:
bid = get_base_id_device(card_id)
if bid < card_stats.shape[0]:
# Cost calculation
cost = card_stats[bid, 0]
effective_cost = cost
prev_cid = p_stage[slot]
if prev_cid >= 0:
prev_bid = get_base_id_device(prev_cid)
if prev_bid < card_stats.shape[0]:
prev_cost = card_stats[prev_bid, 0]
effective_cost = max(0, cost - prev_cost)
# Pay cost by tapping energy
ec = min(p_global_ctx[EN], 12)
paid = 0
if effective_cost > 0:
for e_idx in range(ec):
if 3 + e_idx < 16:
if p_tapped[3 + e_idx] == 0:
p_tapped[3 + e_idx] = 1
paid += 1
if paid >= effective_cost:
break
# Move to stage
p_stage[slot] = card_id
p_hand[hand_idx] = 0
p_global_ctx[HD] -= 1
p_global_ctx[51 + slot] = 1 # Mark played
# Resolve auto-ability
bid = get_base_id_device(card_id)
if bid < bytecode_index.shape[0]:
map_idx = bytecode_index[bid, 0]
if map_idx >= 0:
flat_ctx = cuda.local.array(64, dtype=np.int32)
for j in range(64):
flat_ctx[j] = 0
new_ptr, _, ab_bonus = resolve_bytecode_device(
bytecode_map[map_idx],
flat_ctx,
p_global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_continuous_vec,
p_continuous_ptr[0],
p_tapped,
p_live,
opp_tapped,
p_trash,
bytecode_map,
bytecode_index,
)
p_continuous_ptr[0] = new_ptr
bonus += ab_bonus
# Activate Ability (200-202)
elif 200 <= act_id <= 202:
slot = act_id - 200
card_id = p_stage[slot]
if card_id >= 0 and p_tapped[slot] == 0:
bid = get_base_id_device(card_id)
if bid < bytecode_index.shape[0]:
map_idx = bytecode_index[bid, 0]
if map_idx >= 0:
flat_ctx = cuda.local.array(64, dtype=np.int32)
for j in range(64):
flat_ctx[j] = 0
new_ptr, _, ab_bonus = resolve_bytecode_device(
bytecode_map[map_idx],
flat_ctx,
p_global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_continuous_vec,
p_continuous_ptr[0],
p_tapped,
p_live,
opp_tapped,
p_trash,
bytecode_map,
bytecode_index,
)
p_continuous_ptr[0] = new_ptr
bonus += ab_bonus
p_tapped[slot] = 1
# Set Live Card (400-459)
elif 400 <= act_id <= 459:
hand_idx = act_id - 400
if hand_idx < 60:
card_id = p_hand[hand_idx]
if card_id > 0:
# Find empty live zone slot
for j in range(50):
if p_live[j] == 0:
p_live[j] = card_id
p_hand[hand_idx] = 0
p_global_ctx[HD] -= 1
break
return bonus
@cuda.jit(device=True)
def resolve_live_device(
live_id, p_stage, p_live, p_scores, p_global_ctx, p_deck, p_hand, p_trash, card_stats, p_cont_vec, p_cont_ptr
):
"""
Device function to resolve a live card.
Returns the score value if successful, 0 otherwise.
"""
bid = get_base_id_device(live_id)
if live_id < 0 or bid >= card_stats.shape[0]:
return 0
# Get required hearts from card_stats (indices 12-18)
required = cuda.local.array(7, dtype=np.int32)
for c in range(7):
required[c] = card_stats[bid, 12 + c]
total_required = 0
for c in range(7):
total_required += required[c]
if total_required <= 0:
# No requirements - auto-succeed
return card_stats[bid, 38] # Score value
# Calculate provided hearts from stage members
provided = cuda.local.array(7, dtype=np.int32)
for c in range(7):
provided[c] = 0
for slot in range(3):
cid = p_stage[slot]
if cid > 0:
s_bid = get_base_id_device(cid)
if s_bid < card_stats.shape[0]:
for c in range(7):
provided[c] += card_stats[s_bid, 12 + c]
# Check if requirements met
for c in range(6): # Colors (not All)
if required[c] > provided[c]:
return 0 # Failed
# All requirements met
return card_stats[bid, 38]
@cuda.jit(device=True)
def run_opponent_turn_device(
rng_state,
i,
opp_hand,
opp_deck,
opp_stage,
opp_energy_vec,
opp_energy_count,
opp_tapped,
opp_live,
opp_scores,
opp_global_ctx,
opp_trash,
p_tapped,
opp_history,
card_stats,
bytecode_map,
bytecode_index,
):
"""
Simple heuristic opponent turn.
Plays members if possible, activates abilities, sets lives.
"""
# Play up to 3 members in empty slots
for slot in range(3):
if opp_stage[slot] == -1:
# Find playable member in hand
for h in range(60):
cid = opp_hand[h]
if cid >= 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
ctype = card_stats[bid, 10]
if ctype == 1: # Member
cost = card_stats[bid, 0]
if cost <= opp_global_ctx[EN]:
# Play it
opp_stage[slot] = cid
opp_hand[h] = 0
opp_global_ctx[HD] -= 1
# Update History
for k in range(5, 0, -1):
opp_history[i, k] = opp_history[i, k - 1]
opp_history[i, 0] = cid
break
# Set a live card if possible
for h in range(60):
cid = opp_hand[h]
if cid >= 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
ctype = card_stats[bid, 10]
if ctype == 2: # Live
for lz in range(50):
if opp_live[lz] == 0:
opp_live[lz] = cid
opp_hand[h] = 0
opp_global_ctx[HD] -= 1
# Update History
for k in range(5, 0, -1):
opp_history[i, k] = opp_history[i, k - 1]
opp_history[i, 0] = cid
break
break
# ============================================================================
# MAIN KERNELS
# ============================================================================
@cuda.jit
def reset_kernel(
indices,
batch_stage,
batch_energy_vec,
batch_energy_count,
batch_continuous_vec,
batch_continuous_ptr,
batch_tapped,
batch_live,
batch_scores,
batch_flat_ctx,
batch_global_ctx,
batch_hand,
batch_deck,
batch_trash,
batch_opp_history,
opp_stage,
opp_energy_vec,
opp_energy_count,
opp_tapped,
opp_live,
opp_scores,
opp_global_ctx,
opp_hand,
opp_deck,
opp_trash,
ability_member_ids,
ability_live_ids,
rng_states,
force_start_order,
obs_buffer,
card_stats,
):
"""
CUDA Kernel to reset environments.
"""
tid = cuda.grid(1)
if tid >= indices.shape[0]:
return
i = indices[tid]
# Clear agent state
for j in range(3):
batch_stage[i, j] = -1
for j in range(3):
for k in range(32):
batch_energy_vec[i, j, k] = 0
batch_energy_count[i, j] = 0
for j in range(32):
for k in range(10):
batch_continuous_vec[i, j, k] = 0
batch_continuous_ptr[i] = 0
for j in range(16):
batch_tapped[i, j] = 0
for j in range(50):
batch_live[i, j] = 0
batch_scores[i] = 0
for j in range(64):
batch_flat_ctx[i, j] = 0
for j in range(128):
batch_global_ctx[i, j] = 0
for j in range(60):
batch_trash[i, j] = 0
for j in range(6):
batch_opp_history[i, j] = 0
# Clear opponent state
for j in range(3):
opp_stage[i, j] = -1
for j in range(3):
for k in range(32):
opp_energy_vec[i, j, k] = 0
opp_energy_count[i, j] = 0
for j in range(16):
opp_tapped[i, j] = 0
for j in range(50):
opp_live[i, j] = 0
opp_scores[i] = 0
for j in range(128):
opp_global_ctx[i, j] = 0
for j in range(60):
opp_trash[i, j] = 0
# Generate deck
n_members = ability_member_ids.shape[0]
n_lives = ability_live_ids.shape[0]
# Members (0-47)
for k in range(48):
if n_members == 48:
batch_deck[i, k] = ability_member_ids[k]
opp_deck[i, k] = ability_member_ids[k]
else:
# Random pick using RNG
r = xoroshiro128p_uniform_float32(rng_states, i)
idx = int(r * n_members) % n_members
batch_deck[i, k] = ability_member_ids[idx]
r = xoroshiro128p_uniform_float32(rng_states, i)
idx = int(r * n_members) % n_members
opp_deck[i, k] = ability_member_ids[idx]
# Lives (48-59)
for k in range(12):
if n_lives == 12:
batch_deck[i, 48 + k] = ability_live_ids[k]
opp_deck[i, 48 + k] = ability_live_ids[k]
else:
r = xoroshiro128p_uniform_float32(rng_states, i)
idx = int(r * n_lives) % n_lives
batch_deck[i, 48 + k] = ability_live_ids[idx]
r = xoroshiro128p_uniform_float32(rng_states, i)
idx = int(r * n_lives) % n_lives
opp_deck[i, 48 + k] = ability_live_ids[idx]
# Shuffle decks (Fisher-Yates)
for k in range(59, 0, -1):
r = xoroshiro128p_uniform_float32(rng_states, i)
j = int(r * (k + 1)) % (k + 1)
tmp = batch_deck[i, k]
batch_deck[i, k] = batch_deck[i, j]
batch_deck[i, j] = tmp
r = xoroshiro128p_uniform_float32(rng_states, i)
j = int(r * (k + 1)) % (k + 1)
tmp = opp_deck[i, k]
opp_deck[i, k] = opp_deck[i, j]
opp_deck[i, j] = tmp
# Place 2 cards in Live Zone
batch_live[i, 0] = batch_deck[i, 0]
batch_live[i, 1] = batch_deck[i, 1]
batch_deck[i, 0] = 0
batch_deck[i, 1] = 0
opp_live[i, 0] = opp_deck[i, 0]
opp_live[i, 1] = opp_deck[i, 1]
opp_deck[i, 0] = 0
opp_deck[i, 1] = 0
# Draw hand (6 cards)
for j in range(60):
batch_hand[i, j] = 0
opp_hand[i, j] = 0
drawn = 0
for k in range(2, 60):
if batch_deck[i, k] > 0 and drawn < 6:
batch_hand[i, drawn] = batch_deck[i, k]
batch_deck[i, k] = 0
drawn += 1
drawn_o = 0
for k in range(2, 60):
if opp_deck[i, k] > 0 and drawn_o < 6:
opp_hand[i, drawn_o] = opp_deck[i, k]
opp_deck[i, k] = 0
drawn_o += 1
# Set initial global context
batch_global_ctx[i, HD] = 6
batch_global_ctx[i, DK] = 52
batch_global_ctx[i, EN] = 3
batch_global_ctx[i, PH] = 4 # Start in Main phase (simplified)
batch_global_ctx[i, 54] = 1 # Turn 1
opp_global_ctx[i, HD] = 6
opp_global_ctx[i, DK] = 52
opp_global_ctx[i, EN] = 3
opp_global_ctx[i, PH] = 4
opp_global_ctx[i, 54] = 1
# Start order
if force_start_order == -1:
r = xoroshiro128p_uniform_float32(rng_states, i)
is_second = 1 if r > 0.5 else 0
else:
is_second = force_start_order
batch_global_ctx[i, 10] = is_second
@cuda.jit
def step_kernel(
num_envs,
actions,
batch_hand,
batch_deck,
batch_stage,
batch_energy_vec,
batch_energy_count,
batch_continuous_vec,
batch_continuous_ptr,
batch_tapped,
batch_live,
batch_scores,
batch_flat_ctx,
batch_global_ctx,
opp_hand,
opp_deck,
opp_stage,
opp_energy_vec,
opp_energy_count,
opp_tapped,
opp_live,
opp_scores,
opp_global_ctx,
card_stats,
bytecode_map,
bytecode_index,
obs_buffer,
rewards,
dones,
prev_scores,
prev_opp_scores,
prev_phases,
terminal_obs_buffer,
batch_trash,
opp_trash,
batch_opp_history,
term_scores_agent,
term_scores_opp,
ability_member_ids,
ability_live_ids,
rng_states,
game_config,
opp_mode,
force_start_order,
):
"""
Main integrated step kernel.
Processes one environment per thread.
"""
i = cuda.grid(1)
if i >= num_envs:
return
# Config
CFG_TURN_LIMIT = int(game_config[0])
CFG_STEP_LIMIT = int(game_config[1])
CFG_REWARD_WIN = game_config[2]
CFG_REWARD_LOSE = game_config[3]
CFG_REWARD_SCALE = game_config[4]
CFG_REWARD_TURN_PENALTY = game_config[5]
act_id = actions[i]
ph = int(batch_global_ctx[i, PH])
# Sync score to context
batch_global_ctx[i, SC] = batch_scores[i]
# Increment step counter
batch_global_ctx[i, 58] += 1
# Get continuous pointer slice
cont_ptr_arr = batch_continuous_ptr[i : i + 1]
score_arr = batch_scores[i : i + 1]
# Execute action
bonus = step_player_device(
act_id,
0,
rng_states,
i,
batch_hand[i],
batch_deck[i],
batch_stage[i],
batch_energy_vec[i],
batch_energy_count[i],
batch_tapped[i],
batch_live[i],
score_arr,
batch_global_ctx[i],
batch_trash[i],
batch_continuous_vec[i],
cont_ptr_arr,
opp_tapped[i],
card_stats,
bytecode_map,
bytecode_index,
)
batch_scores[i] += bonus
# Handle turn end (Pass in Main Phase)
if act_id == 0 and ph == 4:
# Run opponent turn
run_opponent_turn_device(
rng_states,
i,
opp_hand[i],
opp_deck[i],
opp_stage[i],
opp_energy_vec[i],
opp_energy_count[i],
opp_tapped[i],
opp_live[i],
opp_scores[i : i + 1],
opp_global_ctx[i],
opp_trash[i],
batch_tapped[i],
batch_opp_history,
card_stats,
bytecode_map,
bytecode_index,
)
# Resolve lives for both players
agent_live_score = 0
opp_live_score = 0
for z in range(10):
lid = batch_live[i, z]
if lid > 0:
s = resolve_live_device(
lid,
batch_stage[i],
batch_live[i],
batch_scores[i : i + 1],
batch_global_ctx[i],
batch_deck[i],
batch_hand[i],
batch_trash[i],
card_stats,
batch_continuous_vec[i],
cont_ptr_arr,
)
agent_live_score += s
# Clear used live
if s > 0:
move_to_trash_device(lid, batch_trash[i], batch_global_ctx[i], TR)
batch_live[i, z] = 0
for z in range(10):
lid = opp_live[i, z]
if lid > 0:
s = resolve_live_device(
lid,
opp_stage[i],
opp_live[i],
opp_scores[i : i + 1],
opp_global_ctx[i],
opp_deck[i],
opp_hand[i],
opp_trash[i],
card_stats,
batch_continuous_vec[i],
cont_ptr_arr,
)
opp_live_score += s
if s > 0:
move_to_trash_device(lid, opp_trash[i], opp_global_ctx[i], TR)
opp_live[i, z] = 0
# Scoring comparison
if agent_live_score > 0 and opp_live_score == 0:
batch_scores[i] += 1
elif agent_live_score == 0 and opp_live_score > 0:
opp_scores[i] += 1
elif agent_live_score > 0 and opp_live_score > 0:
if agent_live_score > opp_live_score:
batch_scores[i] += 1
elif opp_live_score > agent_live_score:
opp_scores[i] += 1
else:
# Tie - both score
batch_scores[i] += 1
opp_scores[i] += 1
# Next turn setup
batch_global_ctx[i, 54] += 1
opp_global_ctx[i, 54] += 1
# Untap and energy
for j in range(16):
batch_tapped[i, j] = 0
if j < opp_tapped.shape[1]:
opp_tapped[i, j] = 0
batch_global_ctx[i, EN] = min(batch_global_ctx[i, EN] + 1, 12)
opp_global_ctx[i, EN] = min(opp_global_ctx[i, EN] + 1, 12)
# Draw card
draw_cards_device(1, batch_hand[i], batch_deck[i], batch_trash[i], batch_global_ctx[i])
draw_cards_device(1, opp_hand[i], opp_deck[i], opp_trash[i], opp_global_ctx[i])
# Calculate rewards
current_score = batch_scores[i]
score_diff = float(current_score) - float(prev_scores[i])
opp_score_diff = float(opp_scores[i]) - float(prev_opp_scores[i])
r = (score_diff * CFG_REWARD_SCALE) - (opp_score_diff * CFG_REWARD_SCALE)
r += CFG_REWARD_TURN_PENALTY
win = current_score >= 3
lose = opp_scores[i] >= 3
if win:
r += CFG_REWARD_WIN
if lose:
r += CFG_REWARD_LOSE
rewards[i] = r
# Sync Opp Stats to Agent Context (for Attention features)
batch_global_ctx[i, 4] = opp_global_ctx[i, 3] # HD
batch_global_ctx[i, 9] = opp_global_ctx[i, 6] # DK
batch_global_ctx[i, 7] = opp_global_ctx[i, 2] # TR
# Check done
is_done = win or lose or batch_global_ctx[i, 54] >= CFG_TURN_LIMIT or batch_global_ctx[i, 58] >= CFG_STEP_LIMIT
dones[i] = is_done
if is_done:
term_scores_agent[i] = batch_scores[i]
term_scores_opp[i] = opp_scores[i]
# Note: Auto-reset should be called separately
# Update prev scores
prev_scores[i] = batch_scores[i]
prev_opp_scores[i] = opp_scores[i]
@cuda.jit
def compute_action_masks_kernel(
num_envs,
batch_hand,
batch_stage,
batch_tapped,
batch_global_ctx,
batch_live,
card_stats,
masks, # Output: (N, 2000)
):
"""
Compute legal action masks on GPU.
"""
i = cuda.grid(1)
if i >= num_envs:
return
# Reset all to False
for a in range(2000):
masks[i, a] = False
ph = batch_global_ctx[i, PH]
# Action 0: Pass is always legal in Main Phase
if ph == 4:
masks[i, 0] = True
# Member Play (1-180): HandIdx * 3 + Slot + 1
for h_idx in range(60):
cid = batch_hand[i, h_idx]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
ctype = card_stats[bid, 10]
cost = card_stats[bid, 0]
if ctype == 1: # Member
for slot in range(3):
# Check if slot empty or can upgrade
old_cid = batch_stage[i, slot]
effective_cost = cost
if old_cid >= 0:
old_bid = get_base_id_device(old_cid)
if old_bid < card_stats.shape[0]:
effective_cost = max(0, cost - card_stats[old_bid, 0])
# Check energy
available_energy = 0
for e in range(12):
if batch_tapped[i, 3 + e] == 0:
available_energy += 1
if available_energy >= effective_cost:
action_id = h_idx * 3 + slot + 1
if action_id < 181:
masks[i, action_id] = True
# Activate Ability (200-202)
for slot in range(3):
cid = batch_stage[i, slot]
if cid > 0 and batch_tapped[i, slot] == 0:
masks[i, 200 + slot] = True
# Set Live (400-459)
for h_idx in range(60):
cid = batch_hand[i, h_idx]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
ctype = card_stats[bid, 10]
if ctype == 2: # Live
# Check if there's an empty live zone slot
for lz_idx in range(50):
if batch_live[i, lz_idx] == 0:
if h_idx < 60: # This check is redundant due to outer loop
masks[i, 400 + h_idx] = True
break # Only need one empty slot to make it legal
@cuda.jit
def encode_observations_kernel(
num_envs,
batch_hand,
batch_stage,
batch_energy_count,
batch_tapped,
batch_scores,
opp_scores,
opp_stage,
opp_tapped,
card_stats,
batch_global_ctx,
batch_live,
turn_number,
obs_buffer,
):
"""
Encode observations on GPU (STANDARD mode).
"""
i = cuda.grid(1)
if i >= num_envs:
return
obs_dim = obs_buffer.shape[1]
# Clear observation
for j in range(obs_dim):
obs_buffer[i, j] = 0.0
# Metadata
obs_buffer[i, 0] = float(batch_scores[i]) / 3.0
obs_buffer[i, 1] = float(opp_scores[i]) / 3.0
obs_buffer[i, 2] = float(batch_global_ctx[i, EN]) / 12.0
obs_buffer[i, 3] = float(batch_global_ctx[i, HD]) / 60.0
obs_buffer[i, 4] = float(batch_global_ctx[i, DK]) / 60.0
obs_buffer[i, 5] = float(batch_global_ctx[i, 54]) / 100.0 # Turn
offset = 10
# Stage (3 slots x 20 features)
for slot in range(3):
cid = batch_stage[i, slot]
base = offset + slot * 20
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
obs_buffer[i, base] = 1.0 # Presence
obs_buffer[i, base + 1] = float(cid) / 2000.0
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 # Cost
obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 # Blades
obs_buffer[i, base + 4] = float(card_stats[bid, 2]) / 10.0 # Hearts
obs_buffer[i, base + 5] = 1.0 if batch_tapped[i, slot] > 0 else 0.0
offset += 60
# Opponent Stage
for slot in range(3):
cid = opp_stage[i, slot]
base = offset + slot * 20
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
obs_buffer[i, base] = 1.0
obs_buffer[i, base + 1] = float(cid) / 2000.0
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0
obs_buffer[i, base + 4] = float(card_stats[bid, 2]) / 10.0
offset += 60
# Hand (up to 20 cards shown)
h_count = 0
for h_idx in range(60):
cid = batch_hand[i, h_idx]
if cid > 0 and h_count < 20:
base = offset + h_count * 20
if base + 10 < obs_dim:
obs_buffer[i, base] = 1.0
obs_buffer[i, base + 1] = float(cid) / 2000.0
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
obs_buffer[i, base + 3] = float(card_stats[bid, 10]) # Type
h_count += 1
offset += 400
# Live zone (up to 10 cards)
l_count = 0
for l_idx in range(50):
cid = batch_live[i, l_idx]
if cid > 0 and l_count < 10:
base = offset + l_count * 10
if base + 5 < obs_dim:
obs_buffer[i, base] = 1.0
obs_buffer[i, base + 1] = float(cid) / 2000.0
l_count += 1
@cuda.jit
def encode_observations_attention_kernel(
num_envs,
batch_hand,
batch_stage,
batch_energy_count,
batch_tapped,
batch_scores,
opp_scores,
opp_stage,
opp_tapped,
card_stats,
batch_global_ctx,
batch_live,
batch_opp_history,
opp_global_ctx, # Added
turn_number,
obs_buffer,
):
"""
Encode observations for Attention Architecture (2240-dim).
"""
i = cuda.grid(1)
if i >= num_envs:
return
# Constants
FEAT = 64
MAX_HAND = 15 # +1 overflow
# Offsets
HAND_START = 0
HAND_OVER_START = HAND_START + (MAX_HAND * FEAT) # 960
STAGE_START = HAND_OVER_START + FEAT # 1024
LIVE_START = STAGE_START + (3 * FEAT) # 1216
LIVE_SUCC_START = LIVE_START + (3 * FEAT) # 1408
OPP_STAGE_START = LIVE_SUCC_START + (3 * FEAT) # 1600
OPP_HIST_START = OPP_STAGE_START + (3 * FEAT) # 1792
GLOBAL_START = OPP_HIST_START + (6 * FEAT) # 2176
# Clear buffer
for k in range(2240):
obs_buffer[i, k] = 0.0
# --- A. HAND (16 slots) ---
hand_count = 0
for j in range(60):
cid = batch_hand[i, j]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
if hand_count < 16:
base = HAND_START + hand_count * FEAT
obs_buffer[i, base + 0] = 1.0 # Presence
obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 # Type
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 # Cost
obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 # Blades
obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
obs_buffer[i, base + 6] = 0.2 # Location: Hand
# Hearts (8-14)
for k in range(7):
if 12 + k < card_stats.shape[1]:
obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
# Group (22-28)
raw_group = card_stats[bid, 11]
obs_buffer[i, base + 22 + (raw_group % 7)] = 1.0
# Context
obs_buffer[i, base + 58] = float(hand_count) / 10.0
obs_buffer[i, base + 59] = 1.0 # Mine
hand_count += 1
# --- B. MY STAGE (3 slots) ---
for slot in range(3):
cid = batch_stage[i, slot]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
base = STAGE_START + slot * FEAT
obs_buffer[i, base + 0] = 1.0
obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0
obs_buffer[i, base + 4] = 1.0 if batch_tapped[i, slot] > 0 else 0.0
obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
obs_buffer[i, base + 6] = 0.4 # Location: Stage
for k in range(7):
if 12 + k < card_stats.shape[1]:
obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
raw_group = card_stats[bid, 11]
obs_buffer[i, base + 22 + (raw_group % 7)] = 1.0
obs_buffer[i, base + 58] = float(slot) / 10.0
obs_buffer[i, base + 59] = 1.0
# --- C. LIVE ZONE (6 slots) ---
live_count = 0
for j in range(50):
cid = batch_live[i, j]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0] and live_count < 6:
base = LIVE_START + live_count * FEAT
obs_buffer[i, base + 0] = 1.0
obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
obs_buffer[i, base + 6] = 0.6 # Location: Live
for k in range(7):
if 12 + k < card_stats.shape[1]:
obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
obs_buffer[i, base + 58] = float(live_count) / 10.0
obs_buffer[i, base + 59] = 1.0
live_count += 1
# --- D. OPP STAGE (3 slots) ---
for slot in range(3):
cid = opp_stage[i, slot]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
base = OPP_STAGE_START + slot * FEAT
obs_buffer[i, base + 0] = 1.0
obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0
obs_buffer[i, base + 4] = 1.0 if opp_tapped[i, slot] > 0 else 0.0
obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
obs_buffer[i, base + 6] = 0.8 # Location: Opp Stage
for k in range(7):
if 12 + k < card_stats.shape[1]:
obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
obs_buffer[i, base + 58] = float(slot) / 10.0
obs_buffer[i, base + 59] = -1.0
# --- E. OPP HISTORY (6 slots) ---
for h in range(6):
cid = batch_opp_history[i, h]
if cid > 0:
bid = get_base_id_device(cid)
if bid < card_stats.shape[0]:
base = OPP_HIST_START + h * FEAT
obs_buffer[i, base + 0] = 1.0
obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
obs_buffer[i, base + 6] = 1.0 # Location: History
obs_buffer[i, base + 58] = float(h) / 10.0
obs_buffer[i, base + 59] = -1.0
# --- F. GLOBAL SCALARS ---
obs_buffer[i, GLOBAL_START + 0] = float(batch_scores[i]) / 10.0
obs_buffer[i, GLOBAL_START + 1] = float(opp_scores[i]) / 10.0
obs_buffer[i, GLOBAL_START + 2] = float(batch_global_ctx[i, 54]) / 20.0 # Turn from Context
obs_buffer[i, GLOBAL_START + 3] = float(batch_global_ctx[i, 8]) / 10.0
obs_buffer[i, GLOBAL_START + 4] = float(batch_global_ctx[i, 5]) / 10.0
obs_buffer[i, GLOBAL_START + 5] = float(batch_global_ctx[i, 6]) / 40.0
obs_buffer[i, GLOBAL_START + 6] = float(hand_count) / 15.0
# Opponent Resources (New)
obs_buffer[i, GLOBAL_START + 7] = float(opp_global_ctx[i, 5]) / 10.0 # Opp Energy
obs_buffer[i, GLOBAL_START + 8] = float(batch_global_ctx[i, 4]) / 10.0 # Opp Hand (from ctx[4])
obs_buffer[i, GLOBAL_START + 9] = float(batch_global_ctx[i, 9]) / 40.0 # Opp Deck (from ctx[9])
obs_buffer[i, GLOBAL_START + 10] = float(batch_global_ctx[i, 7]) / 10.0 # Opp Trash (from ctx[7])