nathbns's picture
Update app.py
aff979d verified
import gradio as gr
import chess
import chess.svg
import io
import json
import numpy as np
import cv2
import os
import tempfile
from pathlib import Path
from preprocess import preprocess_image
from train import create_model
# On charge l'ordre des classes depuis le fichier généré par train.
try:
with open('./class_indices.json', 'r') as f:
class_indices = json.load(f)
# Inverser pour avoir index -> nom
PIECES = [None] * len(class_indices)
for name, idx in class_indices.items():
PIECES[idx] = name
print(f"Ordre des classes chargé: {PIECES}")
except FileNotFoundError:
# Si jamais le fichier n'est pas load correctement ou erreur
PIECES = ['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black',
'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White']
print(f"Fichier class_indices.json non trouvé, utilisation ordre par défaut")
LABELS = {
'Empty': '.',
'Rook_White': 'R',
'Rook_Black': 'r',
'Knight_White': 'N',
'Knight_Black': 'n',
'Bishop_White': 'B',
'Bishop_Black': 'b',
'Queen_White': 'Q',
'Queen_Black': 'q',
'King_White': 'K',
'King_Black': 'k',
'Pawn_White': 'P',
'Pawn_Black': 'p',
}
# On charge notre modele
print("Loading model...")
model = create_model()
model.load_weights('./model_weights.weights.h5')
print("Model loaded!")
def classify_image(img):
# On donne une image d'une pièce unique, on la classifie en une seule classe definie (Son nom est PIECE)
# Ici on normalise notre image comme dans notre entrainement (ici on fait un rescale=1/255)
if img.max() > 1.0:
img = img.astype(np.float32) / 255.0
else:
img = img.astype(np.float32)
y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
y_pred = y_prob.argmax()
return PIECES[y_pred]
def analyze_board(img):
arr = []
M = img.shape[0]//8
N = img.shape[1]//8
for y in range(M-1, img.shape[1], M):
row = []
for x in range(0, img.shape[1], N):
sub_img = img[max(0, y-2*M):y, x:x+N]
if y-2*M < 0:
sub_img = np.concatenate(
(np.zeros((2*M-y, N, 3)), sub_img))
sub_img = sub_img.astype(np.uint8)
piece = classify_image(sub_img)
row.append(LABELS[piece])
arr.append(row)
# Ajustement King-Queen detection
blackKing = False
whiteKing = False
whitePos = (-1, -1)
blackPos = (-1, -1)
for i in range(8):
for j in range(8):
if arr[i][j] == 'K':
whiteKing = True
if arr[i][j] == 'k':
blackKing = True
if arr[i][j] == 'Q':
whitePos = (i, j)
if arr[i][j] == 'q':
blackPos = (i, j)
if not whiteKing and whitePos[0] >= 0:
arr[whitePos[0]][whitePos[1]] = 'K'
if not blackKing and blackPos[0] >= 0:
arr[blackPos[0]][blackPos[1]] = 'k'
return arr
def board_to_fen(board):
with io.StringIO() as s:
for row in board:
empty = 0
for cell in row:
if cell != '.':
if empty > 0:
s.write(str(empty))
empty = 0
s.write(cell)
else:
empty += 1
if empty > 0:
s.write(str(empty))
s.write('/')
s.seek(s.tell() - 1)
s.write(' w KQkq - 0 1')
return s.getvalue()
def analyze_chess_image(image_input):
# Logique gradio pour notre main.
if image_input is None:
return "❌ No image provided", None
try:
# On sauvegarde temporairement
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
if isinstance(image_input, np.ndarray):
cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
else:
image_input.save(tmp.name)
temp_path = tmp.name
# preprocess_image() utilise le modele LAPS
img = preprocess_image(temp_path, save=False)
# EXACT SAME as main.py
arr = analyze_board(img)
fen = board_to_fen(arr)
# On génère l'echiquier
board = chess.Board(fen)
board_svg = chess.svg.board(board=board, size=400)
# on clean le fichier temporairement sauvegarder
os.unlink(temp_path)
return f"{fen}", board_svg
except Exception as e:
import traceback
print(traceback.format_exc())
return f"Error: {str(e)}", None
# Build Gradio interface
with gr.Blocks(title="Chess Board picture -> FEN notation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ♟️ YOCO: You Only Look Once
Upload a chess board image to automatically detect all pieces and get the FEN notation.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload chess board image", type="pil")
submit_btn = gr.Button("Analyze Board", size="lg", variant="primary")
with gr.Column():
status_output = gr.Textbox(label="Result", interactive=False, lines=2)
board_output = gr.HTML(label="Board Visualization")
submit_btn.click(
fn=analyze_chess_image,
inputs=image_input,
outputs=[status_output, board_output]
)
if __name__ == "__main__":
demo.launch()