File size: 8,707 Bytes
da6b0fd
7b779c6
da6b0fd
 
af3fe62
 
da6b0fd
 
 
 
 
 
 
 
7b779c6
da6b0fd
 
 
 
7b779c6
 
da6b0fd
7b779c6
af3fe62
 
da6b0fd
af3fe62
 
da6b0fd
af3fe62
da6b0fd
af3fe62
da6b0fd
af3fe62
da6b0fd
 
 
 
 
 
7b779c6
 
af3fe62
25616b8
7b779c6
 
 
 
 
da6b0fd
7b779c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3fe62
da6b0fd
7b779c6
da6b0fd
 
 
 
7b779c6
 
da6b0fd
 
 
 
 
7b779c6
 
da6b0fd
 
7b779c6
 
 
 
 
 
 
 
 
 
da6b0fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b779c6
da6b0fd
 
 
 
 
7b779c6
af3fe62
da6b0fd
67d93b2
da6b0fd
 
 
 
 
7b779c6
 
 
67d93b2
da6b0fd
 
 
 
 
 
 
 
 
 
 
 
 
7b779c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da6b0fd
 
af3fe62
da6b0fd
7b779c6
 
 
da6b0fd
 
 
 
af3fe62
da6b0fd
 
7b779c6
da6b0fd
af3fe62
da6b0fd
7b779c6
 
af3fe62
da6b0fd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Collection of various utils
"""

import numpy as np

import imageio.v3 as iio
from PIL import Image
# we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
Image.MAX_IMAGE_PIXELS = 933120000

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import logging # ADDED for logging
import math


###
### load SEM images (Note: Not directly used with Gradio gr.Image(type="pil"))
###
def load_image(filename : str) -> np.ndarray :
    """Load an SEM image

    Args:
        filename (str): full path and name of the image file to be loaded

    Returns:
        np.ndarray: file as numpy ndarray
    """
    image =  iio.imread(filename,mode='F')

    return image


###
### show SEM image with boxes in various colours around each damage site
###
def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
               save_image = False, image_path : str = None) :
    """
    Shows an SEM image with colored boxes around identified damage sites.

    Args:
        image (np.ndarray): SEM image to be shown.
        damage_sites (dict): Python dictionary using the coordinates as key (x,y), and the label as value.
        box_size (list, optional): Size of the rectangle drawn around each centroid. Defaults to [250,250].
        save_image (bool, optional): Save the image with the boxes or not. Defaults to False.
        image_path (str, optional) : Full path and name of the output file to be saved.
    """
    logging.info(f"show_boxes: Input image type: {type(image)}") # Added logging

    # Ensure image is a NumPy array of appropriate type for matplotlib
    if isinstance(image, Image.Image):
        image_to_plot = np.array(image.convert('L')) # Convert to grayscale NumPy array
        logging.info("show_boxes: Converted PIL Image to grayscale NumPy array for plotting.")
    elif isinstance(image, np.ndarray):
        if image.ndim == 3 and image.shape[2] in [3,4]: # RGB or RGBA NumPy array
            image_to_plot = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale
            logging.info("show_boxes: Converted multi-channel NumPy array to grayscale for plotting.")
        else: # Assume grayscale already
            image_to_plot = image
            logging.info("show_boxes: Image is already a grayscale NumPy array.")
    else:
        logging.error("show_boxes: Unsupported image format received.")
        image_to_plot = np.zeros((100,100), dtype=np.uint8) # Fallback to black image


    _, ax = plt.subplots(1)
    ax.imshow(image_to_plot, cmap='gray')  # show image on correct axis
    ax.set_xticks([])
    ax.set_yticks([])

    for key, label in damage_sites.items():
        position = [key[0], key[1]] # Assuming key[0] is y (row) and key[1] is x (column)
        
        edgecolor = {
            'Inclusion': 'b',
            'Interface': 'g',
            'Martensite': 'r',
            'Notch': 'y',
            'Shadowing': 'm',
            'Not Classified': 'k' # Added Not Classified for completeness
        }.get(label, 'k')  # default: black

        # Ensure box_size elements are floats for division
        half_box_w = box_size[1] / 2.0
        half_box_h = box_size[0] / 2.0

        # x-coordinate of the bottom-left corner
        rect_x = position[1] - half_box_w
        # y-coordinate of the bottom-left corner (matplotlib origin is bottom-left)
        rect_y = position[0] - half_box_h

        rect = patches.Rectangle((rect_x, rect_y),
                                 box_size[1], box_size[0],
                                 linewidth=1, edgecolor=edgecolor, facecolor='none')
        ax.add_patch(rect)

    legend_elements = [
        Line2D([0], [0], color='b', lw=4, label='Inclusion'),
        Line2D([0], [0], color='g', lw=4, label='Interface'),
        Line2D([0], [0], color='r', lw=4, label='Martensite'),
        Line2D([0], [0], color='y', lw=4, label='Notch'),
        Line2D([0], [0], color='m', lw=4, label='Shadow'),
        Line2D([0], [0], color='k', lw=4, label='Not Classified')
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")

    fig = ax.figure
    fig.tight_layout(pad=0)

    if save_image and image_path:
        fig.savefig(image_path, dpi=1200, bbox_inches='tight')

    canvas = fig.canvas
    canvas.draw()

    data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
        canvas.get_width_height()[::-1] + (4,))
    data = data[:, :, :3]  # RGB only

    plt.close(fig)

    return data


###
### cut out small images from panorama, append colour information
###
def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
    """
    Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
    
    Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
    normalized image suitable for use with classification neural networks that expect color input.

    Parameters
    ----------
    panorama : PIL.Image.Image or np.ndarray
        Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
        or a PIL Image object.
    
    centroids : list of [int, int]
        List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
        identified in preprocessing (e.g., clustering).

    window_size : list of int, optional
        Size [height, width] of each extracted image patch. Defaults to [250, 250].

    Returns
    -------
    list of np.ndarray
        List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
        centroids that allow full window extraction within image bounds are used.
    """
    logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging

    # --- MINIMAL FIX START ---
    # Convert PIL Image to NumPy array if necessary
    if isinstance(panorama, Image.Image):
        # Convert to grayscale NumPy array as your original code expects this structure for processing
        if panorama.mode == 'RGB':
            panorama_array = np.array(panorama.convert('L'))
            logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
        else:
            panorama_array = np.array(panorama)
            logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
    elif isinstance(panorama, np.ndarray):
        # Ensure it's treated as a grayscale array for consistency with original logic
        if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
            panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
            logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
        else:
            panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
            logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
    else:
        logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
        raise ValueError("Unsupported panorama format for classifier input.")

    # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
    if panorama_array.ndim == 2:
        panorama_array = np.expand_dims(panorama_array, axis=-1)  # (H, W, 1)
        logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
    # --- MINIMAL FIX END ---

    H, W, _ = panorama_array.shape # Use panorama_array here
    win_h, win_w = window_size
    images = []

    for (cy, cx) in centroids:
        # Ensure coordinates are integers
        cy, cx = int(round(cy)), int(round(cx))

        x1 = int(cx - win_w / 2)
        y1 = int(cy - win_h / 2)
        x2 = x1 + win_w
        y2 = y1 + win_h

        # Skip if patch would go out of bounds
        if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
            logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
            continue

        # Extract and normalize patch
        patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
        patch = patch * 2. / 255. - 1. # Keep your original normalization

        # Replicate grayscale channel to simulate RGB
        patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
        images.append(patch_color)

    return images