Image_Segmentation_ / experiments /kmeans_segmenter.py
AJain1234's picture
Upload folder using huggingface_hub
0f9608b verified
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import io
def initialize_centroids(data, K):
"""Randomly choose K data points as initial centroids."""
indices = np.random.choice(data.shape[0], K, replace=False)
return data[indices]
def compute_distances(data, centroids):
"""Compute the Euclidean distance between each data point and each centroid."""
return np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
def update_centroids(data, labels, K):
"""Update centroids as the mean of the points assigned to each cluster."""
new_centroids = np.zeros((K, data.shape[1]))
for k in range(K):
cluster_points = data[labels == k]
if len(cluster_points) > 0:
new_centroids[k] = np.mean(cluster_points, axis=0)
return new_centroids
def kmeans_from_scratch(image, K=4, max_iters=100, tol=1e-4):
"""Apply K-means clustering from scratch to segment the image."""
data = image.reshape((-1, 3)).astype(np.float32)
centroids = initialize_centroids(data, K)
for i in range(max_iters):
distances = compute_distances(data, centroids)
labels = np.argmin(distances, axis=1)
new_centroids = update_centroids(data, labels, K)
shift = np.linalg.norm(new_centroids - centroids)
if shift < tol:
break
centroids = new_centroids
segmented_data = centroids[labels].astype(np.uint8)
segmented_image = segmented_data.reshape(image.shape)
return segmented_image, labels.reshape(image.shape[:2]), centroids.astype(np.uint8)
def generate_kmeans_segmented_image(image_path, k=3):
"""Process image with K-means for Gradio app"""
image = Image.open(image_path)
image_np = np.array(image)
if len(image_np.shape) == 3:
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
else:
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
seg_img, labels, centers = kmeans_from_scratch(image_rgb, K=k)
colors_image = np.zeros((50 * k, 100, 3), dtype=np.uint8)
for i, color in enumerate(centers):
colors_image[i*50:(i+1)*50, :] = color
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(image_rgb)
axes[0].set_title("Original Image")
axes[0].axis('off')
axes[1].imshow(seg_img)
axes[1].set_title(f"K-Means (K={k})")
axes[1].axis('off')
axes[2].imshow(colors_image)
axes[2].set_title("Cluster Colors")
axes[2].axis('off')
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
comparison_image = Image.open(buf)
plt.close(fig)
return image, Image.fromarray(seg_img), comparison_image, f"K-Means clustering with K={k}"
if __name__ == "__main__":
image_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg"
original, segmented, comparison, text = generate_kmeans_segmented_image(image_path, k=3)
# Save output images instead of displaying them
segmented.save("kmeans_segmented.png")
comparison.save("kmeans_comparison.png")
print(text)