kaupane commited on
Commit
e50a529
·
verified ·
1 Parent(s): 37176be

Update utils/engine.py

Browse files
Files changed (1) hide show
  1. utils/engine.py +11 -3
utils/engine.py CHANGED
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
8
  from functools import partial
9
  import time
10
  import os
 
11
 
12
  try:
13
  from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
@@ -162,7 +163,7 @@ class Engine:
162
  raise ValueError(f"Invalid engine path or engine not found: {e}")
163
  else:
164
  raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
165
-
166
  def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
167
  bs = len(boards)
168
  possible_moves = len(UCI_MOVE_TO_IDX)
@@ -617,13 +618,16 @@ class Engine:
617
 
618
  return results
619
 
 
620
  def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
621
  if self.type == "stockfish":
622
  return self._stockfish_move(board, return_perplexity)
623
  elif self.type == "chessformer":
 
624
  return self._chessformer_move(board, return_perplexity)
625
  else:
626
  raise ValueError(f"Invalid engine type: {self.type}")
 
627
 
628
  def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
629
  if self.type == "stockfish":
@@ -632,7 +636,8 @@ class Engine:
632
  return self._batch_chessformer_move(boards)
633
  else:
634
  raise ValueError(f"Invalid engine type: {self.type}")
635
-
 
636
  def analyze_position(self, board: chess.Board) -> Optional[float]:
637
  """
638
  Analyzes the given **single board** position using the engine.
@@ -683,8 +688,11 @@ class Engine:
683
  fen = board.fen()
684
  count_tensor = self.compute_repetition([board.copy(stack=True)])
685
 
 
686
  with torch.no_grad():
687
- _, value = self.model([fen],count_tensor)
 
 
688
 
689
  value = value.item()
690
  return value if board.turn == chess.WHITE else -value
 
8
  from functools import partial
9
  import time
10
  import os
11
+ import spaces
12
 
13
  try:
14
  from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
 
163
  raise ValueError(f"Invalid engine path or engine not found: {e}")
164
  else:
165
  raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
166
+
167
  def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
168
  bs = len(boards)
169
  possible_moves = len(UCI_MOVE_TO_IDX)
 
618
 
619
  return results
620
 
621
+ @spaces.GPU
622
  def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
623
  if self.type == "stockfish":
624
  return self._stockfish_move(board, return_perplexity)
625
  elif self.type == "chessformer":
626
+ self.model.to("cuda")
627
  return self._chessformer_move(board, return_perplexity)
628
  else:
629
  raise ValueError(f"Invalid engine type: {self.type}")
630
+ self.model.to("cpu")
631
 
632
  def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
633
  if self.type == "stockfish":
 
636
  return self._batch_chessformer_move(boards)
637
  else:
638
  raise ValueError(f"Invalid engine type: {self.type}")
639
+
640
+ @spaces.GPU
641
  def analyze_position(self, board: chess.Board) -> Optional[float]:
642
  """
643
  Analyzes the given **single board** position using the engine.
 
688
  fen = board.fen()
689
  count_tensor = self.compute_repetition([board.copy(stack=True)])
690
 
691
+ gpu_model = self.model.to("cuda")
692
  with torch.no_grad():
693
+ _, value = gpu_model([fen],count_tensor)
694
+
695
+ self.model = self.model.to("cpu")
696
 
697
  value = value.item()
698
  return value if board.turn == chess.WHITE else -value