File size: 25,385 Bytes
481ffb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
"""
Evaluation script for the Chess Challenge.

This script evaluates a trained chess model by playing games against
Stockfish and computing ELO ratings.
"""

from __future__ import annotations

import argparse
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch


@dataclass
class GameResult:
    """Result of a single game."""
    moves: List[str]
    result: str  # "1-0", "0-1", or "1/2-1/2"
    model_color: str  # "white" or "black"
    termination: str  # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
    illegal_move_count: int


class ChessEvaluator:
    """
    Evaluator for chess models.
    
    This class handles playing games between a trained model and Stockfish,
    tracking results, and computing ELO ratings.
    """
    
    def __init__(
        self,
        model,
        tokenizer,
        stockfish_path: Optional[str] = None,
        stockfish_level: int = 1,
        max_retries: int = 3,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        """
        Initialize the evaluator.
        
        Args:
            model: The trained chess model.
            tokenizer: The chess tokenizer.
            stockfish_path: Path to Stockfish executable.
            stockfish_level: Stockfish skill level (0-20).
            max_retries: Maximum retries for illegal moves.
            device: Device to run the model on.
        """
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.max_retries = max_retries
        self.device = device
        
        # Initialize Stockfish
        try:
            import chess
            import chess.engine
            
            self.chess = chess
            
            if stockfish_path is None:
                # Try common paths
                import shutil
                stockfish_path = shutil.which("stockfish")
            
            if stockfish_path:
                self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
                self.engine.configure({"Skill Level": stockfish_level})
            else:
                print("WARNING: Stockfish not found. Install it for full evaluation.")
                self.engine = None
                
        except ImportError:
            raise ImportError(
                "python-chess is required for evaluation. "
                "Install it with: pip install python-chess"
            )
    
    def __del__(self):
        """Clean up Stockfish engine."""
        if hasattr(self, 'engine') and self.engine:
            self.engine.quit()
    
    def _convert_board_to_moves(self, board) -> str:
        """Convert board move history to model input format."""
        moves = []
        temp_board = self.chess.Board()
        
        for move in board.move_stack:
            # Get piece and color
            color = "W" if temp_board.turn == self.chess.WHITE else "B"
            piece = temp_board.piece_at(move.from_square)
            piece_letter = piece.symbol().upper() if piece else "P"
            
            # Get squares
            from_sq = self.chess.square_name(move.from_square)
            to_sq = self.chess.square_name(move.to_square)
            
            move_str = f"{color}{piece_letter}{from_sq}{to_sq}"
            
            # Add promotion
            if move.promotion:
                move_str += f"={self.chess.piece_symbol(move.promotion).upper()}"
            
            # Add capture suffix
            if temp_board.is_capture(move):
                move_str += "(x)"
            
            # Add check/checkmate suffix
            temp_board.push(move)
            if temp_board.is_checkmate():
                move_str = move_str.replace("(x)", "(x+*)") if "(x)" in move_str else move_str + "(+*)"
            elif temp_board.is_check():
                move_str = move_str.replace("(x)", "(x+)") if "(x)" in move_str else move_str + "(+)"
            
            # Handle castling
            if piece_letter == "K" and abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
                if to_sq[0] == 'g':  # Kingside
                    move_str = move_str.split("(")[0] + "(o)"
                else:  # Queenside
                    move_str = move_str.split("(")[0] + "(O)"
            
            moves.append(move_str)
        
        return " ".join(moves)
    
    def _is_separator_token(self, token_str: str) -> bool:
        """
        Check if a token represents a separator (whitespace, EOS, etc.).
        
        This allows the evaluator to work with different tokenization strategies:
        - Move-level tokenizers: each move is one token, no separators generated
        - Character-level tokenizers: space character marks end of move
        - BPE/subword tokenizers: may generate partial moves
        
        Args:
            token_str: The decoded token string.
        
        Returns:
            True if this token indicates end of a move.
        """
        # Check for EOS token
        if hasattr(self.tokenizer, 'eos_token') and token_str == self.tokenizer.eos_token:
            return True
        
        # Check for whitespace (space, newline, etc.)
        if token_str.strip() == "" and len(token_str) > 0:
            return True
        
        # Check if the token ends with whitespace (some tokenizers include trailing space)
        if token_str != token_str.rstrip():
            return True
        
        return False

    def _generate_move_tokens(
        self,
        input_ids: torch.Tensor,
        temperature: float = 0.7,
        top_k: int = 10,
        max_tokens: int = 20,
    ) -> str:
        """
        Generate tokens until a separator (whitespace/EOS) is encountered.
        
        This method supports different tokenization strategies:
        - For move-level tokenizers: generates one token (the full move)
        - For character/subword tokenizers: generates until whitespace
        
        Args:
            input_ids: The input token IDs.
            temperature: Sampling temperature.
            top_k: Top-k filtering parameter.
            max_tokens: Maximum tokens to generate for a single move.
        
        Returns:
            The generated move string (without trailing separator).
        """
        generated_tokens = []
        current_ids = input_ids.clone()
        
        for _ in range(max_tokens):
            with torch.no_grad():
                outputs = self.model(input_ids=current_ids)
                logits = outputs.logits[:, -1, :] / temperature
                
                # Apply top-k filtering
                if top_k > 0:
                    top_k_values = torch.topk(logits, min(top_k, logits.size(-1)))[0]
                    indices_to_remove = logits < top_k_values[..., -1, None]
                    logits[indices_to_remove] = float("-inf")
                
                # Sample
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)  # Shape: [1, 1]
            
            # Decode the token
            token_str = self.tokenizer.decode(next_token[0])
            
            # Check if this is a separator token
            if self._is_separator_token(token_str):
                break
            
            generated_tokens.append(next_token[0])  # Store [1] tensor
            
            # Append to input for next iteration (next_token is already [1, 1])
            current_ids = torch.cat([current_ids, next_token], dim=-1)
            
            # For move-level tokenizers, a single non-separator token is the full move
            # We can detect this by checking if the token looks like a complete move
            # (starts with W or B, has enough characters for a move)
            if len(token_str) >= 6 and token_str[0] in "WB":
                break
        
        # Decode all generated tokens together
        if generated_tokens:
            all_tokens = torch.cat(generated_tokens, dim=0)
            move_str = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
            return move_str.strip()
        
        return ""

    def _get_model_move(
        self,
        board,
        temperature: float = 0.7,
        top_k: int = 10,
    ) -> Tuple[Optional[str], int]:
        """
        Get the model's next move prediction.
        
        This method generates tokens until a separator (whitespace/EOS) is produced,
        allowing it to work with different tokenization strategies:
        - Move-level tokenizers: each move is a single token
        - Character-level tokenizers: moves are generated character by character
        - BPE/subword tokenizers: moves may be split into subwords
        
        Returns:
            Tuple of (UCI move string, number of retries used).
        """
        self.model.eval()
        
        # Convert board to input format
        moves_str = self._convert_board_to_moves(board)
        
        # Add BOS token if no moves yet
        if not moves_str:
            input_text = self.tokenizer.bos_token
        else:
            input_text = self.tokenizer.bos_token + " " + moves_str
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            max_length=self.model.config.n_ctx - 10,  # Leave room for generated tokens
        ).to(self.device)
        
        # Try to generate a legal move
        for retry in range(self.max_retries):
            # Generate tokens until separator
            move_token = self._generate_move_tokens(
                inputs["input_ids"],
                temperature=temperature,
                top_k=top_k,
            )
            
            # Convert to UCI
            if len(move_token) >= 6:
                uci_move = move_token[2:4] + move_token[4:6]
                
                # Handle promotion
                if "=" in move_token:
                    promo_idx = move_token.index("=")
                    uci_move += move_token[promo_idx + 1].lower()
                
                try:
                    move = self.chess.Move.from_uci(uci_move)
                    if move in board.legal_moves:
                        return uci_move, retry
                except (ValueError, self.chess.InvalidMoveError):
                    pass
        
        return None, self.max_retries
    
    def _get_stockfish_move(self, board, time_limit: float = 0.1) -> str:
        """Get Stockfish's move."""
        if self.engine is None:
            raise RuntimeError("Stockfish engine not initialized")
        
        result = self.engine.play(board, self.chess.engine.Limit(time=time_limit))
        return result.move.uci()
    
    def play_game(
        self,
        model_color: str = "white",
        max_moves: int = 200,
        temperature: float = 0.7,
    ) -> GameResult:
        """
        Play a single game between the model and Stockfish.
        
        Args:
            model_color: "white" or "black".
            max_moves: Maximum number of moves before draw.
            temperature: Sampling temperature for model.
        
        Returns:
            GameResult with the game details.
        """
        board = self.chess.Board()
        moves = []
        illegal_move_count = 0
        
        model_is_white = model_color == "white"
        
        while not board.is_game_over() and len(moves) < max_moves:
            is_model_turn = (board.turn == self.chess.WHITE) == model_is_white
            
            if is_model_turn:
                # Model's turn
                uci_move, retries = self._get_model_move(board, temperature)
                illegal_move_count += retries
                
                if uci_move is None:
                    # Model couldn't find a legal move
                    return GameResult(
                        moves=moves,
                        result="0-1" if model_is_white else "1-0",
                        model_color=model_color,
                        termination="illegal_move",
                        illegal_move_count=illegal_move_count + 1,
                    )
                
                move = self.chess.Move.from_uci(uci_move)
            else:
                # Stockfish's turn
                if self.engine:
                    uci_move = self._get_stockfish_move(board)
                    move = self.chess.Move.from_uci(uci_move)
                else:
                    # Random move if no engine
                    move = random.choice(list(board.legal_moves))
            
            board.push(move)
            moves.append(move.uci())
        
        # Determine result
        if board.is_checkmate():
            if board.turn == self.chess.WHITE:
                result = "0-1"  # Black wins
            else:
                result = "1-0"  # White wins
            termination = "checkmate"
        elif board.is_stalemate():
            result = "1/2-1/2"
            termination = "stalemate"
        elif board.is_insufficient_material():
            result = "1/2-1/2"
            termination = "insufficient_material"
        elif board.can_claim_draw():
            result = "1/2-1/2"
            termination = "draw_claim"
        elif len(moves) >= max_moves:
            result = "1/2-1/2"
            termination = "max_moves"
        else:
            result = "1/2-1/2"
            termination = "unknown"
        
        return GameResult(
            moves=moves,
            result=result,
            model_color=model_color,
            termination=termination,
            illegal_move_count=illegal_move_count,
        )
    
    def evaluate_legal_moves(
        self,
        n_positions: int = 1000,
        temperature: float = 0.7,
        verbose: bool = True,
    ) -> dict:
        """
        Evaluate the model's ability to generate legal moves.
        
        This evaluation only checks if the model generates legal moves,
        without playing full games. Useful as a first-pass evaluation.
        
        Args:
            n_positions: Number of positions to test.
            temperature: Sampling temperature.
            verbose: Whether to print progress.
        
        Returns:
            Dictionary with legal move statistics.
        """
        results = {
            "total_positions": 0,
            "legal_first_try": 0,
            "legal_with_retry": 0,
            "illegal_all_retries": 0,
            "positions": [],
        }
        
        # Generate random positions by playing random moves
        for i in range(n_positions):
            board = self.chess.Board()
            
            # Play random number of moves (5-40) to get varied positions
            n_random_moves = random.randint(5, 40)
            for _ in range(n_random_moves):
                if board.is_game_over():
                    break
                move = random.choice(list(board.legal_moves))
                board.push(move)
            
            if board.is_game_over():
                continue  # Skip terminal positions
            
            results["total_positions"] += 1
            
            # Test model's move generation
            uci_move, retries = self._get_model_move(board, temperature)
            
            position_result = {
                "fen": board.fen(),
                "move_number": len(board.move_stack),
                "legal": uci_move is not None,
                "retries": retries,
            }
            results["positions"].append(position_result)
            
            if uci_move is not None:
                if retries == 0:
                    results["legal_first_try"] += 1
                else:
                    results["legal_with_retry"] += 1
            else:
                results["illegal_all_retries"] += 1
            
            if verbose and (i + 1) % 100 == 0:
                legal_rate = (results["legal_first_try"] + results["legal_with_retry"]) / results["total_positions"]
                print(f"  Positions: {i + 1}/{n_positions} | Legal rate: {legal_rate:.1%}")
        
        # Calculate statistics
        total = results["total_positions"]
        if total > 0:
            results["legal_rate_first_try"] = results["legal_first_try"] / total
            results["legal_rate_with_retry"] = (results["legal_first_try"] + results["legal_with_retry"]) / total
            results["illegal_rate"] = results["illegal_all_retries"] / total
        else:
            results["legal_rate_first_try"] = 0
            results["legal_rate_with_retry"] = 0
            results["illegal_rate"] = 1
        
        return results
    
    def evaluate(
        self,
        n_games: int = 100,
        temperature: float = 0.7,
        verbose: bool = True,
    ) -> dict:
        """
        Run a full win-rate evaluation of the model against Stockfish.
        
        Args:
            n_games: Number of games to play.
            temperature: Sampling temperature.
            verbose: Whether to print progress.
        
        Returns:
            Dictionary with evaluation metrics.
        """
        results = {
            "wins": 0,
            "losses": 0,
            "draws": 0,
            "illegal_moves": 0,
            "total_moves": 0,
            "games": [],
        }
        
        for i in range(n_games):
            # Alternate colors
            model_color = "white" if i % 2 == 0 else "black"
            
            game = self.play_game(
                model_color=model_color,
                temperature=temperature,
            )
            
            results["games"].append(game)
            results["total_moves"] += len(game.moves)
            results["illegal_moves"] += game.illegal_move_count
            
            # Count result
            if game.result == "1/2-1/2":
                results["draws"] += 1
            elif (game.result == "1-0" and model_color == "white") or \
                 (game.result == "0-1" and model_color == "black"):
                results["wins"] += 1
            else:
                results["losses"] += 1
            
            if verbose and (i + 1) % 10 == 0:
                print(f"  Games: {i + 1}/{n_games} | "
                      f"W: {results['wins']} L: {results['losses']} D: {results['draws']}")
        
        # Calculate statistics
        total = results["wins"] + results["losses"] + results["draws"]
        results["win_rate"] = results["wins"] / total if total > 0 else 0
        results["draw_rate"] = results["draws"] / total if total > 0 else 0
        results["loss_rate"] = results["losses"] / total if total > 0 else 0

        total_attempts = results["total_moves"] + results["illegal_moves"]

        # Average length counts both legal moves and illegal attempts so early illegal terminations
        # don't show as near-zero length games.
        results["avg_game_length"] = total_attempts / total if total > 0 else 0

        # Illegal move rate: illegal attempts over total attempts
        results["illegal_move_rate"] = results["illegal_moves"] / total_attempts if total_attempts > 0 else 0
        
        # Estimate ELO (simplified)
        # Stockfish Level 1 is approximately 1350 ELO
        stockfish_elo = 1350
        if results["win_rate"] > 0 or results["loss_rate"] > 0:
            score = results["wins"] + 0.5 * results["draws"]
            expected = total * 0.5  # Expected score against equal opponent
            
            # Simple ELO estimation
            if score > 0:
                win_ratio = score / total
                if win_ratio > 0 and win_ratio < 1:
                    elo_diff = -400 * (1 - 2 * win_ratio) / (1 if win_ratio > 0.5 else -1)
                    results["estimated_elo"] = stockfish_elo + elo_diff
                else:
                    results["estimated_elo"] = stockfish_elo + (400 if win_ratio >= 1 else -400)
            else:
                results["estimated_elo"] = stockfish_elo - 400
        else:
            results["estimated_elo"] = None
        
        return results


