kerzel commited on
Commit
25616b8
·
1 Parent(s): fd4e02c

utils by gemini

Browse files
Files changed (1) hide show
  1. utils.py +138 -141
utils.py CHANGED
@@ -1,163 +1,160 @@
1
- """
2
- Collection of various utils
3
- """
4
-
5
  import numpy as np
 
 
6
 
7
- import imageio.v3 as iio
8
- from PIL import Image
9
- # we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
10
- Image.MAX_IMAGE_PIXELS = 933120000
11
-
12
- import matplotlib.pyplot as plt
13
- import matplotlib.patches as patches
14
- from matplotlib.lines import Line2D
15
-
16
-
17
- import math
18
-
19
-
20
- ###
21
- ### load SEM images
22
- ###
23
- def load_image(filename : str) -> np.ndarray :
24
- """Load an SEM image
25
 
26
  Args:
27
- filename (str): full path and name of the image file to be loaded
 
 
28
 
29
  Returns:
30
- np.ndarray: file as numpy ndarray
31
  """
32
- image = iio.imread(filename,mode='F')
33
-
34
- return image
35
-
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- ###
39
- ### show SEM image with boxes in various colours around each damage site
40
- ###
41
- def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
42
- save_image = False, image_path : str = None) :
43
- """_summary_
44
 
45
- Args:
46
- image (np.ndarray): SEM image to be shown
47
- damage_sites (dict): python dictionary using the coordinates as key (x,y), and the label as value
48
- box_size (list, optional): size of the rectangle drawn around each centroid. Defaults to [250,250].
49
- save_image (bool, optional): save the image with the boxes or not. Defaults to False.
50
- image_path (str, optional) : Full path and name of the output file to be saved
51
  """
 
52
 
53
- _, ax = plt.subplots(1)
54
- ax.imshow(image, cmap='gray') # show image on correct axis
55
- ax.set_xticks([])
56
- ax.set_yticks([])
57
-
58
- for key, label in damage_sites.items():
59
- position = [key[0], key[1]]
60
- edgecolor = {
61
- 'Inclusion': 'b',
62
- 'Interface': 'g',
63
- 'Martensite': 'r',
64
- 'Notch': 'y',
65
- 'Shadowing': 'm'
66
- }.get(label, 'k') # default: black
67
-
68
- rect = patches.Rectangle((position[1] - box_size[1] / 2., position[0] - box_size[0] / 2),
69
- box_size[1], box_size[0],
70
- linewidth=1, edgecolor=edgecolor, facecolor='none')
71
- ax.add_patch(rect)
72
-
73
- legend_elements = [
74
- Line2D([0], [0], color='b', lw=4, label='Inclusion'),
75
- Line2D([0], [0], color='g', lw=4, label='Interface'),
76
- Line2D([0], [0], color='r', lw=4, label='Martensite'),
77
- Line2D([0], [0], color='y', lw=4, label='Notch'),
78
- Line2D([0], [0], color='m', lw=4, label='Shadow'),
79
- Line2D([0], [0], color='k', lw=4, label='Not Classified')
80
- ]
81
- ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")
82
-
83
- fig = ax.figure
84
- fig.tight_layout(pad=0)
85
-
86
- if save_image and image_path:
87
- fig.savefig(image_path, dpi=1200, bbox_inches='tight')
88
-
89
- canvas = fig.canvas
90
- canvas.draw()
91
-
92
- data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
93
- canvas.get_width_height()[::-1] + (4,))
94
- data = data[:, :, :3] # RGB only
95
-
96
- plt.close(fig)
97
-
98
- return data
99
-
100
 
101
- ###
102
- ### cut out small images from panorama, append colour information
103
- ###
104
- def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=[250, 250]) -> list:
105
  """
106
- Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
109
- normalized image suitable for use with classification neural networks that expect color input.
110
-
111
- Parameters
112
- ----------
113
- panorama : np.ndarray
114
- Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data.
 
 
 
115
 
116
- centroids : list of [int, int]
117
- List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
118
- identified in preprocessing (e.g., clustering).
119
-
120
- window_size : list of int, optional
121
- Size [height, width] of each extracted image patch. Defaults to [250, 250].
122
-
123
- Returns
124
- -------
125
- list of np.ndarray
126
- List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
127
- centroids that allow full window extraction within image bounds are used.
128
- """
129
- if panorama.ndim == 2:
130
- panorama = np.expand_dims(panorama, axis=-1) # (H, W, 1)
131
-
132
- H, W, _ = panorama.shape
133
- win_h, win_w = window_size
134
- images = []
135
-
136
- for (cy, cx) in centroids:
137
- x1 = int(cx - win_w / 2)
138
- y1 = int(cy - win_h / 2)
139
- x2 = x1 + win_w
140
- y2 = y1 + win_h
141
-
142
- # Skip if patch would go out of bounds
143
- if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
144
- continue
145
-
146
- # Extract and normalize patch
147
- patch = panorama[y1:y2, x1:x2, 0].astype(np.float32)
148
- patch = patch * 2. / 255. - 1.
149
-
150
- # Replicate grayscale channel to simulate RGB
151
- patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
152
- images.append(patch_color)
153
-
154
- return images
155
-
156
 
 
 
 
 
 
 
 
157
 
 
 
