import torch import torch.nn.functional as F from model import OthelloNet from bitboard import get_bit, make_input_planes import numpy as np def load_dualist(model_path="dualist_model.pth", device="cpu"): """ Loads the Dualist Othello model. """ model = OthelloNet(num_res_blocks=10, num_channels=256) checkpoint = torch.load(model_path, map_location=device) # Handle both full state dict and partial if needed if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) else: model.load_state_dict(checkpoint) model.to(device) model.eval() return model def get_best_move(model, player_bb, opponent_bb, legal_moves_bb, device="cpu"): """ Given the current board state and legal moves, returns the best move (bitmask). """ # 1. Prepare input planes (3x8x8) input_tensor = make_input_planes(player_bb, opponent_bb).to(device) # 2. Forward pass with torch.no_grad(): policy_logits, value = model(input_tensor) # 3. Filter legal moves and find best # The policy head outputs 65 indices (64 squares + 1 pass) # We ignore the pass move for now unless no other moves are possible # We'll map back to bitmask probs = torch.exp(policy_logits).squeeze(0).cpu().numpy() best_move_idx = -1 max_prob = -1.0 for i in range(64): # Convert index back to (row, col) row, col = (63 - i) // 8, (63 - i) % 8 mask = get_bit(row, col) if legal_moves_bb & mask: if probs[i] > max_prob: max_prob = probs[i] best_move_idx = i if best_move_idx == -1: # Check if pass (idx 64) is the only option or if something is wrong return 0 # Pass/No move row, col = (63 - best_move_idx) // 8, (63 - best_move_idx) % 8 return get_bit(row, col) if __name__ == "__main__": # Quick example: Starting position # Black: bit 28 and 35 # White: bit 27 and 36 # (Simplified for demonstration) print("Dualist Inference Test") try: model = load_dualist() print("Model loaded successfully!") # Starting position (Black pieces, White pieces) # B: (3,4), (4,3) -> bits 27, 36? (depends on indexing) # Using bits from Othello standard starting board black_bb = 0x0000000810000000 white_bb = 0x0000001008000000 legal_moves = 0x0000102004080000 # Standard opening moves for Black best = get_best_move(model, black_bb, white_bb, legal_moves) print(f"Best move found: {hex(best)}") except FileNotFoundError: print("Error: dualist_model.pth not found. Ensure it's in the same directory.") except Exception as e: print(f"An error occurred: {e}")