import numpy as np # Bitboard Constants BOARD_SIZE = 8 FULL_MASK = 0xFFFFFFFFFFFFFFFF def popcount(x): """Counts set bits in a 64-bit integer.""" return bin(x).count('1') def bit_to_row_col(bit_mask): """Converts a single bit mask to (row, col) coordinates.""" if bit_mask == 0: return -1, -1 # Find the index of the set bit (0-63) # Assumes only one bit is set idx = bit_mask.bit_length() - 1 # Edax/Othello usually maps MSB to A1 (0,0) or LSB to H8 (7,7) # Let's align with Edax: A1 is usually high bit. # Standard: index 63 is A1, index 0 is H8. # row = (63 - idx) // 8 # col = (63 - idx) % 8 # However, standard bit manipulation often uses LSB=0. # Let's check Edax conventions later, but for now standard math: row = (63 - idx) // 8 col = (63 - idx) % 8 return row, col def get_bit(row, col): """Returns a bitmask with a single bit set at (row, col).""" shift = 63 - (row * 8 + col) return 1 << shift def make_input_planes(player_bb, opponent_bb): """ Converts bitboards into a 3x8x8 input tensor for the Neural Network. Plane 0: Player pieces (1 if present, 0 otherwise) Plane 1: Opponent pieces (1 if present, 0 otherwise) Plane 2: Constant 1 (indicating it's the player's turn, or generally providing board usage context) Some implementations use 'Valid Moves' here instead. Let's use a constant plane for now as per AlphaZero standard, or we can update to valid moves if we have them handy. """ planes = np.zeros((3, 8, 8), dtype=np.float32) # Fill Plane 0 (Player) for r in range(8): for c in range(8): mask = get_bit(r, c) if player_bb & mask: planes[0, r, c] = 1.0 # Fill Plane 1 (Opponent) for r in range(8): for c in range(8): mask = get_bit(r, c) if opponent_bb & mask: planes[1, r, c] = 1.0 # Fill Plane 2 (Constant / Color) # Often for single-network (canonical form), this might just be 1s. planes[2, :, :] = 1.0 import torch return torch.tensor(planes).unsqueeze(0) # Add batch dimension: (1, 3, 8, 8) def print_board(black_bb, white_bb): """Prints the board state using B/W symbols.""" print(" A B C D E F G H") for r in range(8): line = f"{r+1} " for c in range(8): mask = get_bit(r, c) if black_bb & mask: line += "B " elif white_bb & mask: line += "W " else: line += ". " print(line)