VJyzCELERY commited on
Commit
4631366
·
1 Parent(s): f83df97

Added better visualization

Browse files
Files changed (3) hide show
  1. app.py +4 -3
  2. src/__pycache__/model.cpython-312.pyc +0 -0
  3. src/model.py +288 -65
app.py CHANGED
@@ -396,7 +396,7 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
396
  plt.show()
397
 
398
  return fig
399
- def predict_image(upload,show_original,max_channels):
400
  img = cv2.cvtColor(cv2.imread(upload),cv2.COLOR_BGR2RGB)
401
  model_base_path = "./trained_model"
402
  classic_model_path =os.path.join(model_base_path,'classic_model.pt')
@@ -412,7 +412,7 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
412
  return "No CNN Model trained",None,None,None
413
  cnn_predict = cnn_model.predict(img)
414
  classic_predict = classic_model.predict(img)
415
- cnn_features = cnn_model.visualize_feature(img,max_channels=max_channels)
416
  classical_features = classic_model.visualize_feature(img,show_original=show_original)
417
  return None,make_figure_from_image(img),cnn_predict,classic_predict,cnn_features,classical_features
418
 
@@ -423,6 +423,7 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
423
  gr.Markdown("# CNN Settings")
424
  with gr.Accordion(open=False):
425
  cnn_max_channel_visual = gr.Number(value=8,precision=0,label='Max CNN Channels to Preview',interactive=True)
 
426
  with gr.Column():
427
  gr.Markdown("# Classical Settings")
428
  with gr.Accordion(open=False):
@@ -439,7 +440,7 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
439
 
440
  prediction_btn.click(
441
  fn=predict_image,
442
- inputs=[image_upload,classic_show_original,cnn_max_channel_visual],
443
  outputs=[verbose,image_preview,cnn_prediction,classical_prediction,cnn_features,classical_features]
444
  )
445
 
 
396
  plt.show()
397
 
398
  return fig
399
+ def predict_image(upload,show_original,max_channels,cnn_couple_channel):
400
  img = cv2.cvtColor(cv2.imread(upload),cv2.COLOR_BGR2RGB)
401
  model_base_path = "./trained_model"
402
  classic_model_path =os.path.join(model_base_path,'classic_model.pt')
 
412
  return "No CNN Model trained",None,None,None
413
  cnn_predict = cnn_model.predict(img)
414
  classic_predict = classic_model.predict(img)
415
+ cnn_features = cnn_model.visualize_feature(img,max_channels=max_channels,couple=cnn_couple_channel)
416
  classical_features = classic_model.visualize_feature(img,show_original=show_original)
417
  return None,make_figure_from_image(img),cnn_predict,classic_predict,cnn_features,classical_features
418
 
 
423
  gr.Markdown("# CNN Settings")
424
  with gr.Accordion(open=False):
425
  cnn_max_channel_visual = gr.Number(value=8,precision=0,label='Max CNN Channels to Preview',interactive=True)
426
+ cnn_couple_channel = gr.Checkbox(value=False,label='Couple Channels into RGB')
427
  with gr.Column():
428
  gr.Markdown("# Classical Settings")
429
  with gr.Accordion(open=False):
 
440
 
441
  prediction_btn.click(
442
  fn=predict_image,
443
+ inputs=[image_upload,classic_show_original,cnn_max_channel_visual,cnn_couple_channel],
444
  outputs=[verbose,image_preview,cnn_prediction,classical_prediction,cnn_features,classical_features]
445
  )
446
 
src/__pycache__/model.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/model.cpython-312.pyc and b/src/__pycache__/model.cpython-312.pyc differ
 
src/model.py CHANGED
@@ -55,6 +55,9 @@ class Config:
55
  gabor_lambda = 10
56
  gabor_gamma = 0.5
57
 
 
 
 
58
  class CNNFeatureExtractor(nn.Module):
59
  def __init__(self,config : Config):
