VJyzCELERY commited on
Commit
c3d45c0
·
1 Parent(s): 5a5e816

Commit to hf space

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /dataset
2
+ /trained_model
3
+ *.pt
4
+ streamlitapp.py
app.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import zipfile
3
+ import os
4
+ import torch
5
+ from src.dataloader import ImageDataset,collate_fn
6
+ from src.model import Classifier,Config,CNNFeatureExtractor,ClassicalFeatureExtractor,load
7
+ from torch.utils.data import Subset
8
+ from src.trainer import ModelTrainer
9
+ import torch
10
+ import os
11
+ import numpy as np
12
+ import time
13
+ import cv2
14
+ from PIL import Image
15
+ import io
16
+ import matplotlib.pyplot as plt
17
+ import shutil
18
+ import pandas as pd
19
+ from sklearn.model_selection import train_test_split
20
+ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
21
+ from sklearn.metrics import classification_report
22
+ from torch.utils.data import DataLoader
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ def unzip_dataset(zip_file):
26
+ base_name = os.path.splitext(os.path.basename(zip_file.name))[0]
27
+ dataset_path = os.path.join(".", base_name)
28
+
29
+ os.makedirs(dataset_path, exist_ok=True)
30
+
31
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
32
+ zip_ref.extractall(dataset_path)
33
+ extracted_items = os.listdir(dataset_path)
34
+ if len(extracted_items) == 1 and os.path.isdir(os.path.join(dataset_path, extracted_items[0])):
35
+ dataset_path = os.path.join(dataset_path, extracted_items[0])
36
+
37
+ print(f"Dataset extracted to: {dataset_path}")
38
+ class_names = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]
39
+ print(f"Detected classes: {class_names}")
40
+
41
+ for cls in class_names:
42
+ cls_path = os.path.join(dataset_path, cls)
43
+ images = os.listdir(cls_path)
44
+ print(f"Class '{cls}' has {len(images)} images. Sample: {images[:3]}")
45
+
46
+ return dataset_path
47
+
48
+ cnn_history={
49
+ "train_acc":[],
50
+ "train_loss":[],
51
+ "val_acc":[],
52
+ "val_loss":[]
53
+ }
54
+
55
+ classic_history={
56
+ "train_acc":[],
57
+ "train_loss":[],
58
+ "val_acc":[],
59
+ "val_loss":[]
60
+ }
61
+
62
+ training_interrupt = False
63
+
64
+ def fig_to_image(fig):
65
+ buf = io.BytesIO()
66
+ fig.savefig(buf, format="png")
67
+ buf.seek(0)
68
+ img = Image.open(buf).convert("RGB")
69
+ img_array = np.array(img)
70
+ plt.close(fig)
71
+ return img_array
72
+
73
+ def plot(datas, labels, xlabel, ylabel, title, figsize=(16, 8)):
74
+ fig, ax = plt.subplots(figsize=figsize)
75
+ for data, label in zip(datas, labels):
76
+ ax.plot(range(1, len(data) + 1), data, label=label)
77
+ ax.set_xlabel(xlabel)
78
+ ax.set_ylabel(ylabel)
79
+ ax.set_title(title)
80
+ ax.legend()
81
+ return fig_to_image(fig)
82
+
83
+ class TrainingInterrupted(Exception):
84
+ pass
85
+ def stop_training():
86
+ global training_interrupt
87
+ training_interrupt = True
88
+ return "Training stopped."
89
+
90
+ def train(cnn,classic,train_set,val_set,batch_size,lr,epochs,device="cpu",visualize_every=5):
91
+ global training_interrupt
92
+ training_interrupt = False
93
+ global cnn_history
94
+ global classic_history
95
+ cnn_done=False
96
+ cnn_history={
97
+ "train_acc":[],
98
+ "train_loss":[],
99
+ "val_acc":[],
100
+ "val_loss":[]
101
+ }
102
+
103
+ classic_history={
104
+ "train_acc":[],
105
+ "train_loss":[],
106
+ "val_acc":[],
107
+ "val_loss":[]
108
+ }
109
+ try:
110
+ if training_interrupt:
111
+ raise TrainingInterrupted("Training was interrupted!")
112
+ cnntrainer = ModelTrainer(cnn,train_set,val_set,batch_size,lr,device=device,return_fig=True)
113
+ classictrainer = ModelTrainer(classic,train_set,val_set,batch_size,lr,device=device,return_fig=True)
114
+ cnn_text=""
115
+ classic_text=""
116
+ cnn_fig=None
117
+ all_cnn_fig = []
118
+ all_classic_fig= []
119
+ classic_fig=None
120
+ start_time = time.time()
121
+ for i,(cnn_train_loss,cnn_train_acc,cnn_val_loss,cnn_val_acc,cnn_fig) in enumerate(cnntrainer.train(epochs,visualize_every=visualize_every)):
122
+ if training_interrupt:
123
+ raise TrainingInterrupted("Training was interrupted!")
124
+ if i == epochs:
125
+ break
126
+ if cnn_fig is not None:
127
+ for fig in cnn_fig:
128
+ fig.suptitle(f"Epoch {i+1}", fontsize=16)
129
+ all_cnn_fig.append(fig_to_image(fig))
130
+ cnn_text+= f"Epochs {i+1} : Train Loss: {cnn_train_loss:.4f}, Train Acc: {cnn_train_acc:.4f}, Val Loss: {cnn_val_loss:.4f}, Val Acc: {cnn_val_acc:.4f}\n"
131
+ cnn_history['train_acc'].append(cnn_train_acc)
132
+ cnn_history['train_loss'].append(cnn_train_loss)
133
+ cnn_history['val_acc'].append(cnn_val_acc)
134
+ cnn_history['val_loss'].append(cnn_val_loss)
135
+
136
+ yield cnn_text,all_cnn_fig,classic_text,all_classic_fig,cnn_done
137
+ cnn_done=True
138
+ dt = time.time()-start_time
139
+ cnn_fig=None
140
+ cnn_text+=f'Time taken : {dt:.2f} seconds\n'
141
+ yield cnn_text,all_cnn_fig,classic_text,all_classic_fig,cnn_done
142
+ start_time = time.time()
143
+ for i,(classic_train_loss,classic_train_acc,classic_val_loss,classic_val_acc,classic_fig) in enumerate(classictrainer.train(epochs,visualize_every=visualize_every)):
144
+ if training_interrupt:
145
+ raise TrainingInterrupted("Training was interrupted!")
146
+ if i == epochs:
147
+ break
148
+ if classic_fig is not None:
149
+ for fig in classic_fig:
150
+ fig.suptitle(f"Epoch {i+1}", fontsize=16)
151
+ all_classic_fig.append(fig_to_image(fig))
152
+ classic_history['train_acc'].append(classic_train_acc)
153
+ classic_history['train_loss'].append(classic_train_loss)
154
+ classic_history['val_acc'].append(classic_val_acc)
155
+ classic_history['val_loss'].append(classic_val_loss)
156
+ classic_text+= f"Epochs {i+1} : Train Loss: {classic_train_loss:.4f}, Train Acc: {classic_train_acc:.4f}, Val Loss: {classic_val_loss:.4f}, Val Acc: {classic_val_acc:.4f}\n"
157
+ yield cnn_text,all_cnn_fig,classic_text,all_classic_fig,cnn_done
158
+ dt = time.time()-start_time
159
+ classic_fig=None
160
+ classic_text+=f'Time taken : {dt:.2f} seconds\n'
161
+ yield cnn_text,all_cnn_fig,classic_text,all_classic_fig,cnn_done
162
+ except TrainingInterrupted as e:
163
+ print(e)
164
+ return
165
+
166
+ def train_model(zip_file,batch_size,lr,epochs,seed,vis_every,
167
+ img_width,img_height,fc_num_layers,
168
+ in_channels,conv_hidden_dim,dropout,
169
+ classical_downsample,
170
+ hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
171
+ canny_sigma,canny_low,canny_high,
172
+ gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
173
+ harris_block_size,harris_ksize,harris_k,
174
+ shi_max_corners,shi_quality_level,shi_min_distance,
175
+ lbp_P,lbp_R,
176
+ gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma):
177
+ config = Config()
178
+ global training_interrupt
179
+ training_interrupt = False
180
+ BATCH_SIZE = batch_size
181
+ DATASET_PATH = unzip_dataset(zip_file)
182
+ SEED = seed
183
+ EPOCHS = epochs
184
+ LR = lr
185
+ config.img_size = (int(img_width),int(img_height))
186
+ config.fc_num_layers = int(fc_num_layers)
187
+ # CNN Config
188
+ config.in_channels = int(in_channels)
189
+ config.conv_hidden_dim=int(conv_hidden_dim)
190
+ config.dropout=dropout
191
+ # Classical Config
192
+ config.classical_downsample=int(classical_downsample)
193
+ config.hog_orientations=int(hog_orientations)
194
+ config.hog_pixels_per_cell=(int(hog_pixels_per_cell),int(hog_pixels_per_cell))
195
+ config.hog_cells_per_block=(int(hog_cells_per_block),int(hog_cells_per_block))
196
+ config.hog_block_norm=hog_block_norm
197
+ config.canny_sigma=int(canny_sigma)
198
+ config.canny_low=canny_low
199
+ config.canny_high=canny_high
200
+ config.gaussian_ksize=(int(gaussian_ksize),int(gaussian_ksize))
201
+ config.gaussian_sigmaX=gaussian_sigmaX
202
+ config.gaussian_sigmaY=gaussian_sigmaY
203
+ config.harris_block_size=int(harris_block_size)
204
+ config.harris_ksize=int(harris_ksize)
205
+ config.harris_k=harris_k
206
+ config.shi_max_corners=int(shi_max_corners)
207
+ config.shi_quality_level=shi_quality_level
208
+ config.shi_min_distance=int(shi_min_distance)
209
+ config.lbp_P=int(lbp_P)
210
+ config.lbp_R=int(lbp_R)
211
+ config.gabor_ksize=int(gabor_ksize)
212
+ config.gabor_sigma=int(gabor_sigma)
213
+ config.gabor_theta=int(gabor_theta)
214
+ config.gabor_lambda=int(gabor_lambda)
215
+ config.gabor_gamma=gabor_gamma
216
+ cnn_history_plots=[]
217
+ classical_history_plots=[]
218
+ cnn_plotted=False
219
+ try:
220
+ dataset = ImageDataset(DATASET_PATH,config.img_size)
221
+ labels = [item['id'] for item in dataset.data]
222
+ train_idx, validation_idx = train_test_split(np.arange(len(dataset)),
223
+ test_size=0.2,
224
+ random_state=SEED,
225
+ shuffle=True,
226
+ stratify=labels)
227
+ train_dataset = Subset(dataset, train_idx)
228
+ val_dataset = Subset(dataset, validation_idx)
229
+ cnnbackbone = CNNFeatureExtractor(config).to(device)
230
+ cnnmodel = Classifier(cnnbackbone,train_dataset.dataset.classes,config).to(device)
231
+ classicbackbone = ClassicalFeatureExtractor(config)
232
+ classicmodel = Classifier(classicbackbone,train_dataset.dataset.classes,config).to(device)
233
+ for cnn_text,cnn_fig,classic_text,classic_fig,cnn_done in train(cnnmodel,classicmodel,train_dataset,val_dataset,BATCH_SIZE,LR,EPOCHS,device,visualize_every=vis_every):
234
+ if cnn_done and not cnn_plotted:
235
+ cnn_history_plots.append(plot([cnn_history['train_acc'],cnn_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
236
+ cnn_history_plots.append(plot([cnn_history['train_loss'],cnn_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
237
+
238
+ yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
239
+ classical_history_plots.append(plot([classic_history['train_acc'],classic_history['val_acc']],['Training Accuracy','Validation Accuracy'],'Epochs','Accuracy (%)','Training vs Validation Accuracy'))
240
+ classical_history_plots.append(plot([classic_history['train_loss'],classic_history['val_loss']],['Training Loss','Validation Loss'],'Epochs','Loss','Training vs Validation Loss'))
241
+
242
+ yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
243
+
244
+ except RuntimeError as e:
245
+ print(e)
246
+ yield cnn_text,cnn_fig,classic_text,classic_fig,cnn_history_plots,classical_history_plots
247
+ return
248
+ finally:
249
+ if os.path.exists(DATASET_PATH):
250
+ shutil.rmtree(DATASET_PATH)
251
+ print(f"Temporary dataset folder '{DATASET_PATH}' removed.")
252
+
253
+ cnnmodel.save(os.path.join('trained_model','cnn_model.pt'))
254
+ classicmodel.save(os.path.join('trained_model','classic_model.pt'))
255
+
256
+
257
+ intro_html = """
258
+ <div style="
259
+ border-left:6px solid #2563eb;
260
+ border-right:6px solid #2563eb;
261
+ padding:16px;
262
+ border-radius:8px;
263
+ font-size:16px;
264
+ line-height:1.6;
265
+ text-align: justify;
266
+ text-justify: inter-word;
267
+ ">
268
+ <h1 style="margin-top:0;">Welcome to the Object Classifier Playground!</h1>
269
+ <p>
270
+ Object Classification is a field of computer vision where we train computer to learn to classify or identify what a model is.
271
+ In traditional Object Classification, the task usually consist of feature extraction and classification model.
272
+ For feature extraction there has been several methods of extracting a feature using certain algorithm. These algorithm consist of algorithm such as Corner Detection, Edge Detection, Local Binary Pattern (LBP) or even Histogram of Gradient (HoG).
273
+ There are a lot of means of feature extraction. After feature extraction, the feature will be passed to machine learning algorithm specifically classifier model.
274
+ One such model is the SVM, k Nearest Neighbor or Naive Bayes which will learn to distinguish object categories based on said features.
275
+ </p>
276
+ <p>
277
+ With the advancement of deep learning, object classification task has been significantly simplified. Now with deep learning, we barely use feature extraction algorithm anymore.
278
+ The reason is not because feature extraction has became obsolete in deep learning, instead the process itself has become part of the learning process. With deep learning, we use a model called Convolutional Neural Network (CNN).
279
+ A convolutional network consist of two main layers, the convolution layer and the fully connected layer. The convolution layer apply filter on the image with a filter usually called convolutional kernel where the value of each cells in the convolutional kernel is random initially.
280
+ </p>
281
+ <img src="https://raw.githubusercontent.com/VJyzCELERY/ClassicalObjectClassifier/refs/heads/main/assets/conv-illus.jpg"></img>
282
+ <p>
283
+ For more detail on how convolutional neural network work, you can refer to this <a href="https://viso.ai/deep-learning/convolution-operations/">link</a>.
284
+ </p>
285
+ <p>
286
+ In reality, what this convolution operation does is extract features to be processed on for a machine learning or another deep learning model. The Convolution by itself does not result in an object classification directly. So even deep learning model such as CNN
287
+ still does the traditional feature extraction then classification pipeline. However, the strength in this model is the convolution layer learns what weight it needs to use to get the best feature possible. Usually in a single convolution layer could result in tens or hundreds of feature channels.
288
+ </p>
289
+ <p>
290
+ In this program, although we will not discuss too deep about what traditional feature extraction is nor the fully inner workings of CNN, we will instead have a playground to demonstrate what feature extraction both perform and how they differ from
291
+ one and another.
292
+ </p>
293
+ <h2 style="margin-top:0;">The Model Architecture!</h2>
294
+ <p>
295
+ The model architecture used in this program will follow a CNN architecture where it will consist of Convolution layer and Fully Connected Layer as a classifier. However, we will instead make it so that the feature extraction layer or the convolution layer be replacable with a traditional feature extraction algorithm.
296
+ This is done because in theory they should be able to perform just as well or a little worse as it is basically what Convolution Layer does as convolution layer is able to extract a lot more features and trainable and specific features.
297
+ </p>
298
+ <p>
299
+ For more detail you can refer to : https://github.com/VJyzCELERY/ClassicalObjectClassifier which will include a paper to explain the code and it's method.
300
+ </p>
301
+
302
+ </div>
303
+ """
304
+
305
+ with gr.Blocks(title="Object Classifier Playground") as demo:
306
+ with gr.Tab("Introduction"):
307
+ gr.HTML(intro_html)
308
+ with gr.Tab("Training"):
309
+ with gr.Row():
310
+ zip_file = gr.File(label='Upload Dataset in Zip',file_types=['.zip'],file_count='single',interactive=True)
311
+ batch_size = gr.Number(value=32,label='Batch Size',interactive=True,precision=0)
312
+ lr = gr.Number(value=1e-3,label='Learning Rate',interactive=True)
313
+ epochs= gr.Number(value=20,label="Epochs",interactive=True,precision=0)
314
+ seed=gr.Number(value=42,label='Seed',interactive=True,precision=0)
315
+ vis_every=gr.Number(value=5,label='Visualize Validation Every (Epochs)',interactive=True,precision=0)
316
+ with gr.Row():
317
+ img_width=gr.Number(value=128,label='Image Width',interactive=True,precision=0)
318
+ img_height=gr.Number(value=128,label='Image Height',interactive=True,precision=0)
319
+ fc_num_layers = gr.Number(value=3,label="Fully Connected Layer Depth",interactive=True,precision=0)
320
+ dropout = gr.Slider(minimum=0,maximum=1,value=0.2,step=0.05,label='Fully Connected Layer Dropout',interactive=True)
321
+ gr.Markdown("# CNN Feature Extractor Configuration")
322
+ with gr.Accordion(label="CNN Settings",open=False):
323
+ with gr.Row():
324
+ in_channels = gr.Number(value=3,label='Input Color Channel Amount',interactive=True,precision=0)
325
+ conv_hidden_dim = gr.Number(value=3,label='Conv Hidden Dim',interactive=True,precision=0)
326
+ gr.Markdown("# Classical Feature Extractor Configuration")
327
+ with gr.Accordion(label='Classical Feature Extractor Settings',open=False):
328
+ with gr.Row():
329
+ classical_downsample = gr.Number(value=1,label='Classical Extractor Downsampling Amount',interactive=True,precision=0)
330
+ with gr.Row():
331
+ hog_orientations = gr.Number(value=9,label='HoG Orientations',interactive=True,precision=0)
332
+ hog_pixels_per_cell = gr.Number(value=16,label='HoG pixels per cell',interactive=True,precision=0)
333
+ hog_cells_per_block = gr.Number(value=2,label='HoG cells per block',interactive=True,precision=0)
334
+ hog_block_norm = gr.Dropdown(['L2-Hys'],value='L2-Hys',label='HoG Block Normalization Method',interactive=True)
335
+ with gr.Row():
336
+ canny_sigma = gr.Number(value=1.0,label='Canny Sigma Value',interactive=True)
337
+ canny_low = gr.Number(value=100,label='Canny Low Threshold',interactive=True,precision=0)
338
+ canny_high = gr.Number(value=200,label='Canny High Threshold',interactive=True,precision=0)
339
+ with gr.Row():
340
+ gaussian_ksize = gr.Number(value=3,label='Gaussian Kernel Size',interactive=True,precision=0)
341
+ gaussian_sigmaX = gr.Number(value=1.0,label='Gaussian Sigma X Value',interactive=True)
342
+ gaussian_sigmaY = gr.Number(value=1.0,label='Gaussian Sigma Y Value',interactive=True)
343
+ with gr.Row():
344
+ harris_block_size = gr.Number(value=2,label='Harris Corner Block Size',interactive=True,precision=0)
345
+ harris_ksize = gr.Number(value=3,label='Harris Corner Kernel Size',interactive=True,precision=0)
346
+ harris_k = gr.Slider(minimum=0.01, maximum=0.1, value=0.04, step=0.005, label='Harris Corner K value',interactive=True)
347
+ with gr.Row():
348
+ shi_max_corners = gr.Number(value=100,label='Shi-Tomasi Max Corners',interactive=True,precision=0)
349
+ shi_quality_level = gr.Number(value=0.01,label='Shi-Tomasi Quality Level',interactive=True)
350
+ shi_min_distance = gr.Number(value=10,label='Shi-Tomasi Min Distance',interactive=True,precision=0)
351
+ with gr.Row():
352
+ lbp_P = gr.Number(value=8,label='LBP P Value',interactive=True,precision=0)
353
+ lbp_R = gr.Number(value=1,label='LBP R Value',interactive=True,precision=0)
354
+ with gr.Row():
355
+ gabor_ksize = gr.Number(value=21,label="Gabor Kernel Size",interactive=True,precision=0)
356
+ gabor_sigma = gr.Number(value=5,label="Gabor Sigma",interactive=True,precision=0)
357
+ gabor_theta = gr.Number(value=0,label="Gabor Theta",interactive=True,precision=0)
358
+ gabor_lambda = gr.Number(value=10,label="Gabor Lambda",interactive=True,precision=0)
359
+ gabor_gamma = gr.Number(value=0.5,label="Gabor Gamma",interactive=True)
360
+ with gr.Column():
361
+ train_btn = gr.Button("Train Model",variant='secondary',interactive=True)
362
+ stop_btn = gr.Button("Stop Training")
363
+
364
+ with gr.Column():
365
+ with gr.Column():
366
+ gr.Markdown("### CNN Training Log")
367
+ cnn_log = gr.Textbox(label="CNN Log", interactive=False)
368
+ cnn_fig = gr.Gallery(label="CNN Batch Visualization",interactive=False,object_fit='fill',columns=1)
369
+ cnn_plots = gr.Gallery(label="CNN Training Performance",interactive=False,object_fit='fill',columns=1)
370
+ with gr.Column():
371
+ gr.Markdown("### Classical Training Log")
372
+ classical_log = gr.Textbox(label="Classical Log", interactive=False)
373
+ classical_fig = gr.Gallery(label="Classical Batch Visualization",interactive=False,object_fit='fill',columns=1)
374
+ classical_plots = gr.Gallery(label="CNN Training Performance",interactive=False,object_fit='fill',columns=1)
375
+ stop_btn.click(fn=stop_training, inputs=[], outputs=[])
376
+ train_btn.click(fn=train_model,
377
+ inputs=[zip_file,batch_size,lr,epochs,seed,vis_every,
378
+ img_width,img_height,fc_num_layers,
379
+ in_channels,conv_hidden_dim,dropout,
380
+ classical_downsample,
381
+ hog_orientations,hog_pixels_per_cell,hog_cells_per_block,hog_block_norm,
382
+ canny_sigma,canny_low,canny_high,
383
+ gaussian_ksize,gaussian_sigmaX,gaussian_sigmaY,
384
+ harris_block_size,harris_ksize,harris_k,
385
+ shi_max_corners,shi_quality_level,shi_min_distance,
386
+ lbp_P,lbp_R,
387
+ gabor_ksize,gabor_sigma,gabor_theta,gabor_lambda,gabor_gamma],
388
+ outputs=[cnn_log,cnn_fig,classical_log,classical_fig,cnn_plots,classical_plots]
389
+ )
390
+ def make_figure_from_image(img):
391
+ fig, ax = plt.subplots(figsize=(8,8))
392
+ ax.imshow(img)
393
+ ax.axis("off")
394
+
395
+ plt.show()
396
+
397
+ return fig
398
+ def predict_image(upload,show_original,max_channels):
399
+ img = cv2.cvtColor(cv2.imread(upload),cv2.COLOR_BGR2RGB)
400
+ model_base_path = "./trained_model"
401
+ classic_model_path =os.path.join(model_base_path,'classic_model.pt')
402
+ cnn_model_path = os.path.join(model_base_path,'cnn_model.pt')
403
+ os.makedirs(model_base_path,exist_ok=True)
404
+ if os.path.exists(classic_model_path):
405
+ classic_model = load(classic_model_path,ClassicalFeatureExtractor,device=device)
406
+ else:
407
+ return "No Classical Model trained",None,None,None
408
+ if os.path.exists(cnn_model_path):
409
+ cnn_model = load(cnn_model_path,CNNFeatureExtractor,device=device)
410
+ else:
411
+ return "No CNN Model trained",None,None,None
412
+ cnn_predict = cnn_model.predict(img)
413
+ classic_predict = classic_model.predict(img)
414
+ cnn_features = cnn_model.visualize_feature(img,max_channels=max_channels)
415
+ classical_features = classic_model.visualize_feature(img,show_original=show_original)
416
+ return None,make_figure_from_image(img),cnn_predict,classic_predict,cnn_features,classical_features
417
+
418
+ with gr.Tab("Inference"):
419
+ with gr.Row():
420
+ image_upload = gr.File(file_count='single',file_types=['image'],label='Upload Image to Infer',interactive=True)
421
+ with gr.Column():
422
+ gr.Markdown("# CNN Settings")
423
+ with gr.Accordion(open=False):
424
+ cnn_max_channel_visual = gr.Number(value=8,precision=0,label='Max CNN Channels to Preview',interactive=True)
425
+ with gr.Column():
426
+ gr.Markdown("# Classical Settings")
427
+ with gr.Accordion(open=False):
428
+ classic_show_original = gr.Checkbox(value=True,label='Show Original Image as Features')
429
+ with gr.Column():
430
+ gr.Markdown("# Predictions")
431
+ verbose = gr.Markdown()
432
+ image_preview = gr.Plot(value=None,label="Input Image")
433
+ cnn_features = gr.Gallery(label='CNN Extracted Features',columns=1,object_fit='fill',interactive=False)
434
+ classical_features = gr.Gallery(label='Classical Extracted Features',columns=1,object_fit='fill',interactive=False)
435
+ cnn_prediction=gr.Textbox(interactive=False,value='No Predictions',label='CNN Predictions')
436
+ classical_prediction=gr.Textbox(interactive=False,value='No Predictions',label='Classical Model Predictions')
437
+ prediction_btn = gr.Button('Predict',variant='primary')
438
+
439
+ prediction_btn.click(
440
+ fn=predict_image,
441
+ inputs=[image_upload,classic_show_original,cnn_max_channel_visual],
442
+ outputs=[verbose,image_preview,cnn_prediction,classical_prediction,cnn_features,classical_features]
443
+ )
444
+
445
+
446
+ if __name__ == "__main__":
447
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==5.5.0
2
+ pydantic==2.10.6
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ scikit-learn
7
+ matplotlib
8
+ scikit-image
src/__pycache__/dataloader.cpython-312.pyc ADDED
Binary file (2.57 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
src/__pycache__/trainer.cpython-312.pyc ADDED
Binary file (9.76 kB). View file
 
src/dataloader.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+ import os
4
+ import numpy as np
5
+ import cv2
6
+
7
+ def collate_fn(batch):
8
+ imgs = [img for img, _ in batch]
9
+ labels = torch.tensor([label for _, label in batch])
10
+ return imgs, labels
11
+
12
+
13
+ class ImageDataset(Dataset):
14
+ def __init__(self,root_path : str,img_size=(256,256)):
15
+ classes = os.listdir(root_path)
16
+ self.img_size = img_size
17
+ self.classes = classes
18
+ data = []
19
+ for idx,class_name in enumerate(classes):
20
+ class_path = os.path.join(root_path,class_name)
21
+ files = os.listdir(class_path)
22
+ for file in files:
23
+ filepath = os.path.join(class_path,file)
24
+ data.append({"image_path":filepath,"label":class_name,"id":idx})
25
+ self.data = data
26
+
27
+ def __len__(self):
28
+ return len(self.data)
29
+
30
+ def __getitem__(self,idx):
31
+ curr = self.data[idx]
32
+ label = curr['id']
33
+ img_path = curr['image_path']
34
+ img = cv2.imread(img_path)
35
+ img = cv2.resize(img,(self.img_size))
36
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
37
+ img = img.astype(np.float32) / 255.0
38
+ return img,label
39
+
40
+
src/model.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+ from skimage.feature import hog,local_binary_pattern
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ import io
10
+ from PIL import Image
11
+
12
+ @dataclass
13
+ class Config:
14
+ img_size=(256,256)
15
+ in_channels=3
16
+ fc_num_layers=3
17
+ conv_hidden_dim=3
18
+ conv_kernel_size=3
19
+ dropout=0.2
20
+ classical_downsample=1
21
+ # HOG
22
+ hog_orientations = 9
23
+ hog_pixels_per_cell = (16, 16)
24
+ hog_cells_per_block = (2, 2)
25
+ hog_block_norm = 'L2-Hys'
26
+
27
+ # Canny
28
+ canny_sigma = 1.0
29
+ canny_low = 100
30
+ canny_high = 200
31
+
32
+ # Gaussian
33
+ gaussian_ksize = (3, 3)
34
+ gaussian_sigmaX = 1.0
35
+ gaussian_sigmaY = 1.0
36
+
37
+ # Harris corners
38
+ harris_block_size = 2
39
+ harris_ksize = 3
40
+ harris_k = 0.04
41
+
42
+ # Shi-Tomasi corners
43
+ shi_max_corners = 100
44
+ shi_quality_level = 0.01
45
+ shi_min_distance = 10
46
+
47
+ # LBP
48
+ lbp_P = 8
49
+ lbp_R = 1
50
+
51
+ # Gabor filters
52
+ gabor_ksize = 21
53
+ gabor_sigma = 5
54
+ gabor_theta = 0
55
+ gabor_lambda = 10
56
+ gabor_gamma = 0.5
57
+
58
+ class CNNFeatureExtractor(nn.Module):
59
+ def __init__(self,config : Config):
60
+ super().__init__()
61
+ layers = []
62
+ self.in_channels = config.in_channels
63
+ in_channel = config.in_channels
64
+ self.img_size = config.img_size
65
+ out_channel = 32
66
+ for i in range(config.conv_hidden_dim):
67
+ layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=1))
68
+ layers.append(nn.BatchNorm2d(out_channel))
69
+ layers.append(nn.ReLU())
70
+ layers.append(nn.MaxPool2d(2))
71
+ in_channel=out_channel
72
+ out_channel*=2
73
+ self.layers = nn.Sequential(*layers)
74
+ def get_device(self):
75
+ return next(self.parameters()).device
76
+ def forward(self,x):
77
+ if isinstance(x, list):
78
+ if isinstance(x[0], np.ndarray):
79
+ x = np.stack(x, axis=0)
80
+ if isinstance(x,np.ndarray):
81
+ if len(x.shape) == 2:
82
+ x = x[:, :, None]
83
+ x = np.expand_dims(x, 0)
84
+ x = x.transpose(2, 0, 1)
85
+ elif len(x.shape) == 3:
86
+ x = x.transpose(2, 0, 1)
87
+ x = np.expand_dims(x, 0)
88
+ elif x.ndim == 4:
89
+ x = x.transpose(0, 3, 1, 2) # Change to (B,C,H,W)
90
+ x = torch.from_numpy(x).float()
91
+ elif isinstance(x, torch.Tensor):
92
+ if x.ndim == 3:
93
+ x = x.unsqueeze(0)
94
+ x=x.to(self.get_device())
95
+ return self.layers(x) # Always expects (B,C,H,W)
96
+ def output(self):
97
+ self.eval()
98
+
99
+ with torch.no_grad():
100
+ x = torch.zeros(
101
+ (1, self.in_channels, self.img_size[1], self.img_size[0]),
102
+ device=self.get_device()
103
+ )
104
+
105
+ out = self(x)
106
+
107
+ return out
108
+ def visualize(self, input_image, max_channels=8,show=True):
109
+ self.eval()
110
+ device = self.get_device()
111
+
112
+ if isinstance(input_image, np.ndarray):
113
+ x = torch.from_numpy(input_image).permute(2, 0, 1).float().unsqueeze(0).to(device) # HWC -> CHW -> B
114
+ elif isinstance(input_image, torch.Tensor):
115
+ x = input_image.unsqueeze(0).to(device) if input_image.ndim == 3 else input_image.to(device)
116
+ else:
117
+ raise TypeError("input_image must be np.ndarray or torch.Tensor")
118
+
119
+ conv_layers = [(name, module) for name, module in self.named_modules() if isinstance(module, nn.Conv2d)]
120
+ all_layer_images = []
121
+
122
+ for name, layer in conv_layers:
123
+ activations = []
124
+
125
+ def hook_fn(module, input, output):
126
+ activations.append(output.cpu().detach())
127
+
128
+ handle = layer.register_forward_hook(hook_fn)
129
+ _ = self(x)
130
+ handle.remove()
131
+
132
+ act = activations[0][0]
133
+ num_channels = min(act.shape[0], max_channels)
134
+
135
+ fig, axes = plt.subplots(1, num_channels, figsize=(3*num_channels, 3))
136
+ if num_channels == 1:
137
+ axes = [axes]
138
+
139
+ for i in range(num_channels):
140
+ axes[i].imshow(act[i], cmap='gray')
141
+ axes[i].axis('off')
142
+
143
+ fig.suptitle(f'Layer: {name}', fontsize=14)
144
+ if show:
145
+ plt.show()
146
+
147
+ buf = io.BytesIO()
148
+ fig.savefig(buf, format='png')
149
+ buf.seek(0)
150
+ img = Image.open(buf).convert("RGB")
151
+ all_layer_images.append(np.array(img))
152
+ plt.close(fig)
153
+ return all_layer_images
154
+
155
+ class ClassicalFeatureExtractor(nn.Module):
156
+ def __init__(self, config : Config):
157
+ super().__init__()
158
+ self.img_size = config.img_size # (H, W)
159
+ self.hog_orientations = config.hog_orientations
160
+ self.num_downsample = config.classical_downsample
161
+ self.config = config
162
+ self.feature_names = ['HoG','Canny Edge','Harris Corner','Shi-Tomasi corners','LBP','Gabor Filters']
163
+ self.device = 'cpu'
164
+
165
+ def get_device(self):
166
+ return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
167
+
168
+
169
+ def extract_features(self, img):
170
+ cfg = self.config
171
+
172
+ # Convert to grayscale
173
+ min_h = cfg.hog_pixels_per_cell[0] * cfg.hog_cells_per_block[0]
174
+ min_w = cfg.hog_pixels_per_cell[1] * cfg.hog_cells_per_block[1]
175
+ gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
176
+
177
+ for _ in range(self.num_downsample):
178
+ h, w = gray.shape
179
+ if h <= min_h or w <= min_w:
180
+ break
181
+ gray = cv2.pyrDown(gray)
182
+
183
+ gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
184
+
185
+ feature_list = []
186
+
187
+ # 1. HOG
188
+ _, hog_image = hog(
189
+ gray,
190
+ orientations=cfg.hog_orientations,
191
+ pixels_per_cell=cfg.hog_pixels_per_cell,
192
+ cells_per_block=cfg.hog_cells_per_block,
193
+ block_norm=cfg.hog_block_norm,
194
+ visualize=True
195
+ )
196
+ feature_list.append(hog_image)
197
+
198
+ # 2. Canny edges
199
+ edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
200
+ feature_list.append(edges)
201
+
202
+ # 3. Harris corners
203
+ harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
204
+ harris = cv2.dilate(harris, None)
205
+ harris = np.clip(harris, 0, 1)
206
+ feature_list.append(harris)
207
+
208
+ # 4. Shi-Tomasi corners
209
+ shi_corners = np.zeros_like(gray, dtype=np.float32)
210
+ keypoints = cv2.goodFeaturesToTrack(gray, maxCorners=cfg.shi_max_corners, qualityLevel=cfg.shi_quality_level, minDistance=cfg.shi_min_distance)
211
+ if keypoints is not None:
212
+ for kp in keypoints:
213
+ x, y = kp.ravel()
214
+ shi_corners[int(y), int(x)] = 1.0
215
+ feature_list.append(shi_corners)
216
+
217
+ # 5. LBP
218
+ lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
219
+ lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
220
+ feature_list.append(lbp)
221
+
222
+ # 6. Gabor filter
223
+ g_kernel = cv2.getGaborKernel((cfg.gabor_ksize, cfg.gabor_ksize), cfg.gabor_sigma, cfg.gabor_theta, cfg.gabor_lambda, cfg.gabor_gamma)
224
+ gabor_feat = cv2.filter2D(gray, cv2.CV_32F, g_kernel)
225
+ gabor_feat = (gabor_feat - gabor_feat.min()) / (gabor_feat.max() - gabor_feat.min() + 1e-8)
226
+ feature_list.append(gabor_feat)
227
+
228
+ # Stack all features along channel axis
229
+ features = np.stack(feature_list, axis=2)
230
+ return features.astype(np.float32)
231
+
232
+
233
+ def forward(self, x):
234
+ if isinstance(x, torch.Tensor):
235
+ x = x.cpu().numpy()
236
+ if isinstance(x, np.ndarray):
237
+ if x.ndim == 3:
238
+ x = np.expand_dims(x, 0)
239
+ elif x.ndim != 4:
240
+ raise ValueError(f"Expected input of shape HWC or BHWC, got {x.shape}")
241
+ elif isinstance(x, list):
242
+ x = np.stack(x, axis=0)
243
+
244
+ batch_features = []
245
+ for img in x:
246
+ if img.ndim != 3 or img.shape[2] != 3:
247
+ img = np.repeat(img[:, :, None], 3, axis=2)
248
+ feat = self.extract_features(img)
249
+ batch_features.append(feat)
250
+ batch_features = np.stack(batch_features, axis=0)
251
+ return torch.from_numpy(batch_features).float().to(self.get_device())
252
+
253
+ def visualize(self, img, show_original=True,show=True):
254
+ if img.ndim != 3 or img.shape[2] != 3:
255
+ img = np.repeat(img[:, :, None], 3, axis=2)
256
+
257
+ feature_stack = self.extract_features(img)
258
+ num_channels = feature_stack.shape[2]
259
+
260
+ outputs = []
261
+
262
+ def fig_to_pil(fig):
263
+ buf = io.BytesIO()
264
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
265
+ buf.seek(0)
266
+
267
+ pil_img = Image.open(buf).copy()
268
+
269
+ buf.close()
270
+ plt.close(fig)
271
+
272
+ return pil_img
273
+
274
+ if show_original:
275
+ fig = plt.figure(figsize=(4, 4))
276
+ plt.imshow(img)
277
+ plt.title("Original")
278
+ plt.axis("off")
279
+ if show:
280
+ plt.show()
281
+ outputs.append(fig_to_pil(fig))
282
+
283
+ for c in range(num_channels):
284
+ fig = plt.figure(figsize=(4, 4))
285
+
286
+ plt.imshow(feature_stack[:, :, c], cmap="gray")
287
+ plt.title(f"Feature {self.feature_names[c]}")
288
+ plt.axis("off")
289
+ if show:
290
+ plt.show()
291
+ outputs.append(fig_to_pil(fig))
292
+
293
+ return outputs
294
+
295
+
296
+ def output(self):
297
+ """Return dummy output to compute in_features for FC head"""
298
+ dummy_img = np.zeros((1, self.img_size[1],self.img_size[0], 3), dtype=np.float32)
299
+ feat = self.forward(dummy_img)
300
+ return feat
301
+
302
+
303
+
304
+ class FullyConnectedHead(nn.Module):
305
+ def __init__(self,in_features,classes,config:Config):
306
+ super().__init__()
307
+ num_classes = len(classes)
308
+ self.classes = classes
309
+ layers = []
310
+ out_features=256
311
+ for i in range(config.fc_num_layers):
312
+ layers.append(nn.Linear(in_features,out_features))
313
+ layers.append(nn.BatchNorm1d(out_features))
314
+ layers.append(nn.ReLU())
315
+ layers.append(nn.Dropout(config.dropout))
316
+ in_features=out_features
317
+ out_features=out_features // 2
318
+ if out_features <= num_classes:
319
+ break
320
+ layers.append(nn.Linear(in_features,num_classes))
321
+ self.layers = nn.Sequential(*layers)
322
+ def get_device(self):
323
+ return next(self.parameters()).device
324
+ def forward(self,x : torch.Tensor):
325
+ x=x.to(self.get_device())
326
+ return self.layers(x)
327
+
328
+ class Classifier(nn.Module):
329
+ def __init__(self,backbone,classes,config : Config):
330
+ super().__init__()
331
+ self.config=config
332
+ self.classes=classes
333
+ self.backbone = backbone
334
+ self.flatten = nn.Flatten()
335
+ feat = backbone.output()
336
+ flat = self.flatten(feat)
337
+ in_features = flat.shape[1]
338
+ self.head = FullyConnectedHead(in_features,classes,config)
339
+ def get_device(self):
340
+ return next(self.parameters()).device
341
+
342
+ @torch.no_grad()
343
+ def predict(self, x):
344
+ self.eval()
345
+ target_size = self.config.img_size
346
+ x = cv2.resize(x, target_size)
347
+ logits = self.forward(x)
348
+ probs = torch.softmax(logits, dim=1)
349
+ pred_idx = torch.argmax(probs, dim=1).item()
350
+
351
+ return self.classes[pred_idx]
352
+
353
+ def forward(self,x):
354
+ feat = self.backbone(x)
355
+ feat = self.flatten(feat)
356
+ return self.head(feat)
357
+ def visualize_feature(self,img,return_img=True,**kwargs):
358
+ target_size = self.config.img_size
359
+ img = cv2.resize(img, target_size)
360
+ if return_img:
361
+ return self.backbone.visualize(img,**kwargs)
362
+ else:
363
+ self.backbone.visualize(img,**kwargs)
364
+ def save(self, path: str):
365
+ os.makedirs(os.path.dirname(path), exist_ok=True)
366
+ torch.save({
367
+ 'model_state_dict': self.state_dict(),
368
+ 'classes': self.classes,
369
+ 'config': self.config
370
+ }, path)
371
+ print(f"Model saved to {path}")
372
+
373
+ @staticmethod
374
+ def load(path: str, backbone_class, device='cpu'):
375
+ checkpoint = torch.load(path, map_location=device,weights_only=False)
376
+ config = checkpoint['config']
377
+ classes = checkpoint['classes']
378
+ backbone = backbone_class(config).to(device)
379
+ model = Classifier(backbone, classes, config).to(device)
380
+ model.load_state_dict(checkpoint['model_state_dict'])
381
+ model.eval()
382
+ print(f"Model loaded from {path}")
383
+ return model
src/trainer.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model import Classifier
2
+ from src.dataloader import ImageDataset,collate_fn
3
+ from torch.utils.data import DataLoader
4
+ import torch.optim as optim
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+ import matplotlib.pyplot as plt
8
+ import torch
9
+ import random
10
+ import numpy as np
11
+ import torch.nn as nn
12
+ import time
13
+
14
+ def seed_worker(worker_id):
15
+ worker_seed = torch.initial_seed() % 2**32
16
+ np.random.seed(worker_seed)
17
+ random.seed(worker_seed)
18
+
19
+ class ModelTrainer:
20
+ def __init__(self,model : Classifier,train_set : ImageDataset,val_set : ImageDataset = None, batch_size=32,lr = 1e-3,device='cpu',return_fig=False, seed=None):
21
+ g = torch.Generator()
22
+ if seed is not None:
23
+ g.manual_seed(seed)
24
+
25
+ self.train_loader = DataLoader(
26
+ train_set,
27
+ batch_size,
28
+ shuffle=True,
29
+ collate_fn=collate_fn,
30
+ worker_init_fn=seed_worker,
31
+ generator=g
32
+ )
33
+
34
+ self.device = device
35
+
36
+ if val_set is not None:
37
+ self.val_loader = DataLoader(
38
+ val_set,
39
+ batch_size,
40
+ shuffle=False,
41
+ collate_fn=collate_fn,
42
+ worker_init_fn=seed_worker
43
+ )
44
+ else:
45
+ self.val_loader = None
46
+ self.class_names = model.classes
47
+ self.model = model
48
+ self.lr = lr
49
+ self.optim = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
50
+ self.optim.zero_grad()
51
+ self.criterion = nn.CrossEntropyLoss()
52
+ self.return_fig=return_fig
53
+
54
+ def visualize_batch(self, imgs, preds, labels, class_names=None, max_samples=4):
55
+
56
+ first_image = imgs
57
+ if isinstance(imgs, list):
58
+ imgs = np.stack(imgs, axis=0)
59
+ imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2).float()
60
+
61
+ imgs_np = imgs.cpu().numpy()
62
+ preds = preds.cpu().numpy()
63
+ labels = labels.cpu().numpy()
64
+
65
+ batch_size = imgs_np.shape[0]
66
+ indices = random.sample(range(batch_size), min(max_samples, batch_size))
67
+ first_image = first_image[indices[0]]
68
+ fig_pred = plt.figure(figsize=(6 * len(indices), 5))
69
+ grid = fig_pred.add_gridspec(1, len(indices))
70
+
71
+ for col, idx in enumerate(indices):
72
+ ax = fig_pred.add_subplot(grid[0, col])
73
+ ax.imshow(imgs_np[idx].transpose(1, 2, 0))
74
+
75
+ if class_names:
76
+ title = f"P: {class_names[preds[idx]]} | T: {class_names[labels[idx]]}"
77
+ else:
78
+ title = f"P: {preds[idx]} | T: {labels[idx]}"
79
+
80
+ ax.set_title(title)
81
+ ax.axis("off")
82
+
83
+ fig_pred.tight_layout()
84
+ raw_features = self.model.visualize_feature(first_image,show=False)
85
+ feature_figs = []
86
+
87
+ for f in raw_features:
88
+
89
+ if isinstance(f, plt.Figure):
90
+ feature_figs.append(f)
91
+ continue
92
+
93
+ if hasattr(f, "mode"):
94
+ f = np.array(f)
95
+ h, w = f.shape[:2]
96
+
97
+ dpi = 100
98
+ fig_w = max(4, w / dpi)
99
+ fig_h = max(4, h / dpi)
100
+ fig = plt.figure(figsize=(fig_w, fig_h), dpi=dpi)
101
+ ax = fig.add_subplot(111)
102
+ ax.imshow(f)
103
+ ax.axis("off")
104
+ feature_figs.append(fig)
105
+
106
+
107
+ all_figs = [fig_pred] + feature_figs
108
+ if not self.return_fig:
109
+ plt.show()
110
+ plt.close(fig_pred)
111
+ if self.return_fig:
112
+ return all_figs
113
+ else:
114
+ return None
115
+
116
+
117
+ def train_one_epoch(self):
118
+ self.model.train()
119
+ total_loss = 0
120
+ train_pbar = tqdm(self.train_loader, desc="Training",leave=False)
121
+ correct = 0
122
+ total = 0
123
+ for imgs, labels in train_pbar:
124
+ labels = labels.to(self.device)
125
+
126
+ # Forward
127
+ outputs = self.model(imgs)
128
+ loss = self.criterion(outputs, labels)
129
+
130
+ # Backward
131
+ self.optim.zero_grad()
132
+ loss.backward()
133
+ self.optim.step()
134
+ preds = outputs.argmax(dim=1)
135
+ correct += (preds == labels).sum().item()
136
+ total += labels.size(0)
137
+ total_loss += loss.item()
138
+ train_pbar.set_postfix(acc=correct/total,loss=loss.item())
139
+
140
+ avg_loss = total_loss / len(self.train_loader)
141
+ avg_acc = correct / total
142
+ return avg_loss,avg_acc
143
+ def train(self, epochs=10, visualize_every=5):
144
+ train_losses=[]
145
+ train_accuracies=[]
146
+ val_losses=[]
147
+ val_accuracies=[]
148
+ for epoch in range(1, epochs + 1):
149
+ train_loss,train_acc = self.train_one_epoch()
150
+ train_losses.append(train_loss)
151
+ train_accuracies.append(train_acc)
152
+ if self.val_loader is not None:
153
+ val_loss,val_acc,fig=self.validate(epoch, visualize=(epoch % visualize_every == 0 or epoch == 1))
154
+ val_losses.append(val_loss)
155
+ val_accuracies.append(val_acc)
156
+ print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f} | Val Loss : {val_loss:.4f} | Val Acc : {val_acc:.4f}")
157
+ yield train_loss,train_acc,val_loss,val_acc,fig
158
+ else:
159
+ print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc : {train_acc:.4f}")
160
+ yield train_loss,train_acc,None,None,None
161
+ yield train_losses,train_accuracies,val_losses,val_accuracies,None
162
+
163
+ def validate(self,epoch, visualize=False):
164
+ if self.val_loader is None:
165
+ return
166
+
167
+ self.model.eval()
168
+ total_loss = 0
169
+ correct = 0
170
+ total = 0
171
+
172
+ val_imgs_display = None
173
+ val_preds_display = None
174
+ val_labels_display = None
175
+
176
+ val_pbar = tqdm(self.val_loader, desc="Validation",leave=False)
177
+ fig = None
178
+ with torch.no_grad():
179
+ for imgs, labels in val_pbar:
180
+ labels = labels.to(self.device)
181
+
182
+ outputs = self.model(imgs)
183
+ loss = self.criterion(outputs, labels)
184
+ total_loss += loss.item()
185
+
186
+ preds = outputs.argmax(dim=1)
187
+ correct += (preds == labels).sum().item()
188
+ total += labels.size(0)
189
+
190
+ if visualize and val_imgs_display is None:
191
+ val_imgs_display = imgs
192
+ val_preds_display = preds
193
+ val_labels_display = labels
194
+
195
+ val_pbar.set_postfix(loss=loss.item(), acc=correct / total)
196
+
197
+ avg_loss = total_loss / len(self.val_loader)
198
+ acc = correct / total
199
+
200
+ if visualize and val_imgs_display is not None:
201
+ fig = self.visualize_batch(val_imgs_display, val_preds_display, val_labels_display, self.class_names)
202
+
203
+ return avg_loss,acc,fig