fix in clustering to operate on image
Browse files- app.py +1 -1
- clustering.py +63 -37
app.py
CHANGED
|
@@ -147,7 +147,7 @@ with gr.Blocks() as app:
|
|
| 147 |
gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
|
| 148 |
gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
|
| 149 |
|
| 150 |
-
image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
|
| 151 |
with gr.Row():
|
| 152 |
with gr.Column(scale=1):
|
| 153 |
image_input = gr.Image(type="pil", label="Upload SEM Image")
|
|
|
|
| 147 |
gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
|
| 148 |
gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
|
| 149 |
|
| 150 |
+
#image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
|
| 151 |
with gr.Row():
|
| 152 |
with gr.Column(scale=1):
|
| 153 |
image_input = gr.Image(type="pil", label="Upload SEM Image")
|
clustering.py
CHANGED
|
@@ -5,21 +5,19 @@
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
import scipy.ndimage as ndi
|
| 11 |
from scipy.spatial import KDTree
|
| 12 |
-
|
| 13 |
from sklearn.cluster import DBSCAN
|
| 14 |
-
|
| 15 |
import logging
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
-
def get_centroids(image
|
| 19 |
eps=1, min_samples=5, metric='euclidean',
|
| 20 |
-
min_size
|
| 21 |
-
filter_close_centroids
|
| 22 |
-
"""
|
|
|
|
| 23 |
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
|
| 24 |
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
|
| 25 |
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
|
|
@@ -31,27 +29,50 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
|
|
| 31 |
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
|
| 32 |
|
| 33 |
Args:
|
| 34 |
-
image
|
| 35 |
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
|
| 36 |
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
|
| 37 |
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
|
| 38 |
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
|
| 39 |
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
|
| 40 |
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
|
| 41 |
-
filter_close_centroids (
|
| 42 |
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
|
| 43 |
|
| 44 |
Returns:
|
| 45 |
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
|
| 46 |
"""
|
| 47 |
|
| 48 |
-
|
| 49 |
centroids = []
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# apply the threshold to identify regions of "dark" pixels
|
| 53 |
-
#
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# sometimes the clusters have small holes in them, for example, individual pixels
|
| 57 |
# inside a region below the threshold. This may confuse the clustering algorith later on
|
|
@@ -59,20 +80,21 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
|
|
| 59 |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
|
| 60 |
# N.B. the algorith only works on binay data
|
| 61 |
if fill_holes:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# apply the treshold to the image to identify regions of "dark" pixels
|
| 65 |
-
#cluster_candidates = np.asarray(image < image_threshold).nonzero()
|
| 66 |
|
| 67 |
# transform image format into a numpy array to pass on to DBSCAN clustering
|
| 68 |
-
cluster_candidates = np.asarray(
|
| 69 |
cluster_candidates = np.transpose(cluster_candidates)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
|
| 73 |
# (e.g. they are too small, etc)
|
| 74 |
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
|
| 75 |
-
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='
|
| 76 |
|
| 77 |
dbscan.fit(cluster_candidates)
|
| 78 |
|
|
@@ -87,25 +109,29 @@ def get_centroids(image : np.ndarray, image_threshold = 20,
|
|
| 87 |
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
|
| 88 |
# in the cluster candidates
|
| 89 |
for i in set(labels):
|
| 90 |
-
if i
|
| 91 |
# all points belonging to a given cluster
|
| 92 |
-
cluster_points = cluster_candidates[labels==i, :]
|
| 93 |
if len(cluster_points) > min_size:
|
| 94 |
-
x_mean=np.mean(cluster_points, axis=0)[0]
|
| 95 |
-
y_mean=np.mean(cluster_points, axis=0)[1]
|
| 96 |
-
centroids.append([x_mean,y_mean])
|
| 97 |
|
| 98 |
-
if filter_close_centroids:
|
| 99 |
proximity_tree = KDTree(centroids)
|
| 100 |
pairs = proximity_tree.query_pairs(filter_radius)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
|
|
|
|
|
|
| 8 |
import scipy.ndimage as ndi
|
| 9 |
from scipy.spatial import KDTree
|
|
|
|
| 10 |
from sklearn.cluster import DBSCAN
|
|
|
|
| 11 |
import logging
|
| 12 |
+
from PIL import Image # ADDED: Import PIL for image type checking/conversion
|
| 13 |
|
| 14 |
|
| 15 |
+
def get_centroids(image, image_threshold=20, # Removed type hint np.ndarray as it can also be PIL.Image.Image initially
|
| 16 |
eps=1, min_samples=5, metric='euclidean',
|
| 17 |
+
min_size=20, fill_holes=False,
|
| 18 |
+
filter_close_centroids=False, filter_radius=50) -> list:
|
| 19 |
+
"""
|
| 20 |
+
Determine centroids of clusters corresponding to potential damage sites.
|
| 21 |
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
|
| 22 |
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
|
| 23 |
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
|
|
|
|
| 29 |
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
|
| 30 |
|
| 31 |
Args:
|
| 32 |
+
image: Input SEM image (PIL Image or NumPy array).
|
| 33 |
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
|
| 34 |
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
|
| 35 |
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
|
| 36 |
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
|
| 37 |
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
|
| 38 |
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
|
| 39 |
+
filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
|
| 40 |
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
|
| 44 |
"""
|
| 45 |
|
|
|
|
| 46 |
centroids = []
|
| 47 |
+
logging.info(f"get_centroids: Input image type: {type(image)}") # Added logging
|
| 48 |
+
|
| 49 |
+
# --- MINIMAL FIX START ---
|
| 50 |
+
# Convert PIL Image to NumPy array if necessary
|
| 51 |
+
if isinstance(image, Image.Image):
|
| 52 |
+
# Convert to grayscale first for thresholding, assuming it's a single-channel operation
|
| 53 |
+
if image.mode == 'RGB': # Handle RGB images by converting to grayscale 'L' mode
|
| 54 |
+
image_array = np.array(image.convert('L'))
|
| 55 |
+
logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.") # Added logging
|
| 56 |
+
else: # Handle other PIL modes (like 'L' for grayscale)
|
| 57 |
+
image_array = np.array(image)
|
| 58 |
+
logging.info("get_centroids: Converted PIL Image to NumPy array.") # Added logging
|
| 59 |
+
elif isinstance(image, np.ndarray):
|
| 60 |
+
# If it's already a NumPy array, ensure it's grayscale if it was multi-channel
|
| 61 |
+
if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA NumPy array
|
| 62 |
+
image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
|
| 63 |
+
logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.") # Added logging
|
| 64 |
+
else: # Assume it's already a suitable grayscale NumPy array
|
| 65 |
+
image_array = image
|
| 66 |
+
logging.info("get_centroids: Image is already a suitable NumPy array.") # Added logging
|
| 67 |
+
else:
|
| 68 |
+
logging.error("get_centroids: Unsupported image format received. Expected PIL Image or NumPy array.") # Added logging
|
| 69 |
+
raise ValueError("Unsupported image format. Expected PIL Image or NumPy array for thresholding.")
|
| 70 |
|
| 71 |
# apply the threshold to identify regions of "dark" pixels
|
| 72 |
+
# The result is a binary mask (true/false) whether a given pixel is above or below the threshold
|
| 73 |
+
# Now using 'image_array' instead of 'image'
|
| 74 |
+
cluster_candidates_mask = image_array < image_threshold
|
| 75 |
+
# --- MINIMAL FIX END ---
|
| 76 |
|
| 77 |
# sometimes the clusters have small holes in them, for example, individual pixels
|
| 78 |
# inside a region below the threshold. This may confuse the clustering algorith later on
|
|
|
|
| 80 |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
|
| 81 |
# N.B. the algorith only works on binay data
|
| 82 |
if fill_holes:
|
| 83 |
+
cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# transform image format into a numpy array to pass on to DBSCAN clustering
|
| 86 |
+
cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
|
| 87 |
cluster_candidates = np.transpose(cluster_candidates)
|
| 88 |
|
| 89 |
+
# Handle case where no candidates are found after thresholding
|
| 90 |
+
if cluster_candidates.size == 0: # Added check for empty array
|
| 91 |
+
logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
|
| 92 |
+
return []
|
| 93 |
|
| 94 |
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
|
| 95 |
# (e.g. they are too small, etc)
|
| 96 |
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
|
| 97 |
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Used 'metric' parameter
|
| 98 |
|
| 99 |
dbscan.fit(cluster_candidates)
|
| 100 |
|
|
|
|
| 109 |
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
|
| 110 |
# in the cluster candidates
|
| 111 |
for i in set(labels):
|
| 112 |
+
if i > -1: # Ensure it's not noise
|
| 113 |
# all points belonging to a given cluster
|
| 114 |
+
cluster_points = cluster_candidates[labels == i, :]
|
| 115 |
if len(cluster_points) > min_size:
|
| 116 |
+
x_mean = np.mean(cluster_points, axis=0)[0]
|
| 117 |
+
y_mean = np.mean(cluster_points, axis=0)[1]
|
| 118 |
+
centroids.append([x_mean, y_mean])
|
| 119 |
|
| 120 |
+
if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
|
| 121 |
proximity_tree = KDTree(centroids)
|
| 122 |
pairs = proximity_tree.query_pairs(filter_radius)
|
| 123 |
+
|
| 124 |
+
# Use a set to mark indices for removal to avoid modifying list during iteration
|
| 125 |
+
indices_to_remove = set()
|
| 126 |
+
for p1_idx, p2_idx in pairs:
|
| 127 |
+
# Decide which one to remove. For simplicity, remove the one with the higher index
|
| 128 |
+
# This ensures you don't try to remove an index that might have already been removed
|
| 129 |
+
indices_to_remove.add(max(p1_idx, p2_idx))
|
| 130 |
+
|
| 131 |
+
# Rebuild the centroids list, excluding the marked ones
|
| 132 |
+
filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
|
| 133 |
+
centroids = filtered_centroids
|
| 134 |
+
logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
return centroids
|