Spaces:
Build error
Build error
| from ultralytics import YOLO | |
| import cv2 | |
| from stockfish import Stockfish | |
| import os | |
| import numpy as np | |
| import streamlit as st | |
| import requests | |
| # Constants | |
| FEN_MAPPING = { | |
| "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k", | |
| "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K" | |
| } | |
| GRID_BORDER = 10 # Border size in pixels | |
| GRID_SIZE = 204 # Effective grid size (10px to 214px) | |
| BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px | |
| X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
| Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
| # Functions | |
| def get_grid_coordinate(pixel_x, pixel_y): | |
| """ | |
| Function to determine the grid coordinate of a pixel, considering a 10px border and | |
| the grid where bottom-left is (a, 1) and top-left is (h, 8). | |
| """ | |
| # Grid settings | |
| border = 10 # 10px border | |
| grid_size = 204 # Effective grid size (10px to 214px) | |
| block_size = grid_size // 8 # Each block is ~25px | |
| x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h) | |
| y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1) | |
| # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10) | |
| adjusted_x = pixel_x - border | |
| adjusted_y = pixel_y - border | |
| # Check bounds | |
| if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size: | |
| return "Pixel outside grid bounds" | |
| # Determine the grid column and row | |
| x_index = adjusted_x // block_size | |
| y_index = adjusted_y // block_size | |
| if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels): | |
| return "Pixel outside grid bounds" | |
| # Convert indices to grid coordinates | |
| x_index = adjusted_x // block_size # Determine the column index (0-7) | |
| y_index = adjusted_y // block_size # Determine the row index (0-7) | |
| # Convert row index to the correct label, with '8' at the bottom | |
| y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1' | |
| x_label = x_labels[x_index] | |
| y_label = 8 - y_labeld + 1 | |
| return f"{x_label}{y_label}" | |
| def predict_next_move(fen, stockfish): | |
| """ | |
| Predict the next move using Stockfish. | |
| """ | |
| if stockfish.is_fen_valid(fen): | |
| stockfish.set_fen_position(fen) | |
| else: | |
| return "Invalid FEN notation!" | |
| best_move = stockfish.get_best_move() | |
| ans = transform_string(best_move) | |
| return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)." | |
| # def download_stockfish(): | |
| # url = "https://drive.google.com/file/d/18pkwBVc13fgKP3LzrTHE4yzhjyGJexlR/view?usp=sharing" # Replace with the actual link | |
| # file_name = "stockfish-windows-x86-64-avx2.exe" | |
| # if not os.path.exists(file_name): | |
| # print(f"Downloading {file_name}...") | |
| # response = requests.get(url, stream=True) | |
| # with open(file_name, "wb") as file: | |
| # for chunk in response.iter_content(chunk_size=1024): | |
| # if chunk: | |
| # file.write(chunk) | |
| # print(f"{file_name} downloaded successfully.") | |
| def process_image(image_path): | |
| # Ensure output directory exists | |
| if not os.path.exists('output'): | |
| os.makedirs('output') | |
| # Load the segmentation model | |
| segmentation_model = YOLO("segmentation.pt") | |
| # Run inference to get segmentation results | |
| results = segmentation_model.predict( | |
| source=image_path, | |
| conf=0.8 # Confidence threshold | |
| ) | |
| # Initialize variables for the segmented mask and bounding box | |
| segmentation_mask = None | |
| bbox = None | |
| for result in results: | |
| if result.boxes.conf[0] >= 0.8: # Filter results by confidence | |
| segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0] | |
| bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates | |
| break | |
| if segmentation_mask is None: | |
| print("No segmentation mask with confidence above 0.8 found.") | |
| return None | |
| # Load the image | |
| image = cv2.imread(image_path) | |
| # Resize segmentation mask to match the input image dimensions | |
| segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0])) | |
| # Extract bounding box coordinates | |
| if bbox is not None: | |
| x1, y1, x2, y2 = bbox | |
| # Crop the segmented region based on the bounding box | |
| cropped_segment = image[int(y1):int(y2), int(x1):int(x2)] | |
| # Save the cropped segmented image | |
| cropped_image_path = 'output/cropped_segment.jpg' | |
| cv2.imwrite(cropped_image_path, cropped_segment) | |
| print(f"Cropped segmented image saved to {cropped_image_path}") | |
| st.image(cropped_segment, caption="Uploaded Image", use_column_width=True) | |
| # Return the cropped image | |
| return cropped_segment | |
| def transform_string(input_str): | |
| # Remove extra spaces and convert to lowercase | |
| input_str = input_str.strip().lower() | |
| # Check if input is valid | |
| if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \ | |
| not input_str[2].isalpha() or not input_str[3].isdigit(): | |
| return "Invalid input" | |
| # Define mappings | |
| letter_mapping = { | |
| 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e', | |
| 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a' | |
| } | |
| number_mapping = { | |
| '1': '8', '2': '7', '3': '6', '4': '5', | |
| '5': '4', '6': '3', '7': '2', '8': '1' | |
| } | |
| # Transform string | |
| result = "" | |
| for i, char in enumerate(input_str): | |
| if i % 2 == 0: # Letters | |
| result += letter_mapping.get(char, "Invalid") | |
| else: # Numbers | |
| result += number_mapping.get(char, "Invalid") | |
| # Check for invalid transformations | |
| if "Invalid" in result: | |
| return "Invalid input" | |
| return result | |
| # Streamlit app | |
| def main(): | |
| # download_stockfish() | |
| st.title("Chessboard Position Detection and Move Prediction") | |
| os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755) | |
| st.write(os.getcwd()) | |
| # User uploads an image or captures it from their camera | |
| image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"]) | |
| if image_file is not None: | |
| # Save the image to a temporary file | |
| temp_dir = "temp_images" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg") | |
| with open(temp_file_path, "wb") as f: | |
| f.write(image_file.getbuffer()) | |
| # Process the image using its file path | |
| processed_image = process_image(temp_file_path) | |
| if processed_image is not None: | |
| # Resize the image to 224x224 | |
| processed_image = cv2.resize(processed_image, (224, 224)) | |
| height, width, _ = processed_image.shape | |
| # Initialize the YOLO model | |
| model = YOLO("standard.pt") # Replace with your trained model weights file | |
| # Run detection | |
| results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6) | |
| # Initialize the board for FEN (empty rows represented by "8") | |
| board = [["8"] * 8 for _ in range(8)] | |
| # Extract predictions and map to FEN board | |
| for result in results[0].boxes: | |
| x1, y1, x2, y2 = result.xyxy[0].tolist() | |
| class_id = int(result.cls[0]) | |
| class_name = model.names[class_id] | |
| # Convert class_name to FEN notation | |
| fen_piece = FEN_MAPPING.get(class_name, None) | |
| if not fen_piece: | |
| continue | |
| # Calculate the center of the bounding box | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| # Convert to integer pixel coordinates | |
| pixel_x = int(center_x) | |
| pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system | |
| # Get grid coordinate | |
| grid_position = get_grid_coordinate(pixel_x, pixel_y) | |
| if grid_position != "Pixel outside grid bounds": | |
| file = ord(grid_position[0]) - ord('a') # Column index (0-7) | |
| rank = int(grid_position[1]) - 1 # Row index (0-7) | |
| # Place the piece on the board | |
| board[7 - rank][file] = fen_piece # Flip rank index for FEN | |
| # Generate the FEN string | |
| fen_rows = [] | |
| for row in board: | |
| fen_row = "" | |
| empty_count = 0 | |
| for cell in row: | |
| if cell == "8": | |
| empty_count += 1 | |
| else: | |
| if empty_count > 0: | |
| fen_row += str(empty_count) | |
| empty_count = 0 | |
| fen_row += cell | |
| if empty_count > 0: | |
| fen_row += str(empty_count) | |
| fen_rows.append(fen_row) | |
| position_fen = "/".join(fen_rows) | |
| # Ask the user for the next move side | |
| move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"]) | |
| move_side = "w" if move_side.startswith("w") else "b" | |
| # Append the full FEN string continuation | |
| fen_notation = f"{position_fen} {move_side} - - 0 0" | |
| st.subheader("Generated FEN Notation:") | |
| st.code(fen_notation) | |
| # Initialize the Stockfish engine | |
| stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt") | |
| stockfish = Stockfish( | |
| path=stockfish_path, | |
| depth=15, | |
| parameters={"Threads": 2, "Minimum Thinking Time": 30} | |
| ) | |
| # Predict the next move | |
| next_move = predict_next_move(fen_notation, stockfish) | |
| st.subheader("Stockfish Recommended Move:") | |
| st.write(next_move) | |
| else: | |
| st.error("Failed to process the image. Please try again.") | |
| if __name__ == "__main__": | |
| main() | |