trioskosmos commited on
Commit
e7b84fc
·
verified ·
1 Parent(s): 5206d7b

Upload ai/agents/neural_mcts.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/agents/neural_mcts.py +128 -0
ai/agents/neural_mcts.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+
6
+ # Add project root to path
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
8
+
9
+ import engine_rust
10
+
11
+ from ai.models.training_config import POLICY_SIZE
12
+ from ai.training.train import AlphaNet
13
+
14
+
15
+ class NeuralHeuristicAgent:
16
+ """
17
+ An agent that uses the ResNet (Intuition) to filter moves,
18
+ and MCTS (Calculation) to verify them.
19
+ """
20
+
21
+ def __init__(self, model_path, sims=100):
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ checkpoint = torch.load(model_path, map_location=self.device)
24
+ state_dict = (
25
+ checkpoint["model_state"] if isinstance(checkpoint, dict) and "model_state" in checkpoint else checkpoint
26
+ )
27
+
28
+ self.model = AlphaNet(policy_size=POLICY_SIZE).to(self.device)
29
+ self.model.load_state_dict(state_dict)
30
+ self.model.eval()
31
+
32
+ self.sims = sims
33
+
34
+ def get_action(self, game, db):
35
+ # 1. Get Logits from ResNet
36
+ encoded = game.encode_state(db)
37
+ state_tensor = torch.FloatTensor(encoded).unsqueeze(0).to(self.device)
38
+
39
+ with torch.no_grad():
40
+ logits, score_eval = self.model(state_tensor)
41
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
42
+
43
+ legal_actions = game.get_legal_action_ids()
44
+ if not legal_actions:
45
+ return 0
46
+ if len(legal_actions) == 1:
47
+ return int(legal_actions[0])
48
+
49
+ # 2. Run engine's fast MCTS (Random Rollout based)
50
+ # This provides a 'ground truth' sanity check.
51
+ mcts_suggestions = game.get_mcts_suggestions(self.sims, engine_rust.SearchHorizon.TurnEnd)
52
+ mcts_visits = {int(a): v for a, s, v in mcts_suggestions}
53
+ mcts_scores = {int(a): s for a, s, v in mcts_suggestions}
54
+
55
+ # 3. Combine Intuition (Probs) and Calculation (MCTS Win Rate)
56
+ # We calculate a combined score for each legal action
57
+ best_action = legal_actions[0]
58
+ max_score = -1e9
59
+
60
+ for aid in legal_actions:
61
+ aid = int(aid)
62
+ prior = probs[aid] if aid < len(probs) else 0.0
63
+
64
+ # Convert MCTS visits/score to a win probability [0, 1]
65
+ # MCTS score is usually total reward / visits.
66
+ # We'll use visits as a proxy for confidence.
67
+ win_prob = mcts_scores.get(aid, 0.0)
68
+ conf = mcts_visits.get(aid, 0) / (self.sims + 1)
69
+
70
+ # Strategy:
71
+ # If MCTS finds a move that is significantly better than PASS (0),
72
+ # we favor it even if ResNet is biased towards 0.
73
+
74
+ # Simple weighted sum
75
+ # Prior (0.3) + WinProb (0.7)
76
+ score = 0.3 * prior + 0.7 * win_prob
77
+
78
+ # Bonus for MCTS confidence
79
+ score += 0.2 * conf
80
+
81
+ if score > max_score:
82
+ max_score = score
83
+ best_action = aid
84
+
85
+ return best_action
86
+
87
+
88
+ class NeuralMCTSFullAgent:
89
+ """
90
+ AlphaZero-style agent that uses the Rust-implemented NeuralMCTS.
91
+ This is much faster than the Python hybrid because the entire
92
+ MCTS search and NN evaluation happens inside the Rust core.
93
+ """
94
+
95
+ def __init__(self, model_path, sims=100):
96
+ # We assume engine_rust has been compiled with ORT support.
97
+ # This will load the ONNX model once into a background session.
98
+ self.mcts = engine_rust.PyNeuralMCTS(model_path)
99
+ self.sims = sims
100
+
101
+ def get_action(self, game, db):
102
+ # suggestions: Vec<(action_id, score, visit_count)>
103
+ suggestions = self.mcts.get_suggestions(game, self.sims)
104
+ if not suggestions:
105
+ # Fallback to random or pass if something is wrong
106
+ return 0
107
+
108
+ # NeuralMCTS returns suggestions sorted by visit count descending
109
+ # so [0][0] is the most visited action.
110
+ return int(suggestions[0][0])
111
+
112
+
113
+ class HybridMCTSAgent:
114
+ """
115
+ The ultimate agent. It uses the Rust-implemented HybridMCTS
116
+ which blends Neural intuition with Heuristic calculation.
117
+ Target speed is <0.1s/move at 100 sims.
118
+ """
119
+
120
+ def __init__(self, model_path, sims=100, neural_weight=0.3):
121
+ self.mcts = engine_rust.PyHybridMCTS(model_path, neural_weight)
122
+ self.sims = sims
123
+
124
+ def get_action(self, game, db):
125
+ suggestions = self.mcts.get_suggestions(game, self.sims)
126
+ if not suggestions:
127
+ return 0
128
+ return int(suggestions[0][0])