Spaces:
Sleeping
Sleeping
| """ | |
| generate.py — Play chess against the Liquid Chess Model (LCM). | |
| Moves are entered in UCI format (e.g. e2e4, g1f3, e7e8q). | |
| The model responds instantly with its chosen move. | |
| Usage: | |
| python generate.py | |
| python generate.py --checkpoint model.safetensors --side black | |
| python generate.py --temperature 0.5 | |
| Commands during play: | |
| moves — list all legal moves | |
| undo — take back the last two moves | |
| resign — resign the game | |
| quit — exit | |
| Requirements: | |
| pip install chess torch safetensors | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| import torch | |
| import torch.nn.functional as F | |
| import chess | |
| sys.path.insert(0, ".") | |
| from config import ChessModelConfig | |
| from model import ChessModel | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # LOADING | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def load_vocab(vocab_path: str) -> tuple[dict, dict]: | |
| with open(vocab_path) as f: | |
| token_to_id = json.load(f) | |
| return token_to_id, {v: k for k, v in token_to_id.items()} | |
| def load_model(checkpoint_path: str, device: torch.device) -> tuple[ChessModel, ChessModelConfig]: | |
| config = ChessModelConfig() | |
| model = ChessModel(config) | |
| if checkpoint_path.endswith(".safetensors"): | |
| from safetensors.torch import load_model as load_safetensors | |
| load_safetensors(model, checkpoint_path) | |
| else: | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| state = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()} | |
| model.load_state_dict(state) | |
| model.to(device).eval() | |
| return model, config | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # INFERENCE | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def get_model_move( | |
| model: ChessModel, | |
| config: ChessModelConfig, | |
| board: chess.Board, | |
| move_history: list[str], | |
| token_to_id: dict, | |
| id_to_token: dict, | |
| device: torch.device, | |
| temperature: float = 1.0, | |
| top_k: int = 0, | |
| ) -> str | None: | |
| """Return the model's chosen move in UCI format.""" | |
| pov_id = token_to_id.get("<W>" if board.turn == chess.WHITE else "<B>", 1) | |
| token_ids = [pov_id] | |
| for uci in move_history: | |
| tid = token_to_id.get(uci) | |
| if tid is not None: | |
| token_ids.append(tid) | |
| token_ids = token_ids[-(config.max_seq_len - 1):] | |
| input_tensor = torch.tensor([token_ids], dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| ntp_logits, _ = model(input_tensor) | |
| logits = ntp_logits[0, -1, :] | |
| legal_ucis = [m.uci() for m in board.legal_moves] | |
| if not legal_ucis: | |
| return None | |
| legal_ids = [token_to_id[u] for u in legal_ucis if u in token_to_id] | |
| if not legal_ids: | |
| import random | |
| return random.choice(legal_ucis) | |
| mask = torch.full_like(logits, float("-inf")) | |
| mask[legal_ids] = logits[legal_ids] | |
| mask = mask / max(temperature, 1e-6) | |
| if top_k > 0: | |
| top_vals, _ = torch.topk(mask[mask != float("-inf")], min(top_k, len(legal_ids))) | |
| mask[mask < top_vals[-1]] = float("-inf") | |
| if temperature < 0.01: | |
| chosen_id = torch.argmax(mask).item() | |
| else: | |
| chosen_id = torch.multinomial(F.softmax(mask, dim=-1), num_samples=1).item() | |
| return id_to_token.get(chosen_id, legal_ucis[0]) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # DISPLAY | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| PIECE_SYMBOLS = { | |
| 'P': '♟', 'N': '♞', 'B': '♝', 'R': '♜', 'Q': '♛', 'K': '♚', | |
| 'p': '♙', 'n': '♘', 'b': '♗', 'r': '♖', 'q': '♕', 'k': '♔', | |
| } | |
| def print_board(board: chess.Board, player_is_white: bool): | |
| print() | |
| ranks = range(7, -1, -1) if player_is_white else range(8) | |
| files = range(8) if player_is_white else range(7, -1, -1) | |
| for rank in ranks: | |
| row = f" {rank + 1} " | |
| for file in files: | |
| piece = board.piece_at(chess.square(file, rank)) | |
| row += (PIECE_SYMBOLS.get(piece.symbol(), '?') if piece else '·') + " " | |
| print(row) | |
| print(" a b c d e f g h" if player_is_white else " h g f e d c b a") | |
| print() | |
| def print_status(board: chess.Board): | |
| if board.is_checkmate(): | |
| winner = "Black" if board.turn == chess.WHITE else "White" | |
| print(f"\n{'='*38}\n CHECKMATE — {winner} wins!\n{'='*38}\n") | |
| elif board.is_stalemate(): | |
| print("\nStalemate — draw.") | |
| elif board.is_insufficient_material(): | |
| print("\nInsufficient material — draw.") | |
| elif board.is_fifty_moves(): | |
| print("\n50-move rule — draw.") | |
| elif board.is_check(): | |
| print(" *** CHECK ***") | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # GAME LOOP | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def play(args): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| token_to_id, id_to_token = load_vocab(args.vocab) | |
| model, config = load_model(args.checkpoint, device) | |
| player_is_white = args.side.lower() != "black" | |
| print(f"\nLiquid Chess Model — {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters") | |
| print(f"You are playing as {'White' if player_is_white else 'Black'}.") | |
| print("Enter moves in UCI format (e.g. e2e4, g1f3). Type 'moves', 'undo', 'resign', or 'quit'.\n") | |
| board = chess.Board() | |
| move_history = [] | |
| while not board.is_game_over(): | |
| print_board(board, player_is_white) | |
| print_status(board) | |
| if board.is_game_over(): | |
| break | |
| is_player_turn = (board.turn == chess.WHITE) == player_is_white | |
| if is_player_turn: | |
| while True: | |
| try: | |
| raw = input("Your move: ").strip().lower() | |
| except EOFError: | |
| return | |
| if raw == "quit": | |
| print("Goodbye!") | |
| return | |
| if raw == "resign": | |
| print("You resigned.") | |
| return | |
| if raw == "moves": | |
| print("Legal moves:", ", ".join(sorted(m.uci() for m in board.legal_moves))) | |
| continue | |
| if raw == "undo" and len(move_history) >= 2: | |
| board.pop(); board.pop() | |
| move_history = move_history[:-2] | |
| print("Undone.") | |
| break | |
| try: | |
| move = chess.Move.from_uci(raw) | |
| if move in board.legal_moves: | |
| board.push(move) | |
| move_history.append(raw) | |
| break | |
| else: | |
| print(f"Illegal move: {raw}") | |
| except ValueError: | |
| print(f"Invalid format: {raw} — use UCI (e.g. e2e4)") | |
| else: | |
| print("Model is thinking...") | |
| move_uci = get_model_move( | |
| model, config, board, move_history, | |
| token_to_id, id_to_token, device, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| ) | |
| if move_uci is None: | |
| break | |
| board.push(chess.Move.from_uci(move_uci)) | |
| move_history.append(move_uci) | |
| print(f"Model plays: {move_uci}") | |
| print_board(board, player_is_white) | |
| print_status(board) | |
| print(f"Result: {board.result()}") | |
| print(f"Moves: {' '.join(move_history)}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Play chess against the Liquid Chess Model.") | |
| parser.add_argument("--checkpoint", default="model.safetensors", | |
| help="Path to model checkpoint (.safetensors or .pt)") | |
| parser.add_argument("--vocab", default="vocab.json", | |
| help="Path to vocab.json") | |
| parser.add_argument("--side", default="white", choices=["white", "black"], | |
| help="Your color (default: white)") | |
| parser.add_argument("--temperature", type=float, default=1.0, | |
| help="Sampling temperature — lower is more deterministic (default: 1.0)") | |
| parser.add_argument("--top-k", type=int, default=0, | |
| help="Top-k filtering — 0 disables (default: 0)") | |
| play(parser.parse_args()) |