| | import numpy as np
|
| |
|
| |
|
| |
|
| | BOARD_SIZE = 8
|
| | FULL_MASK = 0xFFFFFFFFFFFFFFFF
|
| |
|
| | def popcount(x):
|
| | """Counts set bits in a 64-bit integer."""
|
| | return bin(x).count('1')
|
| |
|
| | def bit_to_row_col(bit_mask):
|
| | """Converts a single bit mask to (row, col) coordinates."""
|
| | if bit_mask == 0:
|
| | return -1, -1
|
| |
|
| |
|
| | idx = bit_mask.bit_length() - 1
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | row = (63 - idx) // 8
|
| | col = (63 - idx) % 8
|
| | return row, col
|
| |
|
| | def get_bit(row, col):
|
| | """Returns a bitmask with a single bit set at (row, col)."""
|
| | shift = 63 - (row * 8 + col)
|
| | return 1 << shift
|
| |
|
| | def make_input_planes(player_bb, opponent_bb):
|
| | """
|
| | Converts bitboards into a 3x8x8 input tensor for the Neural Network.
|
| | Plane 0: Player pieces (1 if present, 0 otherwise)
|
| | Plane 1: Opponent pieces (1 if present, 0 otherwise)
|
| | Plane 2: Constant 1 (indicating it's the player's turn, or generally providing board usage context)
|
| | Some implementations use 'Valid Moves' here instead.
|
| | Let's use a constant plane for now as per AlphaZero standard,
|
| | or we can update to valid moves if we have them handy.
|
| | """
|
| | planes = np.zeros((3, 8, 8), dtype=np.float32)
|
| |
|
| |
|
| | for r in range(8):
|
| | for c in range(8):
|
| | mask = get_bit(r, c)
|
| | if player_bb & mask:
|
| | planes[0, r, c] = 1.0
|
| |
|
| |
|
| | for r in range(8):
|
| | for c in range(8):
|
| | mask = get_bit(r, c)
|
| | if opponent_bb & mask:
|
| | planes[1, r, c] = 1.0
|
| |
|
| |
|
| |
|
| | planes[2, :, :] = 1.0
|
| |
|
| | import torch
|
| | return torch.tensor(planes).unsqueeze(0)
|
| |
|
| | def print_board(black_bb, white_bb):
|
| | """Prints the board state using B/W symbols."""
|
| | print(" A B C D E F G H")
|
| | for r in range(8):
|
| | line = f"{r+1} "
|
| | for c in range(8):
|
| | mask = get_bit(r, c)
|
| | if black_bb & mask:
|
| | line += "B "
|
| | elif white_bb & mask:
|
| | line += "W "
|
| | else:
|
| | line += ". "
|
| | print(line)
|
| |
|