|
|
""" |
|
|
Utility functions for the Chess Challenge. |
|
|
|
|
|
This module provides helper functions for: |
|
|
- Parameter counting and budget analysis |
|
|
- Model registration with Hugging Face |
|
|
- Move validation with python-chess |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Dict, Optional, TYPE_CHECKING |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from src.model import ChessConfig |
|
|
|
|
|
|
|
|
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int: |
|
|
""" |
|
|
Count the number of parameters in a model. |
|
|
|
|
|
Args: |
|
|
model: The PyTorch model. |
|
|
trainable_only: If True, only count trainable parameters. |
|
|
|
|
|
Returns: |
|
|
Total number of parameters. |
|
|
""" |
|
|
if trainable_only: |
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
return sum(p.numel() for p in model.parameters()) |
|
|
|
|
|
|
|
|
def count_parameters_by_component(model: nn.Module) -> Dict[str, int]: |
|
|
""" |
|
|
Count parameters broken down by model component. |
|
|
|
|
|
Args: |
|
|
model: The PyTorch model. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping component names to parameter counts. |
|
|
""" |
|
|
counts = {} |
|
|
for name, module in model.named_modules(): |
|
|
if len(list(module.children())) == 0: |
|
|
param_count = sum(p.numel() for p in module.parameters(recurse=False)) |
|
|
if param_count > 0: |
|
|
counts[name] = param_count |
|
|
return counts |
|
|
|
|
|
|
|
|
def estimate_parameters(config: "ChessConfig") -> Dict[str, int]: |
|
|
""" |
|
|
Estimate the parameter count for a given configuration. |
|
|
|
|
|
This is useful for planning your architecture before building the model. |
|
|
|
|
|
Args: |
|
|
config: Model configuration. |
|
|
|
|
|
Returns: |
|
|
Dictionary with estimated parameter counts by component. |
|
|
""" |
|
|
V = config.vocab_size |
|
|
d = config.n_embd |
|
|
L = config.n_layer |
|
|
n_ctx = config.n_ctx |
|
|
n_inner = config.n_inner |
|
|
|
|
|
estimates = { |
|
|
"token_embeddings": V * d, |
|
|
"position_embeddings": n_ctx * d, |
|
|
"attention_qkv_per_layer": 3 * d * d, |
|
|
"attention_proj_per_layer": d * d, |
|
|
"ffn_per_layer": 2 * d * n_inner, |
|
|
"layernorm_per_layer": 4 * d, |
|
|
"final_layernorm": 2 * d, |
|
|
} |
|
|
|
|
|
|
|
|
per_layer = ( |
|
|
estimates["attention_qkv_per_layer"] + |
|
|
estimates["attention_proj_per_layer"] + |
|
|
estimates["ffn_per_layer"] + |
|
|
estimates["layernorm_per_layer"] |
|
|
) |
|
|
|
|
|
estimates["total_transformer_layers"] = L * per_layer |
|
|
|
|
|
|
|
|
if config.tie_weights: |
|
|
estimates["lm_head"] = 0 |
|
|
estimates["lm_head_note"] = "Tied with token embeddings" |
|
|
else: |
|
|
estimates["lm_head"] = V * d |
|
|
|
|
|
|
|
|
estimates["total"] = ( |
|
|
estimates["token_embeddings"] + |
|
|
estimates["position_embeddings"] + |
|
|
estimates["total_transformer_layers"] + |
|
|
estimates["final_layernorm"] + |
|
|
estimates["lm_head"] |
|
|
) |
|
|
|
|
|
return estimates |
|
|
|
|
|
|
|
|
def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None: |
|
|
""" |
|
|
Print a formatted parameter budget analysis. |
|
|
|
|
|
Args: |
|
|
config: Model configuration. |
|
|
limit: Parameter limit to compare against. |
|
|
""" |
|
|
estimates = estimate_parameters(config) |
|
|
|
|
|
print("=" * 60) |
|
|
print("PARAMETER BUDGET ANALYSIS") |
|
|
print("=" * 60) |
|
|
print(f"\nConfiguration:") |
|
|
print(f" vocab_size (V) = {config.vocab_size}") |
|
|
print(f" n_embd (d) = {config.n_embd}") |
|
|
print(f" n_layer (L) = {config.n_layer}") |
|
|
print(f" n_head = {config.n_head}") |
|
|
print(f" n_ctx = {config.n_ctx}") |
|
|
print(f" n_inner = {config.n_inner}") |
|
|
print(f" tie_weights = {config.tie_weights}") |
|
|
|
|
|
print(f"\nParameter Breakdown:") |
|
|
print(f" Token Embeddings: {estimates['token_embeddings']:>10,}") |
|
|
print(f" Position Embeddings: {estimates['position_embeddings']:>10,}") |
|
|
print(f" Transformer Layers: {estimates['total_transformer_layers']:>10,}") |
|
|
print(f" Final LayerNorm: {estimates['final_layernorm']:>10,}") |
|
|
|
|
|
if config.tie_weights: |
|
|
print(f" LM Head: {'(tied)':>10}") |
|
|
else: |
|
|
print(f" LM Head: {estimates['lm_head']:>10,}") |
|
|
|
|
|
print(f" " + "-" * 30) |
|
|
print(f" TOTAL: {estimates['total']:>10,}") |
|
|
|
|
|
print(f"\nBudget Status:") |
|
|
print(f" Limit: {limit:>10,}") |
|
|
print(f" Used: {estimates['total']:>10,}") |
|
|
print(f" Remaining:{limit - estimates['total']:>10,}") |
|
|
|
|
|
if estimates['total'] <= limit: |
|
|
print(f"\n Within budget! ({estimates['total'] / limit * 100:.1f}% used)") |
|
|
else: |
|
|
print(f"\n OVER BUDGET by {estimates['total'] - limit:,} parameters!") |
|
|
|
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool: |
|
|
""" |
|
|
Validate a move using python-chess. |
|
|
|
|
|
This function converts the dataset's extended UCI format to standard UCI |
|
|
and validates it against the current board state. |
|
|
|
|
|
Args: |
|
|
move: Move in extended UCI format (e.g., "WPe2e4", "BNg8f6(x)"). |
|
|
board_fen: FEN string of the current board state (optional). |
|
|
|
|
|
Returns: |
|
|
True if the move is legal, False otherwise. |
|
|
""" |
|
|
try: |
|
|
import chess |
|
|
except ImportError: |
|
|
raise ImportError("python-chess is required for move validation. " |
|
|
"Install it with: pip install python-chess") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(move) < 6: |
|
|
return False |
|
|
|
|
|
|
|
|
color = move[0] |
|
|
piece = move[1] |
|
|
from_sq = move[2:4] |
|
|
to_sq = move[4:6] |
|
|
|
|
|
|
|
|
promotion = None |
|
|
if "=" in move: |
|
|
promo_idx = move.index("=") |
|
|
promotion = move[promo_idx + 1].lower() |
|
|
|
|
|
|
|
|
board = chess.Board(board_fen) if board_fen else chess.Board() |
|
|
|
|
|
|
|
|
uci_move = from_sq + to_sq |
|
|
if promotion: |
|
|
uci_move += promotion |
|
|
|
|
|
try: |
|
|
move_obj = chess.Move.from_uci(uci_move) |
|
|
return move_obj in board.legal_moves |
|
|
except (ValueError, chess.InvalidMoveError): |
|
|
return False |
|
|
|
|
|
|
|
|
def convert_extended_uci_to_uci(move: str) -> str: |
|
|
""" |
|
|
Convert extended UCI format to standard UCI format. |
|
|
|
|
|
Args: |
|
|
move: Move in extended UCI format (e.g., "WPe2e4"). |
|
|
|
|
|
Returns: |
|
|
Move in standard UCI format (e.g., "e2e4"). |
|
|
""" |
|
|
if len(move) < 6: |
|
|
return move |
|
|
|
|
|
|
|
|
from_sq = move[2:4] |
|
|
to_sq = move[4:6] |
|
|
|
|
|
|
|
|
promotion = "" |
|
|
if "=" in move: |
|
|
promo_idx = move.index("=") |
|
|
promotion = move[promo_idx + 1].lower() |
|
|
|
|
|
return from_sq + to_sq + promotion |
|
|
|
|
|
|
|
|
def convert_uci_to_extended( |
|
|
uci_move: str, |
|
|
board_fen: str, |
|
|
) -> str: |
|
|
""" |
|
|
Convert standard UCI format to extended UCI format. |
|
|
|
|
|
Args: |
|
|
uci_move: Move in standard UCI format (e.g., "e2e4"). |
|
|
board_fen: FEN string of the current board state. |
|
|
|
|
|
Returns: |
|
|
Move in extended UCI format (e.g., "WPe2e4"). |
|
|
""" |
|
|
try: |
|
|
import chess |
|
|
except ImportError: |
|
|
raise ImportError("python-chess is required for move conversion.") |
|
|
|
|
|
board = chess.Board(board_fen) |
|
|
move = chess.Move.from_uci(uci_move) |
|
|
|
|
|
|
|
|
color = "W" if board.turn == chess.WHITE else "B" |
|
|
|
|
|
|
|
|
piece = board.piece_at(move.from_square) |
|
|
piece_letter = piece.symbol().upper() if piece else "P" |
|
|
|
|
|
|
|
|
from_sq = chess.square_name(move.from_square) |
|
|
to_sq = chess.square_name(move.to_square) |
|
|
|
|
|
result = f"{color}{piece_letter}{from_sq}{to_sq}" |
|
|
|
|
|
|
|
|
if move.promotion: |
|
|
result += f"={chess.piece_symbol(move.promotion).upper()}" |
|
|
|
|
|
|
|
|
if board.is_capture(move): |
|
|
result += "(x)" |
|
|
|
|
|
|
|
|
board.push(move) |
|
|
if board.is_checkmate(): |
|
|
if "(x)" in result: |
|
|
result = result.replace("(x)", "(x+*)") |
|
|
else: |
|
|
result += "(+*)" |
|
|
elif board.is_check(): |
|
|
if "(x)" in result: |
|
|
result = result.replace("(x)", "(x+)") |
|
|
else: |
|
|
result += "(+)" |
|
|
board.pop() |
|
|
|
|
|
|
|
|
if board.is_castling(move): |
|
|
if move.to_square in [chess.G1, chess.G8]: |
|
|
result = result.replace("(x)", "").replace("(+)", "") + "(o)" |
|
|
else: |
|
|
result = result.replace("(x)", "").replace("(+)", "") + "(O)" |
|
|
|
|
|
return result |
|
|
|