File size: 2,531 Bytes
bfb1531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import tensorflow as tf
import numpy as np
import plotly.graph_objects as go

# === Charger les trois modèles binaires ===
model_cubisme = tf.keras.models.load_model("model_cubisme.keras")
model_expressionnisme = tf.keras.models.load_model("model_expressionnisme.keras")
model_postimp = tf.keras.models.load_model("model_postimpressionnisme.keras")

# === Liste des classes ===
classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"]

# === Fonction de prédiction ===
def predire(image):
    # Prétraitement
    image_resized = tf.image.resize(image, (224, 224)) / 255.0
    image_batch = tf.expand_dims(image_resized, axis=0)

    # Prédictions des trois modèles
    p_cubisme = float(model_cubisme.predict(image_batch)[0][0])
    p_expr = float(model_expressionnisme.predict(image_batch)[0][0])
    p_postimp = float(model_postimp.predict(image_batch)[0][0])

    probs = [p_cubisme, p_expr, p_postimp]

    # Tri (optionnel, pour classer les barres par probabilité décroissante)
    sorted_indices = np.argsort(probs)[::-1]
    sorted_classes = [classes[i] for i in sorted_indices]
    sorted_probs = [probs[i] for i in sorted_indices]
    colors = ['#2ecc71' if p >= 0.5 else '#bdc3c7' for p in sorted_probs]

    # === Construction du graphique ===
    fig = go.Figure(go.Bar(
        x=sorted_classes,
        y=sorted_probs,
        marker=dict(color=colors, line=dict(color='black', width=1)),
        text=[f"{p*100:.1f}%" for p in sorted_probs],
        textposition='auto'
    ))

    fig.update_layout(
        xaxis=dict(fixedrange=True, tickangle=45, tickfont=dict(size=15), automargin=True),
        yaxis=dict(fixedrange=True, range=[0, 1], title="Probabilité", tickfont=dict(size=14)),
        title=dict(
            text="Probabilités par mouvement pictural",
            y=0.90,
            pad=dict(b=30)
        ),
        margin=dict(l=20, r=20, t=0, b=60),
        height=600,
        font=dict(size=13)
    )

    fig.data[0].textfont = dict(color='black', size=14, family="Arial")
    return fig

# === Interface Gradio ===
demo = gr.Interface(
    fn=predire,
    inputs=gr.Image(type="numpy", label="Importer une œuvre"),
    outputs=gr.Plot(label="Résultats de la classification"),
    title="🎨 Classification de style pictural (3 CNN binaires)",
    description="Chaque CNN évalue indépendamment la probabilité d’appartenance à un mouvement pictural. Les barres vertes indiquent une probabilité ≥ 50 %.",
    theme=gr.themes.Soft()
)

demo.launch()