| | --- |
| | license: mit |
| | datasets: |
| | - yp-edu/stockfish-debug |
| | name: yp-edu/gpt2-stockfish-debug |
| | results: |
| | - task: train |
| | metrics: |
| | - name: train-loss |
| | type: loss |
| | value: 0.151 |
| | verified: false |
| | - name: eval-loss |
| | type: loss |
| | value: 0.138 |
| | verified: false |
| | widget: |
| | - text: "FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\nMOVE:" |
| | example_title: "Init Board" |
| | - text: "FEN: r2q1rk1/1p3ppp/4bb2/p2p4/5B2/1P1P4/1PPQ1PPP/R3R1K1 w - - 1 17\nMOVE:" |
| | example_title: "Middle Board" |
| | - text: "FEN: 4r1k1/1p1b1ppp/8/8/3P4/2P5/1q3PPP/6K1 b - - 0 28\nMOVE:" |
| | example_title: "Checkmate Possible" |
| | --- |
| | # Model Card for gpt2-stockfish-debug |
| |
|
| | See my [blog post](https://yp-edu.github.io/projects/training-gpt2-on-stockfish-games) for additional details. |
| |
|
| | ## Training Details |
| |
|
| | The model was trained during 1 epoch on the [yp-edu/stockfish-debug](https://huggingface.co/datasets/yp-edu/stockfish-debug) dataset (no hyperparameter tuning done). The samples are: |
| |
|
| | ```json |
| | {"prompt":"FEN: {fen}\nMOVE:", "completion": " {move}"} |
| | ``` |
| |
|
| | Two possible simple extensions: |
| |
|
| | - Expand the FEN string: `r2qk3/...` -> `r11qk111/...` or equivalent |
| | - Condition with the result (ELO not available in the dataset): |
| | ```json |
| | {"prompt":"RES: {res}\nFEN: {fen}\nMOVE:", "completion": " {move}"} |
| | ``` |
| |
|
| | ## Use the Model |
| |
|
| | The following code requires `python-chess` (in addition to `transformers`) which you can install using `pip install python-chess`. |
| |
|
| | ```python |
| | import chess |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | |
| | |
| | def next_move(model, tokenizer, fen): |
| | input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt") |
| | input_ids = {k: v.to(model.device) for k, v in input_ids.items()} |
| | out = model.generate( |
| | **input_ids, |
| | max_new_tokens=10, |
| | pad_token_id=tokenizer.eos_token_id, |
| | do_sample=True, |
| | temperature=0.1, |
| | ) |
| | out_str = tokenizer.batch_decode(out)[0] |
| | return out_str.split("MOVE:")[-1].replace("<|endoftext|>", "").strip() |
| | |
| | |
| | board = chess.Board() |
| | model = AutoModelForCausalLM.from_pretrained("yp-edu/gpt2-stockfish-debug") |
| | tokenizer = AutoTokenizer.from_pretrained("yp-edu/gpt2-stockfish-debug") # or "gpt2" |
| | tokenizer.pad_token = tokenizer.eos_token |
| | for i in range(100): |
| | fen = board.fen() |
| | move_uci = next_move(model, tokenizer, fen) |
| | try: |
| | print(move_uci) |
| | move = chess.Move.from_uci(move_uci) |
| | if move not in board.legal_moves: |
| | raise chess.IllegalMoveError |
| | board.push(move) |
| | outcome = board.outcome() |
| | if outcome is not None: |
| | print(board) |
| | print(outcome.result()) |
| | break |
| | except chess.IllegalMoveError: |
| | print(board) |
| | print("Illegal move", i) |
| | break |
| | else: |
| | print(board) |
| | ``` |
| |
|