lcm-chess / generate.py
MostLime's picture
init upload
b2c1dad verified
"""
generate.py β€” Play chess against the Liquid Chess Model (LCM).
Moves are entered in UCI format (e.g. e2e4, g1f3, e7e8q).
The model responds instantly with its chosen move.
Usage:
python generate.py
python generate.py --checkpoint model.safetensors --side black
python generate.py --temperature 0.5
Commands during play:
moves β€” list all legal moves
undo β€” take back the last two moves
resign β€” resign the game
quit β€” exit
Requirements:
pip install chess torch safetensors
"""
import argparse
import json
import sys
import torch
import torch.nn.functional as F
import chess
sys.path.insert(0, ".")
from config import ChessModelConfig
from model import ChessModel
# ══════════════════════════════════════════════════════════════════════════════
# LOADING
# ══════════════════════════════════════════════════════════════════════════════
def load_vocab(vocab_path: str) -> tuple[dict, dict]:
with open(vocab_path) as f:
token_to_id = json.load(f)
return token_to_id, {v: k for k, v in token_to_id.items()}
def load_model(checkpoint_path: str, device: torch.device) -> tuple[ChessModel, ChessModelConfig]:
config = ChessModelConfig()
model = ChessModel(config)
if checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_model as load_safetensors
load_safetensors(model, checkpoint_path)
else:
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
state = {k.replace("_orig_mod.", ""): v for k, v in ckpt["model"].items()}
model.load_state_dict(state)
model.to(device).eval()
return model, config
# ══════════════════════════════════════════════════════════════════════════════
# INFERENCE
# ══════════════════════════════════════════════════════════════════════════════
def get_model_move(
model: ChessModel,
config: ChessModelConfig,
board: chess.Board,
move_history: list[str],
token_to_id: dict,
id_to_token: dict,
device: torch.device,
temperature: float = 1.0,
top_k: int = 0,
) -> str | None:
"""Return the model's chosen move in UCI format."""
pov_id = token_to_id.get("<W>" if board.turn == chess.WHITE else "<B>", 1)
token_ids = [pov_id]
for uci in move_history:
tid = token_to_id.get(uci)
if tid is not None:
token_ids.append(tid)
token_ids = token_ids[-(config.max_seq_len - 1):]
input_tensor = torch.tensor([token_ids], dtype=torch.long, device=device)
with torch.no_grad():
ntp_logits, _ = model(input_tensor)
logits = ntp_logits[0, -1, :]
legal_ucis = [m.uci() for m in board.legal_moves]
if not legal_ucis:
return None
legal_ids = [token_to_id[u] for u in legal_ucis if u in token_to_id]
if not legal_ids:
import random
return random.choice(legal_ucis)
mask = torch.full_like(logits, float("-inf"))
mask[legal_ids] = logits[legal_ids]
mask = mask / max(temperature, 1e-6)
if top_k > 0:
top_vals, _ = torch.topk(mask[mask != float("-inf")], min(top_k, len(legal_ids)))
mask[mask < top_vals[-1]] = float("-inf")
if temperature < 0.01:
chosen_id = torch.argmax(mask).item()
else:
chosen_id = torch.multinomial(F.softmax(mask, dim=-1), num_samples=1).item()
return id_to_token.get(chosen_id, legal_ucis[0])
# ══════════════════════════════════════════════════════════════════════════════
# DISPLAY
# ══════════════════════════════════════════════════════════════════════════════
PIECE_SYMBOLS = {
'P': 'β™Ÿ', 'N': 'β™ž', 'B': '♝', 'R': 'β™œ', 'Q': 'β™›', 'K': 'β™š',
'p': 'β™™', 'n': 'β™˜', 'b': 'β™—', 'r': 'β™–', 'q': 'β™•', 'k': 'β™”',
}
def print_board(board: chess.Board, player_is_white: bool):
print()
ranks = range(7, -1, -1) if player_is_white else range(8)
files = range(8) if player_is_white else range(7, -1, -1)
for rank in ranks:
row = f" {rank + 1} "
for file in files:
piece = board.piece_at(chess.square(file, rank))
row += (PIECE_SYMBOLS.get(piece.symbol(), '?') if piece else 'Β·') + " "
print(row)
print(" a b c d e f g h" if player_is_white else " h g f e d c b a")
print()
def print_status(board: chess.Board):
if board.is_checkmate():
winner = "Black" if board.turn == chess.WHITE else "White"
print(f"\n{'='*38}\n CHECKMATE β€” {winner} wins!\n{'='*38}\n")
elif board.is_stalemate():
print("\nStalemate β€” draw.")
elif board.is_insufficient_material():
print("\nInsufficient material β€” draw.")
elif board.is_fifty_moves():
print("\n50-move rule β€” draw.")
elif board.is_check():
print(" *** CHECK ***")
# ══════════════════════════════════════════════════════════════════════════════
# GAME LOOP
# ══════════════════════════════════════════════════════════════════════════════
def play(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
token_to_id, id_to_token = load_vocab(args.vocab)
model, config = load_model(args.checkpoint, device)
player_is_white = args.side.lower() != "black"
print(f"\nLiquid Chess Model β€” {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
print(f"You are playing as {'White' if player_is_white else 'Black'}.")
print("Enter moves in UCI format (e.g. e2e4, g1f3). Type 'moves', 'undo', 'resign', or 'quit'.\n")
board = chess.Board()
move_history = []
while not board.is_game_over():
print_board(board, player_is_white)
print_status(board)
if board.is_game_over():
break
is_player_turn = (board.turn == chess.WHITE) == player_is_white
if is_player_turn:
while True:
try:
raw = input("Your move: ").strip().lower()
except EOFError:
return
if raw == "quit":
print("Goodbye!")
return
if raw == "resign":
print("You resigned.")
return
if raw == "moves":
print("Legal moves:", ", ".join(sorted(m.uci() for m in board.legal_moves)))
continue
if raw == "undo" and len(move_history) >= 2:
board.pop(); board.pop()
move_history = move_history[:-2]
print("Undone.")
break
try:
move = chess.Move.from_uci(raw)
if move in board.legal_moves:
board.push(move)
move_history.append(raw)
break
else:
print(f"Illegal move: {raw}")
except ValueError:
print(f"Invalid format: {raw} β€” use UCI (e.g. e2e4)")
else:
print("Model is thinking...")
move_uci = get_model_move(
model, config, board, move_history,
token_to_id, id_to_token, device,
temperature=args.temperature,
top_k=args.top_k,
)
if move_uci is None:
break
board.push(chess.Move.from_uci(move_uci))
move_history.append(move_uci)
print(f"Model plays: {move_uci}")
print_board(board, player_is_white)
print_status(board)
print(f"Result: {board.result()}")
print(f"Moves: {' '.join(move_history)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Play chess against the Liquid Chess Model.")
parser.add_argument("--checkpoint", default="model.safetensors",
help="Path to model checkpoint (.safetensors or .pt)")
parser.add_argument("--vocab", default="vocab.json",
help="Path to vocab.json")
parser.add_argument("--side", default="white", choices=["white", "black"],
help="Your color (default: white)")
parser.add_argument("--temperature", type=float, default=1.0,
help="Sampling temperature β€” lower is more deterministic (default: 1.0)")
parser.add_argument("--top-k", type=int, default=0,
help="Top-k filtering β€” 0 disables (default: 0)")
play(parser.parse_args())