RL-Chess / app.py
Gregniuki's picture
Update app.py
0c38210 verified
import gradio as gr
import chess
import torch
import numpy as np
import os
import io
import time
from collections import OrderedDict
from PIL import Image, ImageDraw, ImageFont
from RL_Chess_Alpha import AlphaChessNet, AlphaMCTS
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
WEIGHTS_PATH = "alpha_merged_weights.pth"
# ==========================================
# 1. GLOBAL RESOURCES (Loaded once for all)
# ==========================================
model = AlphaChessNet().to(DEVICE)
if os.path.exists(WEIGHTS_PATH):
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
model.eval()
def load_piece_images():
images = {}
mapping = {
'P': 'w_pawn_png_128px', 'R': 'w_rook_png_128px', 'N': 'w_knight_png_128px',
'B': 'w_bishop_png_128px', 'Q': 'w_queen_png_128px', 'K': 'w_king_png_128px',
'p': 'b_pawn_png_128px', 'r': 'b_rook_png_128px', 'n': 'b_knight_png_128px',
'b': 'b_bishop_png_128px', 'q': 'b_queen_png_128px', 'k': 'b_king_png_128px'
}
for code, name in mapping.items():
path = f"pieces/{name}.png"
if os.path.exists(path):
try:
img = Image.open(path).convert("RGBA")
images[code] = img.resize((80, 80), Image.LANCZOS)
except Exception as e:
print(f"Error loading {path}: {e}")
images[code] = None
else:
print(f"File not found: {path}")
images[code] = None
return images
piece_images = load_piece_images()
# ==========================================
# 2. SESSION MANAGEMENT (One per browser tab)
# ==========================================
class GameSession:
def __init__(self):
self.board = chess.Board()
self.selected_sq = None
self.mcts = AlphaMCTS(model, DEVICE)
# We use an OrderedDict to limit stored sessions and prevent memory leaks
sessions = OrderedDict()
def get_session(sid):
if sid not in sessions:
if len(sessions) > 100:
sessions.popitem(last=False) # Remove oldest session if we have too many
sessions[sid] = GameSession()
return sessions[sid]
# ==========================================
# 3. GAME LOGIC
# ==========================================
def render_board(session, status=""):
size = 640
sq = 80
img = Image.new('RGB', (size, size), '#769656')
draw = ImageDraw.Draw(img)
for r in range(8):
for f in range(8):
color = '#eeeed2' if (r + f) % 2 == 0 else '#769656'
square = chess.square(f, 7-r)
if session.selected_sq == square:
color = '#f5f682'
elif session.selected_sq is not None:
# Highlight legal moves for the selected piece
for move in session.board.legal_moves:
if move.from_square == session.selected_sq and move.to_square == square:
color = '#acdf91'
break
draw.rectangle([f*sq, r*sq, (f+1)*sq, (r+1)*sq], fill=color)
for s in chess.SQUARES:
p = session.board.piece_at(s)
if p:
f, r = chess.square_file(s), 7 - chess.square_rank(s)
img_p = piece_images.get(p.symbol())
if img_p:
img.paste(img_p, (f*sq, r*sq), img_p)
return img, status
def handle_click(evt: gr.SelectData, wm, bm, request: gr.Request):
# Grabs the unique ID for this browser tab
sid = request.session_hash
session = get_session(sid)
x, y = evt.index
file, rank = x // 80, 7 - (y // 80)
clicked_sq = chess.square(int(file), int(rank))
# Check if it's currently the AI's turn
ctrl = wm if session.board.turn == chess.WHITE else bm
if ctrl != "Human":
return render_board(session, "It is the AI's turn.")
if session.selected_sq == clicked_sq:
session.selected_sq = None
elif session.selected_sq is None:
piece = session.board.piece_at(clicked_sq)
if piece and piece.color == session.board.turn:
session.selected_sq = clicked_sq
else:
move = chess.Move(session.selected_sq, clicked_sq)
if move in session.board.legal_moves or chess.Move(session.selected_sq, clicked_sq, chess.QUEEN) in session.board.legal_moves:
if move not in session.board.legal_moves:
move = chess.Move(session.selected_sq, clicked_sq, chess.QUEEN)
session.board.push(move)
session.selected_sq = None
else:
piece = session.board.piece_at(clicked_sq)
if piece and piece.color == session.board.turn:
session.selected_sq = clicked_sq
else:
session.selected_sq = None
return render_board(session, f"{'White' if session.board.turn == chess.WHITE else 'Black'} to move")
def ai_turn(sid, sims):
session = get_session(sid)
if session.board.is_game_over():
return render_board(session, f"GameOver: {session.board.result()}")
# Catch cases where there are no legal moves to prevent crashes
if list(session.board.legal_moves) == []:
return render_board(session, "No legal moves.")
root = session.mcts.search(session.board, simulations=int(sims), training=False)
if not root.children:
return render_board(session, "AI could not find a move.")
move = max(root.children.items(), key=lambda x: x[1].N)[0]
san = session.board.san(move)
session.board.push(move)
session.selected_sq = None
return render_board(session, f"AI played {san}")
def reset_all(request: gr.Request):
sid = request.session_hash
session = get_session(sid)
session.board.reset()
session.selected_sq = None
session.mcts = AlphaMCTS(model, DEVICE)
return render_board(session, "Game Reset")
# ==========================================
# 4. GRADIO UI
# ==========================================
# Note: Theme is now passed in launch() in Gradio 6.0
with gr.Blocks() as demo:
gr.Markdown("# ♟️ AlphaZero Tactical Engine - Skill level 3")
with gr.Row():
with gr.Column(scale=2):
# Render a dummy starting board to show immediately on load
dummy_session = GameSession()
initial_img, _ = render_board(dummy_session)
board_img = gr.Image(value=initial_img, width=640, height=640, interactive=False, label="Chess Board")
status_txt = gr.Textbox(label="Status", value="Ready / White to move")
with gr.Column(scale=1):
w_mode = gr.Radio(["Human", "AI"], value="Human", label="White Player")
b_mode = gr.Radio(["Human", "AI"], value="AI", label="Black Player")
sims_val = gr.Slider(10, 1000, 400, label="Search Simulations")
auto_go = gr.Checkbox(label="Enable Autoplay", value=True)
reset_btn = gr.Button("Reset Game")
def main_loop(is_auto, wm, bm, s, request: gr.Request):
sid = request.session_hash
session = get_session(sid)
if not is_auto or session.board.is_game_over():
return gr.update(), gr.update()
current_ctrl = wm if session.board.turn == chess.WHITE else bm
if current_ctrl == "AI":
return ai_turn(sid, s)
return gr.update(), gr.update()
# Pass only the UI components to inputs, request object is injected automatically
board_img.select(
handle_click,
inputs=[w_mode, b_mode],
outputs=[board_img, status_txt]
)
reset_btn.click(
reset_all,
outputs=[board_img, status_txt]
)
timer = gr.Timer(1.0)
timer.tick(
main_loop,
inputs=[auto_go, w_mode, b_mode, sims_val],
outputs=[board_img, status_txt]
)
if __name__ == "__main__":
# Updated launch args for Gradio 6.0 and HF Spaces
demo.launch(theme=gr.themes.Soft())