Update modeling_chessbot.py
Browse files- modeling_chessbot.py +1 -1
modeling_chessbot.py
CHANGED
|
@@ -613,7 +613,7 @@ class ChessBotModel(ChessBotPreTrainedModel):
|
|
| 613 |
# Find the move with the highest policy value that is legal
|
| 614 |
legal_moves_mask = - torch.ones_like(policy) * 999
|
| 615 |
for move in legal_moves:
|
| 616 |
-
legal_moves_mask[policy_index
|
| 617 |
policy = legal_moves_mask + policy
|
| 618 |
return policy_index[torch.argmax(policy).item()]
|
| 619 |
else:
|
|
|
|
| 613 |
# Find the move with the highest policy value that is legal
|
| 614 |
legal_moves_mask = - torch.ones_like(policy) * 999
|
| 615 |
for move in legal_moves:
|
| 616 |
+
legal_moves_mask[policy_index.index(move)] = 0
|
| 617 |
policy = legal_moves_mask + policy
|
| 618 |
return policy_index[torch.argmax(policy).item()]
|
| 619 |
else:
|