kaupane commited on
Commit
a2ebcaf
·
verified ·
1 Parent(s): d3cac8c

Update utils/engine.py

Browse files

optimize zero gpu acquirement

Files changed (1) hide show
  1. utils/engine.py +32 -23
utils/engine.py CHANGED
@@ -440,12 +440,21 @@ class Engine:
440
  else:
441
  return move_uci, final_value
442
 
 
443
  def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
444
  """Get move from chessformer with optional search enhance"""
 
 
445
  if self.depth == 0:
446
- return self._raw_chessformer_move(board,return_perplexity)
 
 
 
447
  else:
448
- return self._search_enhanced_move(board,return_perplexity,verbose)
 
 
 
449
 
450
  def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
451
  """Get best move from stockfish"""
@@ -618,18 +627,13 @@ class Engine:
618
 
619
  return results
620
 
621
- @spaces.GPU
622
  def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
623
  if self.type == "stockfish":
624
  return self._stockfish_move(board, return_perplexity)
625
  elif self.type == "chessformer":
626
- self.device = torch.device("cuda")
627
- self.model.to(torch.device("cuda"))
628
  return self._chessformer_move(board, return_perplexity)
629
  else:
630
  raise ValueError(f"Invalid engine type: {self.type}")
631
- self.model.to(torch.device("cpu"))
632
- self.device = torch.device("cpu")
633
 
634
  def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
635
  if self.type == "stockfish":
@@ -640,6 +644,26 @@ class Engine:
640
  raise ValueError(f"Invalid engine type: {self.type}")
641
 
642
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  def analyze_position(self, board: chess.Board) -> Optional[float]:
644
  """
645
  Analyzes the given **single board** position using the engine.
@@ -687,22 +711,7 @@ class Engine:
687
 
688
 
689
  elif self.type == "chessformer":
690
- fen = board.fen()
691
-
692
- self.model = self.model.to(torch.device("cuda"))
693
- self.device = torch.device("cuda")
694
-
695
- count_tensor = self.compute_repetition([board.copy(stack=True)])
696
-
697
- with torch.no_grad():
698
- _, value = self.model([fen],count_tensor)
699
-
700
- self.model = self.model.to("cpu")
701
- self.device = torch.device("cpu")
702
-
703
- value = value.item()
704
- return value if board.turn == chess.WHITE else -value
705
-
706
  else:
707
  raise ValueError(f"Invalid engine type.")
708
 
 
440
  else:
441
  return move_uci, final_value
442
 
443
+ @spaces.GPU
444
  def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
445
  """Get move from chessformer with optional search enhance"""
446
+ self.device = torch.device("cuda")
447
+ self.model.to(torch.device("cuda"))
448
  if self.depth == 0:
449
+ result = self._raw_chessformer_move(board,return_perplexity)
450
+ self.model.to(torch.device("cpu"))
451
+ self.device = torch.device("cpu")
452
+ return result
453
  else:
454
+ result = self._search_enhanced_move(board,return_perplexity,verbose)
455
+ self.model.to(torch.device("cpu"))
456
+ self.device = torch.device("cpu")
457
+ return result
458
 
459
  def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
460
  """Get best move from stockfish"""
 
627
 
628
  return results
629
 
 
630
  def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
631
  if self.type == "stockfish":
632
  return self._stockfish_move(board, return_perplexity)
633
  elif self.type == "chessformer":
 
 
634
  return self._chessformer_move(board, return_perplexity)
635
  else:
636
  raise ValueError(f"Invalid engine type: {self.type}")
 
 
637
 
638
  def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
639
  if self.type == "stockfish":
 
644
  raise ValueError(f"Invalid engine type: {self.type}")
645
 
646
  @spaces.GPU
647
+ def _analyze_position_gpu(self, board: chess.Board) -> Optional[float]:
648
+ """
649
+ Only acquire ZeroGPU when model forward pass is needed.
650
+ """
651
+ fen = board.fen()
652
+
653
+ self.model = self.model.to(torch.device("cuda"))
654
+ self.device = torch.device("cuda")
655
+
656
+ count_tensor = self.compute_repetition([board.copy(stack=True)])
657
+
658
+ with torch.no_grad():
659
+ _, value = self.model([fen],count_tensor)
660
+
661
+ self.model = self.model.to("cpu")
662
+ self.device = torch.device("cpu")
663
+
664
+ value = value.item()
665
+ return value if board.turn == chess.WHITE else -value
666
+
667
  def analyze_position(self, board: chess.Board) -> Optional[float]:
668
  """
669
  Analyzes the given **single board** position using the engine.
 
711
 
712
 
713
  elif self.type == "chessformer":
714
+ return self._analyze_position_gpu(board)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  else:
716
  raise ValueError(f"Invalid engine type.")
717