File size: 33,443 Bytes
e8a0fd8
 
 
 
 
 
 
 
 
 
e50a529
e8a0fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e50a529
e8a0fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ebcaf
e8a0fd8
 
a2ebcaf
 
e8a0fd8
a2ebcaf
 
 
 
e8a0fd8
a2ebcaf
 
 
 
e8a0fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e50a529
 
a2ebcaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8a0fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2ebcaf
e8a0fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
"""An engine class to provide a universal way to interact with both chessformer and stockfish"""
import torch
import chess
import math
import chess.engine
import multiprocessing
from dataclasses import dataclass, field
from functools import partial
import time
import os
import spaces

try:
    from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
except ImportError:
    from mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
from torch.distributions import Categorical
from typing import Optional, Tuple, List, Union

@dataclass
class ChessformerConfig:
    chessformer: torch.nn.Module=None
    device: Optional[torch.device]=None
    temperature: float=0.5
    depth: int=2
    top_k: int=8
    decay_rate: float=0.6
    max_batch_size: int=896

@dataclass
class StockfishConfig:
    engine_path: str="/usr/games/stockfish"
    depth: int=16


def _stockfish_worker(board_fen: str, engine_path: str, depth: int) -> Optional[Tuple[str, float]]:
    """
    Analyzes a single board FEN using a temporary Stockfish engine instance.
    Designed for use with multiprocessing.
    Returns the best move UCI and the normalized score [-1,1].
    Does not handle draw claims explicitly as FEN lacks history.
    Caller should check board.is_game_over() on the main board object.
    """
    engine = None
    try:
        engine = chess.engine.SimpleEngine.popen_uci(engine_path)
        # initialize board from FEN - history is lost here
        board = chess.Board(board_fen)

        info = engine.analyse(board, chess.engine.Limit(depth=depth))
        
        score_obj = info.get("score")
        pv = info.get("pv")

        if score_obj is None or pv is None or not pv:
            # Analysis failed
            print(f"Warning: Stockfish analysis failed for FEN: {board_fen}")
            return None
        
        best_move_uci = pv[0].uci()
        pov_score = score_obj.pov(board.turn)
        cp = None

        if pov_score.is_mate():
            mate_score = pov_score.mate()
            cp = 10000.0 if mate_score > 0 else -10000.0
        elif pov_score.cp is not None:
            cp = float(pov_score.cp)
        else:
            print(f"Warning: Stockfish score object lacks cp/mate for FEN: {board_fen}")
            return None # analysis is unclear

        normalized_cp = 2 / (1 + math.exp(-0.004*cp)) - 1        

        return best_move_uci, normalized_cp
    
    except (chess.engine.EngineError, chess.engine.EngineTerminatedError, FileNotFoundError, ValueError) as e:
        print(f"Stockfish worker error for FEN {board_fen}: {e}")
        return None
    finally:
        if engine:
            engine.quit()

def _compute_repetition_single(board: chess.Board) -> int:
    """Compute repetition count. Used in _chessformer_move and _batch_chessformer_move"""
    
    transposition_key = board._transposition_key()
    count = 0
    if board.move_stack:
        if board._transposition_key() == transposition_key:
            count = 1
    else:
        count = 1
    try:
        # Iterate back through history
        while board.move_stack:
            move = board.pop() # note that history is lost here
            if board.is_irreversible(move):
                break
            if board._transposition_key() == transposition_key:
                count += 1
    except Exception as e:
        print(f"Error occurred during repetition count for board {board.fen()}: {e}")
        return 1 # fallback to 1
    return max(1, count)
    
