Spaces:
Sleeping
Sleeping
| from inference_sdk import InferenceHTTPClient | |
| from ultralytics import YOLO | |
| import cv2 | |
| from stockfish import Stockfish | |
| import os | |
| import numpy as np | |
| import streamlit as st | |
| CLIENT = InferenceHTTPClient( | |
| api_url="https://outline.roboflow.com", | |
| api_key="9Ez1hwfkqVa2h6pRQQHH" | |
| ) | |
| # Constants | |
| FEN_MAPPING = { | |
| "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k", | |
| "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K" | |
| } | |
| GRID_BORDER = 0 # Border size in pixels | |
| GRID_SIZE = 224 # Effective grid size (10px to 214px) | |
| BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px | |
| X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
| Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
| # Functions | |
| def get_grid_coordinate(pixel_x, pixel_y, perspective): | |
| """ | |
| Function to determine the grid coordinate of a pixel, considering a 10px border and | |
| the grid where bottom-left is (a, 1) and top-left is (h, 8). | |
| The perspective argument can adjust for white ('w') or black ('b') viewpoint. | |
| """ | |
| # Grid settings | |
| border = 0 # 10px border | |
| grid_size = 224 # Effective grid size (10px to 214px) | |
| block_size = grid_size // 8 # Each block is ~25px | |
| x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
| y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
| # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10) | |
| adjusted_x = pixel_x - border | |
| adjusted_y = pixel_y - border | |
| # Check bounds | |
| if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size: | |
| return "Pixel outside grid bounds" | |
| # Determine the grid column and row | |
| x_index = adjusted_x // block_size | |
| y_index = adjusted_y // block_size | |
| if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels): | |
| return "Pixel outside grid bounds" | |
| # Adjust labels based on perspective | |
| if perspective == "b": | |
| x_index = 7 - x_index # Flip x-axis for black's perspective | |
| y_index = 7- y_index # Flip y-axis for black's perspective | |
| file = x_labels[x_index] | |
| rank = y_labels[y_index] | |
| return f"{file}{rank}" | |
| def predict_next_move(fen, stockfish): | |
| """ | |
| Predict the next move using Stockfish. | |
| """ | |
| if stockfish.is_fen_valid(fen): | |
| stockfish.set_fen_position(fen) | |
| else: | |
| return "Invalid FEN notation!" | |
| best_move = stockfish.get_best_move() | |
| return f"The predicted next move is: {best_move}" if best_move else "No valid move found (checkmate/stalemate)." | |
| def main(): | |
| st.title("Chessboard Position Detection and Move Prediction") | |
| # Set permissions for the Stockfish engine binary | |
| os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755) | |
| # User uploads an image or captures it from their camera | |
| image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"]) | |
| if image_file is not None: | |
| # Save the image to a temporary file | |
| temp_dir = "temp_images" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg") | |
| with open(temp_file_path, "wb") as f: | |
| f.write(image_file.getbuffer()) | |
| # Load the YOLO models | |
| model = YOLO("chessDetection3d.pt") # Replace with your trained model weights file | |
| seg_model = YOLO("segmentation.pt") | |
| # Load and process the image | |
| img = cv2.imread(temp_file_path) | |
| r = seg_model.predict(source=temp_file_path) | |
| if len(r) == 0 or len(r) > 1: | |
| if len(r) == 0: | |
| st.write("NO BOARD IN THE IMAGE") | |
| elif len(r) > 1: | |
| st.write("Multiple boards are there in the image, please take only at a time") | |
| return | |
| xyxy = r[0].boxes.xyxy | |
| x_min, y_min, x_max, y_max = map(int, xyxy[0]) | |
| new_img = img[y_min:y_max, x_min:x_max] | |
| image = cv2.resize(new_img, (224, 224)) | |
| st.image(image, caption="Segmented Chessboard", use_container_width=True) | |
| height, width, _ = image.shape | |
| # Get user input for perspective | |
| p = st.radio("Select perspective:", ["b (Black)", "w (White)"]) | |
| p = p[0].lower() | |
| # Initialize the board for FEN (empty rows represented by "8") | |
| board = [["8"] * 8 for _ in range(8)] | |
| # Run detection | |
| results = model.predict(source=image, save=False, save_txt=False, conf=0.7) | |
| # Extract predictions and map to FEN board | |
| for result in results[0].boxes: | |
| x1, y1, x2, y2 = result.xyxy[0].tolist() | |
| class_id = int(result.cls[0]) | |
| class_name = model.names[class_id] | |
| fen_piece = FEN_MAPPING.get(class_name, None) | |
| if not fen_piece: | |
| continue | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| pixel_x = int(center_x) | |
| pixel_y = int(height - center_y) | |
| grid_position = get_grid_coordinate(pixel_x, pixel_y, p) | |
| if grid_position != "Pixel outside grid bounds": | |
| file = ord(grid_position[0]) - ord('a') | |
| rank = int(grid_position[1]) - 1 | |
| board[rank][file] = fen_piece | |
| # Generate the FEN string | |
| fen_rows = [] | |
| for row in board: | |
| fen_row = "" | |
| empty_count = 0 | |
| for cell in row: | |
| if cell == "8": | |
| empty_count += 1 | |
| else: | |
| if empty_count > 0: | |
| fen_row += str(empty_count) | |
| empty_count = 0 | |
| fen_row += cell | |
| if empty_count > 0: | |
| fen_row += str(empty_count) | |
| fen_rows.append(fen_row) | |
| position_fen = "/".join(fen_rows) | |
| move_side = st.radio("Select the side to move:", ["w (White)", "b (Black)"])[0].lower() | |
| fen_notation = f"{position_fen} {move_side} - - 0 0" | |
| st.subheader("Generated FEN Notation:") | |
| st.code(fen_notation) | |
| # Initialize the Stockfish engine | |
| stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt") | |
| stockfish = Stockfish( | |
| path=stockfish_path, | |
| depth=10, | |
| parameters={"Threads": 2, "Minimum Thinking Time": 2} | |
| ) | |
| # Predict the next move | |
| next_move = predict_next_move(fen_notation, stockfish) | |
| st.subheader("Stockfish Recommended Move:") | |
| st.write(next_move) | |
| if __name__ == "__main__": | |
| main() |