File size: 2,978 Bytes
cf2aacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn.functional as F
from model import OthelloNet
from bitboard import get_bit, make_input_planes
import numpy as np

def load_dualist(model_path="dualist_model.pth", device="cpu"):
    """

    Loads the Dualist Othello model.

    """
    model = OthelloNet(num_res_blocks=10, num_channels=256)
    checkpoint = torch.load(model_path, map_location=device)
    
    # Handle both full state dict and partial if needed
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)
        
    model.to(device)
    model.eval()
    return model

def get_best_move(model, player_bb, opponent_bb, legal_moves_bb, device="cpu"):
    """

    Given the current board state and legal moves, returns the best move (bitmask).

    """
    # 1. Prepare input planes (3x8x8)
    input_tensor = make_input_planes(player_bb, opponent_bb).to(device)
    
    # 2. Forward pass
    with torch.no_grad():
        policy_logits, value = model(input_tensor)
        
    # 3. Filter legal moves and find best
    # The policy head outputs 65 indices (64 squares + 1 pass)
    # We ignore the pass move for now unless no other moves are possible
    # We'll map back to bitmask
    
    probs = torch.exp(policy_logits).squeeze(0).cpu().numpy()
    
    best_move_idx = -1
    max_prob = -1.0
    
    for i in range(64):
        # Convert index back to (row, col)
        row, col = (63 - i) // 8, (63 - i) % 8
        mask = get_bit(row, col)
        
        if legal_moves_bb & mask:
            if probs[i] > max_prob:
                max_prob = probs[i]
                best_move_idx = i
                
    if best_move_idx == -1:
        # Check if pass (idx 64) is the only option or if something is wrong
        return 0 # Pass/No move
        
    row, col = (63 - best_move_idx) // 8, (63 - best_move_idx) % 8
    return get_bit(row, col)

if __name__ == "__main__":
    # Quick example: Starting position
    # Black: bit 28 and 35
    # White: bit 27 and 36
    # (Simplified for demonstration)
    
    print("Dualist Inference Test")
    try:
        model = load_dualist()
        print("Model loaded successfully!")
        
        # Starting position (Black pieces, White pieces)
        # B: (3,4), (4,3) -> bits 27, 36? (depends on indexing)
        # Using bits from Othello standard starting board
        black_bb = 0x0000000810000000
        white_bb = 0x0000001008000000
        legal_moves = 0x0000102004080000 # Standard opening moves for Black
        
        best = get_best_move(model, black_bb, white_bb, legal_moves)
        print(f"Best move found: {hex(best)}")
        
    except FileNotFoundError:
        print("Error: dualist_model.pth not found. Ensure it's in the same directory.")
    except Exception as e:
        print(f"An error occurred: {e}")