File size: 3,345 Bytes
5a40dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import sys
import unittest

import numpy as np

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))

from ai.obs_adapters import UnifiedObservationEncoder

from engine.game.enums import Phase
from engine.game.game_state import GameState


class TestObsAlignment(unittest.TestCase):
    def test_encode_8192_mappings(self):
        print("Verifying UnifiedObservationEncoder._encode_8192 alignment...")

        # 1. Create Mock GameState
        state = GameState()
        p = state.players[0]
        opp = state.players[1]

        # Set known values to distinct integers to trace them
        # Counts:
        # Score: 5 (Distinct from Phase values 3, 4)

        p.success_lives = [1] * 5
        opp.success_lives = [1] * 2

        p.discard = [1] * 11
        p.hand = [1] * 4
        p.energy_zone = [1] * 5
        p.main_deck = [1] * 20
        opp.discard = [1] * 7
        opp.hand = [1] * 6
        opp.main_deck = [1] * 25

        state.phase = Phase.MAIN  # 4
        state.turn_number = 10

        # 2. Encode
        obs = UnifiedObservationEncoder.encode(state, 8192, 0)

        # 3. Verify Output in Observation Vector

        SCORE_START = 8000
        # obs[SCORE_START] = score / 9.0
        self.assertAlmostEqual(obs[SCORE_START], 5.0 / 9.0, places=4, msg="My Score Mismatch")

        VOLUMES_START = 7800
        # My Deck (DK=20) -> VOL+0
        self.assertAlmostEqual(obs[VOLUMES_START], 20.0 / 50.0, places=4, msg="My Deck Volume Mismatch")

        # Opp Deck (DK=25) -> VOL+1
        self.assertAlmostEqual(obs[VOLUMES_START + 1], 25.0 / 50.0, places=4, msg="Opp Deck Volume Mismatch")

        # My Hand (HD=4) -> VOL+2
        self.assertAlmostEqual(obs[VOLUMES_START + 2], 4.0 / 20.0, places=4, msg="My Hand Volume Mismatch")

        # My Trash (TR=11) -> VOL+3
        self.assertAlmostEqual(obs[VOLUMES_START + 3], 11.0 / 50.0, places=4, msg="My Trash Volume Mismatch")

        # Opp Hand (HD=6) -> VOL+4
        self.assertAlmostEqual(obs[VOLUMES_START + 4], 6.0 / 20.0, places=4, msg="Opp Hand Volume Mismatch")

        # Opp Trash (OT=7) -> VOL+5
        val_opp_trash = obs[VOLUMES_START + 5] * 50.0
        print(f"Observed Opp Trash (Volume): {val_opp_trash} (Expected 7)")
        if abs(val_opp_trash - 7.0) > 0.1:
            self.fail(f"Opp Trash mismatch: Got {val_opp_trash}, Expected 7")

        # One-Hot Phase
        # Index 20 + Phase. Phase.MAIN = 4 -> Index 24.
        # Score is 5.
        # If code reads [0] (Score=5), it sets Index 25.
        # If code reads [8] (Phase=4), it sets Index 24.

        if obs[24] == 1.0:
            print("Phase correctly identified as MAIN (4).")
        elif obs[25] == 1.0:
            print("Phase identified as 5 (Score value!) -> Mismatch.")
            self.fail("Phase encoding is reading Score index!")
        else:
            print("Phase not set at expected index 24?")
            # Check if it's somewhere else
            idx = np.where(obs[20:30] == 1.0)[0]
            print(f"Phase set at relative index: {idx}")
            self.fail(f"Phase encoding failed. Index 24 not set. Found set at relative: {idx}")


if __name__ == "__main__":
    unittest.main()