158
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
-
 
 
 
 
 
162
 
 
163
 
 
 
 
 
 
1
  import numpy as np
2
+ from PIL import Image, ImageDraw
3
+ import logging
4
 
5
+ def prepare_classifier_input(image, centroids, window_size):
6
+ """
7
+ Extracts image patches around centroids and prepares them as input for Keras models.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  Args:
10
+ image: The input SEM image (PIL Image or NumPy array).
11
+ centroids (list): List of (x,y) coordinates of damage site centroids.
12
+ window_size (list): [height, width] of the square window to extract around each centroid.
13
 
14
  Returns:
15
+ np.ndarray: A batch of image patches, ready for model prediction.
16
  """
17
+ logging.info(f"prepare_classifier_input: Input image type: {type(image)}")
18
+
19
+ # Convert PIL Image to NumPy array if necessary
20
+ if isinstance(image, Image.Image):
21
+ # Convert to RGB first to ensure 3 channels for consistent model input
22
+ image_array = np.array(image.convert('RGB'))
23
+ logging.info("prepare_classifier_input: Converted PIL Image to RGB NumPy array.")
24
+ elif isinstance(image, np.ndarray):
25
+ # Ensure it's a 3-channel array for consistency if it's already NumPy
26
+ if image.ndim == 2: # Grayscale NumPy array
27
+ image_array = np.stack([image, image, image], axis=-1) # Convert to 3 channels
28
+ logging.info("prepare_classifier_input: Converted grayscale NumPy array to 3-channel.")
29
+ elif image.ndim == 3 and image.shape[2] == 4: # RGBA NumPy array
30
+ image_array = image[:, :, :3] # Drop alpha channel
31
+ logging.info("prepare_classifier_input: Dropped alpha channel from RGBA NumPy array.")
32
+ else: # Already RGB or similar 3-channel NumPy array
33
+ image_array = image
34
+ logging.info("prepare_classifier_input: Image is already a suitable NumPy array.")
35
+ else:
36
+ logging.error("prepare_classifier_input: Unsupported image format received. Expected PIL Image or NumPy array.")
37
+ raise ValueError("Unsupported image format for classifier input.")
38
+
39
+ if not centroids:
40
+ logging.warning("No centroids provided for prepare_classifier_input. Returning empty array.")
41
+ return np.empty((0, window_size[0], window_size[1], image_array.shape[2]), dtype=np.float32)
42
+
43
+ patches = []
44
+ img_height, img_width, _ = image_array.shape # Get dimensions from the now-guaranteed NumPy array
45
+ half_window_h, half_window_w = window_size[0] // 2, window_size[1] // 2
46
+
47
+ for c_y, c_x in centroids: # Centroids are (y, x) from clustering
48
+ # Ensure coordinates are integers
49
+ c_y, c_x = int(round(c_y)), int(round(c_x))
50
+
51
+ # Calculate bounding box for the patch
52
+ # Handle boundary conditions by clamping coordinates
53
+ y1 = max(0, c_y - half_window_h)
54
+ y2 = min(img_height, c_y + half_window_h)
55
+ x1 = max(0, c_x - half_window_w)
56
+ x2 = min(img_width, c_x + half_window_w)
57
+
58
+ # Extract patch
59
+ patch = image_array[y1:y2, x1:x2, :]
60
+
61
+ # Pad if the patch is smaller than window_size (due to boundary clamping)
62
+ if patch.shape[0] != window_size[0] or patch.shape[1] != window_size[1]:
63
+ padded_patch = np.zeros((window_size[0], window_size[1], image_array.shape[2]), dtype=patch.dtype)
64
+ padded_patch[0:patch.shape[0], 0:patch.shape[1], :] = patch
65
+ patch = padded_patch
66
+
67
+ patches.append(patch)
68
+
69
+ # Normalize pixel values if your model expects it (e.g., to 0-1)
70
+ # This is a common step, adjust if your model's training pre-processing was different
71
+ # Assuming images are 0-255, converting to float 0-1
72
+ return np.array(patches, dtype=np.float32) / 255.0
73
 
 
 
 
 
 
 
