FD900 commited on
Commit
e0152b8
·
verified ·
1 Parent(s): c62c303

Update tools/chess_tools.py

Browse files
Files changed (1) hide show
  1. tools/chess_tools.py +88 -0
tools/chess_tools.py CHANGED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from smolagents import Tool
3
+ import requests
4
+
5
+ class ChessEngineLocatorTool(Tool):
6
+ name = "locate_chess_engine"
7
+ description = "Locates the installed Stockfish chess engine binary on the system."
8
+ inputs = {}
9
+ output_type = "string"
10
+
11
+ def forward(self) -> str:
12
+ engine_path = shutil.which("stockfish")
13
+ return engine_path or "Stockfish engine not found."
14
+
15
+
16
+ class ChessboardImageToFENTool(Tool):
17
+ name = "image_to_fen"
18
+ description = "Analyzes a chessboard image and returns its FEN representation."
19
+ inputs = {
20
+ "image_url": {
21
+ "type": "string",
22
+ "description": "A public image URL or a base64 data URL containing a chessboard photo."
23
+ }
24
+ }
25
+ output_type = "string"
26
+
27
+ def __init__(self, endpoint_url: str, hf_token: str, **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.endpoint_url = endpoint_url
30
+ self.hf_token = hf_token
31
+
32
+ def forward(self, image_url: str) -> str:
33
+ headers = {
34
+ "Authorization": f"Bearer {self.hf_token}",
35
+ "Content-Type": "application/json"
36
+ }
37
+ payload = {
38
+ "inputs": [
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {"type": "input_text", "text": "Extract chess piece positions from this image."},
43
+ {"type": "input_image", "image_url": image_url}
44
+ ]
45
+ },
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {
50
+ "type": "input_text",
51
+ "text": (
52
+ "Using the extracted pieces, list each one with its position in format <piece><square> (e.g., Kd4).\n"
53
+ "Use uppercase for white, lowercase for black. Output only the lines with positions."
54
+ )
55
+ }
56
+ ]
57
+ }
58
+ ]
59
+ }
60
+
61
+ response = requests.post(self.endpoint_url, headers=headers, json=payload)
62
+ response.raise_for_status()
63
+ raw_text = response.json().get("generated_text", "")
64
+
65
+ squares = {}
66
+ for line in raw_text.splitlines():
67
+ line = line.strip()
68
+ if len(line) == 3:
69
+ squares[line[1:3]] = line[0]
70
+
71
+ fen_rows = []
72
+ for rank in range(8, 0, -1):
73
+ row = ""
74
+ empty = 0
75
+ for file in "abcdefgh":
76
+ pos = f"{file}{rank}"
77
+ if pos in squares:
78
+ if empty > 0:
79
+ row += str(empty)
80
+ empty = 0
81
+ row += squares[pos]
82
+ else:
83
+ empty += 1
84
+ if empty:
85
+ row += str(empty)
86
+ fen_rows.append(row)
87
+
88
+ return "/".join(fen_rows)