import os import sys import unittest import numpy as np # Add project root to path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) from ai.obs_adapters import UnifiedObservationEncoder from engine.game.enums import Phase from engine.game.game_state import GameState class TestObsAlignment(unittest.TestCase): def test_encode_8192_mappings(self): print("Verifying UnifiedObservationEncoder._encode_8192 alignment...") # 1. Create Mock GameState state = GameState() p = state.players[0] opp = state.players[1] # Set known values to distinct integers to trace them # Counts: # Score: 5 (Distinct from Phase values 3, 4) p.success_lives = [1] * 5 opp.success_lives = [1] * 2 p.discard = [1] * 11 p.hand = [1] * 4 p.energy_zone = [1] * 5 p.main_deck = [1] * 20 opp.discard = [1] * 7 opp.hand = [1] * 6 opp.main_deck = [1] * 25 state.phase = Phase.MAIN # 4 state.turn_number = 10 # 2. Encode obs = UnifiedObservationEncoder.encode(state, 8192, 0) # 3. Verify Output in Observation Vector SCORE_START = 8000 # obs[SCORE_START] = score / 9.0 self.assertAlmostEqual(obs[SCORE_START], 5.0 / 9.0, places=4, msg="My Score Mismatch") VOLUMES_START = 7800 # My Deck (DK=20) -> VOL+0 self.assertAlmostEqual(obs[VOLUMES_START], 20.0 / 50.0, places=4, msg="My Deck Volume Mismatch") # Opp Deck (DK=25) -> VOL+1 self.assertAlmostEqual(obs[VOLUMES_START + 1], 25.0 / 50.0, places=4, msg="Opp Deck Volume Mismatch") # My Hand (HD=4) -> VOL+2 self.assertAlmostEqual(obs[VOLUMES_START + 2], 4.0 / 20.0, places=4, msg="My Hand Volume Mismatch") # My Trash (TR=11) -> VOL+3 self.assertAlmostEqual(obs[VOLUMES_START + 3], 11.0 / 50.0, places=4, msg="My Trash Volume Mismatch") # Opp Hand (HD=6) -> VOL+4 self.assertAlmostEqual(obs[VOLUMES_START + 4], 6.0 / 20.0, places=4, msg="Opp Hand Volume Mismatch") # Opp Trash (OT=7) -> VOL+5 val_opp_trash = obs[VOLUMES_START + 5] * 50.0 print(f"Observed Opp Trash (Volume): {val_opp_trash} (Expected 7)") if abs(val_opp_trash - 7.0) > 0.1: self.fail(f"Opp Trash mismatch: Got {val_opp_trash}, Expected 7") # One-Hot Phase # Index 20 + Phase. Phase.MAIN = 4 -> Index 24. # Score is 5. # If code reads [0] (Score=5), it sets Index 25. # If code reads [8] (Phase=4), it sets Index 24. if obs[24] == 1.0: print("Phase correctly identified as MAIN (4).") elif obs[25] == 1.0: print("Phase identified as 5 (Score value!) -> Mismatch.") self.fail("Phase encoding is reading Score index!") else: print("Phase not set at expected index 24?") # Check if it's somewhere else idx = np.where(obs[20:30] == 1.0)[0] print(f"Phase set at relative index: {idx}") self.fail(f"Phase encoding failed. Index 24 not set. Found set at relative: {idx}") if __name__ == "__main__": unittest.main()