kerzel's picture
try to fix utils.py with all the gradio fixes
7b779c6
raw
history blame
8.71 kB
"""
Collection of various utils
"""
import numpy as np
import imageio.v3 as iio
from PIL import Image
# we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
Image.MAX_IMAGE_PIXELS = 933120000
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import logging # ADDED for logging
import math
###
### load SEM images (Note: Not directly used with Gradio gr.Image(type="pil"))
###
def load_image(filename : str) -> np.ndarray :
"""Load an SEM image
Args:
filename (str): full path and name of the image file to be loaded
Returns:
np.ndarray: file as numpy ndarray
"""
image = iio.imread(filename,mode='F')
return image
###
### show SEM image with boxes in various colours around each damage site
###
def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
save_image = False, image_path : str = None) :
"""
Shows an SEM image with colored boxes around identified damage sites.
Args:
image (np.ndarray): SEM image to be shown.
damage_sites (dict): Python dictionary using the coordinates as key (x,y), and the label as value.
box_size (list, optional): Size of the rectangle drawn around each centroid. Defaults to [250,250].
save_image (bool, optional): Save the image with the boxes or not. Defaults to False.
image_path (str, optional) : Full path and name of the output file to be saved.
"""
logging.info(f"show_boxes: Input image type: {type(image)}") # Added logging
# Ensure image is a NumPy array of appropriate type for matplotlib
if isinstance(image, Image.Image):
image_to_plot = np.array(image.convert('L')) # Convert to grayscale NumPy array
logging.info("show_boxes: Converted PIL Image to grayscale NumPy array for plotting.")
elif isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3,4]: # RGB or RGBA NumPy array
image_to_plot = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale
logging.info("show_boxes: Converted multi-channel NumPy array to grayscale for plotting.")
else: # Assume grayscale already
image_to_plot = image
logging.info("show_boxes: Image is already a grayscale NumPy array.")
else:
logging.error("show_boxes: Unsupported image format received.")
image_to_plot = np.zeros((100,100), dtype=np.uint8) # Fallback to black image
_, ax = plt.subplots(1)
ax.imshow(image_to_plot, cmap='gray') # show image on correct axis
ax.set_xticks([])
ax.set_yticks([])
for key, label in damage_sites.items():
position = [key[0], key[1]] # Assuming key[0] is y (row) and key[1] is x (column)
edgecolor = {
'Inclusion': 'b',
'Interface': 'g',
'Martensite': 'r',
'Notch': 'y',
'Shadowing': 'm',
'Not Classified': 'k' # Added Not Classified for completeness
}.get(label, 'k') # default: black
# Ensure box_size elements are floats for division
half_box_w = box_size[1] / 2.0
half_box_h = box_size[0] / 2.0
# x-coordinate of the bottom-left corner
rect_x = position[1] - half_box_w
# y-coordinate of the bottom-left corner (matplotlib origin is bottom-left)
rect_y = position[0] - half_box_h
rect = patches.Rectangle((rect_x, rect_y),
box_size[1], box_size[0],
linewidth=1, edgecolor=edgecolor, facecolor='none')
ax.add_patch(rect)
legend_elements = [
Line2D([0], [0], color='b', lw=4, label='Inclusion'),
Line2D([0], [0], color='g', lw=4, label='Interface'),
Line2D([0], [0], color='r', lw=4, label='Martensite'),
Line2D([0], [0], color='y', lw=4, label='Notch'),
Line2D([0], [0], color='m', lw=4, label='Shadow'),
Line2D([0], [0], color='k', lw=4, label='Not Classified')
]
ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")
fig = ax.figure
fig.tight_layout(pad=0)
if save_image and image_path:
fig.savefig(image_path, dpi=1200, bbox_inches='tight')
canvas = fig.canvas
canvas.draw()
data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
canvas.get_width_height()[::-1] + (4,))
data = data[:, :, :3] # RGB only
plt.close(fig)
return data
###
### cut out small images from panorama, append colour information
###
def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
"""
Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
normalized image suitable for use with classification neural networks that expect color input.
Parameters
----------
panorama : PIL.Image.Image or np.ndarray
Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
or a PIL Image object.
centroids : list of [int, int]
List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
identified in preprocessing (e.g., clustering).
window_size : list of int, optional
Size [height, width] of each extracted image patch. Defaults to [250, 250].
Returns
-------
list of np.ndarray
List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
centroids that allow full window extraction within image bounds are used.
"""
logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging
# --- MINIMAL FIX START ---
# Convert PIL Image to NumPy array if necessary
if isinstance(panorama, Image.Image):
# Convert to grayscale NumPy array as your original code expects this structure for processing
if panorama.mode == 'RGB':
panorama_array = np.array(panorama.convert('L'))
logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
else:
panorama_array = np.array(panorama)
logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
elif isinstance(panorama, np.ndarray):
# Ensure it's treated as a grayscale array for consistency with original logic
if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
else:
panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
else:
logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
raise ValueError("Unsupported panorama format for classifier input.")
# Now, ensure panorama_array has a channel dimension if it's 2D for consistency
if panorama_array.ndim == 2:
panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
# --- MINIMAL FIX END ---
H, W, _ = panorama_array.shape # Use panorama_array here
win_h, win_w = window_size
images = []
for (cy, cx) in centroids:
# Ensure coordinates are integers
cy, cx = int(round(cy)), int(round(cx))
x1 = int(cx - win_w / 2)
y1 = int(cy - win_h / 2)
x2 = x1 + win_w
y2 = y1 + win_h
# Skip if patch would go out of bounds
if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
continue
# Extract and normalize patch
patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
patch = patch * 2. / 255. - 1. # Keep your original normalization
# Replicate grayscale channel to simulate RGB
patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
images.append(patch_color)
return images