Eric2mangel commited on
Commit
0b47556
·
verified ·
1 Parent(s): e7ec3dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -1
app.py CHANGED
@@ -2,6 +2,80 @@ import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Charger ton modèle
7
  #model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc88_auc94_20251007_100138.h5")
@@ -27,4 +101,4 @@ demo = gr.Interface(
27
  description="Upload une image et découvre le mouvement pictural estimé par le CNN."
28
  )
29
 
30
- demo.launch()
 
2
  import tensorflow as tf
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
+ from matplotlib.figure import Figure
6
+
7
+ # Charger ton modèle
8
+ model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc88_auc94_20251007_170519.keras")
9
+
10
+ # Classes
11
+ classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"]
12
+
13
+ # Fonction de prédiction avec graphique personnalisé
14
+ def predire(image):
15
+ # Prédiction
16
+ image_resized = tf.image.resize(image, (224, 224)) / 255.0
17
+ preds = model.predict(tf.expand_dims(image_resized, axis=0))[0]
18
+
19
+ # Créer le graphique avec matplotlib
20
+ fig = Figure(figsize=(10, 6))
21
+ ax = fig.add_subplot(111)
22
+
23
+ # Trier par probabilité décroissante
24
+ sorted_indices = np.argsort(preds)[::-1]
25
+ sorted_classes = [classes[i] for i in sorted_indices]
26
+ sorted_probs = [preds[i] for i in sorted_indices]
27
+
28
+ # Définir les couleurs : vert si > 50%, bleu sinon
29
+ colors = ['#2ecc71' if prob > 0.5 else '#3498db' for prob in sorted_probs]
30
+
31
+ # Créer le bar plot horizontal
32
+ bars = ax.barh(sorted_classes, sorted_probs, color=colors, edgecolor='black', linewidth=1.5)
33
+
34
+ # Ajouter les pourcentages sur les barres
35
+ for i, (bar, prob) in enumerate(zip(bars, sorted_probs)):
36
+ width = bar.get_width()
37
+ label_x = width + 0.02 if width < 0.9 else width - 0.02
38
+ ha = 'left' if width < 0.9 else 'right'
39
+ text_color = 'black' if width < 0.9 else 'white'
40
+ ax.text(label_x, bar.get_y() + bar.get_height()/2,
41
+ f'{prob*100:.1f}%',
42
+ ha=ha, va='center', fontsize=12, fontweight='bold', color=text_color)
43
+
44
+ # Configuration du graphique
45
+ ax.set_xlabel('Probabilité', fontsize=12, fontweight='bold')
46
+ ax.set_xlim(0, 1.0)
47
+ ax.set_title('Probabilités par mouvement pictural', fontsize=14, fontweight='bold', pad=20)
48
+ ax.grid(axis='x', alpha=0.3, linestyle='--')
49
+ ax.set_axisbelow(True)
50
+
51
+ # Légende
52
+ from matplotlib.patches import Patch
53
+ legend_elements = [
54
+ Patch(facecolor='#2ecc71', edgecolor='black', label='> 50%'),
55
+ Patch(facecolor='#3498db', edgecolor='black', label='≤ 50%')
56
+ ]
57
+ ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
58
+
59
+ fig.tight_layout()
60
+
61
+ return fig
62
+
63
+ # Interface Gradio
64
+ demo = gr.Interface(
65
+ fn=predire,
66
+ inputs=gr.Image(type="numpy", label="Importer une œuvre"),
67
+ outputs=gr.Plot(label="Résultats de la classification"),
68
+ title="🎨 Classification de style pictural",
69
+ description="Upload une image et découvre le mouvement pictural estimé par le CNN. Les barres vertes indiquent une probabilité supérieure à 50%.",
70
+ examples=None,
71
+ theme=gr.themes.Soft()
72
+ )
73
+
74
+ demo.launch()
75
+ """import gradio as gr
76
+ import tensorflow as tf
77
+ import numpy as np
78
+ import matplotlib.pyplot as plt
79
 
80
  # Charger ton modèle
81
  #model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc88_auc94_20251007_100138.h5")
 
101
  description="Upload une image et découvre le mouvement pictural estimé par le CNN."
102
  )
103
 
104
+ demo.launch()"""