thomas-schweich/pawn-18h-push / scripts /play_vs_stockfish.py
thomas-schweich's picture
download
raw
15.2 kB
#!/usr/bin/env python3
"""PAWN model vs Stockfish UCI engine.
Plays N games, alternating colors. PAWN's outcome token is set to the
"I win" token for whichever side it's playing (WHITE_CHECKMATES for
white, BLACK_CHECKMATES for black), conditioning the model to play
toward victory.
Usage:
play_vs_stockfish.py \\
--backbone thomas-schweich/pawn-base \\
--adapter /workspace/logs/trial_0011_v4/bottleneck_.../checkpoints/best \\
--strategy bottleneck \\
--stockfish stockfish --elo 1350 --movetime-ms 5 \\
--games 20 --max-ply 200
Writes JSONL game records to --output if specified.
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import time
from pathlib import Path
sys.path.insert(0, "/opt/pawn")
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
import chess_engine
from pawn.config import (
CLMConfig, WHITE_CHECKMATES, BLACK_CHECKMATES,
)
from pawn.model import PAWNCLM
from pawn import model as model_module
from pawn.checkpoint import load_backbone_weights
from pawn.gpu import configure_gpu, apply_gpu_config
# --------------------------------------------------------------------------
# Stockfish UCI wrapper
# --------------------------------------------------------------------------
class Stockfish:
def __init__(self, path: str, elo: int, limit_strength: bool = True,
movetime_ms: int = 5, threads: int = 1, hash_mb: int = 16):
self.proc = subprocess.Popen(
[path],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
bufsize=1,
)
self.movetime_ms = movetime_ms
self._send("uci")
self._wait("uciok")
self._send(f"setoption name Hash value {hash_mb}")
self._send(f"setoption name Threads value {threads}")
if limit_strength:
self._send(f"setoption name UCI_LimitStrength value true")
self._send(f"setoption name UCI_Elo value {elo}")
self._send("ucinewgame")
self._send("isready")
self._wait("readyok")
def _send(self, cmd: str) -> None:
self.proc.stdin.write(cmd + "\n")
self.proc.stdin.flush()
def _wait(self, token: str) -> list[str]:
out = []
while True:
line = self.proc.stdout.readline().strip()
if not line:
continue
out.append(line)
if line.startswith(token) or line.split()[0] == token:
return out
def new_game(self) -> None:
self._send("ucinewgame")
self._send("isready")
self._wait("readyok")
def get_move(self, uci_position: str) -> str:
"""uci_position is like 'position startpos moves e2e4 e7e5'."""
self._send(uci_position)
self._send(f"go movetime {self.movetime_ms}")
lines = self._wait("bestmove")
best_line = [ln for ln in lines if ln.startswith("bestmove")][-1]
parts = best_line.split()
return parts[1] # "bestmove e2e4 [ponder e7e5]"
def close(self) -> None:
try:
self._send("quit")
self.proc.wait(timeout=2)
except Exception:
self.proc.kill()
# --------------------------------------------------------------------------
# PAWN model loading
# --------------------------------------------------------------------------
def load_pawn_with_adapter(
backbone_ref: str, adapter_path: str | None,
strategy: str, device: str,
) -> nn.Module:
"""Load backbone + adapter. backbone_ref can be a HF repo or local dir."""
state_dict, model_config = load_backbone_weights(backbone_ref, device)
cfg = CLMConfig(**model_config) if model_config else CLMConfig()
backbone = PAWNCLM(cfg).to(device)
backbone.load_state_dict(state_dict)
backbone.eval()
if adapter_path is None:
return backbone # bare backbone
adapter_state = load_file(str(Path(adapter_path) / "adapter.safetensors"))
adapter_cfg = json.loads(
(Path(adapter_path) / "config.json").read_text()
)
if strategy == "bottleneck":
from pawn.adapters.bottleneck import BottleneckCLM
layers = tuple(adapter_cfg["adapter_layers"])
model = BottleneckCLM(
backbone,
bottleneck_dim=adapter_cfg["bottleneck_dim"],
layers=layers,
adapt_attn=adapter_cfg.get("adapt_attn", True),
adapt_ffn=adapter_cfg.get("adapt_ffn", True),
).to(device)
model.load_state_dict(adapter_state, strict=False)
elif strategy == "unfreeze":
backbone.load_state_dict(adapter_state, strict=False)
model = backbone
elif strategy == "sparse":
from pawn.adapters.sparse import SparseCLM
from pawn.adapters.lora import ATTN_PRESETS
layers = tuple(adapter_cfg.get("adapter_layers") or range(8))
density = adapter_cfg.get("density", 0.2)
attn_targets = ATTN_PRESETS[adapter_cfg.get("sparse_targets", "qkvo")]
model = SparseCLM(
backbone, density=density,
attn_targets=attn_targets,
adapt_ffn=adapter_cfg.get("sparse_ffn", True),
layers=layers,
).to(device)
model.load_state_dict(adapter_state, strict=False)
elif strategy == "rosa":
from pawn.adapters.bottleneck import BottleneckCLM
layers = tuple(adapter_cfg.get("adapter_layers") or (4, 5, 6, 7))
model = BottleneckCLM(
backbone,
bottleneck_dim=adapter_cfg["bottleneck_dim"],
layers=layers,
adapt_attn=True, adapt_ffn=True,
).to(device)
model.load_state_dict(adapter_state, strict=False)
else:
raise ValueError(f"Unknown strategy: {strategy}")
model.eval()
return model
# --------------------------------------------------------------------------
# Move generation
# --------------------------------------------------------------------------
@torch.no_grad()
def pawn_pick_move(
model: nn.Module,
token_sequence: list[int],
legal_token_ids: list[int],
device: str,
temperature: float = 0.0,
) -> int:
"""Run one forward pass, pick a legal move.
- temperature=0: argmax among legal tokens.
- temperature>0: sample from softmax(logits/T) masked to legal.
"""
# Route through whichever forward_hidden / project_head the wrapper has
ids = torch.tensor([token_sequence], dtype=torch.long, device=device)
# Use forward_hidden if available (adapters), else use backbone forward
if hasattr(model, "forward_hidden"):
hidden = model.forward_hidden(ids)
last_hidden = hidden[:, -1:, :]
logits = model.project_head(last_hidden)[0, 0] # (V,)
elif hasattr(model, "bb"): # unfreeze wrapper — emulate
bb = model.bb if hasattr(model, "bb") else model
x = bb.embed(ids)
T = ids.shape[1]
rope_cos = bb.rope_cos[:, :, :T, :]
rope_sin = bb.rope_sin[:, :, :T, :]
mask = None # causal is baked into SDPA via is_causal=True in Attention
for layer in bb.layers:
x = layer(x, rope_cos, rope_sin, mask)
x = bb.final_norm(x[:, -1:, :])
logits = bb.lm_head(x)[0, 0]
else:
# Raw backbone: use forward_generate for efficiency (no kv cache here
# since we rebuild sequence each move; simple path)
attn_mask = torch.ones_like(ids, dtype=torch.bool, device=device)
out, _ = model(ids, attn_mask, hidden_only=False)
logits = out[0, -1]
logits = logits.float()
# Mask to legal moves
mask = torch.full_like(logits, float("-inf"))
legal_idx = torch.tensor(legal_token_ids, dtype=torch.long, device=device)
mask[legal_idx] = 0.0
logits = logits + mask
if temperature <= 0:
return int(logits.argmax().item())
probs = F.softmax(logits / temperature, dim=-1)
return int(torch.multinomial(probs, 1).item())
# --------------------------------------------------------------------------
# Game loop
# --------------------------------------------------------------------------
def play_one_game(
pawn_model, pawn_device, pawn_temp,
stockfish: Stockfish, pawn_is_white: bool, max_ply: int,
vocab: dict,
) -> dict:
"""Play one game. Returns dict with result and statistics."""
gs = chess_engine.PyGameState()
stockfish.new_game()
outcome_token = WHITE_CHECKMATES if pawn_is_white else BLACK_CHECKMATES
token_seq = [outcome_token]
termination = None
winner = None
while gs.ply() < max_ply:
if gs.is_game_over():
term = gs.check_termination()
termination = term
if gs.is_checkmate():
# Side to move is in checkmate, so the OTHER side wins.
winner = "black" if gs.is_white_to_move() else "white"
break
white_to_move = gs.is_white_to_move()
pawn_turn = (pawn_is_white == white_to_move)
if pawn_turn:
legal_tokens = list(gs.legal_move_tokens())
try:
tok = pawn_pick_move(
pawn_model, token_seq, legal_tokens,
pawn_device, temperature=pawn_temp,
)
except Exception as e:
termination = f"pawn_error:{e}"
break
if tok not in legal_tokens:
termination = "pawn_illegal"
break
gs.make_move(tok)
token_seq.append(tok)
else:
try:
uci = stockfish.get_move(gs.uci_position())
except Exception as e:
termination = f"stockfish_error:{e}"
break
# Convert UCI to token via vocab
tok = vocab["move_to_token"].get(uci)
if tok is None:
termination = f"bad_uci:{uci}"
break
if tok not in gs.legal_move_tokens():
termination = f"stockfish_illegal:{uci}"
break
gs.make_move(tok)
token_seq.append(tok)
if termination is None:
termination = "max_ply"
# Determine result
if winner == "white":
pawn_result = "win" if pawn_is_white else "loss"
elif winner == "black":
pawn_result = "loss" if pawn_is_white else "win"
else:
pawn_result = "draw"
return {
"pawn_color": "white" if pawn_is_white else "black",
"plies": gs.ply(),
"termination": termination,
"winner": winner,
"pawn_result": pawn_result,
"moves": gs.move_history(),
}
# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
def main():
p = argparse.ArgumentParser()
p.add_argument("--backbone", default="thomas-schweich/pawn-base",
help="HF repo ID or local directory")
p.add_argument("--adapter", default=None,
help="Path to adapter checkpoint dir (with adapter.safetensors)")
p.add_argument("--strategy", default="bottleneck",
choices=["bottleneck", "unfreeze", "rosa", "sparse", "none"])
p.add_argument("--stockfish", default="stockfish")
p.add_argument("--elo", type=int, default=1350)
p.add_argument("--movetime-ms", type=int, default=5)
p.add_argument("--games", type=int, default=10)
p.add_argument("--max-ply", type=int, default=200)
p.add_argument("--temperature", type=float, default=0.0,
help="Softmax temp for PAWN move sampling. 0=argmax.")
p.add_argument("--device", default="cuda")
p.add_argument("--output", default=None, help="Optional JSONL file")
p.add_argument("--alternate", action="store_true", default=True,
help="Alternate PAWN colors each game (default on)")
args = p.parse_args()
# Configure GPU / SDPA
gpu_cfg = configure_gpu(
device=args.device, sdpa_math=True, no_compile=True,
)
apply_gpu_config(gpu_cfg, model_module, lambda x: x)
# Load PAWN
print(f"Loading backbone: {args.backbone}")
adapter_path = args.adapter if args.strategy != "none" else None
model = load_pawn_with_adapter(
args.backbone, adapter_path, args.strategy, args.device,
)
vocab = chess_engine.export_move_vocabulary()
# Start Stockfish
print(f"Starting Stockfish: elo={args.elo} movetime={args.movetime_ms}ms")
sf = Stockfish(
args.stockfish, elo=args.elo, limit_strength=True,
movetime_ms=args.movetime_ms,
)
# Play games
print(f"Playing {args.games} games...\n")
records = []
t0 = time.time()
w = l = d = 0
for i in range(args.games):
pawn_is_white = (i % 2 == 0) if args.alternate else True
g0 = time.time()
rec = play_one_game(
model, args.device, args.temperature,
sf, pawn_is_white, args.max_ply, vocab,
)
dt = time.time() - g0
rec["seconds"] = round(dt, 2)
records.append(rec)
if rec["pawn_result"] == "win": w += 1
elif rec["pawn_result"] == "loss": l += 1
else: d += 1
print(
f"Game {i+1}/{args.games} pawn={rec['pawn_color']:5s} "
f"result={rec['pawn_result']:4s} plies={rec['plies']:3d} "
f"term={str(rec['termination']):25s} {dt:5.1f}s"
)
sf.close()
# Summary
total_s = time.time() - t0
n = len(records)
score = w + 0.5 * d
print(f"\n=== Summary ===")
print(f" Games: {n} Wall time: {total_s:.1f}s "
f"({total_s/n:.1f}s/game avg)")
print(f" PAWN: W={w} L={l} D={d} "
f"score={score}/{n} = {score/n*100:.1f}%")
# Elo estimate via logistic: E = 1/(1+10^((R_op - R_pawn)/400))
# → R_pawn = R_op - 400 * log10(1/E - 1)
import math
p_score = max(min(score / n, 0.999), 0.001)
elo_delta = -400 * math.log10(1 / p_score - 1)
print(f" Implied Elo delta: {elo_delta:+.0f} (opponent {args.elo})")
print(f" Estimated PAWN Elo: {args.elo + elo_delta:.0f}")
if args.output:
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("w") as f:
summary = {
"type": "summary",
"games": n, "wins": w, "losses": l, "draws": d,
"score": score, "score_pct": score / n,
"opponent_elo": args.elo,
"pawn_elo_estimate": args.elo + elo_delta,
"adapter": args.adapter,
"backbone": args.backbone,
"strategy": args.strategy,
"movetime_ms": args.movetime_ms,
"temperature": args.temperature,
"max_ply": args.max_ply,
}
f.write(json.dumps(summary) + "\n")
for r in records:
f.write(json.dumps({"type": "game", **r}) + "\n")
print(f" Wrote {args.output}")
if __name__ == "__main__":
main()

Xet Storage Details

Size:
15.2 kB
·
Xet hash:
df6a60e64ba0d314b2053bb8259bc5bb9e77e148ad1e93b143977e06828f7925

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.