from __future__ import annotations import json import math import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.hub import cached_file def _is_square(tok: str) -> bool: return len(tok) == 2 and tok[0] in "abcdefgh" and tok[1] in "12345678" def _resolve_file(name_or_path: str, filename: str) -> str: if isinstance(name_or_path, str) and os.path.isdir(name_or_path): p = os.path.join(name_or_path, filename) if os.path.exists(p): return p return cached_file(name_or_path, filename) def _load_vocab(name_or_path: str) -> Tuple[Dict[str, int], Dict[int, str]]: vocab_path = _resolve_file(name_or_path, "vocab.json") with open(vocab_path, "r", encoding="utf-8") as f: tok2id = json.load(f) id2tok = {int(i): t for t, i in tok2id.items()} return tok2id, id2tok @dataclass class TokenScheme: W: str B: str pieces: Dict[str, str] sep: Optional[str] suffix: Dict[str, str] prom: Dict[str, str] pad_id: int bos_id: int eos_id: int unk_id: int def _detect_scheme(tok2id: Dict[str, int], config) -> TokenScheme: W = "W" if "W" in tok2id else None B = "B" if "B" in tok2id else None if W is None or B is None: raise ValueError("Cannot find W/B tokens in vocab") pieces = {} for p in ["P", "N", "B", "R", "Q", "K"]: if p in tok2id: pieces[p] = p else: raise ValueError(f"Cannot find piece token {p} in vocab") sep = " " if " " in tok2id else None suffix = {} for k, v in [ ("cap", "(x)"), ("cap_check", "(x*)"), ("cap_mate", "(x+*)"), ("check", "(+)"), ("mate", "(+*)"), ("o", "(o)"), ("O", "(O)"), ]: if v in tok2id: suffix[k] = v prom = {} for p, v in [("Q", "(Q)"), ("R", "(R)"), ("B", "(B)"), ("N", "(N)")]: if v in tok2id: prom[p] = v pad_id = int(getattr(config, "pad_token_id", 0)) bos_id = int(getattr(config, "bos_token_id", 1)) eos_id = int(getattr(config, "eos_token_id", 2)) unk_id = int(getattr(config, "unk_token_id", 3)) return TokenScheme(W=W, B=B, pieces=pieces, sep=sep, suffix=suffix, prom=prom, pad_id=pad_id, bos_id=bos_id, eos_id=eos_id, unk_id=unk_id) class ChessConfig(PretrainedConfig): model_type = "chess_transformer" def __init__( self, vocab_size: int = 85, n_embd: int = 128, n_layer: int = 5, n_head: int = 4, n_ctx: int = 256, n_inner: Optional[int] = None, dropout: float = 0.1, layer_norm_epsilon: float = 1e-5, tie_weights: bool = False, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, unk_token_id: int = 3, **kwargs, ): self.vocab_size = int(vocab_size) self.n_embd = int(n_embd) self.n_layer = int(n_layer) self.n_head = int(n_head) self.n_ctx = int(n_ctx) self.n_inner = int(n_inner) if n_inner is not None else 3 * int(n_embd) self.dropout = float(dropout) self.layer_norm_epsilon = float(layer_norm_epsilon) self.tie_weights = bool(tie_weights) kwargs["pad_token_id"] = pad_token_id kwargs["bos_token_id"] = bos_token_id kwargs["eos_token_id"] = eos_token_id kwargs["unk_token_id"] = unk_token_id super().__init__(**kwargs) class MLP(nn.Module): def __init__(self, config: ChessConfig): super().__init__() self.c_fc = nn.Linear(config.n_embd, config.n_inner) self.c_proj = nn.Linear(config.n_inner, config.n_embd) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = F.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x class MultiHeadAttention(nn.Module): def __init__(self, config: ChessConfig): super().__init__() assert config.n_embd % config.n_head == 0 self.n_head = config.n_head self.head_dim = config.n_embd // config.n_head self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) self.c_proj = nn.Linear(config.n_embd, config.n_embd) self.dropout = nn.Dropout(config.dropout) bias = torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx) self.register_buffer("bias", bias, persistent=False) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: B, T, C = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(C, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) if attention_mask is not None: att = att.masked_fill(attention_mask.view(B, 1, 1, T) == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.c_proj(y) y = self.dropout(y) return y class Block(nn.Module): def __init__(self, config: ChessConfig): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = MultiHeadAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.mlp = MLP(config) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = x + self.attn(self.ln_1(x), attention_mask=attention_mask) x = x + self.mlp(self.ln_2(x)) return x class ChessForCausalLM(PreTrainedModel): config_class = ChessConfig base_model_prefix = "" def __init__(self, config: ChessConfig): super().__init__(config) self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_ctx, config.n_embd) self.drop = nn.Dropout(config.dropout) self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) if getattr(config, "tie_weights", False): self.lm_head.weight = self.wte.weight self.post_init() self._tok2id = None self._id2tok = None self._scheme = None def _ensure_vocab(self): if self._tok2id is None or self._id2tok is None: name_or_path = getattr(self.config, "_name_or_path", None) or getattr(self, "name_or_path", None) if not name_or_path: raise ValueError("Cannot resolve model path to load vocab.json") self._tok2id, self._id2tok = _load_vocab(name_or_path) def _get_scheme(self) -> TokenScheme: if self._scheme is None: self._ensure_vocab() self._scheme = _detect_scheme(self._tok2id, self.config) return self._scheme def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True, **kwargs): B, T = input_ids.shape if T > self.config.n_ctx: input_ids = input_ids[:, -self.config.n_ctx :] if attention_mask is not None: attention_mask = attention_mask[:, -self.config.n_ctx :] if labels is not None: labels = labels[:, -self.config.n_ctx :] B, T = input_ids.shape pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0) x = self.wte(input_ids) + self.wpe(pos) x = self.drop(x) for block in self.h: x = block(x, attention_mask=attention_mask) x = self.ln_f(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[:, :-1].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) if not return_dict: return (logits, loss) return CausalLMOutputWithPast(logits=logits, loss=loss) def _ids_to_tokens(self, ids: List[int]) -> List[str]: self._ensure_vocab() return [self._id2tok.get(int(i), "[UNK]") for i in ids] def _parse_history_to_board(self, input_ids_1d: List[int]): import chess scheme = self._get_scheme() toks = self._ids_to_tokens(input_ids_1d) specials = {"[PAD]", "[BOS]", "[EOS]", "[UNK]"} toks = [t for t in toks if t not in specials] b = chess.Board() i = 0 while i < len(toks): while i < len(toks) and toks[i] not in (scheme.W, scheme.B): i += 1 if i >= len(toks): break i += 1 while i < len(toks) and scheme.sep is not None and toks[i] == scheme.sep: i += 1 if i >= len(toks) or toks[i] not in scheme.pieces.values(): break i += 1 while i < len(toks) and scheme.sep is not None and toks[i] == scheme.sep: i += 1 if i >= len(toks) or not _is_square(toks[i]): break src = toks[i] i += 1 while i < len(toks) and scheme.sep is not None and toks[i] == scheme.sep: i += 1 if i >= len(toks) or not _is_square(toks[i]): break dst = toks[i] i += 1 suffixes = [] while i < len(toks) and toks[i] not in (scheme.W, scheme.B): if scheme.sep is not None and toks[i] == scheme.sep: i += 1 continue suffixes.append(toks[i]) i += 1 uci = f"{src}{dst}" promo = None for p, ptok in scheme.prom.items(): if ptok in suffixes: promo = p.lower() break if promo is not None: uci += promo try: mv = chess.Move.from_uci(uci) if mv in b.legal_moves: b.push(mv) else: break except Exception: break return b def _move_to_ids(self, board, move_uci: str) -> List[int]: import chess scheme = self._get_scheme() self._ensure_vocab() tok2id = self._tok2id mv = chess.Move.from_uci(move_uci) color_tok = scheme.W if board.turn == chess.WHITE else scheme.B piece = board.piece_at(mv.from_square) pl = piece.symbol().upper() if piece is not None else "P" if pl not in scheme.pieces: pl = "P" src = chess.square_name(mv.from_square) dst = chess.square_name(mv.to_square) toks = [color_tok, pl] if scheme.sep is not None: toks += [scheme.sep, src, scheme.sep, dst] else: toks += [src, dst] is_capture = board.is_capture(mv) board.push(mv) is_mate = board.is_checkmate() is_check = board.is_check() board.pop() suffix_tok = None if is_capture and is_mate: suffix_tok = scheme.suffix.get("cap_mate") elif is_capture and is_check: suffix_tok = scheme.suffix.get("cap_check") elif is_capture: suffix_tok = scheme.suffix.get("cap") elif is_mate: suffix_tok = scheme.suffix.get("mate") elif is_check: suffix_tok = scheme.suffix.get("check") if suffix_tok is not None: toks.append(suffix_tok) if mv.promotion is not None: prom = chess.piece_symbol(mv.promotion).upper() if prom in scheme.prom: toks.append(scheme.prom[prom]) if scheme.sep is not None: toks.append(scheme.sep) return [tok2id.get(t, scheme.unk_id) for t in toks] @torch.no_grad() def _score_candidates(self, prefix_ids, cand_ids_list, attention_mask, temperature, batch_size=64): device = prefix_ids.device T0 = prefix_ids.size(1) scores = torch.empty(len(cand_ids_list), device=device, dtype=torch.float32) pad_id = int(self.config.pad_token_id) for start in range(0, len(cand_ids_list), batch_size): batch = cand_ids_list[start : start + batch_size] max_c = max(len(c) for c in batch) input_ids_list = [] attn_list = [] for c in batch: c_ids = torch.tensor(c, device=device, dtype=torch.long).unsqueeze(0) seq = torch.cat([prefix_ids, c_ids], dim=1) pad_len = (T0 + max_c) - seq.size(1) if pad_len > 0: pad = torch.full((1, pad_len), pad_id, device=device, dtype=torch.long) seq = torch.cat([seq, pad], dim=1) input_ids_list.append(seq) if attention_mask is None: a = torch.ones((1, seq.size(1)), device=device, dtype=torch.long) else: a = attention_mask if a.size(1) != T0: a = a[:, -T0:] ones = torch.ones((1, len(c)), device=device, dtype=torch.long) zeros = torch.zeros((1, max_c - len(c)), device=device, dtype=torch.long) a = torch.cat([a, ones, zeros], dim=1) attn_list.append(a) input_ids = torch.cat(input_ids_list, dim=0) attn_mask = torch.cat(attn_list, dim=0) out = self.forward(input_ids=input_ids, attention_mask=attn_mask, return_dict=True) logits = out.logits / float(max(1e-6, temperature)) logp = torch.log_softmax(logits, dim=-1) for bi, c in enumerate(batch): lp = 0.0 for j in range(len(c)): pos = T0 + j - 1 if pos < 0: continue tok_id = int(c[j]) lp += float(logp[bi, pos, tok_id].item()) scores[start + bi] = lp return scores def generate(self, input_ids=None, attention_mask=None, max_new_tokens=16, temperature=1.0, do_sample=False, **kwargs): import chess if input_ids is None: raise ValueError("generate() requires input_ids") if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) if input_ids.size(0) != 1: return super().generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, **kwargs, ) try: board = self._parse_history_to_board(input_ids[0].tolist()) except Exception: board = None if board is None or board.is_game_over(): return super().generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, **kwargs, ) legal = list(board.legal_moves) if not legal: return input_ids cand_ids_list = [self._move_to_ids(board, mv.uci()) for mv in legal] scores = self._score_candidates( prefix_ids=input_ids, cand_ids_list=cand_ids_list, attention_mask=attention_mask, temperature=float(temperature), batch_size=64, ) best = int(torch.argmax(scores).item()) best_ids = torch.tensor(cand_ids_list[best], device=input_ids.device, dtype=torch.long).unsqueeze(0) if best_ids.size(1) > int(max_new_tokens): best_ids = best_ids[:, : int(max_new_tokens)] return torch.cat([input_ids, best_ids], dim=1)