gomoku-ai-code / gomoku_mcts.py
lcccluck's picture
Upload Gomoku training and MCTS code
63cdefe verified
#!/usr/bin/env python3
"""Minimal Gomoku MCTS example.
This file is intentionally separate from gomoku_pg.py.
It uses the simpler AlphaZero-style recipe:
1. self-play with MCTS
2. policy/value targets from search
3. supervised update on policy + value heads
"""
from __future__ import annotations
import argparse
import math
import random
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import torch
from torch import nn
def choose_device(name: str) -> torch.device:
if name != "auto":
return torch.device(name)
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
class GomokuEnv:
def __init__(self, board_size: int, win_length: int):
if board_size <= 1:
raise ValueError("board_size must be > 1")
if not 1 < win_length <= board_size:
raise ValueError("win_length must satisfy 1 < win_length <= board_size")
self.board_size = board_size
self.win_length = win_length
self.reset()
def reset(self) -> np.ndarray:
self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
self.current_player = 1
self.done = False
self.winner = 0
return self.board
def valid_moves(self) -> np.ndarray:
return np.flatnonzero((self.board == 0).reshape(-1))
def step(self, action: int) -> tuple[bool, int]:
next_board, next_player, done, winner = apply_action_to_board(
board=self.board,
current_player=self.current_player,
action=action,
win_length=self.win_length,
)
self.board = next_board
self.current_player = next_player
self.done = done
self.winner = winner
return done, winner
def render(self) -> str:
symbols = {1: "X", -1: "O", 0: "."}
header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size))
rows = [header]
for row_idx in range(self.board_size):
row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx])
rows.append(f"{row_idx + 1:2d} {row}")
return "\n".join(rows)
def action_to_coords(action: int, board_size: int) -> tuple[int, int]:
return divmod(int(action), board_size)
def coords_to_action(row: int, col: int, board_size: int) -> int:
return row * board_size + col
def count_one_side(
board: np.ndarray,
row: int,
col: int,
dr: int,
dc: int,
player: int,
) -> int:
board_size = board.shape[0]
total = 0
r, c = row + dr, col + dc
while 0 <= r < board_size and 0 <= c < board_size:
if board[r, c] != player:
break
total += 1
r += dr
c += dc
return total
def is_winning_move(
board: np.ndarray,
row: int,
col: int,
player: int,
win_length: int,
) -> bool:
directions = ((1, 0), (0, 1), (1, 1), (1, -1))
for dr, dc in directions:
count = 1
count += count_one_side(board, row, col, dr, dc, player)
count += count_one_side(board, row, col, -dr, -dc, player)
if count >= win_length:
return True
return False
def apply_action_to_board(
board: np.ndarray,
current_player: int,
action: int,
win_length: int,
) -> tuple[np.ndarray, int, bool, int]:
board_size = board.shape[0]
row, col = action_to_coords(action, board_size)
if board[row, col] != 0:
raise ValueError(f"illegal move at ({row}, {col})")
next_board = board.copy()
next_board[row, col] = current_player
if is_winning_move(next_board, row, col, current_player, win_length):
return next_board, -current_player, True, current_player
if not np.any(next_board == 0):
return next_board, -current_player, True, 0
return next_board, -current_player, False, 0
def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor:
current = (board == current_player).astype(np.float32)
opponent = (board == -current_player).astype(np.float32)
legal = (board == 0).astype(np.float32)
return torch.from_numpy(np.stack([current, opponent, legal], axis=0))
class PolicyValueNet(nn.Module):
def __init__(self, channels: int = 64):
super().__init__()
self.trunk = nn.Sequential(
nn.Conv2d(3, channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.ReLU(),
)
self.policy_head = nn.Conv2d(channels, 1, kernel_size=1)
self.value_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(channels, channels),
nn.ReLU(),
nn.Linear(channels, 1),
nn.Tanh(),
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
features = self.trunk(x)
policy_logits = self.policy_head(features).flatten(start_dim=1)
value = self.value_head(features).squeeze(-1)
return policy_logits, value
def masked_logits(logits: torch.Tensor, board: np.ndarray) -> torch.Tensor:
legal = torch.as_tensor((board == 0).reshape(-1), device=logits.device, dtype=torch.bool)
return logits.masked_fill(~legal, -1e9)
def evaluate_policy_value(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
device: torch.device,
) -> tuple[np.ndarray, float]:
state = encode_state(board, current_player).unsqueeze(0).to(device)
with torch.no_grad():
logits, value = policy(state)
logits = masked_logits(logits.squeeze(0), board)
probs = torch.softmax(logits, dim=0).cpu().numpy()
return probs, float(value.item())
@dataclass
class MCTSNode:
board: np.ndarray
current_player: int
win_length: int
done: bool = False
winner: int = 0
priors: dict[int, float] = field(default_factory=dict)
visit_counts: dict[int, int] = field(default_factory=dict)
value_sums: dict[int, float] = field(default_factory=dict)
children: dict[int, "MCTSNode"] = field(default_factory=dict)
expanded: bool = False
def expand(
self,
priors: np.ndarray,
add_noise: bool = False,
dirichlet_alpha: float = 0.3,
noise_eps: float = 0.25,
) -> None:
legal_actions = np.flatnonzero((self.board == 0).reshape(-1))
legal_priors = priors[legal_actions]
total_prob = float(np.sum(legal_priors))
if total_prob <= 0.0:
legal_priors = np.full(len(legal_actions), 1.0 / max(len(legal_actions), 1), dtype=np.float32)
else:
legal_priors = legal_priors / total_prob
if add_noise and len(legal_actions) > 0:
noise = np.random.dirichlet([dirichlet_alpha] * len(legal_actions))
legal_priors = (1.0 - noise_eps) * legal_priors + noise_eps * noise
self.priors = {
int(action): float(prior)
for action, prior in zip(legal_actions, legal_priors, strict=False)
}
self.visit_counts = {action: 0 for action in self.priors}
self.value_sums = {action: 0.0 for action in self.priors}
self.expanded = True
def q_value(self, action: int) -> float:
visits = self.visit_counts[action]
if visits == 0:
return 0.0
return self.value_sums[action] / visits
def select_action(self, c_puct: float) -> int:
total_visits = sum(self.visit_counts.values())
sqrt_total = math.sqrt(total_visits + 1.0)
best_action = -1
best_score = -float("inf")
for action, prior in self.priors.items():
q = self.q_value(action)
u = c_puct * prior * sqrt_total / (1.0 + self.visit_counts[action])
score = q + u
if score > best_score:
best_score = score
best_action = action
return best_action
def child_for_action(self, action: int) -> "MCTSNode":
child = self.children.get(action)
if child is not None:
return child
next_board, next_player, done, winner = apply_action_to_board(
board=self.board,
current_player=self.current_player,
action=action,
win_length=self.win_length,
)
child = MCTSNode(
board=next_board,
current_player=next_player,
win_length=self.win_length,
done=done,
winner=winner,
)
self.children[action] = child
return child
def terminal_value(winner: int, current_player: int) -> float:
if winner == 0:
return 0.0
return 1.0 if winner == current_player else -1.0
def sample_from_visits(visits: np.ndarray, temperature: float) -> tuple[int, np.ndarray]:
flat = visits.reshape(-1).astype(np.float64)
if np.all(flat == 0):
flat = np.ones_like(flat)
if temperature <= 1e-6:
probs = np.zeros_like(flat, dtype=np.float64)
probs[int(np.argmax(flat))] = 1.0
else:
adjusted = np.power(flat, 1.0 / temperature)
probs = adjusted / np.sum(adjusted)
action = int(np.random.choice(len(probs), p=probs))
return action, probs.reshape(visits.shape).astype(np.float32)
def choose_mcts_action(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
win_length: int,
device: torch.device,
num_simulations: int,
c_puct: float,
temperature: float,
add_root_noise: bool,
dirichlet_alpha: float,
noise_eps: float,
) -> tuple[int, np.ndarray]:
root = MCTSNode(board=board.copy(), current_player=current_player, win_length=win_length)
priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device)
root.expand(
priors,
add_noise=add_root_noise,
dirichlet_alpha=dirichlet_alpha,
noise_eps=noise_eps,
)
for _ in range(num_simulations):
node = root
path: list[tuple[MCTSNode, int]] = []
while node.expanded and not node.done:
action = node.select_action(c_puct)
path.append((node, action))
node = node.child_for_action(action)
if node.done:
value = terminal_value(node.winner, node.current_player)
else:
priors, value = evaluate_policy_value(policy, node.board, node.current_player, device)
node.expand(priors)
for parent, action in reversed(path):
value = -value
parent.visit_counts[action] += 1
parent.value_sums[action] += value
visits = np.zeros(board.shape, dtype=np.float32)
for action, count in root.visit_counts.items():
row, col = action_to_coords(action, board.shape[0])
visits[row, col] = float(count)
action, visit_probs = sample_from_visits(visits, temperature=temperature)
return action, visit_probs
def choose_ai_action(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
win_length: int,
device: torch.device,
agent: str,
mcts_sims: int,
c_puct: float,
) -> tuple[int, np.ndarray | None]:
if agent == "policy":
priors, _ = evaluate_policy_value(policy, board, current_player, device)
return int(np.argmax(priors)), None
return choose_mcts_action(
policy=policy,
board=board,
current_player=current_player,
win_length=win_length,
device=device,
num_simulations=mcts_sims,
c_puct=c_puct,
temperature=1e-6,
add_root_noise=False,
dirichlet_alpha=0.3,
noise_eps=0.25,
)
def self_play_game(
policy: PolicyValueNet,
board_size: int,
win_length: int,
device: torch.device,
mcts_sims: int,
c_puct: float,
temperature: float,
temperature_drop_moves: int,
dirichlet_alpha: float,
noise_eps: float,
) -> tuple[list[tuple[torch.Tensor, np.ndarray, float]], int, int]:
env = GomokuEnv(board_size=board_size, win_length=win_length)
env.reset()
history: list[tuple[torch.Tensor, np.ndarray, int]] = []
move_idx = 0
while not env.done:
move_temp = temperature if move_idx < temperature_drop_moves else 1e-6
action, visit_probs = choose_mcts_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
num_simulations=mcts_sims,
c_puct=c_puct,
temperature=move_temp,
add_root_noise=True,
dirichlet_alpha=dirichlet_alpha,
noise_eps=noise_eps,
)
history.append((encode_state(env.board, env.current_player), visit_probs.reshape(-1), env.current_player))
env.step(action)
move_idx += 1
examples: list[tuple[torch.Tensor, np.ndarray, float]] = []
for state, visit_probs, player in history:
if env.winner == 0:
outcome = 0.0
else:
outcome = 1.0 if player == env.winner else -1.0
examples.append((state, visit_probs, outcome))
return examples, env.winner, move_idx
def train_batch(
policy: PolicyValueNet,
optimizer: torch.optim.Optimizer,
batch: list[tuple[torch.Tensor, np.ndarray, float]],
device: torch.device,
value_coef: float,
) -> tuple[float, float, float]:
states = torch.stack([item[0] for item in batch]).to(device)
target_policies = torch.tensor(
np.stack([item[1] for item in batch]),
dtype=torch.float32,
device=device,
)
target_values = torch.tensor([item[2] for item in batch], dtype=torch.float32, device=device)
logits, values = policy(states)
log_probs = torch.log_softmax(logits, dim=1)
policy_loss = -(target_policies * log_probs).sum(dim=1).mean()
value_loss = torch.mean((values - target_values) ** 2)
loss = policy_loss + value_coef * value_loss
optimizer.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
optimizer.step()
return float(loss.item()), float(policy_loss.item()), float(value_loss.item())
def save_checkpoint(path: Path, policy: PolicyValueNet, args: argparse.Namespace) -> None:
torch.save(
{
"state_dict": policy.state_dict(),
"channels": args.channels,
"board_size": args.board_size,
"win_length": args.win_length,
},
path,
)
def last_checkpoint_path(base_path: Path) -> Path:
return base_path.with_name(f"{base_path.stem}_last{base_path.suffix}")
def load_checkpoint(path: Path, map_location: torch.device) -> dict:
checkpoint = torch.load(path, map_location=map_location)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
return checkpoint
raise RuntimeError(f"{path} is not a compatible gomoku_mcts checkpoint")
def resolve_game_config(
checkpoint_path: Path,
arg_board_size: int | None,
arg_win_length: int | None,
arg_channels: int,
device: torch.device,
) -> tuple[PolicyValueNet, int, int]:
checkpoint = load_checkpoint(checkpoint_path, map_location=device)
board_size = int(checkpoint.get("board_size") or arg_board_size or 15)
win_length = int(checkpoint.get("win_length") or arg_win_length or 5)
channels = int(checkpoint.get("channels") or arg_channels)
policy = PolicyValueNet(channels=channels).to(device)
policy.load_state_dict(checkpoint["state_dict"])
policy.eval()
return policy, board_size, win_length
def play_vs_random_once(
policy: PolicyValueNet,
board_size: int,
win_length: int,
device: torch.device,
policy_player: int,
agent: str,
mcts_sims: int,
c_puct: float,
) -> int:
env = GomokuEnv(board_size=board_size, win_length=win_length)
env.reset()
while not env.done:
if env.current_player == policy_player:
action, _ = choose_ai_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
agent=agent,
mcts_sims=mcts_sims,
c_puct=c_puct,
)
else:
action = int(np.random.choice(env.valid_moves()))
env.step(action)
return env.winner
def evaluate_vs_random(
policy: PolicyValueNet,
board_size: int,
win_length: int,
device: torch.device,
games: int,
agent: str,
mcts_sims: int,
c_puct: float,
) -> tuple[float, int, int, int]:
wins = 0
draws = 0
losses = 0
for game_idx in range(games):
policy_player = 1 if game_idx < games // 2 else -1
winner = play_vs_random_once(
policy=policy,
board_size=board_size,
win_length=win_length,
device=device,
policy_player=policy_player,
agent=agent,
mcts_sims=mcts_sims,
c_puct=c_puct,
)
if winner == 0:
draws += 1
elif winner == policy_player:
wins += 1
else:
losses += 1
return wins / max(games, 1), wins, draws, losses
def train(args: argparse.Namespace) -> None:
set_seed(args.seed)
device = choose_device(args.device)
policy = PolicyValueNet(channels=args.channels).to(device)
if args.init_checkpoint is not None and args.init_checkpoint.exists():
checkpoint = load_checkpoint(args.init_checkpoint, map_location=device)
policy.load_state_dict(checkpoint["state_dict"])
optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.weight_decay)
replay_buffer: deque[tuple[torch.Tensor, np.ndarray, float]] = deque(maxlen=args.buffer_size)
print(f"device={device} board={args.board_size} win={args.win_length}")
for iteration in range(1, args.iterations + 1):
policy.eval()
winners: list[int] = []
lengths: list[int] = []
for _ in range(args.games_per_iter):
examples, winner, moves = self_play_game(
policy=policy,
board_size=args.board_size,
win_length=args.win_length,
device=device,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
temperature=args.temperature,
temperature_drop_moves=args.temperature_drop_moves,
dirichlet_alpha=args.dirichlet_alpha,
noise_eps=args.noise_eps,
)
replay_buffer.extend(examples)
winners.append(winner)
lengths.append(moves)
losses: list[tuple[float, float, float]] = []
if len(replay_buffer) >= args.batch_size:
policy.train()
for _ in range(args.train_steps):
batch = random.sample(replay_buffer, args.batch_size)
losses.append(
train_batch(
policy=policy,
optimizer=optimizer,
batch=batch,
device=device,
value_coef=args.value_coef,
)
)
avg_loss = float(np.mean([x[0] for x in losses])) if losses else 0.0
avg_policy_loss = float(np.mean([x[1] for x in losses])) if losses else 0.0
avg_value_loss = float(np.mean([x[2] for x in losses])) if losses else 0.0
p1_wins = sum(1 for x in winners if x == 1)
p2_wins = sum(1 for x in winners if x == -1)
draws = sum(1 for x in winners if x == 0)
avg_len = float(np.mean(lengths)) if lengths else 0.0
message = (
f"iter={iteration:5d} loss={avg_loss:7.4f} policy={avg_policy_loss:7.4f} "
f"value={avg_value_loss:7.4f} p1={p1_wins:3d} p2={p2_wins:3d} draw={draws:3d} "
f"avg_len={avg_len:6.2f} buffer={len(replay_buffer):6d}"
)
if args.eval_every > 0 and iteration % args.eval_every == 0:
policy.eval()
win_rate, wins, eval_draws, eval_losses = evaluate_vs_random(
policy=policy,
board_size=args.board_size,
win_length=args.win_length,
device=device,
games=args.eval_games,
agent="mcts",
mcts_sims=args.eval_mcts_sims,
c_puct=args.c_puct,
)
message += f" random_win_rate={win_rate:.3f} ({wins}/{eval_draws}/{eval_losses})"
print(message)
if args.save_every > 0 and iteration % args.save_every == 0:
checkpoint_path = last_checkpoint_path(args.checkpoint)
save_checkpoint(checkpoint_path, policy, args)
print(f"saved checkpoint to {checkpoint_path}")
save_checkpoint(args.checkpoint, policy, args)
print(f"saved checkpoint to {args.checkpoint}")
def evaluate(args: argparse.Namespace) -> None:
device = choose_device(args.device)
policy, board_size, win_length = resolve_game_config(
checkpoint_path=args.checkpoint,
arg_board_size=args.board_size,
arg_win_length=args.win_length,
arg_channels=args.channels,
device=device,
)
win_rate, wins, draws, losses = evaluate_vs_random(
policy=policy,
board_size=board_size,
win_length=win_length,
device=device,
games=args.games,
agent=args.agent,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
)
print(f"device={device}")
print(f"agent={args.agent} mcts_sims={args.mcts_sims}")
print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}")
def ask_human_move(env: GomokuEnv) -> int:
while True:
text = input("your move (row col): ").strip()
parts = text.replace(",", " ").split()
if len(parts) != 2:
print("please enter: row col")
continue
try:
row, col = (int(parts[0]) - 1, int(parts[1]) - 1)
except ValueError:
print("row and col must be integers")
continue
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
print("move out of range")
continue
if env.board[row, col] != 0:
print("that position is occupied")
continue
return coords_to_action(row, col, env.board_size)
def play(args: argparse.Namespace) -> None:
device = choose_device(args.device)
policy, board_size, win_length = resolve_game_config(
checkpoint_path=args.checkpoint,
arg_board_size=args.board_size,
arg_win_length=args.win_length,
arg_channels=args.channels,
device=device,
)
env = GomokuEnv(board_size=board_size, win_length=win_length)
human_player = 1 if args.human_first else -1
print(f"device={device}")
print(
f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} "
f"agent={args.agent} mcts_sims={args.mcts_sims}"
)
while not env.done:
print()
print(env.render())
print()
if env.current_player == human_player:
action = ask_human_move(env)
else:
action, _ = choose_ai_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
agent=args.agent,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
)
row, col = action_to_coords(action, env.board_size)
print(f"ai move: {row + 1} {col + 1}")
env.step(action)
print()
print(env.render())
if env.winner == 0:
print("draw")
elif env.winner == human_player:
print("you win")
else:
print("ai wins")
def gui(args: argparse.Namespace) -> None:
try:
import pygame
except ModuleNotFoundError as exc:
raise SystemExit(
"pygame is not installed. Install it with: "
"~/miniconda3/bin/conda run -n lerobot python -m pip install pygame"
) from exc
device = choose_device(args.device)
policy, board_size, win_length = resolve_game_config(
checkpoint_path=args.checkpoint,
arg_board_size=args.board_size,
arg_win_length=args.win_length,
arg_channels=args.channels,
device=device,
)
env = GomokuEnv(board_size=board_size, win_length=win_length)
human_player = 1 if args.human_first else -1
last_search_visits: np.ndarray | None = None
pygame.init()
pygame.display.set_caption("Gomoku MCTS")
font = pygame.font.SysFont("Arial", 24)
small_font = pygame.font.SysFont("Arial", 18)
cell_size = args.cell_size
padding = 40
status_height = 80
board_pixels = board_size * cell_size
screen = pygame.display.set_mode(
(board_pixels + padding * 2, board_pixels + padding * 2 + status_height)
)
clock = pygame.time.Clock()
background = (236, 196, 122)
line_color = (80, 55, 20)
black_stone = (20, 20, 20)
white_stone = (245, 245, 245)
accent = (180, 40, 40)
def board_to_screen(row: int, col: int) -> tuple[int, int]:
x = padding + col * cell_size + cell_size // 2
y = padding + row * cell_size + cell_size // 2
return x, y
def mouse_to_action(pos: tuple[int, int]) -> int | None:
x, y = pos
if x < padding or y < padding:
return None
col = (x - padding) // cell_size
row = (y - padding) // cell_size
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
return None
if env.board[row, col] != 0:
return None
return coords_to_action(row, col, env.board_size)
def ai_step() -> None:
nonlocal last_search_visits
if env.done or env.current_player == human_player:
return
action, visits = choose_ai_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
agent=args.agent,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
)
last_search_visits = visits
env.step(action)
def restart() -> None:
nonlocal last_search_visits
env.reset()
last_search_visits = None
if env.current_player != human_player:
ai_step()
def status_text() -> str:
if env.done:
if env.winner == 0:
return "Draw. Press R to restart."
if env.winner == human_player:
return "You win. Press R to restart."
return "AI wins. Press R to restart."
if env.current_player == human_player:
return "Your turn. Left click to place."
return "AI is thinking..."
if env.current_player != human_player:
ai_step()
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
running = False
elif event.key == pygame.K_r:
restart()
elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
if env.done or env.current_player != human_player:
continue
action = mouse_to_action(event.pos)
if action is None:
continue
env.step(action)
ai_step()
screen.fill(background)
for idx in range(board_size + 1):
x = padding + idx * cell_size
y = padding + idx * cell_size
pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2)
pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2)
for row in range(env.board_size):
for col in range(env.board_size):
stone = int(env.board[row, col])
if stone == 0:
continue
x, y = board_to_screen(row, col)
color = black_stone if stone == 1 else white_stone
pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4)
pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1)
for idx in range(board_size):
label = small_font.render(str(idx + 1), True, line_color)
screen.blit(label, (padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8))
screen.blit(label, (8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2))
info = (
f"{board_size}x{board_size} connect={win_length} device={device} "
f"agent={args.agent} sims={args.mcts_sims}"
)
screen.blit(small_font.render(info, True, line_color), (padding, padding + board_pixels + 16))
screen.blit(font.render(status_text(), True, accent), (padding, padding + board_pixels + 42))
if last_search_visits is not None and args.agent == "mcts":
peak = int(np.max(last_search_visits))
screen.blit(
small_font.render(f"peak_visits={peak}", True, line_color),
(padding + 420, padding + board_pixels + 16),
)
pygame.display.flip()
clock.tick(args.fps)
pygame.quit()
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Minimal Gomoku MCTS example")
subparsers = parser.add_subparsers(dest="mode", required=True)
def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None:
board_default = None if defaults_from_checkpoint else 15
win_default = None if defaults_from_checkpoint else 5
subparser.add_argument("--board-size", type=int, default=board_default)
subparser.add_argument("--win-length", type=int, default=win_default)
subparser.add_argument("--channels", type=int, default=64)
subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_mcts.pt"))
def add_inference_arguments(subparser: argparse.ArgumentParser) -> None:
subparser.add_argument("--agent", choices=["policy", "mcts"], default="mcts")
subparser.add_argument("--mcts-sims", type=int, default=120)
subparser.add_argument("--c-puct", type=float, default=1.5)
train_parser = subparsers.add_parser("train", help="MCTS self-play training")
add_common_arguments(train_parser)
train_parser.add_argument("--iterations", type=int, default=200)
train_parser.add_argument("--games-per-iter", type=int, default=8)
train_parser.add_argument("--train-steps", type=int, default=32)
train_parser.add_argument("--batch-size", type=int, default=64)
train_parser.add_argument("--buffer-size", type=int, default=20000)
train_parser.add_argument("--lr", type=float, default=1e-3)
train_parser.add_argument("--weight-decay", type=float, default=1e-4)
train_parser.add_argument("--value-coef", type=float, default=1.0)
train_parser.add_argument("--mcts-sims", type=int, default=64)
train_parser.add_argument("--eval-mcts-sims", type=int, default=120)
train_parser.add_argument("--c-puct", type=float, default=1.5)
train_parser.add_argument("--temperature", type=float, default=1.0)
train_parser.add_argument("--temperature-drop-moves", type=int, default=8)
train_parser.add_argument("--dirichlet-alpha", type=float, default=0.3)
train_parser.add_argument("--noise-eps", type=float, default=0.25)
train_parser.add_argument("--eval-every", type=int, default=10)
train_parser.add_argument("--eval-games", type=int, default=20)
train_parser.add_argument("--save-every", type=int, default=10)
train_parser.add_argument("--seed", type=int, default=42)
train_parser.add_argument("--init-checkpoint", type=Path, default=None)
train_parser.set_defaults(func=train)
eval_parser = subparsers.add_parser("eval", help="evaluate against random agent")
add_common_arguments(eval_parser)
add_inference_arguments(eval_parser)
eval_parser.add_argument("--games", type=int, default=40)
eval_parser.set_defaults(func=evaluate)
play_parser = subparsers.add_parser("play", help="play against the model")
add_common_arguments(play_parser, defaults_from_checkpoint=True)
add_inference_arguments(play_parser)
play_parser.add_argument("--human-first", action="store_true")
play_parser.set_defaults(func=play)
gui_parser = subparsers.add_parser("gui", help="pygame GUI")
add_common_arguments(gui_parser, defaults_from_checkpoint=True)
add_inference_arguments(gui_parser)
gui_parser.add_argument("--human-first", action="store_true")
gui_parser.add_argument("--cell-size", type=int, default=48)
gui_parser.add_argument("--fps", type=int, default=30)
gui_parser.set_defaults(func=gui)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()