Buckets:
| #!/usr/bin/env python3 | |
| """PAWN model vs Stockfish UCI engine. | |
| Plays N games, alternating colors. PAWN's outcome token is set to the | |
| "I win" token for whichever side it's playing (WHITE_CHECKMATES for | |
| white, BLACK_CHECKMATES for black), conditioning the model to play | |
| toward victory. | |
| Usage: | |
| play_vs_stockfish.py \\ | |
| --backbone thomas-schweich/pawn-base \\ | |
| --adapter /workspace/logs/trial_0011_v4/bottleneck_.../checkpoints/best \\ | |
| --strategy bottleneck \\ | |
| --stockfish stockfish --elo 1350 --movetime-ms 5 \\ | |
| --games 20 --max-ply 200 | |
| Writes JSONL game records to --output if specified. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| sys.path.insert(0, "/opt/pawn") | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from safetensors.torch import load_file | |
| import chess_engine | |
| from pawn.config import ( | |
| CLMConfig, WHITE_CHECKMATES, BLACK_CHECKMATES, | |
| ) | |
| from pawn.model import PAWNCLM | |
| from pawn import model as model_module | |
| from pawn.checkpoint import load_backbone_weights | |
| from pawn.gpu import configure_gpu, apply_gpu_config | |
| # -------------------------------------------------------------------------- | |
| # Stockfish UCI wrapper | |
| # -------------------------------------------------------------------------- | |
| class Stockfish: | |
| def __init__(self, path: str, elo: int, limit_strength: bool = True, | |
| movetime_ms: int = 5, threads: int = 1, hash_mb: int = 16): | |
| self.proc = subprocess.Popen( | |
| [path], | |
| stdin=subprocess.PIPE, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.DEVNULL, | |
| text=True, | |
| bufsize=1, | |
| ) | |
| self.movetime_ms = movetime_ms | |
| self._send("uci") | |
| self._wait("uciok") | |
| self._send(f"setoption name Hash value {hash_mb}") | |
| self._send(f"setoption name Threads value {threads}") | |
| if limit_strength: | |
| self._send(f"setoption name UCI_LimitStrength value true") | |
| self._send(f"setoption name UCI_Elo value {elo}") | |
| self._send("ucinewgame") | |
| self._send("isready") | |
| self._wait("readyok") | |
| def _send(self, cmd: str) -> None: | |
| self.proc.stdin.write(cmd + "\n") | |
| self.proc.stdin.flush() | |
| def _wait(self, token: str) -> list[str]: | |
| out = [] | |
| while True: | |
| line = self.proc.stdout.readline().strip() | |
| if not line: | |
| continue | |
| out.append(line) | |
| if line.startswith(token) or line.split()[0] == token: | |
| return out | |
| def new_game(self) -> None: | |
| self._send("ucinewgame") | |
| self._send("isready") | |
| self._wait("readyok") | |
| def get_move(self, uci_position: str) -> str: | |
| """uci_position is like 'position startpos moves e2e4 e7e5'.""" | |
| self._send(uci_position) | |
| self._send(f"go movetime {self.movetime_ms}") | |
| lines = self._wait("bestmove") | |
| best_line = [ln for ln in lines if ln.startswith("bestmove")][-1] | |
| parts = best_line.split() | |
| return parts[1] # "bestmove e2e4 [ponder e7e5]" | |
| def close(self) -> None: | |
| try: | |
| self._send("quit") | |
| self.proc.wait(timeout=2) | |
| except Exception: | |
| self.proc.kill() | |
| # -------------------------------------------------------------------------- | |
| # PAWN model loading | |
| # -------------------------------------------------------------------------- | |
| def load_pawn_with_adapter( | |
| backbone_ref: str, adapter_path: str | None, | |
| strategy: str, device: str, | |
| ) -> nn.Module: | |
| """Load backbone + adapter. backbone_ref can be a HF repo or local dir.""" | |
| state_dict, model_config = load_backbone_weights(backbone_ref, device) | |
| cfg = CLMConfig(**model_config) if model_config else CLMConfig() | |
| backbone = PAWNCLM(cfg).to(device) | |
| backbone.load_state_dict(state_dict) | |
| backbone.eval() | |
| if adapter_path is None: | |
| return backbone # bare backbone | |
| adapter_state = load_file(str(Path(adapter_path) / "adapter.safetensors")) | |
| adapter_cfg = json.loads( | |
| (Path(adapter_path) / "config.json").read_text() | |
| ) | |
| if strategy == "bottleneck": | |
| from pawn.adapters.bottleneck import BottleneckCLM | |
| layers = tuple(adapter_cfg["adapter_layers"]) | |
| model = BottleneckCLM( | |
| backbone, | |
| bottleneck_dim=adapter_cfg["bottleneck_dim"], | |
| layers=layers, | |
| adapt_attn=adapter_cfg.get("adapt_attn", True), | |
| adapt_ffn=adapter_cfg.get("adapt_ffn", True), | |
| ).to(device) | |
| model.load_state_dict(adapter_state, strict=False) | |
| elif strategy == "unfreeze": | |
| backbone.load_state_dict(adapter_state, strict=False) | |
| model = backbone | |
| elif strategy == "sparse": | |
| from pawn.adapters.sparse import SparseCLM | |
| from pawn.adapters.lora import ATTN_PRESETS | |
| layers = tuple(adapter_cfg.get("adapter_layers") or range(8)) | |
| density = adapter_cfg.get("density", 0.2) | |
| attn_targets = ATTN_PRESETS[adapter_cfg.get("sparse_targets", "qkvo")] | |
| model = SparseCLM( | |
| backbone, density=density, | |
| attn_targets=attn_targets, | |
| adapt_ffn=adapter_cfg.get("sparse_ffn", True), | |
| layers=layers, | |
| ).to(device) | |
| model.load_state_dict(adapter_state, strict=False) | |
| elif strategy == "rosa": | |
| from pawn.adapters.bottleneck import BottleneckCLM | |
| layers = tuple(adapter_cfg.get("adapter_layers") or (4, 5, 6, 7)) | |
| model = BottleneckCLM( | |
| backbone, | |
| bottleneck_dim=adapter_cfg["bottleneck_dim"], | |
| layers=layers, | |
| adapt_attn=True, adapt_ffn=True, | |
| ).to(device) | |
| model.load_state_dict(adapter_state, strict=False) | |
| else: | |
| raise ValueError(f"Unknown strategy: {strategy}") | |
| model.eval() | |
| return model | |
| # -------------------------------------------------------------------------- | |
| # Move generation | |
| # -------------------------------------------------------------------------- | |
| def pawn_pick_move( | |
| model: nn.Module, | |
| token_sequence: list[int], | |
| legal_token_ids: list[int], | |
| device: str, | |
| temperature: float = 0.0, | |
| ) -> int: | |
| """Run one forward pass, pick a legal move. | |
| - temperature=0: argmax among legal tokens. | |
| - temperature>0: sample from softmax(logits/T) masked to legal. | |
| """ | |
| # Route through whichever forward_hidden / project_head the wrapper has | |
| ids = torch.tensor([token_sequence], dtype=torch.long, device=device) | |
| # Use forward_hidden if available (adapters), else use backbone forward | |
| if hasattr(model, "forward_hidden"): | |
| hidden = model.forward_hidden(ids) | |
| last_hidden = hidden[:, -1:, :] | |
| logits = model.project_head(last_hidden)[0, 0] # (V,) | |
| elif hasattr(model, "bb"): # unfreeze wrapper — emulate | |
| bb = model.bb if hasattr(model, "bb") else model | |
| x = bb.embed(ids) | |
| T = ids.shape[1] | |
| rope_cos = bb.rope_cos[:, :, :T, :] | |
| rope_sin = bb.rope_sin[:, :, :T, :] | |
| mask = None # causal is baked into SDPA via is_causal=True in Attention | |
| for layer in bb.layers: | |
| x = layer(x, rope_cos, rope_sin, mask) | |
| x = bb.final_norm(x[:, -1:, :]) | |
| logits = bb.lm_head(x)[0, 0] | |
| else: | |
| # Raw backbone: use forward_generate for efficiency (no kv cache here | |
| # since we rebuild sequence each move; simple path) | |
| attn_mask = torch.ones_like(ids, dtype=torch.bool, device=device) | |
| out, _ = model(ids, attn_mask, hidden_only=False) | |
| logits = out[0, -1] | |
| logits = logits.float() | |
| # Mask to legal moves | |
| mask = torch.full_like(logits, float("-inf")) | |
| legal_idx = torch.tensor(legal_token_ids, dtype=torch.long, device=device) | |
| mask[legal_idx] = 0.0 | |
| logits = logits + mask | |
| if temperature <= 0: | |
| return int(logits.argmax().item()) | |
| probs = F.softmax(logits / temperature, dim=-1) | |
| return int(torch.multinomial(probs, 1).item()) | |
| # -------------------------------------------------------------------------- | |
| # Game loop | |
| # -------------------------------------------------------------------------- | |
| def play_one_game( | |
| pawn_model, pawn_device, pawn_temp, | |
| stockfish: Stockfish, pawn_is_white: bool, max_ply: int, | |
| vocab: dict, | |
| ) -> dict: | |
| """Play one game. Returns dict with result and statistics.""" | |
| gs = chess_engine.PyGameState() | |
| stockfish.new_game() | |
| outcome_token = WHITE_CHECKMATES if pawn_is_white else BLACK_CHECKMATES | |
| token_seq = [outcome_token] | |
| termination = None | |
| winner = None | |
| while gs.ply() < max_ply: | |
| if gs.is_game_over(): | |
| term = gs.check_termination() | |
| termination = term | |
| if gs.is_checkmate(): | |
| # Side to move is in checkmate, so the OTHER side wins. | |
| winner = "black" if gs.is_white_to_move() else "white" | |
| break | |
| white_to_move = gs.is_white_to_move() | |
| pawn_turn = (pawn_is_white == white_to_move) | |
| if pawn_turn: | |
| legal_tokens = list(gs.legal_move_tokens()) | |
| try: | |
| tok = pawn_pick_move( | |
| pawn_model, token_seq, legal_tokens, | |
| pawn_device, temperature=pawn_temp, | |
| ) | |
| except Exception as e: | |
| termination = f"pawn_error:{e}" | |
| break | |
| if tok not in legal_tokens: | |
| termination = "pawn_illegal" | |
| break | |
| gs.make_move(tok) | |
| token_seq.append(tok) | |
| else: | |
| try: | |
| uci = stockfish.get_move(gs.uci_position()) | |
| except Exception as e: | |
| termination = f"stockfish_error:{e}" | |
| break | |
| # Convert UCI to token via vocab | |
| tok = vocab["move_to_token"].get(uci) | |
| if tok is None: | |
| termination = f"bad_uci:{uci}" | |
| break | |
| if tok not in gs.legal_move_tokens(): | |
| termination = f"stockfish_illegal:{uci}" | |
| break | |
| gs.make_move(tok) | |
| token_seq.append(tok) | |
| if termination is None: | |
| termination = "max_ply" | |
| # Determine result | |
| if winner == "white": | |
| pawn_result = "win" if pawn_is_white else "loss" | |
| elif winner == "black": | |
| pawn_result = "loss" if pawn_is_white else "win" | |
| else: | |
| pawn_result = "draw" | |
| return { | |
| "pawn_color": "white" if pawn_is_white else "black", | |
| "plies": gs.ply(), | |
| "termination": termination, | |
| "winner": winner, | |
| "pawn_result": pawn_result, | |
| "moves": gs.move_history(), | |
| } | |
| # -------------------------------------------------------------------------- | |
| # Main | |
| # -------------------------------------------------------------------------- | |
| def main(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--backbone", default="thomas-schweich/pawn-base", | |
| help="HF repo ID or local directory") | |
| p.add_argument("--adapter", default=None, | |
| help="Path to adapter checkpoint dir (with adapter.safetensors)") | |
| p.add_argument("--strategy", default="bottleneck", | |
| choices=["bottleneck", "unfreeze", "rosa", "sparse", "none"]) | |
| p.add_argument("--stockfish", default="stockfish") | |
| p.add_argument("--elo", type=int, default=1350) | |
| p.add_argument("--movetime-ms", type=int, default=5) | |
| p.add_argument("--games", type=int, default=10) | |
| p.add_argument("--max-ply", type=int, default=200) | |
| p.add_argument("--temperature", type=float, default=0.0, | |
| help="Softmax temp for PAWN move sampling. 0=argmax.") | |
| p.add_argument("--device", default="cuda") | |
| p.add_argument("--output", default=None, help="Optional JSONL file") | |
| p.add_argument("--alternate", action="store_true", default=True, | |
| help="Alternate PAWN colors each game (default on)") | |
| args = p.parse_args() | |
| # Configure GPU / SDPA | |
| gpu_cfg = configure_gpu( | |
| device=args.device, sdpa_math=True, no_compile=True, | |
| ) | |
| apply_gpu_config(gpu_cfg, model_module, lambda x: x) | |
| # Load PAWN | |
| print(f"Loading backbone: {args.backbone}") | |
| adapter_path = args.adapter if args.strategy != "none" else None | |
| model = load_pawn_with_adapter( | |
| args.backbone, adapter_path, args.strategy, args.device, | |
| ) | |
| vocab = chess_engine.export_move_vocabulary() | |
| # Start Stockfish | |
| print(f"Starting Stockfish: elo={args.elo} movetime={args.movetime_ms}ms") | |
| sf = Stockfish( | |
| args.stockfish, elo=args.elo, limit_strength=True, | |
| movetime_ms=args.movetime_ms, | |
| ) | |
| # Play games | |
| print(f"Playing {args.games} games...\n") | |
| records = [] | |
| t0 = time.time() | |
| w = l = d = 0 | |
| for i in range(args.games): | |
| pawn_is_white = (i % 2 == 0) if args.alternate else True | |
| g0 = time.time() | |
| rec = play_one_game( | |
| model, args.device, args.temperature, | |
| sf, pawn_is_white, args.max_ply, vocab, | |
| ) | |
| dt = time.time() - g0 | |
| rec["seconds"] = round(dt, 2) | |
| records.append(rec) | |
| if rec["pawn_result"] == "win": w += 1 | |
| elif rec["pawn_result"] == "loss": l += 1 | |
| else: d += 1 | |
| print( | |
| f"Game {i+1}/{args.games} pawn={rec['pawn_color']:5s} " | |
| f"result={rec['pawn_result']:4s} plies={rec['plies']:3d} " | |
| f"term={str(rec['termination']):25s} {dt:5.1f}s" | |
| ) | |
| sf.close() | |
| # Summary | |
| total_s = time.time() - t0 | |
| n = len(records) | |
| score = w + 0.5 * d | |
| print(f"\n=== Summary ===") | |
| print(f" Games: {n} Wall time: {total_s:.1f}s " | |
| f"({total_s/n:.1f}s/game avg)") | |
| print(f" PAWN: W={w} L={l} D={d} " | |
| f"score={score}/{n} = {score/n*100:.1f}%") | |
| # Elo estimate via logistic: E = 1/(1+10^((R_op - R_pawn)/400)) | |
| # → R_pawn = R_op - 400 * log10(1/E - 1) | |
| import math | |
| p_score = max(min(score / n, 0.999), 0.001) | |
| elo_delta = -400 * math.log10(1 / p_score - 1) | |
| print(f" Implied Elo delta: {elo_delta:+.0f} (opponent {args.elo})") | |
| print(f" Estimated PAWN Elo: {args.elo + elo_delta:.0f}") | |
| if args.output: | |
| out = Path(args.output) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with out.open("w") as f: | |
| summary = { | |
| "type": "summary", | |
| "games": n, "wins": w, "losses": l, "draws": d, | |
| "score": score, "score_pct": score / n, | |
| "opponent_elo": args.elo, | |
| "pawn_elo_estimate": args.elo + elo_delta, | |
| "adapter": args.adapter, | |
| "backbone": args.backbone, | |
| "strategy": args.strategy, | |
| "movetime_ms": args.movetime_ms, | |
| "temperature": args.temperature, | |
| "max_ply": args.max_ply, | |
| } | |
| f.write(json.dumps(summary) + "\n") | |
| for r in records: | |
| f.write(json.dumps({"type": "game", **r}) + "\n") | |
| print(f" Wrote {args.output}") | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 15.2 kB
- Xet hash:
- df6a60e64ba0d314b2053bb8259bc5bb9e77e148ad1e93b143977e06828f7925
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.