LovecaSim / ai /environments /vector_env.py
trioskosmos's picture
Upload ai/environments/vector_env.py with huggingface_hub
d980970 verified
import os
from typing import List
import numpy as np
from numba import njit, prange
import ai.research.integrated_step_numba as isn
from engine.game.fast_logic import (
batch_apply_action,
resolve_bytecode,
)
@njit(cache=True)
def step_vectorized(
actions: np.ndarray,
batch_stage: np.ndarray,
batch_energy_vec: np.ndarray,
batch_energy_count: np.ndarray,
batch_continuous_vec: np.ndarray,
batch_continuous_ptr: np.ndarray,
batch_tapped: np.ndarray,
batch_live: np.ndarray,
batch_opp_tapped: np.ndarray,
batch_scores: np.ndarray,
batch_flat_ctx: np.ndarray,
batch_global_ctx: np.ndarray,
batch_hand: np.ndarray,
batch_deck: np.ndarray,
# New: Bytecode Maps
bytecode_map: np.ndarray, # (GlobalOpMapSize, MaxBytecodeLen, 4)
bytecode_index: np.ndarray, # (NumCards, NumAbilities) -> Index in map
card_stats: np.ndarray,
batch_trash: np.ndarray, # Added
):
"""
Step N game environments in parallel using JIT logic and Real Card Data.
"""
# Score sync now handled internally by batch_apply_action
batch_apply_action(
actions,
0, # player_id
batch_stage,
batch_energy_vec,
batch_energy_count,
batch_continuous_vec,
batch_continuous_ptr,
batch_tapped,
batch_scores,
batch_live,
batch_opp_tapped,
batch_flat_ctx,
batch_global_ctx,
batch_hand,
batch_deck,
batch_trash, # Added
bytecode_map,
bytecode_index,
card_stats,
)
rewards = np.zeros(actions.shape[0], dtype=np.float32)
dones = np.zeros(actions.shape[0], dtype=np.bool_)
return rewards, dones
class VectorGameState:
"""
Manages a batch of independent GameStates for high-throughput training.
"""
def __init__(self, num_envs: int, opp_mode: int = 0, force_start_order: int = -1):
self.num_envs = num_envs
# opp_mode: 0=Heuristic, 1=Random, 2=Solitaire (Pass Only)
self.opp_mode = opp_mode
self.force_start_order = force_start_order # -1=Random, 0=P1, 1=P2
self.turn = 1
# Batched state buffers - Player 0 (Agent)
self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32)
self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32)
self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32)
self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32)
self.batch_tapped = np.zeros((num_envs, 16), dtype=np.int32) # Slots 0-2, Energy 3-15
self.batch_live = np.zeros((num_envs, 50), dtype=np.int32)
self.batch_opp_tapped = np.zeros((num_envs, 16), dtype=np.int32)
self.batch_scores = np.zeros(num_envs, dtype=np.int32)
# Batched state buffers - Opponent State (Player 1)
self.opp_stage = np.full((num_envs, 3), -1, dtype=np.int32)
self.opp_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32) # Match Agent Shape
self.opp_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
self.opp_tapped = np.zeros((num_envs, 16), dtype=np.int8)
self.opp_live = np.zeros((num_envs, 50), dtype=np.int32) # Added Opp Live
self.opp_scores = np.zeros(num_envs, dtype=np.int32)
# New State Tracking for Integrated Step
self.prev_scores = np.zeros(num_envs, dtype=np.int32)
self.prev_opp_scores = np.zeros(num_envs, dtype=np.int32)
self.prev_phases = np.zeros(num_envs, dtype=np.int32)
self.episode_returns = np.zeros(num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(num_envs, dtype=np.int32)
# Opponent Finite Deck Buffers
self.opp_hand = np.zeros((num_envs, 60), dtype=np.int32)
self.opp_deck = np.zeros((num_envs, 60), dtype=np.int32)
# Load Numba functions
import os
if os.getenv("USE_SCENARIOS", "0") == "1":
self._load_scenarios()
# Scenario Reward Scaling
self.scenario_reward_scale = float(os.getenv("SCENARIO_REWARD_SCALE", "1.0"))
if os.getenv("USE_SCENARIOS", "0") == "1" and self.scenario_reward_scale != 1.0:
print(f" [VectorEnv] Scenario Reward Scale: {self.scenario_reward_scale}")
# New: Opponent History Buffer (Top 20 cards e.g.)
self.batch_opp_history = np.zeros((num_envs, 50), dtype=np.int32)
# Pre-allocated context buffers (Extreme speed optimization)
self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32)
self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32)
self.opp_global_ctx = np.zeros((num_envs, 128), dtype=np.int32) # Persistent Opponent Context
self.batch_hand = np.zeros((num_envs, 60), dtype=np.int32)
self.batch_deck = np.zeros((num_envs, 60), dtype=np.int32)
self.batch_trash = np.zeros((num_envs, 60), dtype=np.int32) # Added Trash
self.opp_trash = np.zeros((num_envs, 60), dtype=np.int32) # Added Opp Trash
# Observation Buffer
# 20480 floats per env to handle Full Hand (60 cards) + Opponent + Stats
# Increased for "Real Vision" upgrade
# Observation Buffer
# Mode Selection
import os
self.obs_mode = os.getenv("OBS_MODE", "STANDARD")
if self.obs_mode == "COMPRESSED":
self.obs_dim = 512
self.action_space_dim = 2000
print(" [VectorEnv] Observation Mode: COMPRESSED (512-dim)")
elif self.obs_mode == "IMAX":
self.obs_dim = 8192
self.action_space_dim = 2000
print(" [VectorEnv] Observation Mode: IMAX (8192-dim)")
elif self.obs_mode == "ATTENTION":
self.obs_dim = 2240
self.action_space_dim = 512
print(" [VectorEnv] Observation Mode: ATTENTION (2240-dim)")
else:
self.obs_dim = 2304
self.action_space_dim = 2000
print(" [VectorEnv] Observation Mode: STANDARD (2304-dim)")
self.obs_buffer = np.zeros((self.num_envs, self.obs_dim), dtype=np.float32)
# Terminal Obs Buffer for Auto-Reset
self.terminal_obs_buffer = np.zeros((self.num_envs, self.obs_dim), dtype=np.float32)
# Global Turn Counter (Pointer for Numba)
self.turn_number_ptr = np.zeros(1, dtype=np.int32)
self.turn_number_ptr[0] = 1
# Game Config (Turn Limits & Rewards)
# 0: Turn Limit, 1: Step Limit, 2: Win Reward, 3: Lose Reward, 4: Score Scale, 5: Turn Penalty
self.game_config = np.zeros(10, dtype=np.float32)
self.game_config[0] = float(os.getenv("GAME_TURN_LIMIT", "100"))
self.game_config[1] = float(os.getenv("GAME_STEP_LIMIT", "1000"))
self.game_config[2] = float(os.getenv("GAME_REWARD_WIN", "100.0"))
self.game_config[3] = float(os.getenv("GAME_REWARD_LOSE", "-100.0"))
self.game_config[4] = float(os.getenv("GAME_REWARD_SCORE_SCALE", "50.0"))
self.game_config[5] = float(os.getenv("GAME_REWARD_TURN_PENALTY", "-0.05"))
print(
f" [VectorEnv] Game Config: Turns={int(self.game_config[0])}, Steps={int(self.game_config[1])}, Win={self.game_config[2]}, Lose={self.game_config[3]}"
)
# Load Bytecode Map
self._load_bytecode()
# Check for Fixed Deck Override
fixed_deck_path = os.getenv("USE_FIXED_DECK")
if fixed_deck_path:
self._load_fixed_deck_pool(fixed_deck_path)
else:
self._load_verified_deck_pool()
def _load_bytecode(self):
import json
try:
with open("data/cards_numba.json", "r") as f:
raw_map = json.load(f)
# Convert to numpy array
# Format: key "cardid_abidx" -> List[int]
# storage:
# 1. giant array of bytecodes (N, MaxLen, 4)
# 2. lookup index (CardID, AbIdx) -> Index in giant array
self.max_cards = 2000
self.max_abilities = 8
self.max_len = 128 # Max 128 instructions per ability for future expansion
# Count unique compiled entries
unique_entries = len(raw_map)
# (Index 0 is empty/nop)
self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32)
self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32)
idx_counter = 1
for key, bc_list in raw_map.items():
cid, aid = map(int, key.split("_"))
if cid < self.max_cards and aid < self.max_abilities:
# reshape list to (M, 4)
bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
length = min(bc_arr.shape[0], self.max_len)
self.bytecode_map[idx_counter, :length] = bc_arr[:length]
self.bytecode_index[cid, aid] = idx_counter
idx_counter += 1
print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.")
# --- IMAX PRO VISION (Stride 80) ---
# Fixed Geography: No maps, no shifting. Dedicated space per ability.
# 0-19: Stats (Cost, Hearts, Traits, Live Reqs)
# 20-35: Ability 1 (Trig, Cond, Opts, 3 Effs)
# 36-47: Ability 2 (Trig, Cond, 3 Effs)
# 48-59: Ability 3 (Trig, Cond, 3 Effs)
# 60-71: Ability 4 (Trig, Cond, 3 Effs)
# 79: Location Signal (Runtime Only)
self.card_stats = np.zeros((self.max_cards, 80), dtype=np.int32)
try:
import json
import re
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db = json.load(f)
# We need to map Card ID (int) -> Stats
# cards_compiled.json is keyed by string integer "0", "1"...
count = 0
# Build character name to ID mapping for Baton Pass
name_to_id = {}
# First pass: collect all character names and their IDs
if "member_db" in db:
for cid_str, card in db["member_db"].items():
cid = int(cid_str)
if cid < self.max_cards:
# Store character name to ID mapping
name = card.get("name", "")
if name:
name_to_id[name] = cid
# Load Members
if "member_db" in db:
for cid_str, card in db["member_db"].items():
cid = int(cid_str)
if cid < self.max_cards:
# 0. Card Type (1=Member)
self.card_stats[cid, 10] = 1
# 1. Cost
self.card_stats[cid, 0] = card.get("cost", 0)
# 2. Blades
self.card_stats[cid, 1] = card.get("blades", 0)
# 3. Hearts (Sum of array elements > 0?)
# Actually just count non-zero hearts in array? Or sum of values?
# Usually 'hearts' is [points, points...]. Let's sum points.
h_arr = card.get("hearts", [])
self.card_stats[cid, 2] = sum(h_arr)
# 4. Store detailed hearts for Members too (indices 12-18)
# [Pn, Rd, Yl, Gr, Bl, Pu, All]
for r_idx in range(min(len(h_arr), 7)):
self.card_stats[cid, 12 + r_idx] = h_arr[r_idx]
# Store Character ID in index 19 for Baton Pass condition
name = card.get("name", "")
if name in name_to_id:
self.card_stats[cid, 19] = name_to_id[name]
# Infer Primary Color (for visualization/traits)
col = 0
for cidx, val in enumerate(h_arr):
if val > 0:
col = cidx + 1 # 1-based color
break
self.card_stats[cid, 3] = col
# 5. Volume/Draw Icons
self.card_stats[cid, 4] = card.get("volume_icons", 0)
self.card_stats[cid, 5] = card.get("draw_icons", 0)
# 6. Blade Hearts (flipped as yell)
bh = card.get("blade_hearts", [])
for b_idx in range(min(len(bh), 7)):
self.card_stats[cid, 40 + b_idx] = bh[b_idx]
# Live Card Stats
if "required_hearts" in card:
# Pack Required Hearts into 12-18 (Pink..Purple, All)
reqs = card.get("required_hearts", [])
for r_idx in range(min(len(reqs), 7)):
self.card_stats[cid, 12 + r_idx] = reqs[r_idx]
# --- FIXED GEOGRAPHY ABILITY PACKING ---
ab_list = card.get("abilities", [])
# Helper to pack an ability into a fixed block
def pack_ability_block(ab, base_idx, has_opts=False):
if not ab:
return
# Trigger (Base + 0)
self.card_stats[cid, base_idx] = ab.get("trigger", 0)
# Condition (Base + 1, 2)
conds = ab.get("conditions", [])
if conds:
self.card_stats[cid, base_idx + 1] = conds[0].get("type", 0)
self.card_stats[cid, base_idx + 2] = conds[0].get("params", {}).get("value", 0)
# Effects
effs = ab.get("effects", [])
eff_start = base_idx + 3
if has_opts: # Ability 1 has extra space for Options
eff_start = base_idx + 9 # Skip 6 slots for options
# Pack Options (from first effect)
if effs:
m_opts = effs[0].get("modal_options", [])
if len(m_opts) > 0 and len(m_opts[0]) > 0:
o = m_opts[0][0] # Opt 1
self.card_stats[cid, base_idx + 3] = o.get("effect_type", 0)
self.card_stats[cid, base_idx + 4] = o.get("value", 0)
self.card_stats[cid, base_idx + 5] = o.get("target", 0)
if len(m_opts) > 1 and len(m_opts[1]) > 0:
o = m_opts[1][0] # Opt 2
self.card_stats[cid, base_idx + 6] = o.get("effect_type", 0)
self.card_stats[cid, base_idx + 7] = o.get("value", 0)
self.card_stats[cid, base_idx + 8] = o.get("target", 0)
# Pack up to 3 Effects
for e_i in range(min(len(effs), 3)):
e = effs[e_i]
off = eff_start + (e_i * 3)
self.card_stats[cid, off] = e.get("effect_type", 0)
self.card_stats[cid, off + 1] = e.get("value", 0)
self.card_stats[cid, off + 2] = e.get("target", 0)
# Block 1: Ability 1 (Indices 20-35) [Has Options]
if len(ab_list) > 0:
pack_ability_block(ab_list[0], 20, has_opts=True)
# Block 2: Ability 2 (Indices 36-47)
if len(ab_list) > 1:
pack_ability_block(ab_list[1], 36)
# Block 3: Ability 3 (Indices 48-59)
if len(ab_list) > 2:
pack_ability_block(ab_list[2], 48)
# Block 4: Ability 4 (Indices 60-71)
if len(ab_list) > 3:
pack_ability_block(ab_list[3], 60)
# 7. Type
self.card_stats[cid, 10] = 1
# 8. Traits Bitmask (Groups & Units) -> Stores in Index 11
# Bits 0-4: Groups (Max 5)
# Bits 5-20: Units (Max 16)
mask = 0
groups = card.get("groups", [])
for g in groups:
try:
mask |= 1 << (int(g) % 20)
except:
pass
units = card.get("units", [])
for u in units:
try:
mask |= 1 << ((int(u) % 20) + 5)
except:
pass
self.card_stats[cid, 11] = mask
count += 1
# Load Lives
if "live_db" in db:
for cid_str, card in db["live_db"].items():
cid = int(cid_str)
if cid < self.max_cards:
# Type: Live=2
self.card_stats[cid, 10] = 2
# Required Hearts
reqs = card.get("required_hearts", [])
for r_idx in range(min(len(reqs), 7)):
self.card_stats[cid, 12 + r_idx] = reqs[r_idx]
# Score
self.card_stats[cid, 38] = card.get("score", 0)
# Store Character ID in index 19 for Baton Pass condition
name = card.get("name", "")
if name in name_to_id:
self.card_stats[cid, 19] = name_to_id[name]
count += 1
print(f" [VectorEnv] Loaded detailed stats/abilities for {count} cards.")
# --- RUNTIME PATCHING FOR BATON PASS CARDS ---
# Scan all cards for "バトンタッチして" condition and inject C_BATON opcode
print(" [VectorEnv] Starting runtime patching for Baton Pass cards...")
# Load the original bytecode map to scan for cards that need patching
with open("data/cards_numba.json", "r") as f:
raw_map = json.load(f)
# Regex pattern to detect Baton Pass condition
baton_pattern = re.compile(r"「(.+?)」からバトンタッチして")
patched_count = 0
idx_counter = 1 # Start from 1 since 0 is reserved for empty
# First pass: count how many patched bytecodes we'll need
baton_cards = []
for cid_str, card in {**db.get("member_db", {}), **db.get("live_db", {})}.items():
cid = int(cid_str)
if cid >= self.max_cards:
continue
# Check if this card has abilities with Baton Pass condition
ab_list = card.get("abilities", [])
for ab_idx, ability in enumerate(ab_list):
raw_text = ability.get("raw_text", "")
# Check if the raw text contains the Baton Pass pattern
match = baton_pattern.search(raw_text)
if match:
target_name = match.group(1)
# Get the target character ID
target_cid = name_to_id.get(target_name, -1)
if target_cid != -1:
original_key = f"{cid}_{ab_idx}"
if original_key in raw_map:
baton_cards.append((cid, ab_idx, target_cid, raw_map[original_key], target_name))
# Second pass: expand bytecode_map if needed and apply patches
for cid, ab_idx, target_cid, original_bytecode, target_name in baton_cards:
# Get the card object again to access the name
card = {}
if str(cid) in db.get("member_db", {}):
card = db["member_db"][str(cid)]
elif str(cid) in db.get("live_db", {}):
card = db["live_db"][str(cid)]
# This card has a Baton Pass condition that needs to be patched
print(
f" [VectorEnv] Patching Baton Pass for card {cid} ('{card.get('name', '')}') targeting '{target_name}' (ID: {target_cid})"
)
# Create new bytecode sequence with C_BATON condition prepended
# Format: [C_BATON, Target_Char_ID, 0, 0] + original_bytecode
# Prepend CHECK_BATON (231) opcode
new_bytecode = [231, target_cid, 0, 0] + original_bytecode # original_bytecode is already a list
# Find a free slot in the bytecode map for the patched version
if idx_counter < self.bytecode_map.shape[0]:
# Reshape the new bytecode to fit the map dimensions
bc_arr = np.array(new_bytecode, dtype=np.int32).reshape(-1, 4)
length = min(bc_arr.shape[0], self.max_len)
self.bytecode_map[idx_counter, :length] = bc_arr[:length]
# Update the bytecode index to point to the new patched version
self.bytecode_index[cid, ab_idx] = idx_counter
patched_count += 1
print(
f" [VectorEnv] Successfully patched ability {ab_idx} for card {cid}, new bytecode index: {idx_counter}"
)
idx_counter += 1
else:
print(f" [VectorEnv] Error: No more space in bytecode map for card {cid}")
print(f" [VectorEnv] Runtime patching completed. {patched_count} cards patched.")
except Exception as e:
print(f" [VectorEnv] Warning: Failed to load compiled stats: {e}")
except FileNotFoundError:
print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.")
self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32)
self.bytecode_index = np.zeros((1, 1), dtype=np.int32)
def _load_verified_deck_pool(self):
import json
try:
# Load Verified List
with open("data/verified_card_pool.json", "r", encoding="utf-8") as f:
verified_data = json.load(f)
# Load DB to map CardNo -> CardID
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db_data = json.load(f)
self.ability_member_ids = []
self.ability_live_ids = []
self.vanilla_member_ids = []
self.vanilla_live_ids = []
# Map numbers to IDs and types
member_no_map = {}
live_no_map = {}
for cid, cdata in db_data.get("member_db", {}).items():
member_no_map[cdata["card_no"]] = int(cid)
for cid, cdata in db_data.get("live_db", {}).items():
live_no_map[cdata["card_no"]] = int(cid)
# Check for list compatibility mode
if isinstance(verified_data, list):
print(" [VectorEnv] Loading Verified Pool from List (Compatibility Mode)")
for v_no in verified_data:
if v_no in member_no_map:
self.ability_member_ids.append(member_no_map[v_no])
elif v_no in live_no_map:
self.ability_live_ids.append(live_no_map[v_no])
else:
# 1. Primary Pool: Abilities (Categorized)
# Support both old keys (verified_abilities) and new keys (members)
source_members = verified_data.get("verified_abilities", []) + verified_data.get("members", [])
for v_no in source_members:
if v_no in member_no_map:
self.ability_member_ids.append(member_no_map[v_no])
source_lives = verified_data.get("verified_lives", []) + verified_data.get("lives", [])
for v_no in source_lives:
if v_no in live_no_map:
self.ability_live_ids.append(live_no_map[v_no])
# 2. Secondary Pool: Vanilla
for v_no in verified_data.get("vanilla_members", []):
if v_no in member_no_map:
self.vanilla_member_ids.append(member_no_map[v_no])
for v_no in verified_data.get("vanilla_lives", []):
if v_no in live_no_map:
self.vanilla_live_ids.append(live_no_map[v_no])
# Fallback/Warnings
if not self.ability_member_ids:
if self.vanilla_member_ids:
print(" [VectorEnv] Warning: No ability members. using vanilla members.")
self.ability_member_ids = self.vanilla_member_ids
else:
print(" [VectorEnv] Warning: No members found. Using ID 1.")
self.ability_member_ids = [1]
if not self.ability_live_ids:
if self.vanilla_live_ids:
print(" [VectorEnv] Warning: No ability lives. Using vanilla lives.")
self.ability_live_ids = self.vanilla_live_ids
else:
print(" [VectorEnv] Warning: No lives found. Using ID 999 (Dummy).")
self.ability_live_ids = [999]
print(
f" [VectorEnv] Pools: {len(self.ability_member_ids)} Ability Members, {len(self.ability_live_ids)} Ability Lives."
)
print(
f" [VectorEnv] Fallbacks: {len(self.vanilla_member_ids)} Vanilla Members, {len(self.vanilla_live_ids)} Vanilla Lives."
)
self.ability_member_ids = np.array(self.ability_member_ids, dtype=np.int32)
self.ability_live_ids = np.array(self.ability_live_ids, dtype=np.int32)
self.vanilla_member_ids = np.array(self.vanilla_member_ids, dtype=np.int32)
self.vanilla_live_ids = np.array(self.vanilla_live_ids, dtype=np.int32)
except Exception as e:
print(f" [VectorEnv] Deck Load Error: {e}")
self.ability_member_ids = np.array([1], dtype=np.int32)
self.ability_live_ids = np.array([999], dtype=np.int32)
self.vanilla_member_ids = np.array([], dtype=np.int32)
self.vanilla_live_ids = np.array([], dtype=np.int32)
def _load_fixed_deck_pool(self, deck_path: str):
import json
import re
print(f" [VectorEnv] Loading FIXED DECK from: {deck_path}")
try:
# 1. Load DB to map CardNo -> CardID
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db_data = json.load(f)
member_no_map = {}
live_no_map = {}
for cid, cdata in db_data.get("member_db", {}).items():
member_no_map[cdata["card_no"]] = int(cid)
for cid, cdata in db_data.get("live_db", {}).items():
live_no_map[cdata["card_no"]] = int(cid)
# 2. Parse Markdown
with open(deck_path, "r", encoding="utf-8") as f:
lines = f.readlines()
members = []
lives = []
for line in lines:
# Look for "4x [PL!-...]" - flexible for markdown bolding like **4x**
match = re.search(r"(\d+)x.*?\[(PL!-[^\]]+)\]", line)
if match:
count = int(match.group(1))
card_no = match.group(2)
if card_no in member_no_map:
for _ in range(count):
members.append(member_no_map[card_no])
elif card_no in live_no_map:
for _ in range(count):
lives.append(live_no_map[card_no])
# 3. Finalize
if len(members) != 48:
print(f" [VectorEnv] Warning: Fixed deck members count is {len(members)}, expected 48.")
if len(lives) != 12:
print(f" [VectorEnv] Warning: Fixed deck lives count is {len(lives)}, expected 12.")
self.ability_member_ids = np.array(members, dtype=np.int32)
self.ability_live_ids = np.array(lives, dtype=np.int32)
self.vanilla_member_ids = np.array([], dtype=np.int32)
self.vanilla_live_ids = np.array([], dtype=np.int32)
print(
f" [VectorEnv] Fixed Deck Loaded: {len(self.ability_member_ids)} members, {len(self.ability_live_ids)} lives."
)
except Exception as e:
print(f" [VectorEnv] Fixed Deck Load Error: {e}")
self._load_verified_deck_pool()
def _load_scenarios(self, path="data/scenarios.npz"):
try:
import numpy as np
data = np.load(path)
self.scenarios = {k: data[k] for k in data.files}
self.num_scenarios = len(self.scenarios["batch_hand"])
print(f" [VectorEnv] Loaded {self.num_scenarios} scenarios from {path}")
except Exception as e:
print(f" [VectorEnv] Failed to load scenarios: {e}")
self.scenarios = None
def reset(self, indices: List[int] = None):
"""Reset specified environments (or all if indices is None)."""
if indices is None:
# Full Reset
# Optimization: If resetting all, just loop all in Numba
# We can use a special function or pass all indices
indices_arr = np.arange(self.num_envs, dtype=np.int32)
else:
indices_arr = np.array(indices, dtype=np.int32)
# Use new reset_single logic via loop or parallel
# We can reuse integrated_step_numba's reset logic helper
# But we need a standalone reset kernel
isn.reset_kernel_numba(
indices_arr,
self.batch_stage,
self.batch_energy_vec,
self.batch_energy_count,
self.batch_continuous_vec,
self.batch_continuous_ptr,
self.batch_tapped,
self.batch_live,
self.batch_scores,
self.batch_flat_ctx,
self.batch_global_ctx,
self.batch_hand,
self.batch_deck,
self.opp_stage,
self.opp_energy_vec,
self.opp_energy_count,
self.opp_tapped,
self.opp_live,
self.opp_scores,
self.opp_global_ctx,
self.opp_hand,
self.opp_deck,
self.batch_trash,
self.opp_trash,
self.batch_opp_history,
self.ability_member_ids,
self.ability_live_ids,
int(self.force_start_order),
)
# Scenario Overwrite
if getattr(self, "scenarios", None) is not None and os.getenv("USE_SCENARIOS", "0") == "1":
try:
# Select random scenarios
num_reset = self.num_envs if indices is None else len(indices_arr)
reset_indices = np.arange(self.num_envs) if indices is None else indices_arr
scen_indices = np.random.randint(0, self.num_scenarios, size=num_reset)
def load_field(name, target):
if name in self.scenarios:
data = self.scenarios[name][scen_indices]
if target.ndim == 1 and data.ndim == 2 and data.shape[1] == 1:
data = data.ravel()
target[reset_indices] = data
load_field("batch_hand", self.batch_hand)
load_field("batch_deck", self.batch_deck)
load_field("batch_stage", self.batch_stage)
load_field("batch_energy_vec", self.batch_energy_vec)
load_field("batch_energy_count", self.batch_energy_count)
load_field("batch_continuous_vec", self.batch_continuous_vec)
load_field("batch_continuous_ptr", self.batch_continuous_ptr)
load_field("batch_tapped", self.batch_tapped)
load_field("batch_live", self.batch_live)
load_field("batch_scores", self.batch_scores)
load_field("batch_flat_ctx", self.batch_flat_ctx)
load_field("batch_global_ctx", self.batch_global_ctx)
load_field("opp_hand", self.opp_hand)
load_field("opp_deck", self.opp_deck)
load_field("opp_stage", self.opp_stage)
load_field("opp_energy_vec", self.opp_energy_vec)
load_field("opp_energy_count", self.opp_energy_count)
load_field("opp_tapped", self.opp_tapped)
load_field("opp_live", self.opp_live)
load_field("opp_scores", self.opp_scores)
load_field("opp_global_ctx", self.opp_global_ctx)
except Exception as e:
print(f" [VectorEnv] Error loading scenario data: {e}")
# Reset local trackers
if indices is None:
self.turn = 1
self.prev_scores.fill(0)
self.prev_opp_scores.fill(0)
self.prev_phases.fill(0)
self.episode_returns.fill(0)
self.episode_lengths.fill(0)
else:
for idx in indices:
self.prev_scores[idx] = 0
self.prev_opp_scores[idx] = 0
self.prev_phases[idx] = 0
self.episode_returns[idx] = 0
self.episode_lengths[idx] = 0
# Return observations
return self.get_observations()
def step(self, actions: np.ndarray):
"""Apply a batch of actions across all environments using Optimized Integrated Step."""
# Ensure actions are int32
if actions.dtype != np.int32:
actions = actions.astype(np.int32)
return self.integrated_step(actions)
def integrated_step(self, actions: np.ndarray):
"""
Executes the optimized Numba Integrated Step.
Returns: obs, rewards, dones, infos (list of dicts)
"""
term_scores_agent = np.zeros(self.num_envs, dtype=np.int32)
term_scores_opp = np.zeros(self.num_envs, dtype=np.int32)
rewards, dones = isn.integrated_step_numba(
self.num_envs,
actions,
self.batch_hand,
self.batch_deck,
self.batch_stage,
self.batch_energy_vec,
self.batch_energy_count,
self.batch_continuous_vec,
self.batch_continuous_ptr,
self.batch_tapped,
self.batch_live,
self.batch_scores,
self.batch_flat_ctx,
self.batch_global_ctx,
self.opp_hand,
self.opp_deck,
self.opp_stage,
self.opp_energy_vec,
self.opp_energy_count,
self.opp_tapped,
self.opp_live, # Added
self.opp_scores,
self.opp_global_ctx,
self.card_stats,
self.bytecode_map,
self.bytecode_index,
self.batch_opp_history,
self.obs_buffer,
self.prev_scores,
self.prev_opp_scores,
self.prev_phases,
self.ability_member_ids,
self.ability_live_ids,
self.turn_number_ptr,
self.terminal_obs_buffer,
self.batch_trash,
self.opp_trash,
term_scores_agent,
term_scores_opp,
0
if self.obs_mode == "IMAX"
else (1 if self.obs_mode == "STANDARD" else (3 if self.obs_mode == "ATTENTION" else 2)),
self.game_config, # New Config
int(self.opp_mode),
int(self.force_start_order),
)
# Apply Scenario Reward Scaling
if self.scenario_reward_scale != 1.0 and os.getenv("USE_SCENARIOS", "0") == "1":
rewards *= self.scenario_reward_scale
# Construct Infos (minimal python overhead)
infos = []
for i in range(self.num_envs):
if dones[i]:
infos.append(
{
"terminal_observation": self.terminal_obs_buffer[i].copy(),
"episode": {"r": float(rewards[i]), "l": 10},
"terminal_score_agent": int(term_scores_agent[i]),
"terminal_score_opp": int(term_scores_opp[i]),
}
)
else:
# Accumulate rewards for ongoing episodes
# NOTE: rewards[i] is the delta reward for this specific integrated step.
self.episode_returns[i] += rewards[i]
self.episode_lengths[i] += 1
infos.append({})
# After loop, update terminal infos for done envs with the SUMMED returns
for i in range(self.num_envs):
if dones[i]:
# Add terminal reward to the return
final_return = self.episode_returns[i] + rewards[i]
final_length = self.episode_lengths[i] + 1
infos[i]["episode"] = {"r": float(final_return), "l": int(final_length)}
# Reset accumulators for the next episode in this slot
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
return self.obs_buffer, rewards, dones, infos
def get_action_masks(self):
"""Return legal action masks."""
if self.obs_mode == "ATTENTION":
return compute_action_masks_attention(
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_tapped,
self.batch_global_ctx,
self.batch_live,
self.card_stats,
)
else:
return compute_action_masks(
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_tapped,
self.batch_global_ctx,
self.batch_live,
self.card_stats,
)
def get_observations(self):
"""Return a batched observation for RL models."""
if self.obs_mode == "COMPRESSED":
return isn.encode_observations_compressed(
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.turn,
self.obs_buffer,
)
elif self.obs_mode == "IMAX":
return isn.encode_observations_imax(
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.turn,
self.obs_buffer,
)
elif self.obs_mode == "ATTENTION":
return isn.encode_observations_attention(
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.opp_global_ctx,
self.turn,
self.obs_buffer,
)
else:
return isn.encode_observations_standard(
self.num_envs,
self.batch_hand,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.opp_scores,
self.opp_stage,
self.opp_tapped,
self.card_stats,
self.batch_global_ctx,
self.batch_live,
self.batch_opp_history,
self.turn,
self.obs_buffer,
)
@njit(cache=True)
def step_opponent_vectorized(
opp_hand: np.ndarray, # (N, 60)
opp_deck: np.ndarray, # (N, 60)
opp_stage: np.ndarray,
opp_energy_vec: np.ndarray,
opp_energy_count: np.ndarray,
opp_tapped: np.ndarray,
opp_scores: np.ndarray,
agent_tapped: np.ndarray,
opp_global_ctx: np.ndarray, # (N, 128)
bytecode_map: np.ndarray,
bytecode_index: np.ndarray,
):
"""
Very simplified opponent step. Reuses agent bytecode but targets opponent buffers.
"""
num_envs = len(opp_hand)
# Dummy buffers for context (reused per env)
f_ctx = np.zeros(64, dtype=np.int32)
# We use the passed Hand/Deck buffers directly!
live = np.zeros(50, dtype=np.int32) # Dummy live zone for opponent
# Reusable dummies to avoid allocation in loop
dummy_cont_vec = np.zeros((32, 10), dtype=np.int32)
dummy_ptr = np.zeros(1, dtype=np.int32) # Ref Array
dummy_bonus = np.zeros(1, dtype=np.int32) # Ref Array
for i in range(num_envs):
# RESET local context per environment
f_ctx.fill(0)
# 1. Select Random Legal Action from Hand
# Scan hand for valid bytecodes
# Use fixed array for Numba compatibility (no lists)
candidates = np.zeros(60, dtype=np.int32)
c_ptr = 0
for j in range(60): # Hand size
cid = opp_hand[i, j]
if cid > 0:
candidates[c_ptr] = j # Store Index in Hand
c_ptr += 1
if c_ptr == 0:
continue
# Pick one random index
idx_choice = np.random.randint(0, c_ptr)
hand_idx = candidates[idx_choice]
act_id = opp_hand[i, hand_idx]
# 2. Execute
if act_id > 0 and act_id < bytecode_index.shape[0]:
map_idx = bytecode_index[act_id, 0]
if map_idx > 0:
code_seq = bytecode_map[map_idx]
opp_global_ctx[i, 0] = opp_scores[i]
opp_global_ctx[i, 3] -= 1 # Decrement Hand Count (HD) after playing
# Reset dummies
dummy_ptr[0] = 0
dummy_bonus[0] = 0
# Pass Row Slices of Hand/Deck
# Careful: slicing in loop might allocate. Pass full array + index?
# resolve_bytecode expects 1D array.
# We can't pass a slice 'opp_hand[i]' effectively if function modifies it in place?
# Actually resolve_bytecode modifies it.
# Numba slices are views, should work.
resolve_bytecode(
code_seq,
f_ctx,
opp_global_ctx[i],
1,
opp_hand[i],
opp_deck[i],
opp_stage[i],
opp_energy_vec[i],
opp_energy_count[i],
dummy_cont_vec,
dummy_ptr,
opp_tapped[i],
live,
agent_tapped[i],
bytecode_map,
bytecode_index,
dummy_bonus,
)
# Neutralized: opp_scores[i] = opp_global_ctx[i, 0]
# SC = 0; OS = 1; TR = 2; HD = 3; DI = 4; EN = 5; DK = 6; OT = 7
# Resolve bytecode puts score in SC (index 0) for the current player?
# Let's check fast_logic.py: it uses global_ctx[SC].
# So opp_scores[i] = opp_global_ctx[i, 0] is correct if they are the "current player" in that call.
# 3. Post-Play Cleanup (Draw to refill?)
# If card played, act_id removed from hand by resolve_bytecode (Opcode 11/12/13 usually).
# To simulate "Draw", we check if hand size < 5.
# Count current hand
cnt = 0
for j in range(60):
if opp_hand[i, j] > 0:
cnt += 1
if cnt < 5:
# Draw top card from Deck
# Find first card in Deck
top_card = 0
deck_idx = -1
for j in range(60):
if opp_deck[i, j] > 0:
top_card = opp_deck[i, j]
deck_idx = j
break
if top_card > 0:
# Move to Hand (First empty slot)
for j in range(60):
if opp_hand[i, j] == 0:
opp_hand[i, j] = top_card
opp_deck[i, deck_idx] = 0 # Remove from deck
opp_global_ctx[i, 3] += 1 # Increment Hand Count (HD)
opp_global_ctx[i, 6] -= 1 # Decrement Deck Count (DK)
break
@njit(cache=True)
def resolve_auto_phases(
num_envs: int,
batch_hand: np.ndarray,
batch_deck: np.ndarray,
batch_global_ctx: np.ndarray,
batch_tapped: np.ndarray,
single_step: bool = False,
):
"""
Automatically advances the game through non-interactive phases (0, 1, 2)
until it reaches the Main Phase (3) or the game is over.
Includes Turn Start Draw (Phase 2).
"""
for i in range(num_envs):
# We loop to handle multiple phase jumps if needed
# SAFETY: Limit iterations
max_iters = 1 if single_step else 10
for _ in range(max_iters):
ph = int(batch_global_ctx[i, 8])
# 0 (MULLIGAN) or 8 (LIVE_RESULT) -> 1 (ACTIVE)
if ph == 0 or ph == 8:
# Turn Start: Reset Slot Played Flags (Indices 51-53)
batch_global_ctx[i, 51:54] = 0
# Reset Tapped Status (Members 0-2, Energy 3-15)
batch_tapped[i, 0:16] = 0
# Increment Energy Count (Index 5) (Up to 12)
cur_ec = batch_global_ctx[i, 5]
if cur_ec == 0:
batch_global_ctx[i, 5] = 3
elif cur_ec < 12:
batch_global_ctx[i, 5] = cur_ec + 1
# Increment Turn Counter (Index 54)
batch_global_ctx[i, 54] += 1
batch_global_ctx[i, 8] = 1
continue
# ACTIVE (1) -> ENERGY (2)
if ph == 1:
batch_global_ctx[i, 8] = 2
continue
# ENERGY (2) -> DRAW (3)
if ph == 2:
batch_global_ctx[i, 8] = 3
continue
# DRAW (3) -> MAIN (4)
if ph == 3:
# DRAW 1 CARD
top_card = 0
deck_idx = -1
for d_idx in range(60):
if batch_deck[i, d_idx] > 0:
top_card = batch_deck[i, d_idx]
deck_idx = d_idx
break
# REPLENISH DECK IF EMPTY (Infinite play for benchmarks)
if top_card == 0:
batch_global_ctx[i, 8] = 4
continue
if top_card > 0:
for h_idx in range(60):
if batch_hand[i, h_idx] == 0:
batch_hand[i, h_idx] = top_card
batch_deck[i, deck_idx] = 0
batch_global_ctx[i, 3] = 0
for k in range(60):
if batch_hand[i, k] > 0:
batch_global_ctx[i, 3] += 1
batch_global_ctx[i, 6] -= 1
break
batch_global_ctx[i, 8] = 4
continue
# If ph == 4 (Main), we stop and let the agent act.
if ph == 4:
break
# If ph is not handled, break to avoid infinite loop
break
@njit(parallel=True, cache=True)
def compute_action_masks_attention(
num_envs: int,
batch_hand: np.ndarray,
batch_stage: np.ndarray,
batch_tapped: np.ndarray,
batch_global_ctx: np.ndarray,
batch_live: np.ndarray,
card_stats: np.ndarray,
):
"""
Compute legal action masks for ATTENTION mode (512 actions).
Mapping:
- 0: Pass
- 1-45: Play Member (15 hand idx * 3 slots)
- 46-60: Set Live (15 hand idx)
- 61-63: Activate Ability (3 slots)
- 64-69: Mulligan Select (6 cards)
- 100-299: Choice Actions (Not fully implemented yet)
"""
masks = np.zeros((num_envs, 512), dtype=np.bool_)
masks[:, 0] = True # Pass always legal
for i in prange(num_envs):
phase = batch_global_ctx[i, 8]
# --- Mulligan (Phase Includes -1, 0) ---
if phase <= 0:
# Allow pass (0) to finish
masks[i, 0] = True
# Allow select mulligan (64-69) for first 6 cards
# ONE-WAY: If already selected (flag=1), mask it.
for h_idx in range(6):
if batch_hand[i, h_idx] > 0:
if batch_global_ctx[i, 120 + h_idx] == 0:
masks[i, 64 + h_idx] = True
continue
# --- Main Phase (4) ---
if phase == 4:
ec = batch_global_ctx[i, 5]
tapped_count = 0
for e_idx in range(min(ec, 12)):
if batch_tapped[i, 3 + e_idx] > 0:
tapped_count += 1
available_energy = ec - tapped_count
# 1. Play Actions (1-45) & Set Live (46-60)
# Hand limit for this mode is 15 primary indices
for h_idx in range(15):
cid = batch_hand[i, h_idx]
if cid <= 0 or cid >= card_stats.shape[0]:
continue
is_member = card_stats[cid, 10] == 1
is_live = card_stats[cid, 10] == 2
if is_member:
# Play to Slot 0-2 (Actions 1-45)
# Base = 1 + h_idx * 3
cost = card_stats[cid, 0]
for slot in range(3):
# One play per slot per turn check
if batch_global_ctx[i, 51 + slot] > 0:
continue
# Effective Cost (Baton Touch)
effective_cost = cost
prev_cid = batch_stage[i, slot]
if prev_cid > 0 and prev_cid < card_stats.shape[0]:
effective_cost = max(0, cost - card_stats[prev_cid, 0])
if effective_cost <= available_energy:
masks[i, 1 + h_idx * 3 + slot] = True
# Set Live (Actions 46-60)
# Rule 8.3 & 8.2.2: ANY card can be set.
# Limit 3 cards in zone
live_count = 0
for lx in range(6): # Check full 6 capacity (3 pending + 3 success)
if batch_live[i, lx] > 0:
live_count += 1
if live_count < 3:
masks[i, 46 + h_idx] = True
# 2. Activate Abilities (61-63)
for slot in range(3):
cid = batch_stage[i, slot]
if cid > 0 and not batch_tapped[i, slot]:
masks[i, 61 + slot] = True
# --- Choice Handling (Phase 7+) ---
if phase >= 7 or phase == 4:
# Allow hand selection (100-159)
for h_idx in range(60):
if batch_hand[i, h_idx] > 0:
masks[i, 100 + h_idx] = True
# Allow energy selection (160-171)
ec_val = batch_global_ctx[i, 5]
for e_idx in range(min(ec_val, 12)):
masks[i, 160 + e_idx] = True
return masks
@njit(parallel=True, cache=True)
def compute_action_masks(
num_envs: int,
batch_hand: np.ndarray,
batch_stage: np.ndarray,
batch_tapped: np.ndarray,
batch_global_ctx: np.ndarray,
batch_live: np.ndarray,
card_stats: np.ndarray,
):
"""
Compute legal action masks using Python-compatible action IDs:
- 0: Pass (always legal in Main Phase)
- 1-180: Play Member from Hand (HandIdx * 3 + Slot + 1)
- 200-202: Activate Ability (Slot)
- 400-459: Set Live Card (HandIdx)
"""
masks = np.zeros((num_envs, 2000), dtype=np.bool_)
# Action 0 (Pass) is always legal
masks[:, 0] = True
for i in prange(num_envs):
phase = batch_global_ctx[i, 8]
# Mulligan Phases (-1, 0)
# Mulligan Phases (-1, 0)
if phase == -1 or phase == 0:
masks[i, 0] = True # Pass to finalize
# Only allow selection if the card exists AND isn't already selected (One-way)
for h_idx in range(6): # Only first 6 cards are mull-able (Parity)
if batch_hand[i, h_idx] > 0:
selected = batch_global_ctx[i, 120 + h_idx]
if selected == 0:
masks[i, 300 + h_idx] = True
continue
# Only compute member/ability actions in Main Phase (4)
if phase == 4:
# Calculate available untapped energy
ec = batch_global_ctx[i, 5] # EC at index 5
tapped_count = 0
for e_idx in range(min(ec, 12)):
if batch_tapped[i, 3 + e_idx] > 0:
tapped_count += 1
available_energy = ec - tapped_count
# --- Member Play Actions (1-180) ---
# Action ID = HandIdx * 3 + Slot + 1
for h_idx in range(60):
cid = batch_hand[i, h_idx]
# CRITICAL SAFETY: card_stats shape check
if cid <= 0 or cid >= card_stats.shape[0]:
continue
# Check if this is a Member card (Type 1)
if card_stats[cid, 10] != 1:
# Check if this is a Live card (Type 2) for play actions 400-459
if card_stats[cid, 10] == 2:
# Action ID = 400 + h_idx
action_id = 400 + h_idx
# --- RULE ACCURACY: Live cards can be set without checking hearts ---
# Requirements are checked during Performance phase (Rule 8.3)
# We allow setting if hand size limit not reached (max 3 in zone)
count_in_zone = 0
for j in range(50):
if batch_live[i, j] > 0:
count_in_zone += 1
if count_in_zone < 3:
masks[i, action_id] = True
continue
# Member cost in card_stats[cid, 0]
cost = card_stats[cid, 0]
for slot in range(3):
action_id = h_idx * 3 + slot + 1
# Rule: One play per slot per turn (Indices 51-53)
if batch_global_ctx[i, 51 + slot] > 0:
continue
# Calculate effective cost (Baton Touch reduction)
effective_cost = cost
prev_cid = batch_stage[i, slot]
# SAFETY: Check cid range to avoid out-of-bounds card_stats access
if prev_cid >= 0 and prev_cid < card_stats.shape[0]:
prev_cost = card_stats[prev_cid, 0]
effective_cost = cost - prev_cost
if effective_cost < 0:
effective_cost = 0
if effective_cost <= available_energy:
masks[i, action_id] = True
# --- Activate Ability Actions (200-202) ---
for slot in range(3):
cid = batch_stage[i, slot]
if cid > 0 and not batch_tapped[i, slot]:
# Check if card has an activated ability
# For now, assume all untapped members can activate
masks[i, 200 + slot] = True
# --- Mandatory Choice Handling (Phase 7, 8 & Fallback) ---
if phase >= 7 or phase == 4:
# Allow hand selection/discard actions (500-559) if hand has cards
# This prevents Zero Legal Moves when a choice is pending.
for h_idx in range(60):
if batch_hand[i, h_idx] > 0:
masks[i, 500 + h_idx] = True
# Allow energy selection actions (600-611) if energy exists
energy_count = batch_global_ctx[i, 5]
for e_idx in range(min(energy_count, 12)):
masks[i, 600 + e_idx] = True
return masks
# Export for legacy/external compatibility
encode_observations_vectorized = isn.encode_observations_standard