File size: 8,184 Bytes
af3fe62
 
 
 
 
 
 
 
 
 
 
e360db8
af3fe62
 
e360db8
af3fe62
e360db8
 
 
 
af3fe62
 
 
 
 
 
 
 
 
 
 
e360db8
af3fe62
 
 
 
 
 
e360db8
af3fe62
 
 
 
 
 
da6b0fd
e360db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3fe62
 
e360db8
 
 
 
af3fe62
 
 
 
 
 
 
e360db8
af3fe62
 
e360db8
af3fe62
 
e360db8
 
 
 
af3fe62
 
 
 
e360db8
af3fe62
 
 
 
 
 
 
da6b0fd
af3fe62
 
 
 
 
 
e360db8
af3fe62
e360db8
af3fe62
e360db8
 
 
af3fe62
e360db8
af3fe62
 
e360db8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
    Before we can identify damage sites, we need to look for suitable regions in the image.
    Typically, damage sites appear as dark regions in the image. Instead of simple thresholding, we use 
    a clustering approach to identify regions that belong together and form damage site candidates.
"""

import numpy as np
import scipy.ndimage as ndi
from scipy.spatial import KDTree
from sklearn.cluster import DBSCAN
import logging
from PIL import Image # ADDED: Import PIL for image type checking/conversion


def get_centroids(image, image_threshold=20, # Removed type hint np.ndarray as it can also be PIL.Image.Image initially
                  eps=1, min_samples=5, metric='euclidean',
                  min_size=20, fill_holes=False,
                  filter_close_centroids=False, filter_radius=50) -> list:
    """ 
    Determine centroids of clusters corresponding to potential damage sites.
    In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
    Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
    from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
    to the list of (x,y) coordinates that is returned as the final list potential damage sites.

    Sometimes, clusters may be found in very close proximity to each other, we can reject those to avoid 
    classifying the same event multiple times (which may distort our statistics).

    DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html

    Args:
        image: Input SEM image (PIL Image or NumPy array).
        image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
        eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
        min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
        metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
        min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
        fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
        filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
        filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50

    Returns:
        list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
    """

    centroids = []
    logging.info(f"get_centroids: Input image type: {type(image)}") # Added logging

    # --- MINIMAL FIX START ---
    # Convert PIL Image to NumPy array if necessary
    if isinstance(image, Image.Image):
        # Convert to grayscale first for thresholding, assuming it's a single-channel operation
        if image.mode == 'RGB': # Handle RGB images by converting to grayscale 'L' mode
            image_array = np.array(image.convert('L'))
            logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.") # Added logging
        else: # Handle other PIL modes (like 'L' for grayscale)
            image_array = np.array(image)
            logging.info("get_centroids: Converted PIL Image to NumPy array.") # Added logging
    elif isinstance(image, np.ndarray):
        # If it's already a NumPy array, ensure it's grayscale if it was multi-channel
        if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA NumPy array
            image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
            logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.") # Added logging
        else: # Assume it's already a suitable grayscale NumPy array
            image_array = image
            logging.info("get_centroids: Image is already a suitable NumPy array.") # Added logging
    else:
        logging.error("get_centroids: Unsupported image format received. Expected PIL Image or NumPy array.") # Added logging
        raise ValueError("Unsupported image format. Expected PIL Image or NumPy array for thresholding.")

    # apply the threshold to identify regions of "dark" pixels
    # The result is a binary mask (true/false) whether a given pixel is above or below the threshold
    # Now using 'image_array' instead of 'image'
    cluster_candidates_mask = image_array < image_threshold
    # --- MINIMAL FIX END ---

    # sometimes the clusters have small holes in them, for example, individual pixels
    # inside a region below the threshold. This may confuse the clustering algorith later on
    # and we can use the following to fill these holes
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
    # N.B. the algorith only works on binay data
    if fill_holes:
        cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)

    # transform image format into a numpy array to pass on to DBSCAN clustering
    cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
    cluster_candidates = np.transpose(cluster_candidates)

    # Handle case where no candidates are found after thresholding
    if cluster_candidates.size == 0: # Added check for empty array
        logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
        return []

    # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
    # (e.g. they are too small, etc)
    # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
    dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Used 'metric' parameter

    dbscan.fit(cluster_candidates)

    labels = dbscan.labels_
    # Number of clusters in labels, ignoring noise if present.
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise = list(labels).count(-1)
    logging.debug('# clusters {}, #noise {}'.format(n_clusters, n_noise))


    # now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
    # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
    # in the cluster candidates
    for i in set(labels):
        if i > -1: # Ensure it's not noise
            # all points belonging to a given cluster
            cluster_points = cluster_candidates[labels == i, :]
            if len(cluster_points) > min_size:
                x_mean = np.mean(cluster_points, axis=0)[0]
                y_mean = np.mean(cluster_points, axis=0)[1]
                centroids.append([x_mean, y_mean])

    if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
        proximity_tree = KDTree(centroids)
        pairs = proximity_tree.query_pairs(filter_radius)
        
        # Use a set to mark indices for removal to avoid modifying list during iteration
        indices_to_remove = set()
        for p1_idx, p2_idx in pairs:
            # Decide which one to remove. For simplicity, remove the one with the higher index
            # This ensures you don't try to remove an index that might have already been removed
            indices_to_remove.add(max(p1_idx, p2_idx)) 
        
        # Rebuild the centroids list, excluding the marked ones
        filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
        centroids = filtered_centroids
        logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")


    return centroids