Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import chess | |
| import chess.svg | |
| import io | |
| import json | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from preprocess import preprocess_image | |
| from train import create_model | |
| # On charge l'ordre des classes depuis le fichier généré par train. | |
| try: | |
| with open('./class_indices.json', 'r') as f: | |
| class_indices = json.load(f) | |
| # Inverser pour avoir index -> nom | |
| PIECES = [None] * len(class_indices) | |
| for name, idx in class_indices.items(): | |
| PIECES[idx] = name | |
| print(f"Ordre des classes chargé: {PIECES}") | |
| except FileNotFoundError: | |
| # Si jamais le fichier n'est pas load correctement ou erreur | |
| PIECES = ['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black', | |
| 'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White'] | |
| print(f"Fichier class_indices.json non trouvé, utilisation ordre par défaut") | |
| LABELS = { | |
| 'Empty': '.', | |
| 'Rook_White': 'R', | |
| 'Rook_Black': 'r', | |
| 'Knight_White': 'N', | |
| 'Knight_Black': 'n', | |
| 'Bishop_White': 'B', | |
| 'Bishop_Black': 'b', | |
| 'Queen_White': 'Q', | |
| 'Queen_Black': 'q', | |
| 'King_White': 'K', | |
| 'King_Black': 'k', | |
| 'Pawn_White': 'P', | |
| 'Pawn_Black': 'p', | |
| } | |
| # On charge notre modele | |
| print("Loading model...") | |
| model = create_model() | |
| model.load_weights('./model_weights.weights.h5') | |
| print("Model loaded!") | |
| def classify_image(img): | |
| # On donne une image d'une pièce unique, on la classifie en une seule classe definie (Son nom est PIECE) | |
| # Ici on normalise notre image comme dans notre entrainement (ici on fait un rescale=1/255) | |
| if img.max() > 1.0: | |
| img = img.astype(np.float32) / 255.0 | |
| else: | |
| img = img.astype(np.float32) | |
| y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0) | |
| y_pred = y_prob.argmax() | |
| return PIECES[y_pred] | |
| def analyze_board(img): | |
| arr = [] | |
| M = img.shape[0]//8 | |
| N = img.shape[1]//8 | |
| for y in range(M-1, img.shape[1], M): | |
| row = [] | |
| for x in range(0, img.shape[1], N): | |
| sub_img = img[max(0, y-2*M):y, x:x+N] | |
| if y-2*M < 0: | |
| sub_img = np.concatenate( | |
| (np.zeros((2*M-y, N, 3)), sub_img)) | |
| sub_img = sub_img.astype(np.uint8) | |
| piece = classify_image(sub_img) | |
| row.append(LABELS[piece]) | |
| arr.append(row) | |
| # Ajustement King-Queen detection | |
| blackKing = False | |
| whiteKing = False | |
| whitePos = (-1, -1) | |
| blackPos = (-1, -1) | |
| for i in range(8): | |
| for j in range(8): | |
| if arr[i][j] == 'K': | |
| whiteKing = True | |
| if arr[i][j] == 'k': | |
| blackKing = True | |
| if arr[i][j] == 'Q': | |
| whitePos = (i, j) | |
| if arr[i][j] == 'q': | |
| blackPos = (i, j) | |
| if not whiteKing and whitePos[0] >= 0: | |
| arr[whitePos[0]][whitePos[1]] = 'K' | |
| if not blackKing and blackPos[0] >= 0: | |
| arr[blackPos[0]][blackPos[1]] = 'k' | |
| return arr | |
| def board_to_fen(board): | |
| with io.StringIO() as s: | |
| for row in board: | |
| empty = 0 | |
| for cell in row: | |
| if cell != '.': | |
| if empty > 0: | |
| s.write(str(empty)) | |
| empty = 0 | |
| s.write(cell) | |
| else: | |
| empty += 1 | |
| if empty > 0: | |
| s.write(str(empty)) | |
| s.write('/') | |
| s.seek(s.tell() - 1) | |
| s.write(' w KQkq - 0 1') | |
| return s.getvalue() | |
| def analyze_chess_image(image_input): | |
| # Logique gradio pour notre main. | |
| if image_input is None: | |
| return "❌ No image provided", None | |
| try: | |
| # On sauvegarde temporairement | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: | |
| if isinstance(image_input, np.ndarray): | |
| cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)) | |
| else: | |
| image_input.save(tmp.name) | |
| temp_path = tmp.name | |
| # preprocess_image() utilise le modele LAPS | |
| img = preprocess_image(temp_path, save=False) | |
| # EXACT SAME as main.py | |
| arr = analyze_board(img) | |
| fen = board_to_fen(arr) | |
| # On génère l'echiquier | |
| board = chess.Board(fen) | |
| board_svg = chess.svg.board(board=board, size=400) | |
| # on clean le fichier temporairement sauvegarder | |
| os.unlink(temp_path) | |
| return f"{fen}", board_svg | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| return f"Error: {str(e)}", None | |
| # Build Gradio interface | |
| with gr.Blocks(title="Chess Board picture -> FEN notation", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ♟️ YOCO: You Only Look Once | |
| Upload a chess board image to automatically detect all pieces and get the FEN notation. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload chess board image", type="pil") | |
| submit_btn = gr.Button("Analyze Board", size="lg", variant="primary") | |
| with gr.Column(): | |
| status_output = gr.Textbox(label="Result", interactive=False, lines=2) | |
| board_output = gr.HTML(label="Board Visualization") | |
| submit_btn.click( | |
| fn=analyze_chess_image, | |
| inputs=image_input, | |
| outputs=[status_output, board_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |