""" generate.py — Play chess against the Liquid Chess Model (LCM). Moves are entered in UCI format (e.g. e2e4, g1f3, e7e8q). The model responds instantly with its chosen move. Usage: python generate.py python generate.py --checkpoint model.safetensors --side black python generate.py --temperature 0.5 Commands during play: moves — list all legal moves undo — take back the last two moves resign — resign the game quit — exit Requirements: pip install chess torch safetensors """ import argparse import json import sys import torch import torch.nn.functional as F import chess sys.path.insert(0, ".") from config import ChessModelConfig from model import ChessModel # ══════════════════════════════════════════════════════════════════════════════ # LOADING # ══════════════════════════════════════════════════════════════════════════════ def load_vocab(vocab_path: str) -> tuple[dict, dict]: with open(vocab_path) as f: token_to_id = json.load(f) return token_to_id, {v: k for k, v in token_to_id.items()} def load_model(checkpoint_path: str, device: torch.device) -> tuple[ChessModel, ChessModelConfig]: config = ChessModelConfig() model = ChessModel(config) if checkpoint_path.endswith(".safetensors"): from safetensors.torch import load_model as load_safetensors load_safetensors(model, checkpoint_path) else: ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) state = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()} model.load_state_dict(state) model.to(device).eval() return model, config # ══════════════════════════════════════════════════════════════════════════════ # INFERENCE # ══════════════════════════════════════════════════════════════════════════════ def get_model_move( model: ChessModel, config: ChessModelConfig, board: chess.Board, move_history: list[str], token_to_id: dict, id_to_token: dict, device: torch.device, temperature: float = 1.0, top_k: int = 0, ) -> str | None: """Return the model's chosen move in UCI format.""" pov_id = token_to_id.get("" if board.turn == chess.WHITE else "", 1) token_ids = [pov_id] for uci in move_history: tid = token_to_id.get(uci) if tid is not None: token_ids.append(tid) token_ids = token_ids[-(config.max_seq_len - 1):] input_tensor = torch.tensor([token_ids], dtype=torch.long, device=device) with torch.no_grad(): ntp_logits, _ = model(input_tensor) logits = ntp_logits[0, -1, :] legal_ucis = [m.uci() for m in board.legal_moves] if not legal_ucis: return None legal_ids = [token_to_id[u] for u in legal_ucis if u in token_to_id] if not legal_ids: import random return random.choice(legal_ucis) mask = torch.full_like(logits, float("-inf")) mask[legal_ids] = logits[legal_ids] mask = mask / max(temperature, 1e-6) if top_k > 0: top_vals, _ = torch.topk(mask[mask != float("-inf")], min(top_k, len(legal_ids))) mask[mask < top_vals[-1]] = float("-inf") if temperature < 0.01: chosen_id = torch.argmax(mask).item() else: chosen_id = torch.multinomial(F.softmax(mask, dim=-1), num_samples=1).item() return id_to_token.get(chosen_id, legal_ucis[0]) # ══════════════════════════════════════════════════════════════════════════════ # DISPLAY # ══════════════════════════════════════════════════════════════════════════════ PIECE_SYMBOLS = { 'P': '♟', 'N': '♞', 'B': '♝', 'R': '♜', 'Q': '♛', 'K': '♚', 'p': '♙', 'n': '♘', 'b': '♗', 'r': '♖', 'q': '♕', 'k': '♔', } def print_board(board: chess.Board, player_is_white: bool): print() ranks = range(7, -1, -1) if player_is_white else range(8) files = range(8) if player_is_white else range(7, -1, -1) for rank in ranks: row = f" {rank + 1} " for file in files: piece = board.piece_at(chess.square(file, rank)) row += (PIECE_SYMBOLS.get(piece.symbol(), '?') if piece else '·') + " " print(row) print(" a b c d e f g h" if player_is_white else " h g f e d c b a") print() def print_status(board: chess.Board): if board.is_checkmate(): winner = "Black" if board.turn == chess.WHITE else "White" print(f"\n{'='*38}\n CHECKMATE — {winner} wins!\n{'='*38}\n") elif board.is_stalemate(): print("\nStalemate — draw.") elif board.is_insufficient_material(): print("\nInsufficient material — draw.") elif board.is_fifty_moves(): print("\n50-move rule — draw.") elif board.is_check(): print(" *** CHECK ***") # ══════════════════════════════════════════════════════════════════════════════ # GAME LOOP # ══════════════════════════════════════════════════════════════════════════════ def play(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") token_to_id, id_to_token = load_vocab(args.vocab) model, config = load_model(args.checkpoint, device) player_is_white = args.side.lower() != "black" print(f"\nLiquid Chess Model — {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters") print(f"You are playing as {'White' if player_is_white else 'Black'}.") print("Enter moves in UCI format (e.g. e2e4, g1f3). Type 'moves', 'undo', 'resign', or 'quit'.\n") board = chess.Board() move_history = [] while not board.is_game_over(): print_board(board, player_is_white) print_status(board) if board.is_game_over(): break is_player_turn = (board.turn == chess.WHITE) == player_is_white if is_player_turn: while True: try: raw = input("Your move: ").strip().lower() except EOFError: return if raw == "quit": print("Goodbye!") return if raw == "resign": print("You resigned.") return if raw == "moves": print("Legal moves:", ", ".join(sorted(m.uci() for m in board.legal_moves))) continue if raw == "undo" and len(move_history) >= 2: board.pop(); board.pop() move_history = move_history[:-2] print("Undone.") break try: move = chess.Move.from_uci(raw) if move in board.legal_moves: board.push(move) move_history.append(raw) break else: print(f"Illegal move: {raw}") except ValueError: print(f"Invalid format: {raw} — use UCI (e.g. e2e4)") else: print("Model is thinking...") move_uci = get_model_move( model, config, board, move_history, token_to_id, id_to_token, device, temperature=args.temperature, top_k=args.top_k, ) if move_uci is None: break board.push(chess.Move.from_uci(move_uci)) move_history.append(move_uci) print(f"Model plays: {move_uci}") print_board(board, player_is_white) print_status(board) print(f"Result: {board.result()}") print(f"Moves: {' '.join(move_history)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Play chess against the Liquid Chess Model.") parser.add_argument("--checkpoint", default="model.safetensors", help="Path to model checkpoint (.safetensors or .pt)") parser.add_argument("--vocab", default="vocab.json", help="Path to vocab.json") parser.add_argument("--side", default="white", choices=["white", "black"], help="Your color (default: white)") parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature — lower is more deterministic (default: 1.0)") parser.add_argument("--top-k", type=int, default=0, help="Top-k filtering — 0 disables (default: 0)") play(parser.parse_args())