def load_model_from_hub(model_id: str, device: str = "auto"):
    """
    Load a model from the Hugging Face Hub.
    
    Args:
        model_id: Model ID on Hugging Face Hub.
        device: Device to load the model on.
    
    Returns:
        Tuple of (model, tokenizer).
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    # Import to register custom classes
    from src.model import ChessConfig, ChessForCausalLM
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        device_map=device,
    )
    
    return model, tokenizer


def main():
    """Main evaluation function."""
    parser = argparse.ArgumentParser(description="Evaluate a chess model")
    
    parser.add_argument(
        "--model_path", type=str, required=True,
        help="Path to the model or Hugging Face model ID"
    )
    parser.add_argument(
        "--mode", type=str, default="legal", choices=["legal", "winrate", "both"],
        help="Evaluation mode: 'legal' for legal move rate, 'winrate' for games, 'both' for both"
    )
    parser.add_argument(
        "--stockfish_path", type=str, default=None,
        help="Path to Stockfish executable"
    )
    parser.add_argument(
        "--stockfish_level", type=int, default=1,
        help="Stockfish skill level (0-20)"
    )
    parser.add_argument(
        "--n_positions", type=int, default=500,
        help="Number of positions for legal move evaluation"
    )
    parser.add_argument(
        "--n_games", type=int, default=100,
        help="Number of games to play for win rate evaluation"
    )
    parser.add_argument(
        "--temperature", type=float, default=0.7,
        help="Sampling temperature"
    )
    
    args = parser.parse_args()
    
    print("=" * 60)
    print("CHESS CHALLENGE - EVALUATION")
    print("=" * 60)
    
    # Load model
    print(f"\nLoading model from: {args.model_path}")
    
    if "/" in args.model_path and not args.model_path.startswith("."):
        # Assume Hugging Face model ID
        model, tokenizer = load_model_from_hub(args.model_path)
    else:
        # Local path
        from transformers import AutoModelForCausalLM
        from src.tokenizer import ChessTokenizer
        from src.model import ChessConfig, ChessForCausalLM
        
        tokenizer = ChessTokenizer.from_pretrained(args.model_path)
        model = AutoModelForCausalLM.from_pretrained(args.model_path)
    
    # Create evaluator
    print(f"\nSetting up evaluator...")
    evaluator = ChessEvaluator(
        model=model,
        tokenizer=tokenizer,
        stockfish_path=args.stockfish_path,
        stockfish_level=args.stockfish_level,
    )
    
    # Run legal move evaluation
    if args.mode in ["legal", "both"]:
        print(f"\n" + "=" * 60)
        print("PHASE 1: LEGAL MOVE EVALUATION")
        print("=" * 60)
        print(f"Testing {args.n_positions} random positions...")
        
        legal_results = evaluator.evaluate_legal_moves(
            n_positions=args.n_positions,
            temperature=args.temperature,
            verbose=True,
        )
        
        print("\n" + "-" * 40)
        print("LEGAL MOVE RESULTS")
        print("-" * 40)
        print(f"  Positions tested:     {legal_results['total_positions']}")
        print(f"  Legal (1st try):      {legal_results['legal_first_try']} ({legal_results['legal_rate_first_try']:.1%})")
        print(f"  Legal (with retry):   {legal_results['legal_first_try'] + legal_results['legal_with_retry']} ({legal_results['legal_rate_with_retry']:.1%})")
        print(f"  Always illegal:       {legal_results['illegal_all_retries']} ({legal_results['illegal_rate']:.1%})")
    
    # Run win rate evaluation
    if args.mode in ["winrate", "both"]:
        print(f"\n" + "=" * 60)
        print("PHASE 2: WIN RATE EVALUATION")
        print("=" * 60)
        print(f"Playing {args.n_games} games against Stockfish (Level {args.stockfish_level})...")
        
        winrate_results = evaluator.evaluate(
            n_games=args.n_games,
            temperature=args.temperature,
            verbose=True,
        )
        
        print("\n" + "-" * 40)
        print("WIN RATE RESULTS")
        print("-" * 40)
        print(f"  Wins:   {winrate_results['wins']}")
        print(f"  Losses: {winrate_results['losses']}")
        print(f"  Draws:  {winrate_results['draws']}")
        print(f"\n  Win Rate:  {winrate_results['win_rate']:.1%}")
        print(f"  Draw Rate: {winrate_results['draw_rate']:.1%}")
        print(f"  Loss Rate: {winrate_results['loss_rate']:.1%}")
        print(f"\n  Avg Game Length: {winrate_results['avg_game_length']:.1f} moves")
        print(f"  Illegal Move Rate: {winrate_results['illegal_move_rate']:.2%}")
        
        if winrate_results["estimated_elo"]:
            print(f"\n  Estimated ELO: {winrate_results['estimated_elo']:.0f}")
    
    print("\n" + "=" * 60)
    print("EVALUATION COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()