Spaces:
Build error
Build error
File size: 3,201 Bytes
4bb934b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import io
def initialize_centroids(data, K):
"""Randomly choose K data points as initial centroids."""
indices = np.random.choice(data.shape[0], K, replace=False)
return data[indices]
def compute_distances(data, centroids):
"""Compute the Euclidean distance between each data point and each centroid."""
return np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
def update_centroids(data, labels, K):
"""Update centroids as the mean of the points assigned to each cluster."""
new_centroids = np.zeros((K, data.shape[1]))
for k in range(K):
cluster_points = data[labels == k]
if len(cluster_points) > 0:
new_centroids[k] = np.mean(cluster_points, axis=0)
return new_centroids
def kmeans_from_scratch(image, K=4, max_iters=100, tol=1e-4):
"""Apply K-means clustering from scratch to segment the image."""
data = image.reshape((-1, 3)).astype(np.float32)
centroids = initialize_centroids(data, K)
for i in range(max_iters):
distances = compute_distances(data, centroids)
labels = np.argmin(distances, axis=1)
new_centroids = update_centroids(data, labels, K)
shift = np.linalg.norm(new_centroids - centroids)
if shift < tol:
break
centroids = new_centroids
segmented_data = centroids[labels].astype(np.uint8)
segmented_image = segmented_data.reshape(image.shape)
return segmented_image, labels.reshape(image.shape[:2]), centroids.astype(np.uint8)
def generate_kmeans_segmented_image(image_path, k=3):
"""Process image with K-means for Gradio app"""
image = Image.open(image_path)
image_np = np.array(image)
if len(image_np.shape) == 3:
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
else:
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
seg_img, labels, centers = kmeans_from_scratch(image_rgb, K=k)
colors_image = np.zeros((50 * k, 100, 3), dtype=np.uint8)
for i, color in enumerate(centers):
colors_image[i*50:(i+1)*50, :] = color
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(image_rgb)
axes[0].set_title("Original Image")
axes[0].axis('off')
axes[1].imshow(seg_img)
axes[1].set_title(f"K-Means (K={k})")
axes[1].axis('off')
axes[2].imshow(colors_image)
axes[2].set_title("Cluster Colors")
axes[2].axis('off')
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
comparison_image = Image.open(buf)
plt.close(fig)
return image, Image.fromarray(seg_img), comparison_image, f"K-Means clustering with K={k}"
if __name__ == "__main__":
image_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg"
original, segmented, comparison, text = generate_kmeans_segmented_image(image_path, k=3)
# Save output images instead of displaying them
segmented.save("kmeans_segmented.png")
comparison.save("kmeans_comparison.png")
print(text) |