import numpy as np import gradio as gr import cv2 import colorsys def generate_distinct_colors(k): colors = [] for i in range(k): hue = i / k saturation = 0.8 value = 0.8 rgb = colorsys.hsv_to_rgb(hue, saturation, value) colors.append(np.array(rgb) * 255) return np.array(colors, dtype=np.uint8) def k_means_segmentation(image, k, max_iters=100, tol=1e-4): pixels = image.reshape(-1, 3).astype(np.float32) np.random.seed(42) centroids = pixels[np.random.choice(pixels.shape[0], k, replace=False)] for iteration in range(max_iters): distances = np.linalg.norm(pixels[:, np.newaxis] - centroids, axis=2) labels = np.argmin(distances, axis=1) new_centroids = np.array([ pixels[labels == i].mean(axis=0) if np.sum(labels == i) > 0 else centroids[i] for i in range(k) ]) if np.linalg.norm(new_centroids - centroids) < tol: print(f"K-Means convergé après {iteration + 1} itérations.") break centroids = new_centroids distinct_colors = generate_distinct_colors(k) segmented_pixels = distinct_colors[labels].astype(np.uint8) segmented_image = segmented_pixels.reshape(image.shape) return segmented_image def mean_shift_segmentation(image, bandwidth=30, max_iters=20, tol=1e-3, max_image_size=200): """ Version optimisée de la segmentation Mean Shift Args: image: Image d'entrée BGR bandwidth: Rayon de recherche max_iters: Nombre maximum d'itérations tol: Seuil de convergence max_image_size: Taille maximale de l'image (le plus grand côté) """ # Redimensionnement de l'image si nécessaire h, w = image.shape[:2] scale = max_image_size / max(h, w) if scale < 1: new_size = (int(w * scale), int(h * scale)) image = cv2.resize(image, new_size) # Conversion en LAB lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) # Préparation des données pixels = lab_image.reshape(-1, 3).astype(np.float32) # Sous-échantillonnage pour les centres initiaux (1 pixel sur 10) step = 10 centers = pixels[::step].copy() # Boucle principale de Mean Shift for iteration in range(max_iters): # Calcul vectorisé des distances entre tous les pixels et les centres distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2)) # Pour chaque centre, trouver les pixels dans son rayon et calculer la nouvelle position new_centers = [] for i in range(len(centers)): in_bandwidth = distances[:, i] < bandwidth if np.sum(in_bandwidth) > 0: new_centers.append(np.mean(pixels[in_bandwidth], axis=0)) else: new_centers.append(centers[i]) new_centers = np.array(new_centers) # Vérifier la convergence center_shifts = np.sqrt(np.sum((centers - new_centers) ** 2, axis=1)) centers = new_centers if np.all(center_shifts < tol): print(f"Mean Shift convergé après {iteration + 1} itérations.") break # Attribution des labels distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2)) labels = np.argmin(distances, axis=1) # Génération des couleurs distinctes n_clusters = len(centers) distinct_colors = np.zeros((n_clusters, 3), dtype=np.uint8) for i in range(n_clusters): hue = i / n_clusters saturation = 0.8 value = 0.8 rgb = colorsys.hsv_to_rgb(hue, saturation, value) distinct_colors[i] = np.array(rgb) * 255 # Création de l'image segmentée segmented_pixels = distinct_colors[labels] segmented_image = segmented_pixels.reshape(image.shape) # Redimensionnement au format original si nécessaire if scale < 1: segmented_image = cv2.resize(segmented_image, (w, h)) return segmented_image def segment_image(image, k=None): """ Fonction principale qui gère la segmentation en fonction des paramètres """ if image is None: return None # Conversion de l'image au format BGR si nécessaire if len(image.shape) == 2: # Image en niveaux de gris image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.shape[2] == 4: # Image avec canal alpha image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) try: if k is not None and k > 0: return k_means_segmentation(image, int(k)) else: return mean_shift_segmentation(image) except Exception as e: print(f"Erreur lors de la segmentation: {str(e)}") return None def segment_image(image, k=None): """ Fonction principale qui gère la segmentation en fonction des paramètres """ if image is None: return None # Conversion de l'image au format BGR si nécessaire if len(image.shape) == 2: # Image en niveaux de gris image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.shape[2] == 4: # Image avec canal alpha image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) try: if k is not None and k > 0: return k_means_segmentation(image, int(k)) else: return mean_shift_segmentation(image) except Exception as e: print(f"Erreur lors de la segmentation: {str(e)}") return None with gr.Blocks(title="Art par Segmentation d'Images", theme=gr.themes.Soft()) as app: gr.Markdown(""" # 🎨 Studio de Segmentation Artistique Transformez vos photos en œuvres d'art segmentées ! ### Mode d'emploi : 1. Téléchargez une image 2. Choisissez votre style : - **K-means** : Entrez un nombre K pour définir le nombre exact de couleurs - **Mean Shift** : Laissez K vide pour une segmentation automatique 3. Cliquez sur "Transformer" et admirez le résultat ! """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Image Originale", type="numpy" ) k_input = gr.Number( label="Nombre de segments (K)", minimum=0, step=1 ) segment_btn = gr.Button( "🎨 Transformer", variant="primary", scale=0 ) with gr.Column(scale=1): output_image = gr.Image( label="Résultat", type="numpy" ) # Exemples d'utilisation gr.Examples( examples=[ ["MEN-Denim-id_00000089-17_4_full.png", 5], ["ben.jpg", None, ""], ], inputs=[input_image, k_input], outputs=output_image, fn=segment_image, cache_examples=True ) # Configuration des événements segment_btn.click( fn=segment_image, inputs=[input_image, k_input], outputs=output_image ) # Message d'aide supplémentaire gr.Markdown(""" ### 💡 Conseils : - Pour K-means : essayez des valeurs entre 3 et 10 pour des résultats intéressants - Pour Mean Shift : idéal pour les images complexes avec beaucoup de détails - Les images de taille moyenne fonctionnent le mieux """) # Lancement de l'application app.launch(share=True)