|
|
from __future__ import annotations |
|
|
|
|
|
import chess |
|
|
import gradio as gr |
|
|
from gradio_chessboard import Chessboard |
|
|
|
|
|
START_FEN = chess.STARTING_FEN |
|
|
CENTER4 = [chess.D4, chess.E4, chess.D5, chess.E5] |
|
|
|
|
|
PIECE_V = { |
|
|
chess.PAWN: 1.0, |
|
|
chess.KNIGHT: 3.2, |
|
|
chess.BISHOP: 3.3, |
|
|
chess.ROOK: 5.0, |
|
|
chess.QUEEN: 9.0, |
|
|
chess.KING: 0.0, |
|
|
} |
|
|
|
|
|
|
|
|
AXES = [ |
|
|
"material", |
|
|
"king_safety", |
|
|
"piece_activity", |
|
|
"center_reach", |
|
|
"initiative", |
|
|
"tactical_tension", |
|
|
] |
|
|
|
|
|
W = { |
|
|
"material": 1.8, |
|
|
"king_safety": 1.4, |
|
|
"piece_activity": 1.0, |
|
|
"center_reach": 1.0, |
|
|
"initiative": 0.9, |
|
|
"tactical_tension": 1.1, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clamp(x: float, lo: float = -1.0, hi: float = 1.0) -> float: |
|
|
return max(lo, min(hi, float(x))) |
|
|
|
|
|
|
|
|
def side_name(color: bool) -> str: |
|
|
return "White" if color == chess.WHITE else "Black" |
|
|
|
|
|
|
|
|
def square(sq: int) -> str: |
|
|
return chess.square_name(sq) |
|
|
|
|
|
|
|
|
def uci_to_label(prev: chess.Board, mv: chess.Move) -> str: |
|
|
s = f"{square(mv.from_square)}→{square(mv.to_square)}" |
|
|
if prev.is_capture(mv): |
|
|
s += " (capture)" |
|
|
if prev.is_castling(mv): |
|
|
s += " (castle)" |
|
|
return s |
|
|
|
|
|
|
|
|
def infer_move(prev_fen: str, new_fen: str) -> chess.Move | None: |
|
|
b = chess.Board(prev_fen) |
|
|
for mv in b.legal_moves: |
|
|
bb = b.copy(stack=False) |
|
|
bb.push(mv) |
|
|
if bb.fen() == new_fen: |
|
|
return mv |
|
|
return None |
|
|
|
|
|
|
|
|
def count_attackers(board: chess.Board, color: bool, sq: int) -> int: |
|
|
return len(board.attackers(color, sq)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _material(board: chess.Board, me: bool) -> float: |
|
|
s = 0.0 |
|
|
for pt, v in PIECE_V.items(): |
|
|
s += v * (len(board.pieces(pt, me)) - len(board.pieces(pt, not me))) |
|
|
return s |
|
|
|
|
|
|
|
|
def _king_exposure(board: chess.Board, me: bool) -> float: |
|
|
|
|
|
ksq = board.king(me) |
|
|
if ksq is None: |
|
|
return 10.0 |
|
|
opp = not me |
|
|
risk = 0 |
|
|
|
|
|
for sq2 in chess.SquareSet(chess.BB_KING_ATTACKS[ksq]): |
|
|
if board.is_attacked_by(opp, sq2): |
|
|
risk += 1 |
|
|
return float(risk) |
|
|
|
|
|
|
|
|
def _castled(board: chess.Board, me: bool) -> bool: |
|
|
ksq = board.king(me) |
|
|
if ksq is None: |
|
|
return False |
|
|
return (ksq in [chess.G1, chess.C1]) if me == chess.WHITE else (ksq in [chess.G8, chess.C8]) |
|
|
|
|
|
|
|
|
def _minor_development(board: chess.Board, me: bool) -> int: |
|
|
|
|
|
if me == chess.WHITE: |
|
|
kn0 = {chess.B1, chess.G1} |
|
|
bi0 = {chess.C1, chess.F1} |
|
|
else: |
|
|
kn0 = {chess.B8, chess.G8} |
|
|
bi0 = {chess.C8, chess.F8} |
|
|
k = sum(1 for sq in board.pieces(chess.KNIGHT, me) if sq not in kn0) |
|
|
b = sum(1 for sq in board.pieces(chess.BISHOP, me) if sq not in bi0) |
|
|
return k + b |
|
|
|
|
|
|
|
|
def _center_reach(board: chess.Board, me: bool) -> int: |
|
|
|
|
|
c = 0 |
|
|
for sq2 in CENTER4: |
|
|
if board.is_attacked_by(me, sq2): |
|
|
c += 1 |
|
|
return c |
|
|
|
|
|
|
|
|
def _mobility(board: chess.Board, me: bool) -> int: |
|
|
b = board.copy(stack=False) |
|
|
b.turn = me |
|
|
return b.legal_moves.count() |
|
|
|
|
|
|
|
|
def _initiative(board: chess.Board, me: bool) -> float: |
|
|
|
|
|
b = board.copy(stack=False) |
|
|
b.turn = me |
|
|
total = 0 |
|
|
checks = 0 |
|
|
caps = 0 |
|
|
for mv in b.legal_moves: |
|
|
total += 1 |
|
|
if b.is_capture(mv): |
|
|
caps += 1 |
|
|
bb = b.copy(stack=False) |
|
|
bb.push(mv) |
|
|
if bb.is_check(): |
|
|
checks += 1 |
|
|
if total == 0: |
|
|
return 0.0 |
|
|
return 0.65 * (checks / total) + 0.35 * (caps / total) |
|
|
|
|
|
|
|
|
def _tactical_tension(board: chess.Board, me: bool) -> float: |
|
|
|
|
|
b = board.copy(stack=False) |
|
|
b.turn = me |
|
|
caps = 0 |
|
|
for mv in b.legal_moves: |
|
|
if b.is_capture(mv): |
|
|
caps += 1 |
|
|
|
|
|
return float(min(10, caps)) |
|
|
|
|
|
|
|
|
def eval_axes(board: chess.Board, me: bool) -> dict[str, float]: |
|
|
mat = _material(board, me) |
|
|
kex = _king_exposure(board, me) |
|
|
dev = _minor_development(board, me) |
|
|
cen = _center_reach(board, me) |
|
|
mob = _mobility(board, me) |
|
|
ini = _initiative(board, me) |
|
|
ten = _tactical_tension(board, me) |
|
|
|
|
|
|
|
|
axes = { |
|
|
"material": clamp(mat / 10.0), |
|
|
"king_safety": clamp((4.0 - kex) / 4.0), |
|
|
"piece_activity": clamp((dev - 1.5) / 1.5), |
|
|
"center_reach": clamp((cen - 1.5) / 1.5), |
|
|
"initiative": clamp((ini - 0.12) / 0.25), |
|
|
"tactical_tension": clamp((ten - 2.0) / 4.0), |
|
|
} |
|
|
return axes |
|
|
|
|
|
|
|
|
def score(axes: dict[str, float]) -> float: |
|
|
return sum(W[a] * axes.get(a, 0.0) for a in AXES) |
|
|
|
|
|
|
|
|
def top_deltas(before: dict[str, float], after: dict[str, float], k: int = 3): |
|
|
rows = [] |
|
|
for a in AXES: |
|
|
d = after.get(a, 0.0) - before.get(a, 0.0) |
|
|
wd = d * W[a] |
|
|
rows.append((a, d, wd)) |
|
|
rows.sort(key=lambda x: abs(x[2]), reverse=True) |
|
|
return rows[:k], rows |
|
|
|
|
|
|
|
|
def axis_pretty(a: str) -> str: |
|
|
return { |
|
|
"material": "Material", |
|
|
"king_safety": "King safety", |
|
|
"piece_activity": "Piece activity", |
|
|
"center_reach": "Center reach", |
|
|
"initiative": "Initiative", |
|
|
"tactical_tension": "Tactical tension", |
|
|
}.get(a, a) |
|
|
|
|
|
|
|
|
def axis_table_md(before: dict[str, float], after: dict[str, float]) -> str: |
|
|
lines = [] |
|
|
lines.append("| Axis | Before | After | Δ |") |
|
|
lines.append("|---|---:|---:|---:|") |
|
|
for a in AXES: |
|
|
b = before.get(a, 0.0) |
|
|
c = after.get(a, 0.0) |
|
|
d = c - b |
|
|
lines.append(f"| {axis_pretty(a)} | {b:+.2f} | {c:+.2f} | {d:+.2f} |") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def risk_flags(prev: chess.Board, mv: chess.Move, after: chess.Board) -> list[str]: |
|
|
flags = [] |
|
|
|
|
|
mover = prev.turn |
|
|
opp = not mover |
|
|
|
|
|
|
|
|
p = after.piece_at(mv.to_square) |
|
|
if p and p.color == mover: |
|
|
atk = count_attackers(after, opp, mv.to_square) |
|
|
dfd = count_attackers(after, mover, mv.to_square) |
|
|
if atk > dfd: |
|
|
flags.append("Moved piece becomes a likely target (more attacked than defended).") |
|
|
|
|
|
|
|
|
p0 = prev.piece_at(mv.from_square) |
|
|
if p0 and p0.piece_type == chess.QUEEN and prev.fullmove_number <= 6: |
|
|
flags.append("Early queen move: can invite tempo hits.") |
|
|
|
|
|
|
|
|
if _king_exposure(after, mover) > _king_exposure(prev, mover): |
|
|
flags.append("King safety slightly worsens (more pressure near your king).") |
|
|
|
|
|
return flags |
|
|
|
|
|
|
|
|
def advantage_label(axes: dict[str, float], me: bool) -> str: |
|
|
|
|
|
s = score(axes) |
|
|
if s >= 0.55: |
|
|
return "Ahead" |
|
|
if s >= 0.20: |
|
|
return "Slightly better" |
|
|
if s > -0.20: |
|
|
return "Roughly equal" |
|
|
if s > -0.55: |
|
|
return "Slightly worse" |
|
|
return "Behind" |
|
|
|
|
|
|
|
|
def short_verdict(before_axes: dict[str, float], after_axes: dict[str, float], me: bool) -> str: |
|
|
b = advantage_label(before_axes, me) |
|
|
a = advantage_label(after_axes, me) |
|
|
|
|
|
if b == a: |
|
|
return f"State: {a} (no major swing)." |
|
|
return f"State shift: {b} → {a}." |
|
|
|
|
|
|
|
|
def causal_sentence(top3, flags: list[str]) -> str: |
|
|
|
|
|
pos = [t for t in top3 if t[2] > 0] |
|
|
neg = [t for t in top3 if t[2] < 0] |
|
|
|
|
|
because = ", ".join(axis_pretty(x[0]) for x in pos[:2]) if pos else "structure" |
|
|
cost = ", ".join(axis_pretty(x[0]) for x in neg[:2]) if neg else "" |
|
|
|
|
|
if cost: |
|
|
core = f"Because {because} improves, at the cost of {cost}." |
|
|
else: |
|
|
core = f"Because {because} improves with limited downside." |
|
|
|
|
|
if flags: |
|
|
|
|
|
core += f" Boundary risk: {flags[0]}" |
|
|
return core |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def candidate_score(prev: chess.Board, mv: chess.Move) -> tuple[float, dict[str, float], list[str]]: |
|
|
me = prev.turn |
|
|
before_axes = eval_axes(prev, me) |
|
|
|
|
|
b2 = prev.copy(stack=False) |
|
|
b2.push(mv) |
|
|
after_axes = eval_axes(b2, me) |
|
|
|
|
|
base = score(after_axes) |
|
|
|
|
|
|
|
|
bonus = 0.0 |
|
|
if prev.is_capture(mv): |
|
|
bonus += 0.06 |
|
|
if b2.is_check(): |
|
|
bonus += 0.06 |
|
|
if prev.is_castling(mv): |
|
|
bonus += 0.05 |
|
|
|
|
|
|
|
|
flags = risk_flags(prev, mv, b2) |
|
|
penalty = 0.04 * len(flags) |
|
|
|
|
|
return base + bonus - penalty, after_axes, flags |
|
|
|
|
|
|
|
|
def top3_candidates(prev: chess.Board) -> list[tuple[chess.Move, float, dict[str, float], list[str]]]: |
|
|
me = prev.turn |
|
|
scored = [] |
|
|
for mv in prev.legal_moves: |
|
|
s, ax, fl = candidate_score(prev, mv) |
|
|
scored.append((mv, s, ax, fl)) |
|
|
scored.sort(key=lambda x: x[1], reverse=True) |
|
|
return scored[:3] |
|
|
|
|
|
|
|
|
def format_candidate(prev: chess.Board, mv: chess.Move, before_axes: dict[str, float], s: float, after_axes: dict[str, float], flags: list[str]) -> str: |
|
|
|
|
|
top3, _all = top_deltas(before_axes, after_axes, k=3) |
|
|
pos = [t for t in top3 if t[2] > 0] |
|
|
neg = [t for t in top3 if t[2] < 0] |
|
|
|
|
|
pro = axis_pretty(pos[0][0]) if pos else "Structure" |
|
|
con = axis_pretty(neg[0][0]) if neg else "Low risk" |
|
|
risk = flags[0] if flags else "Low immediate risk." |
|
|
return ( |
|
|
f"- **{prev.san(mv)}** ({uci_to_label(prev, mv)})\n" |
|
|
f" - Why: **{pro}** improves (score≈{s:+.2f}).\n" |
|
|
f" - Risk: {risk if flags else con + '.'}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_hds_trace(prev_fen: str, new_fen: str) -> str: |
|
|
prev = chess.Board(prev_fen) |
|
|
mv = infer_move(prev_fen, new_fen) |
|
|
|
|
|
|
|
|
me = prev.turn |
|
|
before_axes = eval_axes(prev, me) |
|
|
|
|
|
if mv is None: |
|
|
return ( |
|
|
"⚠️ Could not infer the move (board might be edited, not moved).\n\n" |
|
|
"HDS Trace:\n\n" |
|
|
+ axis_table_md(before_axes, before_axes) |
|
|
) |
|
|
|
|
|
after = chess.Board(new_fen) |
|
|
after_axes = eval_axes(after, me) |
|
|
|
|
|
top3, _all = top_deltas(before_axes, after_axes, k=3) |
|
|
flags = risk_flags(prev, mv, after) |
|
|
|
|
|
verdict1 = short_verdict(before_axes, after_axes, me) |
|
|
mover = side_name(me) |
|
|
move_line = f"Move: **{mover} plays {prev.san(mv)}** ({uci_to_label(prev, mv)})." |
|
|
|
|
|
|
|
|
trace = [] |
|
|
trace.append("### HDS Trace (deterministic)") |
|
|
trace.append(f"- {verdict1}") |
|
|
trace.append(f"- {move_line}") |
|
|
trace.append(f"- {causal_sentence(top3, flags)}") |
|
|
|
|
|
|
|
|
trace.append("\n**State vector (axes)**") |
|
|
trace.append(axis_table_md(before_axes, after_axes)) |
|
|
|
|
|
|
|
|
trace.append("\n### Top-3 candidate moves (with pros / risks)") |
|
|
cands = top3_candidates(prev) |
|
|
for (cmv, cs, cax, cfl) in cands: |
|
|
trace.append(format_candidate(prev, cmv, before_axes, cs, cax, cfl)) |
|
|
|
|
|
|
|
|
return "\n".join(trace) |
|
|
|
|
|
|
|
|
def on_move(fen_after: str, prev_fen: str): |
|
|
if not prev_fen: |
|
|
prev_fen = START_FEN |
|
|
md = build_hds_trace(prev_fen, fen_after) |
|
|
return fen_after, fen_after, md |
|
|
|
|
|
|
|
|
def on_reset(): |
|
|
md = ( |
|
|
"### HDS Trace (deterministic)\n" |
|
|
"- State: Roughly equal.\n" |
|
|
"- Move: (none yet)\n" |
|
|
"- Goal: build a safe structure, then increase pressure.\n" |
|
|
) |
|
|
return START_FEN, START_FEN, md |
|
|
|
|
|
|
|
|
CSS = """ |
|
|
footer { display: none !important; } |
|
|
.gradio-container { max-width: 980px !important; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=CSS, title="HDS Chess Trainer") as demo: |
|
|
prev_state = gr.State(START_FEN) |
|
|
|
|
|
board = Chessboard( |
|
|
value=START_FEN, |
|
|
game_mode=True, |
|
|
orientation="white", |
|
|
show_label=False, |
|
|
label=None, |
|
|
) |
|
|
|
|
|
log = gr.Markdown( |
|
|
value=( |
|
|
"### HDS Trace (deterministic)\n" |
|
|
"- State: Roughly equal.\n" |
|
|
"- Move: (none yet)\n" |
|
|
"- Goal: build a safe structure, then increase pressure.\n" |
|
|
) |
|
|
) |
|
|
|
|
|
reset_btn = gr.Button("Reset") |
|
|
|
|
|
board.move(on_move, inputs=[board, prev_state], outputs=[board, prev_state, log]) |
|
|
reset_btn.click(on_reset, inputs=None, outputs=[board, prev_state, log]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|