Alfredvc commited on
Commit
f96cfc0
·
verified ·
1 Parent(s): 22b7d27

Document logits to move decoding

Browse files
Files changed (1) hide show
  1. README.md +48 -0
README.md CHANGED
@@ -88,6 +88,54 @@ bin_logits = session.run(["bin_logits"], {"bin_moves": bin_moves})[0]
88
  The ONNX artifact uses the `bin_logits_v1` interface: `bin_moves` input with
89
  shape `[batch, time]` and `bin_logits` output with shape `[batch, 65536]`.
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ## Validation
92
 
93
  | Artifact | Validation | Status | Backend | Sample shape |
 
88
  The ONNX artifact uses the `bin_logits_v1` interface: `bin_moves` input with
89
  shape `[batch, time]` and `bin_logits` output with shape `[batch, 65536]`.
90
 
91
+ ## Converting Logits To Moves
92
+
93
+ The model predicts move tokens, not SAN strings. Do not take an unconstrained
94
+ argmax over the full vocabulary. Score the legal moves in the current board
95
+ position and choose from that legal set.
96
+
97
+ For PyTorch, logits are over the native output tokenizer vocabulary:
98
+
99
+ ```python
100
+ from chess_autocomplete.chess_utils import Board
101
+
102
+ board = Board()
103
+ # Apply any moves already played:
104
+ # board.push(chess.Move.from_uci("e2e4"))
105
+
106
+ next_logits = logits[0, -1]
107
+ legal = []
108
+ for move in board.board.legal_moves:
109
+ raw_bin_word = board.encode(move)
110
+ token_id = loaded.output_tokenizer.encode(raw_bin_word)
111
+ legal.append((float(next_logits[token_id]), move))
112
+
113
+ score, best_move = max(legal, key=lambda item: item[0])
114
+ print(best_move.uci())
115
+ ```
116
+
117
+ For ONNX `bin_logits_v1`, logits are already indexed by raw 16-bit move word:
118
+
119
+ ```python
120
+ from chess_autocomplete.chess_utils import Board
121
+
122
+ board = Board()
123
+ # Apply any moves already played:
124
+ # board.push(chess.Move.from_uci("e2e4"))
125
+
126
+ next_logits = bin_logits[0]
127
+ legal = []
128
+ for move in board.board.legal_moves:
129
+ raw_bin_word = board.encode(move)
130
+ legal.append((float(next_logits[raw_bin_word]), move))
131
+
132
+ score, best_move = max(legal, key=lambda item: item[0])
133
+ print(best_move.uci())
134
+ ```
135
+
136
+ Call `board.push(best_move)` after selecting a move so the next prediction is
137
+ decoded against the updated legal move set.
138
+
139
  ## Validation
140
 
141
  | Artifact | Validation | Status | Backend | Sample shape |