Shad0wKillar commited on
Commit
59cf609
·
verified ·
1 Parent(s): a9867f4

Initial push

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EfficientNet_B7_20percent.pth_wrong_pred.png filter=lfs diff=lfs merge=lfs -text
EfficientNet_B7_20percent.pth_confusion_matrix.png ADDED
EfficientNet_B7_20percent.pth_curves.png ADDED
EfficientNet_B7_20percent.pth_wrong_pred.png ADDED

Git LFS Details

  • SHA256: 4c64095b381b9955b48c237dc4642ee452824a11917bf3a5159909e99a2b84ba
  • Pointer size: 131 Bytes
  • Size of remote file: 258 kB
efficient_b7_20_percent.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchinfo
4
+
5
+ import typing
6
+ import requests
7
+ import os
8
+ import zipfile
9
+ import mlxtend.plotting
10
+ import torchmetrics
11
+ from pathlib import Path
12
+ from timeit import default_timer as timer
13
+ from tqdm.auto import tqdm
14
+ import matplotlib
15
+
16
+ matplotlib.use("TkAgg")
17
+ from matplotlib import pyplot as plt
18
+
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ TRAIN_MODEL = False
22
+ BATCH_SIZE = 32
23
+ LEARNING_RATE = 0.001
24
+ NUM_EPOCH = 10
25
+ MODEL_PATH = Path("models")
26
+ MODEL_NAME = "EfficientNet_B7_20percent.pth"
27
+ MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
28
+
29
+ # Downloading the data here
30
+ data_path = Path("data/")
31
+ image_path = data_path / "pizza_steak_sushi_20_percent"
32
+
33
+ # If the image folder doesn't exist, download it and prepare it...
34
+ if image_path.is_dir():
35
+ print(f"{image_path} directory exists.")
36
+ else:
37
+ print(f"Did not find {image_path} directory, creating one...")
38
+ image_path.mkdir(parents=True, exist_ok=True)
39
+
40
+ # Download pizza, steak, sushi data
41
+ with open(data_path / "pizza_steak_sushi_20_percent.zip", "wb") as f:
42
+ request = requests.get(
43
+ "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi_20_percent.zip"
44
+ )
45
+ print("Downloading pizza, steak, sushi data...")
46
+ f.write(request.content)
47
+
48
+ # Unzip pizza, steak, sushi data
49
+ with zipfile.ZipFile(
50
+ data_path / "pizza_steak_sushi_20_percent.zip", "r"
51
+ ) as zip_ref:
52
+ print("Unzipping pizza, steak, sushi data...")
53
+ zip_ref.extractall(image_path)
54
+
55
+ # Remove .zip file
56
+ os.remove(data_path / "pizza_steak_sushi_20_percent.zip")
57
+
58
+ train_dir = image_path / "train"
59
+ test_dir = image_path / "test"
60
+
61
+ manual_transform = torchvision.transforms.Compose(
62
+ [
63
+ torchvision.transforms.Resize((224, 224)),
64
+ torchvision.transforms.ToTensor(),
65
+ torchvision.transforms.Normalize(
66
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
67
+ ),
68
+ ]
69
+ )
70
+
71
+
72
+ def create_dataloaders(
73
+ train_dir: Path,
74
+ test_dir: Path,
75
+ batch_size: int,
76
+ num_workers: int,
77
+ transform: torchvision.transforms.Compose,
78
+ ) -> tuple[
79
+ torch.utils.data.DataLoader,
80
+ torch.utils.data.DataLoader,
81
+ list[str],
82
+ torchvision.datasets.ImageFolder,
83
+ torchvision.datasets.ImageFolder,
84
+ ]:
85
+ train_data = torchvision.datasets.ImageFolder(
86
+ train_dir,
87
+ transform=transform,
88
+ )
89
+ test_data = torchvision.datasets.ImageFolder(
90
+ test_dir,
91
+ transform=transform,
92
+ )
93
+
94
+ class_names = train_data.classes
95
+
96
+ train_dataloader = torch.utils.data.DataLoader(
97
+ train_data,
98
+ batch_size=batch_size,
99
+ shuffle=True,
100
+ num_workers=num_workers,
101
+ pin_memory=True,
102
+ )
103
+
104
+ test_dataloader = torch.utils.data.DataLoader(
105
+ test_data,
106
+ batch_size=batch_size,
107
+ num_workers=num_workers,
108
+ shuffle=False,
109
+ pin_memory=True,
110
+ )
111
+
112
+ return (
113
+ train_dataloader,
114
+ test_dataloader,
115
+ class_names,
116
+ train_data,
117
+ test_data,
118
+ )
119
+
120
+
121
+ (
122
+ train_dataloader_manual_transform,
123
+ test_dataloader_manual_transform,
124
+ class_names_manual_transform,
125
+ train_data,
126
+ test_data,
127
+ ) = create_dataloaders(
128
+ train_dir=train_dir,
129
+ test_dir=test_dir,
130
+ num_workers=os.cpu_count() or 0,
131
+ batch_size=BATCH_SIZE,
132
+ transform=manual_transform,
133
+ )
134
+
135
+ weights = torchvision.models.EfficientNet_B7_Weights.DEFAULT
136
+
137
+ auto_transform = weights.transforms()
138
+
139
+ (
140
+ train_dataloader,
141
+ test_dataloader,
142
+ class_names,
143
+ train_data,
144
+ test_data,
145
+ ) = create_dataloaders(
146
+ train_dir=train_dir,
147
+ test_dir=test_dir,
148
+ batch_size=BATCH_SIZE,
149
+ num_workers=os.cpu_count() or 0,
150
+ transform=auto_transform,
151
+ )
152
+
153
+ model = torchvision.models.efficientnet_b7(weights=weights).to(device)
154
+
155
+ torchinfo.summary(
156
+ model=model,
157
+ input_size=(32, 3, 224, 224),
158
+ col_names=["input_size", "output_size", "num_params", "trainable"],
159
+ row_settings=["var_names"],
160
+ )
161
+
162
+ for feature in model.features:
163
+ print(feature)
164
+
165
+ for param in model.features.parameters():
166
+ param.requires_grad = False
167
+
168
+ print(f"Classifier part has (before changing):\n{model.classifier}")
169
+
170
+ torch.manual_seed(37)
171
+ torch.cuda.manual_seed(37)
172
+ output_shape = len(class_names)
173
+ model.classifier = torch.nn.Sequential(
174
+ torch.nn.Dropout(p=0.2, inplace=True),
175
+ torch.nn.Linear(in_features=2560, out_features=output_shape, bias=True),
176
+ )
177
+
178
+ print(f"Classifier part has (after changing):\n{model.classifier}")
179
+
180
+ torchinfo.summary(
181
+ model=model,
182
+ input_size=(32, 3, 224, 224),
183
+ col_names=["input_size", "output_size", "num_params", "trainable"],
184
+ row_settings=["var_names"],
185
+ )
186
+
187
+ loss_fn = torch.nn.CrossEntropyLoss()
188
+ optim = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
189
+
190
+
191
+ class Engine:
192
+ def __init__(
193
+ self,
194
+ train_dataloader: torch.utils.data.DataLoader,
195
+ test_dataloader: torch.utils.data.DataLoader,
196
+ model: torch.nn.Module,
197
+ optim: torch.optim.Optimizer,
198
+ loss_fn: torch.nn.Module,
199
+ device: typing.Literal["cuda", "cpu"],
200
+ num_epoch: int,
201
+ ):
202
+ self.train_dataloader = train_dataloader
203
+ self.test_dataloader = test_dataloader
204
+ self.optim = optim
205
+ self.loss_fn = loss_fn
206
+ self.device = device
207
+ self.num_epoch = num_epoch
208
+ self.model = model.to(device)
209
+
210
+ def _train_step(self) -> tuple[float, float]:
211
+ self.model.train()
212
+ loss_train = 0
213
+ acc_train = 0
214
+
215
+ for batch, (X, y) in enumerate(self.train_dataloader):
216
+ X, y = X.to(self.device), y.to(self.device)
217
+
218
+ train_pred = self.model(X)
219
+ loss = self.loss_fn(train_pred, y)
220
+
221
+ loss_train += loss.item()
222
+
223
+ optim.zero_grad()
224
+ loss.backward()
225
+ optim.step()
226
+
227
+ pred_class = torch.argmax(torch.softmax(train_pred, dim=1), dim=1)
228
+ acc = (pred_class == y).sum().item() / len(pred_class)
229
+
230
+ acc_train += acc
231
+
232
+ if batch % 2 == 0:
233
+ print(f"{batch} batches have been processed...")
234
+
235
+ loss_train = loss_train / len(self.train_dataloader)
236
+ acc_train = acc_train / len(self.train_dataloader)
237
+
238
+ return loss_train, acc_train
239
+
240
+ def _test_step(self) -> tuple[float, float]:
241
+ self.model.eval()
242
+ loss_test = 0
243
+ acc_test = 0
244
+
245
+ with torch.inference_mode():
246
+ for batch, (X, y) in enumerate(self.test_dataloader):
247
+ X, y = X.to(self.device), y.to(self.device)
248
+
249
+ test_pred = self.model(X)
250
+ loss = self.loss_fn(test_pred, y)
251
+
252
+ loss_test += loss.item()
253
+
254
+ pred_class = torch.argmax(torch.softmax(test_pred, dim=1), dim=1)
255
+ acc = (pred_class == y).sum().item() / len(pred_class)
256
+ acc_test += acc
257
+
258
+ if batch % 2 == 0:
259
+ print(f"{batch} batches have been processed...")
260
+
261
+ loss_test = loss_test / len(self.test_dataloader)
262
+ acc_test = acc_test / len(self.test_dataloader)
263
+
264
+ return loss_test, acc_test
265
+
266
+ def train(self) -> tuple[list[float], list[float], list[float], list[float]]:
267
+ train_loss_list = []
268
+ test_loss_list = []
269
+ train_acc_list = []
270
+ test_acc_list = []
271
+ for epoch in tqdm(range(self.num_epoch)):
272
+ print(f"{'*' * 6} EPOCH NUM: {epoch} {'*' * 6}")
273
+
274
+ print("Starting the training...")
275
+ train_loss, train_acc = self._train_step()
276
+ print("Starting the testing...")
277
+ test_loss, test_acc = self._test_step()
278
+ print(
279
+ f"Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f} "
280
+ f"Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.3f}"
281
+ )
282
+
283
+ train_loss_list.append(train_loss)
284
+ train_acc_list.append(train_acc)
285
+ test_loss_list.append(test_loss)
286
+ test_acc_list.append(test_acc)
287
+
288
+ return train_loss_list, train_acc_list, test_loss_list, test_acc_list
289
+
290
+
291
+ torch.manual_seed(37)
292
+ torch.cuda.manual_seed(37)
293
+ engine = Engine(
294
+ train_dataloader=train_dataloader,
295
+ test_dataloader=test_dataloader,
296
+ model=model,
297
+ optim=optim,
298
+ loss_fn=loss_fn,
299
+ num_epoch=NUM_EPOCH,
300
+ device=device,
301
+ )
302
+
303
+
304
+ def plot_curves(
305
+ train_loss: list[float],
306
+ train_acc: list[float],
307
+ test_loss: list[float],
308
+ test_acc: list[float],
309
+ num_epoch: int,
310
+ ):
311
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
312
+
313
+ # Ploting loss curves
314
+ ax[0].plot(range(num_epoch), train_loss, color="red", label="Train")
315
+ ax[0].plot(range(num_epoch), test_loss, color="blue", label="Test")
316
+ ax[0].set(xlabel="Epochs", ylabel="Loss", title="Train vs Test Loss")
317
+ ax[0].legend()
318
+
319
+ # Plotting acc curves
320
+ ax[1].plot(range(num_epoch), train_acc, color="red", label="Train")
321
+ ax[1].plot(range(num_epoch), test_acc, color="blue", label="Test")
322
+ ax[1].set(xlabel="Epochs", ylabel="Accuracy", title="Train vs Test Accuracy")
323
+ ax[1].legend()
324
+
325
+ fig.suptitle("Loss and Accuracy Curve")
326
+ plt.savefig(f"{MODEL_NAME}_curves.png")
327
+ plt.show()
328
+
329
+
330
+ if TRAIN_MODEL:
331
+ start_time = timer()
332
+ train_loss, train_acc, test_loss, test_acc = engine.train()
333
+ end_time = timer()
334
+ print(f"INFO: Training process took {end_time - start_time:.3f} seconds.")
335
+
336
+ MODEL_PATH.mkdir(parents=True, exist_ok=True)
337
+ torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)
338
+
339
+ plot_curves(train_loss, train_acc, test_loss, test_acc, NUM_EPOCH)
340
+
341
+ else:
342
+ model.load_state_dict(
343
+ torch.load(f=MODEL_SAVE_PATH, weights_only=True, map_location=device)
344
+ )
345
+
346
+
347
+ # Plotting the Confusion Matrix
348
+ def give_predictions(
349
+ test_dataloader: torch.utils.data.DataLoader,
350
+ model: torch.nn.Module,
351
+ device: typing.Literal["cuda", "cpu"],
352
+ ) -> tuple[torch.Tensor, torch.Tensor]:
353
+ print("Starting the testing...")
354
+ model.to(device)
355
+
356
+ predictions = []
357
+ logits_prob = []
358
+ model.eval()
359
+ with torch.inference_mode():
360
+ for X, y in tqdm(test_dataloader, desc="Doing Validation"):
361
+ X, y = X.to(device), y.to(device)
362
+
363
+ logits = model(X)
364
+
365
+ pred = torch.argmax(torch.softmax(logits, dim=1), dim=1)
366
+ logits_prob.append(torch.softmax(logits, dim=1).cpu())
367
+
368
+ predictions.append(pred.cpu())
369
+
370
+ return torch.cat(predictions), torch.cat(logits_prob)
371
+
372
+
373
+ # First we need the prediction on entire dataset
374
+ test_preds, logits_prob = give_predictions(
375
+ test_dataloader=test_dataloader, model=model, device=device
376
+ )
377
+
378
+ confmat = torchmetrics.ConfusionMatrix(num_classes=len(class_names), task="multiclass")
379
+ confmat_tensor = confmat(preds=test_preds, target=torch.tensor(test_data.targets))
380
+ fig, ax = mlxtend.plotting.plot_confusion_matrix(
381
+ conf_mat=confmat_tensor.numpy(),
382
+ class_names=class_names,
383
+ figsize=(10, 7),
384
+ )
385
+ plt.savefig(f"{MODEL_NAME}_confusion_matrix.png")
386
+ plt.show()
387
+
388
+ # Getting the wrong predictions where the model was most confidient.
389
+ pred_wrong = []
390
+ for i in range(len(test_preds)):
391
+ if test_preds[i] != test_data.targets[i]:
392
+ pred_wrong.append([test_data.targets[i], test_preds[i], logits_prob[i], i])
393
+
394
+ pred_wrong.sort(key=lambda x: x[2][x[1]], reverse=True)
395
+
396
+ # Creating this so I can get un-normalized data so I can plot the image.
397
+ # otherwise some images will be below zero that is invaild etc.
398
+ test_data_original = torchvision.datasets.ImageFolder(
399
+ test_dir,
400
+ transform=None,
401
+ )
402
+
403
+ if len(pred_wrong) > 2:
404
+ nrows, ncols = len(pred_wrong) // 2 if len(pred_wrong) // 2 < 5 else 5, 2
405
+ fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 8))
406
+ for rows in range(nrows):
407
+ for cols in range(ncols):
408
+ index_1d = rows * ncols + cols
409
+ image, true_label_index = test_data_original[pred_wrong[index_1d][3]]
410
+ true_label = class_names[true_label_index]
411
+ pred_label_index = pred_wrong[index_1d][1]
412
+ pred_label = class_names[pred_label_index]
413
+ ax[rows][cols].imshow(image)
414
+ ax[rows][cols].set_title(
415
+ f"True: {true_label}:{pred_wrong[index_1d][2][true_label_index]:.2f} | Prediction: {pred_label}:{pred_wrong[index_1d][2][pred_label_index]:.2f}"
416
+ )
417
+ ax[rows][cols].axis("off")
418
+ plt.savefig(f"{MODEL_NAME}_wrong_pred.png")
419
+ plt.show()
420
+ elif len(pred_wrong) == 1:
421
+ image, true_label_index = test_data_original[pred_wrong[0][3]]
422
+ true_label = class_names[true_label_index]
423
+ pred_label_index = pred_wrong[0][1]
424
+ pred_label = class_names[pred_label_index]
425
+ plt.imshow(image)
426
+
427
+ plt.title(
428
+ f"True: {true_label}:{pred_wrong[0][2][true_label_index]:.2f} | Prediction: {pred_label}:{pred_wrong[0][2][pred_label_index]:.2f}"
429
+ )
430
+ plt.axis(False)
431
+ plt.savefig(f"{MODEL_NAME}_wrong_pred.png")
432
+ plt.show()