File size: 13,340 Bytes
7f9c451
 
12471a6
2c0ecb8
12471a6
38d6b68
2c0ecb8
 
7f9c451
2c0ecb8
12471a6
7f9c451
12471a6
7f9c451
 
12471a6
 
38d6b68
2c0ecb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9c451
2c0ecb8
 
 
 
 
7f9c451
 
2c0ecb8
 
7f9c451
 
2c0ecb8
 
7f9c451
12471a6
2c0ecb8
 
 
 
 
 
 
12471a6
 
2c0ecb8
7f9c451
38d6b68
7f9c451
 
 
2c0ecb8
7f9c451
12471a6
 
2c0ecb8
 
7f9c451
 
2c0ecb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9c451
2c0ecb8
 
7f9c451
2c0ecb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9c451
2c0ecb8
 
 
7f9c451
2c0ecb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d6b68
7f9c451
2c0ecb8
 
 
 
 
 
7f9c451
2c0ecb8
 
 
7f9c451
2c0ecb8
 
 
 
 
7f9c451
 
2c0ecb8
 
7f9c451
2c0ecb8
 
7f9c451
2c0ecb8
 
 
7f9c451
2c0ecb8
 
 
 
 
 
7f9c451
2c0ecb8
 
 
7f9c451
2c0ecb8
 
 
 
 
7f9c451
2c0ecb8
 
7f9c451
2c0ecb8
 
 
 
 
7f9c451
 
 
2c0ecb8
 
 
 
 
 
 
7f9c451
38d6b68
12471a6
 
7f9c451
2c0ecb8
38d6b68
 
2c0ecb8
12471a6
 
7f9c451
 
 
2c0ecb8
7f9c451
 
 
 
2c0ecb8
 
 
 
 
 
 
12471a6
38d6b68
7f9c451
 
2c0ecb8
 
38d6b68
 
12471a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
from __future__ import annotations

import chess
import gradio as gr
from gradio_chessboard import Chessboard

START_FEN = chess.STARTING_FEN
CENTER4 = [chess.D4, chess.E4, chess.D5, chess.E5]

PIECE_V = {
    chess.PAWN: 1.0,
    chess.KNIGHT: 3.2,
    chess.BISHOP: 3.3,
    chess.ROOK: 5.0,
    chess.QUEEN: 9.0,
    chess.KING: 0.0,
}

# HDS-ish "axes" (kept human-readable)
AXES = [
    "material",
    "king_safety",
    "piece_activity",
    "center_reach",
    "initiative",
    "tactical_tension",
]

W = {
    "material": 1.8,
    "king_safety": 1.4,
    "piece_activity": 1.0,
    "center_reach": 1.0,
    "initiative": 0.9,
    "tactical_tension": 1.1,
}

# -----------------------------
# Helpers
# -----------------------------
def clamp(x: float, lo: float = -1.0, hi: float = 1.0) -> float:
    return max(lo, min(hi, float(x)))


def side_name(color: bool) -> str:
    return "White" if color == chess.WHITE else "Black"


def square(sq: int) -> str:
    return chess.square_name(sq)


def uci_to_label(prev: chess.Board, mv: chess.Move) -> str:
    s = f"{square(mv.from_square)}{square(mv.to_square)}"
    if prev.is_capture(mv):
        s += " (capture)"
    if prev.is_castling(mv):
        s += " (castle)"
    return s


def infer_move(prev_fen: str, new_fen: str) -> chess.Move | None:
    b = chess.Board(prev_fen)
    for mv in b.legal_moves:
        bb = b.copy(stack=False)
        bb.push(mv)
        if bb.fen() == new_fen:
            return mv
    return None


def count_attackers(board: chess.Board, color: bool, sq: int) -> int:
    return len(board.attackers(color, sq))


# -----------------------------
# Axis components (deterministic, engine-less)
# -----------------------------
def _material(board: chess.Board, me: bool) -> float:
    s = 0.0
    for pt, v in PIECE_V.items():
        s += v * (len(board.pieces(pt, me)) - len(board.pieces(pt, not me)))
    return s  # raw


