Spaces:
Running on Zero
Running on Zero
Update utils/engine.py
Browse files- utils/engine.py +11 -3
utils/engine.py
CHANGED
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|
| 8 |
from functools import partial
|
| 9 |
import time
|
| 10 |
import os
|
|
|
|
| 11 |
|
| 12 |
try:
|
| 13 |
from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
|
|
@@ -162,7 +163,7 @@ class Engine:
|
|
| 162 |
raise ValueError(f"Invalid engine path or engine not found: {e}")
|
| 163 |
else:
|
| 164 |
raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
|
| 165 |
-
|
| 166 |
def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
|
| 167 |
bs = len(boards)
|
| 168 |
possible_moves = len(UCI_MOVE_TO_IDX)
|
|
@@ -617,13 +618,16 @@ class Engine:
|
|
| 617 |
|
| 618 |
return results
|
| 619 |
|
|
|
|
| 620 |
def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
|
| 621 |
if self.type == "stockfish":
|
| 622 |
return self._stockfish_move(board, return_perplexity)
|
| 623 |
elif self.type == "chessformer":
|
|
|
|
| 624 |
return self._chessformer_move(board, return_perplexity)
|
| 625 |
else:
|
| 626 |
raise ValueError(f"Invalid engine type: {self.type}")
|
|
|
|
| 627 |
|
| 628 |
def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
|
| 629 |
if self.type == "stockfish":
|
|
@@ -632,7 +636,8 @@ class Engine:
|
|
| 632 |
return self._batch_chessformer_move(boards)
|
| 633 |
else:
|
| 634 |
raise ValueError(f"Invalid engine type: {self.type}")
|
| 635 |
-
|
|
|
|
| 636 |
def analyze_position(self, board: chess.Board) -> Optional[float]:
|
| 637 |
"""
|
| 638 |
Analyzes the given **single board** position using the engine.
|
|
@@ -683,8 +688,11 @@ class Engine:
|
|
| 683 |
fen = board.fen()
|
| 684 |
count_tensor = self.compute_repetition([board.copy(stack=True)])
|
| 685 |
|
|
|
|
| 686 |
with torch.no_grad():
|
| 687 |
-
_, value =
|
|
|
|
|
|
|
| 688 |
|
| 689 |
value = value.item()
|
| 690 |
return value if board.turn == chess.WHITE else -value
|
|
|
|
| 8 |
from functools import partial
|
| 9 |
import time
|
| 10 |
import os
|
| 11 |
+
import spaces
|
| 12 |
|
| 13 |
try:
|
| 14 |
from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
|
|
|
|
| 163 |
raise ValueError(f"Invalid engine path or engine not found: {e}")
|
| 164 |
else:
|
| 165 |
raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
|
| 166 |
+
|
| 167 |
def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
|
| 168 |
bs = len(boards)
|
| 169 |
possible_moves = len(UCI_MOVE_TO_IDX)
|
|
|
|
| 618 |
|
| 619 |
return results
|
| 620 |
|
| 621 |
+
@spaces.GPU
|
| 622 |
def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
|
| 623 |
if self.type == "stockfish":
|
| 624 |
return self._stockfish_move(board, return_perplexity)
|
| 625 |
elif self.type == "chessformer":
|
| 626 |
+
self.model.to("cuda")
|
| 627 |
return self._chessformer_move(board, return_perplexity)
|
| 628 |
else:
|
| 629 |
raise ValueError(f"Invalid engine type: {self.type}")
|
| 630 |
+
self.model.to("cpu")
|
| 631 |
|
| 632 |
def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
|
| 633 |
if self.type == "stockfish":
|
|
|
|
| 636 |
return self._batch_chessformer_move(boards)
|
| 637 |
else:
|
| 638 |
raise ValueError(f"Invalid engine type: {self.type}")
|
| 639 |
+
|
| 640 |
+
@spaces.GPU
|
| 641 |
def analyze_position(self, board: chess.Board) -> Optional[float]:
|
| 642 |
"""
|
| 643 |
Analyzes the given **single board** position using the engine.
|
|
|
|
| 688 |
fen = board.fen()
|
| 689 |
count_tensor = self.compute_repetition([board.copy(stack=True)])
|
| 690 |
|
| 691 |
+
gpu_model = self.model.to("cuda")
|
| 692 |
with torch.no_grad():
|
| 693 |
+
_, value = gpu_model([fen],count_tensor)
|
| 694 |
+
|
| 695 |
+
self.model = self.model.to("cpu")
|
| 696 |
|
| 697 |
value = value.item()
|
| 698 |
return value if board.turn == chess.WHITE else -value
|