kerzel commited on
Commit
fd4e02c
·
1 Parent(s): d7afb8c

debug clustering

Browse files
Files changed (1) hide show
  1. clustering.py +63 -37
clustering.py CHANGED
@@ -5,21 +5,19 @@
5
  """
6
 
7
  import numpy as np
8
-
9
-
10
  import scipy.ndimage as ndi
11
  from scipy.spatial import KDTree
12
-
13
  from sklearn.cluster import DBSCAN
14
-
15
  import logging
 
16
 
17
 
18
- def get_centroids(image : np.ndarray, image_threshold = 20,
19
  eps=1, min_samples=5, metric='euclidean',
20
- min_size = 20, fill_holes = False,
21
- filter_close_centroids = False, filter_radius = 50) -> list:
22
- """ Determine centroids of clusters corresponding to potential damage sites.
 
23
  In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
24
  Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
25
  from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
@@ -31,27 +29,48 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
31
  DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
32
 
33
  Args:
34
- image (np.ndarray): Input SEM image
35
  image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
36
  eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
37
  min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
38
  metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
39
  min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
40
  fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
41
- filter_close_centroids (book optional): Filter cluster centroids within a given radius. Defaults to False
42
  filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
43
 
44
  Returns:
45
  list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
46
  """
47
 
48
-
49
  centroids = []
50
- #print('Threshold: ', image_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # apply the threshold to identify regions of "dark" pixels
53
  # the result is a binary mask (true/false) whether a given pixel is above or below the threshold
54
- cluster_candidates = image < image_threshold
55
 
56
  # sometimes the clusters have small holes in them, for example, individual pixels
57
  # inside a region below the threshold. This may confuse the clustering algorith later on
@@ -59,20 +78,22 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
59
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
60
  # N.B. the algorith only works on binay data
61
  if fill_holes:
62
- cluster_candidates = ndi.binary_fill_holes(cluster_candidates)
63
-
64
- # apply the treshold to the image to identify regions of "dark" pixels
65
- #cluster_candidates = np.asarray(image < image_threshold).nonzero()
66
 
67
  # transform image format into a numpy array to pass on to DBSCAN clustering
68
- cluster_candidates = np.asarray(cluster_candidates).nonzero()
 
69
  cluster_candidates = np.transpose(cluster_candidates)
70
 
 
 
 
 
71
 
72
  # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
73
  # (e.g. they are too small, etc)
74
  # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
75
- dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='euclidean')
76
 
77
  dbscan.fit(cluster_candidates)
78
 
@@ -80,32 +101,37 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
80
  # Number of clusters in labels, ignoring noise if present.
81
  n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
82
  n_noise = list(labels).count(-1)
83
- logging.debug('# clusters {}, #noise {}'.format(n_clusters, n_noise))
84
 
85
 
86
  # now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
87
  # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
88
  # in the cluster candidates
89
  for i in set(labels):
90
- if i>-1:
91
  # all points belonging to a given cluster
92
- cluster_points = cluster_candidates[labels==i, :]
93
  if len(cluster_points) > min_size:
94
- x_mean=np.mean(cluster_points, axis=0)[0]
95
- y_mean=np.mean(cluster_points, axis=0)[1]
96
- centroids.append([x_mean,y_mean])
97
 
98
- if filter_close_centroids:
99
  proximity_tree = KDTree(centroids)
100
  pairs = proximity_tree.query_pairs(filter_radius)
101
- for p in pairs:
102
- #print('pair: ', p, ' p[0]: ', p[0], ' p[1]:', p[1])
103
- #print('coords: ', proximity_tree.data[p[0]], ' ', proximity_tree.data[p[1]])
104
- coords_to_remove = [proximity_tree.data[p[0]][0], proximity_tree.data[p[0]][1]]
105
- try:
106
- idx = centroids.index(coords_to_remove)
107
- centroids.pop(idx)
108
- except ValueError:
109
- pass
110
-
111
- return centroids
 
 
 
 
 
 
5
  """
6
 
7
  import numpy as np
 
 
8
  import scipy.ndimage as ndi
9
  from scipy.spatial import KDTree
 
10
  from sklearn.cluster import DBSCAN
 
11
  import logging
12
+ from PIL import Image # Import PIL for type checking/conversion if necessary
13
 
14
 
