Ynvers commited on
Commit
3a87f82
·
1 Parent(s): eab3be8
Files changed (4) hide show
  1. MEN-Denim-id_00000089-17_4_full.png +0 -0
  2. ben.jpg +0 -0
  3. image.webp +0 -0
  4. mode_app.py +232 -0
MEN-Denim-id_00000089-17_4_full.png ADDED
ben.jpg ADDED
image.webp ADDED
mode_app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import cv2
4
+ import colorsys
5
+
6
+
7
+ def generate_distinct_colors(k):
8
+ colors = []
9
+ for i in range(k):
10
+ hue = i / k
11
+ saturation = 0.8
12
+ value = 0.8
13
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
14
+ colors.append(np.array(rgb) * 255)
15
+ return np.array(colors, dtype=np.uint8)
16
+
17
+ def k_means_segmentation(image, k, max_iters=100, tol=1e-4):
18
+ pixels = image.reshape(-1, 3).astype(np.float32)
19
+ np.random.seed(42)
20
+ centroids = pixels[np.random.choice(pixels.shape[0], k, replace=False)]
21
+
22
+ for iteration in range(max_iters):
23
+ distances = np.linalg.norm(pixels[:, np.newaxis] - centroids, axis=2)
24
+ labels = np.argmin(distances, axis=1)
25
+
26
+ new_centroids = np.array([
27
+ pixels[labels == i].mean(axis=0) if np.sum(labels == i) > 0
28
+ else centroids[i]
29
+ for i in range(k)
30
+ ])
31
+
32
+ if np.linalg.norm(new_centroids - centroids) < tol:
33
+ print(f"K-Means convergé après {iteration + 1} itérations.")
34
+ break
35
+
36
+ centroids = new_centroids
37
+ distinct_colors = generate_distinct_colors(k)
38
+ segmented_pixels = distinct_colors[labels].astype(np.uint8)
39
+ segmented_image = segmented_pixels.reshape(image.shape)
40
+
41
+ return segmented_image
42
+
43
+ def mean_shift_segmentation(image, bandwidth=30, max_iters=20, tol=1e-3, max_image_size=200):
44
+ """
45
+ Version optimisée de la segmentation Mean Shift
46
+
47
+ Args:
48
+ image: Image d'entrée BGR
49
+ bandwidth: Rayon de recherche
50
+ max_iters: Nombre maximum d'itérations
51
+ tol: Seuil de convergence
52
+ max_image_size: Taille maximale de l'image (le plus grand côté)
53
+ """
54
+ # Redimensionnement de l'image si nécessaire
55
+ h, w = image.shape[:2]
56
+ scale = max_image_size / max(h, w)
57
+ if scale < 1:
58
+ new_size = (int(w * scale), int(h * scale))
59
+ image = cv2.resize(image, new_size)
60
+
61
+ # Conversion en LAB
62
+ lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
63
+
64
+ # Préparation des données
65
+ pixels = lab_image.reshape(-1, 3).astype(np.float32)
66
+
67
+ # Sous-échantillonnage pour les centres initiaux (1 pixel sur 10)
68
+ step = 10
69
+ centers = pixels[::step].copy()
70
+
71
+ # Boucle principale de Mean Shift
72
+ for iteration in range(max_iters):
73
+ # Calcul vectorisé des distances entre tous les pixels et les centres
74
+ distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2))
75
+
76
+ # Pour chaque centre, trouver les pixels dans son rayon et calculer la nouvelle position
77
+ new_centers = []
78
+ for i in range(len(centers)):
79
+ in_bandwidth = distances[:, i] < bandwidth
80
+ if np.sum(in_bandwidth) > 0:
81
+ new_centers.append(np.mean(pixels[in_bandwidth], axis=0))
82
+ else:
83
+ new_centers.append(centers[i])
84
+
85
+ new_centers = np.array(new_centers)
86
+
87
+ # Vérifier la convergence
88
+ center_shifts = np.sqrt(np.sum((centers - new_centers) ** 2, axis=1))
89
+ centers = new_centers
90
+
91
+ if np.all(center_shifts < tol):
92
+ print(f"Mean Shift convergé après {iteration + 1} itérations.")
93
+ break
94
+
95
+ # Attribution des labels
96
+ distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2))
97
+ labels = np.argmin(distances, axis=1)
98
+
99
+ # Génération des couleurs distinctes
100
+ n_clusters = len(centers)
101
+ distinct_colors = np.zeros((n_clusters, 3), dtype=np.uint8)
102
+
103
+ for i in range(n_clusters):
104
+ hue = i / n_clusters
105
+ saturation = 0.8
106
+ value = 0.8
107
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
108
+ distinct_colors[i] = np.array(rgb) * 255
109
+
110
+ # Création de l'image segmentée
111
+ segmented_pixels = distinct_colors[labels]
112
+ segmented_image = segmented_pixels.reshape(image.shape)
113
+
114
+ # Redimensionnement au format original si nécessaire
115
+ if scale < 1:
116
+ segmented_image = cv2.resize(segmented_image, (w, h))
117
+
118
+ return segmented_image
119
+
120
+ def segment_image(image, k=None):
121
+ """
122
+ Fonction principale qui gère la segmentation en fonction des paramètres
123
+ """
124
+ if image is None:
125
+ return None
126
+
127
+ # Conversion de l'image au format BGR si nécessaire
128
+ if len(image.shape) == 2: # Image en niveaux de gris
129
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
130
+ elif image.shape[2] == 4: # Image avec canal alpha
131
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
132
+
133
+ try:
134
+ if k is not None and k > 0:
135
+ return k_means_segmentation(image, int(k))
136
+ else:
137
+ return mean_shift_segmentation(image)
138
+ except Exception as e:
139
+ print(f"Erreur lors de la segmentation: {str(e)}")
140
+ return None
141
+
142
+
143
+ def segment_image(image, k=None):
144
+ """
145
+ Fonction principale qui gère la segmentation en fonction des paramètres
146
+ """
147
+ if image is None:
148
+ return None
149
+
150
+ # Conversion de l'image au format BGR si nécessaire
151
+ if len(image.shape) == 2: # Image en niveaux de gris
152
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
153
+ elif image.shape[2] == 4: # Image avec canal alpha
154
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
155
+
156
+ try:
157
+ if k is not None and k > 0:
158
+ return k_means_segmentation(image, int(k))
159
+ else:
160
+ return mean_shift_segmentation(image)
161
+ except Exception as e:
162
+ print(f"Erreur lors de la segmentation: {str(e)}")
163
+ return None
164
+
165
+
166
+ with gr.Blocks(title="Art par Segmentation d'Images", theme=gr.themes.Soft()) as app:
167
+ gr.Markdown("""
168
+ # 🎨 Studio de Segmentation Artistique
169
+
170
+ Transformez vos photos en œuvres d'art segmentées !
171
+
172
+ ### Mode d'emploi :
173
+ 1. Téléchargez une image
174
+ 2. Choisissez votre style :
175
+ - **K-means** : Entrez un nombre K pour définir le nombre exact de couleurs
176
+ - **Mean Shift** : Laissez K vide pour une segmentation automatique
177
+ 3. Cliquez sur "Transformer" et admirez le résultat !
178
+ """)
179
+
180
+ with gr.Row():
181
+ with gr.Column(scale=1):
182
+ input_image = gr.Image(
183
+ label="Image Originale",
184
+ type="numpy"
185
+ )
186
+ k_input = gr.Number(
187
+ label="Nombre de segments (K)",
188
+ minimum=0,
189
+ step=1
190
+ )
191
+ segment_btn = gr.Button(
192
+ "🎨 Transformer",
193
+ variant="primary",
194
+ scale=0
195
+ )
196
+
197
+ with gr.Column(scale=1):
198
+ output_image = gr.Image(
199
+ label="Résultat",
200
+ type="numpy"
201
+ )
202
+
203
+ # Exemples d'utilisation
204
+ gr.Examples(
205
+ examples=[
206
+ ["MEN-Denim-id_00000089-17_4_full.png", 5],
207
+ ["ben.jpg", None, ""],
208
+
209
+ ],
210
+ inputs=[input_image, k_input],
211
+ outputs=output_image,
212
+ fn=segment_image,
213
+ cache_examples=True
214
+ )
215
+
216
+ # Configuration des événements
217
+ segment_btn.click(
218
+ fn=segment_image,
219
+ inputs=[input_image, k_input],
220
+ outputs=output_image
221
+ )
222
+
223
+ # Message d'aide supplémentaire
224
+ gr.Markdown("""
225
+ ### 💡 Conseils :
226
+ - Pour K-means : essayez des valeurs entre 3 et 10 pour des résultats intéressants
227
+ - Pour Mean Shift : idéal pour les images complexes avec beaucoup de détails
228
+ - Les images de taille moyenne fonctionnent le mieux
229
+ """)
230
+
231
+ # Lancement de l'application
232
+ app.launch(share=True)