LovecaSim / ai /environments /vector_env_legacy.py
trioskosmos's picture
Upload ai/environments/vector_env_legacy.py with huggingface_hub
2badd2f verified
from typing import List
import numpy as np
from engine.game.ai_compat import njit
from engine.game.fast_logic import batch_apply_action
@njit(cache=True)
def step_vectorized(
actions: np.ndarray,
batch_stage: np.ndarray,
batch_energy_vec: np.ndarray,
batch_energy_count: np.ndarray,
batch_continuous_vec: np.ndarray,
batch_continuous_ptr: np.ndarray,
batch_tapped: np.ndarray,
batch_live: np.ndarray,
batch_opp_tapped: np.ndarray,
batch_scores: np.ndarray,
batch_flat_ctx: np.ndarray,
batch_global_ctx: np.ndarray,
batch_hand: np.ndarray,
batch_deck: np.ndarray,
# New: Bytecode Maps
bytecode_map: np.ndarray, # (GlobalOpMapSize, MaxBytecodeLen, 4)
bytecode_index: np.ndarray, # (NumCards, NumAbilities) -> Index in map
):
"""
Step N game environments in parallel using JIT logic and Real Card Data.
"""
# Sync individual scores to global_ctx before stepping
for i in range(len(actions)):
batch_global_ctx[i, 0] = batch_scores[i]
batch_apply_action(
actions,
0, # player_id
batch_stage,
batch_energy_vec,
batch_energy_count,
batch_continuous_vec,
batch_continuous_ptr,
batch_tapped,
batch_scores,
batch_live,
batch_opp_tapped,
batch_flat_ctx,
batch_global_ctx,
batch_hand,
batch_deck,
bytecode_map,
bytecode_index,
)
class VectorGameState:
"""
Manages a batch of independent GameStates for high-throughput training.
"""
def __init__(self, num_envs: int):
self.num_envs = num_envs
self.turn = 1
# Batched state buffers
self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32)
self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32)
self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32)
self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32)
self.batch_tapped = np.zeros((num_envs, 3), dtype=np.int32)
self.batch_live = np.zeros((num_envs, 50), dtype=np.int32)
self.batch_opp_tapped = np.zeros((num_envs, 3), dtype=np.int32)
self.batch_scores = np.zeros(num_envs, dtype=np.int32)
# Pre-allocated context buffers (Extreme speed optimization)
self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32)
self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32)
self.batch_hand = np.zeros((num_envs, 50), dtype=np.int32)
self.batch_deck = np.zeros((num_envs, 50), dtype=np.int32)
# Pre-allocated observation buffer (SAVES ALLOCATION TIME)
self.obs_buffer = np.zeros((num_envs, 320), dtype=np.float32)
# Load Bytecode Map
self._load_bytecode()
self._load_verified_deck_pool()
def _load_bytecode(self):
import json
try:
with open("data/cards_numba.json", "r") as f:
raw_map = json.load(f)
# Convert to numpy array
# Format: key "cardid_abidx" -> List[int]
# storage:
# 1. giant array of bytecodes (N, MaxLen, 4)
# 2. lookup index (CardID, AbIdx) -> Index in giant array
self.max_cards = 2000
self.max_abilities = 4
self.max_len = 64 # Max 64 instructions per ability
# Count unique compiled entries
unique_entries = len(raw_map)
# (Index 0 is empty/nop)
self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32)
self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32)
idx_counter = 1
for key, bc_list in raw_map.items():
cid, aid = map(int, key.split("_"))
if cid < self.max_cards and aid < self.max_abilities:
# reshape list to (M, 4)
bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
length = min(bc_arr.shape[0], self.max_len)
self.bytecode_map[idx_counter, :length] = bc_arr[:length]
self.bytecode_index[cid, aid] = idx_counter
idx_counter += 1
print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.")
except FileNotFoundError:
print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.")
self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32)
self.bytecode_index = np.zeros((1, 1), dtype=np.int32)
def _load_verified_deck_pool(self):
import json
try:
# Load Verified List
with open("verified_card_pool.json", "r", encoding="utf-8") as f:
verified_data = json.load(f)
# Load DB to map CardNo -> CardID
with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
db_data = json.load(f)
self.verified_card_ids = []
# Map numbers to IDs
card_no_map = {}
for cid, cdata in db_data["member_db"].items():
card_no_map[cdata["card_no"]] = int(cid)
for v_no in verified_data.get("verified_abilities", []):
if v_no in card_no_map:
self.verified_card_ids.append(card_no_map[v_no])
# Fallback
if not self.verified_card_ids:
print(" [VectorEnv] Warning: No verified cards found. Using ID 1.")
self.verified_card_ids = [1]
else:
print(f" [VectorEnv] Loaded {len(self.verified_card_ids)} verified cards for training.")
self.verified_card_ids = np.array(self.verified_card_ids, dtype=np.int32)
except Exception as e:
print(f" [VectorEnv] Deck Load Error: {e}")
self.verified_card_ids = np.array([1], dtype=np.int32)
def reset(self, indices: List[int] = None):
"""Reset specified environments (or all if indices is None)."""
if indices is None:
indices = list(range(self.num_envs))
# Optimization: Bulk operations for indices if supported,
# but for now loop is fine (reset is rare compared to step)
# Prepare a random deck selection to broadcast?
# Actually random.choice is fast.
for i in indices:
self.batch_stage[i].fill(-1)
self.batch_energy_vec[i].fill(0)
self.batch_energy_count[i].fill(0)
self.batch_continuous_vec[i].fill(0)
self.batch_continuous_ptr[i] = 0
self.batch_tapped[i].fill(0)
self.batch_live[i].fill(0)
self.batch_opp_tapped[i].fill(0)
self.batch_scores[i] = 0
# Reset contexts
self.batch_flat_ctx[i].fill(0)
self.batch_global_ctx[i].fill(0)
# Initialize Deck with Verified Cards (Random 50)
# Fast choice from verified pool
if len(self.verified_card_ids) > 0:
dk = np.random.choice(self.verified_card_ids, 50)
self.batch_deck[i] = dk
# Initialize Hand (Draw 5 from deck)
# Simple simulation: Move top 5 deck to hand
self.batch_hand[i, :5] = self.batch_deck[i, :5]
# Shift deck? Or just pointer?
# For this benchmark we assume infinite deck or simple pointer logic via opcodes.
# But the 'hand' array needs to be populated for gameplay to start.
self.turn = 1
def step(self, actions: np.ndarray):
"""Apply a batch of actions across all environments."""
step_vectorized(
actions,
self.batch_stage,
self.batch_energy_vec,
self.batch_energy_count,
self.batch_continuous_vec,
self.batch_continuous_ptr,
self.batch_tapped,
self.batch_live,
self.batch_opp_tapped,
self.batch_scores,
self.batch_flat_ctx,
self.batch_global_ctx,
self.batch_hand,
self.batch_deck,
self.bytecode_map,
self.bytecode_index,
)
# Simplified turn advancement
# In real VectorEnv, this would be managed by the engine rules
pass
def get_observations(self):
"""Return a batched observation for RL models."""
return encode_observations_vectorized(
self.num_envs,
self.batch_stage,
self.batch_energy_count,
self.batch_tapped,
self.batch_scores,
self.turn,
self.obs_buffer,
)
@njit(cache=True)
def encode_observations_vectorized(
num_envs: int,
batch_stage: np.ndarray, # (N, 3)
batch_energy_count: np.ndarray, # (N, 3)
batch_tapped: np.ndarray, # (N, 3)
batch_scores: np.ndarray, # (N,)
turn_number: int,
observations: np.ndarray, # (N, 320)
):
# Reset buffer (extremely fast on pre-allocated)
observations.fill(0.0)
max_id_val = 2000.0 # Normalization constant
for i in range(num_envs):
# --- 1. METADATA [0:36] ---
# Phase (Simplify: Always Main Phase=1 for now in vector env)
# Phase 1=Start, 2=Draw, 3=Main... Main is index 3+2=5?
# GameState logic: phase_val = int(phase) + 2. Main is 3. So 5.
observations[i, 5] = 1.0
# Current Player [16:18] - Always Player 0 for this vector view
observations[i, 16] = 1.0
# --- 2. HAND [36:168] ---
# VectorEnv doesn't track hand yet. Leave 0.0.
# --- 3. SELF STAGE [168:204] (3 slots * 12 features) ---
for slot in range(3):
cid = batch_stage[i, slot]
base = 168 + slot * 12
if cid >= 0:
observations[i, base] = 1.0
observations[i, base + 1] = cid / max_id_val
observations[i, base + 2] = 1.0 if batch_tapped[i, slot] else 0.0
# Mock attributes (since we don't have full DB access inside JIT yet)
# In real imp, we'd pass arrays like member_costs
observations[i, base + 3] = 0.5 # Default power
# Energy Count
observations[i, base + 11] = min(batch_energy_count[i, slot] / 5.0, 1.0)
# --- 4. OPPONENT STAGE [204:240] ---
# Not tracked in partial vector env yet.
# --- 5. LIVE ZONE [240:270] ---
# Not tracked in partial vector env yet.
# --- 6. SCORES [270:272] ---
observations[i, 270] = min(batch_scores[i] / 5.0, 1.0)
return observations