LovecaSim / ai /research /fast_logic_gpu.py
trioskosmos's picture
Upload ai/research/fast_logic_gpu.py with huggingface_hub
598726b verified
from engine.game.fast_logic import (
C_CLR,
C_CMP,
C_CTR,
C_ENR,
C_GRP,
C_HND,
C_LLD,
C_OPH,
C_STG,
C_TR1,
DK,
EN,
HD,
O_ADD_H,
O_BLADES,
O_BOOST,
O_BUFF,
O_CHARGE,
O_CHOOSE,
O_DRAW,
O_HEARTS,
O_JUMP,
O_JUMP_F,
O_RECOV_L,
O_RECOV_M,
O_RETURN,
O_TAP_O,
OS,
OT,
SC,
TR,
)
try:
from numba import cuda
from numba.cuda.random import xoroshiro128p_uniform_float32
HAS_CUDA = True
except ImportError:
HAS_CUDA = False
class MockCuda:
def jit(self, *args, **kwargs):
return lambda x: x
def grid(self, x):
return 0
cuda = MockCuda()
def xoroshiro128p_uniform_float32(rng, idx):
return 0.5
@cuda.jit(device=True)
def resolve_bytecode_device(
bytecode,
flat_ctx,
global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_cont_vec,
p_cont_ptr,
p_tapped,
p_live,
opp_tapped,
):
"""
GPU Device function for resolving bytecode.
Equivalent to engine/game/fast_logic.py:resolve_bytecode but optimized for CUDA.
"""
ip = 0
cptr = p_cont_ptr
bonus = 0
cond = True
blen = bytecode.shape[0]
# SAFETY: Infinite loop protection
safety_counter = 0
while ip < blen and safety_counter < 500:
safety_counter += 1
op = bytecode[ip, 0]
v = bytecode[ip, 1]
a = bytecode[ip, 2]
s = bytecode[ip, 3]
if op == 0:
ip += 1
continue
if op == O_RETURN:
break
# Jumps with safety checks
if op == O_JUMP:
new_ip = ip + v
if 0 <= new_ip < blen:
ip = new_ip
else:
ip = blen # Exit
continue
if op == O_JUMP_F:
if not cond:
new_ip = ip + v
if 0 <= new_ip < blen:
ip = new_ip
else:
ip = blen # Exit
continue
ip += 1
continue
if op >= 200:
if op == C_TR1:
cond = global_ctx[TR] == 1
elif op == C_STG:
ct = 0
for i in range(3):
if p_stage[i] != -1:
ct += 1
cond = ct >= v
elif op == C_HND:
cond = global_ctx[HD] >= v
elif op == C_LLD:
cond = global_ctx[SC] > global_ctx[OS]
elif op == C_CLR:
if 0 <= a <= 5:
cond = global_ctx[10 + a] > 0
else:
cond = False
elif op == C_GRP:
if 0 <= a <= 4:
cond = global_ctx[30 + a] >= v
else:
cond = False
elif op == C_ENR:
cond = global_ctx[EN] >= v
elif op == C_CTR:
cond = flat_ctx[7] == 1 # SZ=7 (Hand=1)
elif op == C_CMP:
if v > 0:
cond = global_ctx[SC] >= v
else:
cond = global_ctx[SC] > global_ctx[OS]
elif op == C_OPH:
ct = global_ctx[OT]
if v > 0:
cond = ct >= v
else:
cond = ct > 0
else:
cond = True
ip += 1
else:
if cond:
if op == O_DRAW or op == O_CHOOSE or op == O_ADD_H:
# Draw v cards logic (O_CHOOSE is Look v add 1, simplified to Draw 1)
# O_ADD_H is add v from deck
draw_amt = v
if op == O_CHOOSE:
draw_amt = 1
if global_ctx[DK] >= draw_amt:
global_ctx[DK] -= draw_amt
global_ctx[HD] += draw_amt
# Perform actual card movement
for _ in range(draw_amt):
# 1. Find top card
top_card = 0
d_idx_found = -1
for d_idx in range(60):
if p_deck[d_idx] > 0:
top_card = p_deck[d_idx]
d_idx_found = d_idx
break
if top_card > 0:
# 2. Find empty hand slot
for h_idx in range(60):
if p_hand[h_idx] == 0:
p_hand[h_idx] = top_card
p_deck[d_idx_found] = 0
break
else:
# Draw remaining deck? (Simplified: just draw what we can)
t = global_ctx[DK]
if t > 0:
# Draw t cards
for _ in range(t):
top_card = 0
d_idx_found = -1
for d_idx in range(60):
if p_deck[d_idx] > 0:
top_card = p_deck[d_idx]
d_idx_found = d_idx
break
if top_card > 0:
for h_idx in range(60):
if p_hand[h_idx] == 0:
p_hand[h_idx] = top_card
p_deck[d_idx_found] = 0
break
global_ctx[DK] = 0
global_ctx[HD] += t
elif op == O_CHARGE:
if global_ctx[DK] >= v:
global_ctx[DK] -= v
global_ctx[EN] += v
# Move v cards from Deck to "Energy" (which is virtual count or zone?)
# Logic usually says Charge = move to energy zone.
# In fast_logic, we have p_energy_vec (3 slots x 32).
# But Charge typically goes to specific member energy?
# Or global energy? The global context EN is just a count.
# For POC, we just consume from deck. Real logic needs target slot.
for _ in range(v):
# Remove from deck
for d_idx in range(60):
if p_deck[d_idx] > 0:
p_deck[d_idx] = 0
break
else:
t = global_ctx[DK]
global_ctx[DK] = 0
global_ctx[EN] += t
for _ in range(t):
for d_idx in range(60):
if p_deck[d_idx] > 0:
p_deck[d_idx] = 0
break
elif op == O_BLADES:
if s >= 0 and cptr < 32:
p_cont_vec[cptr, 0] = 1
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = 4
p_cont_vec[cptr, 3] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_HEARTS:
if cptr < 32:
p_cont_vec[cptr, 0] = 2
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 5] = a
p_cont_vec[cptr, 9] = 1
cptr += 1
global_ctx[0] += v # SC = 0. Immediate scoring for Vectorized RL.
elif op == O_RECOV_L:
if 0 <= s < p_live.shape[0]:
p_live[s] = 0
elif op == O_RECOV_M:
if 0 <= s < 3:
p_tapped[s] = 0
elif op == O_TAP_O:
if 0 <= s < 3:
opp_tapped[s] = 1
elif op == O_BUFF:
if cptr < 32:
p_cont_vec[cptr, 0] = 8
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_BOOST:
bonus += v
ip += 1
return cptr, 0, bonus
@cuda.jit
def step_kernel(
rng_states,
batch_stage, # (N, 3)
batch_energy_vec, # (N, 3, 32)
batch_energy_count, # (N, 3)
batch_continuous_vec, # (N, 32, 10)
batch_continuous_ptr, # (N,)
batch_tapped, # (N, 3)
batch_live, # (N, 50)
batch_opp_tapped, # (N, 3)
batch_scores, # (N,)
batch_flat_ctx, # (N, 64)
batch_global_ctx, # (N, 128)
batch_hand, # (N, 60)
batch_deck, # (N, 60)
bytecode_map, # (MapSize, 64, 4)
bytecode_index, # (MaxCards, 4)
actions, # (N,)
):
"""
Main CUDA Kernel for Stepping N Environments.
"""
i = cuda.grid(1)
if i >= batch_global_ctx.shape[0]:
return
# Sync Score
batch_global_ctx[i, SC] = batch_scores[i]
act_id = actions[i]
# 1. Apply Action
if act_id > 0:
card_id = act_id
# Check Bounds
if card_id < bytecode_index.shape[0]:
# Assume Ability 0
map_idx = bytecode_index[card_id, 0]
if map_idx >= 0:
code_seq = bytecode_map[map_idx]
# Set Source Zone to Hand (1) -> mapped to index 7 in flat_ctx?
# In fast_logic.py: SZ = 7.
batch_flat_ctx[i, 7] = 1
# Execute
nc, st, bn = resolve_bytecode_device(
code_seq,
batch_flat_ctx[i],
batch_global_ctx[i],
0, # Player ID
batch_hand[i],
batch_deck[i],
batch_stage[i],
batch_energy_vec[i],
batch_energy_count[i],
batch_continuous_vec[i],
batch_continuous_ptr[i], # Passed as scalar? No, ptr[i] is scalar, but device func expects ref?
# fast_logic expects 'p_cont_ptr' as int,
# returns new ptr.
# Wait, resolve_bytecode returns (cptr, ...).
# So we pass VALUE of ptr.
batch_tapped[i],
batch_live[i],
batch_opp_tapped[i],
)
# Update State
batch_continuous_ptr[i] = nc
batch_scores[i] = batch_global_ctx[i, SC] + bn # SC updated inside + bonus?
# Actually resolve_bytecode updates SC in global_ctx for O_HEARTS.
# So we just take global_ctx[SC].
# Reset SZ
batch_flat_ctx[i, 7] = 0
# Remove Card from Hand
found = False
for h_idx in range(60):
if batch_hand[i, h_idx] == card_id:
batch_hand[i, h_idx] = 0
batch_global_ctx[i, 3] -= 1 # HD
found = True
break
# Place on Stage (if Member)
if found and card_id < 900:
for s_idx in range(3):
if batch_stage[i, s_idx] == -1:
batch_stage[i, s_idx] = card_id
break
# Draw Logic (Refill to 5)
# Count Hand
h_cnt = 0
for h_idx in range(60):
if batch_hand[i, h_idx] > 0:
h_cnt += 1
if h_cnt < 5:
# Draw top card
top_card = 0
d_idx_found = -1
for d_idx in range(60):
if batch_deck[i, d_idx] > 0:
top_card = batch_deck[i, d_idx]
d_idx_found = d_idx
break
if top_card > 0:
for h_idx in range(60):
if batch_hand[i, h_idx] == 0:
batch_hand[i, h_idx] = top_card
batch_deck[i, d_idx_found] = 0
batch_global_ctx[i, 3] += 1
batch_global_ctx[i, 6] -= 1
break
# 2. Opponent (Random) Simulation
# Use XOROSHIRO RNG
if rng_states is not None:
r = xoroshiro128p_uniform_float32(rng_states, i)
if r > 0.8:
# Randomly tap an agent member?
pass