kerzel commited on
Commit
7b779c6
·
1 Parent(s): e360db8

try to fix utils.py with all the gradio fixes

Browse files
Files changed (1) hide show
  1. utils.py +86 -35
utils.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Collection of various utils
3
  """
4
 
5
  import numpy as np
@@ -12,16 +12,15 @@ Image.MAX_IMAGE_PIXELS = 933120000
12
  import matplotlib.pyplot as plt
13
  import matplotlib.patches as patches
14
  from matplotlib.lines import Line2D
15
-
16
-
17
  import math
18
 
19
 
20
  ###
21
- ### load SEM images
22
- ###
23
  def load_image(filename : str) -> np.ndarray :
24
- """Load an SEM image
25
 
26
  Args:
27
  filename (str): full path and name of the image file to be loaded
@@ -34,38 +33,66 @@ def load_image(filename : str) -> np.ndarray :
34
  return image
35
 
36
 
37
-
38
  ###
39
  ### show SEM image with boxes in various colours around each damage site
40
  ###
41
  def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
42
  save_image = False, image_path : str = None) :
43
- """_summary_
 
44
 
45
  Args:
46
- image (np.ndarray): SEM image to be shown
47
- damage_sites (dict): python dictionary using the coordinates as key (x,y), and the label as value
48
- box_size (list, optional): size of the rectangle drawn around each centroid. Defaults to [250,250].
49
- save_image (bool, optional): save the image with the boxes or not. Defaults to False.
50
- image_path (str, optional) : Full path and name of the output file to be saved
51
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  _, ax = plt.subplots(1)
54
- ax.imshow(image, cmap='gray') # show image on correct axis
55
  ax.set_xticks([])
56
  ax.set_yticks([])
57
 
58
  for key, label in damage_sites.items():
59
- position = [key[0], key[1]]
 
60
  edgecolor = {
61
  'Inclusion': 'b',
62
  'Interface': 'g',
63
  'Martensite': 'r',
64
  'Notch': 'y',
65
- 'Shadowing': 'm'
 
66
  }.get(label, 'k') # default: black
67
 
68
- rect = patches.Rectangle((position[1] - box_size[1] / 2., position[0] - box_size[0] / 2),
 
 
 
 
 
 
 
 
 
69
  box_size[1], box_size[0],
70
  linewidth=1, edgecolor=edgecolor, facecolor='none')
71
  ax.add_patch(rect)
@@ -95,13 +122,13 @@ def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
95
 
96
  plt.close(fig)
97
 
98
- return data
99
 
100
 
101
  ###
102
  ### cut out small images from panorama, append colour information
103
  ###
104
- def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=[250, 250]) -> list:
105
  """
106
  Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
107
 
@@ -110,8 +137,9 @@ def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=
110
 
111
  Parameters
112
  ----------
113
- panorama : np.ndarray
114
- Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data.
 
115
 
116
  centroids : list of [int, int]
117
  List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
@@ -126,14 +154,44 @@ def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=
126
  List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
127
  centroids that allow full window extraction within image bounds are used.
128
  """
129
- if panorama.ndim == 2:
130
- panorama = np.expand_dims(panorama, axis=-1) # (H, W, 1)
131
-
132
- H, W, _ = panorama.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  win_h, win_w = window_size
134
  images = []
135
 
136
  for (cy, cx) in centroids:
 
 
 
137
  x1 = int(cx - win_w / 2)
138
  y1 = int(cy - win_h / 2)
139
  x2 = x1 + win_w
@@ -141,11 +199,12 @@ def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=
141
 
142
  # Skip if patch would go out of bounds
143
  if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
 
144
  continue
145
 
146
  # Extract and normalize patch
147
- patch = panorama[y1:y2, x1:x2, 0].astype(np.float32)
148
- patch = patch * 2. / 255. - 1.
149
 
150
  # Replicate grayscale channel to simulate RGB
151
  patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
@@ -153,11 +212,3 @@ def prepare_classifier_input(panorama: np.ndarray, centroids: list, window_size=
153
 
154
  return images
155
 
156
-
157
-
158
-
159
-
160
-
161
-
162
-
163
-
 
1
  """
2
+ Collection of various utils
3
  """
4
 
5
  import numpy as np
 
12
  import matplotlib.pyplot as plt
13
  import matplotlib.patches as patches
14
  from matplotlib.lines import Line2D
15
+ import logging # ADDED for logging
 
16
  import math
17
 
18
 
19
  ###
20
+ ### load SEM images (Note: Not directly used with Gradio gr.Image(type="pil"))
21
+ ###
22
  def load_image(filename : str) -> np.ndarray :
23
+ """Load an SEM image
24
 
25
  Args:
26
  filename (str): full path and name of the image file to be loaded
 
33
  return image
34
 
35
 
 
36
  ###
37
  ### show SEM image with boxes in various colours around each damage site
38
  ###
39
  def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
40
  save_image = False, image_path : str = None) :
