VJyzCELERY commited on
Commit
5e96bc9
·
1 Parent(s): 4631366
app.py CHANGED
@@ -2,10 +2,10 @@ import gradio as gr
2
  import zipfile
3
  import os
4
  import torch
5
- from src.dataloader import ImageDataset,collate_fn
6
  from src.model import Classifier,Config,CNNFeatureExtractor,ClassicalFeatureExtractor,load
7
  from torch.utils.data import Subset
8
- from src.trainer import ModelTrainer
9
  import torch
10
  import os
11
  import numpy as np
@@ -17,9 +17,6 @@ import matplotlib.pyplot as plt
17
  import shutil
18
  import pandas as pd
19
  from sklearn.model_selection import train_test_split
20
- from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
21
- from sklearn.metrics import classification_report
22
- from torch.utils.data import DataLoader
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  def unzip_dataset(zip_file):
@@ -82,13 +79,22 @@ def plot(datas, labels, xlabel, ylabel, title, figsize=(16, 8)):
82
 
83
  class TrainingInterrupted(Exception):
84
  pass
 
 
85
  def stop_training():
86
  global training_interrupt
87
  training_interrupt = True
 
 
 
 
88
  return "Training stopped."
89
 
 
 
90
  def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visualize_every=5):
91
  global training_interrupt
 
92
  training_interrupt = False
93
  global cnn_history
94
  global classic_history
@@ -163,17 +169,17 @@ def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visual
163
  print(e)
164
  return
165
 
166
- def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
167
  img_width,img_height,fc_num_layers,
168
  in_channels,conv_hidden_dim,dropout,
169
  classical_downsample,
170
- hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
171
  canny_sigma,canny_low,canny_high,
172
  gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
173
  harris_block_size,harris_ksize,harris_k,
174
- shi_max_corners,shi_quality_level,shi_min_distance,
175
  lbp_P,lbp_R,
176
- gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma):
 
177
  config = Config()
178
  global training_interrupt
179
  training_interrupt = False
@@ -190,10 +196,10 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
190
  config.dropout=dropout
191
  # Classical Config
192
  config.classical_downsample=int(classical_downsample)
193
- config.hog_orientations=int(hog_orientations)
194
- config.hog_pixels_per_cell=(int(hog_pixels_per_cell),int(hog_pixels_per_cell))
195
- config.hog_cells_per_block=(int(hog_cells_per_block),int(hog_cells_per_block))
196
- config.hog_block_norm=hog_block_norm
197
  config.canny_sigma=int(canny_sigma)
198
  config.canny_low=canny_low
199
  config.canny_high=canny_high
@@ -203,9 +209,6 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
203
  config.harris_block_size=int(harris_block_size)
204
  config.harris_ksize=int(harris_ksize)
205
  config.harris_k=harris_k
206
- config.shi_max_corners=int(shi_max_corners)
207
- config.shi_quality_level=shi_quality_level
208
- config.shi_min_distance=int(shi_min_distance)
209
  config.lbp_P=int(lbp_P)
210
  config.lbp_R=int(lbp_R)
211
  config.gabor_ksize=int(gabor_ksize)
@@ -213,12 +216,14 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
213
  config.gabor_theta=int(gabor_theta)
214
  config.gabor_lambda=int(gabor_lambda)
215
  config.gabor_gamma=gabor_gamma
 
216
  cnn_history_plots=[]
217
  classical_history_plots=[]
218
  cnn_plotted=False
219
  try:
220
  dataset = ImageDataset(DATASET_PATH,config.img_size)
221
  labels = [item['id'] for item in dataset.data]
 
222
  train_idx, validation_idx = train_test_split(np.arange(len(dataset)),
223
  test_size=0.2,
224
  random_state=SEED,
@@ -226,6 +231,12 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
226
  stratify=labels)
227
  train_dataset = Subset(dataset, train_idx)
228
  val_dataset = Subset(dataset, validation_idx)
 
 
 
 
 
 
229
  cnnbackbone = CNNFeatureExtractor(config).to(device)
230
  cnnmodel = Classifier(cnnbackbone,train_dataset.dataset.classes,config).to(device)
231
  classicbackbone = ClassicalFeatureExtractor(config)
@@ -235,11 +246,17 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
235
  cnn_plotted=True