15
+ def get_centroids(image, image_threshold=20,
16
  eps=1, min_samples=5, metric='euclidean',
17
+ min_size=20, fill_holes=False,
18
+ filter_close_centroids=False, filter_radius=50) -> list:
19
+ """
20
+ Determine centroids of clusters corresponding to potential damage sites.
21
  In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
22
  Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
23
  from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
 
29
  DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
30
 
31
  Args:
32
+ image: Input SEM image. Can be a PIL Image or NumPy array.
33
  image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
34
  eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
35
  min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
36
  metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
37
  min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
38
  fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
39
+ filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
40
  filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
41
 
42
  Returns:
43
  list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
44
  """
45
 
 
46
  centroids = []
47
+ logging.info(f"get_centroids: Input image type: {type(image)}")
48
+
49
+ # Convert PIL Image to NumPy array if necessary
50
+ if isinstance(image, Image.Image):
51
+ # Convert to grayscale if it's an RGB image, as thresholding is usually on single channel
52
+ if image.mode == 'RGB':
53
+ image_array = np.array(image.convert('L'))
54
+ logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.")
55
+ else:
56
+ image_array = np.array(image)
57
+ logging.info("get_centroids: Converted PIL Image to NumPy array.")
58
+ elif isinstance(image, np.ndarray):
59
+ # Ensure it's grayscale if it's a multi-channel numpy array
60
+ if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA
61
+ image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
62
+ logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.")
63
+ else:
64
+ image_array = image
65
+ logging.info("get_centroids: Image is already a NumPy array.")
66
+ else:
67
+ logging.error("get_centroids: Unsupported image format received.")
68
+ raise ValueError("Unsupported image format. Expected PIL Image or NumPy array.")
69
+
70
 
71
  # apply the threshold to identify regions of "dark" pixels
72
  # the result is a binary mask (true/false) whether a given pixel is above or below the threshold
73
+ cluster_candidates_mask = image_array < image_threshold # FIXED: Use image_array here
74
 
75
  # sometimes the clusters have small holes in them, for example, individual pixels
76
  # inside a region below the threshold. This may confuse the clustering algorith later on
 
78
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
79
  # N.B. the algorith only works on binay data
80
  if fill_holes:
81
+ cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)
 
 
 
82
 
83
  # transform image format into a numpy array to pass on to DBSCAN clustering
84
+ # Use the mask directly to get non-zero coordinates
85
+ cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
86
  cluster_candidates = np.transpose(cluster_candidates)
87
 
88
+ # Handle case where no candidates are found after thresholding
89
+ if cluster_candidates.size == 0:
90
+ logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
91
+ return []
92
 
93
  # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
94
  # (e.g. they are too small, etc)
95
  # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
96
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Use metric parameter
97
 
98
  dbscan.fit(cluster_candidates)
99
 
 
101
  # Number of clusters in labels, ignoring noise if present.
102
  n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
103
  n_noise = list(labels).count(-1)
104
+ logging.info(f'# clusters {n_clusters}, #noise {n_noise}')
105
 
106
 
107
  # now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
108
  # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
109
  # in the cluster candidates
110
  for i in set(labels):
111
+ if i > -1: # Ensure it's not noise
112
  # all points belonging to a given cluster
113
+ cluster_points = cluster_candidates[labels == i, :]
114
  if len(cluster_points) > min_size:
115
+ x_mean = np.mean(cluster_points, axis=0)[0]
116
+ y_mean = np.mean(cluster_points, axis=0)[1]
117
+ centroids.append([x_mean, y_mean])
118
 
119
+ if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
120
  proximity_tree = KDTree(centroids)
121
  pairs = proximity_tree.query_pairs(filter_radius)
122
+
123
+ # Use a set to mark indices for removal to avoid modifying list during iteration
124
+ indices_to_remove = set()
125
+ for p1_idx, p2_idx in pairs:
126
+ # Decide which one to remove. For simplicity, remove the one with the higher index
127
+ # This ensures you don't try to remove an index that might have already been removed
128
+ indices_to_remove.add(max(p1_idx, p2_idx))
129
+
130
+ # Rebuild the centroids list, excluding the marked ones
131
+ filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
132
+ centroids = filtered_centroids
133
+ logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
134
+
135
+
136
+ return centroids
137
+