import gradio as gr import chess import torch import numpy as np import os import io import time from collections import OrderedDict from PIL import Image, ImageDraw, ImageFont from RL_Chess_Alpha import AlphaChessNet, AlphaMCTS DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' WEIGHTS_PATH = "alpha_merged_weights.pth" # ========================================== # 1. GLOBAL RESOURCES (Loaded once for all) # ========================================== model = AlphaChessNet().to(DEVICE) if os.path.exists(WEIGHTS_PATH): model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE)) model.eval() def load_piece_images(): images = {} mapping = { 'P': 'w_pawn_png_128px', 'R': 'w_rook_png_128px', 'N': 'w_knight_png_128px', 'B': 'w_bishop_png_128px', 'Q': 'w_queen_png_128px', 'K': 'w_king_png_128px', 'p': 'b_pawn_png_128px', 'r': 'b_rook_png_128px', 'n': 'b_knight_png_128px', 'b': 'b_bishop_png_128px', 'q': 'b_queen_png_128px', 'k': 'b_king_png_128px' } for code, name in mapping.items(): path = f"pieces/{name}.png" if os.path.exists(path): try: img = Image.open(path).convert("RGBA") images[code] = img.resize((80, 80), Image.LANCZOS) except Exception as e: print(f"Error loading {path}: {e}") images[code] = None else: print(f"File not found: {path}") images[code] = None return images piece_images = load_piece_images() # ========================================== # 2. SESSION MANAGEMENT (One per browser tab) # ========================================== class GameSession: def __init__(self): self.board = chess.Board() self.selected_sq = None self.mcts = AlphaMCTS(model, DEVICE) # We use an OrderedDict to limit stored sessions and prevent memory leaks sessions = OrderedDict() def get_session(sid): if sid not in sessions: if len(sessions) > 100: sessions.popitem(last=False) # Remove oldest session if we have too many sessions[sid] = GameSession() return sessions[sid] # ========================================== # 3. GAME LOGIC # ========================================== def render_board(session, status=""): size = 640 sq = 80 img = Image.new('RGB', (size, size), '#769656') draw = ImageDraw.Draw(img) for r in range(8): for f in range(8): color = '#eeeed2' if (r + f) % 2 == 0 else '#769656' square = chess.square(f, 7-r) if session.selected_sq == square: color = '#f5f682' elif session.selected_sq is not None: # Highlight legal moves for the selected piece for move in session.board.legal_moves: if move.from_square == session.selected_sq and move.to_square == square: color = '#acdf91' break draw.rectangle([f*sq, r*sq, (f+1)*sq, (r+1)*sq], fill=color) for s in chess.SQUARES: p = session.board.piece_at(s) if p: f, r = chess.square_file(s), 7 - chess.square_rank(s) img_p = piece_images.get(p.symbol()) if img_p: img.paste(img_p, (f*sq, r*sq), img_p) return img, status def handle_click(evt: gr.SelectData, wm, bm, request: gr.Request): # Grabs the unique ID for this browser tab sid = request.session_hash session = get_session(sid) x, y = evt.index file, rank = x // 80, 7 - (y // 80) clicked_sq = chess.square(int(file), int(rank)) # Check if it's currently the AI's turn ctrl = wm if session.board.turn == chess.WHITE else bm if ctrl != "Human": return render_board(session, "It is the AI's turn.") if session.selected_sq == clicked_sq: session.selected_sq = None elif session.selected_sq is None: piece = session.board.piece_at(clicked_sq) if piece and piece.color == session.board.turn: session.selected_sq = clicked_sq else: move = chess.Move(session.selected_sq, clicked_sq) if move in session.board.legal_moves or chess.Move(session.selected_sq, clicked_sq, chess.QUEEN) in session.board.legal_moves: if move not in session.board.legal_moves: move = chess.Move(session.selected_sq, clicked_sq, chess.QUEEN) session.board.push(move) session.selected_sq = None else: piece = session.board.piece_at(clicked_sq) if piece and piece.color == session.board.turn: session.selected_sq = clicked_sq else: session.selected_sq = None return render_board(session, f"{'White' if session.board.turn == chess.WHITE else 'Black'} to move") def ai_turn(sid, sims): session = get_session(sid) if session.board.is_game_over(): return render_board(session, f"GameOver: {session.board.result()}") # Catch cases where there are no legal moves to prevent crashes if list(session.board.legal_moves) == []: return render_board(session, "No legal moves.") root = session.mcts.search(session.board, simulations=int(sims), training=False) if not root.children: return render_board(session, "AI could not find a move.") move = max(root.children.items(), key=lambda x: x[1].N)[0] san = session.board.san(move) session.board.push(move) session.selected_sq = None return render_board(session, f"AI played {san}") def reset_all(request: gr.Request): sid = request.session_hash session = get_session(sid) session.board.reset() session.selected_sq = None session.mcts = AlphaMCTS(model, DEVICE) return render_board(session, "Game Reset") # ========================================== # 4. GRADIO UI # ========================================== # Note: Theme is now passed in launch() in Gradio 6.0 with gr.Blocks() as demo: gr.Markdown("# ♟️ AlphaZero Tactical Engine - Skill level 3") with gr.Row(): with gr.Column(scale=2): # Render a dummy starting board to show immediately on load dummy_session = GameSession() initial_img, _ = render_board(dummy_session) board_img = gr.Image(value=initial_img, width=640, height=640, interactive=False, label="Chess Board") status_txt = gr.Textbox(label="Status", value="Ready / White to move") with gr.Column(scale=1): w_mode = gr.Radio(["Human", "AI"], value="Human", label="White Player") b_mode = gr.Radio(["Human", "AI"], value="AI", label="Black Player") sims_val = gr.Slider(10, 1000, 400, label="Search Simulations") auto_go = gr.Checkbox(label="Enable Autoplay", value=True) reset_btn = gr.Button("Reset Game") def main_loop(is_auto, wm, bm, s, request: gr.Request): sid = request.session_hash session = get_session(sid) if not is_auto or session.board.is_game_over(): return gr.update(), gr.update() current_ctrl = wm if session.board.turn == chess.WHITE else bm if current_ctrl == "AI": return ai_turn(sid, s) return gr.update(), gr.update() # Pass only the UI components to inputs, request object is injected automatically board_img.select( handle_click, inputs=[w_mode, b_mode], outputs=[board_img, status_txt] ) reset_btn.click( reset_all, outputs=[board_img, status_txt] ) timer = gr.Timer(1.0) timer.tick( main_loop, inputs=[auto_go, w_mode, b_mode, sims_val], outputs=[board_img, status_txt] ) if __name__ == "__main__": # Updated launch args for Gradio 6.0 and HF Spaces demo.launch(theme=gr.themes.Soft())