def _king_exposure(board: chess.Board, me: bool) -> float:
    # Higher risk => worse
    ksq = board.king(me)
    if ksq is None:
        return 10.0
    opp = not me
    risk = 0
    # king neighborhood attacked?
    for sq2 in chess.SquareSet(chess.BB_KING_ATTACKS[ksq]):
        if board.is_attacked_by(opp, sq2):
            risk += 1
    return float(risk)  # 0..8-ish


def _castled(board: chess.Board, me: bool) -> bool:
    ksq = board.king(me)
    if ksq is None:
        return False
    return (ksq in [chess.G1, chess.C1]) if me == chess.WHITE else (ksq in [chess.G8, chess.C8])


def _minor_development(board: chess.Board, me: bool) -> int:
    # "awake pieces" (minors moved off initial squares)
    if me == chess.WHITE:
        kn0 = {chess.B1, chess.G1}
        bi0 = {chess.C1, chess.F1}
    else:
        kn0 = {chess.B8, chess.G8}
        bi0 = {chess.C8, chess.F8}
    k = sum(1 for sq in board.pieces(chess.KNIGHT, me) if sq not in kn0)
    b = sum(1 for sq in board.pieces(chess.BISHOP, me) if sq not in bi0)
    return k + b  # 0..4


def _center_reach(board: chess.Board, me: bool) -> int:
    # number of center squares attacked by me (0..4)
    c = 0
    for sq2 in CENTER4:
        if board.is_attacked_by(me, sq2):
            c += 1
    return c


def _mobility(board: chess.Board, me: bool) -> int:
    b = board.copy(stack=False)
    b.turn = me
    return b.legal_moves.count()


def _initiative(board: chess.Board, me: bool) -> float:
    # crude: checks + captures ratio for side-to-move
    b = board.copy(stack=False)
    b.turn = me
    total = 0
    checks = 0
    caps = 0
    for mv in b.legal_moves:
        total += 1
        if b.is_capture(mv):
            caps += 1
        bb = b.copy(stack=False)
        bb.push(mv)
        if bb.is_check():
            checks += 1
    if total == 0:
        return 0.0
    return 0.65 * (checks / total) + 0.35 * (caps / total)


def _tactical_tension(board: chess.Board, me: bool) -> float:
    # crude: number of captures available now + hanging targets
    b = board.copy(stack=False)
    b.turn = me
    caps = 0
    for mv in b.legal_moves:
        if b.is_capture(mv):
            caps += 1
    # scale
    return float(min(10, caps))


def eval_axes(board: chess.Board, me: bool) -> dict[str, float]:
    mat = _material(board, me)                  # raw
    kex = _king_exposure(board, me)             # raw (lower is better)
    dev = _minor_development(board, me)         # 0..4
    cen = _center_reach(board, me)              # 0..4
    mob = _mobility(board, me)                  # ~0..60
    ini = _initiative(board, me)                # 0..1
    ten = _tactical_tension(board, me)          # 0..10

    # normalize each axis to [-1, 1] (deterministic, stable)
    axes = {
        "material": clamp(mat / 10.0),                    # ±10 pawns range
        "king_safety": clamp((4.0 - kex) / 4.0),          # kex 0..8 => safety roughly
        "piece_activity": clamp((dev - 1.5) / 1.5),       # dev 0..4
        "center_reach": clamp((cen - 1.5) / 1.5),         # cen 0..4
        "initiative": clamp((ini - 0.12) / 0.25),         # ini 0..1
        "tactical_tension": clamp((ten - 2.0) / 4.0),     # ten 0..10
    }
    return axes


def score(axes: dict[str, float]) -> float:
    return sum(W[a] * axes.get(a, 0.0) for a in AXES)


def top_deltas(before: dict[str, float], after: dict[str, float], k: int = 3):
    rows = []
    for a in AXES:
        d = after.get(a, 0.0) - before.get(a, 0.0)
        wd = d * W[a]
        rows.append((a, d, wd))
    rows.sort(key=lambda x: abs(x[2]), reverse=True)
    return rows[:k], rows


