| | import torch
|
| | import torch.nn.functional as F
|
| | from model import OthelloNet
|
| | from bitboard import get_bit, make_input_planes
|
| | import numpy as np
|
| |
|
| | def load_dualist(model_path="dualist_model.pth", device="cpu"):
|
| | """
|
| | Loads the Dualist Othello model.
|
| | """
|
| | model = OthelloNet(num_res_blocks=10, num_channels=256)
|
| | checkpoint = torch.load(model_path, map_location=device)
|
| |
|
| |
|
| | if "model_state_dict" in checkpoint:
|
| | model.load_state_dict(checkpoint["model_state_dict"])
|
| | else:
|
| | model.load_state_dict(checkpoint)
|
| |
|
| | model.to(device)
|
| | model.eval()
|
| | return model
|
| |
|
| | def get_best_move(model, player_bb, opponent_bb, legal_moves_bb, device="cpu"):
|
| | """
|
| | Given the current board state and legal moves, returns the best move (bitmask).
|
| | """
|
| |
|
| | input_tensor = make_input_planes(player_bb, opponent_bb).to(device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | policy_logits, value = model(input_tensor)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | probs = torch.exp(policy_logits).squeeze(0).cpu().numpy()
|
| |
|
| | best_move_idx = -1
|
| | max_prob = -1.0
|
| |
|
| | for i in range(64):
|
| |
|
| | row, col = (63 - i) // 8, (63 - i) % 8
|
| | mask = get_bit(row, col)
|
| |
|
| | if legal_moves_bb & mask:
|
| | if probs[i] > max_prob:
|
| | max_prob = probs[i]
|
| | best_move_idx = i
|
| |
|
| | if best_move_idx == -1:
|
| |
|
| | return 0
|
| |
|
| | row, col = (63 - best_move_idx) // 8, (63 - best_move_idx) % 8
|
| | return get_bit(row, col)
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | print("Dualist Inference Test")
|
| | try:
|
| | model = load_dualist()
|
| | print("Model loaded successfully!")
|
| |
|
| |
|
| |
|
| |
|
| | black_bb = 0x0000000810000000
|
| | white_bb = 0x0000001008000000
|
| | legal_moves = 0x0000102004080000
|
| |
|
| | best = get_best_move(model, black_bb, white_bb, legal_moves)
|
| | print(f"Best move found: {hex(best)}")
|
| |
|
| | except FileNotFoundError:
|
| | print("Error: dualist_model.pth not found. Ensure it's in the same directory.")
|
| | except Exception as e:
|
| | print(f"An error occurred: {e}")
|
| |
|