kerzel commited on
Commit
da6b0fd
·
1 Parent(s): 25616b8

try to revert back to old code but keep change in interface

Browse files
Files changed (3) hide show
  1. app.py +128 -132
  2. clustering.py +37 -63
  3. utils.py +141 -138
app.py CHANGED
@@ -1,157 +1,153 @@
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
- from PIL import Image
5
- import logging
6
- import os # Import os for path checks
7
-
8
- # Placeholder imports for clustering and utils.
9
- # In a real scenario, these files (clustering.py, utils.py)
10
- # would contain your actual implementation.
11
- try:
12
- import clustering
13
- import utils
14
- except ImportError as e:
15
- logging.error(f"Error importing helper modules: {e}. Using dummy functions.")
16
- # Define dummy functions if imports fail, to allow the app to launch.
17
- class DummyClustering:
18
- def get_centroids(self, *args, **kwargs):
19
- logging.warning("Using dummy get_centroids. Provide actual clustering.py.")
20
- # Return some dummy centroids for demonstration
21
- # In a real scenario, you might want to raise an error or return an empty list
22
- # if clustering is critical for app functionality.
23
- return [(100, 100), (200, 200)]
24
-
25
- class DummyUtils:
26
- def prepare_classifier_input(self, *args, **kwargs):
27
- logging.warning("Using dummy prepare_classifier_input. Provide actual utils.py.")
28
- # Return dummy data for model input
29
- return np.zeros((1, 250, 250, 3)) # Example shape, adjust as per your model input
30
-
31
- def show_boxes(self, image, damage_sites, save_image=False, image_path=None):
32
- logging.warning("Using dummy show_boxes. Provide actual utils.py.")
33
- # Return the original image for dummy display
34
- # In a real app, this would draw boxes
35
- if image is None:
36
- return Image.new('RGB', (400, 400), color = 'red') # Placeholder if no image provided
37
- return image
38
-
39
- clustering = DummyClustering()
40
- utils = DummyUtils()
41
 
 
 
 
 
 
 
42
 
43
  from tensorflow import keras
44
 
45
- # --- Basic Setup ---
46
- logging.getLogger().setLevel(logging.INFO)
47
 
48
  # --- Constants and Model Loading ---
49
  IMAGE_PATH = "classified_damage_sites.png"
50
  CSV_PATH = "classified_damage_sites.csv"
51
 
52
- # Load models once at startup to improve performance
53
- model1 = None
54
- model2 = None
55
-
56
- try:
57
- # Check if model files exist before attempting to load
58
- if os.path.exists('rwthmaterials_dp800_network1_inclusion.h5'):
59
- model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
60
- logging.info("Model 1 loaded successfully.")
61
- else:
62
- logging.warning("Model 1 (rwthmaterials_dp800_network1_inclusion.h5) not found. Classification results may be inaccurate.")
63
-
64
- if os.path.exists('rwthmaterials_dp800_network2_damage.h5'):
65
- model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
66
- logging.info("Model 2 loaded successfully.")
67
- else:
68
- logging.warning("Model 2 (rwthmaterials_dp800_network2_damage.h5) not found. Classification results may be inaccurate.")
69
-
70
- except Exception as e:
71
- logging.error(f"Error loading models: {e}")
72
- # Models are set to None, and warnings/errors are logged.
73
- # The app will still attempt to launch.
74
-
75
- damage_classes = {3: "Martensite", 2: "Interface", 0: "Notch", 1: "Shadowing"}
76
- model1_windowsize = [250, 250]
77
- model2_windowsize = [100, 100]
78
-
79
- # --- Core Processing Function (Your original logic) ---
80
- def damage_classification(SEM_image, image_threshold, model1_threshold, model2_threshold):
81
- """
82
- This function contains the core scientific logic for classifying damage sites.
83
- It returns the classified image and paths to the output files.
84
- """
85
- if SEM_image is None:
86
- raise gr.Error("Please upload an SEM Image before running classification.")
87
-
88
- if model1 is None or model2 is None:
89
- raise gr.Error("Models not loaded. Please ensure model files are present and valid.")
90
 
91
- damage_sites = {}
92
-
93
- # Step 1: Clustering to find damage centroids
94
- # Ensure clustering.get_centroids handles the case of no centroids found
95
- all_centroids = clustering.get_centroids(
96
- SEM_image,
97
- image_threshold=image_threshold,
98
- fill_holes=True,
99
- filter_close_centroids=True,
100
- )
101
-
102
- for c in all_centroids:
103
- damage_sites[(c[0], c[1])] = "Not Classified"
104
 
