Alfredvc's picture
Add BF16 ONNX artifact
81753e2 verified
---
library_name: chess-autocomplete
tags:
- chess
- pytorch
- safetensors
- onnx
license: apache-2.0
---
# Alfredvc/chess-autocomplete-v1
This repository contains one chess-autocomplete model variant staged for inference.
## Variant
- Repository: `Alfredvc/chess-autocomplete-v1`
- Architecture: `ChessTransformer`
- Dimensions: `768` hidden, `12` heads, `12` blocks
- Maximum half moves: `600`
- Input representation: `Discrete`
- Norm / MLP: `layernorm` / `swiglu`
- Native input tokenizer: `RealizableMoveTokenizer` with `4169` ids
- Native output tokenizer: `RealizableMoveTokenizer` with `4135` ids
- Metadata: Metadata tokens are part of the input token stream.
## Interface
This is a metadata-token model. Inputs must begin with the metadata prefix:
```text
[time_control_token, white_elo_token, black_elo_token, GAME_START, ...moves]
```
Use `TIME_CONTROL_MISSING_WORD` and `RATING_MISSING_WORD` when metadata is not
available.
The native PyTorch model returns logits over the output tokenizer vocabulary
(`4135` ids). The ONNX artifacts wrap that model and return
`bin_logits` over raw 16-bit move words (`65536` ids). These are different output
interfaces.
## PyTorch
```python
import torch
from chess_autocomplete import protocol
from chess_autocomplete.huggingface import load_model_repo
loaded = load_model_repo(".")
raw_input = torch.tensor(
[[
protocol.TIME_CONTROL_MISSING_WORD,
protocol.RATING_MISSING_WORD,
protocol.RATING_MISSING_WORD,
protocol.GAME_START,
]],
dtype=torch.long,
)
input_ids = loaded.input_tokenizer.batch_encode(raw_input)
logits, _ = loaded.model(input_ids)
```
The PyTorch weights are stored in `model.safetensors` and loaded strictly into
`chess_autocomplete.models.ChessTransformer`.
## ONNX Runtime
```python
import numpy as np
import onnxruntime as ort
from chess_autocomplete import protocol
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
bin_moves = np.asarray(
[[
protocol.TIME_CONTROL_MISSING_WORD,
protocol.RATING_MISSING_WORD,
protocol.RATING_MISSING_WORD,
protocol.GAME_START,
]],
dtype=np.int32,
)
bin_logits = session.run(["bin_logits"], {"bin_moves": bin_moves})[0]
```
Two ONNX files are published:
- `model.onnx`: FP32 compatibility artifact.
- `model-bf16.onnx`: BF16 floating-weight artifact for runtimes with BF16
operator support.
Both ONNX artifacts use the `bin_logits_v1` interface: `bin_moves` input with
shape `[batch, time]` and `bin_logits` output with shape `[batch, 65536]`.
## Converting Logits To Moves
The model predicts move tokens, not SAN strings. Do not take an unconstrained
argmax over the full vocabulary. Score the legal moves in the current board
position and choose from that legal set.
For PyTorch, logits are over the native output tokenizer vocabulary:
```python
from chess_autocomplete.chess_utils import Board
board = Board()
# Apply any moves already played:
# board.push(chess.Move.from_uci("e2e4"))
next_logits = logits[0, -1]
legal = []
for move in board.board.legal_moves:
raw_bin_word = board.encode(move)
token_id = loaded.output_tokenizer.encode(raw_bin_word)
legal.append((float(next_logits[token_id]), move))
score, best_move = max(legal, key=lambda item: item[0])
print(best_move.uci())
```
For ONNX `bin_logits_v1`, logits are already indexed by raw 16-bit move word:
```python
from chess_autocomplete.chess_utils import Board
board = Board()
# Apply any moves already played:
# board.push(chess.Move.from_uci("e2e4"))
next_logits = bin_logits[0]
legal = []
for move in board.board.legal_moves:
raw_bin_word = board.encode(move)
legal.append((float(next_logits[raw_bin_word]), move))
score, best_move = max(legal, key=lambda item: item[0])
print(best_move.uci())
```
Call `board.push(best_move)` after selecting a move so the next prediction is
decoded against the updated legal move set.
## Validation
| Artifact | Validation | Status | Backend | Precision | Sample shape |
| --- | --- | --- | --- | --- | --- |
| model.safetensors | write | pass | safetensors.torch.save_file | | |
| model.safetensors | strict_load | pass | safetensors.torch.load_file | | |
| model.onnx | export | pass | torch.onnx | fp32 | [2, 2] |
| model.onnx | runtime | pass | onnxruntime.CPUExecutionProvider | fp32 | [2, 2] |
| model-bf16.onnx | export | pass | torch.onnx | bf16 | [2, 2] |
| model-bf16.onnx | onnx_checker_and_initializer_dtype | pass | onnx.checker | bf16 | |
## Known Limitations
This model is trained for chess move autocomplete and is not a general chess
engine. It does not include Transformers `AutoModel` or `trust_remote_code`
support. Metadata-aware variants encode metadata as input tokens; no separate
metadata tensor path is supported. Some ONNX Runtime CPU builds do not execute
the BF16 MatMul graph; use `model.onnx` for broad compatibility or
`model-bf16.onnx` on a backend with BF16 operator support.