236
  cnn_history_plots.append(plot([cnn_history['train_acc'],cnn_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
237
  cnn_history_plots.append(plot([cnn_history['train_loss'],cnn_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
238
-
 
 
 
239
  yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
240
  classical_history_plots.append(plot([classic_history['train_acc'],classic_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
241
  classical_history_plots.append(plot([classic_history['train_loss'],classic_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
242
-
 
 
 
243
  yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
244
 
245
  except RuntimeError as e:
@@ -314,6 +331,7 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
314
  epochs= gr.Number(value=20,label="Epochs",interactive=True,precision=0)
315
  seed=gr.Number(value=42,label='Seed',interactive=True,precision=0)
316
  vis_every=gr.Number(value=5,label='Visualize Validation Every (Epochs)',interactive=True,precision=0)
 
317
  with gr.Row():
318
  img_width=gr.Number(value=128,label='Image Width',interactive=True,precision=0)
319
  img_height=gr.Number(value=128,label='Image Height',interactive=True,precision=0)
@@ -328,11 +346,12 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
328
  with gr.Accordion(label='Classical Feature Extractor Settings',open=False):
329
  with gr.Row():
330
  classical_downsample = gr.Number(value=1,label='Classical Extractor Downsampling Amount',interactive=True,precision=0)
331
- with gr.Row():
332
- hog_orientations = gr.Number(value=9,label='HoG Orientations',interactive=True,precision=0)
333
- hog_pixels_per_cell = gr.Number(value=16,label='HoG pixels per cell',interactive=True,precision=0)
334
- hog_cells_per_block = gr.Number(value=2,label='HoG cells per block',interactive=True,precision=0)
335
- hog_block_norm = gr.Dropdown(['L2-Hys'],value='L2-Hys',label='HoG Block Normalization Method',interactive=True)
 
336
  with gr.Row():
337
  canny_sigma = gr.Number(value=1.0,label='Canny Sigma Value',interactive=True)
338
  canny_low = gr.Number(value=100,label='Canny Low Threshold',interactive=True,precision=0)
@@ -345,19 +364,17 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
345
  harris_block_size = gr.Number(value=2,label='Harris Corner Block Size',interactive=True,precision=0)
346
  harris_ksize = gr.Number(value=3,label='Harris Corner Kernel Size',interactive=True,precision=0)
347
  harris_k = gr.Slider(minimum=0.01, maximum=0.1, value=0.04, step=0.005, label='Harris Corner K value',interactive=True)
348
- with gr.Row():
349
- shi_max_corners = gr.Number(value=100,label='Shi-Tomasi Max Corners',interactive=True,precision=0)
350
- shi_quality_level = gr.Number(value=0.01,label='Shi-Tomasi Quality Level',interactive=True)
351
- shi_min_distance = gr.Number(value=10,label='Shi-Tomasi Min Distance',interactive=True,precision=0)
352
  with gr.Row():
353
  lbp_P = gr.Number(value=8,label='LBP P Value',interactive=True,precision=0)
354
  lbp_R = gr.Number(value=1,label='LBP R Value',interactive=True,precision=0)
355
  with gr.Row():
356
  gabor_ksize = gr.Number(value=21,label="Gabor Kernel Size",interactive=True,precision=0)
357
  gabor_sigma = gr.Number(value=5,label="Gabor Sigma",interactive=True,precision=0)
358
- gabor_theta = gr.Number(value=0,label="Gabor Theta",interactive=True,precision=0)
359
  gabor_lambda = gr.Number(value=10,label="Gabor Lambda",interactive=True,precision=0)
360
  gabor_gamma = gr.Number(value=0.5,label="Gabor Gamma",interactive=True)
 
 
361
  with gr.Column():
362
  train_btn = gr.Button("Train Model",variant='secondary',interactive=True)
363
  stop_btn = gr.Button("Stop Training")
@@ -375,17 +392,17 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
375
  classical_plots = gr.Gallery(label="CNN Training Performance",interactive=False,object_fit='fill',columns=1)
376
  stop_btn.click(fn=stop_training, inputs=[], outputs=[])
377
  train_btn.click(fn=train_model,
378
- inputs=[zip_file,batch_size,lr,epochs,seed,vis_every,
379
  img_width,img_height,fc_num_layers,
380
  in_channels,conv_hidden_dim,dropout,
381
  classical_downsample,
382
- hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
383
  canny_sigma,canny_low,canny_high,
384
  gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
385
  harris_block_size,harris_ksize,harris_k,
386
- shi_max_corners,shi_quality_level,shi_min_distance,
387
  lbp_P,lbp_R,
388
- gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma],
 
389
  outputs=[cnn_log,cnn_fig,classical_log,classical_fig,cnn_plots,classical_plots]
390
  )
391
  def make_figure_from_image(img):
 
2
  import zipfile
3
  import os
4
  import torch
5
+ from src.dataloader import ImageDataset,collate_fn,AugmentedSubset,simple_augment
6
  from src.model import Classifier,Config,CNNFeatureExtractor,ClassicalFeatureExtractor,load
7
  from torch.utils.data import Subset
8
+ from src.trainer import ModelTrainer,model_evaluation
9
  import torch
10
  import os
11
  import numpy as np
 
17
  import shutil
18
  import pandas as pd
19
  from sklearn.model_selection import train_test_split
 
 
 
20
 
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  def unzip_dataset(zip_file):
 
79
 
80
  class TrainingInterrupted(Exception):
81
  pass
82
+ cnntrainer=None
83
+ classictrainer=None
84
  def stop_training():
85
  global training_interrupt
86
  training_interrupt = True
87
+ if cnntrainer is not None:
88
+ cnntrainer.interrupt=True
89
+ if classictrainer is not None:
90
+ classictrainer.interrupt=True
91
  return "Training stopped."
92
 
93
+
94
+
95
  def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visualize_every=5):
96
  global training_interrupt
97
+ global cnntrainer,classictrainer
98
  training_interrupt = False
99
  global cnn_history
100
  global classic_history
 
169
  print(e)
170
  return
171
 
172
+ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,use_augment,
173
  img_width,img_height,fc_num_layers,
174
  in_channels,conv_hidden_dim,dropout,
175
  classical_downsample,
176
+ # hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
177
  canny_sigma,canny_low,canny_high,
178
  gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
179
  harris_block_size,harris_ksize,harris_k,
 
180
  lbp_P,lbp_R,
181
+ gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma,
182
+ sobel_ksize):
183
  config = Config()
184
  global training_interrupt
185
  training_interrupt = False
 
196
  config.dropout=dropout
197
  # Classical Config
198
  config.classical_downsample=int(classical_downsample)
199
+ # config.hog_orientations=int(hog_orientations)
200
+ # config.hog_pixels_per_cell=(int(hog_pixels_per_cell),int(hog_pixels_per_cell))
201
+ # config.hog_cells_per_block=(int(hog_cells_per_block),int(hog_cells_per_block))
202
+ # config.hog_block_norm=hog_block_norm
203
  config.canny_sigma=int(canny_sigma)
204
  config.canny_low=canny_low
205
  config.canny_high=canny_high
 
209
  config.harris_block_size=int(harris_block_size)
210
  config.harris_ksize=int(harris_ksize)
211
  config.harris_k=harris_k
 
 
 
212
  config.lbp_P=int(lbp_P)
213
  config.lbp_R=int(lbp_R)
214
  config.gabor_ksize=int(gabor_ksize)
 
216
  config.gabor_theta=int(gabor_theta)
217
  config.gabor_lambda=int(gabor_lambda)
218
  config.gabor_gamma=gabor_gamma
219
+ config.sobel_ksize=int(sobel_ksize)
220
  cnn_history_plots=[]
221
  classical_history_plots=[]
222
  cnn_plotted=False
223
  try:
224
  dataset = ImageDataset(DATASET_PATH,config.img_size)
225
  labels = [item['id'] for item in dataset.data]
226
+ classes_name = dataset.classes
227
  train_idx, validation_idx = train_test_split(np.arange(len(dataset)),
228
  test_size=0.2,
229
  random_state=SEED,
 
231
  stratify=labels)
232
  train_dataset = Subset(dataset, train_idx)
233
  val_dataset = Subset(dataset, validation_idx)
234
+ if use_augment:
235
+ train_dataset = AugmentedSubset(train_dataset,simple_augment)
236
+ val_dataset = AugmentedSubset(val_dataset,None)
237
+ else:
238
+ train_dataset = AugmentedSubset(train_dataset,None)
239
+ val_dataset = AugmentedSubset(val_dataset,None)
240
  cnnbackbone = CNNFeatureExtractor(config).to(device)
241
  cnnmodel = Classifier(cnnbackbone,train_dataset.dataset.classes,config).to(device)
242
  classicbackbone = ClassicalFeatureExtractor(config)
 
246
  cnn_plotted=True
247
  cnn_history_plots.append(plot([cnn_history['train_acc'],cnn_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
248
  cnn_history_plots.append(plot([cnn_history['train_loss'],cnn_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
249
+ cm,cr,roc = model_evaluation(cnnmodel,val_dataset,device,BATCH_SIZE,0,classes_name)
250
+ cnn_history_plots.append(fig_to_image(cm))
251
+ cnn_history_plots.append(fig_to_image(cr))
252
+ cnn_history_plots.append(fig_to_image(roc))
253
  yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
254
  classical_history_plots.append(plot([classic_history['train_acc'],classic_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
255
  classical_history_plots.append(plot([classic_history['train_loss'],classic_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
256
+ cm,cr,roc = model_evaluation(classicmodel,val_dataset,device,BATCH_SIZE,0,classes_name)
257
+ classical_history_plots.append(fig_to_image(cm))
258
+ classical_history_plots.append(fig_to_image(cr))
259
+ classical_history_plots.append(fig_to_image(roc))
260
  yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
261
 
262
  except RuntimeError as e:
 
331
  epochs= gr.Number(value=20,label="Epochs",interactive=True,precision=0)
332
  seed=gr.Number(value=42,label='Seed',interactive=True,precision=0)
333
  vis_every=gr.Number(value=5,label='Visualize Validation Every (Epochs)',interactive=True,precision=0)
334
+ use_augment = gr.Checkbox(value=True,label='Use data augmentation for train data')
335
  with gr.Row():
336
  img_width=gr.Number(value=128,label='Image Width',interactive=True,precision=0)
337
  img_height=gr.Number(value=128,label='Image Height',interactive=True,precision=0)
 
346
  with gr.Accordion(label='Classical Feature Extractor Settings',open=False):
347
  with gr.Row():
348
  classical_downsample = gr.Number(value=1,label='Classical Extractor Downsampling Amount',interactive=True,precision=0)
349
+ # Deprecated
350
+ # with gr.Row():
351
+ # hog_orientations = gr.Number(value=9,label='HoG Orientations',interactive=True,precision=0)
352
+ # hog_pixels_per_cell = gr.Number(value=16,label='HoG pixels per cell',interactive=True,precision=0)
353
+ # hog_cells_per_block = gr.Number(value=2,label='HoG cells per block',interactive=True,precision=0)
354
+ # hog_block_norm = gr.Dropdown(['L2-Hys'],value='L2-Hys',label='HoG Block Normalization Method',interactive=True)
355
  with gr.Row():
356
  canny_sigma = gr.Number(value=1.0,label='Canny Sigma Value',interactive=True)
357
  canny_low = gr.Number(value=100,label='Canny Low Threshold',interactive=True,precision=0)
 
364
  harris_block_size = gr.Number(value=2,label='Harris Corner Block Size',interactive=True,precision=0)
365
  harris_ksize = gr.Number(value=3,label='Harris Corner Kernel Size',interactive=True,precision=0)
366
  harris_k = gr.Slider(minimum=0.01, maximum=0.1, value=0.04, step=0.005, label='Harris Corner K value',interactive=True)
 
 
 
 
367
  with gr.Row():
368
  lbp_P = gr.Number(value=8,label='LBP P Value',interactive=True,precision=0)
369
  lbp_R = gr.Number(value=1,label='LBP R Value',interactive=True,precision=0)
370
  with gr.Row():
371
  gabor_ksize = gr.Number(value=21,label="Gabor Kernel Size",interactive=True,precision=0)
372
  gabor_sigma = gr.Number(value=5,label="Gabor Sigma",interactive=True,precision=0)
373
+ gabor_theta = gr.Number(value=0,label="Gabor Theta",interactive=True,precision=0,info="This current does nothing")
374
  gabor_lambda = gr.Number(value=10,label="Gabor Lambda",interactive=True,precision=0)
375
  gabor_gamma = gr.Number(value=0.5,label="Gabor Gamma",interactive=True)
376
+ with gr.Row():
377
+ sobel_ksize = gr.Number(value=3,label="Sobel Kernel Size",interactive=True,precision=0)
378
  with gr.Column():
379
  train_btn = gr.Button("Train Model",variant='secondary',interactive=True)
380
  stop_btn = gr.Button("Stop Training")
 
392
  classical_plots = gr.Gallery(label="CNN Training Performance",interactive=False,object_fit='fill',columns=1)
393
  stop_btn.click(fn=stop_training, inputs=[], outputs=[])
394
  train_btn.click(fn=train_model,
395
+ inputs=[zip_file,batch_size,lr,epochs,seed,vis_every,use_augment,
396
  img_width,img_height,fc_num_layers,
397
  in_channels,conv_hidden_dim,dropout,
398
  classical_downsample,
399
+ # hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
400
  canny_sigma,canny_low,canny_high,
401
  gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
402
  harris_block_size,harris_ksize,harris_k,
 
403
  lbp_P,lbp_R,
404
+ gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma,
405
+ sobel_ksize],
406
  outputs=[cnn_log,cnn_fig,classical_log,classical_fig,cnn_plots,classical_plots]
407
  )
408
  def make_figure_from_image(img):
src/__pycache__/dataloader.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/dataloader.cpython-312.pyc and b/src/__pycache__/dataloader.cpython-312.pyc differ
 
src/__pycache__/model.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/model.cpython-312.pyc and b/src/__pycache__/model.cpython-312.pyc differ
 
src/__pycache__/trainer.cpython-312.pyc CHANGED
Binary files a/src/__pycache__/trainer.cpython-312.pyc and b/src/__pycache__/trainer.cpython-312.pyc differ
 
src/dataloader.py CHANGED
@@ -1,4 +1,4 @@
1
- from torch.utils.data import Dataset
2
  import torch
3
  import os
4
  import numpy as np
@@ -37,4 +37,25 @@ class ImageDataset(Dataset):
37
  img = img.astype(np.float32) / 255.0
38
  return img,label
39
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Subset,Dataset
2
  import torch
3
  import os
4
  import numpy as np
 
37
  img = img.astype(np.float32) / 255.0
38
  return img,label
39
 
40
+ def simple_augment(img):
41
+ if np.random.rand() > 0.5:
42
+ img = cv2.flip(img, 1)
43
 
44
+ angle = np.random.uniform(-15, 15)
45
+ h, w = img.shape[:2]
46
+ M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
47
+ img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
48
+
49
+ return img
50
+
51
+
52
+ class AugmentedSubset(Subset):
53
+ def __init__(self, subset, augment_fn=None):
54
+ super().__init__(subset.dataset, subset.indices)
55
+ self.augment_fn = augment_fn
56
+
57
+ def __getitem__(self, idx):
58
+ img, label = super().__getitem__(idx)
59
+ if self.augment_fn:
60
+ img = self.augment_fn(img)
61
+ return img, label
src/model.py CHANGED
@@ -4,6 +4,8 @@ import cv2
4
  import numpy as np
5
  from dataclasses import dataclass
6
  from skimage.feature import hog,local_binary_pattern
 
 
7
  import matplotlib.pyplot as plt
8
  import os
9
  import io
@@ -14,7 +16,7 @@ class Config:
14
  img_size=(256,256)
15
  in_channels=3
16
  fc_num_layers=3
17
- conv_hidden_dim=3
18
  conv_kernel_size=3
19
  dropout=0.2
20
  classical_downsample=1
@@ -39,10 +41,6 @@ class Config:
39
  harris_ksize = 3
40
  harris_k = 0.04
41
 
42
- # Shi-Tomasi corners
43
- shi_max_corners = 100
44
- shi_quality_level = 0.01
45
- shi_min_distance = 10
46
 
47
  # LBP
48
  lbp_P = 8
@@ -58,6 +56,7 @@ class Config:
58
  # Sobel
59
  sobel_ksize=3
60
 
 
61
  class CNNFeatureExtractor(nn.Module):
62
  def __init__(self,config : Config):
63
  super().__init__()
@@ -67,16 +66,16 @@ class CNNFeatureExtractor(nn.Module):
67
  self.img_size = config.img_size
68
  out_channel = 32
69
  for i in range(config.conv_hidden_dim):
70
- layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=1))
71
  layers.append(nn.BatchNorm2d(out_channel))
72
  layers.append(nn.ReLU())
73
- layers.append(nn.MaxPool2d(2))
74
  in_channel=out_channel
75
  out_channel*=2
76
  self.layers = nn.Sequential(*layers)
77
  def get_device(self):
78
  return next(self.parameters()).device
79
- def forward(self,x):
80
  if isinstance(x, list):
81
  if isinstance(x[0], np.ndarray):
82
  x = np.stack(x, axis=0)
@@ -129,7 +128,7 @@ class CNNFeatureExtractor(nn.Module):
129
  conv_layers = [
130
  (name, module)
131
  for name, module in self.named_modules()
132
- if isinstance(module, nn.Conv2d)
133
  ]
134
 
135
  all_layer_images = []
@@ -251,7 +250,7 @@ class CNNFeatureExtractor(nn.Module):
251
  plt.close(fig)
252
 
253
  return all_layer_images
254
-
255
  class ClassicalFeatureExtractor(nn.Module):
256
  def __init__(self, config : Config):
257
  super().__init__()
@@ -260,128 +259,103 @@ class ClassicalFeatureExtractor(nn.Module):
260
  self.num_downsample = config.classical_downsample
261
  self.config = config
262
  self.device = 'cpu'
263
-
 
264
  def get_device(self):
265
  return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
266
 
267
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  def extract_features(self, img,visualize=False,**kwargs):
269
  cfg = self.config
270
-
271
  # Convert to grayscale
272
  gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
273
-
274
  for _ in range(self.num_downsample):
275
  gray = cv2.pyrDown(gray)
276
-
277
  gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
278
  valid_H, valid_W = gray.shape[:2]
279
 
280
- def render_subplots(items, max_cols=8, figsize_per_cell=3):
281
- n = len(items)
282
- cols = min(max_cols, n)
283
- rows = int(np.ceil(n / cols))
284
-
285
- fig, axes = plt.subplots(
286
- rows, cols,
287
- figsize=(cols * figsize_per_cell, rows * figsize_per_cell)
288
- )
289
-
290
- axes = np.atleast_2d(axes)
291
-
292
- for idx, (img, title, cmap) in enumerate(items):
293
- r = idx // cols
294
- c = idx % cols
295
- ax = axes[r, c]
296
- ax.imshow(img, cmap=cmap)
297
- ax.set_title(title, fontsize=9)
298
- ax.axis("off")
299
-
300
- # Hide unused axes
301
- for idx in range(n, rows * cols):
302
- r = idx // cols
303
- c = idx % cols
304
- axes[r, c].axis("off")
305
-
306
- plt.tight_layout()
307
- return fig
308
 
309
  feature_list = []
310
  vis_items=[]
311
- # figs = []
312
- H, W = gray.shape
313
- cell_h, cell_w = cfg.hog_pixels_per_cell
314
- block_h, block_w = cfg.hog_cells_per_block
315
-
316
- min_h = cell_h * block_h
317
- min_w = cell_w * block_w
318
- use_hog = (H > 2*min_h) and (W > 2*min_w)
319
- # 1. HOG
320
- if use_hog:
321
- hog_descriptors, hog_image = hog(
322
- gray,
323
- orientations=cfg.hog_orientations,
324
- pixels_per_cell=cfg.hog_pixels_per_cell,
325
- cells_per_block=cfg.hog_cells_per_block,
326
- block_norm=cfg.hog_block_norm,
327
- visualize=True,
328
- feature_vector=False
329
- )
330
-
331
- hog_cells = hog_descriptors.mean(axis=(2, 3))
332
 
333
- cell_h, cell_w = cfg.hog_pixels_per_cell
334
- hog_pixel = np.repeat(
335
- np.repeat(hog_cells, cell_h, axis=0),
336
- cell_w, axis=1
337
- )
338
- hog_pixel = hog_pixel[:gray.shape[0], :gray.shape[1]]
339
- hog_energy = np.sum(hog_pixel, axis=2)
340
- dominant_bin = np.argmax(hog_pixel, axis=2)
341
- dominant_strength = np.max(hog_pixel, axis=2)
342
- dominant_weighted = dominant_bin * dominant_strength
343
- valid_H, valid_W = hog_pixel.shape[:2]
344
- if visualize:
345
- # figs.append(plot_feature(hog_energy, "HOG Energy"))
346
- # figs.append(plot_feature(dominant_bin, "HOG Dominant Bin",cmap='hsv'))
347
- # figs.append(plot_feature(dominant_weighted, "HOG Weighted Dominant Bin"))
348
- # figs.append(plot_feature(hog_image[:valid_H, :valid_W], f"HoG"))
349
- vis_items.append((hog_energy, "HOG Energy",'gray'))
350
- vis_items.append((dominant_bin, "HOG Dominant Bin",'hsv'))
351
- vis_items.append((dominant_weighted, "HOG Weighted Dominant Bin",'gray'))
352
- vis_items.append((hog_image[:valid_H, :valid_W], f"HoG",'gray'))
353
- for b in range(hog_pixel.shape[2]):
354
- feature_list.append(hog_pixel[:, :, b])
355
 
356
 
357
  # 2. Canny edges
358
  edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
359
- # feature_list.append(edges.ravel())
360
  feature_list.append(edges[:valid_H, :valid_W])
361
  if visualize:
362
- # figs.append(plot_feature(edges[:valid_H, :valid_W], "Canny Edge"))
363
  vis_items.append((edges[:valid_H, :valid_W], "Canny Edge", "gray"))
364
  # 3. Harris corners
365
  harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
366
  harris = cv2.dilate(harris, None)
367
  harris = np.clip(harris, 0, 1)
368
- # feature_list.append(harris.ravel())
369
  feature_list.append(harris[:valid_H, :valid_W])
370
  if visualize:
371
- # figs.append(plot_feature(harris[:valid_H, :valid_W], "Harris Corner"))
372
  vis_items.append((harris[:valid_H, :valid_W], "Harris Corner", "gray"))
373
- # # 4. Shi-Tomasi corners
374
- # shi_corners = np.zeros_like(gray, dtype=np.float32)
375
- # keypoints = cv2.goodFeaturesToTrack(gray, maxCorners=cfg.shi_max_corners, qualityLevel=cfg.shi_quality_level, minDistance=cfg.shi_min_distance)
376
- # if keypoints is not None:
377
- # for kp in keypoints:
378
- # x, y = kp.ravel()
379
- # shi_corners[int(y), int(x)] = 1.0
380
- # # feature_list.append(shi_corners.ravel())
381
- # feature_list.append(shi_corners[:valid_H, :valid_W])
382
- # if visualize:
383
- # figs.append(plot_feature(shi_corners[:valid_H, :valid_W], "Shi-Tomasi Corner"))
384
- # 5. LBP
385
  lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
386
  lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
387
  # feature_list.append(lbp.ravel())
@@ -389,15 +363,7 @@ class ClassicalFeatureExtractor(nn.Module):
389
  if visualize:
390
  # figs.append(plot_feature(lbp[:valid_H, :valid_W], "LBP"))
391
  vis_items.append((lbp[:valid_H, :valid_W], "LBP", "gray"))
392
- # 6. Gabor filter
393
- # g_kernel = cv2.getGaborKernel((cfg.gabor_ksize, cfg.gabor_ksize), cfg.gabor_sigma, cfg.gabor_theta, cfg.gabor_lambda, cfg.gabor_gamma)
394
- # gabor_feat = cv2.filter2D(gray, cv2.CV_32F, g_kernel)
395
- # gabor_feat = (gabor_feat - gabor_feat.min()) / (gabor_feat.max() - gabor_feat.min() + 1e-8)
396
- # # feature_list.append(gabor_feat.ravel())
397
- # feature_list.append(gabor_feat[:valid_H, :valid_W])
398
- # if visualize:
399
- # figs.append(plot_feature(gabor_feat[:valid_H, :valid_W], "Gabor Filter"))
400
-
401
  for theta in [0, np.pi/4, np.pi/2]:
402
  kernel = cv2.getGaborKernel(
403
  (cfg.gabor_ksize, cfg.gabor_ksize),
@@ -409,9 +375,8 @@ class ClassicalFeatureExtractor(nn.Module):
409
  g /= g.max() + 1e-8
410
  feature_list.append(g[:valid_H, :valid_W])
411
  if visualize:
412
- # figs.append(plot_feature(g[:valid_H, :valid_W], "Gabor Filter"))
413
  vis_items.append((g[:valid_H, :valid_W], f"Gabor θ={theta:.2f}", "gray"))
414
- # 7. Sobel
415
  sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
416
  sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
417
 
@@ -424,11 +389,9 @@ class ClassicalFeatureExtractor(nn.Module):
424
  feature_list.append(sobelx[:valid_H, :valid_W])
425
  feature_list.append(sobely[:valid_H, :valid_W])
426
  if visualize:
427
- # figs.append(plot_feature(sobelx[:valid_H, :valid_W], "Sobel X"))
428
- # figs.append(plot_feature(sobely[:valid_H, :valid_W], "Sobel Y"))
429
  vis_items.append((sobelx[:valid_H, :valid_W], "Sobel X",'gray'))
430
  vis_items.append((sobely[:valid_H, :valid_W], "Sobel Y",'gray'))
431
- # 8. Laplacian
432
  lap = cv2.Laplacian(gray, cv2.CV_32F)
433
  lap = np.abs(lap)
434
  lap /= lap.max() + 1e-8
@@ -436,10 +399,9 @@ class ClassicalFeatureExtractor(nn.Module):
436
  feature_list.append(lap[:valid_H, :valid_W])
437
 
438
  if visualize:
439
- # figs.append(plot_feature(lap[:valid_H, :valid_W], "Laplacian"))
440
  vis_items.append((lap[:valid_H, :valid_W], "Laplacian",'gray'))
441
 
442
- # 9. Gradient Magnitude
443
  gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
444
  gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
445
 
@@ -449,37 +411,38 @@ class ClassicalFeatureExtractor(nn.Module):
449
  feature_list.append(grad_mag[:valid_H, :valid_W])
450
 
451
  if visualize:
452
- # figs.append(plot_feature(grad_mag[:valid_H, :valid_W], "Gradient Magnitude"))
453
  vis_items.append((grad_mag[:valid_H, :valid_W], "Gradient Magnitude",'gray'))
454
 
455
  # Stack all features along channel axis
456
  features = np.stack(feature_list, axis=0)
457
- # features = np.concatenate(feature_list).astype(np.float32)
458
  if visualize:
459
- return features.astype(np.float32),[render_subplots(vis_items, max_cols=8)]
460
  return features.astype(np.float32)
461
 
462
 
463
- def forward(self, x):
 
 
 
464
  if isinstance(x, torch.Tensor):
465
  x = x.cpu().numpy()
 
466
  if isinstance(x, np.ndarray):
467
- if x.ndim == 3:
468
- x = np.expand_dims(x, 0)
469
  elif x.ndim != 4:
470
- raise ValueError(f"Expected input of shape HWC or BHWC, got {x.shape}")
471
- elif isinstance(x, list):
472
- x = np.stack(x, axis=0)
473
-
474
- batch_features = []
475
  for img in x:
476
- if img.ndim != 3 or img.shape[2] != 3:
477
  img = np.repeat(img[:, :, None], 3, axis=2)
478
- feat = self.extract_features(img)
479
- batch_features.append(feat)
480
- batch_features = np.stack(batch_features, axis=0)
481
- batch_features = torch.from_numpy(batch_features).float().to(self.get_device())
482
- return batch_features
483
 
484
  def visualize(self, img, show_original=True,show=True):
485
  if img.ndim != 3 or img.shape[2] != 3:
@@ -517,10 +480,14 @@ class ClassicalFeatureExtractor(nn.Module):
517
 
518
 
519
  def output(self):
520
- """Return dummy output to compute in_features for FC head"""
521
- dummy_img = np.zeros((1, self.img_size[1],self.img_size[0], 3), dtype=np.float32)
522
- feat = self.forward(dummy_img)
523
- return feat
 
 
 
 
524
 
525
 
526
 
@@ -530,21 +497,20 @@ class FullyConnectedHead(nn.Module):
530
  num_classes = len(classes)
531
  self.classes = classes
532
  layers = []
533
- out_features=256
534
- for i in range(config.fc_num_layers):
535
- layers.append(nn.Linear(in_features,out_features))
536
- layers.append(nn.BatchNorm1d(out_features))
537
  layers.append(nn.ReLU())
538
  layers.append(nn.Dropout(config.dropout))
539
- in_features=out_features
540
- out_features=out_features // 2
541
- if out_features <= num_classes:
542
- break
543
  layers.append(nn.Linear(in_features,num_classes))
544
  self.layers = nn.Sequential(*layers)
545
  def get_device(self):
546
  return next(self.parameters()).device
547
- def forward(self,x : torch.Tensor):
548
  x=x.to(self.get_device())
549
  return self.layers(x)
550
 
@@ -568,15 +534,15 @@ class Classifier(nn.Module):
568
  target_size = self.config.img_size
569
  x = cv2.resize(x, target_size)
570
  logits = self.forward(x)
571
- probs = torch.softmax(logits, dim=1)
572
  pred_idx = torch.argmax(probs, dim=1).item()
573
 
574
  return self.classes[pred_idx]
575
 
576
- def forward(self,x):
577
- feat = self.backbone(x)
578
- feat = self.flatten(feat)
579
- return self.head(feat)
580
  def visualize_feature(self,img,return_img=True,**kwargs):
581
  target_size = self.config.img_size
582
  img = cv2.resize(img, target_size)
 
4
  import numpy as np
5
  from dataclasses import dataclass
6
  from skimage.feature import hog,local_binary_pattern
7
+ import itertools
8
+ import torch.nn.functional as F
9
  import matplotlib.pyplot as plt
10
  import os
11
  import io
 
16
  img_size=(256,256)
17
  in_channels=3
18
  fc_num_layers=3
19
+ conv_hidden_dim=2
20
  conv_kernel_size=3
21
  dropout=0.2
22
  classical_downsample=1
 
41
  harris_ksize = 3
42
  harris_k = 0.04
43
 
 
 
 
 
44
 
45
  # LBP
46
  lbp_P = 8
 
56
  # Sobel
57
  sobel_ksize=3
58
 
59
+
60
  class CNNFeatureExtractor(nn.Module):
61
  def __init__(self,config : Config):
62
  super().__init__()
 
66
  self.img_size = config.img_size
67
  out_channel = 32
68
  for i in range(config.conv_hidden_dim):
69
+ layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=config.conv_kernel_size // 2))
70
  layers.append(nn.BatchNorm2d(out_channel))
71
  layers.append(nn.ReLU())
72
+ layers.append(nn.MaxPool2d((2,2)))
73
  in_channel=out_channel
74
  out_channel*=2
75
  self.layers = nn.Sequential(*layers)
76
  def get_device(self):
77
  return next(self.parameters()).device
78
+ def forward(self,x,**kwargs):
79
  if isinstance(x, list):
80
  if isinstance(x[0], np.ndarray):
81
  x = np.stack(x, axis=0)
 
128
  conv_layers = [
129
  (name, module)
130
  for name, module in self.named_modules()
131
+ if isinstance(module, nn.ReLU)
132
  ]
133
 
134
  all_layer_images = []
 
250
  plt.close(fig)
251
 
252
  return all_layer_images
253
+
254
  class ClassicalFeatureExtractor(nn.Module):
255
  def __init__(self, config : Config):
256
  super().__init__()
 
259
  self.num_downsample = config.classical_downsample
260
  self.config = config
261
  self.device = 'cpu'
262
+ self.convolution=None
263
+
264
  def get_device(self):
265
  return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
266
 
267
+ def render_subplots(self,items, max_cols=8, figsize_per_cell=3):
268
+ n = len(items)
269
+ cols = min(max_cols, n)
270
+ rows = int(np.ceil(n / cols))
271
+ fig, axes = plt.subplots(
272
+ rows, cols,
273
+ figsize=(cols * figsize_per_cell, rows * figsize_per_cell)
274
+ )
275
+ axes = np.atleast_2d(axes)
276
+ for idx, (img, title, cmap) in enumerate(items):
277
+ r = idx // cols
278
+ c = idx % cols
279
+ ax = axes[r, c]
280
+ ax.imshow(img, cmap=cmap)
281
+ ax.set_title(title, fontsize=9)
282
+ ax.axis("off")
283
+ for idx in range(n, rows * cols):
284
+ r = idx // cols
285
+ c = idx % cols
286
+ axes[r, c].axis("off")
287
+
288
+ plt.tight_layout()
289
+ return fig
290
+
291
  def extract_features(self, img,visualize=False,**kwargs):
292
  cfg = self.config
 
293
  # Convert to grayscale
294
  gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
 
295
  for _ in range(self.num_downsample):
296
  gray = cv2.pyrDown(gray)
 
297
  gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
298
  valid_H, valid_W = gray.shape[:2]
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  feature_list = []
302
  vis_items=[]
303
+ # DEPRECATED
304
+ # H, W = gray.shape
305
+ # cell_h, cell_w = cfg.hog_pixels_per_cell
306
+ # block_h, block_w = cfg.hog_cells_per_block
307
+
308
+ # min_h = cell_h * block_h
309
+ # min_w = cell_w * block_w
310
+ # use_hog = False
311
+ # # 1. HOG
312
+ # if use_hog:
313
+ # hog_descriptors, hog_image = hog(
314
+ # gray,
315
+ # orientations=cfg.hog_orientations,
316
+ # pixels_per_cell=cfg.hog_pixels_per_cell,
317
+ # cells_per_block=cfg.hog_cells_per_block,
318
+ # block_norm=cfg.hog_block_norm,
319
+ # visualize=True,
320
+ # feature_vector=False
321
+ # )
322
+
323
+ # hog_cells = hog_descriptors.mean(axis=(2, 3))
324
 
325
+ # cell_h, cell_w = cfg.hog_pixels_per_cell
326
+ # hog_pixel = np.repeat(
327
+ # np.repeat(hog_cells, cell_h, axis=0),
328
+ # cell_w, axis=1
329
+ # )
330
+ # hog_pixel = hog_pixel[:gray.shape[0], :gray.shape[1]]
331
+ # hog_energy = np.sum(hog_pixel, axis=2)
332
+ # dominant_bin = np.argmax(hog_pixel, axis=2)
333
+ # dominant_strength = np.max(hog_pixel, axis=2)
334
+ # dominant_weighted = dominant_bin * dominant_strength
335
+ # valid_H, valid_W = hog_pixel.shape[:2]
336
+ # if visualize:
337
+ # vis_items.append((hog_energy, "HOG Energy",'gray'))
338
+ # vis_items.append((dominant_bin, "HOG Dominant Bin",'hsv'))
339
+ # vis_items.append((dominant_weighted, "HOG Weighted Dominant Bin",'gray'))
340
+ # vis_items.append((hog_image[:valid_H, :valid_W], f"HoG",'gray'))
341
+ # for b in range(hog_pixel.shape[2]):
342
+ # feature_list.append(hog_pixel[:, :, b])
 
 
 
 
343
 
344
 
345
  # 2. Canny edges
346
  edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
 
347
  feature_list.append(edges[:valid_H, :valid_W])
348
  if visualize:
 
349
  vis_items.append((edges[:valid_H, :valid_W], "Canny Edge", "gray"))
350
  # 3. Harris corners
351
  harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
352
  harris = cv2.dilate(harris, None)
353
  harris = np.clip(harris, 0, 1)
 
354
  feature_list.append(harris[:valid_H, :valid_W])
355
  if visualize:
 
356
  vis_items.append((harris[:valid_H, :valid_W], "Harris Corner", "gray"))
357
+
358
+ # 4. LBP
 
 
 
 
 
 
 
 
 
 
359
  lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
360
  lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
361
  # feature_list.append(lbp.ravel())
 
363
  if visualize:
364
  # figs.append(plot_feature(lbp[:valid_H, :valid_W], "LBP"))
365
  vis_items.append((lbp[:valid_H, :valid_W], "LBP", "gray"))
366
+ # 5. Gabor filter
 
 
 
 
 
 
 
 
367
  for theta in [0, np.pi/4, np.pi/2]:
368
  kernel = cv2.getGaborKernel(
369
  (cfg.gabor_ksize, cfg.gabor_ksize),
 
375
  g /= g.max() + 1e-8
376
  feature_list.append(g[:valid_H, :valid_W])
377
  if visualize:
 
378
  vis_items.append((g[:valid_H, :valid_W], f"Gabor θ={theta:.2f}", "gray"))
379
+ # 6. Sobel
380
  sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
381
  sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
382
 
 
389
  feature_list.append(sobelx[:valid_H, :valid_W])
390
  feature_list.append(sobely[:valid_H, :valid_W])
391
  if visualize:
 
 
392
  vis_items.append((sobelx[:valid_H, :valid_W], "Sobel X",'gray'))
393
  vis_items.append((sobely[:valid_H, :valid_W], "Sobel Y",'gray'))
394
+ # 7. Laplacian
395
  lap = cv2.Laplacian(gray, cv2.CV_32F)
396
  lap = np.abs(lap)
397
  lap /= lap.max() + 1e-8
 
399
  feature_list.append(lap[:valid_H, :valid_W])
400
 
401
  if visualize:
 
402
  vis_items.append((lap[:valid_H, :valid_W], "Laplacian",'gray'))
403
 
404
+ # 8. Gradient Magnitude
405
  gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
406
  gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
407
 
 
411
  feature_list.append(grad_mag[:valid_H, :valid_W])
412
 
413
  if visualize:
 
414
  vis_items.append((grad_mag[:valid_H, :valid_W], "Gradient Magnitude",'gray'))
415
 
416
  # Stack all features along channel axis
417
  features = np.stack(feature_list, axis=0)
 
418
  if visualize:
419
+ return features.astype(np.float32),[self.render_subplots(vis_items, max_cols=8)]
420
  return features.astype(np.float32)
421
 
422
 
423
+ def forward(self, x, **kwargs):
424
+ if isinstance(x, list):
425
+ x = np.stack(x, axis=0)
426
+
427
  if isinstance(x, torch.Tensor):
428
  x = x.cpu().numpy()
429
+
430
  if isinstance(x, np.ndarray):
431
+ if x.ndim == 3:
432
+ x = x[None]
433
  elif x.ndim != 4:
434
+ raise ValueError(
435
+ f"Expected input of shape HWC or BHWC, got {x.shape}"
436
+ )
437
+ feats = []
 
438
  for img in x:
439
+ if img.shape[2] != 3:
440
  img = np.repeat(img[:, :, None], 3, axis=2)
441
+ feats.append(self.extract_features(img))
442
+
443
+ feats = np.stack(feats, axis=0)
444
+ feats = torch.from_numpy(feats).float().to(self.get_device())
445
+ return feats
446
 
447
  def visualize(self, img, show_original=True,show=True):
448
  if img.ndim != 3 or img.shape[2] != 3:
 
480
 
481
 
482
  def output(self):
483
+ dummy = np.zeros(
484
+ (self.img_size[1], self.img_size[0], 3),
485
+ dtype=np.float32
486
+ )
487
+
488
+ feats = self.forward(dummy)
489
+
490
+ return feats
491
 
492
 
493
 
 
497
  num_classes = len(classes)
498
  self.classes = classes
499
  layers = []
500
+ hidden_dim =1024
501
+ for _ in range(config.fc_num_layers):
502
+ layers.append(nn.Linear(in_features, hidden_dim))
503
+ layers.append(nn.BatchNorm1d(hidden_dim))
504
  layers.append(nn.ReLU())
505
  layers.append(nn.Dropout(config.dropout))
506
+
507
+ in_features = hidden_dim
508
+ hidden_dim = max(hidden_dim // 2, num_classes * 2)
 
509
  layers.append(nn.Linear(in_features,num_classes))
510
  self.layers = nn.Sequential(*layers)
511
  def get_device(self):
512
  return next(self.parameters()).device
513
+ def forward(self,x : torch.Tensor,**kwargs):
514
  x=x.to(self.get_device())
515
  return self.layers(x)
516
 
 
534
  target_size = self.config.img_size
535
  x = cv2.resize(x, target_size)
536
  logits = self.forward(x)
537
+ probs = torch.softmax(logits,dim=1)
538
  pred_idx = torch.argmax(probs, dim=1).item()
539
 
540
  return self.classes[pred_idx]
541
 
542
+ def forward(self,x,**kwargs):
543
+ feat = self.backbone(x,**kwargs)
544
+ feat = self.flatten(feat,**kwargs)
545
+ return self.head(feat,**kwargs)
546
  def visualize_feature(self,img,return_img=True,**kwargs):
547
  target_size = self.config.img_size
548
  img = cv2.resize(img, target_size)
src/trainer.py CHANGED
@@ -10,12 +10,149 @@ import random
10
  import numpy as np
11
  import torch.nn as nn
12
  import time
13
-
 
 
 
 
 
 
14
  def seed_worker(worker_id):
15
  worker_seed = torch.initial_seed() % 2**32
16
  np.random.seed(worker_seed)
17
  random.seed(worker_seed)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class ModelTrainer:
20
  def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None):
21
  g = torch.Generator()
@@ -50,6 +187,9 @@ class ModelTrainer:
50
  self.optim.zero_grad()
51
  self.criterion = nn.CrossEntropyLoss()
52
  self.return_fig=return_fig
 
 
 
53
 
54
  def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4):
55
 
@@ -114,13 +254,15 @@ class ModelTrainer:
114
  return None
115
 
116
 
117
- def train_one_epoch(self):
118
  self.model.train()
119
  total_loss = 0
120
  train_pbar = tqdm(self.train_loader, desc="Training",leave=False)
121
  correct = 0
122
  total = 0
123
  for imgs, labels in train_pbar:
 
 
124
  labels = labels.to(self.device)
125
 
126
  # Forward
@@ -146,18 +288,30 @@ class ModelTrainer:
146
  val_losses=[]
147
  val_accuracies=[]
148
  for epoch in range(1, epochs + 1):
149
- train_loss,train_acc = self.train_one_epoch()
 
 
150
  train_losses.append(train_loss)
151
  train_accuracies.append(train_acc)
152
  if self.val_loader is not None:
153
  val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1))
 
 
154
  val_losses.append(val_loss)
155
  val_accuracies.append(val_acc)
156
  print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}")
 
 
 
 
 
157
  yield train_loss,train_acc,val_loss,val_acc,fig
158
  else:
159
  print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}")
160
  yield train_loss,train_acc,None,None,None
 
 
 
161
  yield train_losses,train_accuracies,val_losses,val_accuracies,None
162
 
163
  def validate(self,epoch, visualize=False):
@@ -177,6 +331,8 @@ class ModelTrainer:
177
  fig = None
178
  with torch.no_grad():
179
  for imgs, labels in val_pbar:
 
 
180
  labels = labels.to(self.device)
181
 
182
  outputs = self.model(imgs)
 
10
  import numpy as np
11
  import torch.nn as nn
12
  import time
13
+ from sklearn.metrics import (
14
+ confusion_matrix,
15
+ classification_report,
16
+ roc_curve,
17
+ auc
18
+ )
19
+ from sklearn.preprocessing import label_binarize
20
  def seed_worker(worker_id):
21
  worker_seed = torch.initial_seed() % 2**32
22
  np.random.seed(worker_seed)
23
  random.seed(worker_seed)
24
 
25
+ def model_evaluation(model, val_set, device,batch_size=32,num_workers=0, class_names=None):
26
+
27
+ model.eval()
28
+ all_preds = []
29
+ all_probs = []
30
+ all_labels = []
31
+ val_loader = DataLoader(
32
+ val_set,
33
+ batch_size=batch_size,
34
+ shuffle=False,
35
+ num_workers=num_workers
36
+ )
37
+ with torch.no_grad():
38
+ for images, labels in val_loader:
39
+ if images.ndim == 4 and images.shape[-1] in (1, 3):
40
+ images = images.permute(0, 3, 1, 2)
41
+ images = images.to(device)
42
+ labels = labels.to(device)
43
+ logits = model(images)
44
+ probs = torch.softmax(logits, dim=1)
45
+ preds = torch.argmax(probs, dim=1)
46
+
47
+ all_preds.append(preds.cpu().numpy())
48
+ all_probs.append(probs.cpu().numpy())
49
+ all_labels.append(labels.cpu().numpy())
50
+
51
+ y_true = np.concatenate(all_labels)
52
+ y_pred = np.concatenate(all_preds)
53
+ y_prob = np.concatenate(all_probs)
54
+
55
+ num_classes = y_prob.shape[1]
56
+
57
+ if class_names is None:
58
+ class_names = [f"Class {i}" for i in range(num_classes)]
59
+
60
+ cm = confusion_matrix(y_true, y_pred)
61
+
62
+ cm_fig, ax = plt.subplots(figsize=(6, 6))
63
+ im = ax.imshow(cm)
64
+
65
+ ax.set_title("Confusion Matrix")
66
+ ax.set_xlabel("Predicted")
67
+ ax.set_ylabel("True")
68
+ ax.set_xticks(range(num_classes))
69
+ ax.set_yticks(range(num_classes))
70
+ ax.set_xticklabels(class_names, rotation=75)
71
+ ax.set_yticklabels(class_names)
72
+
73
+ for i in range(num_classes):
74
+ for j in range(num_classes):
75
+ ax.text(j, i, cm[i, j], ha="center", va="center")
76
+
77
+ plt.tight_layout()
78
+
79
+ report = classification_report(
80
+ y_true, y_pred,
81
+ target_names=class_names,
82
+ output_dict=True
83
+ )
84
+
85
+ cr_fig, ax = plt.subplots(figsize=(12, 8))
86
+ ax.axis("off")
87
+
88
+ table_data = []
89
+ headers = ["Class", "Precision", "Recall", "F1", "Support"]
90
+
91
+ for cls in class_names:
92
+ row = report[cls]
93
+ table_data.append([
94
+ cls,
95
+ f"{row['precision']:.3f}",
96
+ f"{row['recall']:.3f}",
97
+ f"{row['f1-score']:.3f}",
98
+ int(row['support'])
99
+ ])
100
+
101
+ accuracy = report["accuracy"]
102
+ macro_avg = report["macro avg"]
103
+ weighted_avg = report["weighted avg"]
104
+
105
+ table_data.append([
106
+ "Accuracy",
107
+ f"{accuracy:.3f}",
108
+ "",
109
+ "",
110
+ ""
111
+ ])
112
+
113
+ table_data.append([
114
+ "Macro Avg",
115
+ f"{macro_avg['precision']:.3f}",
116
+ f"{macro_avg['recall']:.3f}",
117
+ f"{macro_avg['f1-score']:.3f}",
118
+ f"{int(macro_avg['support'])}" if 'support' in macro_avg else ""
119
+ ])
120
+
121
+ table_data.append([
122
+ "Weighted Avg",
123
+ f"{weighted_avg['precision']:.3f}",
124
+ f"{weighted_avg['recall']:.3f}",
125
+ f"{weighted_avg['f1-score']:.3f}",
126
+ f"{int(weighted_avg['support'])}" if 'support' in weighted_avg else ""
127
+ ])
128
+
129
+ table = ax.table(
130
+ cellText=table_data,
131
+ colLabels=headers,
132
+ loc="center"
133
+ )
134
+
135
+ table.scale(1, 2)
136
+ ax.set_title("Classification Report")
137
+
138
+ y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))
139
+
140
+ roc_fig, ax = plt.subplots(figsize=(6, 6))
141
+
142
+ for i in range(num_classes):
143
+ fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
144
+ roc_auc = auc(fpr, tpr)
145
+ ax.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.3f})")
146
+
147
+ ax.plot([0, 1], [0, 1], linestyle="--")
148
+ ax.set_xlabel("False Positive Rate")
149
+ ax.set_ylabel("True Positive Rate")
150
+ ax.set_title("ROC-AUC Curve")
151
+ ax.legend()
152
+ ax.grid(True)
153
+
154
+ return cm_fig, cr_fig, roc_fig
155
+
156
  class ModelTrainer:
157
  def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None):
158
  g = torch.Generator()
 
187
  self.optim.zero_grad()
188
  self.criterion = nn.CrossEntropyLoss()
189
  self.return_fig=return_fig
190
+ self.best_model_state = None
191
+ self.best_val_acc = 0.0
192
+ self.interrupt=False
193
 
194
  def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4):
195
 
 
254
  return None
255
 
256
 
257
+ def train_one_epoch(self,epoch):
258
  self.model.train()
259
  total_loss = 0
260
  train_pbar = tqdm(self.train_loader, desc="Training",leave=False)
261
  correct = 0
262
  total = 0
263
  for imgs, labels in train_pbar:
264
+ if self.interrupt:
265
+ break
266
  labels = labels.to(self.device)
267
 
268
  # Forward
 
288
  val_losses=[]
289
  val_accuracies=[]
290
  for epoch in range(1, epochs + 1):
291
+ train_loss,train_acc = self.train_one_epoch(epoch)
292
+ if self.interrupt:
293
+ return
294
  train_losses.append(train_loss)
295
  train_accuracies.append(train_acc)
296
  if self.val_loader is not None:
297
  val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1))
298
+ if self.interrupt:
299
+ return
300
  val_losses.append(val_loss)
301
  val_accuracies.append(val_acc)
302
  print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}")
303
+ if val_acc > self.best_val_acc:
304
+ print(f"New best model found at epoch {epoch} (Val Acc: {val_acc:.4f})")
305
+ self.best_val_acc = val_acc
306
+ self.best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()}
307
+
308
  yield train_loss,train_acc,val_loss,val_acc,fig
309
  else:
310
  print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}")
311
  yield train_loss,train_acc,None,None,None
312
+ if self.best_model_state is not None:
313
+ self.model.load_state_dict(self.best_model_state)
314
+ print(f"Best model (Val Acc: {self.best_val_acc:.4f}) loaded into trainer.model")
315
  yield train_losses,train_accuracies,val_losses,val_accuracies,None
316
 
317
  def validate(self,epoch, visualize=False):
 
331
  fig = None
332
  with torch.no_grad():
333
  for imgs, labels in val_pbar:
334
+ if self.interrupt:
335
+ break
336
  labels = labels.to(self.device)
337
 
338
  outputs = self.model(imgs)