File size: 7,187 Bytes
e6adbeb
d75dae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6adbeb
 
 
 
 
d75dae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import tempfile
from typing import Optional

import cv2
import numpy as np
from smolagents import Tool
from ultralytics import YOLO

FEN_MAPPING = {
    "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q",
    "black-king": "k",
    "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q",
    "white-king": "K"
}
GRID_BORDER = 10  # Border size in pixels
GRID_SIZE = 204  # Effective grid size (10px to 214px)
BLOCK_SIZE = GRID_SIZE // 8  # Each block is ~25px
X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']  # Labels for x-axis (a to h)
Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1]  # Reversed labels for y-axis (8 to 1)


class ChessBoardRecognizerTool(Tool):
    name = "chess_board_recognizer"
    description = "Recognizes the state of chess board from image and returns the position representation in Forsyth-Edwards notation (FEN)"
    inputs = {
        "image_path": {
            "type": "string",
            "description": "The path of the chess board image file"
        },
        "is_white_turn": {
            "type": "boolean",
            "description": "Optionally white's turn on the chess board if value not provided",
            "nullable": True

        }
    }
    output_type = "string"

    def __init__(self):
        super().__init__()

        parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

        self.model_std = YOLO(f"{parent_dir}/data/standard.pt")
        self.model_seg = YOLO(f"{parent_dir}/data/segmentation.pt")

    def forward(self, image_path: str, is_white_turn: Optional[bool] = None) -> str:
        processed_image = self._process_image(image_path)

        if processed_image is not None:
            processed_image = cv2.resize(processed_image, (224, 224))
            height, width, _ = processed_image.shape

            results = self.model_std.predict(source=processed_image, save=False, save_txt=False, conf=0.6)

            # Initialize the board for FEN (empty rows represented by "8")
            board = [["8"] * 8 for _ in range(8)]

            # Extract predictions and map to FEN board
            for result in results[0].boxes:
                x1, y1, x2, y2 = result.xyxy[0].tolist()
                class_id = int(result.cls[0])
                class_name = self.model_std.names[class_id]

                # Convert class_name to FEN notation
                fen_piece = FEN_MAPPING.get(class_name, None)
                if not fen_piece:
                    continue

                # Calculate the center of the bounding box
                center_x = (x1 + x2) / 2
                center_y = (y1 + y2) / 2

                # Convert to integer pixel coordinates
                pixel_x = int(center_x)
                pixel_y = int(height - center_y)  # Flip Y-axis for generic coordinate system

                # Get grid coordinate
                grid_position = self._get_grid_coordinate(pixel_x, pixel_y)

                if grid_position != "Pixel outside grid bounds":
                    file = ord(grid_position[0]) - ord('a')  # Column index (0-7)
                    rank = int(grid_position[1]) - 1  # Row index (0-7)

                    # Place the piece on the board
                    board[7 - rank][file] = fen_piece  # Flip rank index for FEN

            # Generate the FEN string
            fen_rows = []
            for row in board:
                fen_row = ""
                empty_count = 0
                for cell in row:
                    if cell == "8":
                        empty_count += 1
                    else:
                        if empty_count > 0:
                            fen_row += str(empty_count)
                            empty_count = 0
                        fen_row += cell
                if empty_count > 0:
                    fen_row += str(empty_count)
                fen_rows.append(fen_row)

            fen_str = "/".join(fen_rows)
            b_or_w_turn = "w" if is_white_turn is None else "b"

            return f"{fen_str} {b_or_w_turn} - - 0 1"

    def _get_grid_coordinate(self, pixel_x, pixel_y):
        """
        Function to determine the grid coordinate of a pixel, considering a 10px border and
        the grid where bottom-left is (a, 1) and top-left is (h, 8).
        """
        # Grid settings
        border = 10  # 10px border
        grid_size = 204  # Effective grid size (10px to 214px)
        block_size = grid_size // 8  # Each block is ~25px

        x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']  # Labels for x-axis (a to h)
        y_labels = [8, 7, 6, 5, 4, 3, 2, 1]  # Reversed labels for y-axis (8 to 1)

        # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
        adjusted_x = pixel_x - border
        adjusted_y = pixel_y - border

        # Check bounds
        if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
            return "Pixel outside grid bounds"

        # Determine the grid column and row
        x_index = adjusted_x // block_size
        y_index = adjusted_y // block_size

        if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
            return "Pixel outside grid bounds"

        # Convert indices to grid coordinates
        x_index = adjusted_x // block_size  # Determine the column index (0-7)
        y_index = adjusted_y // block_size  # Determine the row index (0-7)

        # Convert row index to the correct label, with '8' at the bottom
        y_labeld = y_labels[y_index]  # Correct index directly maps to '8' to '1'
        x_label = x_labels[x_index]
        y_label = 8 - y_labeld + 1

        return f"{x_label}{y_label}"

    def _process_image(self, image_path):
        results = self.model_seg.predict(
            source=image_path,
            conf=0.8  # Confidence threshold
        )

        segmentation_mask = None
        bbox = None

        for result in results:
            if result.boxes.conf[0] >= 0.8:  # Filter results by confidence
                segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
                bbox = result.boxes.xyxy[0].cpu().numpy()  # Get the bounding box coordinates
                break

        if segmentation_mask is None:
            print("No segmentation mask with confidence above 0.8 found.")
            return None

        image = cv2.imread(image_path)

        # Resize segmentation mask to match the input image dimensions
        segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))

        if bbox is None:
            print("No bounding box coordinates found. Skip cropping the image")
            return None

        x1, y1, x2, y2 = bbox
        cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]

        cropped_image_path = tempfile.NamedTemporaryFile(suffix=".jpg").name
        cv2.imwrite(cropped_image_path, cropped_segment)

        print(f"Cropped segmented image saved to {cropped_image_path}")

        # Return the cropped image
        return cropped_segment