trioskosmos commited on
Commit
598726b
·
verified ·
1 Parent(s): 20eb74b

Upload ai/research/fast_logic_gpu.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/research/fast_logic_gpu.py +409 -0
ai/research/fast_logic_gpu.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from engine.game.fast_logic import (
2
+ C_CLR,
3
+ C_CMP,
4
+ C_CTR,
5
+ C_ENR,
6
+ C_GRP,
7
+ C_HND,
8
+ C_LLD,
9
+ C_OPH,
10
+ C_STG,
11
+ C_TR1,
12
+ DK,
13
+ EN,
14
+ HD,
15
+ O_ADD_H,
16
+ O_BLADES,
17
+ O_BOOST,
18
+ O_BUFF,
19
+ O_CHARGE,
20
+ O_CHOOSE,
21
+ O_DRAW,
22
+ O_HEARTS,
23
+ O_JUMP,
24
+ O_JUMP_F,
25
+ O_RECOV_L,
26
+ O_RECOV_M,
27
+ O_RETURN,
28
+ O_TAP_O,
29
+ OS,
30
+ OT,
31
+ SC,
32
+ TR,
33
+ )
34
+
35
+ try:
36
+ from numba import cuda
37
+ from numba.cuda.random import xoroshiro128p_uniform_float32
38
+
39
+ HAS_CUDA = True
40
+ except ImportError:
41
+ HAS_CUDA = False
42
+
43
+ class MockCuda:
44
+ def jit(self, *args, **kwargs):
45
+ return lambda x: x
46
+
47
+ def grid(self, x):
48
+ return 0
49
+
50
+ cuda = MockCuda()
51
+
52
+ def xoroshiro128p_uniform_float32(rng, idx):
53
+ return 0.5
54
+
55
+
56
+ @cuda.jit(device=True)
57
+ def resolve_bytecode_device(
58
+ bytecode,
59
+ flat_ctx,
60
+ global_ctx,
61
+ player_id,
62
+ p_hand,
63
+ p_deck,
64
+ p_stage,
65
+ p_energy_vec,
66
+ p_energy_count,
67
+ p_cont_vec,
68
+ p_cont_ptr,
69
+ p_tapped,
70
+ p_live,
71
+ opp_tapped,
72
+ ):
73
+ """
74
+ GPU Device function for resolving bytecode.
75
+ Equivalent to engine/game/fast_logic.py:resolve_bytecode but optimized for CUDA.
76
+ """
77
+ ip = 0
78
+ cptr = p_cont_ptr
79
+ bonus = 0
80
+ cond = True
81
+ blen = bytecode.shape[0]
82
+
83
+ # SAFETY: Infinite loop protection
84
+ safety_counter = 0
85
+
86
+ while ip < blen and safety_counter < 500:
87
+ safety_counter += 1
88
+ op = bytecode[ip, 0]
89
+ v = bytecode[ip, 1]
90
+ a = bytecode[ip, 2]
91
+ s = bytecode[ip, 3]
92
+
93
+ if op == 0:
94
+ ip += 1
95
+ continue
96
+ if op == O_RETURN:
97
+ break
98
+
99
+ # Jumps with safety checks
100
+ if op == O_JUMP:
101
+ new_ip = ip + v
102
+ if 0 <= new_ip < blen:
103
+ ip = new_ip
104
+ else:
105
+ ip = blen # Exit
106
+ continue
107
+
108
+ if op == O_JUMP_F:
109
+ if not cond:
110
+ new_ip = ip + v
111
+ if 0 <= new_ip < blen:
112
+ ip = new_ip
113
+ else:
114
+ ip = blen # Exit
115
+ continue
116
+ ip += 1
117
+ continue
118
+
119
+ if op >= 200:
120
+ if op == C_TR1:
121
+ cond = global_ctx[TR] == 1
122
+ elif op == C_STG:
123
+ ct = 0
124
+ for i in range(3):
125
+ if p_stage[i] != -1:
126
+ ct += 1
127
+ cond = ct >= v
128
+ elif op == C_HND:
129
+ cond = global_ctx[HD] >= v
130
+ elif op == C_LLD:
131
+ cond = global_ctx[SC] > global_ctx[OS]
132
+ elif op == C_CLR:
133
+ if 0 <= a <= 5:
134
+ cond = global_ctx[10 + a] > 0
135
+ else:
136
+ cond = False
137
+ elif op == C_GRP:
138
+ if 0 <= a <= 4:
139
+ cond = global_ctx[30 + a] >= v
140
+ else:
141
+ cond = False
142
+ elif op == C_ENR:
143
+ cond = global_ctx[EN] >= v
144
+ elif op == C_CTR:
145
+ cond = flat_ctx[7] == 1 # SZ=7 (Hand=1)
146
+ elif op == C_CMP:
147
+ if v > 0:
148
+ cond = global_ctx[SC] >= v
149
+ else:
150
+ cond = global_ctx[SC] > global_ctx[OS]
151
+ elif op == C_OPH:
152
+ ct = global_ctx[OT]
153
+ if v > 0:
154
+ cond = ct >= v
155
+ else:
156
+ cond = ct > 0
157
+ else:
158
+ cond = True
159
+ ip += 1
160
+ else:
161
+ if cond:
162
+ if op == O_DRAW or op == O_CHOOSE or op == O_ADD_H:
163
+ # Draw v cards logic (O_CHOOSE is Look v add 1, simplified to Draw 1)
164
+ # O_ADD_H is add v from deck
165
+ draw_amt = v
166
+ if op == O_CHOOSE:
167
+ draw_amt = 1
168
+
169
+ if global_ctx[DK] >= draw_amt:
170
+ global_ctx[DK] -= draw_amt
171
+ global_ctx[HD] += draw_amt
172
+
173
+ # Perform actual card movement
174
+ for _ in range(draw_amt):
175
+ # 1. Find top card
176
+ top_card = 0
177
+ d_idx_found = -1
178
+ for d_idx in range(60):
179
+ if p_deck[d_idx] > 0:
180
+ top_card = p_deck[d_idx]
181
+ d_idx_found = d_idx
182
+ break
183
+
184
+ if top_card > 0:
185
+ # 2. Find empty hand slot
186
+ for h_idx in range(60):
187
+ if p_hand[h_idx] == 0:
188
+ p_hand[h_idx] = top_card
189
+ p_deck[d_idx_found] = 0
190
+ break
191
+
192
+ else:
193
+ # Draw remaining deck? (Simplified: just draw what we can)
194
+ t = global_ctx[DK]
195
+ if t > 0:
196
+ # Draw t cards
197
+ for _ in range(t):
198
+ top_card = 0
199
+ d_idx_found = -1
200
+ for d_idx in range(60):
201
+ if p_deck[d_idx] > 0:
202
+ top_card = p_deck[d_idx]
203
+ d_idx_found = d_idx
204
+ break
205
+ if top_card > 0:
206
+ for h_idx in range(60):
207
+ if p_hand[h_idx] == 0:
208
+ p_hand[h_idx] = top_card
209
+ p_deck[d_idx_found] = 0
210
+ break
211
+
212
+ global_ctx[DK] = 0
213
+ global_ctx[HD] += t
214
+
215
+ elif op == O_CHARGE:
216
+ if global_ctx[DK] >= v:
217
+ global_ctx[DK] -= v
218
+ global_ctx[EN] += v
219
+ # Move v cards from Deck to "Energy" (which is virtual count or zone?)
220
+ # Logic usually says Charge = move to energy zone.
221
+ # In fast_logic, we have p_energy_vec (3 slots x 32).
222
+ # But Charge typically goes to specific member energy?
223
+ # Or global energy? The global context EN is just a count.
224
+ # For POC, we just consume from deck. Real logic needs target slot.
225
+
226
+ for _ in range(v):
227
+ # Remove from deck
228
+ for d_idx in range(60):
229
+ if p_deck[d_idx] > 0:
230
+ p_deck[d_idx] = 0
231
+ break
232
+ else:
233
+ t = global_ctx[DK]
234
+ global_ctx[DK] = 0
235
+ global_ctx[EN] += t
236
+ for _ in range(t):
237
+ for d_idx in range(60):
238
+ if p_deck[d_idx] > 0:
239
+ p_deck[d_idx] = 0
240
+ break
241
+
242
+ elif op == O_BLADES:
243
+ if s >= 0 and cptr < 32:
244
+ p_cont_vec[cptr, 0] = 1
245
+ p_cont_vec[cptr, 1] = v
246
+ p_cont_vec[cptr, 2] = 4
247
+ p_cont_vec[cptr, 3] = s
248
+ p_cont_vec[cptr, 9] = 1
249
+ cptr += 1
250
+ elif op == O_HEARTS:
251
+ if cptr < 32:
252
+ p_cont_vec[cptr, 0] = 2
253
+ p_cont_vec[cptr, 1] = v
254
+ p_cont_vec[cptr, 5] = a
255
+ p_cont_vec[cptr, 9] = 1
256
+ cptr += 1
257
+ global_ctx[0] += v # SC = 0. Immediate scoring for Vectorized RL.
258
+ elif op == O_RECOV_L:
259
+ if 0 <= s < p_live.shape[0]:
260
+ p_live[s] = 0
261
+ elif op == O_RECOV_M:
262
+ if 0 <= s < 3:
263
+ p_tapped[s] = 0
264
+ elif op == O_TAP_O:
265
+ if 0 <= s < 3:
266
+ opp_tapped[s] = 1
267
+ elif op == O_BUFF:
268
+ if cptr < 32:
269
+ p_cont_vec[cptr, 0] = 8
270
+ p_cont_vec[cptr, 1] = v
271
+ p_cont_vec[cptr, 2] = s
272
+ p_cont_vec[cptr, 9] = 1
273
+ cptr += 1
274
+ elif op == O_BOOST:
275
+ bonus += v
276
+ ip += 1
277
+
278
+ return cptr, 0, bonus
279
+
280
+
281
+ @cuda.jit
282
+ def step_kernel(
283
+ rng_states,
284
+ batch_stage, # (N, 3)
285
+ batch_energy_vec, # (N, 3, 32)
286
+ batch_energy_count, # (N, 3)
287
+ batch_continuous_vec, # (N, 32, 10)
288
+ batch_continuous_ptr, # (N,)
289
+ batch_tapped, # (N, 3)
290
+ batch_live, # (N, 50)
291
+ batch_opp_tapped, # (N, 3)
292
+ batch_scores, # (N,)
293
+ batch_flat_ctx, # (N, 64)
294
+ batch_global_ctx, # (N, 128)
295
+ batch_hand, # (N, 60)
296
+ batch_deck, # (N, 60)
297
+ bytecode_map, # (MapSize, 64, 4)
298
+ bytecode_index, # (MaxCards, 4)
299
+ actions, # (N,)
300
+ ):
301
+ """
302
+ Main CUDA Kernel for Stepping N Environments.
303
+ """
304
+ i = cuda.grid(1)
305
+
306
+ if i >= batch_global_ctx.shape[0]:
307
+ return
308
+
309
+ # Sync Score
310
+ batch_global_ctx[i, SC] = batch_scores[i]
311
+
312
+ act_id = actions[i]
313
+
314
+ # 1. Apply Action
315
+ if act_id > 0:
316
+ card_id = act_id
317
+
318
+ # Check Bounds
319
+ if card_id < bytecode_index.shape[0]:
320
+ # Assume Ability 0
321
+ map_idx = bytecode_index[card_id, 0]
322
+
323
+ if map_idx >= 0:
324
+ code_seq = bytecode_map[map_idx]
325
+
326
+ # Set Source Zone to Hand (1) -> mapped to index 7 in flat_ctx?
327
+ # In fast_logic.py: SZ = 7.
328
+ batch_flat_ctx[i, 7] = 1
329
+
330
+ # Execute
331
+ nc, st, bn = resolve_bytecode_device(
332
+ code_seq,
333
+ batch_flat_ctx[i],
334
+ batch_global_ctx[i],
335
+ 0, # Player ID
336
+ batch_hand[i],
337
+ batch_deck[i],
338
+ batch_stage[i],
339
+ batch_energy_vec[i],
340
+ batch_energy_count[i],
341
+ batch_continuous_vec[i],
342
+ batch_continuous_ptr[i], # Passed as scalar? No, ptr[i] is scalar, but device func expects ref?
343
+ # fast_logic expects 'p_cont_ptr' as int,
344
+ # returns new ptr.
345
+ # Wait, resolve_bytecode returns (cptr, ...).
346
+ # So we pass VALUE of ptr.
347
+ batch_tapped[i],
348
+ batch_live[i],
349
+ batch_opp_tapped[i],
350
+ )
351
+
352
+ # Update State
353
+ batch_continuous_ptr[i] = nc
354
+ batch_scores[i] = batch_global_ctx[i, SC] + bn # SC updated inside + bonus?
355
+ # Actually resolve_bytecode updates SC in global_ctx for O_HEARTS.
356
+ # So we just take global_ctx[SC].
357
+
358
+ # Reset SZ
359
+ batch_flat_ctx[i, 7] = 0
360
+
361
+ # Remove Card from Hand
362
+ found = False
363
+ for h_idx in range(60):
364
+ if batch_hand[i, h_idx] == card_id:
365
+ batch_hand[i, h_idx] = 0
366
+ batch_global_ctx[i, 3] -= 1 # HD
367
+ found = True
368
+ break
369
+
370
+ # Place on Stage (if Member)
371
+ if found and card_id < 900:
372
+ for s_idx in range(3):
373
+ if batch_stage[i, s_idx] == -1:
374
+ batch_stage[i, s_idx] = card_id
375
+ break
376
+
377
+ # Draw Logic (Refill to 5)
378
+ # Count Hand
379
+ h_cnt = 0
380
+ for h_idx in range(60):
381
+ if batch_hand[i, h_idx] > 0:
382
+ h_cnt += 1
383
+
384
+ if h_cnt < 5:
385
+ # Draw top card
386
+ top_card = 0
387
+ d_idx_found = -1
388
+ for d_idx in range(60):
389
+ if batch_deck[i, d_idx] > 0:
390
+ top_card = batch_deck[i, d_idx]
391
+ d_idx_found = d_idx
392
+ break
393
+
394
+ if top_card > 0:
395
+ for h_idx in range(60):
396
+ if batch_hand[i, h_idx] == 0:
397
+ batch_hand[i, h_idx] = top_card
398
+ batch_deck[i, d_idx_found] = 0
399
+ batch_global_ctx[i, 3] += 1
400
+ batch_global_ctx[i, 6] -= 1
401
+ break
402
+
403
+ # 2. Opponent (Random) Simulation
404
+ # Use XOROSHIRO RNG
405
+ if rng_states is not None:
406
+ r = xoroshiro128p_uniform_float32(rng_states, i)
407
+ if r > 0.8:
408
+ # Randomly tap an agent member?
409
+ pass