File size: 2,312 Bytes
b1ed5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import sys

import numpy as np
import torch

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import engine_rust

from ai.models.training_config import POLICY_SIZE
from ai.training.train import AlphaNet


def verify_parity():
    # 1. Setup Models
    model_path_pt = "ai/models/alphanet_best.pt"
    model_path_onnx = "ai/models/alphanet.onnx"

    device = torch.device("cpu")
    checkpoint = torch.load(model_path_pt, map_location=device)
    state_dict = checkpoint["model_state"] if "model_state" in checkpoint else checkpoint

    model_pt = AlphaNet(policy_size=POLICY_SIZE)
    model_pt.load_state_dict(state_dict)
    model_pt.eval()

    nmcts_rust = engine_rust.PyNeuralMCTS(model_path_onnx)

    # 2. Setup Game
    with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
        db_content = f.read()
    db = engine_rust.PyCardDatabase(db_content)
    game = engine_rust.PyGameState(db)

    # Simple init
    p0_deck = [124, 127] * 20
    p1_deck = [124, 127] * 20
    p0_lives = [1024, 1025, 1027]
    p1_lives = [1024, 1025, 1027]
    game.initialize_game(p0_deck, p1_deck, [20000] * 10, [20000] * 10, p0_lives, p1_lives)

    # 3. Compare Encoding
    obs_rust = np.array(game.get_observation(), dtype=np.float32)

    # 4. Compare Inference
    with torch.no_grad():
        logits_pt, val_pt = model_pt(torch.from_numpy(obs_rust).unsqueeze(0))
        logits_pt = logits_pt.numpy()[0]
        val_pt = val_pt.numpy()[0][0]

    print(f"Observation Size: {len(obs_rust)}")
    print(f"Value (Python): {val_pt:.6f}")

    # Note: We need a way to get raw evaluate results from Rust for deep parity
    # But for now, we'll check suggestions as a proxy
    num_sims = 100
    suggestions = nmcts_rust.get_suggestions(game, num_sims)

    print("\nTop Python Actions (Logits):")
    top_pt = np.argsort(logits_pt)[::-1][:5]
    for i in top_pt:
        print(f"  Action {i:4d}: {logits_pt[i]:.4f}")

    print("\nRust Suggestions (1 simulation):")
    for action, score, visits in suggestions[:5]:
        print(f"  Action {action:4d}: Score {score:.4f}, Visits {visits}")


if __name__ == "__main__":
    verify_parity()