Spaces:
Running
Running
Upload ai/utils/verify_parity.py with huggingface_hub
Browse files- ai/utils/verify_parity.py +72 -0
ai/utils/verify_parity.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Add project root to path
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
| 9 |
+
|
| 10 |
+
import engine_rust
|
| 11 |
+
|
| 12 |
+
from ai.models.training_config import POLICY_SIZE
|
| 13 |
+
from ai.training.train import AlphaNet
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def verify_parity():
|
| 17 |
+
# 1. Setup Models
|
| 18 |
+
model_path_pt = "ai/models/alphanet_best.pt"
|
| 19 |
+
model_path_onnx = "ai/models/alphanet.onnx"
|
| 20 |
+
|
| 21 |
+
device = torch.device("cpu")
|
| 22 |
+
checkpoint = torch.load(model_path_pt, map_location=device)
|
| 23 |
+
state_dict = checkpoint["model_state"] if "model_state" in checkpoint else checkpoint
|
| 24 |
+
|
| 25 |
+
model_pt = AlphaNet(policy_size=POLICY_SIZE)
|
| 26 |
+
model_pt.load_state_dict(state_dict)
|
| 27 |
+
model_pt.eval()
|
| 28 |
+
|
| 29 |
+
nmcts_rust = engine_rust.PyNeuralMCTS(model_path_onnx)
|
| 30 |
+
|
| 31 |
+
# 2. Setup Game
|
| 32 |
+
with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
|
| 33 |
+
db_content = f.read()
|
| 34 |
+
db = engine_rust.PyCardDatabase(db_content)
|
| 35 |
+
game = engine_rust.PyGameState(db)
|
| 36 |
+
|
| 37 |
+
# Simple init
|
| 38 |
+
p0_deck = [124, 127] * 20
|
| 39 |
+
p1_deck = [124, 127] * 20
|
| 40 |
+
p0_lives = [1024, 1025, 1027]
|
| 41 |
+
p1_lives = [1024, 1025, 1027]
|
| 42 |
+
game.initialize_game(p0_deck, p1_deck, [20000] * 10, [20000] * 10, p0_lives, p1_lives)
|
| 43 |
+
|
| 44 |
+
# 3. Compare Encoding
|
| 45 |
+
obs_rust = np.array(game.get_observation(), dtype=np.float32)
|
| 46 |
+
|
| 47 |
+
# 4. Compare Inference
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
logits_pt, val_pt = model_pt(torch.from_numpy(obs_rust).unsqueeze(0))
|
| 50 |
+
logits_pt = logits_pt.numpy()[0]
|
| 51 |
+
val_pt = val_pt.numpy()[0][0]
|
| 52 |
+
|
| 53 |
+
print(f"Observation Size: {len(obs_rust)}")
|
| 54 |
+
print(f"Value (Python): {val_pt:.6f}")
|
| 55 |
+
|
| 56 |
+
# Note: We need a way to get raw evaluate results from Rust for deep parity
|
| 57 |
+
# But for now, we'll check suggestions as a proxy
|
| 58 |
+
num_sims = 100
|
| 59 |
+
suggestions = nmcts_rust.get_suggestions(game, num_sims)
|
| 60 |
+
|
| 61 |
+
print("\nTop Python Actions (Logits):")
|
| 62 |
+
top_pt = np.argsort(logits_pt)[::-1][:5]
|
| 63 |
+
for i in top_pt:
|
| 64 |
+
print(f" Action {i:4d}: {logits_pt[i]:.4f}")
|
| 65 |
+
|
| 66 |
+
print("\nRust Suggestions (1 simulation):")
|
| 67 |
+
for action, score, visits in suggestions[:5]:
|
| 68 |
+
print(f" Action {action:4d}: Score {score:.4f}, Visits {visits}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
verify_parity()
|