105
- # Step 2: Model 1 to identify inclusions
106
- if len(all_centroids) > 0:
107
- try:
108
- images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
109
- y1_pred = model1.predict(np.asarray(images_model1, dtype=float))
110
- inclusions = np.where(y1_pred[:, 0] > model1_threshold)[0]
111
- for idx in inclusions:
112
- coord = all_centroids[idx]
113
- damage_sites[(coord[0], coord[1])] = "Inclusion"
114
- except Exception as e:
115
- logging.error(f"Error during Model 1 prediction: {e}")
116
 
117
- # Step 3: Model 2 to classify remaining damage types
118
- centroids_model2 = [list(k) for k, v in damage_sites.items() if v == "Not Classified"]
119
- if centroids_model2:
120
- try:
121
- images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
122
- y2_pred = model2.predict(np.asarray(images_model2, dtype=float))
123
- # Adjust the thresholding for damage_index to handle potential empty results
124
- damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
125
-
126
- for i in range(len(damage_index[0])):
127
- sample_idx = damage_index[0][i]
128
- class_idx = damage_index[1][i]
129
- label = damage_classes.get(class_idx, "Unknown")
130
- coord = centroids_model2[sample_idx]
131
- damage_sites[(coord[0], coord[1])] = label
132
- except Exception as e:
133
- logging.error(f"Error during Model 2 prediction: {e}")
134
 
135
- # Step 4: Draw boxes on image and save output image
136
- # The utils.show_boxes function is assumed to return a PIL Image object
137
- image_with_boxes = utils.show_boxes(SEM_image, damage_sites, save_image=True, image_path=IMAGE_PATH)
138
 
