| import gradio as gr |
| import chess |
| import torch |
| import numpy as np |
| import os |
| import io |
| import time |
| from collections import OrderedDict |
| from PIL import Image, ImageDraw, ImageFont |
| from RL_Chess_Alpha import AlphaChessNet, AlphaMCTS |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| WEIGHTS_PATH = "alpha_merged_weights.pth" |
|
|
| |
| |
| |
| model = AlphaChessNet().to(DEVICE) |
| if os.path.exists(WEIGHTS_PATH): |
| model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE)) |
| model.eval() |
|
|
| def load_piece_images(): |
| images = {} |
| mapping = { |
| 'P': 'w_pawn_png_128px', 'R': 'w_rook_png_128px', 'N': 'w_knight_png_128px', |
| 'B': 'w_bishop_png_128px', 'Q': 'w_queen_png_128px', 'K': 'w_king_png_128px', |
| 'p': 'b_pawn_png_128px', 'r': 'b_rook_png_128px', 'n': 'b_knight_png_128px', |
| 'b': 'b_bishop_png_128px', 'q': 'b_queen_png_128px', 'k': 'b_king_png_128px' |
| } |
| for code, name in mapping.items(): |
| path = f"pieces/{name}.png" |
| if os.path.exists(path): |
| try: |
| img = Image.open(path).convert("RGBA") |
| images[code] = img.resize((80, 80), Image.LANCZOS) |
| except Exception as e: |
| print(f"Error loading {path}: {e}") |
| images[code] = None |
| else: |
| print(f"File not found: {path}") |
| images[code] = None |
| return images |
|
|
| piece_images = load_piece_images() |
|
|
| |
| |
| |
| class GameSession: |
| def __init__(self): |
| self.board = chess.Board() |
| self.selected_sq = None |
| self.mcts = AlphaMCTS(model, DEVICE) |
|
|
| |
| sessions = OrderedDict() |
|
|
| def get_session(sid): |
| if sid not in sessions: |
| if len(sessions) > 100: |
| sessions.popitem(last=False) |
| sessions[sid] = GameSession() |
| return sessions[sid] |
|
|
| |
| |
| |
| def render_board(session, status=""): |
| size = 640 |
| sq = 80 |
| img = Image.new('RGB', (size, size), '#769656') |
| draw = ImageDraw.Draw(img) |
| for r in range(8): |
| for f in range(8): |
| color = '#eeeed2' if (r + f) % 2 == 0 else '#769656' |
| square = chess.square(f, 7-r) |
| if session.selected_sq == square: |
| color = '#f5f682' |
| elif session.selected_sq is not None: |
| |
| for move in session.board.legal_moves: |
| if move.from_square == session.selected_sq and move.to_square == square: |
| color = '#acdf91' |
| break |
| draw.rectangle([f*sq, r*sq, (f+1)*sq, (r+1)*sq], fill=color) |
| |
| for s in chess.SQUARES: |
| p = session.board.piece_at(s) |
| if p: |
| f, r = chess.square_file(s), 7 - chess.square_rank(s) |
| img_p = piece_images.get(p.symbol()) |
| if img_p: |
| img.paste(img_p, (f*sq, r*sq), img_p) |
| |
| return img, status |
|
|
| def handle_click(evt: gr.SelectData, wm, bm, request: gr.Request): |
| |
| sid = request.session_hash |
| session = get_session(sid) |
| x, y = evt.index |
| file, rank = x // 80, 7 - (y // 80) |
| clicked_sq = chess.square(int(file), int(rank)) |
| |
| |
| ctrl = wm if session.board.turn == chess.WHITE else bm |
| if ctrl != "Human": |
| return render_board(session, "It is the AI's turn.") |
|
|
| if session.selected_sq == clicked_sq: |
| session.selected_sq = None |
| elif session.selected_sq is None: |
| piece = session.board.piece_at(clicked_sq) |
| if piece and piece.color == session.board.turn: |
| session.selected_sq = clicked_sq |
| else: |
| move = chess.Move(session.selected_sq, clicked_sq) |
| if move in session.board.legal_moves or chess.Move(session.selected_sq, clicked_sq, chess.QUEEN) in session.board.legal_moves: |
| if move not in session.board.legal_moves: |
| move = chess.Move(session.selected_sq, clicked_sq, chess.QUEEN) |
| session.board.push(move) |
| session.selected_sq = None |
| else: |
| piece = session.board.piece_at(clicked_sq) |
| if piece and piece.color == session.board.turn: |
| session.selected_sq = clicked_sq |
| else: |
| session.selected_sq = None |
| |
| return render_board(session, f"{'White' if session.board.turn == chess.WHITE else 'Black'} to move") |
|
|
| def ai_turn(sid, sims): |
| session = get_session(sid) |
| if session.board.is_game_over(): |
| return render_board(session, f"GameOver: {session.board.result()}") |
| |
| |
| if list(session.board.legal_moves) == []: |
| return render_board(session, "No legal moves.") |
|
|
| root = session.mcts.search(session.board, simulations=int(sims), training=False) |
| if not root.children: |
| return render_board(session, "AI could not find a move.") |
| |
| move = max(root.children.items(), key=lambda x: x[1].N)[0] |
| san = session.board.san(move) |
| session.board.push(move) |
| session.selected_sq = None |
| |
| return render_board(session, f"AI played {san}") |
|
|
| def reset_all(request: gr.Request): |
| sid = request.session_hash |
| session = get_session(sid) |
| session.board.reset() |
| session.selected_sq = None |
| session.mcts = AlphaMCTS(model, DEVICE) |
| return render_board(session, "Game Reset") |
|
|
| |
| |
| |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# ♟️ AlphaZero Tactical Engine - Skill level 3") |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| |
| dummy_session = GameSession() |
| initial_img, _ = render_board(dummy_session) |
| |
| board_img = gr.Image(value=initial_img, width=640, height=640, interactive=False, label="Chess Board") |
| status_txt = gr.Textbox(label="Status", value="Ready / White to move") |
| with gr.Column(scale=1): |
| w_mode = gr.Radio(["Human", "AI"], value="Human", label="White Player") |
| b_mode = gr.Radio(["Human", "AI"], value="AI", label="Black Player") |
| sims_val = gr.Slider(10, 1000, 400, label="Search Simulations") |
| auto_go = gr.Checkbox(label="Enable Autoplay", value=True) |
| reset_btn = gr.Button("Reset Game") |
|
|
| def main_loop(is_auto, wm, bm, s, request: gr.Request): |
| sid = request.session_hash |
| session = get_session(sid) |
| if not is_auto or session.board.is_game_over(): |
| return gr.update(), gr.update() |
| |
| current_ctrl = wm if session.board.turn == chess.WHITE else bm |
| if current_ctrl == "AI": |
| return ai_turn(sid, s) |
| |
| return gr.update(), gr.update() |
|
|
| |
| board_img.select( |
| handle_click, |
| inputs=[w_mode, b_mode], |
| outputs=[board_img, status_txt] |
| ) |
| |
| reset_btn.click( |
| reset_all, |
| outputs=[board_img, status_txt] |
| ) |
| |
| timer = gr.Timer(1.0) |
| timer.tick( |
| main_loop, |
| inputs=[auto_go, w_mode, b_mode, sims_val], |
| outputs=[board_img, status_txt] |
| ) |
|
|
| if __name__ == "__main__": |
| |
| demo.launch(theme=gr.themes.Soft()) |