41
+ """
42
+ Shows an SEM image with colored boxes around identified damage sites.
43
 
44
  Args:
45
+ image (np.ndarray): SEM image to be shown.
46
+ damage_sites (dict): Python dictionary using the coordinates as key (x,y), and the label as value.
47
+ box_size (list, optional): Size of the rectangle drawn around each centroid. Defaults to [250,250].
48
+ save_image (bool, optional): Save the image with the boxes or not. Defaults to False.
49
+ image_path (str, optional) : Full path and name of the output file to be saved.
50
  """
51
+ logging.info(f"show_boxes: Input image type: {type(image)}") # Added logging
52
+
53
+ # Ensure image is a NumPy array of appropriate type for matplotlib
54
+ if isinstance(image, Image.Image):
55
+ image_to_plot = np.array(image.convert('L')) # Convert to grayscale NumPy array
56
+ logging.info("show_boxes: Converted PIL Image to grayscale NumPy array for plotting.")
57
+ elif isinstance(image, np.ndarray):
58
+ if image.ndim == 3 and image.shape[2] in [3,4]: # RGB or RGBA NumPy array
59
+ image_to_plot = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale
60
+ logging.info("show_boxes: Converted multi-channel NumPy array to grayscale for plotting.")
61
+ else: # Assume grayscale already
62
+ image_to_plot = image
63
+ logging.info("show_boxes: Image is already a grayscale NumPy array.")
64
+ else:
65
+ logging.error("show_boxes: Unsupported image format received.")
66
+ image_to_plot = np.zeros((100,100), dtype=np.uint8) # Fallback to black image
67
+
68
 
69
  _, ax = plt.subplots(1)
70
+ ax.imshow(image_to_plot, cmap='gray') # show image on correct axis
71
  ax.set_xticks([])
72
  ax.set_yticks([])
73
 
74
  for key, label in damage_sites.items():
75
+ position = [key[0], key[1]] # Assuming key[0] is y (row) and key[1] is x (column)
76
+
77
  edgecolor = {
78
  'Inclusion': 'b',
79
  'Interface': 'g',
80
  'Martensite': 'r',
81
  'Notch': 'y',
82
+ 'Shadowing': 'm',
83
+ 'Not Classified': 'k' # Added Not Classified for completeness
84
  }.get(label, 'k') # default: black
85
 
86
+ # Ensure box_size elements are floats for division
87
+ half_box_w = box_size[1] / 2.0
88
+ half_box_h = box_size[0] / 2.0
89
+
90
+ # x-coordinate of the bottom-left corner
91
+ rect_x = position[1] - half_box_w
92
+ # y-coordinate of the bottom-left corner (matplotlib origin is bottom-left)
93
+ rect_y = position[0] - half_box_h
94
+
95
+ rect = patches.Rectangle((rect_x, rect_y),
96
  box_size[1], box_size[0],
97
  linewidth=1, edgecolor=edgecolor, facecolor='none')
98
  ax.add_patch(rect)
 
122
 
123
  plt.close(fig)
124
 
125
+ return data
126
 
127
 
128
  ###
129
  ### cut out small images from panorama, append colour information
130
  ###
131
+ def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
132
  """
133
  Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
134
 
 
137
 
138
  Parameters
139
  ----------
140
+ panorama : PIL.Image.Image or np.ndarray
141
+ Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
142
+ or a PIL Image object.
143
 
144
  centroids : list of [int, int]
145
  List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
 
154
  List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
155
  centroids that allow full window extraction within image bounds are used.
156
  """
157
+ logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging
158
+
159
+ # --- MINIMAL FIX START ---
160
+ # Convert PIL Image to NumPy array if necessary
161
+ if isinstance(panorama, Image.Image):
162
+ # Convert to grayscale NumPy array as your original code expects this structure for processing
163
+ if panorama.mode == 'RGB':
164
+ panorama_array = np.array(panorama.convert('L'))
165
+ logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
166
+ else:
167
+ panorama_array = np.array(panorama)
168
+ logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
169
+ elif isinstance(panorama, np.ndarray):
170
+ # Ensure it's treated as a grayscale array for consistency with original logic
171
+ if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
172
+ panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
173
+ logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
174
+ else:
175
+ panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
176
+ logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
177
+ else:
178
+ logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
179
+ raise ValueError("Unsupported panorama format for classifier input.")
180
+
181
+ # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
182
+ if panorama_array.ndim == 2:
183
+ panorama_array = np.expand_dims(panorama_array, axis=-1) # (H, W, 1)
184
+ logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
185
+ # --- MINIMAL FIX END ---
186
+
187
+ H, W, _ = panorama_array.shape # Use panorama_array here
188
  win_h, win_w = window_size
189
  images = []
190
 
191
  for (cy, cx) in centroids:
192
+ # Ensure coordinates are integers
193
+ cy, cx = int(round(cy)), int(round(cx))
194
+
195
  x1 = int(cx - win_w / 2)
196
  y1 = int(cy - win_h / 2)
197
  x2 = x1 + win_w
 
199
 
200
  # Skip if patch would go out of bounds
201
  if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
202
+ logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
203
  continue
204
 
205
  # Extract and normalize patch
206
+ patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
207
+ patch = patch * 2. / 255. - 1. # Keep your original normalization
208
 
209
  # Replicate grayscale channel to simulate RGB
210
  patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
 
212
 
213
  return images
214