|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms, models |
|
|
from torch.utils.data import DataLoader, Subset |
|
|
from torchvision.datasets import ImageFolder |
|
|
from ClassUtils import CrosswalkDataset |
|
|
import numpy as np |
|
|
import random |
|
|
import time |
|
|
|
|
|
|
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings( |
|
|
action='ignore', |
|
|
category=DeprecationWarning, |
|
|
module=r'.*' |
|
|
) |
|
|
|
|
|
|
|
|
learning_rate = 4e-3 |
|
|
epoch_num = 25 |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
vgg16 = models.vgg16(weights = models.VGG16_Weights) |
|
|
|
|
|
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vgg16 = vgg16.to(device) |
|
|
loss_function = nn.BCELoss() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
optimiser = torch.optim.Adam(params= |
|
|
filter(lambda p: p.requires_grad, vgg16.parameters()), |
|
|
lr=learning_rate) |
|
|
|
|
|
|
|
|
training_dataset = CrosswalkDataset("zebra_annotations/classification_data") |
|
|
training_loader = DataLoader(Subset(training_dataset, random.sample(range(len(training_dataset)-1), 25000)), batch_size=128, shuffle=True) |
|
|
|
|
|
for param in vgg16.features.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
vgg16.train() |
|
|
print(len(training_dataset)) |
|
|
for epoch in range(epoch_num): |
|
|
running_loss = 0.0 |
|
|
start_time = time.time() |
|
|
last_time = start_time |
|
|
for images, gt in training_loader: |
|
|
images, gt = images.to(device), gt.to(device) |
|
|
|
|
|
classifications = torch.sigmoid(vgg16(images)) |
|
|
loss = loss_function(classifications, gt) |
|
|
optimiser.zero_grad() |
|
|
loss.backward() |
|
|
optimiser.step() |
|
|
|
|
|
batch_time = time.time() |
|
|
|
|
|
running_loss += loss.item() |
|
|
|
|
|
last_time = batch_time |
|
|
print(",,, ---") |
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch + 1} of {epoch_num} has a per image loss of [{running_loss/len(training_loader):.4f}]") |
|
|
print(f"{(last_time - start_time):.6f}") |
|
|
|
|
|
|
|
|
torch.save(vgg16.state_dict(), "VGG16_Full_State_Dict.pth") |
|
|
|
|
|
|
|
|
torch.save(vgg16.classifier[6].state_dict(), "vgg16_binary_classifier_onlyHead.pth") |
|
|
|