Commit
·
5e96bc9
1
Parent(s):
4631366
Update
Browse files- app.py +49 -32
- src/__pycache__/dataloader.cpython-312.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/__pycache__/trainer.cpython-312.pyc +0 -0
- src/dataloader.py +22 -1
- src/model.py +119 -153
- src/trainer.py +159 -3
app.py
CHANGED
|
@@ -2,10 +2,10 @@ import gradio as gr
|
|
| 2 |
import zipfile
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
-
from src.dataloader import ImageDataset,collate_fn
|
| 6 |
from src.model import Classifier,Config,CNNFeatureExtractor,ClassicalFeatureExtractor,load
|
| 7 |
from torch.utils.data import Subset
|
| 8 |
-
from src.trainer import ModelTrainer
|
| 9 |
import torch
|
| 10 |
import os
|
| 11 |
import numpy as np
|
|
@@ -17,9 +17,6 @@ import matplotlib.pyplot as plt
|
|
| 17 |
import shutil
|
| 18 |
import pandas as pd
|
| 19 |
from sklearn.model_selection import train_test_split
|
| 20 |
-
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
| 21 |
-
from sklearn.metrics import classification_report
|
| 22 |
-
from torch.utils.data import DataLoader
|
| 23 |
|
| 24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
def unzip_dataset(zip_file):
|
|
@@ -82,13 +79,22 @@ def plot(datas, labels, xlabel, ylabel, title, figsize=(16, 8)):
|
|
| 82 |
|
| 83 |
class TrainingInterrupted(Exception):
|
| 84 |
pass
|
|
|
|
|
|
|
| 85 |
def stop_training():
|
| 86 |
global training_interrupt
|
| 87 |
training_interrupt = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
return "Training stopped."
|
| 89 |
|
|
|
|
|
|
|
| 90 |
def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visualize_every=5):
|
| 91 |
global training_interrupt
|
|
|
|
| 92 |
training_interrupt = False
|
| 93 |
global cnn_history
|
| 94 |
global classic_history
|
|
@@ -163,17 +169,17 @@ def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visual
|
|
| 163 |
print(e)
|
| 164 |
return
|
| 165 |
|
| 166 |
-
def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
|
| 167 |
img_width,img_height,fc_num_layers,
|
| 168 |
in_channels,conv_hidden_dim,dropout,
|
| 169 |
classical_downsample,
|
| 170 |
-
hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
|
| 171 |
canny_sigma,canny_low,canny_high,
|
| 172 |
gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
|
| 173 |
harris_block_size,harris_ksize,harris_k,
|
| 174 |
-
shi_max_corners,shi_quality_level,shi_min_distance,
|
| 175 |
lbp_P,lbp_R,
|
| 176 |
-
gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma
|
|
|
|
| 177 |
config = Config()
|
| 178 |
global training_interrupt
|
| 179 |
training_interrupt = False
|
|
@@ -190,10 +196,10 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
|
|
| 190 |
config.dropout=dropout
|
| 191 |
# Classical Config
|
| 192 |
config.classical_downsample=int(classical_downsample)
|
| 193 |
-
config.hog_orientations=int(hog_orientations)
|
| 194 |
-
config.hog_pixels_per_cell=(int(hog_pixels_per_cell),int(hog_pixels_per_cell))
|
| 195 |
-
config.hog_cells_per_block=(int(hog_cells_per_block),int(hog_cells_per_block))
|
| 196 |
-
config.hog_block_norm=hog_block_norm
|
| 197 |
config.canny_sigma=int(canny_sigma)
|
| 198 |
config.canny_low=canny_low
|
| 199 |
config.canny_high=canny_high
|
|
@@ -203,9 +209,6 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
|
|
| 203 |
config.harris_block_size=int(harris_block_size)
|
| 204 |
config.harris_ksize=int(harris_ksize)
|
| 205 |
config.harris_k=harris_k
|
| 206 |
-
config.shi_max_corners=int(shi_max_corners)
|
| 207 |
-
config.shi_quality_level=shi_quality_level
|
| 208 |
-
config.shi_min_distance=int(shi_min_distance)
|
| 209 |
config.lbp_P=int(lbp_P)
|
| 210 |
config.lbp_R=int(lbp_R)
|
| 211 |
config.gabor_ksize=int(gabor_ksize)
|
|
@@ -213,12 +216,14 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
|
|
| 213 |
config.gabor_theta=int(gabor_theta)
|
| 214 |
config.gabor_lambda=int(gabor_lambda)
|
| 215 |
config.gabor_gamma=gabor_gamma
|
|
|
|
| 216 |
cnn_history_plots=[]
|
| 217 |
classical_history_plots=[]
|
| 218 |
cnn_plotted=False
|
| 219 |
try:
|
| 220 |
dataset = ImageDataset(DATASET_PATH,config.img_size)
|
| 221 |
labels = [item['id'] for item in dataset.data]
|
|
|
|
| 222 |
train_idx, validation_idx = train_test_split(np.arange(len(dataset)),
|
| 223 |
test_size=0.2,
|
| 224 |
random_state=SEED,
|
|
@@ -226,6 +231,12 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
|
|
| 226 |
stratify=labels)
|
| 227 |
train_dataset = Subset(dataset, train_idx)
|
| 228 |
val_dataset = Subset(dataset, validation_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
cnnbackbone = CNNFeatureExtractor(config).to(device)
|
| 230 |
cnnmodel = Classifier(cnnbackbone,train_dataset.dataset.classes,config).to(device)
|
| 231 |
classicbackbone = ClassicalFeatureExtractor(config)
|
|
@@ -235,11 +246,17 @@ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
|
|
| 235 |
cnn_plotted=True
|
| 236 |
cnn_history_plots.append(plot([cnn_history['train_acc'],cnn_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
|
| 237 |
cnn_history_plots.append(plot([cnn_history['train_loss'],cnn_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
| 239 |
yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
|
| 240 |
classical_history_plots.append(plot([classic_history['train_acc'],classic_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
|
| 241 |
classical_history_plots.append(plot([classic_history['train_loss'],classic_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
| 243 |
yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
|
| 244 |
|
| 245 |
except RuntimeError as e:
|
|
@@ -314,6 +331,7 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
|
|
| 314 |
epochs= gr.Number(value=20,label="Epochs",interactive=True,precision=0)
|
| 315 |
seed=gr.Number(value=42,label='Seed',interactive=True,precision=0)
|
| 316 |
vis_every=gr.Number(value=5,label='Visualize Validation Every (Epochs)',interactive=True,precision=0)
|
|
|
|
| 317 |
with gr.Row():
|
| 318 |
img_width=gr.Number(value=128,label='Image Width',interactive=True,precision=0)
|
| 319 |
img_height=gr.Number(value=128,label='Image Height',interactive=True,precision=0)
|
|
@@ -328,11 +346,12 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
|
|
| 328 |
with gr.Accordion(label='Classical Feature Extractor Settings',open=False):
|
| 329 |
with gr.Row():
|
| 330 |
classical_downsample = gr.Number(value=1,label='Classical Extractor Downsampling Amount',interactive=True,precision=0)
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
| 336 |
with gr.Row():
|
| 337 |
canny_sigma = gr.Number(value=1.0,label='Canny Sigma Value',interactive=True)
|
| 338 |
canny_low = gr.Number(value=100,label='Canny Low Threshold',interactive=True,precision=0)
|
|
@@ -345,19 +364,17 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
|
|
| 345 |
harris_block_size = gr.Number(value=2,label='Harris Corner Block Size',interactive=True,precision=0)
|
| 346 |
harris_ksize = gr.Number(value=3,label='Harris Corner Kernel Size',interactive=True,precision=0)
|
| 347 |
harris_k = gr.Slider(minimum=0.01, maximum=0.1, value=0.04, step=0.005, label='Harris Corner K value',interactive=True)
|
| 348 |
-
with gr.Row():
|
| 349 |
-
shi_max_corners = gr.Number(value=100,label='Shi-Tomasi Max Corners',interactive=True,precision=0)
|
| 350 |
-
shi_quality_level = gr.Number(value=0.01,label='Shi-Tomasi Quality Level',interactive=True)
|
| 351 |
-
shi_min_distance = gr.Number(value=10,label='Shi-Tomasi Min Distance',interactive=True,precision=0)
|
| 352 |
with gr.Row():
|
| 353 |
lbp_P = gr.Number(value=8,label='LBP P Value',interactive=True,precision=0)
|
| 354 |
lbp_R = gr.Number(value=1,label='LBP R Value',interactive=True,precision=0)
|
| 355 |
with gr.Row():
|
| 356 |
gabor_ksize = gr.Number(value=21,label="Gabor Kernel Size",interactive=True,precision=0)
|
| 357 |
gabor_sigma = gr.Number(value=5,label="Gabor Sigma",interactive=True,precision=0)
|
| 358 |
-
gabor_theta = gr.Number(value=0,label="Gabor Theta",interactive=True,precision=0)
|
| 359 |
gabor_lambda = gr.Number(value=10,label="Gabor Lambda",interactive=True,precision=0)
|
| 360 |
gabor_gamma = gr.Number(value=0.5,label="Gabor Gamma",interactive=True)
|
|
|
|
|
|
|
| 361 |
with gr.Column():
|
| 362 |
train_btn = gr.Button("Train Model",variant='secondary',interactive=True)
|
| 363 |
stop_btn = gr.Button("Stop Training")
|
|
@@ -375,17 +392,17 @@ with gr.Blocks(title="Object Classifier Playground") as demo:
|
|
| 375 |
classical_plots = gr.Gallery(label="CNN Training Performance",interactive=False,object_fit='fill',columns=1)
|
| 376 |
stop_btn.click(fn=stop_training, inputs=[], outputs=[])
|
| 377 |
train_btn.click(fn=train_model,
|
| 378 |
-
inputs=[zip_file,batch_size,lr,epochs,seed,vis_every,
|
| 379 |
img_width,img_height,fc_num_layers,
|
| 380 |
in_channels,conv_hidden_dim,dropout,
|
| 381 |
classical_downsample,
|
| 382 |
-
hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
|
| 383 |
canny_sigma,canny_low,canny_high,
|
| 384 |
gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
|
| 385 |
harris_block_size,harris_ksize,harris_k,
|
| 386 |
-
shi_max_corners,shi_quality_level,shi_min_distance,
|
| 387 |
lbp_P,lbp_R,
|
| 388 |
-
gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma
|
|
|
|
| 389 |
outputs=[cnn_log,cnn_fig,classical_log,classical_fig,cnn_plots,classical_plots]
|
| 390 |
)
|
| 391 |
def make_figure_from_image(img):
|
|
|
|
| 2 |
import zipfile
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
+
from src.dataloader import ImageDataset,collate_fn,AugmentedSubset,simple_augment
|
| 6 |
from src.model import Classifier,Config,CNNFeatureExtractor,ClassicalFeatureExtractor,load
|
| 7 |
from torch.utils.data import Subset
|
| 8 |
+
from src.trainer import ModelTrainer,model_evaluation
|
| 9 |
import torch
|
| 10 |
import os
|
| 11 |
import numpy as np
|
|
|
|
| 17 |
import shutil
|
| 18 |
import pandas as pd
|
| 19 |
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
def unzip_dataset(zip_file):
|
|
|
|
| 79 |
|
| 80 |
class TrainingInterrupted(Exception):
|
| 81 |
pass
|
| 82 |
+
cnntrainer=None
|
| 83 |
+
classictrainer=None
|
| 84 |
def stop_training():
|
| 85 |
global training_interrupt
|
| 86 |
training_interrupt = True
|
| 87 |
+
if cnntrainer is not None:
|
| 88 |
+
cnntrainer.interrupt=True
|
| 89 |
+
if classictrainer is not None:
|
| 90 |
+
classictrainer.interrupt=True
|
| 91 |
return "Training stopped."
|
| 92 |
|
| 93 |
+
|
| 94 |
+
|
| 95 |
def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visualize_every=5):
|
| 96 |
global training_interrupt
|
| 97 |
+
global cnntrainer,classictrainer
|
| 98 |
training_interrupt = False
|
| 99 |
global cnn_history
|
| 100 |
global classic_history
|
|
|
|
| 169 |
print(e)
|
| 170 |
return
|
| 171 |
|
| 172 |
+
def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,use_augment,
|
| 173 |
img_width,img_height,fc_num_layers,
|
| 174 |
in_channels,conv_hidden_dim,dropout,
|
| 175 |
classical_downsample,
|
| 176 |
+
# hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
|
| 177 |
canny_sigma,canny_low,canny_high,
|
| 178 |
gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
|
| 179 |
harris_block_size,harris_ksize,harris_k,
|
|
|
|
| 180 |
lbp_P,lbp_R,
|
| 181 |
+
gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma,
|
| 182 |
+
sobel_ksize):
|
| 183 |
config = Config()
|
| 184 |
global training_interrupt
|
| 185 |
training_interrupt = False
|
|
|
|
| 196 |
config.dropout=dropout
|
| 197 |
# Classical Config
|
| 198 |
config.classical_downsample=int(classical_downsample)
|
| 199 |
+
# config.hog_orientations=int(hog_orientations)
|
| 200 |
+
# config.hog_pixels_per_cell=(int(hog_pixels_per_cell),int(hog_pixels_per_cell))
|
| 201 |
+
# config.hog_cells_per_block=(int(hog_cells_per_block),int(hog_cells_per_block))
|
| 202 |
+
# config.hog_block_norm=hog_block_norm
|
| 203 |
config.canny_sigma=int(canny_sigma)
|
| 204 |
config.canny_low=canny_low
|
| 205 |
config.canny_high=canny_high
|
|
|
|
| 209 |
config.harris_block_size=int(harris_block_size)
|
| 210 |
config.harris_ksize=int(harris_ksize)
|
| 211 |
config.harris_k=harris_k
|
|
|
|
|
|
|
|
|
|
| 212 |
config.lbp_P=int(lbp_P)
|
| 213 |
config.lbp_R=int(lbp_R)
|
| 214 |
config.gabor_ksize=int(gabor_ksize)
|
|
|
|
| 216 |
config.gabor_theta=int(gabor_theta)
|
| 217 |
config.gabor_lambda=int(gabor_lambda)
|
| 218 |
config.gabor_gamma=gabor_gamma
|
| 219 |
+
config.sobel_ksize=int(sobel_ksize)
|
| 220 |
cnn_history_plots=[]
|
| 221 |
classical_history_plots=[]
|
| 222 |
cnn_plotted=False
|
| 223 |
try:
|
| 224 |
dataset = ImageDataset(DATASET_PATH,config.img_size)
|
| 225 |
labels = [item['id'] for item in dataset.data]
|
| 226 |
+
classes_name = dataset.classes
|
| 227 |
train_idx, validation_idx = train_test_split(np.arange(len(dataset)),
|
| 228 |
test_size=0.2,
|
| 229 |
random_state=SEED,
|
|
|
|
| 231 |
stratify=labels)
|
| 232 |
train_dataset = Subset(dataset, train_idx)
|
| 233 |
val_dataset = Subset(dataset, validation_idx)
|
| 234 |
+
if use_augment:
|
| 235 |
+
train_dataset = AugmentedSubset(train_dataset,simple_augment)
|
| 236 |
+
val_dataset = AugmentedSubset(val_dataset,None)
|
| 237 |
+
else:
|
| 238 |
+
train_dataset = AugmentedSubset(train_dataset,None)
|
| 239 |
+
val_dataset = AugmentedSubset(val_dataset,None)
|
| 240 |
cnnbackbone = CNNFeatureExtractor(config).to(device)
|
| 241 |
cnnmodel = Classifier(cnnbackbone,train_dataset.dataset.classes,config).to(device)
|
| 242 |
classicbackbone = ClassicalFeatureExtractor(config)
|
|
|
|
| 246 |
cnn_plotted=True
|
| 247 |
cnn_history_plots.append(plot([cnn_history['train_acc'],cnn_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
|
| 248 |
cnn_history_plots.append(plot([cnn_history['train_loss'],cnn_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
|
| 249 |
+
cm,cr,roc = model_evaluation(cnnmodel,val_dataset,device,BATCH_SIZE,0,classes_name)
|
| 250 |
+
cnn_history_plots.append(fig_to_image(cm))
|
| 251 |
+
cnn_history_plots.append(fig_to_image(cr))
|
| 252 |
+
cnn_history_plots.append(fig_to_image(roc))
|
| 253 |
yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
|
| 254 |
classical_history_plots.append(plot([classic_history['train_acc'],classic_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
|
| 255 |
classical_history_plots.append(plot([classic_history['train_loss'],classic_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
|
| 256 |
+
cm,cr,roc = model_evaluation(classicmodel,val_dataset,device,BATCH_SIZE,0,classes_name)
|
| 257 |
+
classical_history_plots.append(fig_to_image(cm))
|
| 258 |
+
classical_history_plots.append(fig_to_image(cr))
|
| 259 |
+
classical_history_plots.append(fig_to_image(roc))
|
| 260 |
yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
|
| 261 |
|
| 262 |
except RuntimeError as e:
|
|
|
|
| 331 |
epochs= gr.Number(value=20,label="Epochs",interactive=True,precision=0)
|
| 332 |
seed=gr.Number(value=42,label='Seed',interactive=True,precision=0)
|
| 333 |
vis_every=gr.Number(value=5,label='Visualize Validation Every (Epochs)',interactive=True,precision=0)
|
| 334 |
+
use_augment = gr.Checkbox(value=True,label='Use data augmentation for train data')
|
| 335 |
with gr.Row():
|
| 336 |
img_width=gr.Number(value=128,label='Image Width',interactive=True,precision=0)
|
| 337 |
img_height=gr.Number(value=128,label='Image Height',interactive=True,precision=0)
|
|
|
|
| 346 |
with gr.Accordion(label='Classical Feature Extractor Settings',open=False):
|
| 347 |
with gr.Row():
|
| 348 |
classical_downsample = gr.Number(value=1,label='Classical Extractor Downsampling Amount',interactive=True,precision=0)
|
| 349 |
+
# Deprecated
|
| 350 |
+
# with gr.Row():
|
| 351 |
+
# hog_orientations = gr.Number(value=9,label='HoG Orientations',interactive=True,precision=0)
|
| 352 |
+
# hog_pixels_per_cell = gr.Number(value=16,label='HoG pixels per cell',interactive=True,precision=0)
|
| 353 |
+
# hog_cells_per_block = gr.Number(value=2,label='HoG cells per block',interactive=True,precision=0)
|
| 354 |
+
# hog_block_norm = gr.Dropdown(['L2-Hys'],value='L2-Hys',label='HoG Block Normalization Method',interactive=True)
|
| 355 |
with gr.Row():
|
| 356 |
canny_sigma = gr.Number(value=1.0,label='Canny Sigma Value',interactive=True)
|
| 357 |
canny_low = gr.Number(value=100,label='Canny Low Threshold',interactive=True,precision=0)
|
|
|
|
| 364 |
harris_block_size = gr.Number(value=2,label='Harris Corner Block Size',interactive=True,precision=0)
|
| 365 |
harris_ksize = gr.Number(value=3,label='Harris Corner Kernel Size',interactive=True,precision=0)
|
| 366 |
harris_k = gr.Slider(minimum=0.01, maximum=0.1, value=0.04, step=0.005, label='Harris Corner K value',interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
with gr.Row():
|
| 368 |
lbp_P = gr.Number(value=8,label='LBP P Value',interactive=True,precision=0)
|
| 369 |
lbp_R = gr.Number(value=1,label='LBP R Value',interactive=True,precision=0)
|
| 370 |
with gr.Row():
|
| 371 |
gabor_ksize = gr.Number(value=21,label="Gabor Kernel Size",interactive=True,precision=0)
|
| 372 |
gabor_sigma = gr.Number(value=5,label="Gabor Sigma",interactive=True,precision=0)
|
| 373 |
+
gabor_theta = gr.Number(value=0,label="Gabor Theta",interactive=True,precision=0,info="This current does nothing")
|
| 374 |
gabor_lambda = gr.Number(value=10,label="Gabor Lambda",interactive=True,precision=0)
|
| 375 |
gabor_gamma = gr.Number(value=0.5,label="Gabor Gamma",interactive=True)
|
| 376 |
+
with gr.Row():
|
| 377 |
+
sobel_ksize = gr.Number(value=3,label="Sobel Kernel Size",interactive=True,precision=0)
|
| 378 |
with gr.Column():
|
| 379 |
train_btn = gr.Button("Train Model",variant='secondary',interactive=True)
|
| 380 |
stop_btn = gr.Button("Stop Training")
|
|
|
|
| 392 |
classical_plots = gr.Gallery(label="CNN Training Performance",interactive=False,object_fit='fill',columns=1)
|
| 393 |
stop_btn.click(fn=stop_training, inputs=[], outputs=[])
|
| 394 |
train_btn.click(fn=train_model,
|
| 395 |
+
inputs=[zip_file,batch_size,lr,epochs,seed,vis_every,use_augment,
|
| 396 |
img_width,img_height,fc_num_layers,
|
| 397 |
in_channels,conv_hidden_dim,dropout,
|
| 398 |
classical_downsample,
|
| 399 |
+
# hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
|
| 400 |
canny_sigma,canny_low,canny_high,
|
| 401 |
gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
|
| 402 |
harris_block_size,harris_ksize,harris_k,
|
|
|
|
| 403 |
lbp_P,lbp_R,
|
| 404 |
+
gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma,
|
| 405 |
+
sobel_ksize],
|
| 406 |
outputs=[cnn_log,cnn_fig,classical_log,classical_fig,cnn_plots,classical_plots]
|
| 407 |
)
|
| 408 |
def make_figure_from_image(img):
|
src/__pycache__/dataloader.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/dataloader.cpython-312.pyc and b/src/__pycache__/dataloader.cpython-312.pyc differ
|
|
|
src/__pycache__/model.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/model.cpython-312.pyc and b/src/__pycache__/model.cpython-312.pyc differ
|
|
|
src/__pycache__/trainer.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/trainer.cpython-312.pyc and b/src/__pycache__/trainer.cpython-312.pyc differ
|
|
|
src/dataloader.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from torch.utils.data import Dataset
|
| 2 |
import torch
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
@@ -37,4 +37,25 @@ class ImageDataset(Dataset):
|
|
| 37 |
img = img.astype(np.float32) / 255.0
|
| 38 |
return img,label
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Subset,Dataset
|
| 2 |
import torch
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
|
|
| 37 |
img = img.astype(np.float32) / 255.0
|
| 38 |
return img,label
|
| 39 |
|
| 40 |
+
def simple_augment(img):
|
| 41 |
+
if np.random.rand() > 0.5:
|
| 42 |
+
img = cv2.flip(img, 1)
|
| 43 |
|
| 44 |
+
angle = np.random.uniform(-15, 15)
|
| 45 |
+
h, w = img.shape[:2]
|
| 46 |
+
M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
|
| 47 |
+
img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| 48 |
+
|
| 49 |
+
return img
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AugmentedSubset(Subset):
|
| 53 |
+
def __init__(self, subset, augment_fn=None):
|
| 54 |
+
super().__init__(subset.dataset, subset.indices)
|
| 55 |
+
self.augment_fn = augment_fn
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
img, label = super().__getitem__(idx)
|
| 59 |
+
if self.augment_fn:
|
| 60 |
+
img = self.augment_fn(img)
|
| 61 |
+
return img, label
|
src/model.py
CHANGED
|
@@ -4,6 +4,8 @@ import cv2
|
|
| 4 |
import numpy as np
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from skimage.feature import hog,local_binary_pattern
|
|
|
|
|
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import os
|
| 9 |
import io
|
|
@@ -14,7 +16,7 @@ class Config:
|
|
| 14 |
img_size=(256,256)
|
| 15 |
in_channels=3
|
| 16 |
fc_num_layers=3
|
| 17 |
-
conv_hidden_dim=
|
| 18 |
conv_kernel_size=3
|
| 19 |
dropout=0.2
|
| 20 |
classical_downsample=1
|
|
@@ -39,10 +41,6 @@ class Config:
|
|
| 39 |
harris_ksize = 3
|
| 40 |
harris_k = 0.04
|
| 41 |
|
| 42 |
-
# Shi-Tomasi corners
|
| 43 |
-
shi_max_corners = 100
|
| 44 |
-
shi_quality_level = 0.01
|
| 45 |
-
shi_min_distance = 10
|
| 46 |
|
| 47 |
# LBP
|
| 48 |
lbp_P = 8
|
|
@@ -58,6 +56,7 @@ class Config:
|
|
| 58 |
# Sobel
|
| 59 |
sobel_ksize=3
|
| 60 |
|
|
|
|
| 61 |
class CNNFeatureExtractor(nn.Module):
|
| 62 |
def __init__(self,config : Config):
|
| 63 |
super().__init__()
|
|
@@ -67,16 +66,16 @@ class CNNFeatureExtractor(nn.Module):
|
|
| 67 |
self.img_size = config.img_size
|
| 68 |
out_channel = 32
|
| 69 |
for i in range(config.conv_hidden_dim):
|
| 70 |
-
layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=
|
| 71 |
layers.append(nn.BatchNorm2d(out_channel))
|
| 72 |
layers.append(nn.ReLU())
|
| 73 |
-
layers.append(nn.MaxPool2d(2))
|
| 74 |
in_channel=out_channel
|
| 75 |
out_channel*=2
|
| 76 |
self.layers = nn.Sequential(*layers)
|
| 77 |
def get_device(self):
|
| 78 |
return next(self.parameters()).device
|
| 79 |
-
def forward(self,x):
|
| 80 |
if isinstance(x, list):
|
| 81 |
if isinstance(x[0], np.ndarray):
|
| 82 |
x = np.stack(x, axis=0)
|
|
@@ -129,7 +128,7 @@ class CNNFeatureExtractor(nn.Module):
|
|
| 129 |
conv_layers = [
|
| 130 |
(name, module)
|
| 131 |
for name, module in self.named_modules()
|
| 132 |
-
if isinstance(module, nn.
|
| 133 |
]
|
| 134 |
|
| 135 |
all_layer_images = []
|
|
@@ -251,7 +250,7 @@ class CNNFeatureExtractor(nn.Module):
|
|
| 251 |
plt.close(fig)
|
| 252 |
|
| 253 |
return all_layer_images
|
| 254 |
-
|
| 255 |
class ClassicalFeatureExtractor(nn.Module):
|
| 256 |
def __init__(self, config : Config):
|
| 257 |
super().__init__()
|
|
@@ -260,128 +259,103 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 260 |
self.num_downsample = config.classical_downsample
|
| 261 |
self.config = config
|
| 262 |
self.device = 'cpu'
|
| 263 |
-
|
|
|
|
| 264 |
def get_device(self):
|
| 265 |
return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
|
| 266 |
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
def extract_features(self, img,visualize=False,**kwargs):
|
| 269 |
cfg = self.config
|
| 270 |
-
|
| 271 |
# Convert to grayscale
|
| 272 |
gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
| 273 |
-
|
| 274 |
for _ in range(self.num_downsample):
|
| 275 |
gray = cv2.pyrDown(gray)
|
| 276 |
-
|
| 277 |
gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
|
| 278 |
valid_H, valid_W = gray.shape[:2]
|
| 279 |
|
| 280 |
-
def render_subplots(items, max_cols=8, figsize_per_cell=3):
|
| 281 |
-
n = len(items)
|
| 282 |
-
cols = min(max_cols, n)
|
| 283 |
-
rows = int(np.ceil(n / cols))
|
| 284 |
-
|
| 285 |
-
fig, axes = plt.subplots(
|
| 286 |
-
rows, cols,
|
| 287 |
-
figsize=(cols * figsize_per_cell, rows * figsize_per_cell)
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
axes = np.atleast_2d(axes)
|
| 291 |
-
|
| 292 |
-
for idx, (img, title, cmap) in enumerate(items):
|
| 293 |
-
r = idx // cols
|
| 294 |
-
c = idx % cols
|
| 295 |
-
ax = axes[r, c]
|
| 296 |
-
ax.imshow(img, cmap=cmap)
|
| 297 |
-
ax.set_title(title, fontsize=9)
|
| 298 |
-
ax.axis("off")
|
| 299 |
-
|
| 300 |
-
# Hide unused axes
|
| 301 |
-
for idx in range(n, rows * cols):
|
| 302 |
-
r = idx // cols
|
| 303 |
-
c = idx % cols
|
| 304 |
-
axes[r, c].axis("off")
|
| 305 |
-
|
| 306 |
-
plt.tight_layout()
|
| 307 |
-
return fig
|
| 308 |
|
| 309 |
feature_list = []
|
| 310 |
vis_items=[]
|
| 311 |
-
#
|
| 312 |
-
H, W = gray.shape
|
| 313 |
-
cell_h, cell_w = cfg.hog_pixels_per_cell
|
| 314 |
-
block_h, block_w = cfg.hog_cells_per_block
|
| 315 |
-
|
| 316 |
-
min_h = cell_h * block_h
|
| 317 |
-
min_w = cell_w * block_w
|
| 318 |
-
use_hog =
|
| 319 |
-
# 1. HOG
|
| 320 |
-
if use_hog:
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
vis_items.append((dominant_weighted, "HOG Weighted Dominant Bin",'gray'))
|
| 352 |
-
vis_items.append((hog_image[:valid_H, :valid_W], f"HoG",'gray'))
|
| 353 |
-
for b in range(hog_pixel.shape[2]):
|
| 354 |
-
feature_list.append(hog_pixel[:, :, b])
|
| 355 |
|
| 356 |
|
| 357 |
# 2. Canny edges
|
| 358 |
edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
|
| 359 |
-
# feature_list.append(edges.ravel())
|
| 360 |
feature_list.append(edges[:valid_H, :valid_W])
|
| 361 |
if visualize:
|
| 362 |
-
# figs.append(plot_feature(edges[:valid_H, :valid_W], "Canny Edge"))
|
| 363 |
vis_items.append((edges[:valid_H, :valid_W], "Canny Edge", "gray"))
|
| 364 |
# 3. Harris corners
|
| 365 |
harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
|
| 366 |
harris = cv2.dilate(harris, None)
|
| 367 |
harris = np.clip(harris, 0, 1)
|
| 368 |
-
# feature_list.append(harris.ravel())
|
| 369 |
feature_list.append(harris[:valid_H, :valid_W])
|
| 370 |
if visualize:
|
| 371 |
-
# figs.append(plot_feature(harris[:valid_H, :valid_W], "Harris Corner"))
|
| 372 |
vis_items.append((harris[:valid_H, :valid_W], "Harris Corner", "gray"))
|
| 373 |
-
|
| 374 |
-
#
|
| 375 |
-
# keypoints = cv2.goodFeaturesToTrack(gray, maxCorners=cfg.shi_max_corners, qualityLevel=cfg.shi_quality_level, minDistance=cfg.shi_min_distance)
|
| 376 |
-
# if keypoints is not None:
|
| 377 |
-
# for kp in keypoints:
|
| 378 |
-
# x, y = kp.ravel()
|
| 379 |
-
# shi_corners[int(y), int(x)] = 1.0
|
| 380 |
-
# # feature_list.append(shi_corners.ravel())
|
| 381 |
-
# feature_list.append(shi_corners[:valid_H, :valid_W])
|
| 382 |
-
# if visualize:
|
| 383 |
-
# figs.append(plot_feature(shi_corners[:valid_H, :valid_W], "Shi-Tomasi Corner"))
|
| 384 |
-
# 5. LBP
|
| 385 |
lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
|
| 386 |
lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
|
| 387 |
# feature_list.append(lbp.ravel())
|
|
@@ -389,15 +363,7 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 389 |
if visualize:
|
| 390 |
# figs.append(plot_feature(lbp[:valid_H, :valid_W], "LBP"))
|
| 391 |
vis_items.append((lbp[:valid_H, :valid_W], "LBP", "gray"))
|
| 392 |
-
#
|
| 393 |
-
# g_kernel = cv2.getGaborKernel((cfg.gabor_ksize, cfg.gabor_ksize), cfg.gabor_sigma, cfg.gabor_theta, cfg.gabor_lambda, cfg.gabor_gamma)
|
| 394 |
-
# gabor_feat = cv2.filter2D(gray, cv2.CV_32F, g_kernel)
|
| 395 |
-
# gabor_feat = (gabor_feat - gabor_feat.min()) / (gabor_feat.max() - gabor_feat.min() + 1e-8)
|
| 396 |
-
# # feature_list.append(gabor_feat.ravel())
|
| 397 |
-
# feature_list.append(gabor_feat[:valid_H, :valid_W])
|
| 398 |
-
# if visualize:
|
| 399 |
-
# figs.append(plot_feature(gabor_feat[:valid_H, :valid_W], "Gabor Filter"))
|
| 400 |
-
|
| 401 |
for theta in [0, np.pi/4, np.pi/2]:
|
| 402 |
kernel = cv2.getGaborKernel(
|
| 403 |
(cfg.gabor_ksize, cfg.gabor_ksize),
|
|
@@ -409,9 +375,8 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 409 |
g /= g.max() + 1e-8
|
| 410 |
feature_list.append(g[:valid_H, :valid_W])
|
| 411 |
if visualize:
|
| 412 |
-
# figs.append(plot_feature(g[:valid_H, :valid_W], "Gabor Filter"))
|
| 413 |
vis_items.append((g[:valid_H, :valid_W], f"Gabor θ={theta:.2f}", "gray"))
|
| 414 |
-
#
|
| 415 |
sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
|
| 416 |
sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
|
| 417 |
|
|
@@ -424,11 +389,9 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 424 |
feature_list.append(sobelx[:valid_H, :valid_W])
|
| 425 |
feature_list.append(sobely[:valid_H, :valid_W])
|
| 426 |
if visualize:
|
| 427 |
-
# figs.append(plot_feature(sobelx[:valid_H, :valid_W], "Sobel X"))
|
| 428 |
-
# figs.append(plot_feature(sobely[:valid_H, :valid_W], "Sobel Y"))
|
| 429 |
vis_items.append((sobelx[:valid_H, :valid_W], "Sobel X",'gray'))
|
| 430 |
vis_items.append((sobely[:valid_H, :valid_W], "Sobel Y",'gray'))
|
| 431 |
-
#
|
| 432 |
lap = cv2.Laplacian(gray, cv2.CV_32F)
|
| 433 |
lap = np.abs(lap)
|
| 434 |
lap /= lap.max() + 1e-8
|
|
@@ -436,10 +399,9 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 436 |
feature_list.append(lap[:valid_H, :valid_W])
|
| 437 |
|
| 438 |
if visualize:
|
| 439 |
-
# figs.append(plot_feature(lap[:valid_H, :valid_W], "Laplacian"))
|
| 440 |
vis_items.append((lap[:valid_H, :valid_W], "Laplacian",'gray'))
|
| 441 |
|
| 442 |
-
#
|
| 443 |
gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
|
| 444 |
gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
|
| 445 |
|
|
@@ -449,37 +411,38 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 449 |
feature_list.append(grad_mag[:valid_H, :valid_W])
|
| 450 |
|
| 451 |
if visualize:
|
| 452 |
-
# figs.append(plot_feature(grad_mag[:valid_H, :valid_W], "Gradient Magnitude"))
|
| 453 |
vis_items.append((grad_mag[:valid_H, :valid_W], "Gradient Magnitude",'gray'))
|
| 454 |
|
| 455 |
# Stack all features along channel axis
|
| 456 |
features = np.stack(feature_list, axis=0)
|
| 457 |
-
# features = np.concatenate(feature_list).astype(np.float32)
|
| 458 |
if visualize:
|
| 459 |
-
return features.astype(np.float32),[render_subplots(vis_items, max_cols=8)]
|
| 460 |
return features.astype(np.float32)
|
| 461 |
|
| 462 |
|
| 463 |
-
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
| 464 |
if isinstance(x, torch.Tensor):
|
| 465 |
x = x.cpu().numpy()
|
|
|
|
| 466 |
if isinstance(x, np.ndarray):
|
| 467 |
-
if x.ndim == 3:
|
| 468 |
-
x =
|
| 469 |
elif x.ndim != 4:
|
| 470 |
-
raise ValueError(
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
batch_features = []
|
| 475 |
for img in x:
|
| 476 |
-
if img.
|
| 477 |
img = np.repeat(img[:, :, None], 3, axis=2)
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
return
|
| 483 |
|
| 484 |
def visualize(self, img, show_original=True,show=True):
|
| 485 |
if img.ndim != 3 or img.shape[2] != 3:
|
|
@@ -517,10 +480,14 @@ class ClassicalFeatureExtractor(nn.Module):
|
|
| 517 |
|
| 518 |
|
| 519 |
def output(self):
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
|
| 526 |
|
|
@@ -530,21 +497,20 @@ class FullyConnectedHead(nn.Module):
|
|
| 530 |
num_classes = len(classes)
|
| 531 |
self.classes = classes
|
| 532 |
layers = []
|
| 533 |
-
|
| 534 |
-
for
|
| 535 |
-
layers.append(nn.Linear(in_features,
|
| 536 |
-
layers.append(nn.BatchNorm1d(
|
| 537 |
layers.append(nn.ReLU())
|
| 538 |
layers.append(nn.Dropout(config.dropout))
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
break
|
| 543 |
layers.append(nn.Linear(in_features,num_classes))
|
| 544 |
self.layers = nn.Sequential(*layers)
|
| 545 |
def get_device(self):
|
| 546 |
return next(self.parameters()).device
|
| 547 |
-
def forward(self,x : torch.Tensor):
|
| 548 |
x=x.to(self.get_device())
|
| 549 |
return self.layers(x)
|
| 550 |
|
|
@@ -568,15 +534,15 @@ class Classifier(nn.Module):
|
|
| 568 |
target_size = self.config.img_size
|
| 569 |
x = cv2.resize(x, target_size)
|
| 570 |
logits = self.forward(x)
|
| 571 |
-
probs = torch.softmax(logits,
|
| 572 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 573 |
|
| 574 |
return self.classes[pred_idx]
|
| 575 |
|
| 576 |
-
def forward(self,x):
|
| 577 |
-
feat = self.backbone(x)
|
| 578 |
-
feat = self.flatten(feat)
|
| 579 |
-
return self.head(feat)
|
| 580 |
def visualize_feature(self,img,return_img=True,**kwargs):
|
| 581 |
target_size = self.config.img_size
|
| 582 |
img = cv2.resize(img, target_size)
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from skimage.feature import hog,local_binary_pattern
|
| 7 |
+
import itertools
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import os
|
| 11 |
import io
|
|
|
|
| 16 |
img_size=(256,256)
|
| 17 |
in_channels=3
|
| 18 |
fc_num_layers=3
|
| 19 |
+
conv_hidden_dim=2
|
| 20 |
conv_kernel_size=3
|
| 21 |
dropout=0.2
|
| 22 |
classical_downsample=1
|
|
|
|
| 41 |
harris_ksize = 3
|
| 42 |
harris_k = 0.04
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# LBP
|
| 46 |
lbp_P = 8
|
|
|
|
| 56 |
# Sobel
|
| 57 |
sobel_ksize=3
|
| 58 |
|
| 59 |
+
|
| 60 |
class CNNFeatureExtractor(nn.Module):
|
| 61 |
def __init__(self,config : Config):
|
| 62 |
super().__init__()
|
|
|
|
| 66 |
self.img_size = config.img_size
|
| 67 |
out_channel = 32
|
| 68 |
for i in range(config.conv_hidden_dim):
|
| 69 |
+
layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=config.conv_kernel_size // 2))
|
| 70 |
layers.append(nn.BatchNorm2d(out_channel))
|
| 71 |
layers.append(nn.ReLU())
|
| 72 |
+
layers.append(nn.MaxPool2d((2,2)))
|
| 73 |
in_channel=out_channel
|
| 74 |
out_channel*=2
|
| 75 |
self.layers = nn.Sequential(*layers)
|
| 76 |
def get_device(self):
|
| 77 |
return next(self.parameters()).device
|
| 78 |
+
def forward(self,x,**kwargs):
|
| 79 |
if isinstance(x, list):
|
| 80 |
if isinstance(x[0], np.ndarray):
|
| 81 |
x = np.stack(x, axis=0)
|
|
|
|
| 128 |
conv_layers = [
|
| 129 |
(name, module)
|
| 130 |
for name, module in self.named_modules()
|
| 131 |
+
if isinstance(module, nn.ReLU)
|
| 132 |
]
|
| 133 |
|
| 134 |
all_layer_images = []
|
|
|
|
| 250 |
plt.close(fig)
|
| 251 |
|
| 252 |
return all_layer_images
|
| 253 |
+
|
| 254 |
class ClassicalFeatureExtractor(nn.Module):
|
| 255 |
def __init__(self, config : Config):
|
| 256 |
super().__init__()
|
|
|
|
| 259 |
self.num_downsample = config.classical_downsample
|
| 260 |
self.config = config
|
| 261 |
self.device = 'cpu'
|
| 262 |
+
self.convolution=None
|
| 263 |
+
|
| 264 |
def get_device(self):
|
| 265 |
return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
|
| 266 |
|
| 267 |
+
def render_subplots(self,items, max_cols=8, figsize_per_cell=3):
|
| 268 |
+
n = len(items)
|
| 269 |
+
cols = min(max_cols, n)
|
| 270 |
+
rows = int(np.ceil(n / cols))
|
| 271 |
+
fig, axes = plt.subplots(
|
| 272 |
+
rows, cols,
|
| 273 |
+
figsize=(cols * figsize_per_cell, rows * figsize_per_cell)
|
| 274 |
+
)
|
| 275 |
+
axes = np.atleast_2d(axes)
|
| 276 |
+
for idx, (img, title, cmap) in enumerate(items):
|
| 277 |
+
r = idx // cols
|
| 278 |
+
c = idx % cols
|
| 279 |
+
ax = axes[r, c]
|
| 280 |
+
ax.imshow(img, cmap=cmap)
|
| 281 |
+
ax.set_title(title, fontsize=9)
|
| 282 |
+
ax.axis("off")
|
| 283 |
+
for idx in range(n, rows * cols):
|
| 284 |
+
r = idx // cols
|
| 285 |
+
c = idx % cols
|
| 286 |
+
axes[r, c].axis("off")
|
| 287 |
+
|
| 288 |
+
plt.tight_layout()
|
| 289 |
+
return fig
|
| 290 |
+
|
| 291 |
def extract_features(self, img,visualize=False,**kwargs):
|
| 292 |
cfg = self.config
|
|
|
|
| 293 |
# Convert to grayscale
|
| 294 |
gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
|
|
|
| 295 |
for _ in range(self.num_downsample):
|
| 296 |
gray = cv2.pyrDown(gray)
|
|
|
|
| 297 |
gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
|
| 298 |
valid_H, valid_W = gray.shape[:2]
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
feature_list = []
|
| 302 |
vis_items=[]
|
| 303 |
+
# DEPRECATED
|
| 304 |
+
# H, W = gray.shape
|
| 305 |
+
# cell_h, cell_w = cfg.hog_pixels_per_cell
|
| 306 |
+
# block_h, block_w = cfg.hog_cells_per_block
|
| 307 |
+
|
| 308 |
+
# min_h = cell_h * block_h
|
| 309 |
+
# min_w = cell_w * block_w
|
| 310 |
+
# use_hog = False
|
| 311 |
+
# # 1. HOG
|
| 312 |
+
# if use_hog:
|
| 313 |
+
# hog_descriptors, hog_image = hog(
|
| 314 |
+
# gray,
|
| 315 |
+
# orientations=cfg.hog_orientations,
|
| 316 |
+
# pixels_per_cell=cfg.hog_pixels_per_cell,
|
| 317 |
+
# cells_per_block=cfg.hog_cells_per_block,
|
| 318 |
+
# block_norm=cfg.hog_block_norm,
|
| 319 |
+
# visualize=True,
|
| 320 |
+
# feature_vector=False
|
| 321 |
+
# )
|
| 322 |
+
|
| 323 |
+
# hog_cells = hog_descriptors.mean(axis=(2, 3))
|
| 324 |
|
| 325 |
+
# cell_h, cell_w = cfg.hog_pixels_per_cell
|
| 326 |
+
# hog_pixel = np.repeat(
|
| 327 |
+
# np.repeat(hog_cells, cell_h, axis=0),
|
| 328 |
+
# cell_w, axis=1
|
| 329 |
+
# )
|
| 330 |
+
# hog_pixel = hog_pixel[:gray.shape[0], :gray.shape[1]]
|
| 331 |
+
# hog_energy = np.sum(hog_pixel, axis=2)
|
| 332 |
+
# dominant_bin = np.argmax(hog_pixel, axis=2)
|
| 333 |
+
# dominant_strength = np.max(hog_pixel, axis=2)
|
| 334 |
+
# dominant_weighted = dominant_bin * dominant_strength
|
| 335 |
+
# valid_H, valid_W = hog_pixel.shape[:2]
|
| 336 |
+
# if visualize:
|
| 337 |
+
# vis_items.append((hog_energy, "HOG Energy",'gray'))
|
| 338 |
+
# vis_items.append((dominant_bin, "HOG Dominant Bin",'hsv'))
|
| 339 |
+
# vis_items.append((dominant_weighted, "HOG Weighted Dominant Bin",'gray'))
|
| 340 |
+
# vis_items.append((hog_image[:valid_H, :valid_W], f"HoG",'gray'))
|
| 341 |
+
# for b in range(hog_pixel.shape[2]):
|
| 342 |
+
# feature_list.append(hog_pixel[:, :, b])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
|
| 345 |
# 2. Canny edges
|
| 346 |
edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
|
|
|
|
| 347 |
feature_list.append(edges[:valid_H, :valid_W])
|
| 348 |
if visualize:
|
|
|
|
| 349 |
vis_items.append((edges[:valid_H, :valid_W], "Canny Edge", "gray"))
|
| 350 |
# 3. Harris corners
|
| 351 |
harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
|
| 352 |
harris = cv2.dilate(harris, None)
|
| 353 |
harris = np.clip(harris, 0, 1)
|
|
|
|
| 354 |
feature_list.append(harris[:valid_H, :valid_W])
|
| 355 |
if visualize:
|
|
|
|
| 356 |
vis_items.append((harris[:valid_H, :valid_W], "Harris Corner", "gray"))
|
| 357 |
+
|
| 358 |
+
# 4. LBP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
|
| 360 |
lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
|
| 361 |
# feature_list.append(lbp.ravel())
|
|
|
|
| 363 |
if visualize:
|
| 364 |
# figs.append(plot_feature(lbp[:valid_H, :valid_W], "LBP"))
|
| 365 |
vis_items.append((lbp[:valid_H, :valid_W], "LBP", "gray"))
|
| 366 |
+
# 5. Gabor filter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
for theta in [0, np.pi/4, np.pi/2]:
|
| 368 |
kernel = cv2.getGaborKernel(
|
| 369 |
(cfg.gabor_ksize, cfg.gabor_ksize),
|
|
|
|
| 375 |
g /= g.max() + 1e-8
|
| 376 |
feature_list.append(g[:valid_H, :valid_W])
|
| 377 |
if visualize:
|
|
|
|
| 378 |
vis_items.append((g[:valid_H, :valid_W], f"Gabor θ={theta:.2f}", "gray"))
|
| 379 |
+
# 6. Sobel
|
| 380 |
sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
|
| 381 |
sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
|
| 382 |
|
|
|
|
| 389 |
feature_list.append(sobelx[:valid_H, :valid_W])
|
| 390 |
feature_list.append(sobely[:valid_H, :valid_W])
|
| 391 |
if visualize:
|
|
|
|
|
|
|
| 392 |
vis_items.append((sobelx[:valid_H, :valid_W], "Sobel X",'gray'))
|
| 393 |
vis_items.append((sobely[:valid_H, :valid_W], "Sobel Y",'gray'))
|
| 394 |
+
# 7. Laplacian
|
| 395 |
lap = cv2.Laplacian(gray, cv2.CV_32F)
|
| 396 |
lap = np.abs(lap)
|
| 397 |
lap /= lap.max() + 1e-8
|
|
|
|
| 399 |
feature_list.append(lap[:valid_H, :valid_W])
|
| 400 |
|
| 401 |
if visualize:
|
|
|
|
| 402 |
vis_items.append((lap[:valid_H, :valid_W], "Laplacian",'gray'))
|
| 403 |
|
| 404 |
+
# 8. Gradient Magnitude
|
| 405 |
gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
|
| 406 |
gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
|
| 407 |
|
|
|
|
| 411 |
feature_list.append(grad_mag[:valid_H, :valid_W])
|
| 412 |
|
| 413 |
if visualize:
|
|
|
|
| 414 |
vis_items.append((grad_mag[:valid_H, :valid_W], "Gradient Magnitude",'gray'))
|
| 415 |
|
| 416 |
# Stack all features along channel axis
|
| 417 |
features = np.stack(feature_list, axis=0)
|
|
|
|
| 418 |
if visualize:
|
| 419 |
+
return features.astype(np.float32),[self.render_subplots(vis_items, max_cols=8)]
|
| 420 |
return features.astype(np.float32)
|
| 421 |
|
| 422 |
|
| 423 |
+
def forward(self, x, **kwargs):
|
| 424 |
+
if isinstance(x, list):
|
| 425 |
+
x = np.stack(x, axis=0)
|
| 426 |
+
|
| 427 |
if isinstance(x, torch.Tensor):
|
| 428 |
x = x.cpu().numpy()
|
| 429 |
+
|
| 430 |
if isinstance(x, np.ndarray):
|
| 431 |
+
if x.ndim == 3:
|
| 432 |
+
x = x[None]
|
| 433 |
elif x.ndim != 4:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Expected input of shape HWC or BHWC, got {x.shape}"
|
| 436 |
+
)
|
| 437 |
+
feats = []
|
|
|
|
| 438 |
for img in x:
|
| 439 |
+
if img.shape[2] != 3:
|
| 440 |
img = np.repeat(img[:, :, None], 3, axis=2)
|
| 441 |
+
feats.append(self.extract_features(img))
|
| 442 |
+
|
| 443 |
+
feats = np.stack(feats, axis=0)
|
| 444 |
+
feats = torch.from_numpy(feats).float().to(self.get_device())
|
| 445 |
+
return feats
|
| 446 |
|
| 447 |
def visualize(self, img, show_original=True,show=True):
|
| 448 |
if img.ndim != 3 or img.shape[2] != 3:
|
|
|
|
| 480 |
|
| 481 |
|
| 482 |
def output(self):
|
| 483 |
+
dummy = np.zeros(
|
| 484 |
+
(self.img_size[1], self.img_size[0], 3),
|
| 485 |
+
dtype=np.float32
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
feats = self.forward(dummy)
|
| 489 |
+
|
| 490 |
+
return feats
|
| 491 |
|
| 492 |
|
| 493 |
|
|
|
|
| 497 |
num_classes = len(classes)
|
| 498 |
self.classes = classes
|
| 499 |
layers = []
|
| 500 |
+
hidden_dim =1024
|
| 501 |
+
for _ in range(config.fc_num_layers):
|
| 502 |
+
layers.append(nn.Linear(in_features, hidden_dim))
|
| 503 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 504 |
layers.append(nn.ReLU())
|
| 505 |
layers.append(nn.Dropout(config.dropout))
|
| 506 |
+
|
| 507 |
+
in_features = hidden_dim
|
| 508 |
+
hidden_dim = max(hidden_dim // 2, num_classes * 2)
|
|
|
|
| 509 |
layers.append(nn.Linear(in_features,num_classes))
|
| 510 |
self.layers = nn.Sequential(*layers)
|
| 511 |
def get_device(self):
|
| 512 |
return next(self.parameters()).device
|
| 513 |
+
def forward(self,x : torch.Tensor,**kwargs):
|
| 514 |
x=x.to(self.get_device())
|
| 515 |
return self.layers(x)
|
| 516 |
|
|
|
|
| 534 |
target_size = self.config.img_size
|
| 535 |
x = cv2.resize(x, target_size)
|
| 536 |
logits = self.forward(x)
|
| 537 |
+
probs = torch.softmax(logits,dim=1)
|
| 538 |
pred_idx = torch.argmax(probs, dim=1).item()
|
| 539 |
|
| 540 |
return self.classes[pred_idx]
|
| 541 |
|
| 542 |
+
def forward(self,x,**kwargs):
|
| 543 |
+
feat = self.backbone(x,**kwargs)
|
| 544 |
+
feat = self.flatten(feat,**kwargs)
|
| 545 |
+
return self.head(feat,**kwargs)
|
| 546 |
def visualize_feature(self,img,return_img=True,**kwargs):
|
| 547 |
target_size = self.config.img_size
|
| 548 |
img = cv2.resize(img, target_size)
|
src/trainer.py
CHANGED
|
@@ -10,12 +10,149 @@ import random
|
|
| 10 |
import numpy as np
|
| 11 |
import torch.nn as nn
|
| 12 |
import time
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def seed_worker(worker_id):
|
| 15 |
worker_seed = torch.initial_seed() % 2**32
|
| 16 |
np.random.seed(worker_seed)
|
| 17 |
random.seed(worker_seed)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class ModelTrainer:
|
| 20 |
def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None):
|
| 21 |
g = torch.Generator()
|
|
@@ -50,6 +187,9 @@ class ModelTrainer:
|
|
| 50 |
self.optim.zero_grad()
|
| 51 |
self.criterion = nn.CrossEntropyLoss()
|
| 52 |
self.return_fig=return_fig
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4):
|
| 55 |
|
|
@@ -114,13 +254,15 @@ class ModelTrainer:
|
|
| 114 |
return None
|
| 115 |
|
| 116 |
|
| 117 |
-
def train_one_epoch(self):
|
| 118 |
self.model.train()
|
| 119 |
total_loss = 0
|
| 120 |
train_pbar = tqdm(self.train_loader, desc="Training",leave=False)
|
| 121 |
correct = 0
|
| 122 |
total = 0
|
| 123 |
for imgs, labels in train_pbar:
|
|
|
|
|
|
|
| 124 |
labels = labels.to(self.device)
|
| 125 |
|
| 126 |
# Forward
|
|
@@ -146,18 +288,30 @@ class ModelTrainer:
|
|
| 146 |
val_losses=[]
|
| 147 |
val_accuracies=[]
|
| 148 |
for epoch in range(1, epochs + 1):
|
| 149 |
-
train_loss,train_acc = self.train_one_epoch()
|
|
|
|
|
|
|
| 150 |
train_losses.append(train_loss)
|
| 151 |
train_accuracies.append(train_acc)
|
| 152 |
if self.val_loader is not None:
|
| 153 |
val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1))
|
|
|
|
|
|
|
| 154 |
val_losses.append(val_loss)
|
| 155 |
val_accuracies.append(val_acc)
|
| 156 |
print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
yield train_loss,train_acc,val_loss,val_acc,fig
|
| 158 |
else:
|
| 159 |
print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}")
|
| 160 |
yield train_loss,train_acc,None,None,None
|
|
|
|
|
|
|
|
|
|
| 161 |
yield train_losses,train_accuracies,val_losses,val_accuracies,None
|
| 162 |
|
| 163 |
def validate(self,epoch, visualize=False):
|
|
@@ -177,6 +331,8 @@ class ModelTrainer:
|
|
| 177 |
fig = None
|
| 178 |
with torch.no_grad():
|
| 179 |
for imgs, labels in val_pbar:
|
|
|
|
|
|
|
| 180 |
labels = labels.to(self.device)
|
| 181 |
|
| 182 |
outputs = self.model(imgs)
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import torch.nn as nn
|
| 12 |
import time
|
| 13 |
+
from sklearn.metrics import (
|
| 14 |
+
confusion_matrix,
|
| 15 |
+
classification_report,
|
| 16 |
+
roc_curve,
|
| 17 |
+
auc
|
| 18 |
+
)
|
| 19 |
+
from sklearn.preprocessing import label_binarize
|
| 20 |
def seed_worker(worker_id):
|
| 21 |
worker_seed = torch.initial_seed() % 2**32
|
| 22 |
np.random.seed(worker_seed)
|
| 23 |
random.seed(worker_seed)
|
| 24 |
|
| 25 |
+
def model_evaluation(model, val_set, device,batch_size=32,num_workers=0, class_names=None):
|
| 26 |
+
|
| 27 |
+
model.eval()
|
| 28 |
+
all_preds = []
|
| 29 |
+
all_probs = []
|
| 30 |
+
all_labels = []
|
| 31 |
+
val_loader = DataLoader(
|
| 32 |
+
val_set,
|
| 33 |
+
batch_size=batch_size,
|
| 34 |
+
shuffle=False,
|
| 35 |
+
num_workers=num_workers
|
| 36 |
+
)
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
for images, labels in val_loader:
|
| 39 |
+
if images.ndim == 4 and images.shape[-1] in (1, 3):
|
| 40 |
+
images = images.permute(0, 3, 1, 2)
|
| 41 |
+
images = images.to(device)
|
| 42 |
+
labels = labels.to(device)
|
| 43 |
+
logits = model(images)
|
| 44 |
+
probs = torch.softmax(logits, dim=1)
|
| 45 |
+
preds = torch.argmax(probs, dim=1)
|
| 46 |
+
|
| 47 |
+
all_preds.append(preds.cpu().numpy())
|
| 48 |
+
all_probs.append(probs.cpu().numpy())
|
| 49 |
+
all_labels.append(labels.cpu().numpy())
|
| 50 |
+
|
| 51 |
+
y_true = np.concatenate(all_labels)
|
| 52 |
+
y_pred = np.concatenate(all_preds)
|
| 53 |
+
y_prob = np.concatenate(all_probs)
|
| 54 |
+
|
| 55 |
+
num_classes = y_prob.shape[1]
|
| 56 |
+
|
| 57 |
+
if class_names is None:
|
| 58 |
+
class_names = [f"Class {i}" for i in range(num_classes)]
|
| 59 |
+
|
| 60 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 61 |
+
|
| 62 |
+
cm_fig, ax = plt.subplots(figsize=(6, 6))
|
| 63 |
+
im = ax.imshow(cm)
|
| 64 |
+
|
| 65 |
+
ax.set_title("Confusion Matrix")
|
| 66 |
+
ax.set_xlabel("Predicted")
|
| 67 |
+
ax.set_ylabel("True")
|
| 68 |
+
ax.set_xticks(range(num_classes))
|
| 69 |
+
ax.set_yticks(range(num_classes))
|
| 70 |
+
ax.set_xticklabels(class_names, rotation=75)
|
| 71 |
+
ax.set_yticklabels(class_names)
|
| 72 |
+
|
| 73 |
+
for i in range(num_classes):
|
| 74 |
+
for j in range(num_classes):
|
| 75 |
+
ax.text(j, i, cm[i, j], ha="center", va="center")
|
| 76 |
+
|
| 77 |
+
plt.tight_layout()
|
| 78 |
+
|
| 79 |
+
report = classification_report(
|
| 80 |
+
y_true, y_pred,
|
| 81 |
+
target_names=class_names,
|
| 82 |
+
output_dict=True
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
cr_fig, ax = plt.subplots(figsize=(12, 8))
|
| 86 |
+
ax.axis("off")
|
| 87 |
+
|
| 88 |
+
table_data = []
|
| 89 |
+
headers = ["Class", "Precision", "Recall", "F1", "Support"]
|
| 90 |
+
|
| 91 |
+
for cls in class_names:
|
| 92 |
+
row = report[cls]
|
| 93 |
+
table_data.append([
|
| 94 |
+
cls,
|
| 95 |
+
f"{row['precision']:.3f}",
|
| 96 |
+
f"{row['recall']:.3f}",
|
| 97 |
+
f"{row['f1-score']:.3f}",
|
| 98 |
+
int(row['support'])
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
accuracy = report["accuracy"]
|
| 102 |
+
macro_avg = report["macro avg"]
|
| 103 |
+
weighted_avg = report["weighted avg"]
|
| 104 |
+
|
| 105 |
+
table_data.append([
|
| 106 |
+
"Accuracy",
|
| 107 |
+
f"{accuracy:.3f}",
|
| 108 |
+
"",
|
| 109 |
+
"",
|
| 110 |
+
""
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
table_data.append([
|
| 114 |
+
"Macro Avg",
|
| 115 |
+
f"{macro_avg['precision']:.3f}",
|
| 116 |
+
f"{macro_avg['recall']:.3f}",
|
| 117 |
+
f"{macro_avg['f1-score']:.3f}",
|
| 118 |
+
f"{int(macro_avg['support'])}" if 'support' in macro_avg else ""
|
| 119 |
+
])
|
| 120 |
+
|
| 121 |
+
table_data.append([
|
| 122 |
+
"Weighted Avg",
|
| 123 |
+
f"{weighted_avg['precision']:.3f}",
|
| 124 |
+
f"{weighted_avg['recall']:.3f}",
|
| 125 |
+
f"{weighted_avg['f1-score']:.3f}",
|
| 126 |
+
f"{int(weighted_avg['support'])}" if 'support' in weighted_avg else ""
|
| 127 |
+
])
|
| 128 |
+
|
| 129 |
+
table = ax.table(
|
| 130 |
+
cellText=table_data,
|
| 131 |
+
colLabels=headers,
|
| 132 |
+
loc="center"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
table.scale(1, 2)
|
| 136 |
+
ax.set_title("Classification Report")
|
| 137 |
+
|
| 138 |
+
y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))
|
| 139 |
+
|
| 140 |
+
roc_fig, ax = plt.subplots(figsize=(6, 6))
|
| 141 |
+
|
| 142 |
+
for i in range(num_classes):
|
| 143 |
+
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
|
| 144 |
+
roc_auc = auc(fpr, tpr)
|
| 145 |
+
ax.plot(fpr, tpr, label=f"{class_names[i]} (AUC={roc_auc:.3f})")
|
| 146 |
+
|
| 147 |
+
ax.plot([0, 1], [0, 1], linestyle="--")
|
| 148 |
+
ax.set_xlabel("False Positive Rate")
|
| 149 |
+
ax.set_ylabel("True Positive Rate")
|
| 150 |
+
ax.set_title("ROC-AUC Curve")
|
| 151 |
+
ax.legend()
|
| 152 |
+
ax.grid(True)
|
| 153 |
+
|
| 154 |
+
return cm_fig, cr_fig, roc_fig
|
| 155 |
+
|
| 156 |
class ModelTrainer:
|
| 157 |
def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None):
|
| 158 |
g = torch.Generator()
|
|
|
|
| 187 |
self.optim.zero_grad()
|
| 188 |
self.criterion = nn.CrossEntropyLoss()
|
| 189 |
self.return_fig=return_fig
|
| 190 |
+
self.best_model_state = None
|
| 191 |
+
self.best_val_acc = 0.0
|
| 192 |
+
self.interrupt=False
|
| 193 |
|
| 194 |
def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4):
|
| 195 |
|
|
|
|
| 254 |
return None
|
| 255 |
|
| 256 |
|
| 257 |
+
def train_one_epoch(self,epoch):
|
| 258 |
self.model.train()
|
| 259 |
total_loss = 0
|
| 260 |
train_pbar = tqdm(self.train_loader, desc="Training",leave=False)
|
| 261 |
correct = 0
|
| 262 |
total = 0
|
| 263 |
for imgs, labels in train_pbar:
|
| 264 |
+
if self.interrupt:
|
| 265 |
+
break
|
| 266 |
labels = labels.to(self.device)
|
| 267 |
|
| 268 |
# Forward
|
|
|
|
| 288 |
val_losses=[]
|
| 289 |
val_accuracies=[]
|
| 290 |
for epoch in range(1, epochs + 1):
|
| 291 |
+
train_loss,train_acc = self.train_one_epoch(epoch)
|
| 292 |
+
if self.interrupt:
|
| 293 |
+
return
|
| 294 |
train_losses.append(train_loss)
|
| 295 |
train_accuracies.append(train_acc)
|
| 296 |
if self.val_loader is not None:
|
| 297 |
val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1))
|
| 298 |
+
if self.interrupt:
|
| 299 |
+
return
|
| 300 |
val_losses.append(val_loss)
|
| 301 |
val_accuracies.append(val_acc)
|
| 302 |
print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}")
|
| 303 |
+
if val_acc > self.best_val_acc:
|
| 304 |
+
print(f"New best model found at epoch {epoch} (Val Acc: {val_acc:.4f})")
|
| 305 |
+
self.best_val_acc = val_acc
|
| 306 |
+
self.best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()}
|
| 307 |
+
|
| 308 |
yield train_loss,train_acc,val_loss,val_acc,fig
|
| 309 |
else:
|
| 310 |
print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}")
|
| 311 |
yield train_loss,train_acc,None,None,None
|
| 312 |
+
if self.best_model_state is not None:
|
| 313 |
+
self.model.load_state_dict(self.best_model_state)
|
| 314 |
+
print(f"Best model (Val Acc: {self.best_val_acc:.4f}) loaded into trainer.model")
|
| 315 |
yield train_losses,train_accuracies,val_losses,val_accuracies,None
|
| 316 |
|
| 317 |
def validate(self,epoch, visualize=False):
|
|
|
|
| 331 |
fig = None
|
| 332 |
with torch.no_grad():
|
| 333 |
for imgs, labels in val_pbar:
|
| 334 |
+
if self.interrupt:
|
| 335 |
+
break
|
| 336 |
labels = labels.to(self.device)
|
| 337 |
|
| 338 |
outputs = self.model(imgs)
|