Spaces:
Sleeping
Sleeping
File size: 5,627 Bytes
6e2dee6 41c1ae6 6e2dee6 41c1ae6 6e2dee6 aff979d 41c1ae6 aff979d 41c1ae6 6e2dee6 aff979d eb1431e 6e2dee6 87254e2 eb1431e 6e2dee6 aff979d 41c1ae6 6e2dee6 aff979d 6e2dee6 aff979d 6e2dee6 aff979d 6e2dee6 aff979d 6e2dee6 aff979d 6e2dee6 aff979d 6e2dee6 eb1431e 6e2dee6 eb1431e 6e2dee6 aff979d 6e2dee6 aff979d 6e2dee6 eb1431e 6e2dee6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | import gradio as gr
import chess
import chess.svg
import io
import json
import numpy as np
import cv2
import os
import tempfile
from pathlib import Path
from preprocess import preprocess_image
from train import create_model
# On charge l'ordre des classes depuis le fichier généré par train.
try:
with open('./class_indices.json', 'r') as f:
class_indices = json.load(f)
# Inverser pour avoir index -> nom
PIECES = [None] * len(class_indices)
for name, idx in class_indices.items():
PIECES[idx] = name
print(f"Ordre des classes chargé: {PIECES}")
except FileNotFoundError:
# Si jamais le fichier n'est pas load correctement ou erreur
PIECES = ['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black',
'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White']
print(f"Fichier class_indices.json non trouvé, utilisation ordre par défaut")
LABELS = {
'Empty': '.',
'Rook_White': 'R',
'Rook_Black': 'r',
'Knight_White': 'N',
'Knight_Black': 'n',
'Bishop_White': 'B',
'Bishop_Black': 'b',
'Queen_White': 'Q',
'Queen_Black': 'q',
'King_White': 'K',
'King_Black': 'k',
'Pawn_White': 'P',
'Pawn_Black': 'p',
}
# On charge notre modele
print("Loading model...")
model = create_model()
model.load_weights('./model_weights.weights.h5')
print("Model loaded!")
def classify_image(img):
# On donne une image d'une pièce unique, on la classifie en une seule classe definie (Son nom est PIECE)
# Ici on normalise notre image comme dans notre entrainement (ici on fait un rescale=1/255)
if img.max() > 1.0:
img = img.astype(np.float32) / 255.0
else:
img = img.astype(np.float32)
y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
y_pred = y_prob.argmax()
return PIECES[y_pred]
def analyze_board(img):
arr = []
M = img.shape[0]//8
N = img.shape[1]//8
for y in range(M-1, img.shape[1], M):
row = []
for x in range(0, img.shape[1], N):
sub_img = img[max(0, y-2*M):y, x:x+N]
if y-2*M < 0:
sub_img = np.concatenate(
(np.zeros((2*M-y, N, 3)), sub_img))
sub_img = sub_img.astype(np.uint8)
piece = classify_image(sub_img)
row.append(LABELS[piece])
arr.append(row)
# Ajustement King-Queen detection
blackKing = False
whiteKing = False
whitePos = (-1, -1)
blackPos = (-1, -1)
for i in range(8):
for j in range(8):
if arr[i][j] == 'K':
whiteKing = True
if arr[i][j] == 'k':
blackKing = True
if arr[i][j] == 'Q':
whitePos = (i, j)
if arr[i][j] == 'q':
blackPos = (i, j)
if not whiteKing and whitePos[0] >= 0:
arr[whitePos[0]][whitePos[1]] = 'K'
if not blackKing and blackPos[0] >= 0:
arr[blackPos[0]][blackPos[1]] = 'k'
return arr
def board_to_fen(board):
with io.StringIO() as s:
for row in board:
empty = 0
for cell in row:
if cell != '.':
if empty > 0:
s.write(str(empty))
empty = 0
s.write(cell)
else:
empty += 1
if empty > 0:
s.write(str(empty))
s.write('/')
s.seek(s.tell() - 1)
s.write(' w KQkq - 0 1')
return s.getvalue()
def analyze_chess_image(image_input):
# Logique gradio pour notre main.
if image_input is None:
return "❌ No image provided", None
try:
# On sauvegarde temporairement
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
if isinstance(image_input, np.ndarray):
cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
else:
image_input.save(tmp.name)
temp_path = tmp.name
# preprocess_image() utilise le modele LAPS
img = preprocess_image(temp_path, save=False)
# EXACT SAME as main.py
arr = analyze_board(img)
fen = board_to_fen(arr)
# On génère l'echiquier
board = chess.Board(fen)
board_svg = chess.svg.board(board=board, size=400)
# on clean le fichier temporairement sauvegarder
os.unlink(temp_path)
return f"{fen}", board_svg
except Exception as e:
import traceback
print(traceback.format_exc())
return f"Error: {str(e)}", None
# Build Gradio interface
with gr.Blocks(title="Chess Board picture -> FEN notation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ♟️ YOCO: You Only Look Once
Upload a chess board image to automatically detect all pieces and get the FEN notation.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload chess board image", type="pil")
submit_btn = gr.Button("Analyze Board", size="lg", variant="primary")
with gr.Column():
status_output = gr.Textbox(label="Result", interactive=False, lines=2)
board_output = gr.HTML(label="Board Visualization")
submit_btn.click(
fn=analyze_chess_image,
inputs=image_input,
outputs=[status_output, board_output]
)
if __name__ == "__main__":
demo.launch()
|