debug clustering
Browse files- clustering.py +63 -37
clustering.py
CHANGED
|
@@ -5,21 +5,19 @@
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
import scipy.ndimage as ndi
|
| 11 |
from scipy.spatial import KDTree
|
| 12 |
-
|
| 13 |
from sklearn.cluster import DBSCAN
|
| 14 |
-
|
| 15 |
import logging
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
-
def get_centroids(image
|
| 19 |
eps=1, min_samples=5, metric='euclidean',
|
| 20 |
-
min_size
|
| 21 |
-
filter_close_centroids
|
| 22 |
-
"""
|
|
|
|
| 23 |
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
|
| 24 |
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
|
| 25 |
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
|
|
@@ -31,27 +29,48 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
|
|
| 31 |
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
|
| 32 |
|
| 33 |
Args:
|
| 34 |
-
image
|
| 35 |
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
|
| 36 |
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
|
| 37 |
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
|
| 38 |
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
|
| 39 |
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
|
| 40 |
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
|
| 41 |
-
filter_close_centroids (
|
| 42 |
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
|
| 43 |
|
| 44 |
Returns:
|
| 45 |
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
|
| 46 |
"""
|
| 47 |
|
| 48 |
-
|
| 49 |
centroids = []
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# apply the threshold to identify regions of "dark" pixels
|
| 53 |
# the result is a binary mask (true/false) whether a given pixel is above or below the threshold
|
| 54 |
-
|
| 55 |
|
| 56 |
# sometimes the clusters have small holes in them, for example, individual pixels
|
| 57 |
# inside a region below the threshold. This may confuse the clustering algorith later on
|
|
@@ -59,20 +78,22 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
|
|
| 59 |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
|
| 60 |
# N.B. the algorith only works on binay data
|
| 61 |
if fill_holes:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# apply the treshold to the image to identify regions of "dark" pixels
|
| 65 |
-
#cluster_candidates = np.asarray(image < image_threshold).nonzero()
|
| 66 |
|
| 67 |
# transform image format into a numpy array to pass on to DBSCAN clustering
|
| 68 |
-
|
|
|
|
| 69 |
cluster_candidates = np.transpose(cluster_candidates)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
|
| 73 |
# (e.g. they are too small, etc)
|
| 74 |
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
|
| 75 |
-
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=
|
| 76 |
|
| 77 |
dbscan.fit(cluster_candidates)
|
| 78 |
|
|
@@ -80,32 +101,37 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
|
|
| 80 |
# Number of clusters in labels, ignoring noise if present.
|
| 81 |
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
| 82 |
n_noise = list(labels).count(-1)
|
| 83 |
-
logging.
|
| 84 |
|
| 85 |
|
| 86 |
# now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
|
| 87 |
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
|
| 88 |
# in the cluster candidates
|
| 89 |
for i in set(labels):
|
| 90 |
-
if i
|
| 91 |
# all points belonging to a given cluster
|
| 92 |
-
cluster_points = cluster_candidates[labels==i, :]
|
| 93 |
if len(cluster_points) > min_size:
|
| 94 |
-
x_mean=np.mean(cluster_points, axis=0)[0]
|
| 95 |
-
y_mean=np.mean(cluster_points, axis=0)[1]
|
| 96 |
-
centroids.append([x_mean,y_mean])
|
| 97 |
|
| 98 |
-
if filter_close_centroids:
|
| 99 |
proximity_tree = KDTree(centroids)
|
| 100 |
pairs = proximity_tree.query_pairs(filter_radius)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
|
|
|
|
|
|
| 8 |
import scipy.ndimage as ndi
|
| 9 |
from scipy.spatial import KDTree
|
|
|
|
| 10 |
from sklearn.cluster import DBSCAN
|
|
|
|
| 11 |
import logging
|
| 12 |
+
from PIL import Image # Import PIL for type checking/conversion if necessary
|
| 13 |
|
| 14 |
|
| 15 |
+
def get_centroids(image, image_threshold=20,
|
| 16 |
eps=1, min_samples=5, metric='euclidean',
|
| 17 |
+
min_size=20, fill_holes=False,
|
| 18 |
+
filter_close_centroids=False, filter_radius=50) -> list:
|
| 19 |
+
"""
|
| 20 |
+
Determine centroids of clusters corresponding to potential damage sites.
|
| 21 |
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
|
| 22 |
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
|
| 23 |
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
|
|
|
|
| 29 |
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
|
| 30 |
|
| 31 |
Args:
|
| 32 |
+
image: Input SEM image. Can be a PIL Image or NumPy array.
|
| 33 |
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
|
| 34 |
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
|
| 35 |
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
|
| 36 |
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
|
| 37 |
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
|
| 38 |
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
|
| 39 |
+
filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
|
| 40 |
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
|
| 44 |
"""
|
| 45 |
|
|
|
|
| 46 |
centroids = []
|
| 47 |
+
logging.info(f"get_centroids: Input image type: {type(image)}")
|
| 48 |
+
|
| 49 |
+
# Convert PIL Image to NumPy array if necessary
|
| 50 |
+
if isinstance(image, Image.Image):
|
| 51 |
+
# Convert to grayscale if it's an RGB image, as thresholding is usually on single channel
|
| 52 |
+
if image.mode == 'RGB':
|
| 53 |
+
image_array = np.array(image.convert('L'))
|
| 54 |
+
logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.")
|
| 55 |
+
else:
|
| 56 |
+
image_array = np.array(image)
|
| 57 |
+
logging.info("get_centroids: Converted PIL Image to NumPy array.")
|
| 58 |
+
elif isinstance(image, np.ndarray):
|
| 59 |
+
# Ensure it's grayscale if it's a multi-channel numpy array
|
| 60 |
+
if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA
|
| 61 |
+
image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
|
| 62 |
+
logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.")
|
| 63 |
+
else:
|
| 64 |
+
image_array = image
|
| 65 |
+
logging.info("get_centroids: Image is already a NumPy array.")
|
| 66 |
+
else:
|
| 67 |
+
logging.error("get_centroids: Unsupported image format received.")
|
| 68 |
+
raise ValueError("Unsupported image format. Expected PIL Image or NumPy array.")
|
| 69 |
+
|
| 70 |
|
| 71 |
# apply the threshold to identify regions of "dark" pixels
|
| 72 |
# the result is a binary mask (true/false) whether a given pixel is above or below the threshold
|
| 73 |
+
cluster_candidates_mask = image_array < image_threshold # FIXED: Use image_array here
|
| 74 |
|
| 75 |
# sometimes the clusters have small holes in them, for example, individual pixels
|
| 76 |
# inside a region below the threshold. This may confuse the clustering algorith later on
|
|
|
|
| 78 |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
|
| 79 |
# N.B. the algorith only works on binay data
|
| 80 |
if fill_holes:
|
| 81 |
+
cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# transform image format into a numpy array to pass on to DBSCAN clustering
|
| 84 |
+
# Use the mask directly to get non-zero coordinates
|
| 85 |
+
cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
|
| 86 |
cluster_candidates = np.transpose(cluster_candidates)
|
| 87 |
|
| 88 |
+
# Handle case where no candidates are found after thresholding
|
| 89 |
+
if cluster_candidates.size == 0:
|
| 90 |
+
logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
|
| 91 |
+
return []
|
| 92 |
|
| 93 |
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
|
| 94 |
# (e.g. they are too small, etc)
|
| 95 |
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
|
| 96 |
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Use metric parameter
|
| 97 |
|
| 98 |
dbscan.fit(cluster_candidates)
|
| 99 |
|
|
|
|
| 101 |
# Number of clusters in labels, ignoring noise if present.
|
| 102 |
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
| 103 |
n_noise = list(labels).count(-1)
|
| 104 |
+
logging.info(f'# clusters {n_clusters}, #noise {n_noise}')
|
| 105 |
|
| 106 |
|
| 107 |
# now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
|
| 108 |
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
|
| 109 |
# in the cluster candidates
|
| 110 |
for i in set(labels):
|
| 111 |
+
if i > -1: # Ensure it's not noise
|
| 112 |
# all points belonging to a given cluster
|
| 113 |
+
cluster_points = cluster_candidates[labels == i, :]
|
| 114 |
if len(cluster_points) > min_size:
|
| 115 |
+
x_mean = np.mean(cluster_points, axis=0)[0]
|
| 116 |
+
y_mean = np.mean(cluster_points, axis=0)[1]
|
| 117 |
+
centroids.append([x_mean, y_mean])
|
| 118 |
|
| 119 |
+
if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
|
| 120 |
proximity_tree = KDTree(centroids)
|
| 121 |
pairs = proximity_tree.query_pairs(filter_radius)
|
| 122 |
+
|
| 123 |
+
# Use a set to mark indices for removal to avoid modifying list during iteration
|
| 124 |
+
indices_to_remove = set()
|
| 125 |
+
for p1_idx, p2_idx in pairs:
|
| 126 |
+
# Decide which one to remove. For simplicity, remove the one with the higher index
|
| 127 |
+
# This ensures you don't try to remove an index that might have already been removed
|
| 128 |
+
indices_to_remove.add(max(p1_idx, p2_idx))
|
| 129 |
+
|
| 130 |
+
# Rebuild the centroids list, excluding the marked ones
|
| 131 |
+
filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
|
| 132 |
+
centroids = filtered_centroids
|
| 133 |
+
logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
return centroids
|
| 137 |
+
|