| | """
|
| | generate.py β Play chess against the Liquid Chess Model (LCM).
|
| |
|
| | Moves are entered in UCI format (e.g. e2e4, g1f3, e7e8q).
|
| | The model responds instantly with its chosen move.
|
| |
|
| | Usage:
|
| | python generate.py
|
| | python generate.py --checkpoint model.safetensors --side black
|
| | python generate.py --temperature 0.5
|
| |
|
| | Commands during play:
|
| | moves β list all legal moves
|
| | undo β take back the last two moves
|
| | resign β resign the game
|
| | quit β exit
|
| |
|
| | Requirements:
|
| | pip install chess torch safetensors
|
| | """
|
| |
|
| | import argparse
|
| | import json
|
| | import sys
|
| | import torch
|
| | import torch.nn.functional as F
|
| | import chess
|
| |
|
| | sys.path.insert(0, ".")
|
| | from config import ChessModelConfig
|
| | from model import ChessModel
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def load_vocab(vocab_path: str) -> tuple[dict, dict]:
|
| | with open(vocab_path) as f:
|
| | token_to_id = json.load(f)
|
| | return token_to_id, {v: k for k, v in token_to_id.items()}
|
| |
|
| |
|
| | def load_model(checkpoint_path: str, device: torch.device) -> tuple[ChessModel, ChessModelConfig]:
|
| | config = ChessModelConfig()
|
| | model = ChessModel(config)
|
| |
|
| | if checkpoint_path.endswith(".safetensors"):
|
| | from safetensors.torch import load_model as load_safetensors
|
| | load_safetensors(model, checkpoint_path)
|
| | else:
|
| | ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| | state = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()}
|
| | model.load_state_dict(state)
|
| |
|
| | model.to(device).eval()
|
| | return model, config
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_model_move(
|
| | model: ChessModel,
|
| | config: ChessModelConfig,
|
| | board: chess.Board,
|
| | move_history: list[str],
|
| | token_to_id: dict,
|
| | id_to_token: dict,
|
| | device: torch.device,
|
| | temperature: float = 1.0,
|
| | top_k: int = 0,
|
| | ) -> str | None:
|
| | """Return the model's chosen move in UCI format."""
|
| | pov_id = token_to_id.get("<W>" if board.turn == chess.WHITE else "<B>", 1)
|
| | token_ids = [pov_id]
|
| |
|
| | for uci in move_history:
|
| | tid = token_to_id.get(uci)
|
| | if tid is not None:
|
| | token_ids.append(tid)
|
| |
|
| | token_ids = token_ids[-(config.max_seq_len - 1):]
|
| | input_tensor = torch.tensor([token_ids], dtype=torch.long, device=device)
|
| |
|
| | with torch.no_grad():
|
| | ntp_logits, _ = model(input_tensor)
|
| |
|
| | logits = ntp_logits[0, -1, :]
|
| | legal_ucis = [m.uci() for m in board.legal_moves]
|
| | if not legal_ucis:
|
| | return None
|
| |
|
| | legal_ids = [token_to_id[u] for u in legal_ucis if u in token_to_id]
|
| | if not legal_ids:
|
| | import random
|
| | return random.choice(legal_ucis)
|
| |
|
| | mask = torch.full_like(logits, float("-inf"))
|
| | mask[legal_ids] = logits[legal_ids]
|
| | mask = mask / max(temperature, 1e-6)
|
| |
|
| | if top_k > 0:
|
| | top_vals, _ = torch.topk(mask[mask != float("-inf")], min(top_k, len(legal_ids)))
|
| | mask[mask < top_vals[-1]] = float("-inf")
|
| |
|
| | if temperature < 0.01:
|
| | chosen_id = torch.argmax(mask).item()
|
| | else:
|
| | chosen_id = torch.multinomial(F.softmax(mask, dim=-1), num_samples=1).item()
|
| |
|
| | return id_to_token.get(chosen_id, legal_ucis[0])
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | PIECE_SYMBOLS = {
|
| | 'P': 'β', 'N': 'β', 'B': 'β', 'R': 'β', 'Q': 'β', 'K': 'β',
|
| | 'p': 'β', 'n': 'β', 'b': 'β', 'r': 'β', 'q': 'β', 'k': 'β',
|
| | }
|
| |
|
| | def print_board(board: chess.Board, player_is_white: bool):
|
| | print()
|
| | ranks = range(7, -1, -1) if player_is_white else range(8)
|
| | files = range(8) if player_is_white else range(7, -1, -1)
|
| |
|
| | for rank in ranks:
|
| | row = f" {rank + 1} "
|
| | for file in files:
|
| | piece = board.piece_at(chess.square(file, rank))
|
| | row += (PIECE_SYMBOLS.get(piece.symbol(), '?') if piece else 'Β·') + " "
|
| | print(row)
|
| |
|
| | print(" a b c d e f g h" if player_is_white else " h g f e d c b a")
|
| | print()
|
| |
|
| |
|
| | def print_status(board: chess.Board):
|
| | if board.is_checkmate():
|
| | winner = "Black" if board.turn == chess.WHITE else "White"
|
| | print(f"\n{'='*38}\n CHECKMATE β {winner} wins!\n{'='*38}\n")
|
| | elif board.is_stalemate():
|
| | print("\nStalemate β draw.")
|
| | elif board.is_insufficient_material():
|
| | print("\nInsufficient material β draw.")
|
| | elif board.is_fifty_moves():
|
| | print("\n50-move rule β draw.")
|
| | elif board.is_check():
|
| | print(" *** CHECK ***")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def play(args):
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| | token_to_id, id_to_token = load_vocab(args.vocab)
|
| | model, config = load_model(args.checkpoint, device)
|
| |
|
| | player_is_white = args.side.lower() != "black"
|
| | print(f"\nLiquid Chess Model β {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
|
| | print(f"You are playing as {'White' if player_is_white else 'Black'}.")
|
| | print("Enter moves in UCI format (e.g. e2e4, g1f3). Type 'moves', 'undo', 'resign', or 'quit'.\n")
|
| |
|
| | board = chess.Board()
|
| | move_history = []
|
| |
|
| | while not board.is_game_over():
|
| | print_board(board, player_is_white)
|
| | print_status(board)
|
| |
|
| | if board.is_game_over():
|
| | break
|
| |
|
| | is_player_turn = (board.turn == chess.WHITE) == player_is_white
|
| |
|
| | if is_player_turn:
|
| | while True:
|
| | try:
|
| | raw = input("Your move: ").strip().lower()
|
| | except EOFError:
|
| | return
|
| |
|
| | if raw == "quit":
|
| | print("Goodbye!")
|
| | return
|
| | if raw == "resign":
|
| | print("You resigned.")
|
| | return
|
| | if raw == "moves":
|
| | print("Legal moves:", ", ".join(sorted(m.uci() for m in board.legal_moves)))
|
| | continue
|
| | if raw == "undo" and len(move_history) >= 2:
|
| | board.pop(); board.pop()
|
| | move_history = move_history[:-2]
|
| | print("Undone.")
|
| | break
|
| |
|
| | try:
|
| | move = chess.Move.from_uci(raw)
|
| | if move in board.legal_moves:
|
| | board.push(move)
|
| | move_history.append(raw)
|
| | break
|
| | else:
|
| | print(f"Illegal move: {raw}")
|
| | except ValueError:
|
| | print(f"Invalid format: {raw} β use UCI (e.g. e2e4)")
|
| | else:
|
| | print("Model is thinking...")
|
| | move_uci = get_model_move(
|
| | model, config, board, move_history,
|
| | token_to_id, id_to_token, device,
|
| | temperature=args.temperature,
|
| | top_k=args.top_k,
|
| | )
|
| | if move_uci is None:
|
| | break
|
| | board.push(chess.Move.from_uci(move_uci))
|
| | move_history.append(move_uci)
|
| | print(f"Model plays: {move_uci}")
|
| |
|
| | print_board(board, player_is_white)
|
| | print_status(board)
|
| | print(f"Result: {board.result()}")
|
| | print(f"Moves: {' '.join(move_history)}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser(description="Play chess against the Liquid Chess Model.")
|
| | parser.add_argument("--checkpoint", default="model.safetensors",
|
| | help="Path to model checkpoint (.safetensors or .pt)")
|
| | parser.add_argument("--vocab", default="vocab.json",
|
| | help="Path to vocab.json")
|
| | parser.add_argument("--side", default="white", choices=["white", "black"],
|
| | help="Your color (default: white)")
|
| | parser.add_argument("--temperature", type=float, default=1.0,
|
| | help="Sampling temperature β lower is more deterministic (default: 1.0)")
|
| | parser.add_argument("--top-k", type=int, default=0,
|
| | help="Top-k filtering β 0 disables (default: 0)")
|
| | play(parser.parse_args()) |