Spaces:
Running
Running
File size: 5,842 Bytes
bb3fbf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import json
import os
import sys
from typing import List
# Ensure project root is in path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
try:
import engine_rust
except ImportError:
# Try importing from backend if not in root (common in some envs)
try:
from backend import engine_rust
except ImportError:
raise ImportError("Could not import engine_rust. Make sure the Rust extension is built.")
from engine.game.enums import Phase
class AbilityTestContext:
"""
Helper for writing Rust engine tests.
Provides a high-level API for state setup, action execution, and verification.
"""
def __init__(self, compiled_data_path: str = "data/cards_compiled.json"):
if not os.path.exists(compiled_data_path):
# Try alternative path for test execution environments
alt_path = os.path.join(PROJECT_ROOT, compiled_data_path)
if os.path.exists(alt_path):
compiled_data_path = alt_path
else:
raise FileNotFoundError(f"Compiled data not found: {compiled_data_path}")
with open(compiled_data_path, "r", encoding="utf-8") as f:
self.json_data = f.read()
self.db_raw = json.loads(self.json_data)
self.db = engine_rust.PyCardDatabase(self.json_data)
self.gs = engine_rust.PyGameState(self.db)
self.BASE_ID_MASK = 0xFFFFF
def mk_uid(self, base_id: int, instance_idx: int) -> int:
"""Create a Unique ID from base ID and instance index."""
return base_id | (instance_idx << 20)
def find_card_id(self, card_no: str, db_type: str = None) -> int:
"""Find the internal ID for a card number."""
dbs = [db_type] if db_type else ["member_db", "live_db", "energy_db"]
for db_name in dbs:
for cid, card in self.db_raw.get(db_name, {}).items():
if card.get("card_no") == card_no:
return int(cid)
raise ValueError(f"Card {card_no} not found in {dbs}")
def setup_game(self, p0_deck_nos: List[str] = None, p1_deck_nos: List[str] = None):
"""Initialize the game with specific decks (card numbers)."""
def nos_to_uids(nos, offset=0):
if not nos:
return [self.mk_uid(1, i + offset) for i in range(40)] # Default dummy deck
uids = []
counts = {}
for no in nos:
base = self.find_card_id(no)
count = counts.get(base, 0)
uids.append(self.mk_uid(base, count + offset))
counts[base] = count + 1
return uids
p0_main = nos_to_uids(p0_deck_nos, 0)
p1_main = nos_to_uids(p1_deck_nos, 1000) # Offset instance IDs for P1
# Default Energy and Lives
p0_energy = [self.mk_uid(40001, i) for i in range(10)]
p1_energy = [self.mk_uid(40001, 100 + i) for i in range(10)]
p0_lives = [self.mk_uid(1, 200 + i) for i in range(3)]
p1_lives = [self.mk_uid(1, 300 + i) for i in range(3)]
self.gs.initialize_game(p0_main, p1_main, p0_energy, p1_energy, p0_lives, p1_lives)
def skip_mulligan(self):
"""Skip mulligan phases for both players."""
if self.gs.phase == -1: # MULLIGAN_P1
self.gs.step(0)
if self.gs.phase == 0: # MULLIGAN_P2
self.gs.step(0)
def reach_main_phase(self):
"""Advance through Active, Energy, Draw phases to reach Main Phase."""
self.skip_mulligan()
steps = 0
while int(self.gs.phase) < int(Phase.MAIN) and steps < 20:
self.gs.step(0) # Pass/End Phase
steps += 1
def set_hand(self, player_idx: int, card_nos: List[str]):
"""Directly set a player's hand."""
uids = []
for i, no in enumerate(card_nos):
base = self.find_card_id(no)
uids.append(self.mk_uid(base, 500 + i)) # Instance ID 500+
self.gs.set_hand_cards(player_idx, uids)
def set_energy(self, player_idx: int, count: int, tapped_count: int = 0):
"""Directly set a player's energy zone."""
p = self.gs.get_player(player_idx)
p.energy_zone = [self.mk_uid(40001, 600 + i) for i in range(count)] # Instance ID 600+
p.tapped_energy = [True] * tapped_count + [False] * (count - tapped_count)
self.gs.set_player(player_idx, p)
def play_member(self, hand_idx: int, slot_idx: int):
"""Play a member from hand to a specific slot."""
# Action ID: 1 + hand_idx * 3 + slot_idx
action = 1 + hand_idx * 3 + slot_idx
self.gs.step(action)
def get_legal_actions(self) -> List[int]:
"""Get the list of legal action IDs."""
return list(self.gs.get_legal_action_ids())
def assert_phase(self, expected_phase: Phase):
"""Assert the current phase."""
assert int(self.gs.phase) == int(expected_phase), f"Expected phase {expected_phase}, got {self.gs.phase}"
def assert_legal_action(self, action_id: int):
"""Assert that an action is currently legal."""
legal = self.get_legal_actions()
assert action_id in legal, f"Action {action_id} is not legal. Legal: {legal}"
def log(self, msg: str):
"""Helper for logging if needed."""
print(f"[TEST] {msg}")
def print_rule_log(self, limit: int = 10):
"""Print the recent entries from the engine's rule log."""
log = self.gs.rule_log
start = max(0, len(log) - limit)
for i in range(start, len(log)):
print(f" {log[i]}")
|