Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,80 @@ import gradio as gr
|
|
| 2 |
import tensorflow as tf
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Charger ton modèle
|
| 7 |
#model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc88_auc94_20251007_100138.h5")
|
|
@@ -27,4 +101,4 @@ demo = gr.Interface(
|
|
| 27 |
description="Upload une image et découvre le mouvement pictural estimé par le CNN."
|
| 28 |
)
|
| 29 |
|
| 30 |
-
demo.launch()
|
|
|
|
| 2 |
import tensorflow as tf
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
+
from matplotlib.figure import Figure
|
| 6 |
+
|
| 7 |
+
# Charger ton modèle
|
| 8 |
+
model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc88_auc94_20251007_170519.keras")
|
| 9 |
+
|
| 10 |
+
# Classes
|
| 11 |
+
classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"]
|
| 12 |
+
|
| 13 |
+
# Fonction de prédiction avec graphique personnalisé
|
| 14 |
+
def predire(image):
|
| 15 |
+
# Prédiction
|
| 16 |
+
image_resized = tf.image.resize(image, (224, 224)) / 255.0
|
| 17 |
+
preds = model.predict(tf.expand_dims(image_resized, axis=0))[0]
|
| 18 |
+
|
| 19 |
+
# Créer le graphique avec matplotlib
|
| 20 |
+
fig = Figure(figsize=(10, 6))
|
| 21 |
+
ax = fig.add_subplot(111)
|
| 22 |
+
|
| 23 |
+
# Trier par probabilité décroissante
|
| 24 |
+
sorted_indices = np.argsort(preds)[::-1]
|
| 25 |
+
sorted_classes = [classes[i] for i in sorted_indices]
|
| 26 |
+
sorted_probs = [preds[i] for i in sorted_indices]
|
| 27 |
+
|
| 28 |
+
# Définir les couleurs : vert si > 50%, bleu sinon
|
| 29 |
+
colors = ['#2ecc71' if prob > 0.5 else '#3498db' for prob in sorted_probs]
|
| 30 |
+
|
| 31 |
+
# Créer le bar plot horizontal
|
| 32 |
+
bars = ax.barh(sorted_classes, sorted_probs, color=colors, edgecolor='black', linewidth=1.5)
|
| 33 |
+
|
| 34 |
+
# Ajouter les pourcentages sur les barres
|
| 35 |
+
for i, (bar, prob) in enumerate(zip(bars, sorted_probs)):
|
| 36 |
+
width = bar.get_width()
|
| 37 |
+
label_x = width + 0.02 if width < 0.9 else width - 0.02
|
| 38 |
+
ha = 'left' if width < 0.9 else 'right'
|
| 39 |
+
text_color = 'black' if width < 0.9 else 'white'
|
| 40 |
+
ax.text(label_x, bar.get_y() + bar.get_height()/2,
|
| 41 |
+
f'{prob*100:.1f}%',
|
| 42 |
+
ha=ha, va='center', fontsize=12, fontweight='bold', color=text_color)
|
| 43 |
+
|
| 44 |
+
# Configuration du graphique
|
| 45 |
+
ax.set_xlabel('Probabilité', fontsize=12, fontweight='bold')
|
| 46 |
+
ax.set_xlim(0, 1.0)
|
| 47 |
+
ax.set_title('Probabilités par mouvement pictural', fontsize=14, fontweight='bold', pad=20)
|
| 48 |
+
ax.grid(axis='x', alpha=0.3, linestyle='--')
|
| 49 |
+
ax.set_axisbelow(True)
|
| 50 |
+
|
| 51 |
+
# Légende
|
| 52 |
+
from matplotlib.patches import Patch
|
| 53 |
+
legend_elements = [
|
| 54 |
+
Patch(facecolor='#2ecc71', edgecolor='black', label='> 50%'),
|
| 55 |
+
Patch(facecolor='#3498db', edgecolor='black', label='≤ 50%')
|
| 56 |
+
]
|
| 57 |
+
ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
|
| 58 |
+
|
| 59 |
+
fig.tight_layout()
|
| 60 |
+
|
| 61 |
+
return fig
|
| 62 |
+
|
| 63 |
+
# Interface Gradio
|
| 64 |
+
demo = gr.Interface(
|
| 65 |
+
fn=predire,
|
| 66 |
+
inputs=gr.Image(type="numpy", label="Importer une œuvre"),
|
| 67 |
+
outputs=gr.Plot(label="Résultats de la classification"),
|
| 68 |
+
title="🎨 Classification de style pictural",
|
| 69 |
+
description="Upload une image et découvre le mouvement pictural estimé par le CNN. Les barres vertes indiquent une probabilité supérieure à 50%.",
|
| 70 |
+
examples=None,
|
| 71 |
+
theme=gr.themes.Soft()
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
demo.launch()
|
| 75 |
+
"""import gradio as gr
|
| 76 |
+
import tensorflow as tf
|
| 77 |
+
import numpy as np
|
| 78 |
+
import matplotlib.pyplot as plt
|
| 79 |
|
| 80 |
# Charger ton modèle
|
| 81 |
#model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc88_auc94_20251007_100138.h5")
|
|
|
|
| 101 |
description="Upload une image et découvre le mouvement pictural estimé par le CNN."
|
| 102 |
)
|
| 103 |
|
| 104 |
+
demo.launch()"""
|