# Engine class, designed to be used in the Evaluator class and app.py
class Engine:
    def __init__(self,
                 type: str,
                 chessformer_config: Optional[ChessformerConfig]=None,
                 stockfish_config: Optional[StockfishConfig]=None):
        self.type = type
        if type == "chessformer":
            if chessformer_config is None:
                raise ValueError("ChessformerConfig must be provided for chessformer engine.")
            
            self.config = chessformer_config
            if self.config.chessformer is None:
                raise ValueError("ChessFormer model must be provided in config.")
            
            if self.config.device is None:
                self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            elif isinstance(self.config.device, str):
                self.device = torch.device(self.config.device)
            else:
                self.device = self.config.device

            self.model = self.config.chessformer
            self.model.to(self.device)
            self.model.eval()

            if not (self.config.temperature > 0):
                raise ValueError("Temperature must be greater than 0.")
            if not (self.config.top_k > 0):
                raise ValueError("Top-k must be greater than 0.")
            if not (self.config.depth >= 0):
                raise ValueError("Depth must be greater than or equal to 0.")
            if not (0.0 < self.config.decay_rate <= 1.0):
                raise ValueError("Decay rate must be in range (0.0,1.0].")
            if not (self.config.max_batch_size > 0):
                raise ValueError("Max batch size must be an integer greater than 0.")
                
            self.temperature = self.config.temperature
            self.top_k = self.config.top_k
            self.initial_k = self.top_k
            self.depth = self.config.depth
            self.decay_rate = self.config.decay_rate
            self.max_batch_size = self.config.max_batch_size
        elif type == "stockfish":
            if stockfish_config is None:
                raise ValueError("StockfishConfig must be provided for stockfish engine.")
            
            self.config = stockfish_config
            self.engine_path = self.config.engine_path
            self.depth = self.config.depth
            if self.config.engine_path is None:
                raise ValueError("Engine path must be provided in config.")
            try:
                with chess.engine.SimpleEngine.popen_uci(self.config.engine_path) as test:
                    pass
            except (FileNotFoundError, chess.engine.EngineError) as e:
                raise ValueError(f"Invalid engine path or engine not found: {e}")
        else:
            raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
    
    def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
        bs = len(boards)
        possible_moves = len(UCI_MOVE_TO_IDX)
        invalid_mask = torch.full((bs,possible_moves), -torch.inf, dtype=torch.float32, device=self.device)
        for idx,board in enumerate(boards):
            if board.is_game_over(claim_draw=True):
                continue # leave all as -inf
            legal_moves = list(board.legal_moves)
            legal_move_ids = [UCI_MOVE_TO_IDX[move.uci()] for move in legal_moves]
            if legal_move_ids:
                invalid_mask[idx,legal_move_ids] = 0
            if board.can_claim_draw():
                invalid_mask[idx,0] = 0
        
        return invalid_mask

    def compute_repetition(self, boards: List[chess.Board]) -> torch.Tensor:
        """Use multiprocessing to compute repetition count for a batch of boards."""
        bs = len(boards)
        num_workers = min(bs, max(1, os.cpu_count()//2 if os.cpu_count else 1))
        if bs < num_workers * 2: # avoid overhead for very small batches per worker
            num_workers = max(1, bs//2)

        try:
            if num_workers > 1 and bs > 1:
                board_copies = [board.copy(stack=True) for board in boards]
                with multiprocessing.Pool(processes=num_workers) as pool:
                    counts = pool.map(_compute_repetition_single, board_copies)
            else:
                # Run sequentially if only one worker needed or very small batch
                counts = [_compute_repetition_single(b.copy(stack=True)) for b in boards]

            counts_tensor = torch.tensor(counts, dtype=torch.long, device=self.device)
            return counts_tensor # (B,)
        except Exception as e:
            print(f"Error during batch repetition computation: {e}")
            # Fall back to single board computation if multiprocessing fails
            return torch.ones((bs,),dtype=torch.long, device=self.device)

    def _raw_chessformer_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
        """Get the next move from ChessFormer model with optional tactical verification."""
        # Get FEN
        fen = board.fen()

        # Compute repetition
        count_tensor = self.compute_repetition([board])

        move_logits, value = self.model([fen],count_tensor)
        move_logits = move_logits.squeeze(0) # remove batch dimension since it will always be 1
        value = value.squeeze(0).item()
        
        # Calculate invalid mask
        legal_moves = list(board.legal_moves)
        if not legal_moves and not board.can_claim_draw():
            # No legal moves. Should not happen if this function is called correctly, but it wouldn't hurt to add a check
            return None
        legal_move_ids = [UCI_MOVE_TO_IDX[move.uci()] for move in legal_moves]
        invalid_mask = torch.ones_like(move_logits) * (-torch.inf)
        invalid_mask[legal_move_ids] = 0
        if board.can_claim_draw():
            invalid_mask[0] = 0
        move_logits = move_logits + invalid_mask

        if return_perplexity:
            probs = torch.softmax(move_logits, dim=-1)
            perplexity = torch.exp(-torch.sum(probs*torch.log(probs+1e-8))).item()
        
        top_k_ids = torch.topk(move_logits, self.top_k, dim=-1).indices
        top_k_mask = torch.ones_like(move_logits) * (-torch.inf)
        top_k_mask[top_k_ids] = 0
        move_logits = move_logits + top_k_mask
        move_logits = move_logits / self.temperature

        # Sample
        dist = Categorical(logits=move_logits)
        move_id = dist.sample().item()
        move = IDX_TO_UCI_MOVE[move_id]
        if return_perplexity:
            return move, value, perplexity
        else:
            return move, value

    def _search_enhanced_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
        """Get move from chessformer using tactical search"""
        # Step 1: Build search tree level by level
        current_boards = [board] # aggregate board to a list for batch inference
        board_probs = [1] # the probabilities of getting to this position (estimated)

        terminal_leaves = [] # (root_move, prob, game_result_value) ^from white's perspective
        search_leaves = [] # (root_move, prob, board) - leaves not terminal but reached max depth therefore needs evaluation from model

        # Track which root_move each board came from
        board_to_root_move = [None] # root board has no parent move

        for depth in range(self.depth+1):
            if not current_boards:
                break
            k = max(1,int(self.initial_k*(self.decay_rate**depth)))

            fens = [b.fen() for b in current_boards]
            reps = self.compute_repetition(current_boards)

            with torch.no_grad():
                logits, values = self.model(fens,reps)

            next_boards = []
            next_board_probs = []
            next_board_to_root_move = []

            # Process each board at current depth
            for board_idx, current_board in enumerate(current_boards):
                board_logits = logits[board_idx]
                board_prob = board_probs[board_idx]
                parent_root_move = board_to_root_move[board_idx]
                
                # Check if game is over
                if current_board.is_game_over(claim_draw=True):
                    outcome = current_board.outcome(claim_draw=True)
                    if outcome.winner == chess.WHITE:
                        game_value = 1.0
                    elif outcome.winner == chess.BLACK:
                        game_value = -1.0
                    else:
                        game_value = 0.0
                    terminal_leaves.append((parent_root_move, board_prob, game_value))
                    continue

                # If we've reached max depth, add to search leaves
                if depth == self.depth:
                    search_leaves.append((parent_root_move, board_prob, current_board))
                    continue

                # Otherwise, recursively search deeper
                invalid_mask = self.get_invalid_mask([current_board])[0]
                masked_logits = board_logits + invalid_mask

                top_k_values, top_k_indices = torch.topk(masked_logits,k=min(k,torch.sum(invalid_mask==0).item()))
                top_k_probs = torch.softmax(top_k_values,dim=0)
                if depth==0:
                    initial_masked_logits = masked_logits.squeeze(0)
                    initial_invalid_mask = invalid_mask.squeeze(0)
                    initial_top_k_indices = top_k_indices

                # Expand each top k move
                for i,move_idx in enumerate(top_k_indices):
                    move_prob = top_k_probs[i].item()
                    move_uci = IDX_TO_UCI_MOVE[move_idx.item()]

                    root_move = parent_root_move if parent_root_move is not None else move_uci

                    new_board = current_board.copy()

                    if move_uci == "<claim_draw>":
                        if new_board.can_claim_draw():
                            terminal_leaves.append((root_move,board_prob*move_prob,0.0))
                            continue
                        else:
                            continue # should not happen, invalid draw claim
                    else:
                        move = chess.Move.from_uci(move_uci)
                        new_board.push(move)
                    
                    next_boards.append(new_board)
                    next_board_probs.append(board_prob*move_prob)
                    next_board_to_root_move.append(root_move)

            current_boards = next_boards
            board_probs = next_board_probs
            board_to_root_move = next_board_to_root_move

        # Step 2: Evaluate all search leaves
        if search_leaves:
            search_boards = [leaf[2] for leaf in search_leaves]
            search_fens = [b.fen() for b in search_boards]
            search_reps = self.compute_repetition(search_boards)

            with torch.no_grad():
                _, search_values = self.model(search_fens, search_reps)

            for i, (root_move, prob, leaf_board) in enumerate(search_leaves):
                value = search_values[i].item()
                white_perspective_value = value if leaf_board.turn == chess.WHITE else -value
                terminal_leaves.append((root_move,prob,white_perspective_value))
        
        # Step 3: Aggregate all leaves using probability weights
        root_move_weighted_values = {}
        root_move_total_probs = {}
        for root_move, prob, value in terminal_leaves:
            if root_move not in root_move_weighted_values:
                root_move_weighted_values[root_move] = 0.0
                root_move_total_probs[root_move] = 0.0
            root_move_weighted_values[root_move] += prob * value
            root_move_total_probs[root_move] += prob
        
        final_value = sum(root_move_weighted_values.values())
        final_value = final_value if board.turn == chess.WHITE else -final_value
        
        root_move_values = {}
        for root_move in root_move_total_probs:
            if root_move_total_probs[root_move] > 0:
                root_move_values[root_move] = root_move_weighted_values[root_move] / root_move_total_probs[root_move]
            else:
                root_move_values[root_move] = 0

        # Step 4: Confidence-based weighting with search results
        initial_probs = torch.softmax(initial_masked_logits,dim=0)
        entropy = -torch.sum(initial_probs*torch.log(initial_probs+1e-8))
        max_entropy = math.log(torch.sum(initial_invalid_mask==0).item())
        confidence = 1.0 - (entropy/max_entropy) if max_entropy > 0 else 1.0

        if root_move_values:
            search_adjustment_logits = torch.zeros_like(initial_masked_logits)
            for move_uci,search_value in root_move_values.items():
                move_idx = UCI_MOVE_TO_IDX[move_uci]
                search_adjustment_logits[move_idx] += search_value
            # flip value according to perpective
            search_adjustment_logits = search_adjustment_logits if board.turn==chess.WHITE else -search_adjustment_logits
            search_adjustment_logits = search_adjustment_logits - search_adjustment_logits.mean()

            # Normalize search logits to be in the same norm as the initial logits

            initial_valid_norm = torch.norm(initial_masked_logits[initial_top_k_indices]) + 1e-8
            search_valid_norm = torch.norm(search_adjustment_logits[initial_top_k_indices]) + 1e-8

            normalized_search = search_adjustment_logits * initial_valid_norm / search_valid_norm
            normalized_initial = initial_masked_logits

            adjusted_logits = confidence * normalized_initial + (1 - confidence) * normalized_search
        else:
            adjusted_logits = initial_masked_logits

        # Apply temperature and top-k filtering
        top_k_mask = torch.full_like(adjusted_logits, -torch.inf)
        top_k_mask[initial_top_k_indices] = 0
        adjusted_logits = adjusted_logits + top_k_mask
        adjusted_logits = adjusted_logits / self.temperature

        dist = Categorical(logits=adjusted_logits)
        move_idx = dist.sample().item()
        move_uci = IDX_TO_UCI_MOVE[move_idx]

        if return_perplexity:
            final_probs = torch.softmax(adjusted_logits,dim=0)
            perplexity = torch.exp(-torch.sum(final_probs * torch.log(final_probs + 1e-8))).item()

            if verbose and self.depth > 0:
                print(f"\n--- Search Enhanced Move Debug Info ({board.fen()}) ---")
                print(f"Confidence: {confidence:.4f}")

                print("\nMove Analysis (Initial Top-K Candidates):")
                print(f"{'Move':<8} {'Initial Logit':<15} {'Search Adj. Logit':<19} {'Final Adj. Logit':<18} {'Final Prob':<12}")
                print(f"{'-'*8:<8} {'-'*15:<15} {'-'*19:<19} {'-'*18:<18} {'-'*12:<12}")

                for i, idx in enumerate(initial_top_k_indices):
                    move_uci_v = IDX_TO_UCI_MOVE[idx.item()]
                    initial_logit = normalized_initial[idx].item()
                    
                    search_adj_logit_val = normalized_search[idx].item() if root_move_values else 0.0
                    
                    final_adj_logit = adjusted_logits[idx].item()
                    final_prob_val = final_probs[idx].item()

                    print(f"{move_uci_v:<8} {initial_logit:<15.4f} {search_adj_logit_val:<19.4f} {final_adj_logit:<18.4f} {final_prob_val:<12.4f}")

                print(f"\nPerplexity: {perplexity:.4f}")
                print(f"Predicted Value (White's POV): {final_value:.4f}")

                print("\nLeaf Node Values (Root Move, Probability, Value from White's POV):")
                for rm, prob, val in terminal_leaves:
                    print(f"  Root Move: {rm:<8}, Prob: {prob:<.4f}, Value: {val:<.4f}")
                print("--------------------------------------------------")

            return move_uci, final_value, perplexity
        else:
            return move_uci, final_value

    @spaces.GPU
    def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
        """Get move from chessformer with optional search enhance"""
        self.device = torch.device("cuda")
        self.model.to(torch.device("cuda"))
        if self.depth == 0:
            result = self._raw_chessformer_move(board,return_perplexity)
            self.model.to(torch.device("cpu"))
            self.device = torch.device("cpu")
            return result
        else:
            result = self._search_enhanced_move(board,return_perplexity,verbose)
            self.model.to(torch.device("cpu"))
            self.device = torch.device("cpu")
            return result

    def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
        """Get best move from stockfish"""
        try:
            engine = chess.engine.SimpleEngine.popen_uci(self.engine_path)
            info = engine.analyse(board, chess.engine.Limit(depth=self.depth))
        except (chess.engine.EngineError, chess.engine.EngineTerminatedError) as e:
            print(f"Stockfish error: {e}")
            return None

        loss_threshold = -0.4

        score_obj = info.get("score")
        can_claim_draw = board.can_claim_draw()
        if score_obj is None or info.get("pv") is None or not info.get("pv"):
            # Invalid analysis result
            return None
        
        pv = info["pv"]
        pov_score = score_obj.pov(chess.WHITE)
        cp = None
        if pov_score.is_mate():
            mate_score = pov_score.mate()
            cp = 10000.0 if mate_score > 0 else -10000.0
            relative_score = score_obj.relative
            if relative_score.is_mate():
                cp = 10000.0 if relative_score.mate() > 0 else -10000.0
            else:
                if relative_score.cp is not None:
                    cp = float(relative_score.cp)
                else:
                    return None

        elif pov_score.cp is not None:
            relative_score = score_obj.relative
            if relative_score.cp is not None:
                cp = float(relative_score.cp)
            else:
                return None

        else:
            return None
        
        if cp is not None:
            normalized_score = 2 / (1+math.exp(-0.004*cp)) - 1
        else:
            return None

        if can_claim_draw and normalized_score < loss_threshold:
            best_move_uci = "<claim_draw>"
        else:
            best_move_uci = pv[0].uci()

        if engine:
            engine.quit()

        if return_perplexity:
            return best_move_uci, normalized_score, None
        else:
            return best_move_uci, normalized_score

    def _batch_chessformer_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
        """Get the next moves from Chessformer model using batch inference."""
        bs = len(boards)
        if bs > self.max_batch_size:
            raise ValueError(f"num boards ({bs}) exceeded max batch size ({self.max_batch_size}).")
        fens = [board.fen() for board in boards]

        count_tensor = self.compute_repetition(boards) # shape (bs, 1)
        count_tensor = count_tensor.to(self.device)

        with torch.no_grad():
            move_logits, values = self.model(fens, count_tensor)

        invalid_mask = self.get_invalid_mask(boards)

        # Apply mask
        move_logits = move_logits + invalid_mask

        all_masked = torch.all(torch.isinf(move_logits), dim=-1)

        # Apply top-p filtering
        if 0.0 < self.top_p < 1.0: # Apply only if top_p is strictly between 0 and 1
            sorted_logits, sorted_indices = torch.sort(move_logits, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > self.top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = torch.zeros_like(move_logits, dtype=torch.bool).scatter_(
                dim=-1, index=sorted_indices, src=sorted_indices_to_remove
            )
            move_logits[indices_to_remove] = -torch.inf

        # Apply temperature
        temp = self.temperature if self.temperature > 0 else 1.0
        move_logits = move_logits / temp

        # Sample moves
        dist = Categorical(logits=move_logits)
        try:
            sampled_indices = dist.sample()
        except RuntimeError as e:
            print(f"Error sampling moves: {e}. Checking logit values...")
            results = []
            for i in range(bs):
                print(f"Board {i} logits sum: {torch.logsumexp(move_logits[i], dim=-1)}")
                results.append(None) # indicate failure
            return results
        
        results = []
        for i in range(bs):
            if all_masked[i]:
                results.append(None) # Game already over
                continue

            move_id = sampled_indices[i].item()
            move_uci = IDX_TO_UCI_MOVE.get(move_id)
            value = values[i].item()

            if move_uci is None:
                print(f"Warning: Sampled move ID {move_id} not in IDX_TO_UCI_MOVE map")
                results.append(None)
                continue

            results.append((move_uci, value))

        
        return results

    def _batch_stockfish_move(self, boards: List[chess.Board], allow_claim_draw: bool=False) -> List[Tuple[str, float]]:
        """Get the next moves from Stockfish engine using multiprocessing"""
        if allow_claim_draw:
            """Use sequential processing to maintain board history"""
            return [self._stockfish_move(board) for board in boards]
        else:
            """Use multiprocessing to speed up if no need to include claim draw logic"""
            bs = len(boards)
            num_workers = min(bs, max(1, os.cpu_count()//2 if os.cpu_count() else 1))
            if bs < num_workers * 2:
                num_workers = max(1, bs//2)
                if bs == 1: num_workers = 1

            board_fens = [board.fen() for board in boards]

            worker_func = partial(_stockfish_worker,
                                  engine_path=self.engine_path,
                                  depth=self.depth)
            results: List[Optional[Tuple[str,float]]] = [None] * bs

            active_indices = [i for i,b in enumerate(boards) if not b.is_game_over(claim_draw=True)]
            active_fens = [board_fens[i] for i in active_indices]

            if not active_fens:
                # All games are over
                return results # list of None
            
            try:
                if num_workers > 1 and len(active_fens) > 1:
                    with multiprocessing.Pool(processes=num_workers) as pool:
                        worker_results = pool.map(worker_func, active_fens)
                else:
                    worker_results = [worker_func(fen) for fen in active_fens]

                for i, res in enumerate(worker_results):
                    original_index = active_indices[i]
                    results[original_index] = res

            except Exception as e:
                print(f"Error during batch Stockfish move: {e}")
            
            return results

    def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
        if self.type == "stockfish":
            return self._stockfish_move(board, return_perplexity)
        elif self.type == "chessformer":
            return self._chessformer_move(board, return_perplexity)
        else:
            raise ValueError(f"Invalid engine type: {self.type}")
        
    def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
        if self.type == "stockfish":
            return self._batch_stockfish_move(boards)
        elif self.type == "chessformer":
            return self._batch_chessformer_move(boards)
        else:
            raise ValueError(f"Invalid engine type: {self.type}")

    @spaces.GPU
    def _analyze_position_gpu(self, board: chess.Board) -> Optional[float]:
        """
        Only acquire ZeroGPU when model forward pass is needed.
        """
        fen = board.fen()
        
        self.model = self.model.to(torch.device("cuda"))
        self.device = torch.device("cuda")
        
        count_tensor = self.compute_repetition([board.copy(stack=True)])
        
        with torch.no_grad():
            _, value = self.model([fen],count_tensor)

        self.model = self.model.to("cpu")
        self.device = torch.device("cpu")

        value = value.item()
        return value if board.turn == chess.WHITE else -value
        
    def analyze_position(self, board: chess.Board) -> Optional[float]:
        """
        Analyzes the given **single board** position using the engine.
        For Stockfish, returns list of centipawn scores from white's perspective;
        For ChessFormer, returns list of models's value estimates
        Returns None if analysis failed.
        """  
        if self.type == "stockfish":
            try:
                engine = chess.engine.SimpleEngine.popen_uci(self.engine_path)
                info = engine.analyse(board,chess.engine.Limit(depth=self.depth))
                engine.quit()
            except Exception as e:
                print(f"Stockfish error: {e}")
                return None
            
            score_obj = info.get("score")
            pov_score = score_obj.pov(chess.WHITE)
            cp = None
            if pov_score.is_mate():
                mate_score = pov_score.mate()
                cp = 10000.0 if mate_score > 0 else -10000.0
                relative_score = score_obj.relative
                if relative_score.is_mate():
                    cp = 10000.0 if relative_score.mate() > 0 else -10000.0
                else:
                    if relative_score.cp is not None:
                        cp = float(relative_score.cp)
                    else:
                        return None
            elif pov_score.cp is not None:
                relative_score = score_obj.relative
                if relative_score.cp is not None:
                    cp = float(relative_score.cp)
                else:
                    return None
            else:
                return None

            if cp is not None:
                normalized_score = 2 / (1+math.exp(-0.004*cp)) - 1
                return normalized_score if board.turn == chess.WHITE else -normalized_score
            else:
                return None
            
        
        elif self.type == "chessformer":
            return self._analyze_position_gpu(board)
        else:
            raise ValueError(f"Invalid engine type.")


def test_search_enhanced_move(model_path,device):
    """Test the search-enhanced move functionality"""
    print("\n--- Testing Search-Enhanced ChessFormer ---")
    
    import sys
    sys.path.append("./")
    try:
        from model import ChessFormerModel
    except ImportError:
        from model import ChessFormerModel
    
    # Load the trained model
    checkpoint = torch.load(model_path,map_location=device)
    config = checkpoint["config"]
    model = ChessFormerModel(**config)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    model.to(device)
    
    # Test different search configurations
    test_configs = [
        #{"depth": 0, "top_k": 8, "decay_rate": 0.6, "temperature": 0.2},  # No search (baseline)
        #{"depth": 1, "top_k": 8, "decay_rate": 0.6, "temperature": 0.2},  # Shallow search
        {"depth": 8, "top_k": 8, "decay_rate": 0.5, "temperature": 0.5},  # Medium search
    ]
        
    # Test positions
    test_positions = [
        #chess.Board(),  # Starting position
        #chess.Board("r1bqkb1r/pppp1ppp/2n2n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4"),  # Italian Game
        #chess.Board("rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq c6 0 2"),  # Sicilian Defense
        #chess.Board("r1bq1rk1/ppp2ppp/2n2n2/2bpp3/2B1P3/3P1N2/PPP2PPP/RNBQ1RK1 w - - 0 6"),  # Complex middlegame
        chess.Board("r1b1k2r/1p2bpp1/2p1p1np/2N1P3/1q1P4/5N2/B1Q2PPP/R3R1K1 w kq - 0 19"), # blunder: c2e4
        chess.Board("rn1qk2r/1b2bpp1/1pp1pn1p/p7/3P4/2PB1N2/PP1NQPPP/R1B1R1K1 w kq - 2 12"), # blunder: e2e6
    ]

    for i, cfg in enumerate(test_configs):
        print(f"\n--- Test Configuration {i+1}: Depth={cfg['depth']}, Top-K={cfg['top_k']}, Decay={cfg['decay_rate']}, Temp={cfg['temperature']} ---")
        chessformer_config = ChessformerConfig(
            chessformer=model,
            device=device,
            temperature=cfg['temperature'],
            depth=cfg['depth'],
            top_k=cfg['top_k'],
            decay_rate=cfg['decay_rate']
        )
        engine = Engine(type="chessformer", chessformer_config=chessformer_config)

        for j, board in enumerate(test_positions):
            print(f"\n--- Analyzing Position {j+1}: {board.fen()} ---")
            try:
                move, value, perplexity = engine._chessformer_move(board, return_perplexity=True, verbose=True)
                print(f"Selected Move: {move}, Predicted Value (White's POV): {value:.4f}, Perplexity: {perplexity:.4f}")
            except Exception as e:
                print(f"Error analyzing position {board.fen()}: {e}")
                import traceback
                traceback.print_exc()

if __name__ == "__main__":
    model_path = "./ckpts/chessformer-sl_01.pth"
    device = torch.device("cpu")
    test_search_enhanced_move(model_path,device)