trioskosmos commited on
Commit
191b8ab
·
verified ·
1 Parent(s): 886cd06

Upload ai/research/cuda_kernels.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/research/cuda_kernels.py +1368 -0
ai/research/cuda_kernels.py ADDED
@@ -0,0 +1,1368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CUDA Kernels for GPU-Accelerated VectorEnv.
3
+
4
+ This module contains CUDA kernel implementations for:
5
+ - Environment reset
6
+ - Game step (integrated with opponent, phases, scoring)
7
+ - Observation encoding
8
+ - Action mask computation
9
+
10
+ All kernels are designed for the VectorGameStateGPU class.
11
+ """
12
+
13
+ import numpy as np
14
+
15
+ try:
16
+ from numba import cuda
17
+ from numba.cuda.random import xoroshiro128p_normal_float32, xoroshiro128p_uniform_float32
18
+
19
+ HAS_CUDA = True
20
+ except ImportError:
21
+ HAS_CUDA = False
22
+
23
+ # Mock for type checking
24
+ class MockCuda:
25
+ def jit(self, *args, **kwargs):
26
+ def decorator(f):
27
+ return f
28
+
29
+ return decorator
30
+
31
+ def grid(self, x):
32
+ return 0
33
+
34
+ cuda = MockCuda()
35
+
36
+ def xoroshiro128p_uniform_float32(rng, i):
37
+ return 0.5
38
+
39
+
40
+ # ============================================================================
41
+ # CONSTANTS (Must match fast_logic.py)
42
+ # ============================================================================
43
+ SC = 0
44
+ OS = 1
45
+ TR = 2
46
+ HD = 3
47
+ DI = 4
48
+ EN = 5
49
+ DK = 6
50
+ OT = 7
51
+ PH = 8
52
+ OD = 9
53
+
54
+ # Opcodes
55
+ O_DRAW = 10
56
+ O_BLADES = 11
57
+ O_HEARTS = 12
58
+ O_RECOV_L = 13
59
+ O_BOOST = 14
60
+ O_RECOV_M = 15
61
+ O_BUFF = 16
62
+ O_CHARGE = 17
63
+ O_TAP_O = 18
64
+ O_CHOOSE = 19
65
+ O_ADD_H = 20
66
+ O_RETURN = 999
67
+ O_JUMP = 100
68
+ O_JUMP_F = 101
69
+
70
+ # Conditions
71
+ C_TR1 = 200
72
+ C_CLR = 202
73
+ C_STG = 203
74
+ C_HND = 204
75
+ C_CTR = 206
76
+ C_LLD = 207
77
+ C_GRP = 208
78
+ C_OPH = 210
79
+ C_ENR = 213
80
+ C_CMP = 220
81
+
82
+ # Unique ID (UID) System
83
+ BASE_ID_MASK = 0xFFFFF
84
+
85
+
86
+ @cuda.jit(device=True)
87
+ def get_base_id_device(uid: int) -> int:
88
+ """Extract the base card definition ID (0-1999) from a UID."""
89
+ return uid & BASE_ID_MASK
90
+
91
+
92
+ # ============================================================================
93
+ # DEVICE FUNCTIONS (Callable from kernels)
94
+ # ============================================================================
95
+
96
+
97
+ @cuda.jit(device=True)
98
+ def check_deck_refresh_device(p_deck, p_trash, p_global_ctx, DK_idx, TR_idx):
99
+ """Shuffle trash back into deck if deck is empty."""
100
+ if p_global_ctx[DK_idx] <= 0:
101
+ # Count trash
102
+ tr_count = 0
103
+ for t in range(60):
104
+ if p_trash[t] > 0:
105
+ tr_count += 1
106
+
107
+ if tr_count > 0:
108
+ # Move trash to deck
109
+ d_ptr = 0
110
+ for t in range(60):
111
+ if p_trash[t] > 0:
112
+ p_deck[d_ptr] = p_trash[t]
113
+ p_trash[t] = 0
114
+ d_ptr += 1
115
+
116
+ p_global_ctx[DK_idx] = d_ptr
117
+ p_global_ctx[TR_idx] = 0
118
+
119
+
120
+ @cuda.jit(device=True)
121
+ def move_to_trash_device(card_id, p_trash, p_global_ctx, TR_idx):
122
+ """Move a card to trash zone."""
123
+ for t in range(60):
124
+ if p_trash[t] == 0:
125
+ p_trash[t] = card_id
126
+ p_global_ctx[TR_idx] += 1
127
+ break
128
+
129
+
130
+ @cuda.jit(device=True)
131
+ def draw_cards_device(count, p_hand, p_deck, p_trash, p_global_ctx):
132
+ """Draw cards from deck to hand."""
133
+ for _ in range(count):
134
+ check_deck_refresh_device(p_deck, p_trash, p_global_ctx, DK, TR)
135
+
136
+ if p_global_ctx[DK] <= 0:
137
+ break
138
+
139
+ # Find top card in deck
140
+ top_card = 0
141
+ d_idx_found = -1
142
+ for d in range(60):
143
+ if p_deck[d] > 0:
144
+ top_card = p_deck[d]
145
+ d_idx_found = d
146
+ break
147
+
148
+ if top_card > 0:
149
+ # Find empty hand slot
150
+ for h in range(60):
151
+ if p_hand[h] == 0:
152
+ p_hand[h] = top_card
153
+ p_deck[d_idx_found] = 0
154
+ p_global_ctx[DK] -= 1
155
+ p_global_ctx[HD] += 1
156
+ break
157
+
158
+
159
+ @cuda.jit(device=True)
160
+ def resolve_bytecode_device(
161
+ bytecode,
162
+ flat_ctx,
163
+ global_ctx,
164
+ player_id,
165
+ p_hand,
166
+ p_deck,
167
+ p_stage,
168
+ p_energy_vec,
169
+ p_energy_count,
170
+ p_cont_vec,
171
+ p_cont_ptr,
172
+ p_tapped,
173
+ p_live,
174
+ opp_tapped,
175
+ p_trash,
176
+ bytecode_map,
177
+ bytecode_index,
178
+ ):
179
+ """
180
+ GPU Device function for resolving bytecode.
181
+ Returns (new_cont_ptr, status, bonus).
182
+ """
183
+ ip = 0
184
+ cptr = p_cont_ptr
185
+ bonus = 0
186
+ cond = True
187
+ blen = bytecode.shape[0]
188
+ safety_counter = 0
189
+
190
+ while ip < blen and safety_counter < 500:
191
+ safety_counter += 1
192
+ op = bytecode[ip, 0]
193
+ v = bytecode[ip, 1]
194
+ a = bytecode[ip, 2]
195
+ s = bytecode[ip, 3]
196
+
197
+ if op == 0:
198
+ ip += 1
199
+ continue
200
+ if op == O_RETURN:
201
+ break
202
+
203
+ # Jumps
204
+ if op == O_JUMP:
205
+ new_ip = ip + v
206
+ if 0 <= new_ip < blen:
207
+ ip = new_ip
208
+ else:
209
+ break
210
+ continue
211
+
212
+ if op == O_JUMP_F:
213
+ if not cond:
214
+ new_ip = ip + v
215
+ if 0 <= new_ip < blen:
216
+ ip = new_ip
217
+ else:
218
+ break
219
+ continue
220
+ ip += 1
221
+ continue
222
+
223
+ # Conditions (op >= 200)
224
+ if op >= 200:
225
+ if op == C_TR1:
226
+ cond = global_ctx[TR] == 1
227
+ elif op == C_STG:
228
+ ct = 0
229
+ for j in range(3):
230
+ if p_stage[j] != -1:
231
+ ct += 1
232
+ cond = ct >= v
233
+ elif op == C_HND:
234
+ cond = global_ctx[HD] >= v
235
+ elif op == C_LLD:
236
+ cond = global_ctx[SC] > global_ctx[OS]
237
+ elif op == C_ENR:
238
+ cond = global_ctx[EN] >= v
239
+ elif op == C_CMP:
240
+ if v > 0:
241
+ cond = global_ctx[SC] >= v
242
+ else:
243
+ cond = global_ctx[SC] > global_ctx[OS]
244
+ elif op == C_OPH:
245
+ cond = global_ctx[OT] >= v if v > 0 else global_ctx[OT] > 0
246
+ else:
247
+ cond = True
248
+ ip += 1
249
+ else:
250
+ # Effects
251
+ if cond:
252
+ if op == O_DRAW:
253
+ draw_cards_device(v, p_hand, p_deck, p_trash, global_ctx)
254
+ elif op == O_CHARGE:
255
+ # Move cards from deck to energy (simplified)
256
+ amt = min(v, global_ctx[DK])
257
+ for _ in range(amt):
258
+ for d in range(60):
259
+ if p_deck[d] > 0:
260
+ p_deck[d] = 0
261
+ global_ctx[DK] -= 1
262
+ global_ctx[EN] += 1
263
+ break
264
+ elif op == O_HEARTS:
265
+ # Add hearts (points)
266
+ bonus += v
267
+ # Register continuous effect
268
+ if cptr < 32:
269
+ p_cont_vec[cptr, 0] = 2
270
+ p_cont_vec[cptr, 1] = v
271
+ p_cont_vec[cptr, 5] = a
272
+ p_cont_vec[cptr, 9] = 1
273
+ cptr += 1
274
+ elif op == O_BLADES:
275
+ if cptr < 32:
276
+ p_cont_vec[cptr, 0] = 1
277
+ p_cont_vec[cptr, 1] = v
278
+ p_cont_vec[cptr, 2] = 4
279
+ p_cont_vec[cptr, 3] = s
280
+ p_cont_vec[cptr, 9] = 1
281
+ cptr += 1
282
+ elif op == O_RECOV_M:
283
+ if 0 <= s < 3:
284
+ p_tapped[s] = 0
285
+ elif op == O_RECOV_L:
286
+ if 0 <= s < p_live.shape[0]:
287
+ p_live[s] = 0
288
+ elif op == O_TAP_O:
289
+ if 0 <= s < 3:
290
+ opp_tapped[s] = 1
291
+ elif op == O_BUFF:
292
+ if cptr < 32:
293
+ p_cont_vec[cptr, 0] = 8
294
+ p_cont_vec[cptr, 1] = v
295
+ p_cont_vec[cptr, 2] = s
296
+ p_cont_vec[cptr, 9] = 1
297
+ cptr += 1
298
+ elif op == O_BOOST:
299
+ bonus += v
300
+ ip += 1
301
+
302
+ return cptr, 0, bonus
303
+
304
+
305
+ @cuda.jit(device=True)
306
+ def step_player_device(
307
+ act_id,
308
+ player_id,
309
+ rng_state,
310
+ i,
311
+ p_hand,
312
+ p_deck,
313
+ p_stage,
314
+ p_energy_vec,
315
+ p_energy_count,
316
+ p_tapped,
317
+ p_live,
318
+ p_scores,
319
+ p_global_ctx,
320
+ p_trash,
321
+ p_continuous_vec,
322
+ p_continuous_ptr,
323
+ opp_tapped,
324
+ card_stats,
325
+ bytecode_map,
326
+ bytecode_index,
327
+ ):
328
+ """
329
+ Device function for single player step.
330
+ Returns bonus score from this action.
331
+ """
332
+ bonus = 0
333
+
334
+ if act_id == 0:
335
+ # Pass -> Next Phase
336
+ ph = p_global_ctx[PH]
337
+ if ph == -1:
338
+ p_global_ctx[PH] = 0
339
+ elif ph == 0:
340
+ p_global_ctx[PH] = 4 # Skip to Main
341
+ elif ph == 4:
342
+ p_global_ctx[PH] = 8 # Performance
343
+ return 0
344
+
345
+ # Member Play (1-180)
346
+ if 1 <= act_id <= 180:
347
+ adj = act_id - 1
348
+ hand_idx = adj // 3
349
+ slot = adj % 3
350
+
351
+ if hand_idx < 60:
352
+ card_id = p_hand[hand_idx]
353
+ if card_id >= 0:
354
+ bid = get_base_id_device(card_id)
355
+ if bid < card_stats.shape[0]:
356
+ # Cost calculation
357
+ cost = card_stats[bid, 0]
358
+ effective_cost = cost
359
+ prev_cid = p_stage[slot]
360
+ if prev_cid >= 0:
361
+ prev_bid = get_base_id_device(prev_cid)
362
+ if prev_bid < card_stats.shape[0]:
363
+ prev_cost = card_stats[prev_bid, 0]
364
+ effective_cost = max(0, cost - prev_cost)
365
+
366
+ # Pay cost by tapping energy
367
+ ec = min(p_global_ctx[EN], 12)
368
+ paid = 0
369
+ if effective_cost > 0:
370
+ for e_idx in range(ec):
371
+ if 3 + e_idx < 16:
372
+ if p_tapped[3 + e_idx] == 0:
373
+ p_tapped[3 + e_idx] = 1
374
+ paid += 1
375
+ if paid >= effective_cost:
376
+ break
377
+
378
+ # Move to stage
379
+ p_stage[slot] = card_id
380
+ p_hand[hand_idx] = 0
381
+ p_global_ctx[HD] -= 1
382
+ p_global_ctx[51 + slot] = 1 # Mark played
383
+
384
+ # Resolve auto-ability
385
+ bid = get_base_id_device(card_id)
386
+ if bid < bytecode_index.shape[0]:
387
+ map_idx = bytecode_index[bid, 0]
388
+ if map_idx >= 0:
389
+ flat_ctx = cuda.local.array(64, dtype=np.int32)
390
+ for j in range(64):
391
+ flat_ctx[j] = 0
392
+
393
+ new_ptr, _, ab_bonus = resolve_bytecode_device(
394
+ bytecode_map[map_idx],
395
+ flat_ctx,
396
+ p_global_ctx,
397
+ player_id,
398
+ p_hand,
399
+ p_deck,
400
+ p_stage,
401
+ p_energy_vec,
402
+ p_energy_count,
403
+ p_continuous_vec,
404
+ p_continuous_ptr[0],
405
+ p_tapped,
406
+ p_live,
407
+ opp_tapped,
408
+ p_trash,
409
+ bytecode_map,
410
+ bytecode_index,
411
+ )
412
+ p_continuous_ptr[0] = new_ptr
413
+ bonus += ab_bonus
414
+
415
+ # Activate Ability (200-202)
416
+ elif 200 <= act_id <= 202:
417
+ slot = act_id - 200
418
+ card_id = p_stage[slot]
419
+ if card_id >= 0 and p_tapped[slot] == 0:
420
+ bid = get_base_id_device(card_id)
421
+ if bid < bytecode_index.shape[0]:
422
+ map_idx = bytecode_index[bid, 0]
423
+ if map_idx >= 0:
424
+ flat_ctx = cuda.local.array(64, dtype=np.int32)
425
+ for j in range(64):
426
+ flat_ctx[j] = 0
427
+
428
+ new_ptr, _, ab_bonus = resolve_bytecode_device(
429
+ bytecode_map[map_idx],
430
+ flat_ctx,
431
+ p_global_ctx,
432
+ player_id,
433
+ p_hand,
434
+ p_deck,
435
+ p_stage,
436
+ p_energy_vec,
437
+ p_energy_count,
438
+ p_continuous_vec,
439
+ p_continuous_ptr[0],
440
+ p_tapped,
441
+ p_live,
442
+ opp_tapped,
443
+ p_trash,
444
+ bytecode_map,
445
+ bytecode_index,
446
+ )
447
+ p_continuous_ptr[0] = new_ptr
448
+ bonus += ab_bonus
449
+ p_tapped[slot] = 1
450
+
451
+ # Set Live Card (400-459)
452
+ elif 400 <= act_id <= 459:
453
+ hand_idx = act_id - 400
454
+ if hand_idx < 60:
455
+ card_id = p_hand[hand_idx]
456
+ if card_id > 0:
457
+ # Find empty live zone slot
458
+ for j in range(50):
459
+ if p_live[j] == 0:
460
+ p_live[j] = card_id
461
+ p_hand[hand_idx] = 0
462
+ p_global_ctx[HD] -= 1
463
+ break
464
+
465
+ return bonus
466
+
467
+
468
+ @cuda.jit(device=True)
469
+ def resolve_live_device(
470
+ live_id, p_stage, p_live, p_scores, p_global_ctx, p_deck, p_hand, p_trash, card_stats, p_cont_vec, p_cont_ptr
471
+ ):
472
+ """
473
+ Device function to resolve a live card.
474
+ Returns the score value if successful, 0 otherwise.
475
+ """
476
+ bid = get_base_id_device(live_id)
477
+ if live_id < 0 or bid >= card_stats.shape[0]:
478
+ return 0
479
+
480
+ # Get required hearts from card_stats (indices 12-18)
481
+ required = cuda.local.array(7, dtype=np.int32)
482
+ for c in range(7):
483
+ required[c] = card_stats[bid, 12 + c]
484
+
485
+ total_required = 0
486
+ for c in range(7):
487
+ total_required += required[c]
488
+
489
+ if total_required <= 0:
490
+ # No requirements - auto-succeed
491
+ return card_stats[bid, 38] # Score value
492
+
493
+ # Calculate provided hearts from stage members
494
+ provided = cuda.local.array(7, dtype=np.int32)
495
+ for c in range(7):
496
+ provided[c] = 0
497
+
498
+ for slot in range(3):
499
+ cid = p_stage[slot]
500
+ if cid > 0:
501
+ s_bid = get_base_id_device(cid)
502
+ if s_bid < card_stats.shape[0]:
503
+ for c in range(7):
504
+ provided[c] += card_stats[s_bid, 12 + c]
505
+
506
+ # Check if requirements met
507
+ for c in range(6): # Colors (not All)
508
+ if required[c] > provided[c]:
509
+ return 0 # Failed
510
+
511
+ # All requirements met
512
+ return card_stats[bid, 38]
513
+
514
+
515
+ @cuda.jit(device=True)
516
+ def run_opponent_turn_device(
517
+ rng_state,
518
+ i,
519
+ opp_hand,
520
+ opp_deck,
521
+ opp_stage,
522
+ opp_energy_vec,
523
+ opp_energy_count,
524
+ opp_tapped,
525
+ opp_live,
526
+ opp_scores,
527
+ opp_global_ctx,
528
+ opp_trash,
529
+ p_tapped,
530
+ opp_history,
531
+ card_stats,
532
+ bytecode_map,
533
+ bytecode_index,
534
+ ):
535
+ """
536
+ Simple heuristic opponent turn.
537
+ Plays members if possible, activates abilities, sets lives.
538
+ """
539
+ # Play up to 3 members in empty slots
540
+ for slot in range(3):
541
+ if opp_stage[slot] == -1:
542
+ # Find playable member in hand
543
+ for h in range(60):
544
+ cid = opp_hand[h]
545
+ if cid >= 0:
546
+ bid = get_base_id_device(cid)
547
+ if bid < card_stats.shape[0]:
548
+ ctype = card_stats[bid, 10]
549
+ if ctype == 1: # Member
550
+ cost = card_stats[bid, 0]
551
+ if cost <= opp_global_ctx[EN]:
552
+ # Play it
553
+ opp_stage[slot] = cid
554
+ opp_hand[h] = 0
555
+ opp_global_ctx[HD] -= 1
556
+ # Update History
557
+ for k in range(5, 0, -1):
558
+ opp_history[i, k] = opp_history[i, k - 1]
559
+ opp_history[i, 0] = cid
560
+ break
561
+
562
+ # Set a live card if possible
563
+ for h in range(60):
564
+ cid = opp_hand[h]
565
+ if cid >= 0:
566
+ bid = get_base_id_device(cid)
567
+ if bid < card_stats.shape[0]:
568
+ ctype = card_stats[bid, 10]
569
+ if ctype == 2: # Live
570
+ for lz in range(50):
571
+ if opp_live[lz] == 0:
572
+ opp_live[lz] = cid
573
+ opp_hand[h] = 0
574
+ opp_global_ctx[HD] -= 1
575
+ # Update History
576
+ for k in range(5, 0, -1):
577
+ opp_history[i, k] = opp_history[i, k - 1]
578
+ opp_history[i, 0] = cid
579
+ break
580
+ break
581
+
582
+
583
+ # ============================================================================
584
+ # MAIN KERNELS
585
+ # ============================================================================
586
+
587
+
588
+ @cuda.jit
589
+ def reset_kernel(
590
+ indices,
591
+ batch_stage,
592
+ batch_energy_vec,
593
+ batch_energy_count,
594
+ batch_continuous_vec,
595
+ batch_continuous_ptr,
596
+ batch_tapped,
597
+ batch_live,
598
+ batch_scores,
599
+ batch_flat_ctx,
600
+ batch_global_ctx,
601
+ batch_hand,
602
+ batch_deck,
603
+ batch_trash,
604
+ batch_opp_history,
605
+ opp_stage,
606
+ opp_energy_vec,
607
+ opp_energy_count,
608
+ opp_tapped,
609
+ opp_live,
610
+ opp_scores,
611
+ opp_global_ctx,
612
+ opp_hand,
613
+ opp_deck,
614
+ opp_trash,
615
+ ability_member_ids,
616
+ ability_live_ids,
617
+ rng_states,
618
+ force_start_order,
619
+ obs_buffer,
620
+ card_stats,
621
+ ):
622
+ """
623
+ CUDA Kernel to reset environments.
624
+ """
625
+ tid = cuda.grid(1)
626
+ if tid >= indices.shape[0]:
627
+ return
628
+
629
+ i = indices[tid]
630
+
631
+ # Clear agent state
632
+ for j in range(3):
633
+ batch_stage[i, j] = -1
634
+ for j in range(3):
635
+ for k in range(32):
636
+ batch_energy_vec[i, j, k] = 0
637
+ batch_energy_count[i, j] = 0
638
+ for j in range(32):
639
+ for k in range(10):
640
+ batch_continuous_vec[i, j, k] = 0
641
+ batch_continuous_ptr[i] = 0
642
+ for j in range(16):
643
+ batch_tapped[i, j] = 0
644
+ for j in range(50):
645
+ batch_live[i, j] = 0
646
+ batch_scores[i] = 0
647
+ for j in range(64):
648
+ batch_flat_ctx[i, j] = 0
649
+ for j in range(128):
650
+ batch_global_ctx[i, j] = 0
651
+ for j in range(60):
652
+ batch_trash[i, j] = 0
653
+ for j in range(6):
654
+ batch_opp_history[i, j] = 0
655
+
656
+ # Clear opponent state
657
+ for j in range(3):
658
+ opp_stage[i, j] = -1
659
+ for j in range(3):
660
+ for k in range(32):
661
+ opp_energy_vec[i, j, k] = 0
662
+ opp_energy_count[i, j] = 0
663
+ for j in range(16):
664
+ opp_tapped[i, j] = 0
665
+ for j in range(50):
666
+ opp_live[i, j] = 0
667
+ opp_scores[i] = 0
668
+ for j in range(128):
669
+ opp_global_ctx[i, j] = 0
670
+ for j in range(60):
671
+ opp_trash[i, j] = 0
672
+
673
+ # Generate deck
674
+ n_members = ability_member_ids.shape[0]
675
+ n_lives = ability_live_ids.shape[0]
676
+
677
+ # Members (0-47)
678
+ for k in range(48):
679
+ if n_members == 48:
680
+ batch_deck[i, k] = ability_member_ids[k]
681
+ opp_deck[i, k] = ability_member_ids[k]
682
+ else:
683
+ # Random pick using RNG
684
+ r = xoroshiro128p_uniform_float32(rng_states, i)
685
+ idx = int(r * n_members) % n_members
686
+ batch_deck[i, k] = ability_member_ids[idx]
687
+ r = xoroshiro128p_uniform_float32(rng_states, i)
688
+ idx = int(r * n_members) % n_members
689
+ opp_deck[i, k] = ability_member_ids[idx]
690
+
691
+ # Lives (48-59)
692
+ for k in range(12):
693
+ if n_lives == 12:
694
+ batch_deck[i, 48 + k] = ability_live_ids[k]
695
+ opp_deck[i, 48 + k] = ability_live_ids[k]
696
+ else:
697
+ r = xoroshiro128p_uniform_float32(rng_states, i)
698
+ idx = int(r * n_lives) % n_lives
699
+ batch_deck[i, 48 + k] = ability_live_ids[idx]
700
+ r = xoroshiro128p_uniform_float32(rng_states, i)
701
+ idx = int(r * n_lives) % n_lives
702
+ opp_deck[i, 48 + k] = ability_live_ids[idx]
703
+
704
+ # Shuffle decks (Fisher-Yates)
705
+ for k in range(59, 0, -1):
706
+ r = xoroshiro128p_uniform_float32(rng_states, i)
707
+ j = int(r * (k + 1)) % (k + 1)
708
+ tmp = batch_deck[i, k]
709
+ batch_deck[i, k] = batch_deck[i, j]
710
+ batch_deck[i, j] = tmp
711
+
712
+ r = xoroshiro128p_uniform_float32(rng_states, i)
713
+ j = int(r * (k + 1)) % (k + 1)
714
+ tmp = opp_deck[i, k]
715
+ opp_deck[i, k] = opp_deck[i, j]
716
+ opp_deck[i, j] = tmp
717
+
718
+ # Place 2 cards in Live Zone
719
+ batch_live[i, 0] = batch_deck[i, 0]
720
+ batch_live[i, 1] = batch_deck[i, 1]
721
+ batch_deck[i, 0] = 0
722
+ batch_deck[i, 1] = 0
723
+
724
+ opp_live[i, 0] = opp_deck[i, 0]
725
+ opp_live[i, 1] = opp_deck[i, 1]
726
+ opp_deck[i, 0] = 0
727
+ opp_deck[i, 1] = 0
728
+
729
+ # Draw hand (6 cards)
730
+ for j in range(60):
731
+ batch_hand[i, j] = 0
732
+ opp_hand[i, j] = 0
733
+
734
+ drawn = 0
735
+ for k in range(2, 60):
736
+ if batch_deck[i, k] > 0 and drawn < 6:
737
+ batch_hand[i, drawn] = batch_deck[i, k]
738
+ batch_deck[i, k] = 0
739
+ drawn += 1
740
+
741
+ drawn_o = 0
742
+ for k in range(2, 60):
743
+ if opp_deck[i, k] > 0 and drawn_o < 6:
744
+ opp_hand[i, drawn_o] = opp_deck[i, k]
745
+ opp_deck[i, k] = 0
746
+ drawn_o += 1
747
+
748
+ # Set initial global context
749
+ batch_global_ctx[i, HD] = 6
750
+ batch_global_ctx[i, DK] = 52
751
+ batch_global_ctx[i, EN] = 3
752
+ batch_global_ctx[i, PH] = 4 # Start in Main phase (simplified)
753
+ batch_global_ctx[i, 54] = 1 # Turn 1
754
+
755
+ opp_global_ctx[i, HD] = 6
756
+ opp_global_ctx[i, DK] = 52
757
+ opp_global_ctx[i, EN] = 3
758
+ opp_global_ctx[i, PH] = 4
759
+ opp_global_ctx[i, 54] = 1
760
+
761
+ # Start order
762
+ if force_start_order == -1:
763
+ r = xoroshiro128p_uniform_float32(rng_states, i)
764
+ is_second = 1 if r > 0.5 else 0
765
+ else:
766
+ is_second = force_start_order
767
+ batch_global_ctx[i, 10] = is_second
768
+
769
+
770
+ @cuda.jit
771
+ def step_kernel(
772
+ num_envs,
773
+ actions,
774
+ batch_hand,
775
+ batch_deck,
776
+ batch_stage,
777
+ batch_energy_vec,
778
+ batch_energy_count,
779
+ batch_continuous_vec,
780
+ batch_continuous_ptr,
781
+ batch_tapped,
782
+ batch_live,
783
+ batch_scores,
784
+ batch_flat_ctx,
785
+ batch_global_ctx,
786
+ opp_hand,
787
+ opp_deck,
788
+ opp_stage,
789
+ opp_energy_vec,
790
+ opp_energy_count,
791
+ opp_tapped,
792
+ opp_live,
793
+ opp_scores,
794
+ opp_global_ctx,
795
+ card_stats,
796
+ bytecode_map,
797
+ bytecode_index,
798
+ obs_buffer,
799
+ rewards,
800
+ dones,
801
+ prev_scores,
802
+ prev_opp_scores,
803
+ prev_phases,
804
+ terminal_obs_buffer,
805
+ batch_trash,
806
+ opp_trash,
807
+ batch_opp_history,
808
+ term_scores_agent,
809
+ term_scores_opp,
810
+ ability_member_ids,
811
+ ability_live_ids,
812
+ rng_states,
813
+ game_config,
814
+ opp_mode,
815
+ force_start_order,
816
+ ):
817
+ """
818
+ Main integrated step kernel.
819
+ Processes one environment per thread.
820
+ """
821
+ i = cuda.grid(1)
822
+ if i >= num_envs:
823
+ return
824
+
825
+ # Config
826
+ CFG_TURN_LIMIT = int(game_config[0])
827
+ CFG_STEP_LIMIT = int(game_config[1])
828
+ CFG_REWARD_WIN = game_config[2]
829
+ CFG_REWARD_LOSE = game_config[3]
830
+ CFG_REWARD_SCALE = game_config[4]
831
+ CFG_REWARD_TURN_PENALTY = game_config[5]
832
+
833
+ act_id = actions[i]
834
+ ph = int(batch_global_ctx[i, PH])
835
+
836
+ # Sync score to context
837
+ batch_global_ctx[i, SC] = batch_scores[i]
838
+
839
+ # Increment step counter
840
+ batch_global_ctx[i, 58] += 1
841
+
842
+ # Get continuous pointer slice
843
+ cont_ptr_arr = batch_continuous_ptr[i : i + 1]
844
+ score_arr = batch_scores[i : i + 1]
845
+
846
+ # Execute action
847
+ bonus = step_player_device(
848
+ act_id,
849
+ 0,
850
+ rng_states,
851
+ i,
852
+ batch_hand[i],
853
+ batch_deck[i],
854
+ batch_stage[i],
855
+ batch_energy_vec[i],
856
+ batch_energy_count[i],
857
+ batch_tapped[i],
858
+ batch_live[i],
859
+ score_arr,
860
+ batch_global_ctx[i],
861
+ batch_trash[i],
862
+ batch_continuous_vec[i],
863
+ cont_ptr_arr,
864
+ opp_tapped[i],
865
+ card_stats,
866
+ bytecode_map,
867
+ bytecode_index,
868
+ )
869
+ batch_scores[i] += bonus
870
+
871
+ # Handle turn end (Pass in Main Phase)
872
+ if act_id == 0 and ph == 4:
873
+ # Run opponent turn
874
+ run_opponent_turn_device(
875
+ rng_states,
876
+ i,
877
+ opp_hand[i],
878
+ opp_deck[i],
879
+ opp_stage[i],
880
+ opp_energy_vec[i],
881
+ opp_energy_count[i],
882
+ opp_tapped[i],
883
+ opp_live[i],
884
+ opp_scores[i : i + 1],
885
+ opp_global_ctx[i],
886
+ opp_trash[i],
887
+ batch_tapped[i],
888
+ batch_opp_history,
889
+ card_stats,
890
+ bytecode_map,
891
+ bytecode_index,
892
+ )
893
+
894
+ # Resolve lives for both players
895
+ agent_live_score = 0
896
+ opp_live_score = 0
897
+
898
+ for z in range(10):
899
+ lid = batch_live[i, z]
900
+ if lid > 0:
901
+ s = resolve_live_device(
902
+ lid,
903
+ batch_stage[i],
904
+ batch_live[i],
905
+ batch_scores[i : i + 1],
906
+ batch_global_ctx[i],
907
+ batch_deck[i],
908
+ batch_hand[i],
909
+ batch_trash[i],
910
+ card_stats,
911
+ batch_continuous_vec[i],
912
+ cont_ptr_arr,
913
+ )
914
+ agent_live_score += s
915
+ # Clear used live
916
+ if s > 0:
917
+ move_to_trash_device(lid, batch_trash[i], batch_global_ctx[i], TR)
918
+ batch_live[i, z] = 0
919
+
920
+ for z in range(10):
921
+ lid = opp_live[i, z]
922
+ if lid > 0:
923
+ s = resolve_live_device(
924
+ lid,
925
+ opp_stage[i],
926
+ opp_live[i],
927
+ opp_scores[i : i + 1],
928
+ opp_global_ctx[i],
929
+ opp_deck[i],
930
+ opp_hand[i],
931
+ opp_trash[i],
932
+ card_stats,
933
+ batch_continuous_vec[i],
934
+ cont_ptr_arr,
935
+ )
936
+ opp_live_score += s
937
+ if s > 0:
938
+ move_to_trash_device(lid, opp_trash[i], opp_global_ctx[i], TR)
939
+ opp_live[i, z] = 0
940
+
941
+ # Scoring comparison
942
+ if agent_live_score > 0 and opp_live_score == 0:
943
+ batch_scores[i] += 1
944
+ elif agent_live_score == 0 and opp_live_score > 0:
945
+ opp_scores[i] += 1
946
+ elif agent_live_score > 0 and opp_live_score > 0:
947
+ if agent_live_score > opp_live_score:
948
+ batch_scores[i] += 1
949
+ elif opp_live_score > agent_live_score:
950
+ opp_scores[i] += 1
951
+ else:
952
+ # Tie - both score
953
+ batch_scores[i] += 1
954
+ opp_scores[i] += 1
955
+
956
+ # Next turn setup
957
+ batch_global_ctx[i, 54] += 1
958
+ opp_global_ctx[i, 54] += 1
959
+
960
+ # Untap and energy
961
+ for j in range(16):
962
+ batch_tapped[i, j] = 0
963
+ if j < opp_tapped.shape[1]:
964
+ opp_tapped[i, j] = 0
965
+
966
+ batch_global_ctx[i, EN] = min(batch_global_ctx[i, EN] + 1, 12)
967
+ opp_global_ctx[i, EN] = min(opp_global_ctx[i, EN] + 1, 12)
968
+
969
+ # Draw card
970
+ draw_cards_device(1, batch_hand[i], batch_deck[i], batch_trash[i], batch_global_ctx[i])
971
+ draw_cards_device(1, opp_hand[i], opp_deck[i], opp_trash[i], opp_global_ctx[i])
972
+
973
+ # Calculate rewards
974
+ current_score = batch_scores[i]
975
+ score_diff = float(current_score) - float(prev_scores[i])
976
+ opp_score_diff = float(opp_scores[i]) - float(prev_opp_scores[i])
977
+
978
+ r = (score_diff * CFG_REWARD_SCALE) - (opp_score_diff * CFG_REWARD_SCALE)
979
+ r += CFG_REWARD_TURN_PENALTY
980
+
981
+ win = current_score >= 3
982
+ lose = opp_scores[i] >= 3
983
+
984
+ if win:
985
+ r += CFG_REWARD_WIN
986
+ if lose:
987
+ r += CFG_REWARD_LOSE
988
+
989
+ rewards[i] = r
990
+
991
+ # Sync Opp Stats to Agent Context (for Attention features)
992
+ batch_global_ctx[i, 4] = opp_global_ctx[i, 3] # HD
993
+ batch_global_ctx[i, 9] = opp_global_ctx[i, 6] # DK
994
+ batch_global_ctx[i, 7] = opp_global_ctx[i, 2] # TR
995
+
996
+ # Check done
997
+ is_done = win or lose or batch_global_ctx[i, 54] >= CFG_TURN_LIMIT or batch_global_ctx[i, 58] >= CFG_STEP_LIMIT
998
+ dones[i] = is_done
999
+
1000
+ if is_done:
1001
+ term_scores_agent[i] = batch_scores[i]
1002
+ term_scores_opp[i] = opp_scores[i]
1003
+ # Note: Auto-reset should be called separately
1004
+
1005
+ # Update prev scores
1006
+ prev_scores[i] = batch_scores[i]
1007
+ prev_opp_scores[i] = opp_scores[i]
1008
+
1009
+
1010
+ @cuda.jit
1011
+ def compute_action_masks_kernel(
1012
+ num_envs,
1013
+ batch_hand,
1014
+ batch_stage,
1015
+ batch_tapped,
1016
+ batch_global_ctx,
1017
+ batch_live,
1018
+ card_stats,
1019
+ masks, # Output: (N, 2000)
1020
+ ):
1021
+ """
1022
+ Compute legal action masks on GPU.
1023
+ """
1024
+ i = cuda.grid(1)
1025
+ if i >= num_envs:
1026
+ return
1027
+
1028
+ # Reset all to False
1029
+ for a in range(2000):
1030
+ masks[i, a] = False
1031
+
1032
+ ph = batch_global_ctx[i, PH]
1033
+
1034
+ # Action 0: Pass is always legal in Main Phase
1035
+ if ph == 4:
1036
+ masks[i, 0] = True
1037
+
1038
+ # Member Play (1-180): HandIdx * 3 + Slot + 1
1039
+ for h_idx in range(60):
1040
+ cid = batch_hand[i, h_idx]
1041
+ if cid > 0:
1042
+ bid = get_base_id_device(cid)
1043
+ if bid < card_stats.shape[0]:
1044
+ ctype = card_stats[bid, 10]
1045
+ cost = card_stats[bid, 0]
1046
+
1047
+ if ctype == 1: # Member
1048
+ for slot in range(3):
1049
+ # Check if slot empty or can upgrade
1050
+ old_cid = batch_stage[i, slot]
1051
+ effective_cost = cost
1052
+ if old_cid >= 0:
1053
+ old_bid = get_base_id_device(old_cid)
1054
+ if old_bid < card_stats.shape[0]:
1055
+ effective_cost = max(0, cost - card_stats[old_bid, 0])
1056
+
1057
+ # Check energy
1058
+ available_energy = 0
1059
+ for e in range(12):
1060
+ if batch_tapped[i, 3 + e] == 0:
1061
+ available_energy += 1
1062
+
1063
+ if available_energy >= effective_cost:
1064
+ action_id = h_idx * 3 + slot + 1
1065
+ if action_id < 181:
1066
+ masks[i, action_id] = True
1067
+
1068
+ # Activate Ability (200-202)
1069
+ for slot in range(3):
1070
+ cid = batch_stage[i, slot]
1071
+ if cid > 0 and batch_tapped[i, slot] == 0:
1072
+ masks[i, 200 + slot] = True
1073
+
1074
+ # Set Live (400-459)
1075
+ for h_idx in range(60):
1076
+ cid = batch_hand[i, h_idx]
1077
+ if cid > 0:
1078
+ bid = get_base_id_device(cid)
1079
+ if bid < card_stats.shape[0]:
1080
+ ctype = card_stats[bid, 10]
1081
+ if ctype == 2: # Live
1082
+ # Check if there's an empty live zone slot
1083
+ for lz_idx in range(50):
1084
+ if batch_live[i, lz_idx] == 0:
1085
+ if h_idx < 60: # This check is redundant due to outer loop
1086
+ masks[i, 400 + h_idx] = True
1087
+ break # Only need one empty slot to make it legal
1088
+
1089
+
1090
+ @cuda.jit
1091
+ def encode_observations_kernel(
1092
+ num_envs,
1093
+ batch_hand,
1094
+ batch_stage,
1095
+ batch_energy_count,
1096
+ batch_tapped,
1097
+ batch_scores,
1098
+ opp_scores,
1099
+ opp_stage,
1100
+ opp_tapped,
1101
+ card_stats,
1102
+ batch_global_ctx,
1103
+ batch_live,
1104
+ turn_number,
1105
+ obs_buffer,
1106
+ ):
1107
+ """
1108
+ Encode observations on GPU (STANDARD mode).
1109
+ """
1110
+ i = cuda.grid(1)
1111
+ if i >= num_envs:
1112
+ return
1113
+
1114
+ obs_dim = obs_buffer.shape[1]
1115
+
1116
+ # Clear observation
1117
+ for j in range(obs_dim):
1118
+ obs_buffer[i, j] = 0.0
1119
+
1120
+ # Metadata
1121
+ obs_buffer[i, 0] = float(batch_scores[i]) / 3.0
1122
+ obs_buffer[i, 1] = float(opp_scores[i]) / 3.0
1123
+ obs_buffer[i, 2] = float(batch_global_ctx[i, EN]) / 12.0
1124
+ obs_buffer[i, 3] = float(batch_global_ctx[i, HD]) / 60.0
1125
+ obs_buffer[i, 4] = float(batch_global_ctx[i, DK]) / 60.0
1126
+ obs_buffer[i, 5] = float(batch_global_ctx[i, 54]) / 100.0 # Turn
1127
+
1128
+ offset = 10
1129
+
1130
+ # Stage (3 slots x 20 features)
1131
+ for slot in range(3):
1132
+ cid = batch_stage[i, slot]
1133
+ base = offset + slot * 20
1134
+ if cid > 0:
1135
+ bid = get_base_id_device(cid)
1136
+ if bid < card_stats.shape[0]:
1137
+ obs_buffer[i, base] = 1.0 # Presence
1138
+ obs_buffer[i, base + 1] = float(cid) / 2000.0
1139
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 # Cost
1140
+ obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 # Blades
1141
+ obs_buffer[i, base + 4] = float(card_stats[bid, 2]) / 10.0 # Hearts
1142
+ obs_buffer[i, base + 5] = 1.0 if batch_tapped[i, slot] > 0 else 0.0
1143
+
1144
+ offset += 60
1145
+
1146
+ # Opponent Stage
1147
+ for slot in range(3):
1148
+ cid = opp_stage[i, slot]
1149
+ base = offset + slot * 20
1150
+ if cid > 0:
1151
+ bid = get_base_id_device(cid)
1152
+ if bid < card_stats.shape[0]:
1153
+ obs_buffer[i, base] = 1.0
1154
+ obs_buffer[i, base + 1] = float(cid) / 2000.0
1155
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
1156
+ obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0
1157
+ obs_buffer[i, base + 4] = float(card_stats[bid, 2]) / 10.0
1158
+
1159
+ offset += 60
1160
+
1161
+ # Hand (up to 20 cards shown)
1162
+ h_count = 0
1163
+ for h_idx in range(60):
1164
+ cid = batch_hand[i, h_idx]
1165
+ if cid > 0 and h_count < 20:
1166
+ base = offset + h_count * 20
1167
+ if base + 10 < obs_dim:
1168
+ obs_buffer[i, base] = 1.0
1169
+ obs_buffer[i, base + 1] = float(cid) / 2000.0
1170
+ bid = get_base_id_device(cid)
1171
+ if bid < card_stats.shape[0]:
1172
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
1173
+ obs_buffer[i, base + 3] = float(card_stats[bid, 10]) # Type
1174
+ h_count += 1
1175
+
1176
+ offset += 400
1177
+
1178
+ # Live zone (up to 10 cards)
1179
+ l_count = 0
1180
+ for l_idx in range(50):
1181
+ cid = batch_live[i, l_idx]
1182
+ if cid > 0 and l_count < 10:
1183
+ base = offset + l_count * 10
1184
+ if base + 5 < obs_dim:
1185
+ obs_buffer[i, base] = 1.0
1186
+ obs_buffer[i, base + 1] = float(cid) / 2000.0
1187
+ l_count += 1
1188
+
1189
+
1190
+ @cuda.jit
1191
+ def encode_observations_attention_kernel(
1192
+ num_envs,
1193
+ batch_hand,
1194
+ batch_stage,
1195
+ batch_energy_count,
1196
+ batch_tapped,
1197
+ batch_scores,
1198
+ opp_scores,
1199
+ opp_stage,
1200
+ opp_tapped,
1201
+ card_stats,
1202
+ batch_global_ctx,
1203
+ batch_live,
1204
+ batch_opp_history,
1205
+ opp_global_ctx, # Added
1206
+ turn_number,
1207
+ obs_buffer,
1208
+ ):
1209
+ """
1210
+ Encode observations for Attention Architecture (2240-dim).
1211
+ """
1212
+ i = cuda.grid(1)
1213
+ if i >= num_envs:
1214
+ return
1215
+
1216
+ # Constants
1217
+ FEAT = 64
1218
+ MAX_HAND = 15 # +1 overflow
1219
+
1220
+ # Offsets
1221
+ HAND_START = 0
1222
+ HAND_OVER_START = HAND_START + (MAX_HAND * FEAT) # 960
1223
+ STAGE_START = HAND_OVER_START + FEAT # 1024
1224
+ LIVE_START = STAGE_START + (3 * FEAT) # 1216
1225
+ LIVE_SUCC_START = LIVE_START + (3 * FEAT) # 1408
1226
+ OPP_STAGE_START = LIVE_SUCC_START + (3 * FEAT) # 1600
1227
+ OPP_HIST_START = OPP_STAGE_START + (3 * FEAT) # 1792
1228
+ GLOBAL_START = OPP_HIST_START + (6 * FEAT) # 2176
1229
+
1230
+ # Clear buffer
1231
+ for k in range(2240):
1232
+ obs_buffer[i, k] = 0.0
1233
+
1234
+ # --- A. HAND (16 slots) ---
1235
+ hand_count = 0
1236
+ for j in range(60):
1237
+ cid = batch_hand[i, j]
1238
+ if cid > 0:
1239
+ bid = get_base_id_device(cid)
1240
+ if bid < card_stats.shape[0]:
1241
+ if hand_count < 16:
1242
+ base = HAND_START + hand_count * FEAT
1243
+
1244
+ obs_buffer[i, base + 0] = 1.0 # Presence
1245
+ obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0 # Type
1246
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0 # Cost
1247
+ obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0 # Blades
1248
+ obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
1249
+ obs_buffer[i, base + 6] = 0.2 # Location: Hand
1250
+
1251
+ # Hearts (8-14)
1252
+ for k in range(7):
1253
+ if 12 + k < card_stats.shape[1]:
1254
+ obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
1255
+
1256
+ # Group (22-28)
1257
+ raw_group = card_stats[bid, 11]
1258
+ obs_buffer[i, base + 22 + (raw_group % 7)] = 1.0
1259
+
1260
+ # Context
1261
+ obs_buffer[i, base + 58] = float(hand_count) / 10.0
1262
+ obs_buffer[i, base + 59] = 1.0 # Mine
1263
+
1264
+ hand_count += 1
1265
+
1266
+ # --- B. MY STAGE (3 slots) ---
1267
+ for slot in range(3):
1268
+ cid = batch_stage[i, slot]
1269
+ if cid > 0:
1270
+ bid = get_base_id_device(cid)
1271
+ if bid < card_stats.shape[0]:
1272
+ base = STAGE_START + slot * FEAT
1273
+
1274
+ obs_buffer[i, base + 0] = 1.0
1275
+ obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
1276
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
1277
+ obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0
1278
+ obs_buffer[i, base + 4] = 1.0 if batch_tapped[i, slot] > 0 else 0.0
1279
+ obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
1280
+ obs_buffer[i, base + 6] = 0.4 # Location: Stage
1281
+
1282
+ for k in range(7):
1283
+ if 12 + k < card_stats.shape[1]:
1284
+ obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
1285
+
1286
+ raw_group = card_stats[bid, 11]
1287
+ obs_buffer[i, base + 22 + (raw_group % 7)] = 1.0
1288
+
1289
+ obs_buffer[i, base + 58] = float(slot) / 10.0
1290
+ obs_buffer[i, base + 59] = 1.0
1291
+
1292
+ # --- C. LIVE ZONE (6 slots) ---
1293
+ live_count = 0
1294
+ for j in range(50):
1295
+ cid = batch_live[i, j]
1296
+ if cid > 0:
1297
+ bid = get_base_id_device(cid)
1298
+ if bid < card_stats.shape[0] and live_count < 6:
1299
+ base = LIVE_START + live_count * FEAT
1300
+
1301
+ obs_buffer[i, base + 0] = 1.0
1302
+ obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
1303
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
1304
+ obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
1305
+ obs_buffer[i, base + 6] = 0.6 # Location: Live
1306
+
1307
+ for k in range(7):
1308
+ if 12 + k < card_stats.shape[1]:
1309
+ obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
1310
+
1311
+ obs_buffer[i, base + 58] = float(live_count) / 10.0
1312
+ obs_buffer[i, base + 59] = 1.0
1313
+ live_count += 1
1314
+
1315
+ # --- D. OPP STAGE (3 slots) ---
1316
+ for slot in range(3):
1317
+ cid = opp_stage[i, slot]
1318
+ if cid > 0:
1319
+ bid = get_base_id_device(cid)
1320
+ if bid < card_stats.shape[0]:
1321
+ base = OPP_STAGE_START + slot * FEAT
1322
+
1323
+ obs_buffer[i, base + 0] = 1.0
1324
+ obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
1325
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
1326
+ obs_buffer[i, base + 3] = float(card_stats[bid, 1]) / 5.0
1327
+ obs_buffer[i, base + 4] = 1.0 if opp_tapped[i, slot] > 0 else 0.0
1328
+ obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
1329
+ obs_buffer[i, base + 6] = 0.8 # Location: Opp Stage
1330
+
1331
+ for k in range(7):
1332
+ if 12 + k < card_stats.shape[1]:
1333
+ obs_buffer[i, base + 8 + k] = float(card_stats[bid, 12 + k]) / 5.0
1334
+
1335
+ obs_buffer[i, base + 58] = float(slot) / 10.0
1336
+ obs_buffer[i, base + 59] = -1.0
1337
+
1338
+ # --- E. OPP HISTORY (6 slots) ---
1339
+ for h in range(6):
1340
+ cid = batch_opp_history[i, h]
1341
+ if cid > 0:
1342
+ bid = get_base_id_device(cid)
1343
+ if bid < card_stats.shape[0]:
1344
+ base = OPP_HIST_START + h * FEAT
1345
+
1346
+ obs_buffer[i, base + 0] = 1.0
1347
+ obs_buffer[i, base + 1] = float(card_stats[bid, 10]) / 2.0
1348
+ obs_buffer[i, base + 2] = float(card_stats[bid, 0]) / 10.0
1349
+ obs_buffer[i, base + 5] = float(cid) / 2000.0 # Card ID (New)
1350
+ obs_buffer[i, base + 6] = 1.0 # Location: History
1351
+
1352
+ obs_buffer[i, base + 58] = float(h) / 10.0
1353
+ obs_buffer[i, base + 59] = -1.0
1354
+
1355
+ # --- F. GLOBAL SCALARS ---
1356
+ obs_buffer[i, GLOBAL_START + 0] = float(batch_scores[i]) / 10.0
1357
+ obs_buffer[i, GLOBAL_START + 1] = float(opp_scores[i]) / 10.0
1358
+ obs_buffer[i, GLOBAL_START + 2] = float(batch_global_ctx[i, 54]) / 20.0 # Turn from Context
1359
+ obs_buffer[i, GLOBAL_START + 3] = float(batch_global_ctx[i, 8]) / 10.0
1360
+ obs_buffer[i, GLOBAL_START + 4] = float(batch_global_ctx[i, 5]) / 10.0
1361
+ obs_buffer[i, GLOBAL_START + 5] = float(batch_global_ctx[i, 6]) / 40.0
1362
+ obs_buffer[i, GLOBAL_START + 6] = float(hand_count) / 15.0
1363
+
1364
+ # Opponent Resources (New)
1365
+ obs_buffer[i, GLOBAL_START + 7] = float(opp_global_ctx[i, 5]) / 10.0 # Opp Energy
1366
+ obs_buffer[i, GLOBAL_START + 8] = float(batch_global_ctx[i, 4]) / 10.0 # Opp Hand (from ctx[4])
1367
+ obs_buffer[i, GLOBAL_START + 9] = float(batch_global_ctx[i, 9]) / 40.0 # Opp Deck (from ctx[9])
1368
+ obs_buffer[i, GLOBAL_START + 10] = float(batch_global_ctx[i, 7]) / 10.0 # Opp Trash (from ctx[7])