kerzel commited on
Commit
e360db8
·
1 Parent(s): da6b0fd

fix in clustering to operate on image

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. clustering.py +63 -37
app.py CHANGED
@@ -147,7 +147,7 @@ with gr.Blocks() as app:
147
  gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
148
  gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
149
 
150
- image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
151
  with gr.Row():
152
  with gr.Column(scale=1):
153
  image_input = gr.Image(type="pil", label="Upload SEM Image")
 
147
  gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
148
  gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
149
 
150
+ #image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
151
  with gr.Row():
152
  with gr.Column(scale=1):
153
  image_input = gr.Image(type="pil", label="Upload SEM Image")
clustering.py CHANGED
@@ -5,21 +5,19 @@
5
  """
6
 
7
  import numpy as np
8
-
9
-
10
  import scipy.ndimage as ndi
11
  from scipy.spatial import KDTree
12
-
13
  from sklearn.cluster import DBSCAN
14
-
15
  import logging
 
16
 
17
 
18
- def get_centroids(image : np.ndarray, image_threshold = 20,
19
  eps=1, min_samples=5, metric='euclidean',
20
- min_size = 20, fill_holes = False,
21
- filter_close_centroids = False, filter_radius = 50) -> list:
22
- """ Determine centroids of clusters corresponding to potential damage sites.
 
23
  In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
24
  Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
25
  from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
@@ -31,27 +29,50 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
31
  DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
32
 
33
  Args:
34
- image (np.ndarray): Input SEM image
35
  image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
36
  eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
37
  min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
38
  metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
39
  min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
40
  fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
41
- filter_close_centroids (book optional): Filter cluster centroids within a given radius. Defaults to False
42
  filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
43
 
44
  Returns:
45
  list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
46
  """
47
 
48
-
49
  centroids = []
50
- #print('Threshold: ', image_threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # apply the threshold to identify regions of "dark" pixels
53
- # the result is a binary mask (true/false) whether a given pixel is above or below the threshold
54
- cluster_candidates = image < image_threshold
 
 
55
 
56
  # sometimes the clusters have small holes in them, for example, individual pixels
57
  # inside a region below the threshold. This may confuse the clustering algorith later on
@@ -59,20 +80,21 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
59
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
60
  # N.B. the algorith only works on binay data
61
  if fill_holes:
62
- cluster_candidates = ndi.binary_fill_holes(cluster_candidates)
63
-
64
- # apply the treshold to the image to identify regions of "dark" pixels
65
- #cluster_candidates = np.asarray(image < image_threshold).nonzero()
66
 
67
  # transform image format into a numpy array to pass on to DBSCAN clustering
68
- cluster_candidates = np.asarray(cluster_candidates).nonzero()
69
  cluster_candidates = np.transpose(cluster_candidates)
70
 
 
 
 
 
71
 
72
  # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
73
  # (e.g. they are too small, etc)
74
  # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
75
- dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='euclidean')
76
 
77
  dbscan.fit(cluster_candidates)
78
 
@@ -87,25 +109,29 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
87
  # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
88
  # in the cluster candidates
89
  for i in set(labels):
90
- if i>-1:
91
  # all points belonging to a given cluster
92
- cluster_points = cluster_candidates[labels==i, :]
93
  if len(cluster_points) > min_size:
94
- x_mean=np.mean(cluster_points, axis=0)[0]
95
- y_mean=np.mean(cluster_points, axis=0)[1]
96
- centroids.append([x_mean,y_mean])
97
 
98
- if filter_close_centroids:
99
  proximity_tree = KDTree(centroids)
100
  pairs = proximity_tree.query_pairs(filter_radius)
101
- for p in pairs:
102
- #print('pair: ', p, ' p[0]: ', p[0], ' p[1]:', p[1])
103
- #print('coords: ', proximity_tree.data[p[0]], ' ', proximity_tree.data[p[1]])
104
- coords_to_remove = [proximity_tree.data[p[0]][0], proximity_tree.data[p[0]][1]]
105
- try:
106
- idx = centroids.index(coords_to_remove)
107
- centroids.pop(idx)
108
- except ValueError:
109
- pass
110
-
111
- return centroids
 
 
 
 
 
5
  """
6
 
7
  import numpy as np
 
 
8
  import scipy.ndimage as ndi
9
  from scipy.spatial import KDTree
 
10
  from sklearn.cluster import DBSCAN
 
11
  import logging
12
+ from PIL import Image # ADDED: Import PIL for image type checking/conversion
13
 
