| | """ |
| | Utility functions for the Chess Challenge. |
| | |
| | This module provides helper functions for: |
| | - Parameter counting and budget analysis |
| | - Model registration with Hugging Face |
| | - Move validation with python-chess |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Dict, Optional, TYPE_CHECKING |
| |
|
| | import torch.nn as nn |
| |
|
| | if TYPE_CHECKING: |
| | from src.model import ChessConfig |
| |
|
| |
|
| | def count_parameters(model: nn.Module, trainable_only: bool = True) -> int: |
| | """ |
| | Count the number of parameters in a model. |
| | |
| | Args: |
| | model: The PyTorch model. |
| | trainable_only: If True, only count trainable parameters. |
| | |
| | Returns: |
| | Total number of parameters. |
| | """ |
| | if trainable_only: |
| | return sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | return sum(p.numel() for p in model.parameters()) |
| |
|
| |
|
| | def count_parameters_by_component(model: nn.Module) -> Dict[str, int]: |
| | """ |
| | Count parameters broken down by model component. |
| | |
| | Args: |
| | model: The PyTorch model. |
| | |
| | Returns: |
| | Dictionary mapping component names to parameter counts. |
| | """ |
| | counts = {} |
| | for name, module in model.named_modules(): |
| | if len(list(module.children())) == 0: |
| | param_count = sum(p.numel() for p in module.parameters(recurse=False)) |
| | if param_count > 0: |
| | counts[name] = param_count |
| | return counts |
| |
|
| |
|
| | def estimate_parameters(config: "ChessConfig") -> Dict[str, int]: |
| | """ |
| | Estimate the parameter count for a given configuration. |
| | |
| | This is useful for planning your architecture before building the model. |
| | |
| | Args: |
| | config: Model configuration. |
| | |
| | Returns: |
| | Dictionary with estimated parameter counts by component. |
| | """ |
| | V = config.vocab_size |
| | d = config.n_embd |
| | L = config.n_layer |
| | n_ctx = config.n_ctx |
| | n_inner = config.n_inner |
| | |
| | estimates = { |
| | "token_embeddings": V * d, |
| | "position_embeddings": n_ctx * d, |
| | "attention_qkv_per_layer": 3 * d * d, |
| | "attention_proj_per_layer": d * d, |
| | "ffn_per_layer": 2 * d * n_inner, |
| | "layernorm_per_layer": 4 * d, |
| | "final_layernorm": 2 * d, |
| | } |
| | |
| | |
| | per_layer = ( |
| | estimates["attention_qkv_per_layer"] + |
| | estimates["attention_proj_per_layer"] + |
| | estimates["ffn_per_layer"] + |
| | estimates["layernorm_per_layer"] |
| | ) |
| | |
| | estimates["total_transformer_layers"] = L * per_layer |
| | |
| | |
| | if config.tie_weights: |
| | estimates["lm_head"] = 0 |
| | estimates["lm_head_note"] = "Tied with token embeddings" |
| | else: |
| | estimates["lm_head"] = V * d |
| | |
| | |
| | estimates["total"] = ( |
| | estimates["token_embeddings"] + |
| | estimates["position_embeddings"] + |
| | estimates["total_transformer_layers"] + |
| | estimates["final_layernorm"] + |
| | estimates["lm_head"] |
| | ) |
| | |
| | return estimates |
| |
|
| |
|
| | def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None: |
| | """ |
| | Print a formatted parameter budget analysis. |
| | |
| | Args: |
| | config: Model configuration. |
| | limit: Parameter limit to compare against. |
| | """ |
| | estimates = estimate_parameters(config) |
| | |
| | print("=" * 60) |
| | print("PARAMETER BUDGET ANALYSIS") |
| | print("=" * 60) |
| | print(f"\nConfiguration:") |
| | print(f" vocab_size (V) = {config.vocab_size}") |
| | print(f" n_embd (d) = {config.n_embd}") |
| | print(f" n_layer (L) = {config.n_layer}") |
| | print(f" n_head = {config.n_head}") |
| | print(f" n_ctx = {config.n_ctx}") |
| | print(f" n_inner = {config.n_inner}") |
| | print(f" tie_weights = {config.tie_weights}") |
| | |
| | print(f"\nParameter Breakdown:") |
| | print(f" Token Embeddings: {estimates['token_embeddings']:>10,}") |
| | print(f" Position Embeddings: {estimates['position_embeddings']:>10,}") |
| | print(f" Transformer Layers: {estimates['total_transformer_layers']:>10,}") |
| | print(f" Final LayerNorm: {estimates['final_layernorm']:>10,}") |
| | |
| | if config.tie_weights: |
| | print(f" LM Head: {'(tied)':>10}") |
| | else: |
| | print(f" LM Head: {estimates['lm_head']:>10,}") |
| | |
| | print(f" " + "-" * 30) |
| | print(f" TOTAL: {estimates['total']:>10,}") |
| | |
| | print(f"\nBudget Status:") |
| | print(f" Limit: {limit:>10,}") |
| | print(f" Used: {estimates['total']:>10,}") |
| | print(f" Remaining:{limit - estimates['total']:>10,}") |
| | |
| | if estimates['total'] <= limit: |
| | print(f"\n Within budget! ({estimates['total'] / limit * 100:.1f}% used)") |
| | else: |
| | print(f"\n OVER BUDGET by {estimates['total'] - limit:,} parameters!") |
| | |
| | print("=" * 60) |
| |
|
| |
|
| | def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool: |
| | """ |
| | Validate a move using python-chess. |
| | |
| | This function converts the dataset's extended UCI format to standard UCI |
| | and validates it against the current board state. |
| | |
| | Args: |
| | move: Move in extended UCI format (e.g., "WPe2e4", "BNg8f6(x)"). |
| | board_fen: FEN string of the current board state (optional). |
| | |
| | Returns: |
| | True if the move is legal, False otherwise. |
| | """ |
| | try: |
| | import chess |
| | except ImportError: |
| | raise ImportError("python-chess is required for move validation. " |
| | "Install it with: pip install python-chess") |
| | |
| | |
| | |
| | |
| | |
| | if len(move) < 6: |
| | return False |
| | |
| | |
| | color = move[0] |
| | piece = move[1] |
| | from_sq = move[2:4] |
| | to_sq = move[4:6] |
| | |
| | |
| | promotion = None |
| | if "=" in move: |
| | promo_idx = move.index("=") |
| | promotion = move[promo_idx + 1].lower() |
| | |
| | |
| | board = chess.Board(board_fen) if board_fen else chess.Board() |
| | |
| | |
| | uci_move = from_sq + to_sq |
| | if promotion: |
| | uci_move += promotion |
| | |
| | try: |
| | move_obj = chess.Move.from_uci(uci_move) |
| | return move_obj in board.legal_moves |
| | except (ValueError, chess.InvalidMoveError): |
| | return False |
| |
|
| |
|
| | def convert_extended_uci_to_uci(move: str) -> str: |
| | """ |
| | Convert extended UCI format to standard UCI format. |
| | |
| | Args: |
| | move: Move in extended UCI format (e.g., "WPe2e4"). |
| | |
| | Returns: |
| | Move in standard UCI format (e.g., "e2e4"). |
| | """ |
| | if len(move) < 6: |
| | return move |
| | |
| | |
| | from_sq = move[2:4] |
| | to_sq = move[4:6] |
| | |
| | |
| | promotion = "" |
| | if "=" in move: |
| | promo_idx = move.index("=") |
| | promotion = move[promo_idx + 1].lower() |
| | |
| | return from_sq + to_sq + promotion |
| |
|
| |
|
| | def convert_uci_to_extended( |
| | uci_move: str, |
| | board_fen: str, |
| | ) -> str: |
| | """ |
| | Convert standard UCI format to extended UCI format. |
| | |
| | Args: |
| | uci_move: Move in standard UCI format (e.g., "e2e4"). |
| | board_fen: FEN string of the current board state. |
| | |
| | Returns: |
| | Move in extended UCI format (e.g., "WPe2e4"). |
| | """ |
| | try: |
| | import chess |
| | except ImportError: |
| | raise ImportError("python-chess is required for move conversion.") |
| | |
| | board = chess.Board(board_fen) |
| | move = chess.Move.from_uci(uci_move) |
| | |
| | |
| | color = "W" if board.turn == chess.WHITE else "B" |
| | |
| | |
| | piece = board.piece_at(move.from_square) |
| | piece_letter = piece.symbol().upper() if piece else "P" |
| | |
| | |
| | from_sq = chess.square_name(move.from_square) |
| | to_sq = chess.square_name(move.to_square) |
| | |
| | result = f"{color}{piece_letter}{from_sq}{to_sq}" |
| | |
| | |
| | if move.promotion: |
| | result += f"={chess.piece_symbol(move.promotion).upper()}" |
| | |
| | |
| | if board.is_capture(move): |
| | result += "(x)" |
| | |
| | |
| | board.push(move) |
| | if board.is_checkmate(): |
| | if "(x)" in result: |
| | result = result.replace("(x)", "(x+*)") |
| | else: |
| | result += "(+*)" |
| | elif board.is_check(): |
| | if "(x)" in result: |
| | result = result.replace("(x)", "(x+)") |
| | else: |
| | result += "(+)" |
| | board.pop() |
| | |
| | |
| | if board.is_castling(move): |
| | if move.to_square in [chess.G1, chess.G8]: |
| | result = result.replace("(x)", "").replace("(+)", "") + "(o)" |
| | else: |
| | result = result.replace("(x)", "").replace("(+)", "") + "(O)" |
| | |
| | return result |
| |
|