|
|
import gradio as gr |
|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MOUVEMENTS_CONFIG = { |
|
|
"Cubisme": { |
|
|
"modele": "Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras", |
|
|
"couleur": "#6A7B8C" |
|
|
}, |
|
|
"Expressionnisme": { |
|
|
"modele": "Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras", |
|
|
"couleur": "#C0412B" |
|
|
}, |
|
|
"Néo-classicisme": { |
|
|
"modele": "Neoclassicisme_MobileNetV2_UL_c2_l0_v88_20251013_163057.keras", |
|
|
"couleur": "#2F4E79" |
|
|
}, |
|
|
"Post-impressionnisme": { |
|
|
"modele": "Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras", |
|
|
"couleur": "#E4A725" |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
SEUIL_RECONNAISSANCE = 0.5 |
|
|
|
|
|
|
|
|
LARGEUR_BARRES = 0.4 |
|
|
HAUTEUR_GRAPHIQUE = 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeles_disponibles = {} |
|
|
for mouvement, config in MOUVEMENTS_CONFIG.items(): |
|
|
try: |
|
|
modeles_disponibles[mouvement] = tf.keras.models.load_model(config["modele"]) |
|
|
print(f"✓ Modèle '{mouvement}' chargé avec succès") |
|
|
except Exception as e: |
|
|
print(f"✗ Erreur lors du chargement du modèle '{mouvement}': {e}") |
|
|
|
|
|
|
|
|
MOUVEMENTS_DISPONIBLES = list(MOUVEMENTS_CONFIG.keys()) |
|
|
|
|
|
|
|
|
def predire(image, mouvements_selectionnes): |
|
|
|
|
|
if not mouvements_selectionnes: |
|
|
return None, gr.update(value="⚠️ **Veuillez sélectionner au moins un mouvement pictural à analyser.**", visible=True) |
|
|
|
|
|
|
|
|
image_resized = tf.image.resize(image, (224, 224)) / 255.0 |
|
|
image_batch = tf.expand_dims(image_resized, axis=0) |
|
|
|
|
|
|
|
|
resultats = {} |
|
|
for mouvement in mouvements_selectionnes: |
|
|
modele = modeles_disponibles[mouvement] |
|
|
prob = float(modele.predict(image_batch, verbose=0)[0][0]) |
|
|
resultats[mouvement] = prob |
|
|
|
|
|
|
|
|
mouvements_reconnus = {m: p for m, p in resultats.items() if p >= SEUIL_RECONNAISSANCE} |
|
|
|
|
|
|
|
|
if not mouvements_reconnus: |
|
|
return None, gr.update(value=f"❌ **Aucun des mouvements picturaux sélectionnés n'a été reconnu** (seuil : {SEUIL_RECONNAISSANCE*100:.0f}%).", visible=True) |
|
|
|
|
|
|
|
|
mouvements_tries = sorted(mouvements_reconnus.items(), key=lambda x: x[1], reverse=True) |
|
|
classes_triees = [m for m, _ in mouvements_tries] |
|
|
probs_triees = [p for _, p in mouvements_tries] |
|
|
|
|
|
|
|
|
colors = [MOUVEMENTS_CONFIG[m]["couleur"] for m in classes_triees] |
|
|
|
|
|
|
|
|
"""fig = go.Figure(go.Bar( |
|
|
x=classes_triees, |
|
|
y=probs_triees, |
|
|
marker=dict(color=colors, line=dict(color='black', width=1)), |
|
|
text=[f"{p*100:.1f}%" for p in probs_triees], |
|
|
textposition='auto', |
|
|
width=LARGEUR_BARRES |
|
|
))""" |
|
|
|
|
|
"""fig = go.Figure(go.Bar( |
|
|
x=classes_triees, |
|
|
y=probs_triees, |
|
|
marker=dict(color=colors, line=dict(color='black', width=1)), |
|
|
text=[f"{p*100:.1f}%" for p in probs_triees], |
|
|
textposition='inside', # <-- place le texte au centre |
|
|
insidetextanchor='middle', # <-- centre verticalement |
|
|
width=LARGEUR_BARRES |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
xaxis=dict( |
|
|
fixedrange=True, |
|
|
tickangle=45, |
|
|
tickfont=dict(size=15), |
|
|
automargin=True |
|
|
), |
|
|
yaxis=dict( |
|
|
fixedrange=True, |
|
|
range=[0, 1], |
|
|
title="Probabilité", |
|
|
tickfont=dict(size=14) |
|
|
), |
|
|
title=dict( |
|
|
text=f"Mouvements picturaux<br>reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)", |
|
|
y=0.90, |
|
|
pad=dict(b=30) |
|
|
), |
|
|
margin=dict(l=60, r=60, t=80, b=80), |
|
|
height=HAUTEUR_GRAPHIQUE, |
|
|
width=500, |
|
|
font=dict(size=13) |
|
|
) |
|
|
|
|
|
#fig.data[0].textfont = dict(color='black', size=14, family="Arial") |
|
|
fig.data[0].textfont = dict(color='white', size=14, family="Arial") # <-- texte blanc""" |
|
|
|
|
|
|
|
|
fig = go.Figure(go.Bar( |
|
|
x=classes_triees, |
|
|
y=probs_triees, |
|
|
marker=dict( |
|
|
color=colors, |
|
|
line=dict(color='rgba(255,255,255,0.6)', width=1.5), |
|
|
), |
|
|
text=[f"{p*100:.1f}%" for p in probs_triees], |
|
|
textposition='inside', |
|
|
insidetextanchor='middle', |
|
|
width=LARGEUR_BARRES, |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
"""fig.update_layout( |
|
|
template="plotly_white", # base moderne claire |
|
|
paper_bgcolor="rgba(0,0,0,0)", # fond transparent |
|
|
plot_bgcolor="rgba(245,247,250,1)", # gris bleuté clair moderne |
|
|
xaxis=dict( |
|
|
fixedrange=True, |
|
|
tickangle=45, |
|
|
tickfont=dict(size=14, color="#333"), |
|
|
automargin=True, |
|
|
showline=False, |
|
|
zeroline=False |
|
|
), |
|
|
yaxis=dict( |
|
|
fixedrange=True, |
|
|
range=[0, 1], |
|
|
title="Probabilité", |
|
|
tickfont=dict(size=13, color="#555"), |
|
|
gridcolor="rgba(220,220,220,0.4)", |
|
|
zeroline=False |
|
|
), |
|
|
title=dict( |
|
|
text=f"🎨 Mouvements picturaux reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)", |
|
|
y=0.93, |
|
|
x=0.5, # centré horizontalement |
|
|
xanchor='center', |
|
|
font=dict(size=18, color="#333", family="Arial Black") |
|
|
), |
|
|
margin=dict(l=60, r=60, t=80, b=80), |
|
|
height=HAUTEUR_GRAPHIQUE, |
|
|
#width=550, |
|
|
autosize=True, |
|
|
responsive=True, |
|
|
font=dict(size=13, family="Arial"), |
|
|
)""" |
|
|
|
|
|
fig.update_layout( |
|
|
template="plotly_white", |
|
|
paper_bgcolor="rgba(0,0,0,0)", |
|
|
plot_bgcolor="rgba(245,247,250,1)", |
|
|
autosize=True, |
|
|
xaxis=dict( |
|
|
fixedrange=True, |
|
|
tickangle=30, |
|
|
tickfont=dict(size=14, color="#333"), |
|
|
automargin=True, |
|
|
showline=False, |
|
|
zeroline=False |
|
|
), |
|
|
yaxis=dict( |
|
|
fixedrange=True, |
|
|
range=[0, 1], |
|
|
title="Probabilité", |
|
|
tickfont=dict(size=13, color="#555"), |
|
|
gridcolor="rgba(220,220,220,0.4)", |
|
|
zeroline=False |
|
|
), |
|
|
title=dict( |
|
|
text=f"🎨 Mouvements picturaux reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)", |
|
|
y=0.93, |
|
|
x=0.5, |
|
|
xanchor='center', |
|
|
font=dict(size=18, color="#333", family="Arial Black") |
|
|
), |
|
|
|
|
|
margin=dict(l=40, r=20, t=70, b=70), |
|
|
height=HAUTEUR_GRAPHIQUE, |
|
|
|
|
|
font=dict(size=13, family="Arial"), |
|
|
bargap=0.25 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
fig.data[0].textfont = dict(color='white', size=14, family="Arial", weight="bold") |
|
|
|
|
|
|
|
|
fig.update_traces(marker_line_width=1.5, marker_line_color="rgba(255,255,255,0.5)", |
|
|
marker=dict(cornerradius=5)) |
|
|
|
|
|
|
|
|
return fig, gr.update(visible=False) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(f"# 🎨 Classification de style pictural ({len(MOUVEMENTS_DISPONIBLES)} mouvements)") |
|
|
gr.Markdown( |
|
|
f"Sélectionnez les mouvements picturaux à analyser. " |
|
|
f"Seuls ceux atteignant une probabilité ≥ {SEUIL_RECONNAISSANCE*100:.0f}% seront affichés." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
|
|
|
image_input = gr.Image( |
|
|
type="numpy", |
|
|
label="Importer une œuvre", |
|
|
height=400, |
|
|
width=400, |
|
|
elem_classes=["image-fixe"] |
|
|
) |
|
|
|
|
|
mouvements_checkbox = gr.CheckboxGroup( |
|
|
choices=MOUVEMENTS_DISPONIBLES, |
|
|
value=MOUVEMENTS_DISPONIBLES, |
|
|
label="Mouvements à analyser", |
|
|
info="Cochez les mouvements picturaux à tester" |
|
|
) |
|
|
|
|
|
analyser_btn = gr.Button("🔍 Analyser", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_plot = gr.Plot(label="Résultats de la classification") |
|
|
output_message = gr.Markdown(visible=False) |
|
|
|
|
|
analyser_btn.click( |
|
|
fn=predire, |
|
|
inputs=[image_input, mouvements_checkbox], |
|
|
outputs=[output_plot, output_message] |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
"---\n" |
|
|
f"**Note :** Chaque CNN évalue indépendamment la probabilité d'appartenance " |
|
|
f"à un mouvement pictural. Les barres colorées indiquent une reconnaissance ≥ {SEUIL_RECONNAISSANCE*100:.0f}%." |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
""" |
|
|
# Première version |
|
|
import gradio as gr |
|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
# === Charger les trois modèles binaires === |
|
|
model_cubisme = tf.keras.models.load_model("Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras") |
|
|
model_expressionnisme = tf.keras.models.load_model("Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras") |
|
|
model_postimp = tf.keras.models.load_model("Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras") |
|
|
|
|
|
# === Liste des classes === |
|
|
classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"] |
|
|
|
|
|
# === Fonction de prédiction === |
|
|
def predire(image): |
|
|
# Prétraitement |
|
|
image_resized = tf.image.resize(image, (224, 224)) / 255.0 |
|
|
image_batch = tf.expand_dims(image_resized, axis=0) |
|
|
|
|
|
# Prédictions des trois modèles |
|
|
p_cubisme = float(model_cubisme.predict(image_batch)[0][0]) |
|
|
p_expr = float(model_expressionnisme.predict(image_batch)[0][0]) |
|
|
p_postimp = float(model_postimp.predict(image_batch)[0][0]) |
|
|
|
|
|
probs = [p_cubisme, p_expr, p_postimp] |
|
|
|
|
|
# Tri (optionnel, pour classer les barres par probabilité décroissante) |
|
|
sorted_indices = np.argsort(probs)[::-1] |
|
|
sorted_classes = [classes[i] for i in sorted_indices] |
|
|
sorted_probs = [probs[i] for i in sorted_indices] |
|
|
colors = ['#2ecc71' if p >= 0.5 else '#bdc3c7' for p in sorted_probs] |
|
|
|
|
|
# === Construction du graphique === |
|
|
fig = go.Figure(go.Bar( |
|
|
x=sorted_classes, |
|
|
y=sorted_probs, |
|
|
marker=dict(color=colors, line=dict(color='black', width=1)), |
|
|
text=[f"{p*100:.1f}%" for p in sorted_probs], |
|
|
textposition='auto' |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
xaxis=dict(fixedrange=True, tickangle=45, tickfont=dict(size=15), automargin=True), |
|
|
yaxis=dict(fixedrange=True, range=[0, 1], title="Probabilité", tickfont=dict(size=14)), |
|
|
title=dict( |
|
|
text="Probabilités par <br>mouvement pictural", |
|
|
y=0.90, |
|
|
pad=dict(b=30) |
|
|
), |
|
|
margin=dict(l=20, r=20, t=0, b=60), |
|
|
height=600, |
|
|
font=dict(size=13) |
|
|
) |
|
|
|
|
|
fig.data[0].textfont = dict(color='black', size=14, family="Arial") |
|
|
return fig |
|
|
|
|
|
# === Interface Gradio === |
|
|
demo = gr.Interface( |
|
|
fn=predire, |
|
|
inputs=gr.Image(type="numpy", label="Importer une œuvre"), |
|
|
outputs=gr.Plot(label="Résultats de la classification"), |
|
|
title="🎨 Classification de style pictural (3 CNN binaires)", |
|
|
description="Chaque CNN évalue indépendamment la probabilité d’appartenance à un mouvement pictural. Les barres vertes indiquent une probabilité ≥ 50 %.", |
|
|
theme=gr.themes.Soft() |
|
|
) |
|
|
|
|
|
demo.launch()""" |
|
|
|