14
 
15
+ def get_centroids(image, image_threshold=20, # Removed type hint np.ndarray as it can also be PIL.Image.Image initially
16
  eps=1, min_samples=5, metric='euclidean',
17
+ min_size=20, fill_holes=False,
18
+ filter_close_centroids=False, filter_radius=50) -> list:
19
+ """
20
+ Determine centroids of clusters corresponding to potential damage sites.
21
  In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
22
  Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
23
  from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
 
29
  DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
30
 
31
  Args:
32
+ image: Input SEM image (PIL Image or NumPy array).
33
  image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
34
  eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
35
  min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
36
  metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
37
  min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
38
  fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
39
+ filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
40
  filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
41
 
42
  Returns:
43
  list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
44
  """
45
 
 
46
  centroids = []
47
+ logging.info(f"get_centroids: Input image type: {type(image)}") # Added logging
48
+
49
+ # --- MINIMAL FIX START ---
50
+ # Convert PIL Image to NumPy array if necessary
51
+ if isinstance(image, Image.Image):
52
+ # Convert to grayscale first for thresholding, assuming it's a single-channel operation
53
+ if image.mode == 'RGB': # Handle RGB images by converting to grayscale 'L' mode
54
+ image_array = np.array(image.convert('L'))
55
+ logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.") # Added logging
56
+ else: # Handle other PIL modes (like 'L' for grayscale)
57
+ image_array = np.array(image)
58
+ logging.info("get_centroids: Converted PIL Image to NumPy array.") # Added logging
59
+ elif isinstance(image, np.ndarray):
60
+ # If it's already a NumPy array, ensure it's grayscale if it was multi-channel
61
+ if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA NumPy array
62
+ image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
63
+ logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.") # Added logging
64
+ else: # Assume it's already a suitable grayscale NumPy array
65
+ image_array = image
66
+ logging.info("get_centroids: Image is already a suitable NumPy array.") # Added logging
67
+ else:
68
+ logging.error("get_centroids: Unsupported image format received. Expected PIL Image or NumPy array.") # Added logging
69
+ raise ValueError("Unsupported image format. Expected PIL Image or NumPy array for thresholding.")
70
 
71
  # apply the threshold to identify regions of "dark" pixels
72
+ # The result is a binary mask (true/false) whether a given pixel is above or below the threshold
73
+ # Now using 'image_array' instead of 'image'
74
+ cluster_candidates_mask = image_array < image_threshold
75
+ # --- MINIMAL FIX END ---
76
 
77
  # sometimes the clusters have small holes in them, for example, individual pixels
78
  # inside a region below the threshold. This may confuse the clustering algorith later on
 
80
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
81
  # N.B. the algorith only works on binay data
82
  if fill_holes:
83
+ cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)
 
 
 
84
 
85
  # transform image format into a numpy array to pass on to DBSCAN clustering
86
+ cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
87
  cluster_candidates = np.transpose(cluster_candidates)
88
 
89
+ # Handle case where no candidates are found after thresholding
90
+ if cluster_candidates.size == 0: # Added check for empty array
91
+ logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
92
+ return []
93
 
94
  # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
95
  # (e.g. they are too small, etc)
96
  # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
97
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Used 'metric' parameter
98
 
99
  dbscan.fit(cluster_candidates)
100
 
 
109
  # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
110
  # in the cluster candidates
111
  for i in set(labels):
112
+ if i > -1: # Ensure it's not noise
113
  # all points belonging to a given cluster
114
+ cluster_points = cluster_candidates[labels == i, :]
115
  if len(cluster_points) > min_size:
116
+ x_mean = np.mean(cluster_points, axis=0)[0]
117
+ y_mean = np.mean(cluster_points, axis=0)[1]
118
+ centroids.append([x_mean, y_mean])
119
 
120
+ if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
121
  proximity_tree = KDTree(centroids)
122
  pairs = proximity_tree.query_pairs(filter_radius)
123
+
124
+ # Use a set to mark indices for removal to avoid modifying list during iteration
125
+ indices_to_remove = set()
126
+ for p1_idx, p2_idx in pairs:
127
+ # Decide which one to remove. For simplicity, remove the one with the higher index
128
+ # This ensures you don't try to remove an index that might have already been removed
129
+ indices_to_remove.add(max(p1_idx, p2_idx))
130
+
131
+ # Rebuild the centroids list, excluding the marked ones
132
+ filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
133
+ centroids = filtered_centroids
134
+ logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
135
+
136
+
137
+ return centroids