Spaces:
Sleeping
Sleeping
File size: 7,187 Bytes
e6adbeb d75dae7 e6adbeb d75dae7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import tempfile
from typing import Optional
import cv2
import numpy as np
from smolagents import Tool
from ultralytics import YOLO
FEN_MAPPING = {
"black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q",
"black-king": "k",
"white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q",
"white-king": "K"
}
GRID_BORDER = 10 # Border size in pixels
GRID_SIZE = 204 # Effective grid size (10px to 214px)
BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
class ChessBoardRecognizerTool(Tool):
name = "chess_board_recognizer"
description = "Recognizes the state of chess board from image and returns the position representation in Forsyth-Edwards notation (FEN)"
inputs = {
"image_path": {
"type": "string",
"description": "The path of the chess board image file"
},
"is_white_turn": {
"type": "boolean",
"description": "Optionally white's turn on the chess board if value not provided",
"nullable": True
}
}
output_type = "string"
def __init__(self):
super().__init__()
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
self.model_std = YOLO(f"{parent_dir}/data/standard.pt")
self.model_seg = YOLO(f"{parent_dir}/data/segmentation.pt")
def forward(self, image_path: str, is_white_turn: Optional[bool] = None) -> str:
processed_image = self._process_image(image_path)
if processed_image is not None:
processed_image = cv2.resize(processed_image, (224, 224))
height, width, _ = processed_image.shape
results = self.model_std.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
# Initialize the board for FEN (empty rows represented by "8")
board = [["8"] * 8 for _ in range(8)]
# Extract predictions and map to FEN board
for result in results[0].boxes:
x1, y1, x2, y2 = result.xyxy[0].tolist()
class_id = int(result.cls[0])
class_name = self.model_std.names[class_id]
# Convert class_name to FEN notation
fen_piece = FEN_MAPPING.get(class_name, None)
if not fen_piece:
continue
# Calculate the center of the bounding box
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
# Convert to integer pixel coordinates
pixel_x = int(center_x)
pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
# Get grid coordinate
grid_position = self._get_grid_coordinate(pixel_x, pixel_y)
if grid_position != "Pixel outside grid bounds":
file = ord(grid_position[0]) - ord('a') # Column index (0-7)
rank = int(grid_position[1]) - 1 # Row index (0-7)
# Place the piece on the board
board[7 - rank][file] = fen_piece # Flip rank index for FEN
# Generate the FEN string
fen_rows = []
for row in board:
fen_row = ""
empty_count = 0
for cell in row:
if cell == "8":
empty_count += 1
else:
if empty_count > 0:
fen_row += str(empty_count)
empty_count = 0
fen_row += cell
if empty_count > 0:
fen_row += str(empty_count)
fen_rows.append(fen_row)
fen_str = "/".join(fen_rows)
b_or_w_turn = "w" if is_white_turn is None else "b"
return f"{fen_str} {b_or_w_turn} - - 0 1"
def _get_grid_coordinate(self, pixel_x, pixel_y):
"""
Function to determine the grid coordinate of a pixel, considering a 10px border and
the grid where bottom-left is (a, 1) and top-left is (h, 8).
"""
# Grid settings
border = 10 # 10px border
grid_size = 204 # Effective grid size (10px to 214px)
block_size = grid_size // 8 # Each block is ~25px
x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
# Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
adjusted_x = pixel_x - border
adjusted_y = pixel_y - border
# Check bounds
if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
return "Pixel outside grid bounds"
# Determine the grid column and row
x_index = adjusted_x // block_size
y_index = adjusted_y // block_size
if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
return "Pixel outside grid bounds"
# Convert indices to grid coordinates
x_index = adjusted_x // block_size # Determine the column index (0-7)
y_index = adjusted_y // block_size # Determine the row index (0-7)
# Convert row index to the correct label, with '8' at the bottom
y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
x_label = x_labels[x_index]
y_label = 8 - y_labeld + 1
return f"{x_label}{y_label}"
def _process_image(self, image_path):
results = self.model_seg.predict(
source=image_path,
conf=0.8 # Confidence threshold
)
segmentation_mask = None
bbox = None
for result in results:
if result.boxes.conf[0] >= 0.8: # Filter results by confidence
segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
break
if segmentation_mask is None:
print("No segmentation mask with confidence above 0.8 found.")
return None
image = cv2.imread(image_path)
# Resize segmentation mask to match the input image dimensions
segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
if bbox is None:
print("No bounding box coordinates found. Skip cropping the image")
return None
x1, y1, x2, y2 = bbox
cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
cropped_image_path = tempfile.NamedTemporaryFile(suffix=".jpg").name
cv2.imwrite(cropped_image_path, cropped_segment)
print(f"Cropped segmented image saved to {cropped_image_path}")
# Return the cropped image
return cropped_segment
|