Upload model.py with huggingface_hub
Browse files
model.py
CHANGED
|
@@ -452,17 +452,18 @@ class BT4(nn.Module):
|
|
| 452 |
|
| 453 |
return policy_logits, value_winner, value_q
|
| 454 |
|
| 455 |
-
def
|
| 456 |
"""
|
| 457 |
-
Predict a move from a
|
| 458 |
|
| 459 |
Args:
|
| 460 |
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
|
| 461 |
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
|
| 462 |
device: Device to run the model on (if None, uses model's device)
|
|
|
|
| 463 |
|
| 464 |
Returns:
|
| 465 |
-
UCI move string (e.g., 'e2e4')
|
| 466 |
"""
|
| 467 |
# Detect device from model if not provided
|
| 468 |
if device is None:
|
|
|
|
| 452 |
|
| 453 |
return policy_logits, value_winner, value_q
|
| 454 |
|
| 455 |
+
def get_move_from_history(self, fen_or_moves: Union[str, List[str]], T: float, device: str = None, **kwargs) -> str:
|
| 456 |
"""
|
| 457 |
+
Predict a move from a move history or FEN position.
|
| 458 |
|
| 459 |
Args:
|
| 460 |
fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves
|
| 461 |
T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic)
|
| 462 |
device: Device to run the model on (if None, uses model's device)
|
| 463 |
+
return_probs: If True, returns a dictionary of move probabilities instead of a single move
|
| 464 |
|
| 465 |
Returns:
|
| 466 |
+
UCI move string (e.g., 'e2e4') or dictionary of move probabilities if return_probs=True
|
| 467 |
"""
|
| 468 |
# Detect device from model if not provided
|
| 469 |
if device is None:
|