try to revert back to old code but keep change in interface
Browse files- app.py +128 -132
- clustering.py +37 -63
- utils.py +141 -138
app.py
CHANGED
|
@@ -1,157 +1,153 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
-
from PIL import Image
|
| 5 |
-
import logging
|
| 6 |
-
import os # Import os for path checks
|
| 7 |
-
|
| 8 |
-
# Placeholder imports for clustering and utils.
|
| 9 |
-
# In a real scenario, these files (clustering.py, utils.py)
|
| 10 |
-
# would contain your actual implementation.
|
| 11 |
-
try:
|
| 12 |
-
import clustering
|
| 13 |
-
import utils
|
| 14 |
-
except ImportError as e:
|
| 15 |
-
logging.error(f"Error importing helper modules: {e}. Using dummy functions.")
|
| 16 |
-
# Define dummy functions if imports fail, to allow the app to launch.
|
| 17 |
-
class DummyClustering:
|
| 18 |
-
def get_centroids(self, *args, **kwargs):
|
| 19 |
-
logging.warning("Using dummy get_centroids. Provide actual clustering.py.")
|
| 20 |
-
# Return some dummy centroids for demonstration
|
| 21 |
-
# In a real scenario, you might want to raise an error or return an empty list
|
| 22 |
-
# if clustering is critical for app functionality.
|
| 23 |
-
return [(100, 100), (200, 200)]
|
| 24 |
-
|
| 25 |
-
class DummyUtils:
|
| 26 |
-
def prepare_classifier_input(self, *args, **kwargs):
|
| 27 |
-
logging.warning("Using dummy prepare_classifier_input. Provide actual utils.py.")
|
| 28 |
-
# Return dummy data for model input
|
| 29 |
-
return np.zeros((1, 250, 250, 3)) # Example shape, adjust as per your model input
|
| 30 |
-
|
| 31 |
-
def show_boxes(self, image, damage_sites, save_image=False, image_path=None):
|
| 32 |
-
logging.warning("Using dummy show_boxes. Provide actual utils.py.")
|
| 33 |
-
# Return the original image for dummy display
|
| 34 |
-
# In a real app, this would draw boxes
|
| 35 |
-
if image is None:
|
| 36 |
-
return Image.new('RGB', (400, 400), color = 'red') # Placeholder if no image provided
|
| 37 |
-
return image
|
| 38 |
-
|
| 39 |
-
clustering = DummyClustering()
|
| 40 |
-
utils = DummyUtils()
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
from tensorflow import keras
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
|
| 48 |
# --- Constants and Model Loading ---
|
| 49 |
IMAGE_PATH = "classified_damage_sites.png"
|
| 50 |
CSV_PATH = "classified_damage_sites.csv"
|
| 51 |
|
| 52 |
-
# Load models once at startup to improve performance
|
| 53 |
-
model1 = None
|
| 54 |
-
model2 = None
|
| 55 |
-
|
| 56 |
-
try:
|
| 57 |
-
# Check if model files exist before attempting to load
|
| 58 |
-
if os.path.exists('rwthmaterials_dp800_network1_inclusion.h5'):
|
| 59 |
-
model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
|
| 60 |
-
logging.info("Model 1 loaded successfully.")
|
| 61 |
-
else:
|
| 62 |
-
logging.warning("Model 1 (rwthmaterials_dp800_network1_inclusion.h5) not found. Classification results may be inaccurate.")
|
| 63 |
-
|
| 64 |
-
if os.path.exists('rwthmaterials_dp800_network2_damage.h5'):
|
| 65 |
-
model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
|
| 66 |
-
logging.info("Model 2 loaded successfully.")
|
| 67 |
-
else:
|
| 68 |
-
logging.warning("Model 2 (rwthmaterials_dp800_network2_damage.h5) not found. Classification results may be inaccurate.")
|
| 69 |
-
|
| 70 |
-
except Exception as e:
|
| 71 |
-
logging.error(f"Error loading models: {e}")
|
| 72 |
-
# Models are set to None, and warnings/errors are logged.
|
| 73 |
-
# The app will still attempt to launch.
|
| 74 |
-
|
| 75 |
-
damage_classes = {3: "Martensite", 2: "Interface", 0: "Notch", 1: "Shadowing"}
|
| 76 |
-
model1_windowsize = [250, 250]
|
| 77 |
-
model2_windowsize = [100, 100]
|
| 78 |
-
|
| 79 |
-
# --- Core Processing Function (Your original logic) ---
|
| 80 |
-
def damage_classification(SEM_image, image_threshold, model1_threshold, model2_threshold):
|
| 81 |
-
"""
|
| 82 |
-
This function contains the core scientific logic for classifying damage sites.
|
| 83 |
-
It returns the classified image and paths to the output files.
|
| 84 |
-
"""
|
| 85 |
-
if SEM_image is None:
|
| 86 |
-
raise gr.Error("Please upload an SEM Image before running classification.")
|
| 87 |
-
|
| 88 |
-
if model1 is None or model2 is None:
|
| 89 |
-
raise gr.Error("Models not loaded. Please ensure model files are present and valid.")
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
# Step 1: Clustering to find damage centroids
|
| 94 |
-
# Ensure clustering.get_centroids handles the case of no centroids found
|
| 95 |
-
all_centroids = clustering.get_centroids(
|
| 96 |
-
SEM_image,
|
| 97 |
-
image_threshold=image_threshold,
|
| 98 |
-
fill_holes=True,
|
| 99 |
-
filter_close_centroids=True,
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
for c in all_centroids:
|
| 103 |
-
damage_sites[(c[0], c[1])] = "Not Classified"
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
logging.error(f"Error during Model 1 prediction: {e}")
|
| 116 |
|
| 117 |
-
# Step 3: Model 2 to classify remaining damage types
|
| 118 |
-
centroids_model2 = [list(k) for k, v in damage_sites.items() if v == "Not Classified"]
|
| 119 |
-
if centroids_model2:
|
| 120 |
-
try:
|
| 121 |
-
images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
|
| 122 |
-
y2_pred = model2.predict(np.asarray(images_model2, dtype=float))
|
| 123 |
-
# Adjust the thresholding for damage_index to handle potential empty results
|
| 124 |
-
damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
|
| 125 |
-
|
| 126 |
-
for i in range(len(damage_index[0])):
|
| 127 |
-
sample_idx = damage_index[0][i]
|
| 128 |
-
class_idx = damage_index[1][i]
|
| 129 |
-
label = damage_classes.get(class_idx, "Unknown")
|
| 130 |
-
coord = centroids_model2[sample_idx]
|
| 131 |
-
damage_sites[(coord[0], coord[1])] = label
|
| 132 |
-
except Exception as e:
|
| 133 |
-
logging.error(f"Error during Model 2 prediction: {e}")
|
| 134 |
|
| 135 |
-
# Step 4: Draw boxes on image and save output image
|
| 136 |
-
# The utils.show_boxes function is assumed to return a PIL Image object
|
| 137 |
-
image_with_boxes = utils.show_boxes(SEM_image, damage_sites, save_image=True, image_path=IMAGE_PATH)
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
# Log file paths to ensure they are correct
|
| 145 |
-
logging.info(f"Generated Image Path: {IMAGE_PATH}")
|
| 146 |
-
logging.info(f"Generated CSV Path: {CSV_PATH}")
|
| 147 |
|
| 148 |
-
return
|
| 149 |
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
with gr.Blocks() as app:
|
| 152 |
-
gr.Markdown(
|
| 153 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
|
|
|
| 155 |
with gr.Row():
|
| 156 |
with gr.Column(scale=1):
|
| 157 |
image_input = gr.Image(type="pil", label="Upload SEM Image")
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
# our own helper tools
|
| 6 |
+
import clustering
|
| 7 |
+
import utils
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 11 |
|
| 12 |
from tensorflow import keras
|
| 13 |
|
| 14 |
+
#image_threshold = 20
|
| 15 |
+
|
| 16 |
|
| 17 |
# --- Constants and Model Loading ---
|
| 18 |
IMAGE_PATH = "classified_damage_sites.png"
|
| 19 |
CSV_PATH = "classified_damage_sites.csv"
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
model1_windowsize = [250,250]
|
| 23 |
+
#model1_threshold = 0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
model1 = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.h5')
|
| 26 |
+
model1.compile()
|
| 27 |
+
|
| 28 |
+
damage_classes = {3: "Martensite",2: "Interface",0:"Notch",1:"Shadowing"}
|
| 29 |
+
|
| 30 |
+
model2_windowsize = [100,100]
|
| 31 |
+
#model2_threshold = 0.5
|
| 32 |
+
|
| 33 |
+
model2 = keras.models.load_model('rwthmaterials_dp800_network2_damage.h5')
|
| 34 |
+
model2.compile()
|
|
|
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
##
|
| 39 |
+
## Function to do the actual damage classification
|
| 40 |
+
##
|
| 41 |
+
def damage_classification(SEM_image,image_threshold, model1_threshold, model2_threshold):
|
| 42 |
+
|
| 43 |
+
damage_sites = {}
|
| 44 |
+
##
|
| 45 |
+
## clustering
|
| 46 |
+
##
|
| 47 |
+
logging.debug('---------------: clustering :=====================')
|
| 48 |
+
all_centroids = clustering.get_centroids(SEM_image, image_threshold=image_threshold,
|
| 49 |
+
fill_holes=True, filter_close_centroids=True)
|
| 50 |
+
|
| 51 |
+
for i in range(len(all_centroids)) :
|
| 52 |
+
key = (all_centroids[i][0],all_centroids[i][1])
|
| 53 |
+
damage_sites[key] = 'Not Classified'
|
| 54 |
+
|
| 55 |
+
##
|
| 56 |
+
## Inclusions vs the rest
|
| 57 |
+
##
|
| 58 |
+
logging.debug('---------------: prepare model 1 :=====================')
|
| 59 |
+
images_model1 = utils.prepare_classifier_input(SEM_image, all_centroids, window_size=model1_windowsize)
|
| 60 |
+
|
| 61 |
+
logging.debug('---------------: run model 1 :=====================')
|
| 62 |
+
y1_pred = model1.predict(np.asarray(images_model1, float))
|
| 63 |
+
|
| 64 |
+
logging.debug('---------------: model1 threshold :=====================')
|
| 65 |
+
inclusions = y1_pred[:,0].reshape(len(y1_pred),1)
|
| 66 |
+
inclusions = np.where(inclusions > model1_threshold)
|
| 67 |
+
|
| 68 |
+
logging.debug('---------------: model 1 update dict :=====================')
|
| 69 |
+
for i in range(len(inclusions[0])):
|
| 70 |
+
centroid_id = inclusions[0][i]
|
| 71 |
+
coordinates = all_centroids[centroid_id]
|
| 72 |
+
key = (coordinates[0], coordinates[1])
|
| 73 |
+
damage_sites[key] = 'Inclusion'
|
| 74 |
+
logging.debug('Damage sites after model 1')
|
| 75 |
+
logging.debug(damage_sites)
|
| 76 |
+
|
| 77 |
+
##
|
| 78 |
+
## Martensite cracking, etc
|
| 79 |
+
##
|
| 80 |
+
logging.debug('---------------: prepare model 2 :=====================')
|
| 81 |
+
centroids_model2 = []
|
| 82 |
+
for key, value in damage_sites.items():
|
| 83 |
+
if value == 'Not Classified':
|
| 84 |
+
coordinates = list([key[0],key[1]])
|
| 85 |
+
centroids_model2.append(coordinates)
|
| 86 |
+
logging.debug('Centroids model 2')
|
| 87 |
+
logging.debug(centroids_model2)
|
| 88 |
+
|
| 89 |
+
logging.debug('---------------: prepare model 2 :=====================')
|
| 90 |
+
images_model2 = utils.prepare_classifier_input(SEM_image, centroids_model2, window_size=model2_windowsize)
|
| 91 |
+
logging.debug('Images model 2')
|
| 92 |
+
logging.debug(images_model2)
|
| 93 |
+
|
| 94 |
+
logging.debug('---------------: run model 2 :=====================')
|
| 95 |
+
y2_pred = model2.predict(np.asarray(images_model2, float))
|
| 96 |
+
|
| 97 |
+
damage_index = np.asarray(y2_pred > model2_threshold).nonzero()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
for i in range(len(damage_index[0])):
|
| 101 |
+
index = damage_index[0][i]
|
| 102 |
+
identified_class = damage_index[1][i]
|
| 103 |
+
label = damage_classes[identified_class]
|
| 104 |
+
coordinates = centroids_model2[index]
|
| 105 |
+
#print('Damage {} \t identified as {}, \t coordinates {}'.format(i, label, coordinates))
|
| 106 |
+
key = (coordinates[0], coordinates[1])
|
| 107 |
+
damage_sites[key] = label
|
| 108 |
+
|
| 109 |
+
##
|
| 110 |
+
## show the damage sites on the image
|
| 111 |
+
##
|
| 112 |
+
logging.debug("-----------------: final damage sites :=================")
|
| 113 |
+
logging.debug(damage_sites)
|
| 114 |
+
|
| 115 |
+
image_path = 'classified_damage_sites.png'
|
| 116 |
+
image = utils.show_boxes(SEM_image, damage_sites,
|
| 117 |
+
save_image=True,
|
| 118 |
+
image_path=image_path)
|
| 119 |
+
|
| 120 |
+
##
|
| 121 |
+
## export data
|
| 122 |
+
##
|
| 123 |
+
csv_path = 'classified_damage_sites.csv'
|
| 124 |
+
cols = ['x', 'y', 'damage_type']
|
| 125 |
+
|
| 126 |
+
data = []
|
| 127 |
+
for key, value in damage_sites.items():
|
| 128 |
+
data.append([key[0], key[1], value])
|
| 129 |
+
|
| 130 |
+
df = pd.DataFrame(columns=cols, data=data)
|
| 131 |
+
|
| 132 |
+
df.to_csv(csv_path)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
return image, image_path, csv_path
|
| 136 |
|
| 137 |
+
## ---------------------------------------------------------------------------------------------------------------
|
| 138 |
+
## main app interface
|
| 139 |
+
## -----------------------------------------------------------------------------------------------------------------
|
| 140 |
with gr.Blocks() as app:
|
| 141 |
+
gr.Markdown('# Damage Classification in Dual Phase Steels')
|
| 142 |
+
gr.Markdown('This app classifies damage types in dual phase steels. Two models are used. The first model is used to identify inclusions in the steel. The second model is used to identify the remaining damage types: Martensite cracking, Interface Decohesion, Notch effect and Shadows.')
|
| 143 |
+
|
| 144 |
+
gr.Markdown('The models used in this app are based on the following papers:')
|
| 145 |
+
gr.Markdown('Kusche, C., Reclik, T., Freund, M., Al-Samman, T., Kerzel, U., & Korte-Kerzel, S. (2019). Large-area, high-resolution characterisation and classification of damage mechanisms in dual-phase steel using deep learning. PloS one, 14(5), e0216493. [Link](https://doi.org/10.1371/journal.pone.0216493)')
|
| 146 |
+
#gr.Markdown('Medghalchi, S., Kusche, C. F., Karimi, E., Kerzel, U., & Korte-Kerzel, S. (2020). Damage analysis in dual-phase steel using deep learning: transfer from uniaxial to biaxial straining conditions by image data augmentation. Jom, 72, 4420-4430. [Link](https://link.springer.com/article/10.1007/s11837-020-04404-0)')
|
| 147 |
+
gr.Markdown('Setareh Medghalchi, Ehsan Karimi, Sang-Hyeok Lee, Benjamin Berkels, Ulrich Kerzel, Sandra Korte-Kerzel, Three-dimensional characterisation of deformation-induced damage in dual phase steel using deep learning, Materials & Design, Volume 232, 2023, 112108, ISSN 0264-1275, [link] (https://doi.org/10.1016/j.matdes.2023.112108')
|
| 148 |
+
gr.Markdown('Original data and code, including the network weights, can be found at Zenodo [link](https://zenodo.org/records/8065752)')
|
| 149 |
|
| 150 |
+
image_input = gr.Image(value='data/X4-Aligned_cropped_upperleft_small.png', label='Example SEM Image (DP800 steel)',)
|
| 151 |
with gr.Row():
|
| 152 |
with gr.Column(scale=1):
|
| 153 |
image_input = gr.Image(type="pil", label="Upload SEM Image")
|
clustering.py
CHANGED
|
@@ -5,19 +5,21 @@
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
|
|
|
|
|
|
| 8 |
import scipy.ndimage as ndi
|
| 9 |
from scipy.spatial import KDTree
|
|
|
|
| 10 |
from sklearn.cluster import DBSCAN
|
|
|
|
| 11 |
import logging
|
| 12 |
-
from PIL import Image # Import PIL for type checking/conversion if necessary
|
| 13 |
|
| 14 |
|
| 15 |
-
def get_centroids(image, image_threshold=20,
|
| 16 |
eps=1, min_samples=5, metric='euclidean',
|
| 17 |
-
min_size=20, fill_holes=False,
|
| 18 |
-
filter_close_centroids=False, filter_radius=50) -> list:
|
| 19 |
-
"""
|
| 20 |
-
Determine centroids of clusters corresponding to potential damage sites.
|
| 21 |
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
|
| 22 |
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
|
| 23 |
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
|
|
@@ -29,48 +31,27 @@ def get_centroids(image, image_threshold=20,
|
|
| 29 |
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
|
| 30 |
|
| 31 |
Args:
|
| 32 |
-
image: Input SEM image
|
| 33 |
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
|
| 34 |
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
|
| 35 |
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
|
| 36 |
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
|
| 37 |
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
|
| 38 |
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
|
| 39 |
-
filter_close_centroids (
|
| 40 |
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
|
| 44 |
"""
|
| 45 |
|
| 46 |
-
centroids = []
|
| 47 |
-
logging.info(f"get_centroids: Input image type: {type(image)}")
|
| 48 |
-
|
| 49 |
-
# Convert PIL Image to NumPy array if necessary
|
| 50 |
-
if isinstance(image, Image.Image):
|
| 51 |
-
# Convert to grayscale if it's an RGB image, as thresholding is usually on single channel
|
| 52 |
-
if image.mode == 'RGB':
|
| 53 |
-
image_array = np.array(image.convert('L'))
|
| 54 |
-
logging.info("get_centroids: Converted RGB PIL Image to grayscale NumPy array.")
|
| 55 |
-
else:
|
| 56 |
-
image_array = np.array(image)
|
| 57 |
-
logging.info("get_centroids: Converted PIL Image to NumPy array.")
|
| 58 |
-
elif isinstance(image, np.ndarray):
|
| 59 |
-
# Ensure it's grayscale if it's a multi-channel numpy array
|
| 60 |
-
if image.ndim == 3 and image.shape[2] in [3, 4]: # RGB or RGBA
|
| 61 |
-
image_array = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale by averaging channels
|
| 62 |
-
logging.info("get_centroids: Converted multi-channel NumPy array to grayscale NumPy array.")
|
| 63 |
-
else:
|
| 64 |
-
image_array = image
|
| 65 |
-
logging.info("get_centroids: Image is already a NumPy array.")
|
| 66 |
-
else:
|
| 67 |
-
logging.error("get_centroids: Unsupported image format received.")
|
| 68 |
-
raise ValueError("Unsupported image format. Expected PIL Image or NumPy array.")
|
| 69 |
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# apply the threshold to identify regions of "dark" pixels
|
| 72 |
# the result is a binary mask (true/false) whether a given pixel is above or below the threshold
|
| 73 |
-
|
| 74 |
|
| 75 |
# sometimes the clusters have small holes in them, for example, individual pixels
|
| 76 |
# inside a region below the threshold. This may confuse the clustering algorith later on
|
|
@@ -78,22 +59,20 @@ def get_centroids(image, image_threshold=20,
|
|
| 78 |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
|
| 79 |
# N.B. the algorith only works on binay data
|
| 80 |
if fill_holes:
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# transform image format into a numpy array to pass on to DBSCAN clustering
|
| 84 |
-
|
| 85 |
-
cluster_candidates = np.asarray(cluster_candidates_mask).nonzero()
|
| 86 |
cluster_candidates = np.transpose(cluster_candidates)
|
| 87 |
|
| 88 |
-
# Handle case where no candidates are found after thresholding
|
| 89 |
-
if cluster_candidates.size == 0:
|
| 90 |
-
logging.warning("No cluster candidates found after thresholding. Returning empty centroids list.")
|
| 91 |
-
return []
|
| 92 |
|
| 93 |
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
|
| 94 |
# (e.g. they are too small, etc)
|
| 95 |
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
|
| 96 |
-
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=
|
| 97 |
|
| 98 |
dbscan.fit(cluster_candidates)
|
| 99 |
|
|
@@ -101,37 +80,32 @@ def get_centroids(image, image_threshold=20,
|
|
| 101 |
# Number of clusters in labels, ignoring noise if present.
|
| 102 |
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
| 103 |
n_noise = list(labels).count(-1)
|
| 104 |
-
logging.
|
| 105 |
|
| 106 |
|
| 107 |
# now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
|
| 108 |
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
|
| 109 |
# in the cluster candidates
|
| 110 |
for i in set(labels):
|
| 111 |
-
if i
|
| 112 |
# all points belonging to a given cluster
|
| 113 |
-
cluster_points = cluster_candidates[labels
|
| 114 |
if len(cluster_points) > min_size:
|
| 115 |
-
x_mean
|
| 116 |
-
y_mean
|
| 117 |
-
centroids.append([x_mean,
|
| 118 |
|
| 119 |
-
if filter_close_centroids
|
| 120 |
proximity_tree = KDTree(centroids)
|
| 121 |
pairs = proximity_tree.query_pairs(filter_radius)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
logging.info(f"Filtered {len(indices_to_remove)} close centroids. Remaining: {len(centroids)}")
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
return centroids
|
| 137 |
-
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
import scipy.ndimage as ndi
|
| 11 |
from scipy.spatial import KDTree
|
| 12 |
+
|
| 13 |
from sklearn.cluster import DBSCAN
|
| 14 |
+
|
| 15 |
import logging
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
+
def get_centroids(image : np.ndarray, image_threshold = 20,
|
| 19 |
eps=1, min_samples=5, metric='euclidean',
|
| 20 |
+
min_size = 20, fill_holes = False,
|
| 21 |
+
filter_close_centroids = False, filter_radius = 50) -> list:
|
| 22 |
+
""" Determine centroids of clusters corresponding to potential damage sites.
|
|
|
|
| 23 |
In a first step, a threshold is applied to the input image to identify areas of potential damage sites.
|
| 24 |
Using DBSCAN, these agglomerations of pixels are fitted into clusters. Then, the mean x/y values are determined
|
| 25 |
from pixels belonging to one cluster. If the number of pixels in a given cluster excees the threshold given by min_size, this cluster is added
|
|
|
|
| 31 |
DBScan documentation: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
|
| 32 |
|
| 33 |
Args:
|
| 34 |
+
image (np.ndarray): Input SEM image
|
| 35 |
image_threshold (int, optional): Threshold to be applied to the image to identify candidates for damage sites. Defaults to 20.
|
| 36 |
eps (int, optional): parameter eps of DBSCAN: The maximum distance between two samples for one to be considered as in the neighborhood of the other. Defaults to 1.
|
| 37 |
min_samples (int, optional): parameter min_samples of DBSCAN: The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. Defaults to 5.
|
| 38 |
metric (str, optional): parameter metric of DBSCAN. Defaults to 'euclidean'.
|
| 39 |
min_size (int, optional): Minimum number of pixels in a cluster for the damage site candidate to be considered in the final list. Defaults to 20.
|
| 40 |
fill_holes (bool, optional): Fill small holes in damage sites clusters using binary_fill_holes. Defaults to False.
|
| 41 |
+
filter_close_centroids (book optional): Filter cluster centroids within a given radius. Defaults to False
|
| 42 |
filter_radius (float, optional): Radius within which centroids are considered to be the same. Defaults to 50
|
| 43 |
|
| 44 |
Returns:
|
| 45 |
list: list of (x,y) coordinates of the centroids of the clusters of accepted damage site candidates.
|
| 46 |
"""
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
centroids = []
|
| 50 |
+
#print('Threshold: ', image_threshold)
|
| 51 |
|
| 52 |
# apply the threshold to identify regions of "dark" pixels
|
| 53 |
# the result is a binary mask (true/false) whether a given pixel is above or below the threshold
|
| 54 |
+
cluster_candidates = image < image_threshold
|
| 55 |
|
| 56 |
# sometimes the clusters have small holes in them, for example, individual pixels
|
| 57 |
# inside a region below the threshold. This may confuse the clustering algorith later on
|
|
|
|
| 59 |
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.binary_fill_holes.html
|
| 60 |
# N.B. the algorith only works on binay data
|
| 61 |
if fill_holes:
|
| 62 |
+
cluster_candidates = ndi.binary_fill_holes(cluster_candidates)
|
| 63 |
+
|
| 64 |
+
# apply the treshold to the image to identify regions of "dark" pixels
|
| 65 |
+
#cluster_candidates = np.asarray(image < image_threshold).nonzero()
|
| 66 |
|
| 67 |
# transform image format into a numpy array to pass on to DBSCAN clustering
|
| 68 |
+
cluster_candidates = np.asarray(cluster_candidates).nonzero()
|
|
|
|
| 69 |
cluster_candidates = np.transpose(cluster_candidates)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# run the DBSCAN clustering algorithm, candidate sites that are not attributed to a cluster are labelled as "-1", i.e. "noise"
|
| 73 |
# (e.g. they are too small, etc)
|
| 74 |
# For the remaining pixels, a label is assigned to each pixel, indicating to which cluster (or noise) they belong to.
|
| 75 |
+
dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='euclidean')
|
| 76 |
|
| 77 |
dbscan.fit(cluster_candidates)
|
| 78 |
|
|
|
|
| 80 |
# Number of clusters in labels, ignoring noise if present.
|
| 81 |
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
| 82 |
n_noise = list(labels).count(-1)
|
| 83 |
+
logging.debug('# clusters {}, #noise {}'.format(n_clusters, n_noise))
|
| 84 |
|
| 85 |
|
| 86 |
# now loop over all labels found by DBSCAN, i.e. all identified clusters and the noise
|
| 87 |
# we use "set" here, as the labels are attributed to individual pixels, i.e. they appear as often as we have pixels
|
| 88 |
# in the cluster candidates
|
| 89 |
for i in set(labels):
|
| 90 |
+
if i>-1:
|
| 91 |
# all points belonging to a given cluster
|
| 92 |
+
cluster_points = cluster_candidates[labels==i, :]
|
| 93 |
if len(cluster_points) > min_size:
|
| 94 |
+
x_mean=np.mean(cluster_points, axis=0)[0]
|
| 95 |
+
y_mean=np.mean(cluster_points, axis=0)[1]
|
| 96 |
+
centroids.append([x_mean,y_mean])
|
| 97 |
|
| 98 |
+
if filter_close_centroids:
|
| 99 |
proximity_tree = KDTree(centroids)
|
| 100 |
pairs = proximity_tree.query_pairs(filter_radius)
|
| 101 |
+
for p in pairs:
|
| 102 |
+
#print('pair: ', p, ' p[0]: ', p[0], ' p[1]:', p[1])
|
| 103 |
+
#print('coords: ', proximity_tree.data[p[0]], ' ', proximity_tree.data[p[1]])
|
| 104 |
+
coords_to_remove = [proximity_tree.data[p[0]][0], proximity_tree.data[p[0]][1]]
|
| 105 |
+
try:
|
| 106 |
+
idx = centroids.index(coords_to_remove)
|
| 107 |
+
centroids.pop(idx)
|
| 108 |
+
except ValueError:
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
return centroids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
|
@@ -1,160 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
-
from PIL import Image, ImageDraw
|
| 3 |
-
import logging
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
Args:
|
| 10 |
-
|
| 11 |
-
centroids (list): List of (x,y) coordinates of damage site centroids.
|
| 12 |
-
window_size (list): [height, width] of the square window to extract around each centroid.
|
| 13 |
|
| 14 |
Returns:
|
| 15 |
-
np.ndarray:
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# Convert PIL Image to NumPy array if necessary
|
| 20 |
-
if isinstance(image, Image.Image):
|
| 21 |
-
# Convert to RGB first to ensure 3 channels for consistent model input
|
| 22 |
-
image_array = np.array(image.convert('RGB'))
|
| 23 |
-
logging.info("prepare_classifier_input: Converted PIL Image to RGB NumPy array.")
|
| 24 |
-
elif isinstance(image, np.ndarray):
|
| 25 |
-
# Ensure it's a 3-channel array for consistency if it's already NumPy
|
| 26 |
-
if image.ndim == 2: # Grayscale NumPy array
|
| 27 |
-
image_array = np.stack([image, image, image], axis=-1) # Convert to 3 channels
|
| 28 |
-
logging.info("prepare_classifier_input: Converted grayscale NumPy array to 3-channel.")
|
| 29 |
-
elif image.ndim == 3 and image.shape[2] == 4: # RGBA NumPy array
|
| 30 |
-
image_array = image[:, :, :3] # Drop alpha channel
|
| 31 |
-
logging.info("prepare_classifier_input: Dropped alpha channel from RGBA NumPy array.")
|
| 32 |
-
else: # Already RGB or similar 3-channel NumPy array
|
| 33 |
-
image_array = image
|
| 34 |
-
logging.info("prepare_classifier_input: Image is already a suitable NumPy array.")
|
| 35 |
-
else:
|
| 36 |
-
logging.error("prepare_classifier_input: Unsupported image format received. Expected PIL Image or NumPy array.")
|
| 37 |
-
raise ValueError("Unsupported image format for classifier input.")
|
| 38 |
-
|
| 39 |
-
if not centroids:
|
| 40 |
-
logging.warning("No centroids provided for prepare_classifier_input. Returning empty array.")
|
| 41 |
-
return np.empty((0, window_size[0], window_size[1], image_array.shape[2]), dtype=np.float32)
|
| 42 |
-
|
| 43 |
-
patches = []
|
| 44 |
-
img_height, img_width, _ = image_array.shape # Get dimensions from the now-guaranteed NumPy array
|
| 45 |
-
half_window_h, half_window_w = window_size[0] // 2, window_size[1] // 2
|
| 46 |
-
|
| 47 |
-
for c_y, c_x in centroids: # Centroids are (y, x) from clustering
|
| 48 |
-
# Ensure coordinates are integers
|
| 49 |
-
c_y, c_x = int(round(c_y)), int(round(c_x))
|
| 50 |
-
|
| 51 |
-
# Calculate bounding box for the patch
|
| 52 |
-
# Handle boundary conditions by clamping coordinates
|
| 53 |
-
y1 = max(0, c_y - half_window_h)
|
| 54 |
-
y2 = min(img_height, c_y + half_window_h)
|
| 55 |
-
x1 = max(0, c_x - half_window_w)
|
| 56 |
-
x2 = min(img_width, c_x + half_window_w)
|
| 57 |
-
|
| 58 |
-
# Extract patch
|
| 59 |
-
patch = image_array[y1:y2, x1:x2, :]
|
| 60 |
-
|
| 61 |
-
# Pad if the patch is smaller than window_size (due to boundary clamping)
|
| 62 |
-
if patch.shape[0] != window_size[0] or patch.shape[1] != window_size[1]:
|
| 63 |
-
padded_patch = np.zeros((window_size[0], window_size[1], image_array.shape[2]), dtype=patch.dtype)
|
| 64 |
-
padded_patch[0:patch.shape[0], 0:patch.shape[1], :] = patch
|
| 65 |
-
patch = padded_patch
|
| 66 |
-
|
| 67 |
-
patches.append(patch)
|
| 68 |
-
|
| 69 |
-
# Normalize pixel values if your model expects it (e.g., to 0-1)
|
| 70 |
-
# This is a common step, adjust if your model's training pre-processing was different
|
| 71 |
-
# Assuming images are 0-255, converting to float 0-1
|
| 72 |
-
return np.array(patches, dtype=np.float32) / 255.0
|
| 73 |
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
Args:
|
| 80 |
-
image:
|
| 81 |
-
damage_sites (dict):
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
"""
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
if image is None:
|
| 91 |
-
logging.warning("show_boxes received no image. Returning a blank image.")
|
| 92 |
-
img = Image.new('RGB', (500, 500), color = 'black')
|
| 93 |
-
else:
|
| 94 |
-
# Ensure image is a PIL Image for drawing operations
|
| 95 |
-
if isinstance(image, np.ndarray):
|
| 96 |
-
# Convert NumPy array to PIL Image. Assuming input is 0-255.
|
| 97 |
-
if image.dtype == np.float32 and np.max(image) <= 1.0: # If normalized 0-1 float
|
| 98 |
-
image_for_pil = (image * 255).astype(np.uint8)
|
| 99 |
-
else: # Assume 0-255 uint8
|
| 100 |
-
image_for_pil = image.astype(np.uint8)
|
| 101 |
-
|
| 102 |
-
if image_for_pil.ndim == 2: # Grayscale numpy
|
| 103 |
-
img = Image.fromarray(image_for_pil, mode='L').convert("RGB")
|
| 104 |
-
elif image_for_pil.ndim == 3 and image_for_pil.shape[2] in [3,4]: # RGB or RGBA
|
| 105 |
-
img = Image.fromarray(image_for_pil).convert("RGB")
|
| 106 |
-
else:
|
| 107 |
-
logging.error("Unsupported numpy image format for show_boxes.")
|
| 108 |
-
img = Image.new('RGB', (500, 500), color = 'black') # Fallback
|
| 109 |
-
else: # Assume it's already a PIL Image
|
| 110 |
-
img = image.copy().convert("RGB") # Use a copy to avoid modifying original
|
| 111 |
-
|
| 112 |
-
draw = ImageDraw.Draw(img)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
"Not Classified": "gray", # Should ideally not appear on final image
|
| 122 |
-
"Unknown": "white"
|
| 123 |
-
}
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
x1 = max(0, center_x - box_size)
|
| 134 |
-
y1 = max(0, center_y - box_size)
|
| 135 |
-
x2 = min(img.width, center_x + box_size)
|
| 136 |
-
y2 = min(img.height, center_y + box_size)
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
text_offset_x = 5
|
| 145 |
-
text_offset_y = -15
|
| 146 |
-
try:
|
| 147 |
-
draw.text((x1 + text_offset_x, y1 + text_offset_y), label, fill=outline_color)
|
| 148 |
-
except Exception as e:
|
| 149 |
-
logging.warning(f"Could not draw text label '{label}': {e}")
|
| 150 |
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
return img
|
| 160 |
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Collection of various utils
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
import imageio.v3 as iio
|
| 8 |
+
from PIL import Image
|
| 9 |
+
# we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
|
| 10 |
+
Image.MAX_IMAGE_PIXELS = 933120000
|
| 11 |
+
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import matplotlib.patches as patches
|
| 14 |
+
from matplotlib.lines import Line2D
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
###
|
| 21 |
+
### load SEM images
|
| 22 |
+
###
|
| 23 |
+
def load_image(filename : str) -> np.ndarray :
|
| 24 |
+
"""Load an SEM image
|
| 25 |
|
| 26 |
Args:
|
| 27 |
+
filename (str): full path and name of the image file to be loaded
|
|
|
|
|
|
|
| 28 |
|
| 29 |
Returns:
|
| 30 |
+
np.ndarray: file as numpy ndarray
|
| 31 |
"""
|
| 32 |
+
image = iio.imread(filename,mode='F')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
return image
|
| 35 |
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
###
|
| 39 |
+
### show SEM image with boxes in various colours around each damage site
|
| 40 |
+
###
|
| 41 |
+
def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
|
| 42 |
+
save_image = False, image_path : str = None) :
|
| 43 |
+
"""_summary_
|
| 44 |
|
| 45 |
Args:
|
| 46 |
+
image (np.ndarray): SEM image to be shown
|
| 47 |
+
damage_sites (dict): python dictionary using the coordinates as key (x,y), and the label as value
|
| 48 |
+
box_size (list, optional): size of the rectangle drawn around each centroid. Defaults to [250,250].
|
| 49 |
+
save_image (bool, optional): save the image with the boxes or not. Defaults to False.
|
| 50 |
+
image_path (str, optional) : Full path and name of the output file to be saved
|
| 51 |
+
"""
|
| 52 |
|
| 53 |
+
_, ax = plt.subplots(1)
|
| 54 |
+
ax.imshow(image, cmap='gray') # show image on correct axis
|
| 55 |
+
ax.set_xticks([])
|
| 56 |
+
ax.set_yticks([])
|
| 57 |
+
|
| 58 |
+
for key, label in damage_sites.items():
|
| 59 |
+
position = [key[0], key[1]]
|
| 60 |
+
edgecolor = {
|
| 61 |
+
'Inclusion': 'b',
|
| 62 |
+
'Interface': 'g',
|
| 63 |
+
'Martensite': 'r',
|
| 64 |
+
'Notch': 'y',
|
| 65 |
+
'Shadowing': 'm'
|
| 66 |
+
}.get(label, 'k') # default: black
|
| 67 |
+
|
| 68 |
+
rect = patches.Rectangle((position[1] - box_size[1] / 2., position[0] - box_size[0] / 2),
|
| 69 |
+
box_size[1], box_size[0],
|
| 70 |
+
linewidth=1, edgecolor=edgecolor, facecolor='none')
|
| 71 |
+
ax.add_patch(rect)
|
| 72 |
+
|
| 73 |
+
legend_elements = [
|
| 74 |
+
Line2D([0], [0], color='b', lw=4, label='Inclusion'),
|
| 75 |
+
Line2D([0], [0], color='g', lw=4, label='Interface'),
|
| 76 |
+
Line2D([0], [0], color='r', lw=4, label='Martensite'),
|
| 77 |
+
Line2D([0], [0], color='y', lw=4, label='Notch'),
|
| 78 |
+
Line2D([0], [0], color='m', lw=4, label='Shadow'),
|
| 79 |
+
Line2D([0], [0], color='k', lw=4, label='Not Classified')
|
| 80 |
+
]
|
| 81 |
+
ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")
|
| 82 |
+
|
| 83 |
+
fig = ax.figure
|
| 84 |
+
fig.tight_layout(pad=0)
|
| 85 |
+
|
| 86 |
+
if save_image and image_path:
|
| 87 |
+
fig.savefig(image_path, dpi=1200, bbox_inches='tight')
|
| 88 |
+
|
| 89 |
+
canvas = fig.canvas
|
| 90 |
+
canvas.draw()
|
| 91 |
+
|
| 92 |
+
data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
|
| 93 |
+
canvas.get_width_height()[::-1] + (4,))
|
| 94 |
+
data = data[:, :, :3] # RGB only
|
| 95 |
+
|
| 96 |
+
plt.close(fig)
|
| 97 |
+
|
| 98 |
+
return data
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
###
|
| 102 |
+
### cut out small images from panorama, append colour information
|
| 103 |
+
###
|
| 104 |
+
def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=[250, 250]) -> list:
|
| 105 |
"""
|
| 106 |
+
Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
|
| 109 |
+
normalized image suitable for use with classification neural networks that expect color input.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
panorama : np.ndarray
|
| 114 |
+
Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data.
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
centroids : list of [int, int]
|
| 117 |
+
List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
|
| 118 |
+
identified in preprocessing (e.g., clustering).
|
| 119 |
+
|
| 120 |
+
window_size : list of int, optional
|
| 121 |
+
Size [height, width] of each extracted image patch. Defaults to [250, 250].
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
list of np.ndarray
|
| 126 |
+
List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
|
| 127 |
+
centroids that allow full window extraction within image bounds are used.
|
| 128 |
+
"""
|
| 129 |
+
if panorama.ndim == 2:
|
| 130 |
+
panorama = np.expand_dims(panorama, axis=-1) # (H, W, 1)
|
| 131 |
|
| 132 |
+
H, W, _ = panorama.shape
|
| 133 |
+
win_h, win_w = window_size
|
| 134 |
+
images = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
for (cy, cx) in centroids:
|
| 137 |
+
x1 = int(cx - win_w / 2)
|
| 138 |
+
y1 = int(cy - win_h / 2)
|
| 139 |
+
x2 = x1 + win_w
|
| 140 |
+
y2 = y1 + win_h
|
| 141 |
|
| 142 |
+
# Skip if patch would go out of bounds
|
| 143 |
+
if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
|
| 144 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Extract and normalize patch
|
| 147 |
+
patch = panorama[y1:y2, x1:x2, 0].astype(np.float32)
|
| 148 |
+
patch = patch * 2. / 255. - 1.
|
| 149 |
|
| 150 |
+
# Replicate grayscale channel to simulate RGB
|
| 151 |
+
patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
|
| 152 |
+
images.append(patch_color)
|
| 153 |
+
|
| 154 |
+
return images
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
|
|
|
|
| 163 |
|