trioskosmos commited on
Commit
5c0c5f6
·
verified ·
1 Parent(s): 37004b0

Upload ai/models/network.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/models/network.py +480 -0
ai/models/network.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Neural Network for AlphaZero-style training.
3
+
4
+ This module provides a simple neural network architecture for policy and value
5
+ prediction. For a production system, you would use a more sophisticated
6
+ architecture (e.g., ResNet with attention) and train on GPU with PyTorch/TensorFlow.
7
+ """
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Tuple
11
+
12
+ import numpy as np
13
+
14
+
15
+ @dataclass
16
+ class NetworkConfig:
17
+ """Configuration for AlphaZero Network"""
18
+
19
+ input_size: int = 800 # Feature-based encoding (32 floats per card slot)
20
+ # Size of observation vector (Matches GameState.get_observation)
21
+ hidden_size: int = 256
22
+ num_hidden_layers: int = 3
23
+ action_size: int = 1000 # Size of action space (Matches GameState.get_legal_actions)
24
+ learning_rate: float = 0.001
25
+ l2_reg: float = 0.0001
26
+
27
+
28
+ def sigmoid(x: np.ndarray) -> np.ndarray:
29
+ return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
30
+
31
+
32
+ def relu(x: np.ndarray) -> np.ndarray:
33
+ return np.maximum(0, x)
34
+
35
+
36
+ def softmax(x: np.ndarray) -> np.ndarray:
37
+ exp_x = np.exp(x - np.max(x))
38
+ return exp_x / exp_x.sum()
39
+
40
+
41
+ def tanh(x: np.ndarray) -> np.ndarray:
42
+ return np.tanh(x)
43
+
44
+
45
+ class SimpleNetwork:
46
+ """
47
+ Simple feedforward neural network for policy and value prediction.
48
+
49
+ Architecture:
50
+ - Input layer (observation)
51
+ - Hidden layers with ReLU
52
+ - Policy head (softmax over actions)
53
+ - Value head (tanh for [-1, 1])
54
+ """
55
+
56
+ def __init__(self, config: NetworkConfig = None):
57
+ self.config = config or NetworkConfig()
58
+ self._init_weights()
59
+
60
+ def _init_weights(self) -> None:
61
+ """Initialize weights using He initialization"""
62
+ config = self.config
63
+
64
+ # Shared layers
65
+ self.hidden_weights = []
66
+ self.hidden_biases = []
67
+
68
+ in_size = config.input_size
69
+ for _ in range(config.num_hidden_layers):
70
+ std = np.sqrt(2.0 / in_size)
71
+ w = np.random.randn(in_size, config.hidden_size) * std
72
+ b = np.zeros(config.hidden_size)
73
+ self.hidden_weights.append(w)
74
+ self.hidden_biases.append(b)
75
+ in_size = config.hidden_size
76
+
77
+ # Policy head
78
+ std = np.sqrt(2.0 / config.hidden_size)
79
+ self.policy_weight = np.random.randn(config.hidden_size, config.action_size) * std
80
+ self.policy_bias = np.zeros(config.action_size)
81
+
82
+ # Value head
83
+ self.value_weight = np.random.randn(config.hidden_size, 1) * std
84
+ self.value_bias = np.zeros(1)
85
+
86
+ def forward(self, observation: np.ndarray) -> Tuple[np.ndarray, float]:
87
+ """
88
+ Forward pass.
89
+
90
+ Args:
91
+ observation: Input features
92
+
93
+ Returns:
94
+ (policy probabilities, value)
95
+ """
96
+ # Store activations for backward pass
97
+ self.activations = [observation]
98
+
99
+ x = observation
100
+ for w, b in zip(self.hidden_weights, self.hidden_biases, strict=False):
101
+ x = relu(x @ w + b)
102
+ self.activations.append(x)
103
+
104
+ # Policy head
105
+ policy_logits = x @ self.policy_weight + self.policy_bias
106
+ policy = softmax(policy_logits)
107
+
108
+ # Value head
109
+ value = tanh(x @ self.value_weight + self.value_bias)[0]
110
+
111
+ self.last_policy_logits = policy_logits
112
+ self.last_value = value
113
+
114
+ return policy, value
115
+
116
+ def predict(self, state) -> Tuple[np.ndarray, float]:
117
+ """Get policy and value for a game state"""
118
+ obs = state.get_observation()
119
+ policy, value = self.forward(obs)
120
+
121
+ # Mask illegal actions
122
+ legal = state.get_legal_actions()
123
+ masked_policy = policy * legal
124
+ if masked_policy.sum() > 0:
125
+ masked_policy /= masked_policy.sum()
126
+ else:
127
+ # Fall back to uniform over legal
128
+ masked_policy = legal.astype(np.float32)
129
+ masked_policy /= masked_policy.sum()
130
+
131
+ return masked_policy, value
132
+
133
+ def predict_batch(self, states) -> list:
134
+ """Get policy and value for a batch of game states"""
135
+ if not states:
136
+ return []
137
+
138
+ obs = np.array([s.get_observation() for s in states])
139
+ policies, values = self.forward(obs)
140
+
141
+ results = []
142
+ for i, (policy, value) in enumerate(zip(policies, values)):
143
+ legal = states[i].get_legal_actions()
144
+ masked_policy = policy * legal
145
+ if masked_policy.sum() > 0:
146
+ masked_policy /= masked_policy.sum()
147
+ else:
148
+ # Fall back to uniform over legal
149
+ masked_policy = legal.astype(np.float32)
150
+ masked_policy /= masked_policy.sum()
151
+ results.append((masked_policy, value))
152
+
153
+ return results
154
+
155
+ def train_step(
156
+ self, observations: np.ndarray, target_policies: np.ndarray, target_values: np.ndarray
157
+ ) -> Tuple[float, float, float]:
158
+ """
159
+ One training step (Vectorized).
160
+
161
+ Args:
162
+ observations: Batch of observations (batch_size, input_size)
163
+ target_policies: Target policy distributions (batch_size, action_size)
164
+ target_values: Target values (batch_size,)
165
+
166
+ Returns:
167
+ (total_loss, policy_loss, value_loss)
168
+ """
169
+ batch_size = len(observations)
170
+ config = self.config
171
+
172
+ # 1. Forward Pass (Batch)
173
+ pred_policy, pred_value = self.forward(observations)
174
+ # pred_policy: (B, action_size)
175
+ # pred_value: (B,)
176
+
177
+ # 2. Loss Calculation
178
+ # Policy loss: Cross-entropy
179
+ # Mean over batch
180
+ policy_loss = -np.mean(np.sum(target_policies * np.log(pred_policy + 1e-8), axis=1))
181
+
182
+ # Value loss: MSE
183
+ value_loss = np.mean((pred_value - target_values) ** 2)
184
+
185
+ total_loss = policy_loss + value_loss
186
+
187
+ # 3. Backward Pass (Gradients)
188
+ # d_policy = (pred - target) / batch_size (Gradient of Mean Cross Entropy)
189
+ # However, we treat the sum of gradients and then average manually update,
190
+ # so let's stick to the convention: dL/dLogits = (pred - target) / B
191
+ d_policy_logits = (pred_policy - target_policies) / batch_size
192
+
193
+ # d_value = 2 * (pred - target) * tanh'(pre_tanh) / batch_size
194
+ # tanh' = 1 - tanh^2 = 1 - pred_value^2
195
+ d_value_out = 2 * (pred_value - target_values) / batch_size
196
+ d_value_pre_tanh = d_value_out * (1 - pred_value**2)
197
+
198
+ # Gradients for heads
199
+ # hidden_out: (B, hidden_size) (Last activation)
200
+ hidden_out = self.activations[-1]
201
+
202
+ # d_Weights = Input.T @ Error
203
+ # Policy: (H, B) @ (B, A) -> (H, A)
204
+ grad_policy_w = hidden_out.T @ d_policy_logits
205
+ grad_policy_b = np.sum(d_policy_logits, axis=0)
206
+
207
+ # Value: (H, B) @ (B, 1) -> (H, 1)
208
+ # d_value_pre_tanh needs shape (B, 1)
209
+ d_value_pre_tanh = d_value_pre_tanh.reshape(-1, 1)
210
+ grad_value_w = hidden_out.T @ d_value_pre_tanh
211
+ grad_value_b = np.sum(d_value_pre_tanh, axis=0)
212
+
213
+ # Backprop through hidden layers
214
+ # d_hidden_last = d_policy @ W_p.T + d_value @ W_v.T
215
+ # (B, A) @ (A, H) + (B, 1) @ (1, H) -> (B, H)
216
+ d_hidden = d_policy_logits @ self.policy_weight.T + d_value_pre_tanh @ self.value_weight.T
217
+
218
+ # Store grads to apply later
219
+ grads_w = []
220
+ grads_b = []
221
+
222
+ # Iterate backwards through hidden layers
223
+ for layer_idx in range(len(self.hidden_weights) - 1, -1, -1):
224
+ # ReLU derivative: mask where activation > 0
225
+ # self.activations has inputs at [0], layer 1 out at [1], etc.
226
+ # layer_idx maps to weights[layer_idx], which produces activations[layer_idx+1]
227
+ mask = (self.activations[layer_idx + 1] > 0).astype(np.float32)
228
+ d_hidden = d_hidden * mask
229
+
230
+ prev_activation = self.activations[layer_idx]
231
+
232
+ # Gradients for this layer
233
+ # (In, B) @ (B, Out) -> (In, Out)
234
+ g_w = prev_activation.T @ d_hidden
235
+ g_b = np.sum(d_hidden, axis=0)
236
+
237
+ grads_w.insert(0, g_w)
238
+ grads_b.insert(0, g_b)
239
+
240
+ if layer_idx > 0:
241
+ # Propagate to previous layer
242
+ d_hidden = d_hidden @ self.hidden_weights[layer_idx].T
243
+
244
+ # 4. Apply Gradients (SGD + L2)
245
+ for i in range(len(self.hidden_weights)):
246
+ # L2: w = w - lr * (grad + l2 * w)
247
+ self.hidden_weights[i] -= config.learning_rate * (grads_w[i] + config.l2_reg * self.hidden_weights[i])
248
+ self.hidden_biases[i] -= config.learning_rate * grads_b[i]
249
+
250
+ self.policy_weight -= config.learning_rate * (grad_policy_w + config.l2_reg * self.policy_weight)
251
+ self.policy_bias -= config.learning_rate * grad_policy_b
252
+
253
+ self.value_weight -= config.learning_rate * (grad_value_w + config.l2_reg * self.value_weight)
254
+ self.value_bias -= config.learning_rate * grad_value_b
255
+
256
+ return total_loss, policy_loss, value_loss
257
+
258
+ def save(self, filepath: str) -> None:
259
+ """Save network weights to file"""
260
+ # Use allow_pickle and object-array conversion to handle inhomogeneous layer shapes
261
+ np.savez(
262
+ filepath,
263
+ hidden_weights=np.array(self.hidden_weights, dtype=object),
264
+ hidden_biases=np.array(self.hidden_biases, dtype=object),
265
+ policy_weight=self.policy_weight,
266
+ policy_bias=self.policy_bias,
267
+ value_weight=self.value_weight,
268
+ value_bias=self.value_bias,
269
+ )
270
+
271
+ def load(self, filepath: str) -> None:
272
+ """Load network weights from file"""
273
+ data = np.load(filepath, allow_pickle=True)
274
+ # Convert object arrays back to lists of arrays
275
+ self.hidden_weights = list(data["hidden_weights"])
276
+ self.hidden_biases = list(data["hidden_biases"])
277
+ self.policy_weight = data["policy_weight"]
278
+ self.policy_bias = data["policy_bias"]
279
+ self.value_weight = data["value_weight"]
280
+ self.value_bias = data["value_bias"]
281
+
282
+
283
+ class NeuralMCTS:
284
+ """MCTS that uses a neural network for policy and value with parallel search"""
285
+
286
+ def __init__(
287
+ self, network: SimpleNetwork, num_simulations: int = 100, batch_size: int = 8, virtual_loss: float = 3.0
288
+ ):
289
+ self.network = network
290
+ self.num_simulations = num_simulations
291
+ self.batch_size = batch_size
292
+ self.c_puct = 1.4
293
+ self.virtual_loss = virtual_loss
294
+ self.root = None
295
+
296
+ def get_policy_value(self, state) -> Tuple[np.ndarray, float]:
297
+ """Get policy and value from neural network"""
298
+ return self.network.predict(state)
299
+
300
+ def search(self, state) -> np.ndarray:
301
+ """Run MCTS with neural network guidance (Parallel)"""
302
+ from ai.mcts import MCTSNode
303
+
304
+ # Initial root expansion (always blocking)
305
+ policy, _ = self.get_policy_value(state)
306
+ self.root = MCTSNode()
307
+ self.root.expand(state, policy)
308
+
309
+ # We can't batch perfectly if simulations not divisible, but approx is fine
310
+ num_batches = (self.num_simulations + self.batch_size - 1) // self.batch_size
311
+
312
+ for _ in range(num_batches):
313
+ self._simulate_batch(state, self.batch_size)
314
+
315
+ # Return visit count distribution
316
+ # Note: visits length must match action_size from network config or game state
317
+ # MCTSNode children keys are actions.
318
+ # We need a fixed size array for the policy target.
319
+ action_size = len(state.get_legal_actions())
320
+ visits = np.zeros(action_size, dtype=np.float32)
321
+
322
+ for action, child in self.root.children.items():
323
+ visits[action] = child.visit_count
324
+
325
+ if visits.sum() > 0:
326
+ visits /= visits.sum()
327
+
328
+ return visits
329
+
330
+ def _simulate_batch(self, root_state, batch_size) -> None:
331
+ """Run a batch of MCTS simulations parallelized via Virtual Loss"""
332
+ paths = []
333
+ leaf_nodes = []
334
+ request_states = []
335
+
336
+ # 1. Selection Phase for K threads
337
+ for _ in range(batch_size):
338
+ node = self.root
339
+ state = root_state.copy()
340
+ path = [node]
341
+
342
+ # Selection
343
+ while node.is_expanded() and not state.is_terminal():
344
+ action, child = node.select_child(self.c_puct)
345
+
346
+ # Apply Virtual Loss immediately so subsequent selections in this batch diverge
347
+ child.virtual_loss += self.virtual_loss
348
+
349
+ state = state.step(action)
350
+ node = child
351
+ path.append(node)
352
+
353
+ paths.append((path, state))
354
+ leaf_nodes.append(node)
355
+
356
+ if not state.is_terminal():
357
+ request_states.append(state)
358
+
359
+ # 2. Evaluation Phase (Batched)
360
+ responses = []
361
+ if request_states:
362
+ if hasattr(self.network, "predict_batch"):
363
+ responses = self.network.predict_batch(request_states)
364
+ else:
365
+ responses = [self.network.predict(s) for s in request_states]
366
+
367
+ # 3. Expansion & Backpropagation Phase
368
+ resp_idx = 0
369
+ for i in range(batch_size):
370
+ path, state = paths[i]
371
+ leaf = leaf_nodes[i]
372
+
373
+ value = 0.0
374
+
375
+ if state.is_terminal():
376
+ value = state.get_reward(root_state.current_player)
377
+ else:
378
+ # Retrieve prediction
379
+ policy, v = responses[resp_idx]
380
+ resp_idx += 1
381
+ value = v
382
+
383
+ # Expand
384
+ leaf.expand(state, policy)
385
+
386
+ # Backpropagate
387
+ for node in reversed(path):
388
+ node.visit_count += 1
389
+ node.value_sum += value
390
+
391
+ # Remove Virtual Loss (except from root which we didn't add to?
392
+ # Wait, select_child returns child, and we added to child.
393
+ # Root is path[0]. path[1] is first child.
394
+ # So we should only subtract from path[1:] if we logic matches.
395
+ # But wait, did we add to root? No.
396
+ # So check: if node != self.root: node.virtual_loss -= ...
397
+ if node != self.root:
398
+ node.virtual_loss -= self.virtual_loss
399
+
400
+ value = -value
401
+
402
+
403
+ def train_network(network: SimpleNetwork, training_data: list, epochs: int = 10, batch_size: int = 32) -> None:
404
+ """
405
+ Train network on self-play data.
406
+
407
+ Args:
408
+ network: Network to train
409
+ training_data: List of (states, policies, winner) tuples
410
+ epochs: Number of training epochs
411
+ batch_size: Batch size for training
412
+ """
413
+ print(f"Training on {len(training_data)} games...")
414
+
415
+ # Flatten data with rewards
416
+ all_states = []
417
+ all_policies = []
418
+ all_values = []
419
+
420
+ for states, policies, winner, r0, r1 in training_data:
421
+ for i, (s, p) in enumerate(zip(states, policies, strict=False)):
422
+ all_states.append(s)
423
+ all_policies.append(p)
424
+
425
+ # Value from perspective of player who made the move
426
+ player_idx = i % 2
427
+
428
+ # Use actual calculated reward (with score shaping)
429
+ if player_idx == 0:
430
+ all_values.append(r0)
431
+ else:
432
+ all_values.append(r1)
433
+
434
+ all_states = np.array(all_states)
435
+ all_policies = np.array(all_policies)
436
+ all_values = np.array(all_values)
437
+
438
+ n_samples = len(all_states)
439
+
440
+ for epoch in range(epochs):
441
+ # Shuffle data
442
+ indices = np.random.permutation(n_samples)
443
+ total_loss = 0.0
444
+
445
+ for i in range(0, n_samples, batch_size):
446
+ batch_idx = indices[i : i + batch_size]
447
+ loss, p_loss, v_loss = network.train_step(
448
+ all_states[batch_idx], all_policies[batch_idx], all_values[batch_idx]
449
+ )
450
+ total_loss += loss
451
+
452
+ num_batches = (n_samples + batch_size - 1) // batch_size
453
+ print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / num_batches:.4f}")
454
+
455
+
456
+ if __name__ == "__main__":
457
+ # Test network
458
+ from engine.game.game_state import initialize_game
459
+
460
+ print("Testing neural network...")
461
+ config = NetworkConfig()
462
+ network = SimpleNetwork(config)
463
+
464
+ # Test forward pass
465
+ state = initialize_game()
466
+ policy, value = network.predict(state)
467
+
468
+ print(f"Policy shape: {policy.shape}")
469
+ print(f"Policy sum: {policy.sum():.4f}")
470
+ print(f"Value: {value:.4f}")
471
+
472
+ # Test training step
473
+ obs = state.get_observation()
474
+ target_p = np.zeros(config.action_size)
475
+ target_p[0] = 0.8
476
+ target_p[1] = 0.2
477
+ target_v = 0.5
478
+
479
+ loss, p_loss, v_loss = network.train_step(obs.reshape(1, -1), target_p.reshape(1, -1), np.array([target_v]))
480
+ print(f"Training loss: {loss:.4f} (policy: {p_loss:.4f}, value: {v_loss:.4f})")