74
 
75
+ def show_boxes(image, damage_sites, save_image=False, image_path="output_image.png"):
 
 
 
 
 
76
  """
77
+ Draws bounding boxes or markers on the image based on the classified damage sites.
78
 
79
+ Args:
80
+ image: The input SEM image (PIL Image or NumPy array).
81
+ damage_sites (dict): Dictionary with (x,y) coordinates as keys and classification labels as values.
82
+ save_image (bool, optional): Whether to save the image to disk. Defaults to False.
83
+ image_path (str, optional): Path to save the image. Defaults to "output_image.png".
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ Returns:
86
+ PIL.Image.Image: The image with drawn boxes/markers.
 
 
87
  """
88
+ logging.info(f"show_boxes: Input image type: {type(image)}")
89
+
90
+ if image is None:
91
+ logging.warning("show_boxes received no image. Returning a blank image.")
92
+ img = Image.new('RGB', (500, 500), color = 'black')
93
+ else:
94
+ # Ensure image is a PIL Image for drawing operations
95
+ if isinstance(image, np.ndarray):
96
+ # Convert NumPy array to PIL Image. Assuming input is 0-255.
97
+ if image.dtype == np.float32 and np.max(image) <= 1.0: # If normalized 0-1 float
98
+ image_for_pil = (image * 255).astype(np.uint8)
99
+ else: # Assume 0-255 uint8
100
+ image_for_pil = image.astype(np.uint8)
101
+
102
+ if image_for_pil.ndim == 2: # Grayscale numpy
103
+ img = Image.fromarray(image_for_pil, mode='L').convert("RGB")
104
+ elif image_for_pil.ndim == 3 and image_for_pil.shape[2] in [3,4]: # RGB or RGBA
105
+ img = Image.fromarray(image_for_pil).convert("RGB")
106
+ else:
107
+ logging.error("Unsupported numpy image format for show_boxes.")
108
+ img = Image.new('RGB', (500, 500), color = 'black') # Fallback
109
+ else: # Assume it's already a PIL Image
110
+ img = image.copy().convert("RGB") # Use a copy to avoid modifying original
111
+
112
+ draw = ImageDraw.Draw(img)
113
 
114
+ # Define some colors for drawing boxes
115
+ colors = {
116
+ "Inclusion": "red",
117
+ "Martensite": "blue",
118
+ "Interface": "green",
119
+ "Notch": "yellow",
120
+ "Shadowing": "purple",
121
+ "Not Classified": "gray", # Should ideally not appear on final image
122
+ "Unknown": "white"
123
+ }
124
 
125
+ for (x, y), label in damage_sites.items():
126
+ # Centroid coordinates from clustering (y,x) might be float
127
+ center_x = int(round(y)) # Note: (y,x) from clustering means y is row (height), x is column (width)
128
+ center_y = int(round(x)) # PIL expects (x, y) for drawing, so swap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ box_size = 10 # Smaller box for clarity
131
+
132
+ # Calculate box corners, clamping to image boundaries
133
+ x1 = max(0, center_x - box_size)
134
+ y1 = max(0, center_y - box_size)
135
+ x2 = min(img.width, center_x + box_size)
136
+ y2 = min(img.height, center_y + box_size)
137
 
138
+ fill_color = colors.get(label, "white")
139
+ outline_color = "black"
140
 
141
+ draw.rectangle([x1, y1, x2, y2], fill=fill_color, outline=outline_color, width=2)
142
+
143
+ # Draw text label slightly offset
144
+ text_offset_x = 5
145
+ text_offset_y = -15
146
+ try:
147
+ draw.text((x1 + text_offset_x, y1 + text_offset_y), label, fill=outline_color)
148
+ except Exception as e:
149
+ logging.warning(f"Could not draw text label '{label}': {e}")
150
 
151
 
152
+ if save_image and image_path:
153
+ try:
154
+ img.save(image_path)
155
+ logging.info(f"Image saved to {image_path}")
156
+ except Exception as e:
157
+ logging.error(f"Could not save image to {image_path}: {e}")
158
 
159
+ return img
160