Image_Segmentation_ / experiments /enhanced_kmeans_segmenter.py
AJain1234's picture
Upload folder using huggingface_hub
0f9608b verified
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
def slic_kmeans(image_path, K=100, m=10, max_iter=10):
"""
Perform superpixel segmentation using enhanced K-means with LAB+XY.
Args:
image (np.ndarray): RGB input image.
K (int): Number of superpixels.
m (float): Compactness factor.
max_iter (int): Number of iterations.
Returns:
segmented_img: The segmented image with cluster colors.
labels: Cluster label for each pixel.
"""
jpg_image = Image.open(image_path)
image = np.array(jpg_image)
h, w = image.shape[:2]
S = int(np.sqrt(h * w / K)) # grid interval
# Convert to LAB color space
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB).astype(np.float32)
# Create 5D feature vector [L, a, b, x, y]
X, Y = np.meshgrid(np.arange(w), np.arange(h))
features = np.dstack((lab, X, Y)).reshape((-1, 5))
# Initialize cluster centers on grid
centers = []
for y in range(S // 2, h, S):
for x in range(S // 2, w, S):
center = features[y * w + x]
centers.append(center)
centers = np.array(centers)
labels = np.full((h * w,), -1, dtype=np.int32)
distances = np.full((h * w,), np.inf)
for iteration in tqdm(range(max_iter)):
for idx, center in enumerate(centers):
l, a, b, cx, cy = center
x_start, x_end = max(0, int(cx - S)), min(w, int(cx + S))
y_start, y_end = max(0, int(cy - S)), min(h, int(cy + S))
for y in range(y_start, y_end):
for x in range(x_start, x_end):
i = y * w + x
fp = features[i]
dc = np.linalg.norm(fp[:3] - center[:3]) # LAB distance
ds = np.linalg.norm(fp[3:] - center[3:]) # XY distance
D = np.sqrt(dc**2 + (ds / S)**2 * m**2)
if D < distances[i]:
distances[i] = D
labels[i] = idx
# Update cluster centers
new_centers = np.zeros_like(centers)
count = np.zeros(len(centers))
for i in range(h * w):
lbl = labels[i]
new_centers[lbl] += features[i]
count[lbl] += 1
for i in range(len(centers)):
if count[i] > 0:
new_centers[i] /= count[i]
centers = new_centers
# Recolor image based on cluster centers
segmented_img = np.zeros((h, w, 3), dtype=np.uint8)
for i in range(h * w):
lbl = labels[i]
lab_val = centers[lbl][:3]
lab_pixel = np.uint8([[lab_val]])
rgb_pixel = cv2.cvtColor(lab_pixel, cv2.COLOR_LAB2RGB)[0][0]
segmented_img[i // w, i % w] = rgb_pixel
return jpg_image, Image.fromarray(segmented_img), labels.reshape((h, w)), centers
# img_path = "/home/akshat/projects/CSL7360_Project/bird.jpeg"
# image = cv2.imread(img_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# _,seg_img, labels, centers = slic_kmeans(image, K=2, m=20)
# seg_img.save("enhaned_kmeans_segmented.png")
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(image)
# plt.title("Original Image")
# plt.axis("off")
# plt.subplot(1, 2, 2)
# plt.imshow(seg_img)
# plt.title("SLIC-like K-Means Segmentation")
# plt.axis("off")
# plt.tight_layout()
# plt.show()