Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| import chess | |
| import chess.svg | |
| import cairosvg | |
| from stockfish import Stockfish | |
| from PIL import Image | |
| import imageio | |
| import io | |
| import os | |
| import subprocess | |
| model = tf.keras.models.load_model("chess_model.keras") | |
| CLASS_NAMES = [ | |
| 'black_bishop', 'black_king', 'black_knight', 'black_pawn', | |
| 'black_queen', 'black_rook', 'empty', 'white_bishop', | |
| 'white_king', 'white_knight', 'white_pawn', 'white_queen', | |
| 'white_rook' | |
| ] | |
| result = subprocess.run(['which', 'stockfish'], capture_output=True, text=True) | |
| stockfish_path = result.stdout.strip() or '/usr/games/stockfish' | |
| stockfish = Stockfish(path=stockfish_path, depth=15) | |
| def detect_and_crop_board(image_path): | |
| img = cv2.imread(image_path) | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| edges = cv2.Canny(gray, 50, 150) | |
| contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| largest_area = 0 | |
| board_contour = None | |
| for contour in contours: | |
| area = cv2.contourArea(contour) | |
| if area > largest_area: | |
| largest_area = area | |
| board_contour = contour | |
| x, y, w, h = cv2.boundingRect(board_contour) | |
| board = img_rgb[y:y+h, x:x+w] | |
| board = cv2.resize(board, (400, 400)) | |
| return board | |
| def split_board_into_squares(board_img): | |
| square_size = 400 // 8 | |
| squares = [] | |
| for row in range(7, -1, -1): | |
| for col in range(8): | |
| x = col * square_size | |
| y = row * square_size | |
| square = board_img[y:y+square_size, x:x+square_size] | |
| squares.append(square) | |
| return squares | |
| def predict_board(squares): | |
| batch = [] | |
| for square in squares: | |
| img = cv2.resize(square, (64, 64)) | |
| img = tf.keras.applications.efficientnet.preprocess_input( | |
| img.astype(np.float32) | |
| ) | |
| batch.append(img) | |
| batch = np.array(batch) | |
| preds = model.predict(batch, verbose=0) | |
| predictions = [] | |
| for i in range(len(preds)): | |
| class_idx = np.argmax(preds[i]) | |
| confidence = np.max(preds[i]) | |
| predictions.append({ | |
| "class": CLASS_NAMES[class_idx], | |
| "confidence": float(confidence) | |
| }) | |
| predictions = predictions[::-1] | |
| return predictions | |
| def predictions_to_fen(predictions): | |
| piece_map = { | |
| 'white_king': 'K', 'white_queen': 'Q', | |
| 'white_rook': 'R', 'white_bishop': 'B', | |
| 'white_knight': 'N', 'white_pawn': 'P', | |
| 'black_king': 'k', 'black_queen': 'q', | |
| 'black_rook': 'r', 'black_bishop': 'b', | |
| 'black_knight': 'n', 'black_pawn': 'p', | |
| 'empty': None | |
| } | |
| fen_rows = [] | |
| for row in range(7, -1, -1): | |
| empty_count = 0 | |
| fen_row = "" | |
| for col in range(8): | |
| idx = row * 8 + col | |
| piece = piece_map[predictions[idx]['class']] | |
| if piece is None: | |
| empty_count += 1 | |
| else: | |
| if empty_count > 0: | |
| fen_row += str(empty_count) | |
| empty_count = 0 | |
| fen_row += piece | |
| if empty_count > 0: | |
| fen_row += str(empty_count) | |
| fen_rows.append(fen_row) | |
| fen = "/".join(fen_rows) | |
| fen += " w - - 0 1" | |
| return fen | |
| def get_best_moves(fen, num_moves=3): | |
| stockfish.set_fen_position(fen) | |
| top_moves = stockfish.get_top_moves(num_moves) | |
| results = [] | |
| for move in top_moves: | |
| results.append({ | |
| "move": move["Move"], | |
| "centipawn": move["Centipawn"], | |
| "mate": move["Mate"] | |
| }) | |
| return results | |
| def create_gif(fen, moves): | |
| frames = [] | |
| board = chess.Board(fen) | |
| svg = chess.svg.board(board=board, size=400) | |
| png = cairosvg.svg2png(bytestring=svg.encode('utf-8')) | |
| img = Image.open(io.BytesIO(png)).convert('RGB') | |
| for _ in range(10): | |
| frames.append(np.array(img)) | |
| for move_data in moves: | |
| move = chess.Move.from_uci(move_data['move']) | |
| board.push(move) | |
| svg = chess.svg.board(board=board, size=400, lastmove=move) | |
| png = cairosvg.svg2png(bytestring=svg.encode('utf-8')) | |
| img = Image.open(io.BytesIO(png)).convert('RGB') | |
| for _ in range(15): | |
| frames.append(np.array(img)) | |
| board.pop() | |
| gif_path = "/tmp/chess_moves.gif" | |
| imageio.mimsave(gif_path, frames, fps=10) | |
| return gif_path | |
| def analyze(image, output_type): | |
| temp_path = "/tmp/chess_input.png" | |
| Image.fromarray(image).save(temp_path) | |
| board_img = detect_and_crop_board(temp_path) | |
| squares = split_board_into_squares(board_img) | |
| predictions = predict_board(squares) | |
| fen = predictions_to_fen(predictions) | |
| moves = get_best_moves(fen) | |
| result = f"FEN: {fen}\n\nBest Moves:\n" | |
| for i, move in enumerate(moves): | |
| score = move['centipawn'] | |
| if move['mate']: | |
| score = f"Mate in {move['mate']}" | |
| result += f"{i+1}. {move['move']} | Score: {score}\n" | |
| if output_type == "Text": | |
| return result, None | |
| else: | |
| gif_path = create_gif(fen, moves) | |
| return result, gif_path | |
| with gr.Blocks(title="Chess Analyzer") as app: | |
| gr.Markdown("# Chess Analyzer") | |
| gr.Markdown("Upload a Chess.com screenshot to get the best moves") | |
| with gr.Row(): | |
| image_input = gr.Image(label="Upload Screenshot", type="numpy") | |
| with gr.Row(): | |
| output_type = gr.Radio( | |
| choices=["Text", "Video"], | |
| value="Text", | |
| label="Output Type" | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Row(): | |
| text_output = gr.Textbox(label="Analysis Results", lines=8) | |
| gif_output = gr.Image(label="Move Animation", visible=False) | |
| def toggle_gif_visibility(choice): | |
| return gr.update(visible=choice == "Video") | |
| output_type.change( | |
| fn=toggle_gif_visibility, | |
| inputs=output_type, | |
| outputs=gif_output | |
| ) | |
| analyze_btn.click( | |
| fn=analyze, | |
| inputs=[image_input, output_type], | |
| outputs=[text_output, gif_output] | |
| ) | |
| app.launch() |