import json import os import sys from typing import List # Ensure project root is in path PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) try: import engine_rust except ImportError: # Try importing from backend if not in root (common in some envs) try: from backend import engine_rust except ImportError: raise ImportError("Could not import engine_rust. Make sure the Rust extension is built.") from engine.game.enums import Phase class AbilityTestContext: """ Helper for writing Rust engine tests. Provides a high-level API for state setup, action execution, and verification. """ def __init__(self, compiled_data_path: str = "data/cards_compiled.json"): if not os.path.exists(compiled_data_path): # Try alternative path for test execution environments alt_path = os.path.join(PROJECT_ROOT, compiled_data_path) if os.path.exists(alt_path): compiled_data_path = alt_path else: raise FileNotFoundError(f"Compiled data not found: {compiled_data_path}") with open(compiled_data_path, "r", encoding="utf-8") as f: self.json_data = f.read() self.db_raw = json.loads(self.json_data) self.db = engine_rust.PyCardDatabase(self.json_data) self.gs = engine_rust.PyGameState(self.db) self.BASE_ID_MASK = 0xFFFFF def mk_uid(self, base_id: int, instance_idx: int) -> int: """Create a Unique ID from base ID and instance index.""" return base_id | (instance_idx << 20) def find_card_id(self, card_no: str, db_type: str = None) -> int: """Find the internal ID for a card number.""" dbs = [db_type] if db_type else ["member_db", "live_db", "energy_db"] for db_name in dbs: for cid, card in self.db_raw.get(db_name, {}).items(): if card.get("card_no") == card_no: return int(cid) raise ValueError(f"Card {card_no} not found in {dbs}") def setup_game(self, p0_deck_nos: List[str] = None, p1_deck_nos: List[str] = None): """Initialize the game with specific decks (card numbers).""" def nos_to_uids(nos, offset=0): if not nos: return [self.mk_uid(1, i + offset) for i in range(40)] # Default dummy deck uids = [] counts = {} for no in nos: base = self.find_card_id(no) count = counts.get(base, 0) uids.append(self.mk_uid(base, count + offset)) counts[base] = count + 1 return uids p0_main = nos_to_uids(p0_deck_nos, 0) p1_main = nos_to_uids(p1_deck_nos, 1000) # Offset instance IDs for P1 # Default Energy and Lives p0_energy = [self.mk_uid(40001, i) for i in range(10)] p1_energy = [self.mk_uid(40001, 100 + i) for i in range(10)] p0_lives = [self.mk_uid(1, 200 + i) for i in range(3)] p1_lives = [self.mk_uid(1, 300 + i) for i in range(3)] self.gs.initialize_game(p0_main, p1_main, p0_energy, p1_energy, p0_lives, p1_lives) def skip_mulligan(self): """Skip mulligan phases for both players.""" if self.gs.phase == -1: # MULLIGAN_P1 self.gs.step(0) if self.gs.phase == 0: # MULLIGAN_P2 self.gs.step(0) def reach_main_phase(self): """Advance through Active, Energy, Draw phases to reach Main Phase.""" self.skip_mulligan() steps = 0 while int(self.gs.phase) < int(Phase.MAIN) and steps < 20: self.gs.step(0) # Pass/End Phase steps += 1 def set_hand(self, player_idx: int, card_nos: List[str]): """Directly set a player's hand.""" uids = [] for i, no in enumerate(card_nos): base = self.find_card_id(no) uids.append(self.mk_uid(base, 500 + i)) # Instance ID 500+ self.gs.set_hand_cards(player_idx, uids) def set_energy(self, player_idx: int, count: int, tapped_count: int = 0): """Directly set a player's energy zone.""" p = self.gs.get_player(player_idx) p.energy_zone = [self.mk_uid(40001, 600 + i) for i in range(count)] # Instance ID 600+ p.tapped_energy = [True] * tapped_count + [False] * (count - tapped_count) self.gs.set_player(player_idx, p) def play_member(self, hand_idx: int, slot_idx: int): """Play a member from hand to a specific slot.""" # Action ID: 1 + hand_idx * 3 + slot_idx action = 1 + hand_idx * 3 + slot_idx self.gs.step(action) def get_legal_actions(self) -> List[int]: """Get the list of legal action IDs.""" return list(self.gs.get_legal_action_ids()) def assert_phase(self, expected_phase: Phase): """Assert the current phase.""" assert int(self.gs.phase) == int(expected_phase), f"Expected phase {expected_phase}, got {self.gs.phase}" def assert_legal_action(self, action_id: int): """Assert that an action is currently legal.""" legal = self.get_legal_actions() assert action_id in legal, f"Action {action_id} is not legal. Legal: {legal}" def log(self, msg: str): """Helper for logging if needed.""" print(f"[TEST] {msg}") def print_rule_log(self, limit: int = 10): """Print the recent entries from the engine's rule log.""" log = self.gs.rule_log start = max(0, len(log) - limit) for i in range(start, len(log)): print(f" {log[i]}")