Eric2mangel's picture
Update app.py
28737e2
raw
history blame
13.3 kB
import gradio as gr
import tensorflow as tf
import numpy as np
import plotly.graph_objects as go
# ============================================================================
# CONFIGURATION GLOBALE - Modifiez uniquement cette section
# ============================================================================
MOUVEMENTS_CONFIG = {
"Cubisme": {
"modele": "Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras",
"couleur": "#2ecc71" # Vert
},
"Expressionnisme": {
"modele": "Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras",
"couleur": "#2ecc71" # Vert
},
"Post-impressionnisme": {
"modele": "Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras",
"couleur": "#2ecc71" # Vert
}
# Pour ajouter un nouveau mouvement, décommentez et modifiez :
# "Surréalisme": {
# "modele": "Surrealisme_MobileNetV2_v1.keras",
# "couleur": "#3498db" # Bleu
# }
}
# Seuil de reconnaissance (probabilité minimale pour afficher un mouvement)
SEUIL_RECONNAISSANCE = 0.5 # 50%
# Paramètres visuels du graphique
LARGEUR_BARRES = 0.4
HAUTEUR_GRAPHIQUE = 600
# ============================================================================
# FIN DE LA CONFIGURATION - Ne modifiez pas le code ci-dessous
# ============================================================================
# === Charger les modèles dynamiquement ===
modeles_disponibles = {}
for mouvement, config in MOUVEMENTS_CONFIG.items():
try:
modeles_disponibles[mouvement] = tf.keras.models.load_model(config["modele"])
print(f"✓ Modèle '{mouvement}' chargé avec succès")
except Exception as e:
print(f"✗ Erreur lors du chargement du modèle '{mouvement}': {e}")
# Liste des mouvements disponibles (pour l'interface)
MOUVEMENTS_DISPONIBLES = list(MOUVEMENTS_CONFIG.keys())
# === Fonction de prédiction ===
def predire(image, mouvements_selectionnes):
# Vérifier qu'au moins un mouvement est sélectionné
if not mouvements_selectionnes:
return None, gr.update(value="⚠️ **Veuillez sélectionner au moins un mouvement pictural à analyser.**", visible=True)
# Prétraitement
image_resized = tf.image.resize(image, (224, 224)) / 255.0
image_batch = tf.expand_dims(image_resized, axis=0)
# Prédictions uniquement pour les modèles sélectionnés
resultats = {}
for mouvement in mouvements_selectionnes:
modele = modeles_disponibles[mouvement]
prob = float(modele.predict(image_batch, verbose=0)[0][0])
resultats[mouvement] = prob
# Filtrer les mouvements reconnus (≥ seuil)
mouvements_reconnus = {m: p for m, p in resultats.items() if p >= SEUIL_RECONNAISSANCE}
# Si aucun mouvement n'atteint le seuil
if not mouvements_reconnus:
return None, gr.update(value=f"❌ **Aucun des mouvements picturaux sélectionnés n'a été reconnu** (seuil : {SEUIL_RECONNAISSANCE*100:.0f}%).", visible=True)
# Tri par probabilité décroissante
mouvements_tries = sorted(mouvements_reconnus.items(), key=lambda x: x[1], reverse=True)
classes_triees = [m for m, _ in mouvements_tries]
probs_triees = [p for _, p in mouvements_tries]
# Couleurs selon la configuration
colors = [MOUVEMENTS_CONFIG[m]["couleur"] for m in classes_triees]
# === Construction du graphique ===
fig = go.Figure(go.Bar(
x=classes_triees,
y=probs_triees,
marker=dict(color=colors, line=dict(color='black', width=1)),
text=[f"{p*100:.1f}%" for p in probs_triees],
textposition='auto',
width=LARGEUR_BARRES
))
fig.update_layout(
xaxis=dict(
fixedrange=True,
tickangle=45,
tickfont=dict(size=15),
automargin=True
),
yaxis=dict(
fixedrange=True,
range=[0, 1],
title="Probabilité",
tickfont=dict(size=14)
),
title=dict(
text=f"Mouvements picturaux<br>reconnus (≥ {SEUIL_RECONNAISSANCE*100:.0f}%)",
y=0.90,
pad=dict(b=30)
),
margin=dict(l=60, r=60, t=80, b=80),
height=HAUTEUR_GRAPHIQUE,
width=500,
font=dict(size=13)
)
fig.data[0].textfont = dict(color='black', size=14, family="Arial")
return fig, gr.update(visible=False)
# === Interface Gradio ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"# 🎨 Classification de style pictural ({len(MOUVEMENTS_DISPONIBLES)} mouvements)")
gr.Markdown(
f"Sélectionnez les mouvements picturaux à analyser. "
f"Seuls ceux atteignant une probabilité ≥ {SEUIL_RECONNAISSANCE*100:.0f}% seront affichés."
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="Importer une œuvre")
mouvements_checkbox = gr.CheckboxGroup(
choices=MOUVEMENTS_DISPONIBLES,
value=MOUVEMENTS_DISPONIBLES,
label="Mouvements à analyser",
info="Cochez les mouvements picturaux à tester"
)
analyser_btn = gr.Button("🔍 Analyser", variant="primary", size="lg")
with gr.Column(scale=1):
output_plot = gr.Plot(label="Résultats de la classification")
output_message = gr.Markdown(visible=False)
analyser_btn.click(
fn=predire,
inputs=[image_input, mouvements_checkbox],
outputs=[output_plot, output_message]
)
gr.Markdown(
"---\n"
f"**Note :** Chaque CNN évalue indépendamment la probabilité d'appartenance "
f"à un mouvement pictural. Les barres colorées indiquent une reconnaissance ≥ {SEUIL_RECONNAISSANCE*100:.0f}%."
)
demo.launch()
"""
VERSION 2 avec sélection des mouvements et affichage des graphiques uniquement pour les mouvements reconnus
import gradio as gr
import tensorflow as tf
import numpy as np
import plotly.graph_objects as go
# === Charger les trois modèles binaires ===
model_cubisme = tf.keras.models.load_model("Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras")
model_expressionnisme = tf.keras.models.load_model("Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras")
model_postimp = tf.keras.models.load_model("Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras")
# === Dictionnaire des modèles ===
modeles_disponibles = {
"Cubisme": model_cubisme,
"Expressionnisme": model_expressionnisme,
"Post-impressionnisme": model_postimp
}
# === Fonction de prédiction ===
def predire(image, mouvements_selectionnes):
# Vérifier qu'au moins un mouvement est sélectionné
if not mouvements_selectionnes:
return None, gr.update(value="⚠️ **Veuillez sélectionner au moins un mouvement pictural à analyser.**", visible=True)
# Prétraitement
image_resized = tf.image.resize(image, (224, 224)) / 255.0
image_batch = tf.expand_dims(image_resized, axis=0)
# Prédictions uniquement pour les modèles sélectionnés
resultats = {}
for mouvement in mouvements_selectionnes:
modele = modeles_disponibles[mouvement]
prob = float(modele.predict(image_batch, verbose=0)[0][0])
resultats[mouvement] = prob
# Filtrer les mouvements reconnus (≥ 50%)
mouvements_reconnus = {m: p for m, p in resultats.items() if p >= 0.5}
# Si aucun mouvement n'atteint 50%
if not mouvements_reconnus:
return None, gr.update(value="❌ **Aucun des mouvements picturaux sélectionnés n'a été reconnu.**", visible=True)
# Tri par probabilité décroissante
mouvements_tries = sorted(mouvements_reconnus.items(), key=lambda x: x[1], reverse=True)
classes_triees = [m for m, _ in mouvements_tries]
probs_triees = [p for _, p in mouvements_tries]
# Couleur verte pour tous (car tous sont ≥ 50%)
colors = ['#2ecc71'] * len(classes_triees)
# === Construction du graphique ===
fig = go.Figure(go.Bar(
x=classes_triees,
y=probs_triees,
marker=dict(color=colors, line=dict(color='black', width=1)),
text=[f"{p*100:.1f}%" for p in probs_triees],
textposition='auto',
width=0.4
))
fig.update_layout(
xaxis=dict(
fixedrange=True,
tickangle=45,
tickfont=dict(size=15),
automargin=True
),
yaxis=dict(
fixedrange=True,
range=[0, 1],
title="Probabilité",
tickfont=dict(size=14)
),
title=dict(
text="Mouvements picturaux<br>reconnus (≥ 50%)",
y=0.90,
pad=dict(b=30)
),
margin=dict(l=20, r=20, t=0, b=60),
height=600,
font=dict(size=13)
)
fig.data[0].textfont = dict(color='black', size=14, family="Arial")
return fig, gr.update(visible=False)
# === Interface Gradio ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 Classification de style pictural (3 CNN binaires)")
gr.Markdown(
"Sélectionnez les mouvements picturaux à analyser. "
"Seuls ceux atteignant une probabilité ≥ 50% seront affichés."
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="Importer une œuvre")
mouvements_checkbox = gr.CheckboxGroup(
choices=["Cubisme", "Expressionnisme", "Post-impressionnisme"],
value=["Cubisme", "Expressionnisme", "Post-impressionnisme"],
label="Mouvements à analyser",
info="Cochez les mouvements picturaux à tester"
)
analyser_btn = gr.Button("🔍 Analyser", variant="primary", size="lg")
with gr.Column(scale=1):
output_plot = gr.Plot(label="Résultats de la classification")
output_message = gr.Markdown(visible=False)
analyser_btn.click(
fn=predire,
inputs=[image_input, mouvements_checkbox],
outputs=[output_plot, output_message]
)
gr.Markdown(
"---\n"
"**Note :** Chaque CNN évalue indépendamment la probabilité d'appartenance "
"à un mouvement pictural. Les barres vertes indiquent une reconnaissance ≥ 50%."
)
demo.launch()"""
"""
# Première version
import gradio as gr
import tensorflow as tf
import numpy as np
import plotly.graph_objects as go
# === Charger les trois modèles binaires ===
model_cubisme = tf.keras.models.load_model("Cubisme_MobileNetV2_UL_c2_l0_v96_20251013_114051.keras")
model_expressionnisme = tf.keras.models.load_model("Expressionnisme_MobileNetV2_UL_c2_l0_v84_20251012_232500.keras")
model_postimp = tf.keras.models.load_model("Postimpressionnisme_MobileNetV2_UL_c2_l0_v89_20251013_111049.keras")
# === Liste des classes ===
classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"]
# === Fonction de prédiction ===
def predire(image):
# Prétraitement
image_resized = tf.image.resize(image, (224, 224)) / 255.0
image_batch = tf.expand_dims(image_resized, axis=0)
# Prédictions des trois modèles
p_cubisme = float(model_cubisme.predict(image_batch)[0][0])
p_expr = float(model_expressionnisme.predict(image_batch)[0][0])
p_postimp = float(model_postimp.predict(image_batch)[0][0])
probs = [p_cubisme, p_expr, p_postimp]
# Tri (optionnel, pour classer les barres par probabilité décroissante)
sorted_indices = np.argsort(probs)[::-1]
sorted_classes = [classes[i] for i in sorted_indices]
sorted_probs = [probs[i] for i in sorted_indices]
colors = ['#2ecc71' if p >= 0.5 else '#bdc3c7' for p in sorted_probs]
# === Construction du graphique ===
fig = go.Figure(go.Bar(
x=sorted_classes,
y=sorted_probs,
marker=dict(color=colors, line=dict(color='black', width=1)),
text=[f"{p*100:.1f}%" for p in sorted_probs],
textposition='auto'
))
fig.update_layout(
xaxis=dict(fixedrange=True, tickangle=45, tickfont=dict(size=15), automargin=True),
yaxis=dict(fixedrange=True, range=[0, 1], title="Probabilité", tickfont=dict(size=14)),
title=dict(
text="Probabilités par <br>mouvement pictural",
y=0.90,
pad=dict(b=30)
),
margin=dict(l=20, r=20, t=0, b=60),
height=600,
font=dict(size=13)
)
fig.data[0].textfont = dict(color='black', size=14, family="Arial")
return fig
# === Interface Gradio ===
demo = gr.Interface(
fn=predire,
inputs=gr.Image(type="numpy", label="Importer une œuvre"),
outputs=gr.Plot(label="Résultats de la classification"),
title="🎨 Classification de style pictural (3 CNN binaires)",
description="Chaque CNN évalue indépendamment la probabilité d’appartenance à un mouvement pictural. Les barres vertes indiquent une probabilité ≥ 50 %.",
theme=gr.themes.Soft()
)
demo.launch()"""