MaximeMuhlethaler commited on
Commit
8cee25e
·
verified ·
1 Parent(s): 1ab595d

Upload utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils.py +305 -0
utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Chess Challenge.
3
+
4
+ This module provides helper functions for:
5
+ - Parameter counting and budget analysis
6
+ - Model registration with Hugging Face
7
+ - Move validation with python-chess
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Dict, Optional, TYPE_CHECKING
13
+
14
+ import torch.nn as nn
15
+
16
+ if TYPE_CHECKING:
17
+ from src.model import ChessConfig
18
+
19
+
20
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
21
+ """
22
+ Count the number of parameters in a model.
23
+
24
+ Args:
25
+ model: The PyTorch model.
26
+ trainable_only: If True, only count trainable parameters.
27
+
28
+ Returns:
29
+ Total number of parameters.
30
+ """
31
+ if trainable_only:
32
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
33
+ return sum(p.numel() for p in model.parameters())
34
+
35
+
36
+ def count_parameters_by_component(model: nn.Module) -> Dict[str, int]:
37
+ """
38
+ Count parameters broken down by model component.
39
+
40
+ Args:
41
+ model: The PyTorch model.
42
+
43
+ Returns:
44
+ Dictionary mapping component names to parameter counts.
45
+ """
46
+ counts = {}
47
+ for name, module in model.named_modules():
48
+ if len(list(module.children())) == 0: # Leaf module
49
+ param_count = sum(p.numel() for p in module.parameters(recurse=False))
50
+ if param_count > 0:
51
+ counts[name] = param_count
52
+ return counts
53
+
54
+
55
+ def estimate_parameters(config: "ChessConfig") -> Dict[str, int]:
56
+ """
57
+ Estimate the parameter count for a given configuration.
58
+
59
+ This is useful for planning your architecture before building the model.
60
+
61
+ Args:
62
+ config: Model configuration.
63
+
64
+ Returns:
65
+ Dictionary with estimated parameter counts by component.
66
+ """
67
+ V = config.vocab_size
68
+ d = config.n_embd
69
+ L = config.n_layer
70
+ n_ctx = config.n_ctx
71
+ n_inner = config.n_inner
72
+
73
+ estimates = {
74
+ "token_embeddings": V * d,
75
+ "position_embeddings": n_ctx * d,
76
+ "attention_qkv_per_layer": 3 * d * d,
77
+ "attention_proj_per_layer": d * d,
78
+ "ffn_per_layer": 2 * d * n_inner,
79
+ "layernorm_per_layer": 4 * d, # 2 LayerNorms, each with weight and bias
80
+ "final_layernorm": 2 * d,
81
+ }
82
+
83
+ # Calculate totals
84
+ per_layer = (
85
+ estimates["attention_qkv_per_layer"] +
86
+ estimates["attention_proj_per_layer"] +
87
+ estimates["ffn_per_layer"] +
88
+ estimates["layernorm_per_layer"]
89
+ )
90
+
91
+ estimates["total_transformer_layers"] = L * per_layer
92
+
93
+ # LM head (tied with embeddings by default)
94
+ if config.tie_weights:
95
+ estimates["lm_head"] = 0
96
+ estimates["lm_head_note"] = "Tied with token embeddings"
97
+ else:
98
+ estimates["lm_head"] = V * d
99
+
100
+ # Grand total
101
+ estimates["total"] = (
102
+ estimates["token_embeddings"] +
103
+ estimates["position_embeddings"] +
104
+ estimates["total_transformer_layers"] +
105
+ estimates["final_layernorm"] +
106
+ estimates["lm_head"]
107
+ )
108
+
109
+ return estimates
110
+
111
+
112
+ def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None:
113
+ """
114
+ Print a formatted parameter budget analysis.
115
+
116
+ Args:
117
+ config: Model configuration.
118
+ limit: Parameter limit to compare against.
119
+ """
120
+ estimates = estimate_parameters(config)
121
+
122
+ print("=" * 60)
123
+ print("PARAMETER BUDGET ANALYSIS")
124
+ print("=" * 60)
125
+ print(f"\nConfiguration:")
126
+ print(f" vocab_size (V) = {config.vocab_size}")
127
+ print(f" n_embd (d) = {config.n_embd}")
128
+ print(f" n_layer (L) = {config.n_layer}")
129
+ print(f" n_head = {config.n_head}")
130
+ print(f" n_ctx = {config.n_ctx}")
131
+ print(f" n_inner = {config.n_inner}")
132
+ print(f" tie_weights = {config.tie_weights}")
133
+
134
+ print(f"\nParameter Breakdown:")
135
+ print(f" Token Embeddings: {estimates['token_embeddings']:>10,}")
136
+ print(f" Position Embeddings: {estimates['position_embeddings']:>10,}")
137
+ print(f" Transformer Layers: {estimates['total_transformer_layers']:>10,}")
138
+ print(f" Final LayerNorm: {estimates['final_layernorm']:>10,}")
139
+
140
+ if config.tie_weights:
141
+ print(f" LM Head: {'(tied)':>10}")
142
+ else:
143
+ print(f" LM Head: {estimates['lm_head']:>10,}")
144
+
145
+ print(f" " + "-" * 30)
146
+ print(f" TOTAL: {estimates['total']:>10,}")
147
+
148
+ print(f"\nBudget Status:")
149
+ print(f" Limit: {limit:>10,}")
150
+ print(f" Used: {estimates['total']:>10,}")
151
+ print(f" Remaining:{limit - estimates['total']:>10,}")
152
+
153
+ if estimates['total'] <= limit:
154
+ print(f"\n Within budget! ({estimates['total'] / limit * 100:.1f}% used)")
155
+ else:
156
+ print(f"\n OVER BUDGET by {estimates['total'] - limit:,} parameters!")
157
+
158
+ print("=" * 60)
159
+
160
+
161
+ def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool:
162
+ """
163
+ Validate a move using python-chess.
164
+
165
+ This function converts the dataset's extended UCI format to standard UCI
166
+ and validates it against the current board state.
167
+
168
+ Args:
169
+ move: Move in extended UCI format (e.g., "WPe2e4", "BNg8f6(x)").
170
+ board_fen: FEN string of the current board state (optional).
171
+
172
+ Returns:
173
+ True if the move is legal, False otherwise.
174
+ """
175
+ try:
176
+ import chess
177
+ except ImportError:
178
+ raise ImportError("python-chess is required for move validation. "
179
+ "Install it with: pip install python-chess")
180
+
181
+ # Parse the extended UCI format
182
+ # Format: [W|B][Piece][from_sq][to_sq][suffix]
183
+ # Example: WPe2e4, BNg8f6(x), WKe1g1(o)
184
+
185
+ if len(move) < 6:
186
+ return False
187
+
188
+ # Extract components
189
+ color = move[0] # W or B
190
+ piece = move[1] # P, N, B, R, Q, K
191
+ from_sq = move[2:4] # e.g., "e2"
192
+ to_sq = move[4:6] # e.g., "e4"
193
+
194
+ # Check for promotion
195
+ promotion = None
196
+ if "=" in move:
197
+ promo_idx = move.index("=")
198
+ promotion = move[promo_idx + 1].lower()
199
+
200
+ # Create board
201
+ board = chess.Board(board_fen) if board_fen else chess.Board()
202
+
203
+ # Build UCI move string
204
+ uci_move = from_sq + to_sq
205
+ if promotion:
206
+ uci_move += promotion
207
+
208
+ try:
209
+ move_obj = chess.Move.from_uci(uci_move)
210
+ return move_obj in board.legal_moves
211
+ except (ValueError, chess.InvalidMoveError):
212
+ return False
213
+
214
+
215
+ def convert_extended_uci_to_uci(move: str) -> str:
216
+ """
217
+ Convert extended UCI format to standard UCI format.
218
+
219
+ Args:
220
+ move: Move in extended UCI format (e.g., "WPe2e4").
221
+
222
+ Returns:
223
+ Move in standard UCI format (e.g., "e2e4").
224
+ """
225
+ if len(move) < 6:
226
+ return move
227
+
228
+ # Extract squares
229
+ from_sq = move[2:4]
230
+ to_sq = move[4:6]
231
+
232
+ # Check for promotion
233
+ promotion = ""
234
+ if "=" in move:
235
+ promo_idx = move.index("=")
236
+ promotion = move[promo_idx + 1].lower()
237
+
238
+ return from_sq + to_sq + promotion
239
+
240
+
241
+ def convert_uci_to_extended(
242
+ uci_move: str,
243
+ board_fen: str,
244
+ ) -> str:
245
+ """
246
+ Convert standard UCI format to extended UCI format.
247
+
248
+ Args:
249
+ uci_move: Move in standard UCI format (e.g., "e2e4").
250
+ board_fen: FEN string of the current board state.
251
+
252
+ Returns:
253
+ Move in extended UCI format (e.g., "WPe2e4").
254
+ """
255
+ try:
256
+ import chess
257
+ except ImportError:
258
+ raise ImportError("python-chess is required for move conversion.")
259
+
260
+ board = chess.Board(board_fen)
261
+ move = chess.Move.from_uci(uci_move)
262
+
263
+ # Get color
264
+ color = "W" if board.turn == chess.WHITE else "B"
265
+
266
+ # Get piece
267
+ piece = board.piece_at(move.from_square)
268
+ piece_letter = piece.symbol().upper() if piece else "P"
269
+
270
+ # Build extended UCI
271
+ from_sq = chess.square_name(move.from_square)
272
+ to_sq = chess.square_name(move.to_square)
273
+
274
+ result = f"{color}{piece_letter}{from_sq}{to_sq}"
275
+
276
+ # Add promotion
277
+ if move.promotion:
278
+ result += f"={chess.piece_symbol(move.promotion).upper()}"
279
+
280
+ # Add suffix for captures
281
+ if board.is_capture(move):
282
+ result += "(x)"
283
+
284
+ # Add suffix for check/checkmate
285
+ board.push(move)
286
+ if board.is_checkmate():
287
+ if "(x)" in result:
288
+ result = result.replace("(x)", "(x+*)")
289
+ else:
290
+ result += "(+*)"
291
+ elif board.is_check():
292
+ if "(x)" in result:
293
+ result = result.replace("(x)", "(x+)")
294
+ else:
295
+ result += "(+)"
296
+ board.pop()
297
+
298
+ # Handle castling notation
299
+ if board.is_castling(move):
300
+ if move.to_square in [chess.G1, chess.G8]: # Kingside
301
+ result = result.replace("(x)", "").replace("(+)", "") + "(o)"
302
+ else: # Queenside
303
+ result = result.replace("(x)", "").replace("(+)", "") + "(O)"
304
+
305
+ return result