139
- # Step 5: Export CSV file
140
- data = [[x, y, label] for (x, y), label in damage_sites.items()]
141
- df = pd.DataFrame(data, columns=["x", "y", "damage_type"])
142
- df.to_csv(CSV_PATH, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # Log file paths to ensure they are correct
145
- logging.info(f"Generated Image Path: {IMAGE_PATH}")
146
- logging.info(f"Generated CSV Path: {CSV_PATH}")
147
 
148
- return image_with_boxes, IMAGE_PATH, CSV_PATH
149
 
150
- # --- Gradio Interface Definition ---
 
 
151
  with gr.Blocks() as app:
152
- gr.Markdown("# Damage Classification in Dual Phase Steels")
153
- gr.Markdown("Upload a Scanning Electron Microscope (SEM) image and set the thresholds to classify material damage.")
 
 
 
 
 
 
154
 
 
155
  with gr.Row():
156
  with gr.Column(scale=1):
157
  image_input = gr.Image(type="pil", label="Upload SEM Image")
 
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # our own helper tools
6
+ import clustering
7
+ import utils
8
+
9
+ import logging
10
+ logging.getLogger().setLevel(logging.INFO)
11
 
12
  from tensorflow import keras
13
 
14
+ #image_threshold = 20
15
+
16
 
17
  # --- Constants and Model Loading ---
18
  IMAGE_PATH = "classified_damage_sites.png"
19
  CSV_PATH = "classified_damage_sites.csv"
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ model1_windowsize = [250,250]
23
+ #model1_threshold = 0.7
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
26
+ model1.compile()
27
+
28
+ damage_classes = {3: "Martensite",2: "Interface",0:"Notch",1:"Shadowing"}
29
+
30
+ model2_windowsize = [100,100]
31
+ #model2_threshold = 0.5
32
+
33
+ model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
34
+ model2.compile()
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
37
 
38
+ ##
39
+ ## Function to do the actual damage classification
40
+ ##
41
+ def damage_classification(SEM_image,image_threshold, model1_threshold, model2_threshold):
42
+
43
+ damage_sites = {}
44
+ ##
45
+ ## clustering
46
+ ##
47
+ logging.debug('---------------: clustering :=====================')
48
+ all_centroids = clustering.get_centroids(SEM_image, image_threshold=image_threshold,
49
+ fill_holes=True, filter_close_centroids=True)
50
+
51
+ for i in range(len(all_centroids)) :
52
+ key = (all_centroids[i][0],all_centroids[i][1])
53
+ damage_sites[key] = 'Not Classified'
54
+
55
+ ##
56
+ ## Inclusions vs the rest
57
+ ##
58
+ logging.debug('---------------: prepare model 1 :=====================')
59
+ images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
60
+
61
+ logging.debug('---------------: run model 1 :=====================')
62
+ y1_pred = model1.predict(np.asarray(images_model1, float))
63
+
64
+ logging.debug('---------------: model1 threshold :=====================')
65
+ inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
66
+ inclusions = np.where(inclusions > model1_threshold)
67
+
68
+ logging.debug('---------------: model 1 update dict :=====================')
69
+ for i in range(len(inclusions[0])):
70
+ centroid_id = inclusions[0][i]
71
+ coordinates = all_centroids[centroid_id]
72
+ key = (coordinates[0], coordinates[1])
73
+ damage_sites[key] = 'Inclusion'
74
+ logging.debug('Damage sites after model 1')
75
+ logging.debug(damage_sites)
76
+
77
+ ##
78
+ ## Martensite cracking, etc
79
+ ##
80
+ logging.debug('---------------: prepare model 2 :=====================')
81
+ centroids_model2 = []
82
+ for key, value in damage_sites.items():
83
+ if value == 'Not Classified':
84
+ coordinates = list([key[0],key[1]])
85
+ centroids_model2.append(coordinates)
86
+ logging.debug('Centroids model 2')
87
+ logging.debug(centroids_model2)
88
+
89
+ logging.debug('---------------: prepare model 2 :=====================')
90
+ images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
91
+ logging.debug('Images model 2')
92
+ logging.debug(images_model2)
93
+
94
+ logging.debug('---------------: run model 2 :=====================')
95
+ y2_pred = model2.predict(np.asarray(images_model2, float))
96
+
97
+ damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
98
+
99
+
100
+ for i in range(len(damage_index[0])):
101
+ index = damage_index[0][i]
102
+ identified_class = damage_index[1][i]
103
+ label = damage_classes[identified_class]
104
+ coordinates = centroids_model2[index]
105
+ #print('Damage {} \t identified as {}, \t coordinates {}'.format(i, label, coordinates))
106
+ key = (coordinates[0], coordinates[1])
107
+ damage_sites[key] = label
108
+
109
+ ##
110
+ ## show the damage sites on the image
111
+ ##
112
+ logging.debug("-----------------: final damage sites :=================")
113
+ logging.debug(damage_sites)
114
+
115
+ image_path = 'classified_damage_sites.png'
116
+ image = utils.show_boxes(SEM_image, damage_sites,
117
+ save_image=True,
118
+ image_path=image_path)
119
+
120
+ ##
121
+ ## export data
122
+ ##
123
+ csv_path = 'classified_damage_sites.csv'
124
+ cols = ['x', 'y', 'damage_type']
125
+
126
+ data = []
127
+ for key, value in damage_sites.items():
128
+ data.append([key[0], key[1], value])
129
+
130
+ df = pd.DataFrame(columns=cols, data=data)
131
+
132
+ df.to_csv(csv_path)
133
 
 
 
 
134
 
135
+ return image, image_path, csv_path
136
 
137
+ ## ---------------------------------------------------------------------------------------------------------------
138
+ ## main app interface
139
+ ## -----------------------------------------------------------------------------------------------------------------
140
  with gr.Blocks() as app:
141
+ gr.Markdown('# Damage Classification in Dual Phase Steels')
142
+ gr.Markdown('This app classifies damage types in dual phase steels. Two models are used. The first model is used to identify inclusions in the steel. The second model is used to identify the remaining damage types: Martensite cracking, Interface Decohesion, Notch effect and Shadows.')
143
+
144
+ gr.Markdown('The models used in this app are based on the following papers:')
145
+ gr.Markdown('Kusche, C., Reclik, T., Freund, M., Al-Samman, T., Kerzel, U., & Korte-Kerzel, S. (2019). Large-area, high-resolution characterisation and classification of damage mechanisms in dual-phase steel using deep learning. PloS one, 14(5), e0216493. [Link](https://doi.org/10.1371/journal.pone.0216493)')
146
+ #gr.Markdown('Medghalchi, S., Kusche, C. F., Karimi, E., Kerzel, U., & Korte-Kerzel, S. (2020). Damage analysis in dual-phase steel using deep learning: transfer from uniaxial to biaxial straining conditions by image data augmentation. Jom, 72, 4420-4430. [Link](https://link.springer.com/article/10.1007/s11837-020-04404-0)')
147
+ gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
148
+ gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
149
 
150
+ image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
151
  with gr.Row():
152
  with gr.Column(scale=1):
153
  image_input = gr.Image(type="pil", label="Upload SEM Image")
clustering.py CHANGED
@@ -5,19 +5,21 @@
5
  """
6
 
7
  import numpy as np
 
 
8
  import scipy.ndimage as ndi
9
  from scipy.spatial import KDTree
 
10
  from sklearn.cluster import DBSCAN
 
11
  import logging
12
- from PIL import Image # Import PIL for type checking/conversion if necessary
13
 
14
 
15
- def get_centroids(image, image_threshold=20,
16
  eps=1, min_samples=5, metric='euclidean',
17
- min_size=20, fill_holes=False,
18
- filter_close_centroids=False, filter_radius=50) -> list:
19
- """
20
- Determine centroids of clusters corresponding to potential damage sites.
21
  In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
22
  Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
23
  from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
@@ -29,48 +31,27 @@ def get_centroids(image, image_threshold=20,
29
  DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
30
 
31
  Args:
32
- image: Input SEM image. Can be a PIL Image or NumPy array.
33
  image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
34
  eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
35
  min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
36
  metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
37
  min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
38
  fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
39
- filter_close_centroids (bool, optional): Filter cluster centroids within a given radius. Defaults to False
40
  filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
41
 
42
  Returns:
43
  list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
44
  """
45
 
46
- centroids = []
47
- logging.info(f"get_centroids: Input image type: {type(image)}")
48
-
49
- # Convert PIL Image to NumPy array if necessary
50
- if isinstance(image, Image.Image):
51
- # Convert to grayscale if it's an RGB image, as thresholding is usually on single channel
52
- if image.mode == 'RGB':
53
- image_array = np.array(image.convert('L'))
54
- logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.")
55
- else:
56
- image_array = np.array(image)
57
- logging.info("get_centroids: Converted PIL Image to NumPy array.")
58
- elif isinstance(image, np.ndarray):
59
- # Ensure it's grayscale if it's a multi-channel numpy array
60
- if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA
61
- image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
62
- logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.")
63
- else:
64
- image_array = image
65
- logging.info("get_centroids: Image is already a NumPy array.")
66
- else:
67
- logging.error("get_centroids: Unsupported image format received.")
68
- raise ValueError("Unsupported image format. Expected PIL Image or NumPy array.")
69
 
 
 
70
 
71
  # apply the threshold to identify regions of "dark" pixels
72
  # the result is a binary mask (true/false) whether a given pixel is above or below the threshold
73
- cluster_candidates_mask = image_array < image_threshold # FIXED: Use image_array here
74
 
75
  # sometimes the clusters have small holes in them, for example, individual pixels
76
  # inside a region below the threshold. This may confuse the clustering algorith later on
@@ -78,22 +59,20 @@ def get_centroids(image, image_threshold=20,
78
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
79
  # N.B. the algorith only works on binay data
80
  if fill_holes:
81
- cluster_candidates_mask = ndi.binary_fill_holes(cluster_candidates_mask)
 
 
 
82
 
83
  # transform image format into a numpy array to pass on to DBSCAN clustering
84
- # Use the mask directly to get non-zero coordinates
85
- cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
86
  cluster_candidates = np.transpose(cluster_candidates)
87
 
88
- # Handle case where no candidates are found after thresholding
89
- if cluster_candidates.size == 0:
90
- logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
91
- return []
92
 
93
  # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
94
  # (e.g. they are too small, etc)
95
  # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
96
- dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=metric) # Use metric parameter
97
 
98
  dbscan.fit(cluster_candidates)
99
 
@@ -101,37 +80,32 @@ def get_centroids(image, image_threshold=20,
101
  # Number of clusters in labels, ignoring noise if present.
102
  n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
103
  n_noise = list(labels).count(-1)
104
- logging.info(f'# clusters {n_clusters}, #noise {n_noise}')
105
 
106
 
107
  # now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
108
  # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
109
  # in the cluster candidates
110
  for i in set(labels):
111
- if i > -1: # Ensure it's not noise
112
  # all points belonging to a given cluster
113
- cluster_points = cluster_candidates[labels == i, :]
114
  if len(cluster_points) > min_size:
115
- x_mean = np.mean(cluster_points, axis=0)[0]
116
- y_mean = np.mean(cluster_points, axis=0)[1]
117
- centroids.append([x_mean, y_mean])
118
 
119
- if filter_close_centroids and len(centroids) > 1: # Only filter if there's more than one centroid
120
  proximity_tree = KDTree(centroids)
121
  pairs = proximity_tree.query_pairs(filter_radius)
122
-
123
- # Use a set to mark indices for removal to avoid modifying list during iteration
124
- indices_to_remove = set()
125
- for p1_idx, p2_idx in pairs:
126
- # Decide which one to remove. For simplicity, remove the one with the higher index
127
- # This ensures you don't try to remove an index that might have already been removed
128
- indices_to_remove.add(max(p1_idx, p2_idx))
129
-
130
- # Rebuild the centroids list, excluding the marked ones
131
- filtered_centroids = [centroid for i, centroid in enumerate(centroids) if i not in indices_to_remove]
132
- centroids = filtered_centroids
133
- logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
134
-
135
-
136
- return centroids
137
-
 
5
  """
6
 
7
  import numpy as np
8
+
9
+
10
  import scipy.ndimage as ndi
11
  from scipy.spatial import KDTree
12
+
13
  from sklearn.cluster import DBSCAN
14
+
15
  import logging
 
16
 
17
 
18
+ def get_centroids(image : np.ndarray, image_threshold = 20,
19
  eps=1, min_samples=5, metric='euclidean',
20
+ min_size = 20, fill_holes = False,
21
+ filter_close_centroids = False, filter_radius = 50) -> list:
22
+ """ Determine centroids of clusters corresponding to potential damage sites.
 
23
  In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
24
  Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
25
  from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
 
31
  DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
32
 
33
  Args:
34
+ image (np.ndarray): Input SEM image
35
  image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
36
  eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
37
  min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
38
  metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
39
  min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
40
  fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
41
+ filter_close_centroids (book optional): Filter cluster centroids within a given radius. Defaults to False
42
  filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
43
 
44
  Returns:
45
  list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
46
  """
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ centroids = []
50
+ #print('Threshold: ', image_threshold)
51
 
52
  # apply the threshold to identify regions of "dark" pixels
53
  # the result is a binary mask (true/false) whether a given pixel is above or below the threshold
54
+ cluster_candidates = image < image_threshold
55
 
56
  # sometimes the clusters have small holes in them, for example, individual pixels
57
  # inside a region below the threshold. This may confuse the clustering algorith later on
 
59
  # https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
60
  # N.B. the algorith only works on binay data
61
  if fill_holes:
62
+ cluster_candidates = ndi.binary_fill_holes(cluster_candidates)
63
+
64
+ # apply the treshold to the image to identify regions of "dark" pixels
65
+ #cluster_candidates = np.asarray(image < image_threshold).nonzero()
66
 
67
  # transform image format into a numpy array to pass on to DBSCAN clustering
68
+ cluster_candidates = np.asarray(cluster_candidates).nonzero()
 
69
  cluster_candidates = np.transpose(cluster_candidates)
70
 
 
 
 
 
71
 
72
  # run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
73
  # (e.g. they are too small, etc)
74
  # For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
75
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='euclidean')
76
 
77
  dbscan.fit(cluster_candidates)
78
 
 
80
  # Number of clusters in labels, ignoring noise if present.
81
  n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
82
  n_noise = list(labels).count(-1)
83
+ logging.debug('# clusters {}, #noise {}'.format(n_clusters, n_noise))
84
 
85
 
86
  # now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
87
  # we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
88
  # in the cluster candidates
89
  for i in set(labels):
90
+ if i>-1:
91
  # all points belonging to a given cluster
92
+ cluster_points = cluster_candidates[labels==i, :]
93
  if len(cluster_points) > min_size:
94
+ x_mean=np.mean(cluster_points, axis=0)[0]
95
+ y_mean=np.mean(cluster_points, axis=0)[1]
96
+ centroids.append([x_mean,y_mean])
97
 
98
+ if filter_close_centroids:
99
  proximity_tree = KDTree(centroids)
100
  pairs = proximity_tree.query_pairs(filter_radius)
101
+ for p in pairs:
102
+ #print('pair: ', p, ' p[0]: ', p[0], ' p[1]:', p[1])
103
+ #print('coords: ', proximity_tree.data[p[0]], ' ', proximity_tree.data[p[1]])
104
+ coords_to_remove = [proximity_tree.data[p[0]][0], proximity_tree.data[p[0]][1]]
105
+ try:
106
+ idx = centroids.index(coords_to_remove)
107
+ centroids.pop(idx)
108
+ except ValueError:
109
+ pass
110
+
111
+ return centroids
 
 
 
 
 
utils.py CHANGED
@@ -1,160 +1,163 @@
 
 
 
 
1
  import numpy as np
2
- from PIL import Image, ImageDraw
3
- import logging
4
 
5
- def prepare_classifier_input(image, centroids, window_size):
6
- """
7
- Extracts image patches around centroids and prepares them as input for Keras models.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  Args:
10
- image: The input SEM image (PIL Image or NumPy array).
11
- centroids (list): List of (x,y) coordinates of damage site centroids.
12
- window_size (list): [height, width] of the square window to extract around each centroid.
13
 
14
  Returns:
15
- np.ndarray: A batch of image patches, ready for model prediction.
16
  """
17
- logging.info(f"prepare_classifier_input: Input image type: {type(image)}")
18
-
19
- # Convert PIL Image to NumPy array if necessary
20
- if isinstance(image, Image.Image):
21
- # Convert to RGB first to ensure 3 channels for consistent model input
22
- image_array = np.array(image.convert('RGB'))
23
- logging.info("prepare_classifier_input: Converted PIL Image to RGB NumPy array.")
24
- elif isinstance(image, np.ndarray):
25
- # Ensure it's a 3-channel array for consistency if it's already NumPy
26
- if image.ndim == 2: # Grayscale NumPy array
27
- image_array = np.stack([image, image, image], axis=-1) # Convert to 3 channels
28
- logging.info("prepare_classifier_input: Converted grayscale NumPy array to 3-channel.")
29
- elif image.ndim == 3 and image.shape[2] == 4: # RGBA NumPy array
30
- image_array = image[:, :, :3] # Drop alpha channel
31
- logging.info("prepare_classifier_input: Dropped alpha channel from RGBA NumPy array.")
32
- else: # Already RGB or similar 3-channel NumPy array
33
- image_array = image
34
- logging.info("prepare_classifier_input: Image is already a suitable NumPy array.")
35
- else:
36
- logging.error("prepare_classifier_input: Unsupported image format received. Expected PIL Image or NumPy array.")
37
- raise ValueError("Unsupported image format for classifier input.")
38
-
39
- if not centroids:
40
- logging.warning("No centroids provided for prepare_classifier_input. Returning empty array.")
41
- return np.empty((0, window_size[0], window_size[1], image_array.shape[2]), dtype=np.float32)
42
-
43
- patches = []
44
- img_height, img_width, _ = image_array.shape # Get dimensions from the now-guaranteed NumPy array
45
- half_window_h, half_window_w = window_size[0] // 2, window_size[1] // 2
46
-
47
- for c_y, c_x in centroids: # Centroids are (y, x) from clustering
48
- # Ensure coordinates are integers
49
- c_y, c_x = int(round(c_y)), int(round(c_x))
50
-
51
- # Calculate bounding box for the patch
52
- # Handle boundary conditions by clamping coordinates
53
- y1 = max(0, c_y - half_window_h)
54
- y2 = min(img_height, c_y + half_window_h)
55
- x1 = max(0, c_x - half_window_w)
56
- x2 = min(img_width, c_x + half_window_w)
57
-
58
- # Extract patch
59
- patch = image_array[y1:y2, x1:x2, :]
60
-
61
- # Pad if the patch is smaller than window_size (due to boundary clamping)
62
- if patch.shape[0] != window_size[0] or patch.shape[1] != window_size[1]:
63
- padded_patch = np.zeros((window_size[0], window_size[1], image_array.shape[2]), dtype=patch.dtype)
64
- padded_patch[0:patch.shape[0], 0:patch.shape[1], :] = patch
65
- patch = padded_patch
66
-
67
- patches.append(patch)
68
-
69
- # Normalize pixel values if your model expects it (e.g., to 0-1)
70
- # This is a common step, adjust if your model's training pre-processing was different
71
- # Assuming images are 0-255, converting to float 0-1
72
- return np.array(patches, dtype=np.float32) / 255.0
73
 
 
74
 
75
- def show_boxes(image, damage_sites, save_image=False, image_path="output_image.png"):
76
- """
77
- Draws bounding boxes or markers on the image based on the classified damage sites.
 
 
 
 
 
78
 
79
  Args:
80
- image: The input SEM image (PIL Image or NumPy array).
81
- damage_sites (dict): Dictionary with (x,y) coordinates as keys and classification labels as values.
82
- save_image (bool, optional): Whether to save the image to disk. Defaults to False.
83
- image_path (str, optional): Path to save the image. Defaults to "output_image.png".
 
 
84
 
85
- Returns:
86
- PIL.Image.Image: The image with drawn boxes/markers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
- logging.info(f"show_boxes: Input image type: {type(image)}")
89
-
90
- if image is None:
91
- logging.warning("show_boxes received no image. Returning a blank image.")
92
- img = Image.new('RGB', (500, 500), color = 'black')
93
- else:
94
- # Ensure image is a PIL Image for drawing operations
95
- if isinstance(image, np.ndarray):
96
- # Convert NumPy array to PIL Image. Assuming input is 0-255.
97
- if image.dtype == np.float32 and np.max(image) <= 1.0: # If normalized 0-1 float
98
- image_for_pil = (image * 255).astype(np.uint8)
99
- else: # Assume 0-255 uint8
100
- image_for_pil = image.astype(np.uint8)
101
-
102
- if image_for_pil.ndim == 2: # Grayscale numpy
103
- img = Image.fromarray(image_for_pil, mode='L').convert("RGB")
104
- elif image_for_pil.ndim == 3 and image_for_pil.shape[2] in [3,4]: # RGB or RGBA
105
- img = Image.fromarray(image_for_pil).convert("RGB")
106
- else:
107
- logging.error("Unsupported numpy image format for show_boxes.")
108
- img = Image.new('RGB', (500, 500), color = 'black') # Fallback
109
- else: # Assume it's already a PIL Image
110
- img = image.copy().convert("RGB") # Use a copy to avoid modifying original
111
-
112
- draw = ImageDraw.Draw(img)
113
 
114
- # Define some colors for drawing boxes
115
- colors = {
116
- "Inclusion": "red",
117
- "Martensite": "blue",
118
- "Interface": "green",
119
- "Notch": "yellow",
120
- "Shadowing": "purple",
121
- "Not Classified": "gray", # Should ideally not appear on final image
122
- "Unknown": "white"
123
- }
124
 
125
- for (x, y), label in damage_sites.items():
126
- # Centroid coordinates from clustering (y,x) might be float
127
- center_x = int(round(y)) # Note: (y,x) from clustering means y is row (height), x is column (width)
128
- center_y = int(round(x)) # PIL expects (x, y) for drawing, so swap
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- box_size = 10 # Smaller box for clarity
131
-
132
- # Calculate box corners, clamping to image boundaries
133
- x1 = max(0, center_x - box_size)
134
- y1 = max(0, center_y - box_size)
135
- x2 = min(img.width, center_x + box_size)
136
- y2 = min(img.height, center_y + box_size)
137
 
138
- fill_color = colors.get(label, "white")
139
- outline_color = "black"
 
 
 
140
 
141
- draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=outline_color, width=2)
142
-
143
- # Draw text label slightly offset
144
- text_offset_x = 5
145
- text_offset_y = -15
146
- try:
147
- draw.text((x1 + text_offset_x, y1 + text_offset_y), label, fill=outline_color)
148
- except Exception as e:
149
- logging.warning(f"Could not draw text label '{label}': {e}")
150
 
 
 
 
151
 
152
- if save_image and image_path:
153
- try:
154
- img.save(image_path)
155
- logging.info(f"Image saved to {image_path}")
156
- except Exception as e:
157
- logging.error(f"Could not save image to {image_path}: {e}")
 
 
 
 
 
 
158
 
159
- return img
160
 
 
1
+ """
2
+ Collection of various utils
3
+ """
4
+
5
  import numpy as np
 
 
6
 
7
+ import imageio.v3 as iio
8
+ from PIL import Image
9
+ # we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
10
+ Image.MAX_IMAGE_PIXELS = 933120000
11
+
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as patches
14
+ from matplotlib.lines import Line2D
15
+
16
+
17
+ import math
18
+
19
+
20
+ ###
21
+ ### load SEM images
22
+ ###
23
+ def load_image(filename : str) -> np.ndarray :
24
+ """Load an SEM image
25
 
26
  Args:
27
+ filename (str): full path and name of the image file to be loaded
 
 
28
 
29
  Returns:
30
+ np.ndarray: file as numpy ndarray
31
  """
32
+ image = iio.imread(filename,mode='F')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ return image
35
 
36
+
37
+
38
+ ###
39
+ ### show SEM image with boxes in various colours around each damage site
40
+ ###
41
+ def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
42
+ save_image = False, image_path : str = None) :
43
+ """_summary_
44
 
45
  Args:
46
+ image (np.ndarray): SEM image to be shown
47
+ damage_sites (dict): python dictionary using the coordinates as key (x,y), and the label as value
48
+ box_size (list, optional): size of the rectangle drawn around each centroid. Defaults to [250,250].
49
+ save_image (bool, optional): save the image with the boxes or not. Defaults to False.
50
+ image_path (str, optional) : Full path and name of the output file to be saved
51
+ """
52
 
53
+ _, ax = plt.subplots(1)
54
+ ax.imshow(image, cmap='gray') # show image on correct axis
55
+ ax.set_xticks([])
56
+ ax.set_yticks([])
57
+
58
+ for key, label in damage_sites.items():
59
+ position = [key[0], key[1]]
60
+ edgecolor = {
61
+ 'Inclusion': 'b',
62
+ 'Interface': 'g',
63
+ 'Martensite': 'r',
64
+ 'Notch': 'y',
65
+ 'Shadowing': 'm'
66
+ }.get(label, 'k') # default: black
67
+
68
+ rect = patches.Rectangle((position[1] - box_size[1] / 2., position[0] - box_size[0] / 2),
69
+ box_size[1], box_size[0],
70
+ linewidth=1, edgecolor=edgecolor, facecolor='none')
71
+ ax.add_patch(rect)
72
+
73
+ legend_elements = [
74
+ Line2D([0], [0], color='b', lw=4, label='Inclusion'),
75
+ Line2D([0], [0], color='g', lw=4, label='Interface'),
76
+ Line2D([0], [0], color='r', lw=4, label='Martensite'),
77
+ Line2D([0], [0], color='y', lw=4, label='Notch'),
78
+ Line2D([0], [0], color='m', lw=4, label='Shadow'),
79
+ Line2D([0], [0], color='k', lw=4, label='Not Classified')
80
+ ]
81
+ ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")
82
+
83
+ fig = ax.figure
84
+ fig.tight_layout(pad=0)
85
+
86
+ if save_image and image_path:
87
+ fig.savefig(image_path, dpi=1200, bbox_inches='tight')
88
+
89
+ canvas = fig.canvas
90
+ canvas.draw()
91
+
92
+ data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
93
+ canvas.get_width_height()[::-1] + (4,))
94
+ data = data[:, :, :3] # RGB only
95
+
96
+ plt.close(fig)
97
+
98
+ return data
99
+
100
+
101
+ ###
102
+ ### cut out small images from panorama, append colour information
103
+ ###
104
+ def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=[250, 250]) -> list:
105
  """
106
+ Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
109
+ normalized image suitable for use with classification neural networks that expect color input.
110
+
111
+ Parameters
112
+ ----------
113
+ panorama : np.ndarray
114
+ Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data.
 
 
 
115
 
116
+ centroids : list of [int, int]
117
+ List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
118
+ identified in preprocessing (e.g., clustering).
119
+
120
+ window_size : list of int, optional
121
+ Size [height, width] of each extracted image patch. Defaults to [250, 250].
122
+
123
+ Returns
124
+ -------
125
+ list of np.ndarray
126
+ List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
127
+ centroids that allow full window extraction within image bounds are used.
128
+ """
129
+ if panorama.ndim == 2:
130
+ panorama = np.expand_dims(panorama, axis=-1) # (H, W, 1)
131
 
132
+ H, W, _ = panorama.shape
133
+ win_h, win_w = window_size
134
+ images = []
 
 
 
 
135
 
136
+ for (cy, cx) in centroids:
137
+ x1 = int(cx - win_w / 2)
138
+ y1 = int(cy - win_h / 2)
139
+ x2 = x1 + win_w
140
+ y2 = y1 + win_h
141
 
142
+ # Skip if patch would go out of bounds
143
+ if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
144
+ continue
 
 
 
 
 
 
145
 
146
+ # Extract and normalize patch
147
+ patch = panorama[y1:y2, x1:x2, 0].astype(np.float32)
148
+ patch = patch * 2. / 255. - 1.
149
 
150
+ # Replicate grayscale channel to simulate RGB
151
+ patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
152
+ images.append(patch_color)
153
+
154
+ return images
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
 
 
163