trioskosmos commited on
Commit
2d88649
·
verified ·
1 Parent(s): 0e3f840

Upload ai/utils/obs_adapters_backup.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/obs_adapters_backup.py +185 -0
ai/utils/obs_adapters_backup.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from engine.game.game_state import GameState
4
+
5
+
6
+ class UnifiedObservationEncoder:
7
+ """
8
+ Translates current GameState into various historic observation formats.
9
+ """
10
+
11
+ @staticmethod
12
+ def encode(state: GameState, dim: int, player_idx: int = None) -> np.ndarray:
13
+ if player_idx is None:
14
+ player_idx = state.current_player
15
+
16
+ if dim == 8192:
17
+ return UnifiedObservationEncoder._encode_8192(state, player_idx)
18
+ elif dim == 320:
19
+ return UnifiedObservationEncoder._encode_320(state, player_idx)
20
+ elif dim == 128:
21
+ return UnifiedObservationEncoder._encode_128(state, player_idx)
22
+ else:
23
+ raise ValueError(f"Unsupported observation dimension: {dim}")
24
+
25
+ @staticmethod
26
+ def _encode_8192(state: GameState, player_idx: int) -> np.ndarray:
27
+ from ai.vector_env import VectorGameState as VGS
28
+ from ai.vector_env import encode_observations_vectorized
29
+
30
+ p = state.players[player_idx]
31
+ opp = state.players[1 - player_idx]
32
+
33
+ # Max ID for normalization is handled inside encoder
34
+
35
+ # Prepare inputs strictly matching VectorEnv.encode_observations_vectorized signature
36
+ # 1. Num Envs (1)
37
+ # 2. Batch Hand (1, 60)
38
+ # 3. Batch Stage (1, 3)
39
+ # 4. Batch Energy Count (1, 3)
40
+ # 5. Batch Tapped (1, 3)
41
+ # 6. Batch Scores (1,)
42
+ # 7. Opp Scores (1,)
43
+ # 8. Opp Stage (1, 3)
44
+ # 9. Opp Tapped (1, 3)
45
+ # 10. Card Stats (from VGS)
46
+ # 11. Global Context (1, 128)
47
+ # 12. Batch Live (1, 50)
48
+ # 13. Batch Opp History (1, 50)
49
+ # 14. Turn Number
50
+ # 15. Obs Buffer (1, 8192)
51
+
52
+ # --- Allocations ---
53
+ batch_hand = np.zeros((1, 60), dtype=np.int32)
54
+ batch_stage = np.full((1, 3), -1, dtype=np.int32)
55
+ batch_energy_count = np.zeros((1, 3), dtype=np.int32)
56
+ batch_tapped = np.zeros((1, 3), dtype=np.int32)
57
+ batch_live = np.zeros((1, 50), dtype=np.int32)
58
+
59
+ opp_stage = np.full((1, 3), -1, dtype=np.int32)
60
+ opp_tapped = np.zeros((1, 3), dtype=np.int32)
61
+ opp_history = np.zeros((1, 50), dtype=np.int32)
62
+
63
+ # --- Population ---
64
+
65
+ # Hand
66
+ h_len = min(len(p.hand), 60)
67
+ for i in range(h_len):
68
+ batch_hand[0, i] = p.hand[i]
69
+
70
+ # Stage
71
+ for i in range(3):
72
+ batch_stage[0, i] = p.stage[i]
73
+ batch_energy_count[0, i] = p.stage_energy_count[i]
74
+ batch_tapped[0, i] = 1 if p.tapped_members[i] else 0
75
+
76
+ opp_stage[0, i] = opp.stage[i]
77
+ opp_tapped[0, i] = 1 if opp.tapped_members[i] else 0
78
+
79
+ # Live Zone
80
+ # Assuming GameState has p.live_zone list or similar?
81
+ # GameState definition usually implies 'success_lives' are won lives.
82
+ # Active lives might be tracked elsewhere?
83
+ # If not available, leave as zeros.
84
+ # Checking GameState... usually just has success_lives. Active lives are transient in legacy?
85
+ # VectorEnv tracks them. Legacy might not.
86
+
87
+ # Scores
88
+ batch_scores = np.array([len(p.success_lives)], dtype=np.int32)
89
+ opp_scores = np.array([len(opp.success_lives)], dtype=np.int32)
90
+
91
+ # Global Context
92
+ g_ctx = np.zeros((1, 128), dtype=np.int32)
93
+ g_ctx[0, 0] = len(p.success_lives) # SC
94
+ g_ctx[0, 1] = len(opp.success_lives) # OS
95
+ g_ctx[0, 2] = len(p.discard) # TR
96
+ g_ctx[0, 3] = len(p.hand) # HD
97
+ g_ctx[0, 5] = p.energy_count # EN
98
+ g_ctx[0, 6] = len(p.main_deck) # DK
99
+ g_ctx[0, 8] = 5 # PHASE (Main) - Legacy default
100
+
101
+ # Opponent History (Trash top cards?)
102
+ op_h_len = min(len(opp.discard), 50)
103
+ for i in range(op_h_len):
104
+ # LIFO? VectorEnv usually assumes LIFO or FIFO depending on implementation.
105
+ # Usually end is top.
106
+ opp_history[0, i] = opp.discard[-(i + 1)]
107
+
108
+ # Output buffer
109
+ obs = np.zeros((1, 8192), dtype=np.float32)
110
+
111
+ if not hasattr(UnifiedObservationEncoder, "_vgs_cache"):
112
+ UnifiedObservationEncoder._vgs_cache = VGS(1)
113
+
114
+ vgs = UnifiedObservationEncoder._vgs_cache
115
+
116
+ encode_observations_vectorized(
117
+ 1,
118
+ batch_hand,
119
+ batch_stage,
120
+ batch_energy_count,
121
+ batch_tapped,
122
+ batch_scores,
123
+ opp_scores,
124
+ opp_stage,
125
+ opp_tapped,
126
+ vgs.card_stats,
127
+ g_ctx,
128
+ batch_live,
129
+ opp_history,
130
+ state.turn_number,
131
+ obs,
132
+ )
133
+ return obs[0]
134
+
135
+ @staticmethod
136
+ def _encode_320(state: GameState, player_idx: int) -> np.ndarray:
137
+ # LEGACY 320 (First Speed-up Era)
138
+ # Replicates the encoding from ai/vector_env_legacy.py exactly.
139
+ # This era ONLY saw Self Stage and Self Score. Hand/Opp were 0.
140
+
141
+ obs = np.zeros(320, dtype=np.float32)
142
+ p = state.players[player_idx]
143
+ max_id_val = 2000.0 # Standard for VectorEnv
144
+
145
+ # Phase [5] = 1.0 (Mocking Main Phase index from Legacy VectorEnv)
146
+ obs[5] = 1.0
147
+ # Current Player [16]
148
+ obs[16] = 1.0
149
+
150
+ # Stage [168:204] (3 slots * 12 features)
151
+ # Note: Hand [36:168] remains 0.0 as in legacy training.
152
+ for i in range(3):
153
+ cid = p.stage[i]
154
+ base = 168 + i * 12
155
+ if cid >= 0:
156
+ obs[base] = 1.0 # Exist
157
+ obs[base + 1] = cid / max_id_val
158
+ # Legacy energy count was normalized by 5.0
159
+ obs[base + 11] = min(p.stage_energy_count[i] / 5.0, 1.0)
160
+
161
+ # Score [270] (Self Score normalized by 5.0 in legacy)
162
+ obs[270] = min(len(p.success_lives) / 5.0, 1.0)
163
+
164
+ return obs
165
+
166
+ @staticmethod
167
+ def _encode_128(state: GameState, player_idx: int) -> np.ndarray:
168
+ # 128-dim is the global_ctx vector
169
+ p = state.players[player_idx]
170
+ opp = state.players[1 - player_idx]
171
+
172
+ g_ctx = np.zeros(128, dtype=np.float32)
173
+ # Standard normalization from AlphaZero era
174
+ g_ctx[0] = len(p.success_lives) / 3.0
175
+ g_ctx[1] = len(opp.success_lives) / 3.0
176
+ g_ctx[2] = len(p.discard) / 50.0
177
+ g_ctx[3] = len(p.hand) / 50.0 # Normalized to deck size usually
178
+ g_ctx[5] = p.energy_count / 10.0
179
+ g_ctx[6] = len(p.main_deck) / 50.0
180
+
181
+ # Turn info
182
+ g_ctx[10] = state.turn_number / 20.0
183
+ g_ctx[11] = 1.0 if state.current_player == player_idx else 0.0
184
+
185
+ return g_ctx