Spaces:
Running
Running
Upload ai/tests/test_obs_alignment.py with huggingface_hub
Browse files
ai/tests/test_obs_alignment.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import unittest
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# Add project root to path
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
|
| 9 |
+
|
| 10 |
+
from ai.obs_adapters import UnifiedObservationEncoder
|
| 11 |
+
|
| 12 |
+
from engine.game.enums import Phase
|
| 13 |
+
from engine.game.game_state import GameState
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestObsAlignment(unittest.TestCase):
|
| 17 |
+
def test_encode_8192_mappings(self):
|
| 18 |
+
print("Verifying UnifiedObservationEncoder._encode_8192 alignment...")
|
| 19 |
+
|
| 20 |
+
# 1. Create Mock GameState
|
| 21 |
+
state = GameState()
|
| 22 |
+
p = state.players[0]
|
| 23 |
+
opp = state.players[1]
|
| 24 |
+
|
| 25 |
+
# Set known values to distinct integers to trace them
|
| 26 |
+
# Counts:
|
| 27 |
+
# Score: 5 (Distinct from Phase values 3, 4)
|
| 28 |
+
|
| 29 |
+
p.success_lives = [1] * 5
|
| 30 |
+
opp.success_lives = [1] * 2
|
| 31 |
+
|
| 32 |
+
p.discard = [1] * 11
|
| 33 |
+
p.hand = [1] * 4
|
| 34 |
+
p.energy_zone = [1] * 5
|
| 35 |
+
p.main_deck = [1] * 20
|
| 36 |
+
opp.discard = [1] * 7
|
| 37 |
+
opp.hand = [1] * 6
|
| 38 |
+
opp.main_deck = [1] * 25
|
| 39 |
+
|
| 40 |
+
state.phase = Phase.MAIN # 4
|
| 41 |
+
state.turn_number = 10
|
| 42 |
+
|
| 43 |
+
# 2. Encode
|
| 44 |
+
obs = UnifiedObservationEncoder.encode(state, 8192, 0)
|
| 45 |
+
|
| 46 |
+
# 3. Verify Output in Observation Vector
|
| 47 |
+
|
| 48 |
+
SCORE_START = 8000
|
| 49 |
+
# obs[SCORE_START] = score / 9.0
|
| 50 |
+
self.assertAlmostEqual(obs[SCORE_START], 5.0 / 9.0, places=4, msg="My Score Mismatch")
|
| 51 |
+
|
| 52 |
+
VOLUMES_START = 7800
|
| 53 |
+
# My Deck (DK=20) -> VOL+0
|
| 54 |
+
self.assertAlmostEqual(obs[VOLUMES_START], 20.0 / 50.0, places=4, msg="My Deck Volume Mismatch")
|
| 55 |
+
|
| 56 |
+
# Opp Deck (DK=25) -> VOL+1
|
| 57 |
+
self.assertAlmostEqual(obs[VOLUMES_START + 1], 25.0 / 50.0, places=4, msg="Opp Deck Volume Mismatch")
|
| 58 |
+
|
| 59 |
+
# My Hand (HD=4) -> VOL+2
|
| 60 |
+
self.assertAlmostEqual(obs[VOLUMES_START + 2], 4.0 / 20.0, places=4, msg="My Hand Volume Mismatch")
|
| 61 |
+
|
| 62 |
+
# My Trash (TR=11) -> VOL+3
|
| 63 |
+
self.assertAlmostEqual(obs[VOLUMES_START + 3], 11.0 / 50.0, places=4, msg="My Trash Volume Mismatch")
|
| 64 |
+
|
| 65 |
+
# Opp Hand (HD=6) -> VOL+4
|
| 66 |
+
self.assertAlmostEqual(obs[VOLUMES_START + 4], 6.0 / 20.0, places=4, msg="Opp Hand Volume Mismatch")
|
| 67 |
+
|
| 68 |
+
# Opp Trash (OT=7) -> VOL+5
|
| 69 |
+
val_opp_trash = obs[VOLUMES_START + 5] * 50.0
|
| 70 |
+
print(f"Observed Opp Trash (Volume): {val_opp_trash} (Expected 7)")
|
| 71 |
+
if abs(val_opp_trash - 7.0) > 0.1:
|
| 72 |
+
self.fail(f"Opp Trash mismatch: Got {val_opp_trash}, Expected 7")
|
| 73 |
+
|
| 74 |
+
# One-Hot Phase
|
| 75 |
+
# Index 20 + Phase. Phase.MAIN = 4 -> Index 24.
|
| 76 |
+
# Score is 5.
|
| 77 |
+
# If code reads [0] (Score=5), it sets Index 25.
|
| 78 |
+
# If code reads [8] (Phase=4), it sets Index 24.
|
| 79 |
+
|
| 80 |
+
if obs[24] == 1.0:
|
| 81 |
+
print("Phase correctly identified as MAIN (4).")
|
| 82 |
+
elif obs[25] == 1.0:
|
| 83 |
+
print("Phase identified as 5 (Score value!) -> Mismatch.")
|
| 84 |
+
self.fail("Phase encoding is reading Score index!")
|
| 85 |
+
else:
|
| 86 |
+
print("Phase not set at expected index 24?")
|
| 87 |
+
# Check if it's somewhere else
|
| 88 |
+
idx = np.where(obs[20:30] == 1.0)[0]
|
| 89 |
+
print(f"Phase set at relative index: {idx}")
|
| 90 |
+
self.fail(f"Phase encoding failed. Index 24 not set. Found set at relative: {idx}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
unittest.main()
|