File size: 5,842 Bytes
bb3fbf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
import os
import sys
from typing import List

# Ensure project root is in path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

try:
    import engine_rust
except ImportError:
    # Try importing from backend if not in root (common in some envs)
    try:
        from backend import engine_rust
    except ImportError:
        raise ImportError("Could not import engine_rust. Make sure the Rust extension is built.")

from engine.game.enums import Phase


class AbilityTestContext:
    """

    Helper for writing Rust engine tests.

    Provides a high-level API for state setup, action execution, and verification.

    """

    def __init__(self, compiled_data_path: str = "data/cards_compiled.json"):
        if not os.path.exists(compiled_data_path):
            # Try alternative path for test execution environments
            alt_path = os.path.join(PROJECT_ROOT, compiled_data_path)
            if os.path.exists(alt_path):
                compiled_data_path = alt_path
            else:
                raise FileNotFoundError(f"Compiled data not found: {compiled_data_path}")

        with open(compiled_data_path, "r", encoding="utf-8") as f:
            self.json_data = f.read()
            self.db_raw = json.loads(self.json_data)
            self.db = engine_rust.PyCardDatabase(self.json_data)

        self.gs = engine_rust.PyGameState(self.db)
        self.BASE_ID_MASK = 0xFFFFF

    def mk_uid(self, base_id: int, instance_idx: int) -> int:
        """Create a Unique ID from base ID and instance index."""
        return base_id | (instance_idx << 20)

    def find_card_id(self, card_no: str, db_type: str = None) -> int:
        """Find the internal ID for a card number."""
        dbs = [db_type] if db_type else ["member_db", "live_db", "energy_db"]
        for db_name in dbs:
            for cid, card in self.db_raw.get(db_name, {}).items():
                if card.get("card_no") == card_no:
                    return int(cid)
        raise ValueError(f"Card {card_no} not found in {dbs}")

    def setup_game(self, p0_deck_nos: List[str] = None, p1_deck_nos: List[str] = None):
        """Initialize the game with specific decks (card numbers)."""

        def nos_to_uids(nos, offset=0):
            if not nos:
                return [self.mk_uid(1, i + offset) for i in range(40)]  # Default dummy deck
            uids = []
            counts = {}
            for no in nos:
                base = self.find_card_id(no)
                count = counts.get(base, 0)
                uids.append(self.mk_uid(base, count + offset))
                counts[base] = count + 1
            return uids

        p0_main = nos_to_uids(p0_deck_nos, 0)
        p1_main = nos_to_uids(p1_deck_nos, 1000)  # Offset instance IDs for P1

        # Default Energy and Lives
        p0_energy = [self.mk_uid(40001, i) for i in range(10)]
        p1_energy = [self.mk_uid(40001, 100 + i) for i in range(10)]
        p0_lives = [self.mk_uid(1, 200 + i) for i in range(3)]
        p1_lives = [self.mk_uid(1, 300 + i) for i in range(3)]

        self.gs.initialize_game(p0_main, p1_main, p0_energy, p1_energy, p0_lives, p1_lives)

    def skip_mulligan(self):
        """Skip mulligan phases for both players."""
        if self.gs.phase == -1:  # MULLIGAN_P1
            self.gs.step(0)
        if self.gs.phase == 0:  # MULLIGAN_P2
            self.gs.step(0)

    def reach_main_phase(self):
        """Advance through Active, Energy, Draw phases to reach Main Phase."""
        self.skip_mulligan()
        steps = 0
        while int(self.gs.phase) < int(Phase.MAIN) and steps < 20:
            self.gs.step(0)  # Pass/End Phase
            steps += 1

    def set_hand(self, player_idx: int, card_nos: List[str]):
        """Directly set a player's hand."""
        uids = []
        for i, no in enumerate(card_nos):
            base = self.find_card_id(no)
            uids.append(self.mk_uid(base, 500 + i))  # Instance ID 500+
        self.gs.set_hand_cards(player_idx, uids)

    def set_energy(self, player_idx: int, count: int, tapped_count: int = 0):
        """Directly set a player's energy zone."""
        p = self.gs.get_player(player_idx)
        p.energy_zone = [self.mk_uid(40001, 600 + i) for i in range(count)]  # Instance ID 600+
        p.tapped_energy = [True] * tapped_count + [False] * (count - tapped_count)
        self.gs.set_player(player_idx, p)

    def play_member(self, hand_idx: int, slot_idx: int):
        """Play a member from hand to a specific slot."""
        # Action ID: 1 + hand_idx * 3 + slot_idx
        action = 1 + hand_idx * 3 + slot_idx
        self.gs.step(action)

    def get_legal_actions(self) -> List[int]:
        """Get the list of legal action IDs."""
        return list(self.gs.get_legal_action_ids())

    def assert_phase(self, expected_phase: Phase):
        """Assert the current phase."""
        assert int(self.gs.phase) == int(expected_phase), f"Expected phase {expected_phase}, got {self.gs.phase}"

    def assert_legal_action(self, action_id: int):
        """Assert that an action is currently legal."""
        legal = self.get_legal_actions()
        assert action_id in legal, f"Action {action_id} is not legal. Legal: {legal}"

    def log(self, msg: str):
        """Helper for logging if needed."""
        print(f"[TEST] {msg}")

    def print_rule_log(self, limit: int = 10):
        """Print the recent entries from the engine's rule log."""
        log = self.gs.rule_log
        start = max(0, len(log) - limit)
        for i in range(start, len(log)):
            print(f"  {log[i]}")