|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from typing import Dict, List, Optional, Any |
|
|
|
|
|
import chess |
|
|
import chess.engine |
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain.schema import SystemMessage, HumanMessage |
|
|
from langchain.tools import tool |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_openai import ChatOpenAI |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from utils.prompt_manager import prompt_mgmt |
|
|
|
|
|
|
|
|
def encode_image_to_base64(image_path: str) -> str: |
|
|
"""Encode image to base64 for API consumption""" |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
|
|
|
class ChessPiecePosition(BaseModel): |
|
|
"""Model for chess piece position""" |
|
|
square: str = Field(..., description="Chess square notation (e.g., 'e4', 'a1')") |
|
|
piece: str = Field(..., description="Piece type and color (e.g., 'white_king', 'black_queen')") |
|
|
|
|
|
|
|
|
class ChessBoardAnalysis(BaseModel): |
|
|
"""Model for complete chess board analysis""" |
|
|
positions: List[ChessPiecePosition] = Field(..., description="List of all piece positions on the board") |
|
|
|
|
|
def add_positions(self, positions: List[ChessPiecePosition]) -> None: |
|
|
"""Add multiple positions to the analysis""" |
|
|
for position in positions: |
|
|
self.positions.append(position) |
|
|
|
|
|
def merge_with(self, other: 'ChessBoardAnalysis') -> None: |
|
|
"""Merge another analysis into this one (overwriting conflicts)""" |
|
|
self.add_positions(other.positions) |
|
|
|
|
|
def to_fen(self, active_color) -> str: |
|
|
"""Convert the analysis to FEN notation (simplified)""" |
|
|
|
|
|
board = [['' for _ in range(8)] for _ in range(8)] |
|
|
|
|
|
for position in self.positions: |
|
|
file_idx = ord(position.square[0]) - ord('a') |
|
|
rank_idx = 8 - int(position.square[1]) |
|
|
|
|
|
if 0 <= file_idx < 8 and 0 <= rank_idx < 8: |
|
|
piece_char = self._piece_to_char(position.piece) |
|
|
board[rank_idx][file_idx] = piece_char |
|
|
|
|
|
|
|
|
fen_rows = [] |
|
|
for row in board: |
|
|
fen_row = '' |
|
|
empty_count = 0 |
|
|
|
|
|
for cell in row: |
|
|
if cell == '': |
|
|
empty_count += 1 |
|
|
else: |
|
|
if empty_count > 0: |
|
|
fen_row += str(empty_count) |
|
|
empty_count = 0 |
|
|
fen_row += cell |
|
|
|
|
|
if empty_count > 0: |
|
|
fen_row += str(empty_count) |
|
|
|
|
|
fen_rows.append(fen_row) |
|
|
|
|
|
piece_placement = '/'.join(fen_rows) |
|
|
|
|
|
active_color_char = 'w' if active_color.lower() == 'white' else 'b' |
|
|
|
|
|
castling_rights = "-" |
|
|
en_passant = "-" |
|
|
halfmove_clock = 0 |
|
|
fullmove_number = 1 |
|
|
fen_parts = [ |
|
|
piece_placement, |
|
|
active_color_char, |
|
|
castling_rights, |
|
|
en_passant, |
|
|
str(halfmove_clock), |
|
|
str(fullmove_number) |
|
|
] |
|
|
|
|
|
return ' '.join(fen_parts) |
|
|
|
|
|
def _piece_to_char(self, piece: str) -> str: |
|
|
"""Convert piece description to FEN character""" |
|
|
color, piece_type = piece.split('_') |
|
|
piece_chars = { |
|
|
'king': 'K', 'queen': 'Q', 'rook': 'R', |
|
|
'bishop': 'B', 'knight': 'N', 'pawn': 'P' |
|
|
} |
|
|
char = piece_chars.get(piece_type, '') |
|
|
return char.lower() if color == 'black' else char |
|
|
|
|
|
|
|
|
class ChessVisionAnalyzer: |
|
|
def __init__(self): |
|
|
self.llm1 = init_chat_model(model="openai:gpt-4.1", temperature=0.0) |
|
|
self.llm2 = ChatGoogleGenerativeAI(model="gemini-2.5-flash") |
|
|
|
|
|
def analyze_board_orientation(self, active_color: str, image_path: str) -> str: |
|
|
"""Analyze chess board image and return FEN notation""" |
|
|
base64_image = encode_image_to_base64(image_path) |
|
|
|
|
|
messages = [ |
|
|
SystemMessage( |
|
|
content=prompt_mgmt.render_template("chess_board_orientation", {})), |
|
|
HumanMessage(content=[ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"Analyze this chess board image and return the chess board orientation. I know that the " |
|
|
f"active color is {active_color}" |
|
|
|
|
|
}, |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" |
|
|
} |
|
|
|
|
|
} |
|
|
]) |
|
|
] |
|
|
|
|
|
response = self.llm1.invoke(messages) |
|
|
return response.content |
|
|
|
|
|
def analyze_board_from_image(self, board_orientation: str, image_path: str, llm_no: int, |
|
|
squares: Optional[list] = None) -> Optional[ChessBoardAnalysis]: |
|
|
"""Analyze chess board image and return FEN notation""" |
|
|
base64_image = encode_image_to_base64(image_path) |
|
|
|
|
|
squares_text = "" |
|
|
if squares: |
|
|
squares_text = (f"Focus only on these pieces {sorted(squares)} " |
|
|
f"*** Important: make sure you detect correctly their position as they were challanged by another model. Take into account the board " |
|
|
f"orientation." |
|
|
) |
|
|
|
|
|
messages = [ |
|
|
SystemMessage( |
|
|
content=prompt_mgmt.render_template("chess_board_detection", {})), |
|
|
HumanMessage(content=[ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"""Analyze this chess board image and return the pieces positions. |
|
|
{board_orientation} |
|
|
|
|
|
{squares_text} |
|
|
Return the positions of the pieces in JSON format. |
|
|
Use the following schema for each piece: |
|
|
[{{ |
|
|
"square": "chess notation (e.g., 'e4', 'a1')", |
|
|
"piece": "color_piece (e.g., 'white_king', 'black_queen')" |
|
|
}},... |
|
|
|
|
|
{{ |
|
|
"square": "chess notation (e.g., 'e4', 'a1')", |
|
|
"piece": "color_piece (e.g., 'white_king', 'black_queen')" |
|
|
}} |
|
|
] |
|
|
Very Important: Return only this list! |
|
|
""" |
|
|
|
|
|
}, |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" |
|
|
} |
|
|
|
|
|
} |
|
|
]) |
|
|
] |
|
|
if llm_no == 1: |
|
|
response = self.llm1.invoke(messages) |
|
|
else: |
|
|
response = self.llm2.invoke(messages) |
|
|
return self._parse_llm_response(response.content) |
|
|
|
|
|
def analyze_board(self, active_color: str, file_reference: str) -> str: |
|
|
board_orientation = self.analyze_board_orientation(active_color, file_reference) |
|
|
first_analysis_res = self.analyze_board_from_image(board_orientation, file_reference, 1) |
|
|
second_analysis_res = self.analyze_board_from_image(board_orientation, file_reference, 2) |
|
|
|
|
|
result = self.compare_analyses(first_analysis_res, second_analysis_res) |
|
|
if result['conflicts'] is not None and len(result['conflicts']) > 0: |
|
|
arbitrage_result = self.arbitrate_conflicts(result, board_orientation, file_reference, 3) |
|
|
|
|
|
return arbitrage_result.get("consensus").to_fen(active_color) |
|
|
else: |
|
|
result.get("consensus").to_fen(active_color) |
|
|
|
|
|
def _parse_llm_response(self, response: str) -> Optional[ChessBoardAnalysis]: |
|
|
"""Parse LLM response into ChessBoardAnalysis""" |
|
|
try: |
|
|
|
|
|
json_str = response.strip() |
|
|
if "```json" in json_str: |
|
|
json_str = json_str.split("```json")[1].split("```")[0].strip() |
|
|
elif "```" in json_str: |
|
|
json_str = json_str.split("```")[1].split("```")[0].strip() |
|
|
|
|
|
data = json.loads(json_str) |
|
|
print(data) |
|
|
|
|
|
positions = [] |
|
|
for item in data: |
|
|
if item["piece"]: |
|
|
positions.append(ChessPiecePosition(**item)) |
|
|
|
|
|
return ChessBoardAnalysis(positions=positions) |
|
|
except Exception as e: |
|
|
print(f"Failed to parse LLM response: {e}") |
|
|
return None |
|
|
|
|
|
def compare_analyses(self, analysis_1: ChessBoardAnalysis, analysis_2: ChessBoardAnalysis) -> dict: |
|
|
"""Compare the given analyses and identify conflicts""" |
|
|
|
|
|
if not analysis_1 or not analysis_2: |
|
|
return {"conflicts": [], "consensus": None, "need_arbitration": False} |
|
|
|
|
|
|
|
|
dict_1 = {pos.square: pos.piece for pos in analysis_1.positions} |
|
|
dict_2 = {pos.square: pos.piece for pos in analysis_2.positions} |
|
|
|
|
|
conflicts = [] |
|
|
consensus = [] |
|
|
|
|
|
|
|
|
all_squares = set(dict_1.keys()) | set(dict_2.keys()) |
|
|
|
|
|
for square in all_squares: |
|
|
piece_1 = dict_1.get(square) |
|
|
piece_2 = dict_2.get(square) |
|
|
|
|
|
if piece_1 == piece_2: |
|
|
if piece_1: |
|
|
consensus.append(ChessPiecePosition(square=square, piece=piece_1)) |
|
|
else: |
|
|
conflicts.append({ |
|
|
"square": square, |
|
|
"analysis_1": piece_1, |
|
|
"analysis_2": piece_2 |
|
|
}) |
|
|
|
|
|
need_arbitration = len(conflicts) > 0 |
|
|
|
|
|
return { |
|
|
"conflicts": conflicts, |
|
|
"consensus": ChessBoardAnalysis(positions=consensus), |
|
|
"need_arbitration": need_arbitration |
|
|
} |
|
|
|
|
|
def arbitrate_conflicts(self, state: dict, board_orientation: str, image_path: str, depth: int = 1) -> dict: |
|
|
"""Arbitrate conflicting piece positions""" |
|
|
print(f"Arbitrating conflicts with depth {depth}") |
|
|
|
|
|
conflicts = state.get("conflicts", []) |
|
|
conflicts_sqares = [] |
|
|
for conflict in conflicts: |
|
|
if conflict["analysis_1"] is not None: conflicts_sqares.append(conflict["analysis_1"]) |
|
|
if conflict["analysis_2"] is not None: conflicts_sqares.append(conflict["analysis_2"]) |
|
|
|
|
|
conflicts_sqares = set(conflicts_sqares) |
|
|
|
|
|
print("Pieces with conflicts:", conflicts_sqares) |
|
|
|
|
|
first_analysis_res = self.analyze_board_from_image(board_orientation, image_path, 1, conflicts_sqares) |
|
|
second_analysis_res = self.analyze_board_from_image(board_orientation, image_path, 2, conflicts_sqares) |
|
|
result = self.compare_analyses(first_analysis_res, second_analysis_res) |
|
|
result.get("consensus").merge_with(state.get("consensus")) |
|
|
if result['conflicts'] is not None and len(result['conflicts']) > 0: |
|
|
if depth > 0: |
|
|
depth -= 1 |
|
|
result = self.arbitrate_conflicts(result, board_orientation, image_path, depth) |
|
|
else: |
|
|
print("Arbitrage completed with conflicts. took llm2 as ground truth") |
|
|
result.get("consensus").merge_with(second_analysis_res) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class ChessEngineAnalyzer: |
|
|
def __init__(self, stockfish_path: str = "stockfish"): |
|
|
self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) |
|
|
|
|
|
def analyze_position(self, fen: str, depth: int = 18) -> Dict[str, Any]: |
|
|
"""Analyze chess position using Stockfish""" |
|
|
board = chess.Board(fen) |
|
|
|
|
|
|
|
|
info = self.engine.analyse(board, chess.engine.Limit(depth=depth)) |
|
|
|
|
|
best_move = info.get("pv", [])[0] if info.get("pv") else None |
|
|
evaluation = info.get("score", chess.engine.PovScore(chess.engine.Cp(0), chess.WHITE)) |
|
|
|
|
|
return { |
|
|
"best_move": best_move.uci() if best_move else None, |
|
|
"evaluation": str(evaluation), |
|
|
"depth": depth, |
|
|
"analysis": info |
|
|
} |
|
|
|
|
|
def close(self): |
|
|
self.engine.quit() |
|
|
|
|
|
|
|
|
class ChessMoveExplainer: |
|
|
def __init__(self): |
|
|
self.llm = ChatOpenAI( |
|
|
model="gpt-4" |
|
|
) |
|
|
|
|
|
def explain_move(self, fen: str, move: str, analysis: Dict) -> str: |
|
|
"""Generate human-readable explanation of the recommended move""" |
|
|
board = chess.Board(fen) |
|
|
san_move = board.san(chess.Move.from_uci(move)) |
|
|
|
|
|
prompt = f""" |
|
|
Chess position FEN: {fen} |
|
|
Recommended move: {san_move} ({move}) |
|
|
Engine evaluation: {analysis['evaluation']} |
|
|
Analysis depth: {analysis['depth']} |
|
|
|
|
|
Explain this move recommendation in simple terms. Consider: |
|
|
1. Why this move is strong |
|
|
2. What threats it creates or prevents |
|
|
3. The strategic implications |
|
|
4. Alternative moves and why they're inferior |
|
|
5. Keep it concise but informative for an intermediate player |
|
|
""" |
|
|
|
|
|
response = self.llm.invoke([HumanMessage(content=prompt)]) |
|
|
return response.content |
|
|
|
|
|
|
|
|
@tool |
|
|
def chess_analysis_tool(active_color: str, file_reference: str) -> str: |
|
|
""" |
|
|
Tool for analyzing a chess board images and recommending moves |
|
|
:param active_color: The color that should execute the next move |
|
|
:param file_reference: the reference of the image to be analyzed |
|
|
:return: the recommended move along with an analysis |
|
|
""" |
|
|
vision_analyzer = ChessVisionAnalyzer() |
|
|
engine_analyzer = ChessEngineAnalyzer(os.getenv("CHESS_ENGINE_PATH")) |
|
|
move_explainer = ChessMoveExplainer() |
|
|
fen = vision_analyzer.analyze_board(active_color, file_reference) |
|
|
|
|
|
print(f"Got fen {fen}") |
|
|
analysis_result = engine_analyzer.analyze_position(fen) |
|
|
print(f"Got analysis reslut {analysis_result}") |
|
|
engine_analyzer.close() |
|
|
explanation = move_explainer.explain_move(fen, analysis_result["best_move"], analysis_result) |
|
|
|
|
|
return explanation |
|
|
|