Spaces:
Sleeping
Sleeping
Update utils/engine.py
Browse filesoptimize zero gpu acquirement
- utils/engine.py +32 -23
utils/engine.py
CHANGED
|
@@ -440,12 +440,21 @@ class Engine:
|
|
| 440 |
else:
|
| 441 |
return move_uci, final_value
|
| 442 |
|
|
|
|
| 443 |
def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
|
| 444 |
"""Get move from chessformer with optional search enhance"""
|
|
|
|
|
|
|
| 445 |
if self.depth == 0:
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
| 447 |
else:
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
|
| 451 |
"""Get best move from stockfish"""
|
|
@@ -618,18 +627,13 @@ class Engine:
|
|
| 618 |
|
| 619 |
return results
|
| 620 |
|
| 621 |
-
@spaces.GPU
|
| 622 |
def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
|
| 623 |
if self.type == "stockfish":
|
| 624 |
return self._stockfish_move(board, return_perplexity)
|
| 625 |
elif self.type == "chessformer":
|
| 626 |
-
self.device = torch.device("cuda")
|
| 627 |
-
self.model.to(torch.device("cuda"))
|
| 628 |
return self._chessformer_move(board, return_perplexity)
|
| 629 |
else:
|
| 630 |
raise ValueError(f"Invalid engine type: {self.type}")
|
| 631 |
-
self.model.to(torch.device("cpu"))
|
| 632 |
-
self.device = torch.device("cpu")
|
| 633 |
|
| 634 |
def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
|
| 635 |
if self.type == "stockfish":
|
|
@@ -640,6 +644,26 @@ class Engine:
|
|
| 640 |
raise ValueError(f"Invalid engine type: {self.type}")
|
| 641 |
|
| 642 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
def analyze_position(self, board: chess.Board) -> Optional[float]:
|
| 644 |
"""
|
| 645 |
Analyzes the given **single board** position using the engine.
|
|
@@ -687,22 +711,7 @@ class Engine:
|
|
| 687 |
|
| 688 |
|
| 689 |
elif self.type == "chessformer":
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
self.model = self.model.to(torch.device("cuda"))
|
| 693 |
-
self.device = torch.device("cuda")
|
| 694 |
-
|
| 695 |
-
count_tensor = self.compute_repetition([board.copy(stack=True)])
|
| 696 |
-
|
| 697 |
-
with torch.no_grad():
|
| 698 |
-
_, value = self.model([fen],count_tensor)
|
| 699 |
-
|
| 700 |
-
self.model = self.model.to("cpu")
|
| 701 |
-
self.device = torch.device("cpu")
|
| 702 |
-
|
| 703 |
-
value = value.item()
|
| 704 |
-
return value if board.turn == chess.WHITE else -value
|
| 705 |
-
|
| 706 |
else:
|
| 707 |
raise ValueError(f"Invalid engine type.")
|
| 708 |
|
|
|
|
| 440 |
else:
|
| 441 |
return move_uci, final_value
|
| 442 |
|
| 443 |
+
@spaces.GPU
|
| 444 |
def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
|
| 445 |
"""Get move from chessformer with optional search enhance"""
|
| 446 |
+
self.device = torch.device("cuda")
|
| 447 |
+
self.model.to(torch.device("cuda"))
|
| 448 |
if self.depth == 0:
|
| 449 |
+
result = self._raw_chessformer_move(board,return_perplexity)
|
| 450 |
+
self.model.to(torch.device("cpu"))
|
| 451 |
+
self.device = torch.device("cpu")
|
| 452 |
+
return result
|
| 453 |
else:
|
| 454 |
+
result = self._search_enhanced_move(board,return_perplexity,verbose)
|
| 455 |
+
self.model.to(torch.device("cpu"))
|
| 456 |
+
self.device = torch.device("cpu")
|
| 457 |
+
return result
|
| 458 |
|
| 459 |
def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
|
| 460 |
"""Get best move from stockfish"""
|
|
|
|
| 627 |
|
| 628 |
return results
|
| 629 |
|
|
|
|
| 630 |
def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
|
| 631 |
if self.type == "stockfish":
|
| 632 |
return self._stockfish_move(board, return_perplexity)
|
| 633 |
elif self.type == "chessformer":
|
|
|
|
|
|
|
| 634 |
return self._chessformer_move(board, return_perplexity)
|
| 635 |
else:
|
| 636 |
raise ValueError(f"Invalid engine type: {self.type}")
|
|
|
|
|
|
|
| 637 |
|
| 638 |
def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
|
| 639 |
if self.type == "stockfish":
|
|
|
|
| 644 |
raise ValueError(f"Invalid engine type: {self.type}")
|
| 645 |
|
| 646 |
@spaces.GPU
|
| 647 |
+
def _analyze_position_gpu(self, board: chess.Board) -> Optional[float]:
|
| 648 |
+
"""
|
| 649 |
+
Only acquire ZeroGPU when model forward pass is needed.
|
| 650 |
+
"""
|
| 651 |
+
fen = board.fen()
|
| 652 |
+
|
| 653 |
+
self.model = self.model.to(torch.device("cuda"))
|
| 654 |
+
self.device = torch.device("cuda")
|
| 655 |
+
|
| 656 |
+
count_tensor = self.compute_repetition([board.copy(stack=True)])
|
| 657 |
+
|
| 658 |
+
with torch.no_grad():
|
| 659 |
+
_, value = self.model([fen],count_tensor)
|
| 660 |
+
|
| 661 |
+
self.model = self.model.to("cpu")
|
| 662 |
+
self.device = torch.device("cpu")
|
| 663 |
+
|
| 664 |
+
value = value.item()
|
| 665 |
+
return value if board.turn == chess.WHITE else -value
|
| 666 |
+
|
| 667 |
def analyze_position(self, board: chess.Board) -> Optional[float]:
|
| 668 |
"""
|
| 669 |
Analyzes the given **single board** position using the engine.
|
|
|
|
| 711 |
|
| 712 |
|
| 713 |
elif self.type == "chessformer":
|
| 714 |
+
return self._analyze_position_gpu(board)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
else:
|
| 716 |
raise ValueError(f"Invalid engine type.")
|
| 717 |
|