import gradio as gr
import tensorflow as tf
import numpy as np
import plotly.graph_objects as go
# ============================================================================
# CONFIGURATION GLOBALE - Modifiez uniquement cette section
# ============================================================================
MOUVEMENTS_CONFIG = {
"Cubisme": {
"modele": "Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras",
"couleur": "#6A7B8C" # gris bleuté
},
"Expressionnisme": {
"modele": "Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras",
"couleur": "#C0412B" # Rouge orangé profond
},
"Néo-classicisme": {
"modele": "Neoclassicisme_MobileNetV2_UL_c2_l0_v88_20251013_163057.keras",
"couleur": "#2F4E79" # Bleu Empire
},
"Post-impressionnisme": {
"modele": "Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras",
"couleur": "#E4A725" # Jaune doré
}
}
# Seuil de reconnaissance (probabilité minimale pour afficher un mouvement)
SEUIL_RECONNAISSANCE = 0.5 # 50%
# Paramètres visuels du graphique
LARGEUR_BARRES = 0.4
HAUTEUR_GRAPHIQUE = 600
# ============================================================================
# FIN DE LA CONFIGURATION - Ne modifiez pas le code ci-dessous
# ============================================================================
# === Charger les modèles dynamiquement ===
modeles_disponibles = {}
for mouvement, config in MOUVEMENTS_CONFIG.items():
try:
modeles_disponibles[mouvement] = tf.keras.models.load_model(config["modele"])
print(f"✓ Modèle '{mouvement}' chargé avec succès")
except Exception as e:
print(f"✗ Erreur lors du chargement du modèle '{mouvement}': {e}")
# Liste des mouvements disponibles (pour l'interface)
MOUVEMENTS_DISPONIBLES = list(MOUVEMENTS_CONFIG.keys())
# === Fonction de prédiction ===
def predire(image, mouvements_selectionnes):
# Vérifier qu'au moins un mouvement est sélectionné
if not mouvements_selectionnes:
return None, gr.update(value="⚠️ **Veuillez sélectionner au moins un mouvement pictural à analyser.**", visible=True)
# Prétraitement
image_resized = tf.image.resize(image, (224, 224)) / 255.0
image_batch = tf.expand_dims(image_resized, axis=0)
# Prédictions uniquement pour les modèles sélectionnés
resultats = {}
for mouvement in mouvements_selectionnes:
modele = modeles_disponibles[mouvement]
prob = float(modele.predict(image_batch, verbose=0)[0][0])
resultats[mouvement] = prob
# Filtrer les mouvements reconnus (≥ seuil)
mouvements_reconnus = {m: p for m, p in resultats.items() if p >= SEUIL_RECONNAISSANCE}
# Si aucun mouvement n'atteint le seuil
if not mouvements_reconnus:
return None, gr.update(value=f"❌ **Aucun des mouvements picturaux sélectionnés n'a été reconnu** (seuil : {SEUIL_RECONNAISSANCE*100:.0f}%).", visible=True)
# Tri par probabilité décroissante
mouvements_tries = sorted(mouvements_reconnus.items(), key=lambda x: x[1], reverse=True)
classes_triees = [m for m, _ in mouvements_tries]
probs_triees = [p for _, p in mouvements_tries]
# Couleurs selon la configuration
colors = [MOUVEMENTS_CONFIG[m]["couleur"] for m in classes_triees]
# === Construction du graphique ===
"""fig = go.Figure(go.Bar(
x=classes_triees,
y=probs_triees,
marker=dict(color=colors, line=dict(color='black', width=1)),
text=[f"{p*100:.1f}%" for p in probs_triees],
textposition='auto',
width=LARGEUR_BARRES
))"""
"""fig = go.Figure(go.Bar(
x=classes_triees,
y=probs_triees,
marker=dict(color=colors, line=dict(color='black', width=1)),
text=[f"{p*100:.1f}%" for p in probs_triees],
textposition='inside', # <-- place le texte au centre
insidetextanchor='middle', # <-- centre verticalement
width=LARGEUR_BARRES
))
fig.update_layout(
xaxis=dict(
fixedrange=True,
tickangle=45,
tickfont=dict(size=15),
automargin=True
),
yaxis=dict(
fixedrange=True,
range=[0, 1],
title="Probabilité",
tickfont=dict(size=14)
),
title=dict(
text=f"Mouvements picturaux
reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)",
y=0.90,
pad=dict(b=30)
),
margin=dict(l=60, r=60, t=80, b=80),
height=HAUTEUR_GRAPHIQUE,
width=500,
font=dict(size=13)
)
#fig.data[0].textfont = dict(color='black', size=14, family="Arial")
fig.data[0].textfont = dict(color='white', size=14, family="Arial") # <-- texte blanc"""
# === Construction du graphique ===
fig = go.Figure(go.Bar(
x=classes_triees,
y=probs_triees,
marker=dict(
color=colors,
line=dict(color='rgba(255,255,255,0.6)', width=1.5),
),
text=[f"{p*100:.1f}%" for p in probs_triees],
textposition='inside',
insidetextanchor='middle',
width=LARGEUR_BARRES,
#hovertemplate='%{x}
Probabilité : %{y:.1%}',
))
# === Amélioration du design général ===
"""fig.update_layout(
template="plotly_white", # base moderne claire
paper_bgcolor="rgba(0,0,0,0)", # fond transparent
plot_bgcolor="rgba(245,247,250,1)", # gris bleuté clair moderne
xaxis=dict(
fixedrange=True,
tickangle=45,
tickfont=dict(size=14, color="#333"),
automargin=True,
showline=False,
zeroline=False
),
yaxis=dict(
fixedrange=True,
range=[0, 1],
title="Probabilité",
tickfont=dict(size=13, color="#555"),
gridcolor="rgba(220,220,220,0.4)",
zeroline=False
),
title=dict(
text=f"🎨 Mouvements picturaux reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)",
y=0.93,
x=0.5, # centré horizontalement
xanchor='center',
font=dict(size=18, color="#333", family="Arial Black")
),
margin=dict(l=60, r=60, t=80, b=80),
height=HAUTEUR_GRAPHIQUE,
#width=550,
autosize=True,
responsive=True,
font=dict(size=13, family="Arial"),
)"""
fig.update_layout(
template="plotly_white",
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(245,247,250,1)",
autosize=True, # <-- autorise le redimensionnement
xaxis=dict(
fixedrange=True,
tickangle=30, # angle moins prononcé (plus lisible sur mobile)
tickfont=dict(size=14, color="#333"),
automargin=True,
showline=False,
zeroline=False
),
yaxis=dict(
fixedrange=True,
range=[0, 1],
title="Probabilité",
tickfont=dict(size=13, color="#555"),
gridcolor="rgba(220,220,220,0.4)",
zeroline=False
),
title=dict(
text=f"🎨 Mouvements picturaux reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)",
y=0.93,
x=0.5,
xanchor='center',
font=dict(size=18, color="#333", family="Arial Black")
),
# marges légèrement réduites pour éviter le rognage sur petit écran
margin=dict(l=40, r=20, t=70, b=70),
height=HAUTEUR_GRAPHIQUE, # garde la hauteur souhaitée
# NE PAS préciser `width` ici (laissons le navigateur décider)
font=dict(size=13, family="Arial"),
bargap=0.25
)
# Texte blanc à l’intérieur des barres
fig.data[0].textfont = dict(color='white', size=14, family="Arial", weight="bold")
# Petits arrondis sur les barres
fig.update_traces(marker_line_width=1.5, marker_line_color="rgba(255,255,255,0.5)",
marker=dict(cornerradius=5))
return fig, gr.update(visible=False)
# === Interface Gradio ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# 🎨 Classification de style pictural ({len(MOUVEMENTS_DISPONIBLES)} mouvements)")
gr.Markdown(
f"Sélectionnez les mouvements picturaux à analyser. "
f"Seuls ceux atteignant une probabilité ≥ {SEUIL_RECONNAISSANCE*100:.0f}% seront affichés."
)
with gr.Row():
with gr.Column(scale=1):
#image_input = gr.Image(type="numpy", label="Importer une œuvre")
image_input = gr.Image(
type="numpy",
label="Importer une œuvre",
height=400, # hauteur fixe
width=400, # largeur fixe
elem_classes=["image-fixe"] # optionnel si tu veux styliser
)
mouvements_checkbox = gr.CheckboxGroup(
choices=MOUVEMENTS_DISPONIBLES,
value=MOUVEMENTS_DISPONIBLES,
label="Mouvements à analyser",
info="Cochez les mouvements picturaux à tester"
)
analyser_btn = gr.Button("🔍 Analyser", variant="primary", size="lg")
with gr.Column(scale=1):
output_plot = gr.Plot(label="Résultats de la classification")
output_message = gr.Markdown(visible=False)
analyser_btn.click(
fn=predire,
inputs=[image_input, mouvements_checkbox],
outputs=[output_plot, output_message]
)
gr.Markdown(
"---\n"
f"**Note :** Chaque CNN évalue indépendamment la probabilité d'appartenance "
f"à un mouvement pictural. Les barres colorées indiquent une reconnaissance ≥ {SEUIL_RECONNAISSANCE*100:.0f}%."
)
demo.launch()
"""
# Première version
import gradio as gr
import tensorflow as tf
import numpy as np
import plotly.graph_objects as go
# === Charger les trois modèles binaires ===
model_cubisme = tf.keras.models.load_model("Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras")
model_expressionnisme = tf.keras.models.load_model("Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras")
model_postimp = tf.keras.models.load_model("Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras")
# === Liste des classes ===
classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"]
# === Fonction de prédiction ===
def predire(image):
# Prétraitement
image_resized = tf.image.resize(image, (224, 224)) / 255.0
image_batch = tf.expand_dims(image_resized, axis=0)
# Prédictions des trois modèles
p_cubisme = float(model_cubisme.predict(image_batch)[0][0])
p_expr = float(model_expressionnisme.predict(image_batch)[0][0])
p_postimp = float(model_postimp.predict(image_batch)[0][0])
probs = [p_cubisme, p_expr, p_postimp]
# Tri (optionnel, pour classer les barres par probabilité décroissante)
sorted_indices = np.argsort(probs)[::-1]
sorted_classes = [classes[i] for i in sorted_indices]
sorted_probs = [probs[i] for i in sorted_indices]
colors = ['#2ecc71' if p >= 0.5 else '#bdc3c7' for p in sorted_probs]
# === Construction du graphique ===
fig = go.Figure(go.Bar(
x=sorted_classes,
y=sorted_probs,
marker=dict(color=colors, line=dict(color='black', width=1)),
text=[f"{p*100:.1f}%" for p in sorted_probs],
textposition='auto'
))
fig.update_layout(
xaxis=dict(fixedrange=True, tickangle=45, tickfont=dict(size=15), automargin=True),
yaxis=dict(fixedrange=True, range=[0, 1], title="Probabilité", tickfont=dict(size=14)),
title=dict(
text="Probabilités par
mouvement pictural",
y=0.90,
pad=dict(b=30)
),
margin=dict(l=20, r=20, t=0, b=60),
height=600,
font=dict(size=13)
)
fig.data[0].textfont = dict(color='black', size=14, family="Arial")
return fig
# === Interface Gradio ===
demo = gr.Interface(
fn=predire,
inputs=gr.Image(type="numpy", label="Importer une œuvre"),
outputs=gr.Plot(label="Résultats de la classification"),
title="🎨 Classification de style pictural (3 CNN binaires)",
description="Chaque CNN évalue indépendamment la probabilité d’appartenance à un mouvement pictural. Les barres vertes indiquent une probabilité ≥ 50 %.",
theme=gr.themes.Soft()
)
demo.launch()"""