trioskosmos commited on
Commit
5a40dcb
·
verified ·
1 Parent(s): 5b7c0fb

Upload ai/tests/test_obs_alignment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/tests/test_obs_alignment.py +94 -0
ai/tests/test_obs_alignment.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import unittest
4
+
5
+ import numpy as np
6
+
7
+ # Add project root to path
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
9
+
10
+ from ai.obs_adapters import UnifiedObservationEncoder
11
+
12
+ from engine.game.enums import Phase
13
+ from engine.game.game_state import GameState
14
+
15
+
16
+ class TestObsAlignment(unittest.TestCase):
17
+ def test_encode_8192_mappings(self):
18
+ print("Verifying UnifiedObservationEncoder._encode_8192 alignment...")
19
+
20
+ # 1. Create Mock GameState
21
+ state = GameState()
22
+ p = state.players[0]
23
+ opp = state.players[1]
24
+
25
+ # Set known values to distinct integers to trace them
26
+ # Counts:
27
+ # Score: 5 (Distinct from Phase values 3, 4)
28
+
29
+ p.success_lives = [1] * 5
30
+ opp.success_lives = [1] * 2
31
+
32
+ p.discard = [1] * 11
33
+ p.hand = [1] * 4
34
+ p.energy_zone = [1] * 5
35
+ p.main_deck = [1] * 20
36
+ opp.discard = [1] * 7
37
+ opp.hand = [1] * 6
38
+ opp.main_deck = [1] * 25
39
+
40
+ state.phase = Phase.MAIN # 4
41
+ state.turn_number = 10
42
+
43
+ # 2. Encode
44
+ obs = UnifiedObservationEncoder.encode(state, 8192, 0)
45
+
46
+ # 3. Verify Output in Observation Vector
47
+
48
+ SCORE_START = 8000
49
+ # obs[SCORE_START] = score / 9.0
50
+ self.assertAlmostEqual(obs[SCORE_START], 5.0 / 9.0, places=4, msg="My Score Mismatch")
51
+
52
+ VOLUMES_START = 7800
53
+ # My Deck (DK=20) -> VOL+0
54
+ self.assertAlmostEqual(obs[VOLUMES_START], 20.0 / 50.0, places=4, msg="My Deck Volume Mismatch")
55
+
56
+ # Opp Deck (DK=25) -> VOL+1
57
+ self.assertAlmostEqual(obs[VOLUMES_START + 1], 25.0 / 50.0, places=4, msg="Opp Deck Volume Mismatch")
58
+
59
+ # My Hand (HD=4) -> VOL+2
60
+ self.assertAlmostEqual(obs[VOLUMES_START + 2], 4.0 / 20.0, places=4, msg="My Hand Volume Mismatch")
61
+
62
+ # My Trash (TR=11) -> VOL+3
63
+ self.assertAlmostEqual(obs[VOLUMES_START + 3], 11.0 / 50.0, places=4, msg="My Trash Volume Mismatch")
64
+
65
+ # Opp Hand (HD=6) -> VOL+4
66
+ self.assertAlmostEqual(obs[VOLUMES_START + 4], 6.0 / 20.0, places=4, msg="Opp Hand Volume Mismatch")
67
+
68
+ # Opp Trash (OT=7) -> VOL+5
69
+ val_opp_trash = obs[VOLUMES_START + 5] * 50.0
70
+ print(f"Observed Opp Trash (Volume): {val_opp_trash} (Expected 7)")
71
+ if abs(val_opp_trash - 7.0) > 0.1:
72
+ self.fail(f"Opp Trash mismatch: Got {val_opp_trash}, Expected 7")
73
+
74
+ # One-Hot Phase
75
+ # Index 20 + Phase. Phase.MAIN = 4 -> Index 24.
76
+ # Score is 5.
77
+ # If code reads [0] (Score=5), it sets Index 25.
78
+ # If code reads [8] (Phase=4), it sets Index 24.
79
+
80
+ if obs[24] == 1.0:
81
+ print("Phase correctly identified as MAIN (4).")
82
+ elif obs[25] == 1.0:
83
+ print("Phase identified as 5 (Score value!) -> Mismatch.")
84
+ self.fail("Phase encoding is reading Score index!")
85
+ else:
86
+ print("Phase not set at expected index 24?")
87
+ # Check if it's somewhere else
88
+ idx = np.where(obs[20:30] == 1.0)[0]
89
+ print(f"Phase set at relative index: {idx}")
90
+ self.fail(f"Phase encoding failed. Index 24 not set. Found set at relative: {idx}")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ unittest.main()