def axis_pretty(a: str) -> str:
    return {
        "material": "Material",
        "king_safety": "King safety",
        "piece_activity": "Piece activity",
        "center_reach": "Center reach",
        "initiative": "Initiative",
        "tactical_tension": "Tactical tension",
    }.get(a, a)


def axis_table_md(before: dict[str, float], after: dict[str, float]) -> str:
    lines = []
    lines.append("| Axis | Before | After | Δ |")
    lines.append("|---|---:|---:|---:|")
    for a in AXES:
        b = before.get(a, 0.0)
        c = after.get(a, 0.0)
        d = c - b
        lines.append(f"| {axis_pretty(a)} | {b:+.2f} | {c:+.2f} | {d:+.2f} |")
    return "\n".join(lines)


# -----------------------------
# Risk / explanation (HDS-ish causal trace)
# -----------------------------
def risk_flags(prev: chess.Board, mv: chess.Move, after: chess.Board) -> list[str]:
    flags = []

    mover = prev.turn
    opp = not mover

    # moved piece vulnerability (very simple)
    p = after.piece_at(mv.to_square)
    if p and p.color == mover:
        atk = count_attackers(after, opp, mv.to_square)
        dfd = count_attackers(after, mover, mv.to_square)
        if atk > dfd:
            flags.append("Moved piece becomes a likely target (more attacked than defended).")

    # early queen move warning
    p0 = prev.piece_at(mv.from_square)
    if p0 and p0.piece_type == chess.QUEEN and prev.fullmove_number <= 6:
        flags.append("Early queen move: can invite tempo hits.")

    # king safety regression
    if _king_exposure(after, mover) > _king_exposure(prev, mover):
        flags.append("King safety slightly worsens (more pressure near your king).")

    return flags


def advantage_label(axes: dict[str, float], me: bool) -> str:
    # label from "me" perspective using combined score
    s = score(axes)
    if s >= 0.55:
        return "Ahead"
    if s >= 0.20:
        return "Slightly better"
    if s > -0.20:
        return "Roughly equal"
    if s > -0.55:
        return "Slightly worse"
    return "Behind"


def short_verdict(before_axes: dict[str, float], after_axes: dict[str, float], me: bool) -> str:
    b = advantage_label(before_axes, me)
    a = advantage_label(after_axes, me)
    # 1-liner: state transition
    if b == a:
        return f"State: {a} (no major swing)."
    return f"State shift: {b}{a}."


def causal_sentence(top3, flags: list[str]) -> str:
    # Build a "because / at the cost of" line using top weighted deltas
    pos = [t for t in top3 if t[2] > 0]
    neg = [t for t in top3 if t[2] < 0]

    because = ", ".join(axis_pretty(x[0]) for x in pos[:2]) if pos else "structure"
    cost = ", ".join(axis_pretty(x[0]) for x in neg[:2]) if neg else ""

    if cost:
        core = f"Because {because} improves, at the cost of {cost}."
    else:
        core = f"Because {because} improves with limited downside."

    if flags:
        # add the first risk as a hard boundary
        core += f" Boundary risk: {flags[0]}"
    return core


# -----------------------------
# Candidate moves (Top-3)
# -----------------------------
def candidate_score(prev: chess.Board, mv: chess.Move) -> tuple[float, dict[str, float], list[str]]:
    me = prev.turn
    before_axes = eval_axes(prev, me)

    b2 = prev.copy(stack=False)
    b2.push(mv)
    after_axes = eval_axes(b2, me)

    base = score(after_axes)

    # small deterministic bonuses
    bonus = 0.0
    if prev.is_capture(mv):
        bonus += 0.06
    if b2.is_check():
        bonus += 0.06
    if prev.is_castling(mv):
        bonus += 0.05

    # small deterministic risk penalty
    flags = risk_flags(prev, mv, b2)
    penalty = 0.04 * len(flags)

    return base + bonus - penalty, after_axes, flags


def top3_candidates(prev: chess.Board) -> list[tuple[chess.Move, float, dict[str, float], list[str]]]:
    me = prev.turn
    scored = []
    for mv in prev.legal_moves:
        s, ax, fl = candidate_score(prev, mv)
        scored.append((mv, s, ax, fl))
    scored.sort(key=lambda x: x[1], reverse=True)  # higher is better for mover (always computed from mover POV)
    return scored[:3]


