File size: 7,485 Bytes
3a87f82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import numpy as np
import gradio as gr
import cv2
import colorsys


def generate_distinct_colors(k):
  colors = []
  for i in range(k):
    hue = i / k
    saturation = 0.8
    value = 0.8
    rgb = colorsys.hsv_to_rgb(hue, saturation, value)
    colors.append(np.array(rgb) * 255)
  return np.array(colors, dtype=np.uint8)

def k_means_segmentation(image, k, max_iters=100, tol=1e-4):
  pixels = image.reshape(-1, 3).astype(np.float32)
  np.random.seed(42)
  centroids = pixels[np.random.choice(pixels.shape[0], k, replace=False)]

  for iteration in range(max_iters):
    distances = np.linalg.norm(pixels[:, np.newaxis] - centroids, axis=2)
    labels = np.argmin(distances, axis=1)

    new_centroids = np.array([
        pixels[labels == i].mean(axis=0) if np.sum(labels == i) > 0
        else centroids[i]
        for i in range(k)
    ])

    if np.linalg.norm(new_centroids - centroids) < tol:
      print(f"K-Means convergé après {iteration + 1} itérations.")
      break

    centroids = new_centroids
  distinct_colors = generate_distinct_colors(k)
  segmented_pixels = distinct_colors[labels].astype(np.uint8)
  segmented_image = segmented_pixels.reshape(image.shape)

  return segmented_image

def mean_shift_segmentation(image, bandwidth=30, max_iters=20, tol=1e-3, max_image_size=200):
    """
    Version optimisée de la segmentation Mean Shift
    
    Args:
        image: Image d'entrée BGR
        bandwidth: Rayon de recherche
        max_iters: Nombre maximum d'itérations
        tol: Seuil de convergence
        max_image_size: Taille maximale de l'image (le plus grand côté)
    """
    # Redimensionnement de l'image si nécessaire
    h, w = image.shape[:2]
    scale = max_image_size / max(h, w)
    if scale < 1:
        new_size = (int(w * scale), int(h * scale))
        image = cv2.resize(image, new_size)
    
    # Conversion en LAB
    lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    
    # Préparation des données
    pixels = lab_image.reshape(-1, 3).astype(np.float32)
    
    # Sous-échantillonnage pour les centres initiaux (1 pixel sur 10)
    step = 10
    centers = pixels[::step].copy()
    
    # Boucle principale de Mean Shift
    for iteration in range(max_iters):
        # Calcul vectorisé des distances entre tous les pixels et les centres
        distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2))
        
        # Pour chaque centre, trouver les pixels dans son rayon et calculer la nouvelle position
        new_centers = []
        for i in range(len(centers)):
            in_bandwidth = distances[:, i] < bandwidth
            if np.sum(in_bandwidth) > 0:
                new_centers.append(np.mean(pixels[in_bandwidth], axis=0))
            else:
                new_centers.append(centers[i])
        
        new_centers = np.array(new_centers)
        
        # Vérifier la convergence
        center_shifts = np.sqrt(np.sum((centers - new_centers) ** 2, axis=1))
        centers = new_centers
        
        if np.all(center_shifts < tol):
            print(f"Mean Shift convergé après {iteration + 1} itérations.")
            break
    
    # Attribution des labels
    distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2))
    labels = np.argmin(distances, axis=1)
    
    # Génération des couleurs distinctes
    n_clusters = len(centers)
    distinct_colors = np.zeros((n_clusters, 3), dtype=np.uint8)
    
    for i in range(n_clusters):
        hue = i / n_clusters
        saturation = 0.8
        value = 0.8
        rgb = colorsys.hsv_to_rgb(hue, saturation, value)
        distinct_colors[i] = np.array(rgb) * 255
    
    # Création de l'image segmentée
    segmented_pixels = distinct_colors[labels]
    segmented_image = segmented_pixels.reshape(image.shape)
    
    # Redimensionnement au format original si nécessaire
    if scale < 1:
        segmented_image = cv2.resize(segmented_image, (w, h))
    
    return segmented_image

def segment_image(image, k=None):
    """
    Fonction principale qui gère la segmentation en fonction des paramètres
    """
    if image is None:
        return None
    
    # Conversion de l'image au format BGR si nécessaire
    if len(image.shape) == 2:  # Image en niveaux de gris
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    elif image.shape[2] == 4:  # Image avec canal alpha
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
    
    try:
        if k is not None and k > 0:
            return k_means_segmentation(image, int(k))
        else:
            return mean_shift_segmentation(image)
    except Exception as e:
        print(f"Erreur lors de la segmentation: {str(e)}")
        return None
    

def segment_image(image, k=None):
    """
    Fonction principale qui gère la segmentation en fonction des paramètres
    """
    if image is None:
        return None
    
    # Conversion de l'image au format BGR si nécessaire
    if len(image.shape) == 2:  # Image en niveaux de gris
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    elif image.shape[2] == 4:  # Image avec canal alpha
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
    
    try:
        if k is not None and k > 0:
            return k_means_segmentation(image, int(k))
        else:
            return mean_shift_segmentation(image)
    except Exception as e:
        print(f"Erreur lors de la segmentation: {str(e)}")
        return None
    

with gr.Blocks(title="Art par Segmentation d'Images", theme=gr.themes.Soft()) as app:
    gr.Markdown("""
    # 🎨 Studio de Segmentation Artistique
    
    Transformez vos photos en œuvres d'art segmentées !
    
    ### Mode d'emploi :
    1. Téléchargez une image
    2. Choisissez votre style :
       - **K-means** : Entrez un nombre K pour définir le nombre exact de couleurs
       - **Mean Shift** : Laissez K vide pour une segmentation automatique
    3. Cliquez sur "Transformer" et admirez le résultat !
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(
                label="Image Originale",
                type="numpy"
            )
            k_input = gr.Number(
                label="Nombre de segments (K)",
                minimum=0,
                step=1
            )
            segment_btn = gr.Button(
                "🎨 Transformer", 
                variant="primary",
                scale=0
            )
            
        with gr.Column(scale=1):
            output_image = gr.Image(
                label="Résultat",
                type="numpy"
            )
            
    # Exemples d'utilisation
    gr.Examples(
        examples=[
            ["MEN-Denim-id_00000089-17_4_full.png", 5],
            ["ben.jpg", None, ""],

        ],
        inputs=[input_image, k_input],
        outputs=output_image,
        fn=segment_image,
        cache_examples=True
    )
    
    # Configuration des événements
    segment_btn.click(
        fn=segment_image,
        inputs=[input_image, k_input],
        outputs=output_image
    )
    
    # Message d'aide supplémentaire
    gr.Markdown("""
    ### 💡 Conseils :
    - Pour K-means : essayez des valeurs entre 3 et 10 pour des résultats intéressants
    - Pour Mean Shift : idéal pour les images complexes avec beaucoup de détails
    - Les images de taille moyenne fonctionnent le mieux
    """)

# Lancement de l'application
app.launch(share=True)