Agents-Course-Assignment / my_tool_chess_board.py
krzsam's picture
commit
86cbfce
from smolagents import Tool
from PIL import Image
import os
import cv2
import numpy as np
import math
import numpy
from my_train_chess_pieces_recognition import ChessPiecesRecognition
import traceback
# Based on; https://github.com/kratos606/chessboard-recogniser/tree/main
class ChessBoard(Tool):
name = "_my_tool_chess_board"
description = """
Process an image of a chess board and return board position as a list of chess pieces
To invoke the tool use code as below
<code>
chess_pieces = _my_tool_chess_board(img=loaded_image)
</code>
"""
inputs = {
"img": {
"type": "image",
"description": "image of chess board to extract board position",
}
}
output_type = "string"
is_initialized = False
# Steps to do
# - break board image into array of images representing pieces
# Image -> Image[]
# - image recognition on set of images to get piece labels
# Image[] -> str[]
# - construct FEN from string of chess pieces
# str[] -> []
def __init__(self, _chess_board_model_name, _chess_board_model_dir):
print(f"***KS*** ChessBoard initializing ...")
self.recognition = ChessPiecesRecognition(_chess_board_model_name, _chess_board_model_dir)
self.is_initialized = True
def __gradientx(self, img):
# Compute gradient in x-direction using larger Sobel kernel
grad_x = cv2.Sobel(img, cv2.CV_32F, 1, 0, ksize=31)
return grad_x
def __gradienty(self, img):
# Compute gradient in y-direction using larger Sobel kernel
grad_y = cv2.Sobel(img, cv2.CV_32F, 0, 1, ksize=31)
return grad_y
def __checkMatch(self, lineset):
linediff = np.diff(lineset)
x = 0
cnt = 0
for line in linediff:
if abs(line - x) < 5:
cnt += 1
else:
cnt = 0
x = line
return cnt == 5
def __pruneLines(self, lineset, image_dim, margin=20):
# Remove lines near the margins
lineset = [x for x in lineset if x > margin and x < image_dim - margin]
if not lineset:
return lineset
linediff = np.diff(lineset)
x = 0
cnt = 0
start_pos = 0
for i, line in enumerate(linediff):
if abs(line - x) < 5:
cnt += 1
if cnt == 5:
end_pos = i + 2
return lineset[start_pos:end_pos]
else:
cnt = 0
x = line
start_pos = i
return lineset
def __skeletonize_1d(self, arr):
_arr = arr.copy()
for i in range(len(_arr) - 1):
if _arr[i] <= _arr[i + 1]:
_arr[i] = 0
for i in range(len(_arr) - 1, 0, -1):
if _arr[i - 1] > _arr[i]:
_arr[i] = 0
return _arr
def __getChessLines(self, hdx, hdy, hdx_thresh, hdy_thresh, image_shape):
# Generate Gaussian window
window_size = 21
sigma = 8.0
gausswin = cv2.getGaussianKernel(window_size, sigma, cv2.CV_64F)
gausswin = gausswin.flatten()
half_size = window_size // 2
# Threshold signals
hdx_thresh_binary = np.where(hdx > hdx_thresh, 1.0, 0.0)
hdy_thresh_binary = np.where(hdy > hdy_thresh, 1.0, 0.0)
# Blur signals using convolution with Gaussian window
blur_x = np.convolve(hdx_thresh_binary, gausswin, mode='same')
blur_y = np.convolve(hdy_thresh_binary, gausswin, mode='same')
# Skeletonize signals
skel_x = self.__skeletonize_1d(blur_x)
skel_y = self.__skeletonize_1d(blur_y)
# Find line positions
lines_x = np.where(skel_x > 0)[0].tolist()
lines_y = np.where(skel_y > 0)[0].tolist()
# Prune lines
lines_x = self.__pruneLines(lines_x, image_shape[1])
lines_y = self.__pruneLines(lines_y, image_shape[0])
# Check if lines match expected pattern
is_match = (len(lines_x) == 7) and (len(lines_y) == 7) and \
self.__checkMatch(lines_x) and self.__checkMatch(lines_y)
return lines_x, lines_y, is_match
def __getChessTiles(self, img, lines_x, lines_y):
stepx = int(round(np.mean(np.diff(lines_x))))
stepy = int(round(np.mean(np.diff(lines_y))))
# Pad the image if necessary
padl_x = 0
padr_x = 0
padl_y = 0
padr_y = 0
if lines_x[0] - stepx < 0:
padl_x = abs(lines_x[0] - stepx)
if lines_x[-1] + stepx > img.shape[1] - 1:
padr_x = lines_x[-1] + stepx - img.shape[1] + 1
if lines_y[0] - stepy < 0:
padl_y = abs(lines_y[0] - stepy)
if lines_y[-1] + stepy > img.shape[0] - 1:
padr_y = lines_y[-1] + stepy - img.shape[0] + 1
img_padded = cv2.copyMakeBorder(img, padl_y, padr_y, padl_x, padr_x, cv2.BORDER_REPLICATE)
setsx = [lines_x[0] - stepx + padl_x] + [x + padl_x for x in lines_x] + [lines_x[-1] + stepx + padl_x]
setsy = [lines_y[0] - stepy + padl_y] + [y + padl_y for y in lines_y] + [lines_y[-1] + stepy + padl_y]
squares = []
for j in range(8):
for i in range(8):
x1 = setsx[i]
x2 = setsx[i + 1]
y1 = setsy[j]
y2 = setsy[j + 1]
# Adjust sizes to ensure squares are of equal size
if (x2 - x1) != stepx:
x2 = x1 + stepx
if (y2 - y1) != stepy:
y2 = y1 + stepy
square = img_padded[y1:y2, x1:x2]
squares.append(square)
return squares
# Image(PIL) --> Image(CV2)[]
def __extract_pieces_from_image_board(self, image):
# Load the image
if image is None:
print(f"Image not provided")
return
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Preprocessing
equ = cv2.equalizeHist(gray)
norm_image = equ.astype(np.float32) / 255.0
# Compute the gradients
grad_x = self.__gradientx(norm_image)
grad_y = self.__gradienty(norm_image)
# Clip the gradients
Dx_pos = np.clip(grad_x, 0, None)
Dx_neg = np.clip(-grad_x, 0, None)
Dy_pos = np.clip(grad_y, 0, None)
Dy_neg = np.clip(-grad_y, 0, None)
# Compute the Hough transform
hough_Dx = (np.sum(Dx_pos, axis=0) * np.sum(Dx_neg, axis=0)) / (norm_image.shape[0] ** 2)
hough_Dy = (np.sum(Dy_pos, axis=1) * np.sum(Dy_neg, axis=1)) / (norm_image.shape[1] ** 2)
# Adaptive thresholding
a = 1
is_match = False
lines_x = []
lines_y = []
while a < 5:
threshold_x = np.max(hough_Dx) * (a / 5.0)
threshold_y = np.max(hough_Dy) * (a / 5.0)
lines_x, lines_y, is_match = self.__getChessLines(hough_Dx, hough_Dy, threshold_x, threshold_y,
norm_image.shape)
if is_match:
break
else:
a += 1
squares_resized = []
if is_match:
squares = self.__getChessTiles(gray, lines_x, lines_y)
for square in squares:
resized = cv2.resize(square, (32, 32), interpolation=cv2.INTER_AREA)
squares_resized.append(resized)
#print("7 horizontal and vertical lines found, slicing up squares")
#squares = self.getChessTiles(gray, lines_x, lines_y)
#print(f"Tiles generated: ({squares[0].shape[0]}x{squares[0].shape[1]}) * {len(squares)}")
# Extract filename and FEN (assuming filename is FEN)
#img_save_dir = os.path.join("/mnt/c/Users/krzsa/IdeaProjects/Agents-Course-Assignment/chess-pieces")
#letters = "ABCDEFGH"
#for i, square in enumerate(squares):
# filename = f"fen_{letters[i % 8]}{(i // 8) + 1}.png"
# save_path = os.path.join(img_save_dir, filename)
# if i % 8 == 0:
# print(f"#{i}: saving {save_path}...")
# # Resize to 32x32 and save
# resized = cv2.resize(square, (32, 32), interpolation=cv2.INTER_AREA)
# cv2.imwrite(save_path, resized)
return squares_resized
def __detect_chess_pieces(self, images):
return self.recognition.classify_pieces(images)
#def __convert_pieces_list_to_fen(self, pieces):
# return "dupa"
def forward(self, img: Image) -> str:
pieces_list = ""
try:
print(f"***KS*** Analyzing chess board image for image: {img}")
cv2_image = cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
#print(f"***KS*** Got CV2 image shape: {cv2_image.shape}")
# Image(PIL) -> Image(CV2)(32x32) []
squares_resized = self.__extract_pieces_from_image_board(cv2_image)
#print(f"***KS*** Squares resized: {len(squares_resized)}")
# Image(CV2)(32x32) [] -> str[]
pieces_list = self.__detect_chess_pieces(squares_resized)
print(f"***KS*** Pieces list: {pieces_list}")
except Exception as ex:
print(traceback.format_exc())
print(f"***KS*** Exception invoking ChessBoard: {ex}")
return pieces_list