dualist / inference.py
brandonlanexyz's picture
Initial upload of Dualist Othello AI (Iteration 652)
cf2aacd verified
import torch
import torch.nn.functional as F
from model import OthelloNet
from bitboard import get_bit, make_input_planes
import numpy as np
def load_dualist(model_path="dualist_model.pth", device="cpu"):
"""
Loads the Dualist Othello model.
"""
model = OthelloNet(num_res_blocks=10, num_channels=256)
checkpoint = torch.load(model_path, map_location=device)
# Handle both full state dict and partial if needed
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
return model
def get_best_move(model, player_bb, opponent_bb, legal_moves_bb, device="cpu"):
"""
Given the current board state and legal moves, returns the best move (bitmask).
"""
# 1. Prepare input planes (3x8x8)
input_tensor = make_input_planes(player_bb, opponent_bb).to(device)
# 2. Forward pass
with torch.no_grad():
policy_logits, value = model(input_tensor)
# 3. Filter legal moves and find best
# The policy head outputs 65 indices (64 squares + 1 pass)
# We ignore the pass move for now unless no other moves are possible
# We'll map back to bitmask
probs = torch.exp(policy_logits).squeeze(0).cpu().numpy()
best_move_idx = -1
max_prob = -1.0
for i in range(64):
# Convert index back to (row, col)
row, col = (63 - i) // 8, (63 - i) % 8
mask = get_bit(row, col)
if legal_moves_bb & mask:
if probs[i] > max_prob:
max_prob = probs[i]
best_move_idx = i
if best_move_idx == -1:
# Check if pass (idx 64) is the only option or if something is wrong
return 0 # Pass/No move
row, col = (63 - best_move_idx) // 8, (63 - best_move_idx) % 8
return get_bit(row, col)
if __name__ == "__main__":
# Quick example: Starting position
# Black: bit 28 and 35
# White: bit 27 and 36
# (Simplified for demonstration)
print("Dualist Inference Test")
try:
model = load_dualist()
print("Model loaded successfully!")
# Starting position (Black pieces, White pieces)
# B: (3,4), (4,3) -> bits 27, 36? (depends on indexing)
# Using bits from Othello standard starting board
black_bb = 0x0000000810000000
white_bb = 0x0000001008000000
legal_moves = 0x0000102004080000 # Standard opening moves for Black
best = get_best_move(model, black_bb, white_bb, legal_moves)
print(f"Best move found: {hex(best)}")
except FileNotFoundError:
print("Error: dualist_model.pth not found. Ensure it's in the same directory.")
except Exception as e:
print(f"An error occurred: {e}")