LovecaSim / engine /game /fast_logic_backup.py
trioskosmos's picture
Upload folder using huggingface_hub
bb3fbf9 verified
import numpy as np
from numba import njit, prange
# =============================================================================
# HYPER-OPTIMIZED VM CORE (Production Version)
# =============================================================================
# ContextIndex Mappings (Raw Ints)
CV = 20
AT = 22
TS = 12
TI = 13
SZ = 7
CH = 15
# Backward compatibility aliases
CTX_VALUE = CV
CTX_ATTR = AT
CTX_TARGET_SLOT = TS
CTX_TARGET_PLAYER_ID = TI
CTX_SOURCE_ZONE_IDX = SZ
CTX_CHOICE_INDEX = CH
# GlobalContext Mappings
SC = 0
OS = 1
TR = 2
HD = 3
DI = 4
EN = 5
DK = 6
OT = 7
PH = 8
# Opcodes
O_DRAW = 10
O_BLADES = 11
O_HEARTS = 12
O_REDUCE_COST = 13
O_RECOV_L = 15
O_BOOST = 16
O_RECOV_M = 17
O_BUFF = 18
O_MOVE_MEMBER = 20
O_SWAP_CARDS = 21
O_SEARCH_DECK = 22
O_CHARGE = 23
O_ORDER_DECK = 28
O_SELECT_MODE = 30
O_TAP_O = 32
O_PLACE_UNDER = 33
O_LOOK_AND_CHOOSE = 41
O_ACTIVATE_MEMBER = 43
O_ADD_H = 44
O_REPLACE_EFFECT = 46
O_TRIGGER_REMOTE = 47
O_REDUCE_HEART_REQ = 48
O_RETURN = 1
O_JUMP = 2
O_JUMP_F = 3
# Conditions
C_TR1 = 200
C_CLR = 202
C_STG = 203
C_HND = 204
C_CTR = 206
C_LLD = 207
C_GRP = 208
C_OPH = 210
C_ENR = 213
C_CMP = 220
@njit(nopython=True, cache=True)
def resolve_bytecode(
bytecode,
flat_ctx,
global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_cont_vec,
out_cptr, # Modified: Pass by ref array (size 1)
p_tapped,
p_live,
opp_tapped,
b_map,
b_idx,
out_bonus, # Modified: Pass by ref array (size 1)
):
ip = 0
# Load cptr from reference
cptr = out_cptr[0]
# Bonus is accumulated into out_bonus[0]
cond = True
blen = bytecode.shape[0]
# SAFETY: Infinite loop protection
safety_counter = 0
while ip < blen and safety_counter < 500:
safety_counter += 1
op = bytecode[ip, 0]
v = bytecode[ip, 1]
a = bytecode[ip, 2]
s = bytecode[ip, 3]
if op == 0:
ip += 1
continue
if op == O_RETURN:
break
# Dynamic Target Handling (MEMBER_SELECT)
if s == 10:
s = int(flat_ctx[TS])
if op == O_JUMP:
new_ip = ip + v
if 0 <= new_ip < blen:
ip = new_ip
else:
ip = blen
continue
if op == O_JUMP_F:
if not cond:
new_ip = ip + v
if 0 <= new_ip < blen:
ip = new_ip
else:
ip = blen
continue
ip += 1
continue
if op == O_SELECT_MODE:
choice = int(flat_ctx[CH])
if 0 <= choice < v:
jump_ip = ip + 1 + choice
if jump_ip < blen:
offset = bytecode[jump_ip, 1]
new_ip = jump_ip + offset
if 0 <= new_ip < blen:
ip = new_ip
continue
ip += v + 1
continue
if op >= 200:
if op == C_TR1:
cond = global_ctx[TR] == 1
elif op == C_STG:
ct = 0
for i in range(3):
if p_stage[i] != -1:
ct += 1
cond = ct >= v
elif op == C_HND:
cond = global_ctx[HD] >= v
elif op == C_LLD:
cond = global_ctx[SC] > global_ctx[OS]
elif op == C_CLR:
if 0 <= a <= 5:
cond = global_ctx[10 + a] > 0
else:
cond = False
elif op == C_GRP:
if 0 <= a <= 4:
cond = global_ctx[30 + a] >= v
else:
cond = False
elif op == C_ENR:
cond = global_ctx[EN] >= v
elif op == C_CTR:
cond = flat_ctx[SZ] == 1
elif op == C_CMP:
if v > 0:
cond = global_ctx[SC] >= v
else:
cond = global_ctx[SC] > global_ctx[OS]
elif op == C_OPH:
ct = global_ctx[OT]
if v > 0:
cond = ct >= v
else:
cond = ct > 0
else:
cond = True
ip += 1
else:
if cond:
if op == O_DRAW:
if global_ctx[DK] >= v:
global_ctx[DK] -= v
global_ctx[HD] += v
else:
t = global_ctx[DK]
global_ctx[DK] = 0
global_ctx[HD] += t
elif op == O_CHARGE:
if global_ctx[DK] >= v:
global_ctx[DK] -= v
global_ctx[EN] += v
else:
t = global_ctx[DK]
global_ctx[DK] = 0
global_ctx[EN] += t
elif op == O_BLADES:
if s >= 0 and cptr < 32:
p_cont_vec[cptr, 0] = 1
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = 4
p_cont_vec[cptr, 3] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_HEARTS:
if cptr < 32:
p_cont_vec[cptr, 0] = 2
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 5] = a
p_cont_vec[cptr, 9] = 1
cptr += 1
global_ctx[0] += v
elif op == O_REDUCE_COST:
if cptr < 32:
p_cont_vec[cptr, 0] = 3
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_REDUCE_HEART_REQ:
if cptr < 32:
p_cont_vec[cptr, 0] = 48
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_REPLACE_EFFECT:
if cptr < 32:
p_cont_vec[cptr, 0] = 46
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_RECOV_L:
pass
elif op == O_RECOV_M:
pass
elif op == O_ACTIVATE_MEMBER:
if 0 <= s < 3:
p_tapped[s] = 0
elif op == O_SWAP_CARDS:
removed = 0
for h_idx in range(60):
if p_hand[h_idx] > 0:
p_hand[h_idx] = 0
global_ctx[HD] -= 1
removed += 1
if removed >= v:
break
# Actually Move Cards for Draw
drawn = 0
for d_idx in range(60):
if p_deck[d_idx] > 0:
card_id = p_deck[d_idx]
p_deck[d_idx] = 0
# Find empty hand slot
for h_idx in range(60):
if p_hand[h_idx] == 0:
p_hand[h_idx] = card_id
break
global_ctx[DK] -= 1
global_ctx[HD] += 1
drawn += 1
if drawn >= v:
break
# If not enough cards, we stop (counters already updated per card)
elif op == O_PLACE_UNDER:
placed = 0
for h_idx in range(59, -1, -1):
if p_hand[h_idx] > 0:
cid = p_hand[h_idx]
p_hand[h_idx] = 0
global_ctx[HD] -= 1
if 0 <= s < 3:
for e_idx in range(32):
if p_energy_vec[s, e_idx] == 0:
p_energy_vec[s, e_idx] = cid
p_energy_count[s] += 1
break
placed += 1
if placed >= v:
break
elif op == O_MOVE_MEMBER:
dest_slot = int(flat_ctx[TS])
if 0 <= s < 3 and 0 <= dest_slot < 3 and s != dest_slot:
temp_id = p_stage[s]
p_stage[s] = p_stage[dest_slot]
p_stage[dest_slot] = temp_id
temp_tap = p_tapped[s]
p_tapped[s] = p_tapped[dest_slot]
p_tapped[dest_slot] = temp_tap
for e_idx in range(32):
temp_e = p_energy_vec[s, e_idx]
p_energy_vec[s, e_idx] = p_energy_vec[dest_slot, e_idx]
p_energy_vec[dest_slot, e_idx] = temp_e
temp_ec = p_energy_count[s]
p_energy_count[s] = p_energy_count[dest_slot]
p_energy_count[dest_slot] = temp_ec
elif op == O_TAP_O:
if 0 <= s < 3:
opp_tapped[s] = 1
elif op == O_BUFF:
if cptr < 32:
p_cont_vec[cptr, 0] = 8
p_cont_vec[cptr, 1] = v
p_cont_vec[cptr, 2] = s
p_cont_vec[cptr, 9] = 1
cptr += 1
elif op == O_BOOST:
out_bonus[0] += v
elif op == O_LOOK_AND_CHOOSE:
choice_idx = int(flat_ctx[CH])
if choice_idx < 0 or choice_idx >= v:
choice_idx = 0
indices = np.full(v, -1, dtype=np.int32)
ptr = 0
for d_idx in range(60):
if p_deck[d_idx] > 0:
indices[ptr] = d_idx
ptr += 1
if ptr >= v:
break
if ptr > 0:
if choice_idx >= ptr:
choice_idx = 0
real_idx = indices[choice_idx]
if real_idx != -1:
chosen_card = p_deck[real_idx]
if chosen_card > 0:
for h_idx in range(60):
if p_hand[h_idx] == 0:
p_hand[h_idx] = chosen_card
global_ctx[HD] += 1
break
for k in range(ptr):
rid = indices[k]
if rid != -1:
p_deck[rid] = 0
global_ctx[DK] -= 1
elif op == O_ORDER_DECK:
indices = np.full(v, -1, dtype=np.int32)
vals = np.full(v, 0, dtype=np.int32)
ptr = 0
for d_idx in range(60):
if p_deck[d_idx] > 0:
indices[ptr] = d_idx
vals[ptr] = p_deck[d_idx]
ptr += 1
if ptr >= v:
break
if ptr > 1:
for k in range(ptr // 2):
temp = vals[k]
vals[k] = vals[ptr - 1 - k]
vals[ptr - 1 - k] = temp
for k in range(ptr):
p_deck[indices[k]] = vals[k]
elif op == O_ADD_H:
if global_ctx[DK] >= v:
global_ctx[DK] -= v
global_ctx[HD] += v
elif op == O_SEARCH_DECK:
target_idx = int(flat_ctx[TS])
if 0 <= target_idx < 60 and p_deck[target_idx] > 0:
card_to_move = p_deck[target_idx]
p_deck[target_idx] = 0
for h_idx in range(60):
if p_hand[h_idx] == 0:
p_hand[h_idx] = card_to_move
global_ctx[HD] += 1
global_ctx[DK] -= 1
break
else:
for d_idx in range(60):
if p_deck[d_idx] > 0:
card_to_move = p_deck[d_idx]
p_deck[d_idx] = 0
for h_idx in range(60):
if p_hand[h_idx] == 0:
p_hand[h_idx] = card_to_move
global_ctx[HD] += 1
global_ctx[DK] -= 1
break
break
elif op == O_TRIGGER_REMOTE:
target_slot = s
if target_slot < 0:
pass
target_card_id = -1
if 0 <= target_slot < 3:
target_card_id = p_stage[target_slot]
if target_card_id > 0 and target_card_id < b_idx.shape[0]:
map_idx = b_idx[target_card_id, 0]
if map_idx >= 0:
sub_code = b_map[map_idx]
out_cptr[0] = cptr
resolve_bytecode(
sub_code,
flat_ctx,
global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_cont_vec,
out_cptr,
p_tapped,
p_live,
opp_tapped,
b_map,
b_idx,
out_bonus,
)
cptr = out_cptr[0]
ip += 1
out_cptr[0] = cptr
@njit(nopython=True)
def batch_resolve_bytecode(
batch_bytecode,
batch_flat_ctx,
batch_global_ctx,
player_id,
p_hand,
p_deck,
p_stage,
p_energy_vec,
p_energy_count,
p_cont_vec,
p_cont_ptr,
p_tapped,
p_live,
opp_tapped,
b_map,
b_idx,
):
num_envs = batch_bytecode.shape[0]
for i in prange(num_envs):
cptr_slice = p_cont_ptr[i : i + 1]
dummy_bonus = np.zeros(1, dtype=np.int32)
resolve_bytecode(
batch_bytecode[i],
batch_flat_ctx[i],
batch_global_ctx[i],
player_id,
p_hand[i],
p_deck[i],
p_stage[i],
p_energy_vec[i],
p_energy_count[i],
p_cont_vec[i],
cptr_slice,
p_tapped[i],
p_live[i],
opp_tapped[i],
b_map,
b_idx,
dummy_bonus,
)
@njit(nopython=True, cache=True)
def copy_state(s_stg, s_ev, s_ec, s_cv, d_stg, d_ev, d_ec, d_cv):
d_stg[:] = s_stg[:]
d_ev[:] = s_ev[:]
d_ec[:] = s_ec[:]
d_cv[:] = s_cv[:]
@njit(nopython=True)
def batch_apply_action(
actions,
pid,
p_stg,
p_ev,
p_ec,
p_cv,
p_cp,
p_tap,
p_sb,
p_lr,
o_tap,
f_ctx_batch,
g_ctx_batch,
p_h,
p_d,
b_map,
b_idx,
):
# Sync individual scores and turn to global_ctx before stepping
for i in prange(actions.shape[0]):
g_ctx_batch[i, SC] = p_sb[i]
act_id = actions[i]
if act_id > 0:
card_id = act_id
if card_id < b_idx.shape[0]:
map_idx = b_idx[card_id, 0]
if map_idx >= 0:
code_seq = b_map[map_idx]
f_ctx_batch[i, 7] = 1
cptr_slice = p_cp[i : i + 1]
delta_bonus = np.zeros(1, dtype=np.int32)
resolve_bytecode(
code_seq,
f_ctx_batch[i],
g_ctx_batch[i],
pid,
p_h[i],
p_d[i],
p_stg[i],
p_ev[i],
p_ec[i],
p_cv[i],
cptr_slice,
p_tap[i],
p_lr[i],
o_tap[i],
b_map,
b_idx,
delta_bonus,
)
p_sb[i] += delta_bonus[0]
f_ctx_batch[i, 7] = 0
found_h = False
for h_idx in range(60):
if p_h[i, h_idx] == card_id:
p_h[i, h_idx] = 0
g_ctx_batch[i, 3] -= 1
found_h = True
break
if found_h and card_id < 900:
for s_idx in range(3):
if p_stg[i, s_idx] == -1:
p_stg[i, s_idx] = card_id
break
cnt = 0
for h_idx in range(60):
if p_h[i, h_idx] > 0:
cnt += 1
if cnt < 5:
top_card = 0
deck_idx = -1
for d_idx in range(60):
if p_d[i, d_idx] > 0:
top_card = p_d[i, d_idx]
deck_idx = d_idx
break
if top_card > 0:
for h_idx in range(60):
if p_h[i, h_idx] == 0:
p_h[i, h_idx] = top_card
p_d[i, deck_idx] = 0
g_ctx_batch[i, 3] += 1
g_ctx_batch[i, 6] -= 1
break
else:
pass
@njit(nopython=True, cache=True)
def apply_action(aid, pid, p_stg, p_ev, p_ec, p_cv, p_cp, p_tap, p_sb, p_lr, o_tap, f_ctx, g_ctx, p_h, p_d):
# Specialized fast-path for Action 1 (Simulation)
if aid == 1:
bc = np.zeros((1, 4), dtype=np.int32)
bc[0, 0] = 11 # O_BLADES
bc[0, 1] = 1
bc[0, 3] = 0
d_map = np.zeros((1, 1, 4), dtype=np.int32)
d_idx = np.zeros((1, 4), dtype=np.int32)
cptr_arr = np.array([p_cp], dtype=np.int32)
bn_arr = np.zeros(1, dtype=np.int32)
resolve_bytecode(
bc, f_ctx, g_ctx, pid, p_h, p_d, p_stg, p_ev, p_ec, p_cv, cptr_arr, p_tap, p_lr, o_tap, d_map, d_idx, bn_arr
)
return cptr_arr[0], p_sb + bn_arr[0]
return p_cp, p_sb