gatchimuchio's picture
Update app.py
2c0ecb8 verified
from __future__ import annotations
import chess
import gradio as gr
from gradio_chessboard import Chessboard
START_FEN = chess.STARTING_FEN
CENTER4 = [chess.D4, chess.E4, chess.D5, chess.E5]
PIECE_V = {
chess.PAWN: 1.0,
chess.KNIGHT: 3.2,
chess.BISHOP: 3.3,
chess.ROOK: 5.0,
chess.QUEEN: 9.0,
chess.KING: 0.0,
}
# HDS-ish "axes" (kept human-readable)
AXES = [
"material",
"king_safety",
"piece_activity",
"center_reach",
"initiative",
"tactical_tension",
]
W = {
"material": 1.8,
"king_safety": 1.4,
"piece_activity": 1.0,
"center_reach": 1.0,
"initiative": 0.9,
"tactical_tension": 1.1,
}
# -----------------------------
# Helpers
# -----------------------------
def clamp(x: float, lo: float = -1.0, hi: float = 1.0) -> float:
return max(lo, min(hi, float(x)))
def side_name(color: bool) -> str:
return "White" if color == chess.WHITE else "Black"
def square(sq: int) -> str:
return chess.square_name(sq)
def uci_to_label(prev: chess.Board, mv: chess.Move) -> str:
s = f"{square(mv.from_square)}{square(mv.to_square)}"
if prev.is_capture(mv):
s += " (capture)"
if prev.is_castling(mv):
s += " (castle)"
return s
def infer_move(prev_fen: str, new_fen: str) -> chess.Move | None:
b = chess.Board(prev_fen)
for mv in b.legal_moves:
bb = b.copy(stack=False)
bb.push(mv)
if bb.fen() == new_fen:
return mv
return None
def count_attackers(board: chess.Board, color: bool, sq: int) -> int:
return len(board.attackers(color, sq))
# -----------------------------
# Axis components (deterministic, engine-less)
# -----------------------------
def _material(board: chess.Board, me: bool) -> float:
s = 0.0
for pt, v in PIECE_V.items():
s += v * (len(board.pieces(pt, me)) - len(board.pieces(pt, not me)))
return s # raw
def _king_exposure(board: chess.Board, me: bool) -> float:
# Higher risk => worse
ksq = board.king(me)
if ksq is None:
return 10.0
opp = not me
risk = 0
# king neighborhood attacked?
for sq2 in chess.SquareSet(chess.BB_KING_ATTACKS[ksq]):
if board.is_attacked_by(opp, sq2):
risk += 1
return float(risk) # 0..8-ish
def _castled(board: chess.Board, me: bool) -> bool:
ksq = board.king(me)
if ksq is None:
return False
return (ksq in [chess.G1, chess.C1]) if me == chess.WHITE else (ksq in [chess.G8, chess.C8])
def _minor_development(board: chess.Board, me: bool) -> int:
# "awake pieces" (minors moved off initial squares)
if me == chess.WHITE:
kn0 = {chess.B1, chess.G1}
bi0 = {chess.C1, chess.F1}
else:
kn0 = {chess.B8, chess.G8}
bi0 = {chess.C8, chess.F8}
k = sum(1 for sq in board.pieces(chess.KNIGHT, me) if sq not in kn0)
b = sum(1 for sq in board.pieces(chess.BISHOP, me) if sq not in bi0)
return k + b # 0..4
def _center_reach(board: chess.Board, me: bool) -> int:
# number of center squares attacked by me (0..4)
c = 0
for sq2 in CENTER4:
if board.is_attacked_by(me, sq2):
c += 1
return c
def _mobility(board: chess.Board, me: bool) -> int:
b = board.copy(stack=False)
b.turn = me
return b.legal_moves.count()
def _initiative(board: chess.Board, me: bool) -> float:
# crude: checks + captures ratio for side-to-move
b = board.copy(stack=False)
b.turn = me
total = 0
checks = 0
caps = 0
for mv in b.legal_moves:
total += 1
if b.is_capture(mv):
caps += 1
bb = b.copy(stack=False)
bb.push(mv)
if bb.is_check():
checks += 1
if total == 0:
return 0.0
return 0.65 * (checks / total) + 0.35 * (caps / total)
def _tactical_tension(board: chess.Board, me: bool) -> float:
# crude: number of captures available now + hanging targets
b = board.copy(stack=False)
b.turn = me
caps = 0
for mv in b.legal_moves:
if b.is_capture(mv):
caps += 1
# scale
return float(min(10, caps))
def eval_axes(board: chess.Board, me: bool) -> dict[str, float]:
mat = _material(board, me) # raw
kex = _king_exposure(board, me) # raw (lower is better)
dev = _minor_development(board, me) # 0..4
cen = _center_reach(board, me) # 0..4
mob = _mobility(board, me) # ~0..60
ini = _initiative(board, me) # 0..1
ten = _tactical_tension(board, me) # 0..10
# normalize each axis to [-1, 1] (deterministic, stable)
axes = {
"material": clamp(mat / 10.0), # ±10 pawns range
"king_safety": clamp((4.0 - kex) / 4.0), # kex 0..8 => safety roughly
"piece_activity": clamp((dev - 1.5) / 1.5), # dev 0..4
"center_reach": clamp((cen - 1.5) / 1.5), # cen 0..4
"initiative": clamp((ini - 0.12) / 0.25), # ini 0..1
"tactical_tension": clamp((ten - 2.0) / 4.0), # ten 0..10
}
return axes
def score(axes: dict[str, float]) -> float:
return sum(W[a] * axes.get(a, 0.0) for a in AXES)
def top_deltas(before: dict[str, float], after: dict[str, float], k: int = 3):
rows = []
for a in AXES:
d = after.get(a, 0.0) - before.get(a, 0.0)
wd = d * W[a]
rows.append((a, d, wd))
rows.sort(key=lambda x: abs(x[2]), reverse=True)
return rows[:k], rows
def axis_pretty(a: str) -> str:
return {
"material": "Material",
"king_safety": "King safety",
"piece_activity": "Piece activity",
"center_reach": "Center reach",
"initiative": "Initiative",
"tactical_tension": "Tactical tension",
}.get(a, a)
def axis_table_md(before: dict[str, float], after: dict[str, float]) -> str:
lines = []
lines.append("| Axis | Before | After | Δ |")
lines.append("|---|---:|---:|---:|")
for a in AXES:
b = before.get(a, 0.0)
c = after.get(a, 0.0)
d = c - b
lines.append(f"| {axis_pretty(a)} | {b:+.2f} | {c:+.2f} | {d:+.2f} |")
return "\n".join(lines)
# -----------------------------
# Risk / explanation (HDS-ish causal trace)
# -----------------------------
def risk_flags(prev: chess.Board, mv: chess.Move, after: chess.Board) -> list[str]:
flags = []
mover = prev.turn
opp = not mover
# moved piece vulnerability (very simple)
p = after.piece_at(mv.to_square)
if p and p.color == mover:
atk = count_attackers(after, opp, mv.to_square)
dfd = count_attackers(after, mover, mv.to_square)
if atk > dfd:
flags.append("Moved piece becomes a likely target (more attacked than defended).")
# early queen move warning
p0 = prev.piece_at(mv.from_square)
if p0 and p0.piece_type == chess.QUEEN and prev.fullmove_number <= 6:
flags.append("Early queen move: can invite tempo hits.")
# king safety regression
if _king_exposure(after, mover) > _king_exposure(prev, mover):
flags.append("King safety slightly worsens (more pressure near your king).")
return flags
def advantage_label(axes: dict[str, float], me: bool) -> str:
# label from "me" perspective using combined score
s = score(axes)
if s >= 0.55:
return "Ahead"
if s >= 0.20:
return "Slightly better"
if s > -0.20:
return "Roughly equal"
if s > -0.55:
return "Slightly worse"
return "Behind"
def short_verdict(before_axes: dict[str, float], after_axes: dict[str, float], me: bool) -> str:
b = advantage_label(before_axes, me)
a = advantage_label(after_axes, me)
# 1-liner: state transition
if b == a:
return f"State: {a} (no major swing)."
return f"State shift: {b}{a}."
def causal_sentence(top3, flags: list[str]) -> str:
# Build a "because / at the cost of" line using top weighted deltas
pos = [t for t in top3 if t[2] > 0]
neg = [t for t in top3 if t[2] < 0]
because = ", ".join(axis_pretty(x[0]) for x in pos[:2]) if pos else "structure"
cost = ", ".join(axis_pretty(x[0]) for x in neg[:2]) if neg else ""
if cost:
core = f"Because {because} improves, at the cost of {cost}."
else:
core = f"Because {because} improves with limited downside."
if flags:
# add the first risk as a hard boundary
core += f" Boundary risk: {flags[0]}"
return core
# -----------------------------
# Candidate moves (Top-3)
# -----------------------------
def candidate_score(prev: chess.Board, mv: chess.Move) -> tuple[float, dict[str, float], list[str]]:
me = prev.turn
before_axes = eval_axes(prev, me)
b2 = prev.copy(stack=False)
b2.push(mv)
after_axes = eval_axes(b2, me)
base = score(after_axes)
# small deterministic bonuses
bonus = 0.0
if prev.is_capture(mv):
bonus += 0.06
if b2.is_check():
bonus += 0.06
if prev.is_castling(mv):
bonus += 0.05
# small deterministic risk penalty
flags = risk_flags(prev, mv, b2)
penalty = 0.04 * len(flags)
return base + bonus - penalty, after_axes, flags
def top3_candidates(prev: chess.Board) -> list[tuple[chess.Move, float, dict[str, float], list[str]]]:
me = prev.turn
scored = []
for mv in prev.legal_moves:
s, ax, fl = candidate_score(prev, mv)
scored.append((mv, s, ax, fl))
scored.sort(key=lambda x: x[1], reverse=True) # higher is better for mover (always computed from mover POV)
return scored[:3]
def format_candidate(prev: chess.Board, mv: chess.Move, before_axes: dict[str, float], s: float, after_axes: dict[str, float], flags: list[str]) -> str:
# Pick best + worst axis change for a compact explanation
top3, _all = top_deltas(before_axes, after_axes, k=3)
pos = [t for t in top3 if t[2] > 0]
neg = [t for t in top3 if t[2] < 0]
pro = axis_pretty(pos[0][0]) if pos else "Structure"
con = axis_pretty(neg[0][0]) if neg else "Low risk"
risk = flags[0] if flags else "Low immediate risk."
return (
f"- **{prev.san(mv)}** ({uci_to_label(prev, mv)})\n"
f" - Why: **{pro}** improves (score≈{s:+.2f}).\n"
f" - Risk: {risk if flags else con + '.'}"
)
# -----------------------------
# Main move handler
# -----------------------------
def build_hds_trace(prev_fen: str, new_fen: str) -> str:
prev = chess.Board(prev_fen)
mv = infer_move(prev_fen, new_fen)
# If we can't infer (board edited), just show eval snapshot
me = prev.turn
before_axes = eval_axes(prev, me)
if mv is None:
return (
"⚠️ Could not infer the move (board might be edited, not moved).\n\n"
"HDS Trace:\n\n"
+ axis_table_md(before_axes, before_axes)
)
after = chess.Board(new_fen)
after_axes = eval_axes(after, me)
top3, _all = top_deltas(before_axes, after_axes, k=3)
flags = risk_flags(prev, mv, after)
verdict1 = short_verdict(before_axes, after_axes, me)
mover = side_name(me)
move_line = f"Move: **{mover} plays {prev.san(mv)}** ({uci_to_label(prev, mv)})."
# HDS-ish trace block
trace = []
trace.append("### HDS Trace (deterministic)")
trace.append(f"- {verdict1}")
trace.append(f"- {move_line}")
trace.append(f"- {causal_sentence(top3, flags)}")
# State vector table (before/after)
trace.append("\n**State vector (axes)**")
trace.append(axis_table_md(before_axes, after_axes))
# Candidate moves (computed from the position *before* your move)
trace.append("\n### Top-3 candidate moves (with pros / risks)")
cands = top3_candidates(prev)
for (cmv, cs, cax, cfl) in cands:
trace.append(format_candidate(prev, cmv, before_axes, cs, cax, cfl))
# Keep it readable; still “HDS-ish” without garbage jargon
return "\n".join(trace)
def on_move(fen_after: str, prev_fen: str):
if not prev_fen:
prev_fen = START_FEN
md = build_hds_trace(prev_fen, fen_after)
return fen_after, fen_after, md
def on_reset():
md = (
"### HDS Trace (deterministic)\n"
"- State: Roughly equal.\n"
"- Move: (none yet)\n"
"- Goal: build a safe structure, then increase pressure.\n"
)
return START_FEN, START_FEN, md
CSS = """
footer { display: none !important; }
.gradio-container { max-width: 980px !important; }
"""
with gr.Blocks(css=CSS, title="HDS Chess Trainer") as demo:
prev_state = gr.State(START_FEN)
board = Chessboard(
value=START_FEN,
game_mode=True,
orientation="white", # fixed as requested
show_label=False,
label=None,
)
log = gr.Markdown(
value=(
"### HDS Trace (deterministic)\n"
"- State: Roughly equal.\n"
"- Move: (none yet)\n"
"- Goal: build a safe structure, then increase pressure.\n"
)
)
reset_btn = gr.Button("Reset")
board.move(on_move, inputs=[board, prev_state], outputs=[board, prev_state, log])
reset_btn.click(on_reset, inputs=None, outputs=[board, prev_state, log])
if __name__ == "__main__":
demo.launch()