File size: 8,073 Bytes
c783a58
 
 
0732c11
c783a58
 
 
 
 
 
 
 
 
c789271
c783a58
 
 
 
 
 
 
 
c789271
c783a58
 
 
 
 
 
 
 
 
c789271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c783a58
c789271
c783a58
c789271
 
c783a58
c789271
c783a58
 
 
 
 
 
0732c11
c783a58
 
 
 
 
 
 
5f16693
c789271
 
 
c783a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c789271
c783a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292c059
 
 
 
 
c783a58
 
292c059
c783a58
292c059
 
c783a58
292c059
 
 
c783a58
 
 
 
 
 
 
292c059
c783a58
292c059
c783a58
 
 
292c059
 
 
 
 
 
 
c783a58
292c059
 
49346c5
292c059
 
 
 
 
49346c5
c783a58
 
 
 
 
 
 
 
292c059
c783a58
292c059
c783a58
292c059
c783a58
292c059
c783a58
 
 
292c059
 
 
c783a58
 
292c059
 
 
c783a58
 
 
292c059
c783a58
 
292c059
 
 
 
 
c783a58
292c059
 
 
 
c783a58
292c059
 
c783a58
292c059
 
 
c783a58
292c059
 
 
c783a58
292c059
 
 
c783a58
292c059
 
c789271
c783a58
 
292c059
 
 
 
 
c783a58
 
292c059
 
 
c783a58
 
 
292c059
 
c783a58
 
 
292c059
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import re
import chess
import gradio as gr
import spaces
from jinja2 import Template
from gradio_chessboard import Chessboard
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# ============================================================================
# Model Loading
# ============================================================================

MODEL_ID = "nuriyev/chess-reasoner-grpo"

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    revision="b7e531a630fd35065f9c8287f4bd21dff42f871b",
)
model.eval()
print("Model loaded!")


# ============================================================================
# Prompts (matching training exactly)
# ============================================================================

USER_PROMPT = Template("""You are an expert chess player.

Given a current game state, you must select the best legal next move. Think in 1-2 sentences, then output your chosen move.

## State

Board:
{% set fen_board = FEN.split()[0] %}
{%- set ns = namespace(board='') -%}
{%- for char in fen_board -%}
{%- if char in '12345678' -%}
{%- set ns.board = ns.board ~ '.' * (char|int) -%}
{%- elif char != '/' -%}
{%- set ns.board = ns.board ~ char -%}
{%- endif -%}
{%- endfor -%}
{#- Output coordinate grid by file -#}
{%- set files = 'abcdefgh' -%}
{% for f in range(8) %}
{%- for r in range(1, 9) -%}
{{ files[f] }}{{ r }}:{{ ns.board[(8-r)*8 + f] }}{% if r < 8 %} {% endif -%}
{%- endfor %}
{% endfor %}
Turn: It is your turn ({{ side_to_move }})
Legal Moves: {{ legal_moves_uci }}

## Output format

<reason>...brief thinking (1-2 first-person very short concise sentences, identifying threat or opportunity, then deciding on the best move to play next)...</reason>
<uci_move>...your_move...</uci_move>

NOTE: capital letters are white, lowercase are black.""")


# ============================================================================
# Model Inference
# ============================================================================

@spaces.GPU
def get_model_move(fen: str) -> tuple[str, str, str]:
    """Get model's move for the given position. Returns (uci_move, reasoning, raw_output)."""
    board = chess.Board(fen)
    turn = "white" if board.turn else "black"

    messages = [
        {"role": "user", "content": USER_PROMPT.render(
            FEN=fen,
            side_to_move=turn,
            legal_moves_uci=", ".join([move.uci()
                                      for move in board.legal_moves])
        )},
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.8,
            top_k=20,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )

    generated = tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)

    # Parse the output
    think_match = re.search(r'<reason>(.*?)</reason>', generated, re.DOTALL)
    move_match = re.search(r'<uci_move>(.*?)</uci_move>', generated)

    reasoning = think_match.group(1).strip(
    ) if think_match else "No reasoning provided"
    uci_move = move_match.group(1).strip() if move_match else None

    # Clean up raw output for display
    raw_output = generated.split('<|im_end|>')[0].strip()

    return uci_move, reasoning, raw_output


