Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.figure import Figure | |
| # Charger ton modèle | |
| model = tf.keras.models.load_model("MobileNetV2_UL_ML_c3_l0_acc97_auc100_20251012_161415.keras") | |
| # Classes | |
| classes = ["Cubisme", "Expressionnisme", "Post-impressionnisme"] | |
| # Fonction de prédiction avec graphique personnalisé | |
| """def predire(image): | |
| # Prédiction | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| # Créer le graphique avec matplotlib | |
| fig = Figure(figsize=(10, 6)) | |
| ax = fig.add_subplot(111) | |
| # Trier par probabilité décroissante | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| # Définir les couleurs : vert si > 50%, bleu sinon | |
| colors = ['#2ecc71' if prob >= 0.5 else '#bdc3c7' for prob in sorted_probs] | |
| # Créer le bar plot horizontal | |
| bars = ax.barh(sorted_classes, sorted_probs, color=colors, edgecolor='black', linewidth=1.5) | |
| # Ajouter les pourcentages sur les barres | |
| for i, (bar, prob) in enumerate(zip(bars, sorted_probs)): | |
| width = bar.get_width() | |
| label_x = width + 0.02 if width < 0.9 else width - 0.02 | |
| ha = 'left' if width < 0.9 else 'right' | |
| text_color = 'black' if width < 0.9 else 'white' | |
| ax.text(label_x, bar.get_y() + bar.get_height()/2, | |
| f'{prob*100:.1f}%', | |
| ha=ha, va='center', fontsize=12, fontweight='bold', color=text_color) | |
| # Configuration du graphique | |
| ax.set_xlabel('Probabilité', fontsize=12, fontweight='bold') | |
| ax.set_xlim(0, 1.0) | |
| ax.set_title('Probabilités par mouvement pictural', fontsize=14, fontweight='bold', pad=20) | |
| ax.grid(axis='x', alpha=0.3, linestyle='--') | |
| ax.set_axisbelow(True) | |
| # Légende | |
| from matplotlib.patches import Patch | |
| legend_elements = [ | |
| Patch(facecolor='#2ecc71', edgecolor='black', label='≥ 50%'), | |
| Patch(facecolor='#bdc3c7', edgecolor='black', label='< 50%') | |
| ] | |
| ax.legend(handles=legend_elements, loc='upper right', fontsize=10) | |
| fig.tight_layout() | |
| return fig""" | |
| # Solution 2 | |
| """def predire(image): | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| colors = ['#2ecc71' if prob >= 0.5 else '#bdc3c7' for prob in sorted_probs] | |
| fig = Figure(figsize=(4, 3)) # Format compact adapté mobile | |
| ax = fig.add_subplot(111) | |
| # Barres verticales | |
| bars = ax.bar(sorted_classes, sorted_probs, color=colors, edgecolor='black', linewidth=1.5) | |
| # Ajout pourcentages | |
| for bar, prob in zip(bars, sorted_probs): | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03, | |
| f"{prob*100:.1f}%", ha='center', va='bottom', fontsize=13, fontweight='bold', color='black') | |
| ax.set_ylabel('Probabilité', fontsize=15, fontweight='bold') | |
| ax.set_ylim(0, 1.0) | |
| ax.set_title('Probabilités par mouvement pictural', fontsize=16, fontweight='bold', pad=20) | |
| ax.grid(axis='y', alpha=0.15, linestyle='--') | |
| ax.set_axisbelow(True) | |
| # Titres inclinés à 45° | |
| ax.set_xticklabels(sorted_classes, rotation=45, ha='right', fontsize=15, fontweight='bold') | |
| fig.tight_layout() | |
| return fig""" | |
| # Solution 3 : Passer par Plotly | |
| """import plotly.graph_objects as go | |
| def predire(image): | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| colors = ['#2ecc71' if prob >= 0.5 else '#bdc3c7' for prob in sorted_probs] | |
| fig = go.Figure(go.Bar( | |
| x=sorted_classes, | |
| y=sorted_probs, | |
| text=[f"{p*100:.1f}%" for p in sorted_probs], | |
| marker=dict(color=colors, line=dict(color='black', width=1)), | |
| textposition='auto', | |
| )) | |
| fig.update_layout( | |
| xaxis=dict(tickangle=45, tickfont=dict(size=17)), | |
| yaxis=dict(range=[0,1], title='Probabilité', tickfont=dict(size=17)), | |
| title="Probabilités par mouvement pictural", | |
| margin=dict(l=15, r=15, t=40, b=25), | |
| height=280, | |
| font=dict(size=17) | |
| ) | |
| return fig""" | |
| # Solution avec le graphique plus haut (ADAPTER LA HAUTEUR) | |
| """def predire(image): | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| colors = ['#2ecc71' if prob >= 0.5 else '#bdc3c7' for prob in sorted_probs] | |
| fig = Figure(figsize=(12, 4.2)) # hauteur augmentée | |
| ax = fig.add_subplot(111) | |
| bars = ax.bar(sorted_classes, sorted_probs, color=colors, edgecolor='black', linewidth=1.5) | |
| for bar, prob in zip(bars, sorted_probs): | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.035, | |
| f"{prob*100:.1f}%", ha='center', va='bottom', fontsize=12, color='black', fontweight='bold') | |
| ax.set_ylabel('Probabilité', fontsize=14, fontweight='bold') | |
| ax.set_ylim(0, 1) | |
| ax.set_title("Probabilités par mouvement pictural", fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xticklabels(sorted_classes, rotation=45, ha='right', fontsize=13, fontweight='bold') | |
| fig.tight_layout(pad=2.0) | |
| return fig""" | |
| # Façon 5 uniquement sur le texte mais ça donne une erreur | |
| """def predire(image): | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| # HTML pour Gradio : barres remplies & label | |
| bars = [] | |
| for cls, prob in zip(sorted_classes, sorted_probs): | |
| color = "#2ecc71" if prob >= 0.5 else "#bdc3c7" | |
| bars.append(f''' | |
| <div style="background:linear-gradient(90deg,{color} {prob*100:.1f}%,#fff {prob*100:.1f}%);padding:6px 0;margin:5px 0;border-radius:4px;"> | |
| <span style="padding-left:12px;font-size:15px;font-weight:bold;">{cls} — {prob*100:.1f} %</span> | |
| </div>''') | |
| return gr.HTML("".join(bars)) | |
| demo = gr.Interface( | |
| fn=predire, | |
| inputs=gr.Image(type="numpy", label="Importer une œuvre"), | |
| outputs=[gr.Image(label="Input"), gr.HTML(label="Résultats")], | |
| title="🎨 Classification de style pictural", | |
| examples=None, | |
| theme=gr.themes.Soft() | |
| )""" | |
| """import plotly.graph_objects as go | |
| def predire(image): | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| colors = ['#2ecc71' if p >= 0.5 else '#bdc3c7' for p in sorted_probs] | |
| fig = go.Figure(go.Bar( | |
| x=sorted_classes, | |
| y=sorted_probs, | |
| marker=dict(color=colors, line=dict(color='black', width=1)), | |
| text=[f"{p*100:.1f}%" for p in sorted_probs], | |
| textposition='auto' | |
| )) | |
| fig.update_layout( | |
| xaxis=dict(fixedrange=True, tickangle=45, tickfont=dict(size=15)), | |
| yaxis=dict(fixedrange=True, range=[0,1], title="Probabilité", tickfont=dict(size=14)), | |
| title=dict( | |
| text="Probabilités par mouvement pictural", | |
| y=0.95, # remonte le titre pour laisser de l’espace | |
| pad=dict(b=50) # espace (en px) sous le titre, ajuste à volonté | |
| ), | |
| margin=dict(l=20, r=20, t=32, b=46), | |
| height=500, | |
| font=dict(size=16) | |
| ) | |
| # Pour placer le texte à l’intérieur des barres | |
| fig.show(config={'displayModeBar': False}) | |
| return fig | |
| """ | |
| import plotly.graph_objects as go | |
| def predire(image): | |
| image_resized = tf.image.resize(image, (224, 224)) / 255.0 | |
| preds = model.predict(tf.expand_dims(image_resized, axis=0))[0] | |
| sorted_indices = np.argsort(preds)[::-1] | |
| sorted_classes = [classes[i] for i in sorted_indices] | |
| sorted_probs = [preds[i] for i in sorted_indices] | |
| colors = ['#2ecc71' if p >= 0.5 else '#bdc3c7' for p in sorted_probs] | |
| fig = go.Figure(go.Bar( | |
| x=sorted_classes, | |
| y=sorted_probs, | |
| marker=dict(color=colors, line=dict(color='black', width=1)), | |
| text=[f"{p*100:.1f}%" for p in sorted_probs], | |
| textposition='auto' | |
| )) | |
| fig.update_layout( | |
| xaxis=dict(fixedrange=True, tickangle=45, tickfont=dict(size=15), automargin=True), | |
| yaxis=dict(fixedrange=True, range=[0, 1], title="Probabilité", tickfont=dict(size=14)), | |
| title=dict( | |
| text="Probabilités par <br>mouvement pictural", | |
| y=0.90, | |
| pad=dict(b=30) | |
| ), | |
| margin=dict(l=20, r=20, t=0, b=60), # marge top plus haute et bottom plus grande | |
| height=600, | |
| font=dict(size=13) | |
| ) | |
| # Pour placer le texte à l’intérieur des barres | |
| fig.data[0].textfont = dict(color='black', size=14, family="Arial") | |
| fig.show(config={'displayModeBar': False}) | |
| return fig | |
| # Interface Gradio | |
| # Exemple pour Plotly hors Gradio : | |
| demo = gr.Interface( | |
| fn=predire, | |
| inputs=gr.Image(type="numpy", label="Importer une œuvre"), | |
| outputs=gr.Plot(label="Résultats de la classification"), | |
| title="🎨 Classification de style pictural", | |
| description="Upload une image et découvre le mouvement pictural estimé par le CNN. Les barres vertes indiquent une probabilité supérieure ou égale à 50%.", | |
| examples=None, | |
| theme=gr.themes.Soft() | |
| ) | |
| demo.launch() |