efficientnet-b3 / efficient_b3_20_percent.py
Shad0wKillar's picture
Initial Push
68871d6 verified
import torch
import torchvision
import torchinfo
import typing
import requests
import os
import zipfile
import mlxtend.plotting
import torchmetrics
from pathlib import Path
from timeit import default_timer as timer
from tqdm.auto import tqdm
import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_MODEL = False
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCH = 10
MODEL_PATH = Path("models")
MODEL_NAME = "EfficientNet_B3_20percent.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
# Downloading the data here
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi_20_percent"
# If the image folder doesn't exist, download it and prepare it...
if image_path.is_dir():
print(f"{image_path} directory exists.")
else:
print(f"Did not find {image_path} directory, creating one...")
image_path.mkdir(parents=True, exist_ok=True)
# Download pizza, steak, sushi data
with open(data_path / "pizza_steak_sushi_20_percent.zip", "wb") as f:
request = requests.get(
"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi_20_percent.zip"
)
print("Downloading pizza, steak, sushi data...")
f.write(request.content)
# Unzip pizza, steak, sushi data
with zipfile.ZipFile(
data_path / "pizza_steak_sushi_20_percent.zip", "r"
) as zip_ref:
print("Unzipping pizza, steak, sushi data...")
zip_ref.extractall(image_path)
# Remove .zip file
os.remove(data_path / "pizza_steak_sushi_20_percent.zip")
train_dir = image_path / "train"
test_dir = image_path / "test"
manual_transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def create_dataloaders(
train_dir: Path,
test_dir: Path,
batch_size: int,
num_workers: int,
transform: torchvision.transforms.Compose,
) -> tuple[
torch.utils.data.DataLoader,
torch.utils.data.DataLoader,
list[str],
torchvision.datasets.ImageFolder,
torchvision.datasets.ImageFolder,
]:
train_data = torchvision.datasets.ImageFolder(
train_dir,
transform=transform,
)
test_data = torchvision.datasets.ImageFolder(
test_dir,
transform=transform,
)
class_names = train_data.classes
train_dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
test_dataloader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
)
return (
train_dataloader,
test_dataloader,
class_names,
train_data,
test_data,
)
(
train_dataloader_manual_transform,
test_dataloader_manual_transform,
class_names_manual_transform,
train_data,
test_data,
) = create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
num_workers=os.cpu_count() or 0,
batch_size=BATCH_SIZE,
transform=manual_transform,
)
weights = torchvision.models.EfficientNet_B3_Weights.DEFAULT
auto_transform = weights.transforms()
(
train_dataloader,
test_dataloader,
class_names,
train_data,
test_data,
) = create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
batch_size=BATCH_SIZE,
num_workers=os.cpu_count() or 0,
transform=auto_transform,
)
model = torchvision.models.efficientnet_b3(weights=weights).to(device)
torchinfo.summary(
model=model,
input_size=(32, 3, 224, 224),
col_names=["input_size", "output_size", "num_params", "trainable"],
row_settings=["var_names"],
)
for feature in model.features:
print(feature)
for param in model.features.parameters():
param.requires_grad = False
print(f"Classifier part has (before changing):\n{model.classifier}")
torch.manual_seed(37)
torch.cuda.manual_seed(37)
output_shape = len(class_names)
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(in_features=1536, out_features=output_shape, bias=True),
)
print(f"Classifier part has (after changing):\n{model.classifier}")
torchinfo.summary(
model=model,
input_size=(32, 3, 224, 224),
col_names=["input_size", "output_size", "num_params", "trainable"],
row_settings=["var_names"],
)
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
class Engine:
def __init__(
self,
train_dataloader: torch.utils.data.DataLoader,
test_dataloader: torch.utils.data.DataLoader,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
loss_fn: torch.nn.Module,
device: typing.Literal["cuda", "cpu"],
num_epoch: int,
):
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.optim = optim
self.loss_fn = loss_fn
self.device = device
self.num_epoch = num_epoch
self.model = model.to(device)
def _train_step(self) -> tuple[float, float]:
self.model.train()
loss_train = 0
acc_train = 0
for batch, (X, y) in enumerate(self.train_dataloader):
X, y = X.to(self.device), y.to(self.device)
train_pred = self.model(X)
loss = self.loss_fn(train_pred, y)
loss_train += loss.item()
optim.zero_grad()
loss.backward()
optim.step()
pred_class = torch.argmax(torch.softmax(train_pred, dim=1), dim=1)
acc = (pred_class == y).sum().item() / len(pred_class)
acc_train += acc
if batch % 2 == 0:
print(f"{batch} batches have been processed...")
loss_train = loss_train / len(self.train_dataloader)
acc_train = acc_train / len(self.train_dataloader)
return loss_train, acc_train
def _test_step(self) -> tuple[float, float]:
self.model.eval()
loss_test = 0
acc_test = 0
with torch.inference_mode():
for batch, (X, y) in enumerate(self.test_dataloader):
X, y = X.to(self.device), y.to(self.device)
test_pred = self.model(X)
loss = self.loss_fn(test_pred, y)
loss_test += loss.item()
pred_class = torch.argmax(torch.softmax(test_pred, dim=1), dim=1)
acc = (pred_class == y).sum().item() / len(pred_class)
acc_test += acc
if batch % 2 == 0:
print(f"{batch} batches have been processed...")
loss_test = loss_test / len(self.test_dataloader)
acc_test = acc_test / len(self.test_dataloader)
return loss_test, acc_test
def train(self) -> tuple[list[float], list[float], list[float], list[float]]:
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
for epoch in tqdm(range(self.num_epoch)):
print(f"{'*' * 6} EPOCH NUM: {epoch} {'*' * 6}")
print("Starting the training...")
train_loss, train_acc = self._train_step()
print("Starting the testing...")
test_loss, test_acc = self._test_step()
print(
f"Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f}"
f"Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.3f}"
)
train_loss_list.append(train_loss)
train_acc_list.append(train_acc)
test_loss_list.append(test_loss)
test_acc_list.append(test_acc)
return train_loss_list, train_acc_list, test_loss_list, test_acc_list
torch.manual_seed(37)
torch.cuda.manual_seed(37)
engine = Engine(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
model=model,
optim=optim,
loss_fn=loss_fn,
num_epoch=NUM_EPOCH,
device=device,
)
def plot_curves(
train_loss: list[float],
train_acc: list[float],
test_loss: list[float],
test_acc: list[float],
num_epoch: int,
):
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
# Ploting loss curves
ax[0].plot(range(num_epoch), train_loss, color="red", label="Train")
ax[0].plot(range(num_epoch), test_loss, color="blue", label="Test")
ax[0].set(xlabel="Epochs", ylabel="Loss", title="Train vs Test Loss")
ax[0].legend()
# Plotting acc curves
ax[1].plot(range(num_epoch), train_acc, color="red", label="Train")
ax[1].plot(range(num_epoch), test_acc, color="blue", label="Test")
ax[1].set(xlabel="Epochs", ylabel="Accuracy", title="Train vs Test Accuracy")
ax[1].legend()
fig.suptitle("Loss and Accuracy Curve")
plt.savefig(f"{MODEL_NAME}_curves.png")
plt.show()
if TRAIN_MODEL:
start_time = timer()
train_loss, train_acc, test_loss, test_acc = engine.train()
end_time = timer()
print(f"INFO: Training process took {end_time - start_time:.3f} seconds.")
MODEL_PATH.mkdir(parents=True, exist_ok=True)
torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)
plot_curves(train_loss, train_acc, test_loss, test_acc, NUM_EPOCH)
else:
model.load_state_dict(
torch.load(f=MODEL_SAVE_PATH, weights_only=True, map_location=device)
)
# Plotting the Confusion Matrix
def give_predictions(
test_dataloader: torch.utils.data.DataLoader,
model: torch.nn.Module,
device: typing.Literal["cuda", "cpu"],
) -> tuple[torch.Tensor, torch.Tensor]:
print("Starting the testing...")
model.to(device)
predictions = []
logits_prob = []
model.eval()
with torch.inference_mode():
for X, y in tqdm(test_dataloader, desc="Doing Validation"):
X, y = X.to(device), y.to(device)
logits = model(X)
pred = torch.argmax(torch.softmax(logits, dim=1), dim=1)
logits_prob.append(torch.softmax(logits, dim=1).cpu())
predictions.append(pred.cpu())
return torch.cat(predictions), torch.cat(logits_prob)
# First we need the prediction on entire dataset
test_preds, logits_prob = give_predictions(
test_dataloader=test_dataloader, model=model, device=device
)
confmat = torchmetrics.ConfusionMatrix(num_classes=len(class_names), task="multiclass")
confmat_tensor = confmat(preds=test_preds, target=torch.tensor(test_data.targets))
fig, ax = mlxtend.plotting.plot_confusion_matrix(
conf_mat=confmat_tensor.numpy(),
class_names=class_names,
figsize=(10, 7),
)
plt.savefig(f"{MODEL_NAME}_confusion_matrix.png")
plt.show()
# Getting the wrong predictions where the model was most confidient.
pred_wrong = []
for i in range(len(test_preds)):
if test_preds[i] != test_data.targets[i]:
pred_wrong.append([test_data.targets[i], test_preds[i], logits_prob[i], i])
pred_wrong.sort(key=lambda x: x[2][x[1]], reverse=True)
# Creating this so I can get un-normalized data so I can plot the image.
# otherwise some images will be below zero that is invaild etc.
test_data_original = torchvision.datasets.ImageFolder(
test_dir,
transform=None,
)
if len(pred_wrong) > 2:
nrows, ncols = len(pred_wrong) // 2 if len(pred_wrong) // 2 < 5 else 5, 2
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 8))
for rows in range(nrows):
for cols in range(ncols):
index_1d = rows * ncols + cols
image, true_label_index = test_data_original[pred_wrong[index_1d][3]]
true_label = class_names[true_label_index]
pred_label_index = pred_wrong[index_1d][1]
pred_label = class_names[pred_label_index]
ax[rows][cols].imshow(image)
ax[rows][cols].set_title(
f"True: {true_label}:{pred_wrong[index_1d][2][true_label_index]:.2f} | Prediction: {pred_label}:{pred_wrong[index_1d][2][pred_label_index]:.2f}"
)
ax[rows][cols].axis("off")
plt.savefig(f"{MODEL_NAME}_wrong_pred.png")
plt.show()
elif len(pred_wrong) == 1:
image, true_label_index = test_data_original[pred_wrong[0][3]]
true_label = class_names[true_label_index]
pred_label_index = pred_wrong[0][1]
pred_label = class_names[pred_label_index]
plt.imshow(image)
plt.title(
f"True: {true_label}:{pred_wrong[0][2][true_label_index]:.2f} | Prediction: {pred_label}:{pred_wrong[0][2][pred_label_index]:.2f}"
)
plt.axis(False)
plt.savefig(f"{MODEL_NAME}_wrong_pred.png")
plt.show()