trioskosmos commited on
Commit
b1ed5c6
·
verified ·
1 Parent(s): 0dd780c

Upload ai/utils/verify_parity.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/verify_parity.py +72 -0
ai/utils/verify_parity.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ # Add project root to path
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
9
+
10
+ import engine_rust
11
+
12
+ from ai.models.training_config import POLICY_SIZE
13
+ from ai.training.train import AlphaNet
14
+
15
+
16
+ def verify_parity():
17
+ # 1. Setup Models
18
+ model_path_pt = "ai/models/alphanet_best.pt"
19
+ model_path_onnx = "ai/models/alphanet.onnx"
20
+
21
+ device = torch.device("cpu")
22
+ checkpoint = torch.load(model_path_pt, map_location=device)
23
+ state_dict = checkpoint["model_state"] if "model_state" in checkpoint else checkpoint
24
+
25
+ model_pt = AlphaNet(policy_size=POLICY_SIZE)
26
+ model_pt.load_state_dict(state_dict)
27
+ model_pt.eval()
28
+
29
+ nmcts_rust = engine_rust.PyNeuralMCTS(model_path_onnx)
30
+
31
+ # 2. Setup Game
32
+ with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
33
+ db_content = f.read()
34
+ db = engine_rust.PyCardDatabase(db_content)
35
+ game = engine_rust.PyGameState(db)
36
+
37
+ # Simple init
38
+ p0_deck = [124, 127] * 20
39
+ p1_deck = [124, 127] * 20
40
+ p0_lives = [1024, 1025, 1027]
41
+ p1_lives = [1024, 1025, 1027]
42
+ game.initialize_game(p0_deck, p1_deck, [20000] * 10, [20000] * 10, p0_lives, p1_lives)
43
+
44
+ # 3. Compare Encoding
45
+ obs_rust = np.array(game.get_observation(), dtype=np.float32)
46
+
47
+ # 4. Compare Inference
48
+ with torch.no_grad():
49
+ logits_pt, val_pt = model_pt(torch.from_numpy(obs_rust).unsqueeze(0))
50
+ logits_pt = logits_pt.numpy()[0]
51
+ val_pt = val_pt.numpy()[0][0]
52
+
53
+ print(f"Observation Size: {len(obs_rust)}")
54
+ print(f"Value (Python): {val_pt:.6f}")
55
+
56
+ # Note: We need a way to get raw evaluate results from Rust for deep parity
57
+ # But for now, we'll check suggestions as a proxy
58
+ num_sims = 100
59
+ suggestions = nmcts_rust.get_suggestions(game, num_sims)
60
+
61
+ print("\nTop Python Actions (Logits):")
62
+ top_pt = np.argsort(logits_pt)[::-1][:5]
63
+ for i in top_pt:
64
+ print(f" Action {i:4d}: {logits_pt[i]:.4f}")
65
+
66
+ print("\nRust Suggestions (1 simulation):")
67
+ for action, score, visits in suggestions[:5]:
68
+ print(f" Action {action:4d}: Score {score:.4f}, Visits {visits}")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ verify_parity()