def format_candidate(prev: chess.Board, mv: chess.Move, before_axes: dict[str, float], s: float, after_axes: dict[str, float], flags: list[str]) -> str:
    # Pick best + worst axis change for a compact explanation
    top3, _all = top_deltas(before_axes, after_axes, k=3)
    pos = [t for t in top3 if t[2] > 0]
    neg = [t for t in top3 if t[2] < 0]

    pro = axis_pretty(pos[0][0]) if pos else "Structure"
    con = axis_pretty(neg[0][0]) if neg else "Low risk"
    risk = flags[0] if flags else "Low immediate risk."
    return (
        f"- **{prev.san(mv)}** ({uci_to_label(prev, mv)})\n"
        f"  - Why: **{pro}** improves (score≈{s:+.2f}).\n"
        f"  - Risk: {risk if flags else con + '.'}"
    )


# -----------------------------
# Main move handler
# -----------------------------
def build_hds_trace(prev_fen: str, new_fen: str) -> str:
    prev = chess.Board(prev_fen)
    mv = infer_move(prev_fen, new_fen)

    # If we can't infer (board edited), just show eval snapshot
    me = prev.turn
    before_axes = eval_axes(prev, me)

    if mv is None:
        return (
            "⚠️ Could not infer the move (board might be edited, not moved).\n\n"
            "HDS Trace:\n\n"
            + axis_table_md(before_axes, before_axes)
        )

    after = chess.Board(new_fen)
    after_axes = eval_axes(after, me)

    top3, _all = top_deltas(before_axes, after_axes, k=3)
    flags = risk_flags(prev, mv, after)

    verdict1 = short_verdict(before_axes, after_axes, me)
    mover = side_name(me)
    move_line = f"Move: **{mover} plays {prev.san(mv)}** ({uci_to_label(prev, mv)})."

    # HDS-ish trace block
    trace = []
    trace.append("### HDS Trace (deterministic)")
    trace.append(f"- {verdict1}")
    trace.append(f"- {move_line}")
    trace.append(f"- {causal_sentence(top3, flags)}")

    # State vector table (before/after)
    trace.append("\n**State vector (axes)**")
    trace.append(axis_table_md(before_axes, after_axes))

    # Candidate moves (computed from the position *before* your move)
    trace.append("\n### Top-3 candidate moves (with pros / risks)")
    cands = top3_candidates(prev)
    for (cmv, cs, cax, cfl) in cands:
        trace.append(format_candidate(prev, cmv, before_axes, cs, cax, cfl))

    # Keep it readable; still “HDS-ish” without garbage jargon
    return "\n".join(trace)


def on_move(fen_after: str, prev_fen: str):
    if not prev_fen:
        prev_fen = START_FEN
    md = build_hds_trace(prev_fen, fen_after)
    return fen_after, fen_after, md


def on_reset():
    md = (
        "### HDS Trace (deterministic)\n"
        "- State: Roughly equal.\n"
        "- Move: (none yet)\n"
        "- Goal: build a safe structure, then increase pressure.\n"
    )
    return START_FEN, START_FEN, md


CSS = """
footer { display: none !important; }
.gradio-container { max-width: 980px !important; }
"""

with gr.Blocks(css=CSS, title="HDS Chess Trainer") as demo:
    prev_state = gr.State(START_FEN)

    board = Chessboard(
        value=START_FEN,
        game_mode=True,
        orientation="white",   # fixed as requested
        show_label=False,
        label=None,
    )

    log = gr.Markdown(
        value=(
            "### HDS Trace (deterministic)\n"
            "- State: Roughly equal.\n"
            "- Move: (none yet)\n"
            "- Goal: build a safe structure, then increase pressure.\n"
        )
    )

    reset_btn = gr.Button("Reset")

    board.move(on_move, inputs=[board, prev_state], outputs=[board, prev_state, log])
    reset_btn.click(on_reset, inputs=None, outputs=[board, prev_state, log])

if __name__ == "__main__":
    demo.launch()