trioskosmos commited on
Commit
2badd2f
·
verified ·
1 Parent(s): 250ac83

Upload ai/environments/vector_env_legacy.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/environments/vector_env_legacy.py +297 -0
ai/environments/vector_env_legacy.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+
5
+ from engine.game.ai_compat import njit
6
+ from engine.game.fast_logic import batch_apply_action
7
+
8
+
9
+ @njit(cache=True)
10
+ def step_vectorized(
11
+ actions: np.ndarray,
12
+ batch_stage: np.ndarray,
13
+ batch_energy_vec: np.ndarray,
14
+ batch_energy_count: np.ndarray,
15
+ batch_continuous_vec: np.ndarray,
16
+ batch_continuous_ptr: np.ndarray,
17
+ batch_tapped: np.ndarray,
18
+ batch_live: np.ndarray,
19
+ batch_opp_tapped: np.ndarray,
20
+ batch_scores: np.ndarray,
21
+ batch_flat_ctx: np.ndarray,
22
+ batch_global_ctx: np.ndarray,
23
+ batch_hand: np.ndarray,
24
+ batch_deck: np.ndarray,
25
+ # New: Bytecode Maps
26
+ bytecode_map: np.ndarray, # (GlobalOpMapSize, MaxBytecodeLen, 4)
27
+ bytecode_index: np.ndarray, # (NumCards, NumAbilities) -> Index in map
28
+ ):
29
+ """
30
+ Step N game environments in parallel using JIT logic and Real Card Data.
31
+ """
32
+ # Sync individual scores to global_ctx before stepping
33
+ for i in range(len(actions)):
34
+ batch_global_ctx[i, 0] = batch_scores[i]
35
+
36
+ batch_apply_action(
37
+ actions,
38
+ 0, # player_id
39
+ batch_stage,
40
+ batch_energy_vec,
41
+ batch_energy_count,
42
+ batch_continuous_vec,
43
+ batch_continuous_ptr,
44
+ batch_tapped,
45
+ batch_scores,
46
+ batch_live,
47
+ batch_opp_tapped,
48
+ batch_flat_ctx,
49
+ batch_global_ctx,
50
+ batch_hand,
51
+ batch_deck,
52
+ bytecode_map,
53
+ bytecode_index,
54
+ )
55
+
56
+
57
+ class VectorGameState:
58
+ """
59
+ Manages a batch of independent GameStates for high-throughput training.
60
+ """
61
+
62
+ def __init__(self, num_envs: int):
63
+ self.num_envs = num_envs
64
+ self.turn = 1
65
+
66
+ # Batched state buffers
67
+ self.batch_stage = np.full((num_envs, 3), -1, dtype=np.int32)
68
+ self.batch_energy_vec = np.zeros((num_envs, 3, 32), dtype=np.int32)
69
+ self.batch_energy_count = np.zeros((num_envs, 3), dtype=np.int32)
70
+ self.batch_continuous_vec = np.zeros((num_envs, 32, 10), dtype=np.int32)
71
+ self.batch_continuous_ptr = np.zeros(num_envs, dtype=np.int32)
72
+ self.batch_tapped = np.zeros((num_envs, 3), dtype=np.int32)
73
+ self.batch_live = np.zeros((num_envs, 50), dtype=np.int32)
74
+ self.batch_opp_tapped = np.zeros((num_envs, 3), dtype=np.int32)
75
+ self.batch_scores = np.zeros(num_envs, dtype=np.int32)
76
+
77
+ # Pre-allocated context buffers (Extreme speed optimization)
78
+ self.batch_flat_ctx = np.zeros((num_envs, 64), dtype=np.int32)
79
+ self.batch_global_ctx = np.zeros((num_envs, 128), dtype=np.int32)
80
+ self.batch_hand = np.zeros((num_envs, 50), dtype=np.int32)
81
+ self.batch_deck = np.zeros((num_envs, 50), dtype=np.int32)
82
+
83
+ # Pre-allocated observation buffer (SAVES ALLOCATION TIME)
84
+ self.obs_buffer = np.zeros((num_envs, 320), dtype=np.float32)
85
+
86
+ # Load Bytecode Map
87
+ self._load_bytecode()
88
+ self._load_verified_deck_pool()
89
+
90
+ def _load_bytecode(self):
91
+ import json
92
+
93
+ try:
94
+ with open("data/cards_numba.json", "r") as f:
95
+ raw_map = json.load(f)
96
+
97
+ # Convert to numpy array
98
+ # Format: key "cardid_abidx" -> List[int]
99
+ # storage:
100
+ # 1. giant array of bytecodes (N, MaxLen, 4)
101
+ # 2. lookup index (CardID, AbIdx) -> Index in giant array
102
+
103
+ self.max_cards = 2000
104
+ self.max_abilities = 4
105
+ self.max_len = 64 # Max 64 instructions per ability
106
+
107
+ # Count unique compiled entries
108
+ unique_entries = len(raw_map)
109
+ # (Index 0 is empty/nop)
110
+ self.bytecode_map = np.zeros((unique_entries + 1, self.max_len, 4), dtype=np.int32)
111
+ self.bytecode_index = np.full((self.max_cards, self.max_abilities), 0, dtype=np.int32)
112
+
113
+ idx_counter = 1
114
+ for key, bc_list in raw_map.items():
115
+ cid, aid = map(int, key.split("_"))
116
+ if cid < self.max_cards and aid < self.max_abilities:
117
+ # reshape list to (M, 4)
118
+ bc_arr = np.array(bc_list, dtype=np.int32).reshape(-1, 4)
119
+ length = min(bc_arr.shape[0], self.max_len)
120
+ self.bytecode_map[idx_counter, :length] = bc_arr[:length]
121
+ self.bytecode_index[cid, aid] = idx_counter
122
+ idx_counter += 1
123
+
124
+ print(f" [VectorEnv] Loaded {unique_entries} compiled abilities.")
125
+
126
+ except FileNotFoundError:
127
+ print(" [VectorEnv] Warning: data/cards_numba.json not found. Using empty map.")
128
+ self.bytecode_map = np.zeros((1, 64, 4), dtype=np.int32)
129
+ self.bytecode_index = np.zeros((1, 1), dtype=np.int32)
130
+
131
+ def _load_verified_deck_pool(self):
132
+ import json
133
+
134
+ try:
135
+ # Load Verified List
136
+ with open("verified_card_pool.json", "r", encoding="utf-8") as f:
137
+ verified_data = json.load(f)
138
+
139
+ # Load DB to map CardNo -> CardID
140
+ with open("data/cards_compiled.json", "r", encoding="utf-8") as f:
141
+ db_data = json.load(f)
142
+
143
+ self.verified_card_ids = []
144
+
145
+ # Map numbers to IDs
146
+ card_no_map = {}
147
+ for cid, cdata in db_data["member_db"].items():
148
+ card_no_map[cdata["card_no"]] = int(cid)
149
+
150
+ for v_no in verified_data.get("verified_abilities", []):
151
+ if v_no in card_no_map:
152
+ self.verified_card_ids.append(card_no_map[v_no])
153
+
154
+ # Fallback
155
+ if not self.verified_card_ids:
156
+ print(" [VectorEnv] Warning: No verified cards found. Using ID 1.")
157
+ self.verified_card_ids = [1]
158
+ else:
159
+ print(f" [VectorEnv] Loaded {len(self.verified_card_ids)} verified cards for training.")
160
+
161
+ self.verified_card_ids = np.array(self.verified_card_ids, dtype=np.int32)
162
+
163
+ except Exception as e:
164
+ print(f" [VectorEnv] Deck Load Error: {e}")
165
+ self.verified_card_ids = np.array([1], dtype=np.int32)
166
+
167
+ def reset(self, indices: List[int] = None):
168
+ """Reset specified environments (or all if indices is None)."""
169
+ if indices is None:
170
+ indices = list(range(self.num_envs))
171
+
172
+ # Optimization: Bulk operations for indices if supported,
173
+ # but for now loop is fine (reset is rare compared to step)
174
+
175
+ # Prepare a random deck selection to broadcast?
176
+ # Actually random.choice is fast.
177
+
178
+ for i in indices:
179
+ self.batch_stage[i].fill(-1)
180
+ self.batch_energy_vec[i].fill(0)
181
+ self.batch_energy_count[i].fill(0)
182
+ self.batch_continuous_vec[i].fill(0)
183
+ self.batch_continuous_ptr[i] = 0
184
+ self.batch_tapped[i].fill(0)
185
+ self.batch_live[i].fill(0)
186
+ self.batch_opp_tapped[i].fill(0)
187
+ self.batch_scores[i] = 0
188
+
189
+ # Reset contexts
190
+ self.batch_flat_ctx[i].fill(0)
191
+ self.batch_global_ctx[i].fill(0)
192
+
193
+ # Initialize Deck with Verified Cards (Random 50)
194
+ # Fast choice from verified pool
195
+ if len(self.verified_card_ids) > 0:
196
+ dk = np.random.choice(self.verified_card_ids, 50)
197
+ self.batch_deck[i] = dk
198
+
199
+ # Initialize Hand (Draw 5 from deck)
200
+ # Simple simulation: Move top 5 deck to hand
201
+ self.batch_hand[i, :5] = self.batch_deck[i, :5]
202
+ # Shift deck? Or just pointer?
203
+ # For this benchmark we assume infinite deck or simple pointer logic via opcodes.
204
+ # But the 'hand' array needs to be populated for gameplay to start.
205
+
206
+ self.turn = 1
207
+
208
+ def step(self, actions: np.ndarray):
209
+ """Apply a batch of actions across all environments."""
210
+ step_vectorized(
211
+ actions,
212
+ self.batch_stage,
213
+ self.batch_energy_vec,
214
+ self.batch_energy_count,
215
+ self.batch_continuous_vec,
216
+ self.batch_continuous_ptr,
217
+ self.batch_tapped,
218
+ self.batch_live,
219
+ self.batch_opp_tapped,
220
+ self.batch_scores,
221
+ self.batch_flat_ctx,
222
+ self.batch_global_ctx,
223
+ self.batch_hand,
224
+ self.batch_deck,
225
+ self.bytecode_map,
226
+ self.bytecode_index,
227
+ )
228
+ # Simplified turn advancement
229
+ # In real VectorEnv, this would be managed by the engine rules
230
+ pass
231
+
232
+ def get_observations(self):
233
+ """Return a batched observation for RL models."""
234
+ return encode_observations_vectorized(
235
+ self.num_envs,
236
+ self.batch_stage,
237
+ self.batch_energy_count,
238
+ self.batch_tapped,
239
+ self.batch_scores,
240
+ self.turn,
241
+ self.obs_buffer,
242
+ )
243
+
244
+
245
+ @njit(cache=True)
246
+ def encode_observations_vectorized(
247
+ num_envs: int,
248
+ batch_stage: np.ndarray, # (N, 3)
249
+ batch_energy_count: np.ndarray, # (N, 3)
250
+ batch_tapped: np.ndarray, # (N, 3)
251
+ batch_scores: np.ndarray, # (N,)
252
+ turn_number: int,
253
+ observations: np.ndarray, # (N, 320)
254
+ ):
255
+ # Reset buffer (extremely fast on pre-allocated)
256
+ observations.fill(0.0)
257
+ max_id_val = 2000.0 # Normalization constant
258
+
259
+ for i in range(num_envs):
260
+ # --- 1. METADATA [0:36] ---
261
+ # Phase (Simplify: Always Main Phase=1 for now in vector env)
262
+ # Phase 1=Start, 2=Draw, 3=Main... Main is index 3+2=5?
263
+ # GameState logic: phase_val = int(phase) + 2. Main is 3. So 5.
264
+ observations[i, 5] = 1.0
265
+
266
+ # Current Player [16:18] - Always Player 0 for this vector view
267
+ observations[i, 16] = 1.0
268
+
269
+ # --- 2. HAND [36:168] ---
270
+ # VectorEnv doesn't track hand yet. Leave 0.0.
271
+
272
+ # --- 3. SELF STAGE [168:204] (3 slots * 12 features) ---
273
+ for slot in range(3):
274
+ cid = batch_stage[i, slot]
275
+ base = 168 + slot * 12
276
+ if cid >= 0:
277
+ observations[i, base] = 1.0
278
+ observations[i, base + 1] = cid / max_id_val
279
+ observations[i, base + 2] = 1.0 if batch_tapped[i, slot] else 0.0
280
+
281
+ # Mock attributes (since we don't have full DB access inside JIT yet)
282
+ # In real imp, we'd pass arrays like member_costs
283
+ observations[i, base + 3] = 0.5 # Default power
284
+
285
+ # Energy Count
286
+ observations[i, base + 11] = min(batch_energy_count[i, slot] / 5.0, 1.0)
287
+
288
+ # --- 4. OPPONENT STAGE [204:240] ---
289
+ # Not tracked in partial vector env yet.
290
+
291
+ # --- 5. LIVE ZONE [240:270] ---
292
+ # Not tracked in partial vector env yet.
293
+
294
+ # --- 6. SCORES [270:272] ---
295
+ observations[i, 270] = min(batch_scores[i] / 5.0, 1.0)
296
+
297
+ return observations