import torch import torch.nn as nn from torchvision import transforms, models from torch.utils.data import DataLoader, Subset from torchvision.datasets import ImageFolder from ClassUtils import CrosswalkDataset import numpy as np import random import time import warnings # Torchvision's models utils has a depreciation warning for the pretrained parameter in its instantiation but we don't use that warnings.filterwarnings( action='ignore', category=DeprecationWarning, module=r'.*' ) # In a later version, this could be moved to a configuration file. learning_rate = 4e-3 epoch_num = 25 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vgg16 = models.vgg16(weights = models.VGG16_Weights) # Modifies fully connected layer to output binary class predictions vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, 2) # Freeze as you see fit depending on the application you want ot design # for param in vgg16.features.parameters(): # param.requires_grad = False # for param in vgg16.classifier[:6].parameters(): # param.requires_grad = False vgg16 = vgg16.to(device) loss_function = nn.BCELoss() # Prevents accidental loading of the whole training process in the background if __name__ == "__main__": # Takes only the classifier layers, which have not been frozen optimiser = torch.optim.Adam(params= filter(lambda p: p.requires_grad, vgg16.parameters()), lr=learning_rate) training_dataset = CrosswalkDataset("zebra_annotations/classification_data") training_loader = DataLoader(Subset(training_dataset, random.sample(range(len(training_dataset)-1), 25000)), batch_size=128, shuffle=True) for param in vgg16.features.parameters(): param.requires_grad = False vgg16.train() print(len(training_dataset)) for epoch in range(epoch_num): running_loss = 0.0 start_time = time.time() last_time = start_time for images, gt in training_loader: images, gt = images.to(device), gt.to(device) classifications = torch.sigmoid(vgg16(images)) loss = loss_function(classifications, gt) optimiser.zero_grad() loss.backward() optimiser.step() batch_time = time.time() running_loss += loss.item() last_time = batch_time print(",,, ---") print(f"\nEpoch {epoch + 1} of {epoch_num} has a per image loss of [{running_loss/len(training_loader):.4f}]") print(f"{(last_time - start_time):.6f}") # Includes the feature extraction layers torch.save(vgg16.state_dict(), "VGG16_Full_State_Dict.pth") # Only includes the classifier layer # - the 'head' whose weights you can use to overwrite if you don't want to store the whole state dict file torch.save(vgg16.classifier[6].state_dict(), "vgg16_binary_classifier_onlyHead.pth")