Spaces:
Build error
Build error
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| from PIL import Image | |
| import io | |
| def initialize_centroids(data, K): | |
| """Randomly choose K data points as initial centroids.""" | |
| indices = np.random.choice(data.shape[0], K, replace=False) | |
| return data[indices] | |
| def compute_distances(data, centroids): | |
| """Compute the Euclidean distance between each data point and each centroid.""" | |
| return np.linalg.norm(data[:, np.newaxis] - centroids, axis=2) | |
| def update_centroids(data, labels, K): | |
| """Update centroids as the mean of the points assigned to each cluster.""" | |
| new_centroids = np.zeros((K, data.shape[1])) | |
| for k in range(K): | |
| cluster_points = data[labels == k] | |
| if len(cluster_points) > 0: | |
| new_centroids[k] = np.mean(cluster_points, axis=0) | |
| return new_centroids | |
| def kmeans_from_scratch(image, K=4, max_iters=100, tol=1e-4): | |
| """Apply K-means clustering from scratch to segment the image.""" | |
| data = image.reshape((-1, 3)).astype(np.float32) | |
| centroids = initialize_centroids(data, K) | |
| for i in range(max_iters): | |
| distances = compute_distances(data, centroids) | |
| labels = np.argmin(distances, axis=1) | |
| new_centroids = update_centroids(data, labels, K) | |
| shift = np.linalg.norm(new_centroids - centroids) | |
| if shift < tol: | |
| break | |
| centroids = new_centroids | |
| segmented_data = centroids[labels].astype(np.uint8) | |
| segmented_image = segmented_data.reshape(image.shape) | |
| return segmented_image, labels.reshape(image.shape[:2]), centroids.astype(np.uint8) | |
| def generate_kmeans_segmented_image(image_path, k=3): | |
| """Process image with K-means for Gradio app""" | |
| image = Image.open(image_path) | |
| image_np = np.array(image) | |
| if len(image_np.shape) == 3: | |
| image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB) | |
| else: | |
| image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) | |
| seg_img, labels, centers = kmeans_from_scratch(image_rgb, K=k) | |
| colors_image = np.zeros((50 * k, 100, 3), dtype=np.uint8) | |
| for i, color in enumerate(centers): | |
| colors_image[i*50:(i+1)*50, :] = color | |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
| axes[0].imshow(image_rgb) | |
| axes[0].set_title("Original Image") | |
| axes[0].axis('off') | |
| axes[1].imshow(seg_img) | |
| axes[1].set_title(f"K-Means (K={k})") | |
| axes[1].axis('off') | |
| axes[2].imshow(colors_image) | |
| axes[2].set_title("Cluster Colors") | |
| axes[2].axis('off') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| comparison_image = Image.open(buf) | |
| plt.close(fig) | |
| return image, Image.fromarray(seg_img), comparison_image, f"K-Means clustering with K={k}" | |
| if __name__ == "__main__": | |
| image_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg" | |
| original, segmented, comparison, text = generate_kmeans_segmented_image(image_path, k=3) | |
| # Save output images instead of displaying them | |
| segmented.save("kmeans_segmented.png") | |
| comparison.save("kmeans_comparison.png") | |
| print(text) |