File size: 5,627 Bytes
6e2dee6
 
 
 
41c1ae6
6e2dee6
 
 
 
 
 
 
41c1ae6
6e2dee6
aff979d
41c1ae6
 
 
 
 
 
 
 
 
aff979d
41c1ae6
 
 
6e2dee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aff979d
eb1431e
6e2dee6
87254e2
eb1431e
6e2dee6
 
 
aff979d
 
41c1ae6
 
 
 
 
6e2dee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aff979d
6e2dee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aff979d
6e2dee6
 
 
 
aff979d
6e2dee6
 
 
 
 
 
 
aff979d
6e2dee6
 
 
 
 
 
aff979d
6e2dee6
 
 
aff979d
6e2dee6
 
eb1431e
6e2dee6
 
 
 
eb1431e
6e2dee6
 
 
aff979d
6e2dee6
aff979d
6e2dee6
 
 
 
 
 
eb1431e
 
6e2dee6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import chess
import chess.svg
import io
import json
import numpy as np
import cv2
import os
import tempfile
from pathlib import Path

from preprocess import preprocess_image
from train import create_model

# On charge l'ordre des classes depuis le fichier généré par train.
try:
    with open('./class_indices.json', 'r') as f:
        class_indices = json.load(f)
    # Inverser pour avoir index -> nom
    PIECES = [None] * len(class_indices)
    for name, idx in class_indices.items():
        PIECES[idx] = name
    print(f"Ordre des classes chargé: {PIECES}")
except FileNotFoundError:
    # Si jamais le fichier n'est pas load correctement ou erreur
    PIECES = ['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black',
              'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White']
    print(f"Fichier class_indices.json non trouvé, utilisation ordre par défaut")

LABELS = {
    'Empty': '.',
    'Rook_White': 'R',
    'Rook_Black': 'r',
    'Knight_White': 'N',
    'Knight_Black': 'n',
    'Bishop_White': 'B',
    'Bishop_Black': 'b',
    'Queen_White': 'Q',
    'Queen_Black': 'q',
    'King_White': 'K',
    'King_Black': 'k',
    'Pawn_White': 'P',
    'Pawn_Black': 'p',
}

# On charge notre modele
print("Loading model...")
model = create_model()
model.load_weights('./model_weights.weights.h5')
print("Model loaded!")


def classify_image(img):
    # On donne une image d'une pièce unique, on la classifie en une seule classe definie (Son nom est PIECE)
    # Ici on normalise notre image comme dans notre entrainement (ici on fait un rescale=1/255)
    if img.max() > 1.0:
        img = img.astype(np.float32) / 255.0
    else:
        img = img.astype(np.float32)
    
    y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
    y_pred = y_prob.argmax()
    return PIECES[y_pred]


def analyze_board(img):
    arr = []
    M = img.shape[0]//8
    N = img.shape[1]//8
    for y in range(M-1, img.shape[1], M):
        row = []
        for x in range(0, img.shape[1], N):
            sub_img = img[max(0, y-2*M):y, x:x+N]
            if y-2*M < 0:
                sub_img = np.concatenate(
                    (np.zeros((2*M-y, N, 3)), sub_img))
                sub_img = sub_img.astype(np.uint8)

            piece = classify_image(sub_img)
            row.append(LABELS[piece])
        arr.append(row)

    # Ajustement King-Queen detection
    blackKing = False
    whiteKing = False
    whitePos = (-1, -1)
    blackPos = (-1, -1)
    for i in range(8):
        for j in range(8):
            if arr[i][j] == 'K':
                whiteKing = True
            if arr[i][j] == 'k':
                blackKing = True
            if arr[i][j] == 'Q':
                whitePos = (i, j)
            if arr[i][j] == 'q':
                blackPos = (i, j)
    if not whiteKing and whitePos[0] >= 0:
        arr[whitePos[0]][whitePos[1]] = 'K'
    if not blackKing and blackPos[0] >= 0:
        arr[blackPos[0]][blackPos[1]] = 'k'

    return arr


def board_to_fen(board):
    with io.StringIO() as s:
        for row in board:
            empty = 0
            for cell in row:
                if cell != '.':
                    if empty > 0:
                        s.write(str(empty))
                        empty = 0
                    s.write(cell)
                else:
                    empty += 1
            if empty > 0:
                s.write(str(empty))
            s.write('/')
        s.seek(s.tell() - 1)
        s.write(' w KQkq - 0 1')
        return s.getvalue()


def analyze_chess_image(image_input):
    # Logique gradio pour notre main.
    if image_input is None:
        return "❌ No image provided", None
    
    try:
        # On sauvegarde temporairement 
        with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
            if isinstance(image_input, np.ndarray):
                cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR))
            else:
                image_input.save(tmp.name)
            temp_path = tmp.name
        
        # preprocess_image() utilise le modele LAPS
        img = preprocess_image(temp_path, save=False)
        
        # EXACT SAME as main.py
        arr = analyze_board(img)
        fen = board_to_fen(arr)
        
        # On génère l'echiquier
        board = chess.Board(fen)
        board_svg = chess.svg.board(board=board, size=400)
        
        # on clean le fichier temporairement sauvegarder
        os.unlink(temp_path)
        
        return f"{fen}", board_svg
    
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        return f"Error: {str(e)}", None


# Build Gradio interface
with gr.Blocks(title="Chess Board picture -> FEN notation", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # ♟️ YOCO: You Only Look Once
    
    Upload a chess board image to automatically detect all pieces and get the FEN notation.
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload chess board image", type="pil")
            submit_btn = gr.Button("Analyze Board", size="lg", variant="primary")
        
        with gr.Column():
            status_output = gr.Textbox(label="Result", interactive=False, lines=2)
            board_output = gr.HTML(label="Board Visualization")
    
    submit_btn.click(
        fn=analyze_chess_image,
        inputs=image_input,
        outputs=[status_output, board_output]
    )

if __name__ == "__main__":
    demo.launch()