# ============================================================================
# Game Logic
# ============================================================================

def play_move(fen: str) -> tuple[str, str, str, str]:
    """
    Process the position after player's move and get AI response.
    Returns: (new_fen, status, reasoning, raw_output)
    """
    board = chess.Board(fen)

    # Check if game is over
    if board.is_game_over():
        result = get_game_result(board)
        return fen, f"🏁 {result}", "", ""

    # If it's black's turn (AI), make a move
    if not board.turn:
        uci_move, reasoning, raw_output = get_model_move(fen)

        if uci_move:
            try:
                move = chess.Move.from_uci(uci_move)
                if move in board.legal_moves:
                    board.push(move)
                else:
                    reasoning = f"⚠️ Model suggested illegal move: {uci_move}. " + reasoning
            except:
                reasoning = f"⚠️ Model output invalid move: {uci_move}. " + reasoning

        # Check if game is over after AI move
        if board.is_game_over():
            result = get_game_result(board)
            return board.fen(), f"🏁 {result}", reasoning, raw_output

        turn_str = "White (You)" if board.turn else "Black (AI)"
        status = f"**Turn:** {turn_str}"
        if board.is_check():
            status += " ⚠️ CHECK!"

        return board.fen(), status, reasoning, raw_output

    # White's turn - just return current state, preserve reasoning display
    turn_str = "White (You)" if board.turn else "Black (AI)"
    status = f"**Turn:** {turn_str}"
    if board.is_check():
        status += " ⚠️ CHECK!"

    return fen, status, gr.update(), gr.update()


def get_game_result(board: chess.Board) -> str:
    """Get the game result string."""
    if board.is_checkmate():
        winner = "Black" if board.turn else "White"
        return f"Checkmate! {winner} wins!"
    elif board.is_stalemate():
        return "Stalemate - Draw"
    elif board.is_insufficient_material():
        return "Draw - Insufficient material"
    elif board.is_fifty_moves():
        return "Draw - 50 move rule"
    elif board.is_repetition():
        return "Draw - Repetition"
    return "Game Over"


def reset_game() -> tuple[str, str, str, str]:
    """Reset to starting position."""
    return chess.STARTING_FEN, "**Turn:** White (You)", "", ""


def ai_plays_first() -> tuple[str, str, str, str]:
    """Let AI make the opening move."""
    return play_move(chess.STARTING_FEN)


# ============================================================================
# Gradio Interface
# ============================================================================

with gr.Blocks(title="♟️ Chess Reasoner") as demo:
    gr.Markdown("""
    # ♟️ Chess Reasoner
    Play chess against a reasoning AI! You play as **White** - click on pieces to move them.
    """)

    with gr.Row():
        with gr.Column(scale=2):
            board = Chessboard(value=chess.STARTING_FEN,
                               label="", game_mode=True)

        with gr.Column(scale=1):
            status = gr.Markdown(value="**Turn:** White (You)")

            with gr.Row():
                reset_btn = gr.Button("🔄 New Game", variant="primary")
                ai_first_btn = gr.Button("🤖 AI First")

            with gr.Accordion("🧠 AI Reasoning", open=True):
                reasoning = gr.Textbox(
                    label="Thinking", lines=3, interactive=False)

            with gr.Accordion("📝 Raw Output", open=False):
                raw_output = gr.Textbox(
                    label="Model Output", lines=5, interactive=False)

    gr.Markdown("""
    ---
    **Model:** [nuriyev/chess-reasoner-grpo](https://huggingface.co/nuriyev/chess-reasoner-grpo) • Fine-tuned from Qwen3-4B-Instruct
    """)

    # Events
    board.change(
        fn=play_move,
        inputs=[board],
        outputs=[board, status, reasoning, raw_output]
    )

    reset_btn.click(
        fn=reset_game,
        outputs=[board, status, reasoning, raw_output]
    )

    ai_first_btn.click(
        fn=ai_plays_first,
        outputs=[board, status, reasoning, raw_output]
    )

if __name__ == "__main__":
    demo.launch(ssr_mode=False)