60
  super().__init__()
@@ -105,51 +108,148 @@ class CNNFeatureExtractor(nn.Module):
105
  out = self(x)
106
 
107
  return out
108
- def visualize(self, input_image, max_channels=8,show=True):
 
 
 
 
 
 
 
109
  self.eval()
110
  device = self.get_device()
111
 
112
  if isinstance(input_image, np.ndarray):
113
- x = torch.from_numpy(input_image).permute(2, 0, 1).float().unsqueeze(0).to(device) # HWC -> CHW -> B
114
  elif isinstance(input_image, torch.Tensor):
115
  x = input_image.unsqueeze(0).to(device) if input_image.ndim == 3 else input_image.to(device)
116
  else:
117
  raise TypeError("input_image must be np.ndarray or torch.Tensor")
118
 
119
- conv_layers = [(name, module) for name, module in self.named_modules() if isinstance(module, nn.Conv2d)]
 
 
 
 
 
120
  all_layer_images = []
121
 
122
  for name, layer in conv_layers:
123
  activations = []
124
 
125
  def hook_fn(module, input, output):
126
- activations.append(output.cpu().detach())
127
 
128
  handle = layer.register_forward_hook(hook_fn)
129
  _ = self(x)
130
  handle.remove()
131
 
132
- act = activations[0][0]
133
- num_channels = min(act.shape[0], max_channels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- fig, axes = plt.subplots(1, num_channels, figsize=(3*num_channels, 3))
136
- if num_channels == 1:
137
- axes = [axes]
138
 
139
- for i in range(num_channels):
140
- axes[i].imshow(act[i], cmap='gray')
141
- axes[i].axis('off')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- fig.suptitle(f'Layer: {name}', fontsize=14)
144
  if show:
145
  plt.show()
146
 
147
  buf = io.BytesIO()
148
- fig.savefig(buf, format='png')
149
  buf.seek(0)
150
  img = Image.open(buf).convert("RGB")
151
  all_layer_images.append(np.array(img))
152
- plt.close(fig)
 
153
  return all_layer_images
154
 
155
  class ClassicalFeatureExtractor(nn.Module):
@@ -159,74 +259,204 @@ class ClassicalFeatureExtractor(nn.Module):
159
  self.hog_orientations = config.hog_orientations
160
  self.num_downsample = config.classical_downsample
161
  self.config = config
162
- self.feature_names = ['HoG','Canny Edge','Harris Corner','Shi-Tomasi corners','LBP','Gabor Filters']
163
  self.device = 'cpu'
164
 
165
  def get_device(self):
166
  return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
167
 
168
 
169
- def extract_features(self, img):
170
  cfg = self.config
171
 
172
  # Convert to grayscale
173
- min_h = cfg.hog_pixels_per_cell[0] * cfg.hog_cells_per_block[0]
174
- min_w = cfg.hog_pixels_per_cell[1] * cfg.hog_cells_per_block[1]
175
  gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
176
 
177
  for _ in range(self.num_downsample):
178
- h, w = gray.shape
179
- if h <= min_h or w <= min_w:
180
- break
181
  gray = cv2.pyrDown(gray)
182
 
183
  gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- feature_list = []
 
 
 
 
 
 
 
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # 1. HOG
188
- _, hog_image = hog(
189
- gray,
190
- orientations=cfg.hog_orientations,
191
- pixels_per_cell=cfg.hog_pixels_per_cell,
192
- cells_per_block=cfg.hog_cells_per_block,
193
- block_norm=cfg.hog_block_norm,
194
- visualize=True
195
- )
196
- feature_list.append(hog_image)
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  # 2. Canny edges
199
  edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
200
- feature_list.append(edges)
201
-
 
 
 
202
  # 3. Harris corners
203
  harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
204
  harris = cv2.dilate(harris, None)
205
  harris = np.clip(harris, 0, 1)
206
- feature_list.append(harris)
207
-
208
- # 4. Shi-Tomasi corners
209
- shi_corners = np.zeros_like(gray, dtype=np.float32)
210
- keypoints = cv2.goodFeaturesToTrack(gray, maxCorners=cfg.shi_max_corners, qualityLevel=cfg.shi_quality_level, minDistance=cfg.shi_min_distance)
211
- if keypoints is not None:
212
- for kp in keypoints:
213
- x, y = kp.ravel()
214
- shi_corners[int(y), int(x)] = 1.0
215
- feature_list.append(shi_corners)
216
-
 
 
 
 
 
217
  # 5. LBP
218
  lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
219
  lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
220
- feature_list.append(lbp)
221
-
 
 
 
222
  # 6. Gabor filter
223
- g_kernel = cv2.getGaborKernel((cfg.gabor_ksize, cfg.gabor_ksize), cfg.gabor_sigma, cfg.gabor_theta, cfg.gabor_lambda, cfg.gabor_gamma)
224
- gabor_feat = cv2.filter2D(gray, cv2.CV_32F, g_kernel)
225
- gabor_feat = (gabor_feat - gabor_feat.min()) / (gabor_feat.max() - gabor_feat.min() + 1e-8)
226
- feature_list.append(gabor_feat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Stack all features along channel axis
229
- features = np.stack(feature_list, axis=2)
 
 
 
230
  return features.astype(np.float32)
231
 
232
 
@@ -248,15 +478,13 @@ class ClassicalFeatureExtractor(nn.Module):
248
  feat = self.extract_features(img)
249
  batch_features.append(feat)
250
  batch_features = np.stack(batch_features, axis=0)
251
- return torch.from_numpy(batch_features).float().to(self.get_device())
 
252
 
253
  def visualize(self, img, show_original=True,show=True):
254
  if img.ndim != 3 or img.shape[2] != 3:
255
  img = np.repeat(img[:, :, None], 3, axis=2)
256
 
257
- feature_stack = self.extract_features(img)
258
- num_channels = feature_stack.shape[2]
259
-
260
  outputs = []
261
 
262
  def fig_to_pil(fig):
@@ -279,15 +507,10 @@ class ClassicalFeatureExtractor(nn.Module):
279
  if show:
280
  plt.show()
281
  outputs.append(fig_to_pil(fig))
282
-
283
- for c in range(num_channels):
284
- fig = plt.figure(figsize=(4, 4))
285
-
286
- plt.imshow(feature_stack[:, :, c], cmap="gray")
287
- plt.title(f"Feature {self.feature_names[c]}")
288
- plt.axis("off")
289
- if show:
290
- plt.show()
291
  outputs.append(fig_to_pil(fig))
292
 
293
  return outputs
 
55
  gabor_lambda = 10
56
  gabor_gamma = 0.5
57
 
58
+ # Sobel
59
+ sobel_ksize=3
60
+
61
  class CNNFeatureExtractor(nn.Module):
62
  def __init__(self,config : Config):
63
  super().__init__()
 
108
  out = self(x)
109
 
110
  return out
111
+ def visualize(
112
+ self,
113
+ input_image,
114
+ max_channels=8,
115
+ couple=False,
116
+ show=True,
117
+ **kwargs
118
+ ):
119
  self.eval()
120
  device = self.get_device()
121
 
122
  if isinstance(input_image, np.ndarray):
123
+ x = torch.from_numpy(input_image).permute(2, 0, 1).float().unsqueeze(0).to(device)
124
  elif isinstance(input_image, torch.Tensor):
125
  x = input_image.unsqueeze(0).to(device) if input_image.ndim == 3 else input_image.to(device)
126
  else:
127
  raise TypeError("input_image must be np.ndarray or torch.Tensor")
128
 
129
+ conv_layers = [
130
+ (name, module)
131
+ for name, module in self.named_modules()
132
+ if isinstance(module, nn.Conv2d)
133
+ ]
134
+
135
  all_layer_images = []
136
 
137
  for name, layer in conv_layers:
138
  activations = []
139
 
140
  def hook_fn(module, input, output):
141
+ activations.append(output.detach().cpu())
142
 
143
  handle = layer.register_forward_hook(hook_fn)
144
  _ = self(x)
145
  handle.remove()
146
 
147
+ act = activations[0][0] # (C, H, W)
148
+ C, H, W = act.shape
149
+
150
+ # --------------------------------------------------
151
+ # COUPLED RGB VISUALIZATION
152
+ # --------------------------------------------------
153
+ if couple:
154
+ max_rgb = max_channels // 3
155
+ num_rgb = min(C // 3, max_rgb)
156
+ rem = min(C - num_rgb * 3, max_channels - num_rgb * 3)
157
+
158
+ total_tiles = num_rgb + rem
159
+ cols = min(4, total_tiles)
160
+ rows = int(np.ceil(total_tiles / cols))
161
+
162
+ fig, axes = plt.subplots(
163
+ rows, cols,
164
+ figsize=(3 * cols, 3 * rows)
165
+ )
166
+
167
+ axes = np.atleast_2d(axes)
168
+
169
+ tile_idx = 0
170
+
171
+ # ---------------------------
172
+ # RGB COUPLED CHANNELS
173
+ # ---------------------------
174
+ for i in range(num_rgb):
175
+ r = tile_idx // cols
176
+ c = tile_idx % cols
177
+
178
+ rgb = act[i*3:(i+1)*3].clone()
179
+
180
+ for ch in range(3):
181
+ v = rgb[ch]
182
+ rgb[ch] = (v - v.min()) / (v.max() - v.min() + 1e-8)
183
+
184
+ rgb = rgb.permute(1, 2, 0).numpy()
185
+
186
+ axes[r, c].imshow(rgb)
187
+ axes[r, c].axis("off")
188
+ axes[r, c].set_title(f"RGB {i*3}-{i*3+2}", fontsize=9)
189
 
190
+ tile_idx += 1
 
 
191
 
192
+ start = num_rgb * 3
193
+ for j in range(rem):
194
+ r = tile_idx // cols
195
+ c = tile_idx % cols
196
+
197
+ ch = act[start + j]
198
+ ch = (ch - ch.min()) / (ch.max() - ch.min() + 1e-8)
199
+
200
+ axes[r, c].imshow(ch, cmap="gray")
201
+ axes[r, c].axis("off")
202
+ axes[r, c].set_title(f"Ch {start + j}", fontsize=9)
203
+
204
+ tile_idx += 1
205
+
206
+ for idx in range(tile_idx, rows * cols):
207
+ r = idx // cols
208
+ c = idx % cols
209
+ axes[r, c].axis("off")
210
+
211
+ fig.suptitle(f"Layer: {name} (Coupled RGB + Grayscale)", fontsize=14)
212
+ plt.tight_layout()
213
+
214
+ # --------------------------------------------------
215
+ # STANDARD GRAYSCALE VISUALIZATION
216
+ # --------------------------------------------------
217
+ else:
218
+ num_channels = min(C, max_channels)
219
+ cols = min(8, num_channels)
220
+ rows = int(np.ceil(num_channels / cols))
221
+
222
+ fig, axes = plt.subplots(
223
+ rows, cols,
224
+ figsize=(3 * cols, 3 * rows)
225
+ )
226
+
227
+ axes = np.atleast_2d(axes)
228
+
229
+ for idx in range(num_channels):
230
+ r = idx // cols
231
+ c = idx % cols
232
+ axes[r, c].imshow(act[idx], cmap="gray")
233
+ axes[r, c].axis("off")
234
+
235
+ for idx in range(num_channels, rows * cols):
236
+ r = idx // cols
237
+ c = idx % cols
238
+ axes[r, c].axis("off")
239
+
240
+ fig.suptitle(f"Layer: {name}", fontsize=14)
241
+ plt.tight_layout()
242
 
 
243
  if show:
244
  plt.show()
245
 
246
  buf = io.BytesIO()
247
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
248
  buf.seek(0)
249
  img = Image.open(buf).convert("RGB")
250
  all_layer_images.append(np.array(img))
251
+ plt.close(fig)
252
+
253
  return all_layer_images
254
 
255
  class ClassicalFeatureExtractor(nn.Module):
 
259
  self.hog_orientations = config.hog_orientations
260
  self.num_downsample = config.classical_downsample
261
  self.config = config
 
262
  self.device = 'cpu'
263
 
264
  def get_device(self):
265
  return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
266
 
267
 
268
+ def extract_features(self, img,visualize=False,**kwargs):
269
  cfg = self.config
270
 
271
  # Convert to grayscale
 
 
272
  gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
273
 
274
  for _ in range(self.num_downsample):
 
 
 
275
  gray = cv2.pyrDown(gray)
276
 
277
  gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
278
+ valid_H, valid_W = gray.shape[:2]
279
+
280
+ def render_subplots(items, max_cols=8, figsize_per_cell=3):
281
+ n = len(items)
282
+ cols = min(max_cols, n)
283
+ rows = int(np.ceil(n / cols))
284
+
285
+ fig, axes = plt.subplots(
286
+ rows, cols,
287
+ figsize=(cols * figsize_per_cell, rows * figsize_per_cell)
288
+ )
289
 
290
+ axes = np.atleast_2d(axes)
291
+
292
+ for idx, (img, title, cmap) in enumerate(items):
293
+ r = idx // cols
294
+ c = idx % cols
295
+ ax = axes[r, c]
296
+ ax.imshow(img, cmap=cmap)
297
+ ax.set_title(title, fontsize=9)
298
+ ax.axis("off")
299
 
300
+ # Hide unused axes
301
+ for idx in range(n, rows * cols):
302
+ r = idx // cols
303
+ c = idx % cols
304
+ axes[r, c].axis("off")
305
+
306
+ plt.tight_layout()
307
+ return fig
308
+
309
+ feature_list = []
310
+ vis_items=[]
311
+ # figs = []
312
+ H, W = gray.shape
313
+ cell_h, cell_w = cfg.hog_pixels_per_cell
314
+ block_h, block_w = cfg.hog_cells_per_block
315
+
316
+ min_h = cell_h * block_h
317
+ min_w = cell_w * block_w
318
+ use_hog = (H > 2*min_h) and (W > 2*min_w)
319
  # 1. HOG
320
+ if use_hog:
321
+ hog_descriptors, hog_image = hog(
322
+ gray,
323
+ orientations=cfg.hog_orientations,
324
+ pixels_per_cell=cfg.hog_pixels_per_cell,
325
+ cells_per_block=cfg.hog_cells_per_block,
326
+ block_norm=cfg.hog_block_norm,
327
+ visualize=True,
328
+ feature_vector=False
329
+ )
330
 
331
+ hog_cells = hog_descriptors.mean(axis=(2, 3))
332
+
333
+ cell_h, cell_w = cfg.hog_pixels_per_cell
334
+ hog_pixel = np.repeat(
335
+ np.repeat(hog_cells, cell_h, axis=0),
336
+ cell_w, axis=1
337
+ )
338
+ hog_pixel = hog_pixel[:gray.shape[0], :gray.shape[1]]
339
+ hog_energy = np.sum(hog_pixel, axis=2)
340
+ dominant_bin = np.argmax(hog_pixel, axis=2)
341
+ dominant_strength = np.max(hog_pixel, axis=2)
342
+ dominant_weighted = dominant_bin * dominant_strength
343
+ valid_H, valid_W = hog_pixel.shape[:2]
344
+ if visualize:
345
+ # figs.append(plot_feature(hog_energy, "HOG Energy"))
346
+ # figs.append(plot_feature(dominant_bin, "HOG Dominant Bin",cmap='hsv'))
347
+ # figs.append(plot_feature(dominant_weighted, "HOG Weighted Dominant Bin"))
348
+ # figs.append(plot_feature(hog_image[:valid_H, :valid_W], f"HoG"))
349
+ vis_items.append((hog_energy, "HOG Energy",'gray'))
350
+ vis_items.append((dominant_bin, "HOG Dominant Bin",'hsv'))
351
+ vis_items.append((dominant_weighted, "HOG Weighted Dominant Bin",'gray'))
352
+ vis_items.append((hog_image[:valid_H, :valid_W], f"HoG",'gray'))
353
+ for b in range(hog_pixel.shape[2]):
354
+ feature_list.append(hog_pixel[:, :, b])
355
+
356
+
357
  # 2. Canny edges
358
  edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
359
+ # feature_list.append(edges.ravel())
360
+ feature_list.append(edges[:valid_H, :valid_W])
361
+ if visualize:
362
+ # figs.append(plot_feature(edges[:valid_H, :valid_W], "Canny Edge"))
363
+ vis_items.append((edges[:valid_H, :valid_W], "Canny Edge", "gray"))
364
  # 3. Harris corners
365
  harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
366
  harris = cv2.dilate(harris, None)
367
  harris = np.clip(harris, 0, 1)
368
+ # feature_list.append(harris.ravel())
369
+ feature_list.append(harris[:valid_H, :valid_W])
370
+ if visualize:
371
+ # figs.append(plot_feature(harris[:valid_H, :valid_W], "Harris Corner"))
372
+ vis_items.append((harris[:valid_H, :valid_W], "Harris Corner", "gray"))
373
+ # # 4. Shi-Tomasi corners
374
+ # shi_corners = np.zeros_like(gray, dtype=np.float32)
375
+ # keypoints = cv2.goodFeaturesToTrack(gray, maxCorners=cfg.shi_max_corners, qualityLevel=cfg.shi_quality_level, minDistance=cfg.shi_min_distance)
376
+ # if keypoints is not None:
377
+ # for kp in keypoints:
378
+ # x, y = kp.ravel()
379
+ # shi_corners[int(y), int(x)] = 1.0
380
+ # # feature_list.append(shi_corners.ravel())
381
+ # feature_list.append(shi_corners[:valid_H, :valid_W])
382
+ # if visualize:
383
+ # figs.append(plot_feature(shi_corners[:valid_H, :valid_W], "Shi-Tomasi Corner"))
384
  # 5. LBP
385
  lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
386
  lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
387
+ # feature_list.append(lbp.ravel())
388
+ feature_list.append(lbp[:valid_H, :valid_W])
389
+ if visualize:
390
+ # figs.append(plot_feature(lbp[:valid_H, :valid_W], "LBP"))
391
+ vis_items.append((lbp[:valid_H, :valid_W], "LBP", "gray"))
392
  # 6. Gabor filter
393
+ # g_kernel = cv2.getGaborKernel((cfg.gabor_ksize, cfg.gabor_ksize), cfg.gabor_sigma, cfg.gabor_theta, cfg.gabor_lambda, cfg.gabor_gamma)
394
+ # gabor_feat = cv2.filter2D(gray, cv2.CV_32F, g_kernel)
395
+ # gabor_feat = (gabor_feat - gabor_feat.min()) / (gabor_feat.max() - gabor_feat.min() + 1e-8)
396
+ # # feature_list.append(gabor_feat.ravel())
397
+ # feature_list.append(gabor_feat[:valid_H, :valid_W])
398
+ # if visualize:
399
+ # figs.append(plot_feature(gabor_feat[:valid_H, :valid_W], "Gabor Filter"))
400
+
401
+ for theta in [0, np.pi/4, np.pi/2]:
402
+ kernel = cv2.getGaborKernel(
403
+ (cfg.gabor_ksize, cfg.gabor_ksize),
404
+ cfg.gabor_sigma, theta,
405
+ cfg.gabor_lambda, cfg.gabor_gamma
406
+ )
407
+ g = cv2.filter2D(gray, cv2.CV_32F, kernel)
408
+ g = np.abs(g)
409
+ g /= g.max() + 1e-8
410
+ feature_list.append(g[:valid_H, :valid_W])
411
+ if visualize:
412
+ # figs.append(plot_feature(g[:valid_H, :valid_W], "Gabor Filter"))
413
+ vis_items.append((g[:valid_H, :valid_W], f"Gabor θ={theta:.2f}", "gray"))
414
+ # 7. Sobel
415
+ sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
416
+ sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
417
+
418
+ sobelx = np.abs(sobelx)
419
+ sobely = np.abs(sobely)
420
+
421
+ sobelx /= sobelx.max() + 1e-8
422
+ sobely /= sobely.max() + 1e-8
423
+
424
+ feature_list.append(sobelx[:valid_H, :valid_W])
425
+ feature_list.append(sobely[:valid_H, :valid_W])
426
+ if visualize:
427
+ # figs.append(plot_feature(sobelx[:valid_H, :valid_W], "Sobel X"))
428
+ # figs.append(plot_feature(sobely[:valid_H, :valid_W], "Sobel Y"))
429
+ vis_items.append((sobelx[:valid_H, :valid_W], "Sobel X",'gray'))
430
+ vis_items.append((sobely[:valid_H, :valid_W], "Sobel Y",'gray'))
431
+ # 8. Laplacian
432
+ lap = cv2.Laplacian(gray, cv2.CV_32F)
433
+ lap = np.abs(lap)
434
+ lap /= lap.max() + 1e-8
435
+
436
+ feature_list.append(lap[:valid_H, :valid_W])
437
+
438
+ if visualize:
439
+ # figs.append(plot_feature(lap[:valid_H, :valid_W], "Laplacian"))
440
+ vis_items.append((lap[:valid_H, :valid_W], "Laplacian",'gray'))
441
+
442
+ # 9. Gradient Magnitude
443
+ gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
444
+ gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
445
+
446
+ grad_mag = np.sqrt(gx**2 + gy**2)
447
+ grad_mag /= grad_mag.max() + 1e-8
448
+
449
+ feature_list.append(grad_mag[:valid_H, :valid_W])
450
+
451
+ if visualize:
452
+ # figs.append(plot_feature(grad_mag[:valid_H, :valid_W], "Gradient Magnitude"))
453
+ vis_items.append((grad_mag[:valid_H, :valid_W], "Gradient Magnitude",'gray'))
454
 
455
  # Stack all features along channel axis
456
+ features = np.stack(feature_list, axis=0)
457
+ # features = np.concatenate(feature_list).astype(np.float32)
458
+ if visualize:
459
+ return features.astype(np.float32),[render_subplots(vis_items, max_cols=8)]
460
  return features.astype(np.float32)
461
 
462
 
 
478
  feat = self.extract_features(img)
479
  batch_features.append(feat)
480
  batch_features = np.stack(batch_features, axis=0)
481
+ batch_features = torch.from_numpy(batch_features).float().to(self.get_device())
482
+ return batch_features
483
 
484
  def visualize(self, img, show_original=True,show=True):
485
  if img.ndim != 3 or img.shape[2] != 3:
486
  img = np.repeat(img[:, :, None], 3, axis=2)
487
 
 
 
 
488
  outputs = []
489
 
490
  def fig_to_pil(fig):
 
507
  if show:
508
  plt.show()
509
  outputs.append(fig_to_pil(fig))
510
+ feature_stack,figs = self.extract_features(img,visualize=True)
511
+ if show:
512
+ plt.show()
513
+ for fig in figs:
 
 
 
 
 
514
  outputs.append(fig_to_pil(fig))
515
 
516
  return outputs