Spaces:
Running
Running
| import json | |
| import os | |
| import sys | |
| from typing import List | |
| # Ensure project root is in path | |
| PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| try: | |
| import engine_rust | |
| except ImportError: | |
| # Try importing from backend if not in root (common in some envs) | |
| try: | |
| from backend import engine_rust | |
| except ImportError: | |
| raise ImportError("Could not import engine_rust. Make sure the Rust extension is built.") | |
| from engine.game.enums import Phase | |
| class AbilityTestContext: | |
| """ | |
| Helper for writing Rust engine tests. | |
| Provides a high-level API for state setup, action execution, and verification. | |
| """ | |
| def __init__(self, compiled_data_path: str = "data/cards_compiled.json"): | |
| if not os.path.exists(compiled_data_path): | |
| # Try alternative path for test execution environments | |
| alt_path = os.path.join(PROJECT_ROOT, compiled_data_path) | |
| if os.path.exists(alt_path): | |
| compiled_data_path = alt_path | |
| else: | |
| raise FileNotFoundError(f"Compiled data not found: {compiled_data_path}") | |
| with open(compiled_data_path, "r", encoding="utf-8") as f: | |
| self.json_data = f.read() | |
| self.db_raw = json.loads(self.json_data) | |
| self.db = engine_rust.PyCardDatabase(self.json_data) | |
| self.gs = engine_rust.PyGameState(self.db) | |
| self.BASE_ID_MASK = 0xFFFFF | |
| def mk_uid(self, base_id: int, instance_idx: int) -> int: | |
| """Create a Unique ID from base ID and instance index.""" | |
| return base_id | (instance_idx << 20) | |
| def find_card_id(self, card_no: str, db_type: str = None) -> int: | |
| """Find the internal ID for a card number.""" | |
| dbs = [db_type] if db_type else ["member_db", "live_db", "energy_db"] | |
| for db_name in dbs: | |
| for cid, card in self.db_raw.get(db_name, {}).items(): | |
| if card.get("card_no") == card_no: | |
| return int(cid) | |
| raise ValueError(f"Card {card_no} not found in {dbs}") | |
| def setup_game(self, p0_deck_nos: List[str] = None, p1_deck_nos: List[str] = None): | |
| """Initialize the game with specific decks (card numbers).""" | |
| def nos_to_uids(nos, offset=0): | |
| if not nos: | |
| return [self.mk_uid(1, i + offset) for i in range(40)] # Default dummy deck | |
| uids = [] | |
| counts = {} | |
| for no in nos: | |
| base = self.find_card_id(no) | |
| count = counts.get(base, 0) | |
| uids.append(self.mk_uid(base, count + offset)) | |
| counts[base] = count + 1 | |
| return uids | |
| p0_main = nos_to_uids(p0_deck_nos, 0) | |
| p1_main = nos_to_uids(p1_deck_nos, 1000) # Offset instance IDs for P1 | |
| # Default Energy and Lives | |
| p0_energy = [self.mk_uid(40001, i) for i in range(10)] | |
| p1_energy = [self.mk_uid(40001, 100 + i) for i in range(10)] | |
| p0_lives = [self.mk_uid(1, 200 + i) for i in range(3)] | |
| p1_lives = [self.mk_uid(1, 300 + i) for i in range(3)] | |
| self.gs.initialize_game(p0_main, p1_main, p0_energy, p1_energy, p0_lives, p1_lives) | |
| def skip_mulligan(self): | |
| """Skip mulligan phases for both players.""" | |
| if self.gs.phase == -1: # MULLIGAN_P1 | |
| self.gs.step(0) | |
| if self.gs.phase == 0: # MULLIGAN_P2 | |
| self.gs.step(0) | |
| def reach_main_phase(self): | |
| """Advance through Active, Energy, Draw phases to reach Main Phase.""" | |
| self.skip_mulligan() | |
| steps = 0 | |
| while int(self.gs.phase) < int(Phase.MAIN) and steps < 20: | |
| self.gs.step(0) # Pass/End Phase | |
| steps += 1 | |
| def set_hand(self, player_idx: int, card_nos: List[str]): | |
| """Directly set a player's hand.""" | |
| uids = [] | |
| for i, no in enumerate(card_nos): | |
| base = self.find_card_id(no) | |
| uids.append(self.mk_uid(base, 500 + i)) # Instance ID 500+ | |
| self.gs.set_hand_cards(player_idx, uids) | |
| def set_energy(self, player_idx: int, count: int, tapped_count: int = 0): | |
| """Directly set a player's energy zone.""" | |
| p = self.gs.get_player(player_idx) | |
| p.energy_zone = [self.mk_uid(40001, 600 + i) for i in range(count)] # Instance ID 600+ | |
| p.tapped_energy = [True] * tapped_count + [False] * (count - tapped_count) | |
| self.gs.set_player(player_idx, p) | |
| def play_member(self, hand_idx: int, slot_idx: int): | |
| """Play a member from hand to a specific slot.""" | |
| # Action ID: 1 + hand_idx * 3 + slot_idx | |
| action = 1 + hand_idx * 3 + slot_idx | |
| self.gs.step(action) | |
| def get_legal_actions(self) -> List[int]: | |
| """Get the list of legal action IDs.""" | |
| return list(self.gs.get_legal_action_ids()) | |
| def assert_phase(self, expected_phase: Phase): | |
| """Assert the current phase.""" | |
| assert int(self.gs.phase) == int(expected_phase), f"Expected phase {expected_phase}, got {self.gs.phase}" | |
| def assert_legal_action(self, action_id: int): | |
| """Assert that an action is currently legal.""" | |
| legal = self.get_legal_actions() | |
| assert action_id in legal, f"Action {action_id} is not legal. Legal: {legal}" | |
| def log(self, msg: str): | |
| """Helper for logging if needed.""" | |
| print(f"[TEST] {msg}") | |
| def print_rule_log(self, limit: int = 10): | |
| """Print the recent entries from the engine's rule log.""" | |
| log = self.gs.rule_log | |
| start = max(0, len(log) - limit) | |
| for i in range(start, len(log)): | |
| print(f" {log[i]}") | |