File size: 3,201 Bytes
4bb934b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import io

def initialize_centroids(data, K):
    """Randomly choose K data points as initial centroids."""
    indices = np.random.choice(data.shape[0], K, replace=False)
    return data[indices]

def compute_distances(data, centroids):
    """Compute the Euclidean distance between each data point and each centroid."""
    return np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)

def update_centroids(data, labels, K):
    """Update centroids as the mean of the points assigned to each cluster."""
    new_centroids = np.zeros((K, data.shape[1]))
    for k in range(K):
        cluster_points = data[labels == k]
        if len(cluster_points) > 0:
            new_centroids[k] = np.mean(cluster_points, axis=0)
    return new_centroids

def kmeans_from_scratch(image, K=4, max_iters=100, tol=1e-4):
    """Apply K-means clustering from scratch to segment the image."""
    data = image.reshape((-1, 3)).astype(np.float32)

    centroids = initialize_centroids(data, K)

    for i in range(max_iters):
        distances = compute_distances(data, centroids)
        labels = np.argmin(distances, axis=1)

        new_centroids = update_centroids(data, labels, K)
        shift = np.linalg.norm(new_centroids - centroids)

        if shift < tol:
            break
        centroids = new_centroids

    segmented_data = centroids[labels].astype(np.uint8)
    segmented_image = segmented_data.reshape(image.shape)

    return segmented_image, labels.reshape(image.shape[:2]), centroids.astype(np.uint8)

def generate_kmeans_segmented_image(image_path, k=3):
    """Process image with K-means for Gradio app"""
    image = Image.open(image_path)
    image_np = np.array(image)
    
    if len(image_np.shape) == 3:
        image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
        image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB) 
    else:
        image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
    
    seg_img, labels, centers = kmeans_from_scratch(image_rgb, K=k)
    
    colors_image = np.zeros((50 * k, 100, 3), dtype=np.uint8)
    for i, color in enumerate(centers):
        colors_image[i*50:(i+1)*50, :] = color
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    axes[0].imshow(image_rgb)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    axes[1].imshow(seg_img)
    axes[1].set_title(f"K-Means (K={k})")
    axes[1].axis('off')
    
    axes[2].imshow(colors_image)
    axes[2].set_title("Cluster Colors")
    axes[2].axis('off')
    
    plt.tight_layout()
    
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    comparison_image = Image.open(buf)
    plt.close(fig)
    
    return image, Image.fromarray(seg_img), comparison_image, f"K-Means clustering with K={k}"

if __name__ == "__main__":
    image_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg"
    original, segmented, comparison, text = generate_kmeans_segmented_image(image_path, k=3)
    
    # Save output images instead of displaying them
    segmented.save("kmeans_segmented.png")
    comparison.save("kmeans_comparison.png")
    print(text)