Upload 5 files
Browse files- augment.py +89 -0
- requirements.txt +10 -0
- test.py +192 -0
- train.py +332 -0
- utils.py +107 -0
augment.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
class RandAug:
|
| 6 |
+
"""Randomly chosen image augmentations."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, img_size, choice=None):
|
| 9 |
+
# Augmentation options
|
| 10 |
+
self.trans = ['identity', 'rotate', 'color', 'sharpness', 'blur', 'padding' ,'perspective']
|
| 11 |
+
self.img_size = img_size
|
| 12 |
+
self.choice = choice
|
| 13 |
+
|
| 14 |
+
def __call__(self, img):
|
| 15 |
+
if self.choice == None:
|
| 16 |
+
# Weights set 40% probability for the 'identity' augmentation choice
|
| 17 |
+
self.choice = random.choices(self.trans, weights=(40, 10, 10, 10, 10, 10, 10))[0]
|
| 18 |
+
|
| 19 |
+
if self.choice == 'identity':
|
| 20 |
+
trans = transforms.Compose([
|
| 21 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 22 |
+
transforms.ToTensor()
|
| 23 |
+
])
|
| 24 |
+
img = trans(img)
|
| 25 |
+
|
| 26 |
+
elif self.choice == 'rotate':
|
| 27 |
+
degrees = random.uniform(0, 180)
|
| 28 |
+
rand_fill = random.choice([0,1])
|
| 29 |
+
trans = transforms.Compose([
|
| 30 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.RandomRotation(degrees, expand=True, fill=rand_fill),
|
| 33 |
+
transforms.Resize((self.img_size,self.img_size))
|
| 34 |
+
])
|
| 35 |
+
img = trans(img)
|
| 36 |
+
|
| 37 |
+
elif self.choice == 'color':
|
| 38 |
+
rand_brightness = random.uniform(0, 0.3)
|
| 39 |
+
rand_hue = random.uniform(0, 0.5)
|
| 40 |
+
rand_contrast = random.uniform(0, 0.5)
|
| 41 |
+
rand_saturation = random.uniform(0, 0.5)
|
| 42 |
+
trans = transforms.Compose([
|
| 43 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
|
| 46 |
+
])
|
| 47 |
+
img = trans(img)
|
| 48 |
+
|
| 49 |
+
elif self.choice=='sharpness':
|
| 50 |
+
sharpness = 1+(np.random.exponential()/2)
|
| 51 |
+
trans = transforms.Compose([
|
| 52 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 53 |
+
transforms.ToTensor(),
|
| 54 |
+
transforms.RandomAdjustSharpness(sharpness, p=1)
|
| 55 |
+
])
|
| 56 |
+
img = trans(img)
|
| 57 |
+
|
| 58 |
+
elif self.choice=='blur':
|
| 59 |
+
kernel = random.choice([1,3,5])
|
| 60 |
+
trans = transforms.Compose([
|
| 61 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 62 |
+
transforms.ToTensor(),
|
| 63 |
+
transforms.GaussianBlur(kernel, sigma=(0.1, 2.0))
|
| 64 |
+
])
|
| 65 |
+
img = trans(img)
|
| 66 |
+
|
| 67 |
+
elif self.choice=='padding':
|
| 68 |
+
pad = random.choice([3,10,25])
|
| 69 |
+
rand_fill = random.choice([0,1])
|
| 70 |
+
trans = transforms.Compose([
|
| 71 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 72 |
+
transforms.ToTensor(),
|
| 73 |
+
transforms.Pad(pad, fill=rand_fill, padding_mode='constant'),
|
| 74 |
+
transforms.Resize((self.img_size,self.img_size))
|
| 75 |
+
])
|
| 76 |
+
img = trans(img)
|
| 77 |
+
|
| 78 |
+
elif self.choice=='perspective':
|
| 79 |
+
scale = random.uniform(0.1, 0.5)
|
| 80 |
+
rand_fill = random.choice([0,1])
|
| 81 |
+
trans = transforms.Compose([
|
| 82 |
+
transforms.Resize((self.img_size,self.img_size)),
|
| 83 |
+
transforms.ToTensor(),
|
| 84 |
+
transforms.RandomPerspective(distortion_scale=scale, p=1.0, fill=rand_fill),
|
| 85 |
+
transforms.Resize((self.img_size,self.img_size))
|
| 86 |
+
])
|
| 87 |
+
img = trans(img)
|
| 88 |
+
|
| 89 |
+
return img
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
| 2 |
+
torch==1.12.1+cu116
|
| 3 |
+
torchvision==0.13.1+cu116
|
| 4 |
+
scikit-learn==1.0.2
|
| 5 |
+
numpy==1.21.6
|
| 6 |
+
pillow==9.3.0
|
| 7 |
+
matplotlib==3.5.3
|
| 8 |
+
onnx==1.13.0
|
| 9 |
+
onnxruntime==1.13.1
|
| 10 |
+
tqdm==4.64.1
|
test.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
from __future__ import division
|
| 3 |
+
import torch
|
| 4 |
+
import onnxruntime
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torchvision
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
|
| 11 |
+
import seaborn as sn
|
| 12 |
+
import random
|
| 13 |
+
import time
|
| 14 |
+
import json
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from PIL import ImageFile
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import argparse
|
| 19 |
+
print("PyTorch Version: ",torch.__version__)
|
| 20 |
+
print("Torchvision Version: ",torchvision.__version__)
|
| 21 |
+
|
| 22 |
+
parser = argparse.ArgumentParser('arguments for testing the model')
|
| 23 |
+
|
| 24 |
+
parser.add_argument('--ts_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/test/",
|
| 25 |
+
help='path to test data')
|
| 26 |
+
parser.add_argument('--ts_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/test/",
|
| 27 |
+
help='path to test data')
|
| 28 |
+
parser.add_argument('--results_folder', type=str, default="./results/aug_28022024/",
|
| 29 |
+
help='Folder for saving results')
|
| 30 |
+
parser.add_argument('--model_path', type=str, default="/koodit/table_segmentation/empty_cell_detection/train/models/aug_b32_lr0001_28022024.onnx",
|
| 31 |
+
help='path to load model file from')
|
| 32 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
| 33 |
+
help='batch_size')
|
| 34 |
+
parser.add_argument('--num_classes', type=int, default=2,
|
| 35 |
+
help='number of classes for classification')
|
| 36 |
+
parser.add_argument('--name', type=str, default='empty_cell_augment_28022024',
|
| 37 |
+
help='name given to result files')
|
| 38 |
+
|
| 39 |
+
start = time.time()
|
| 40 |
+
|
| 41 |
+
# nohup python test.py > logs/aug_test_28022024.txt 2>&1 &
|
| 42 |
+
# echo $! > output/save_pid.txt
|
| 43 |
+
|
| 44 |
+
torch.manual_seed(67)
|
| 45 |
+
random.seed(67)
|
| 46 |
+
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 50 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 51 |
+
|
| 52 |
+
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_data():
|
| 56 |
+
empty_path = Path(args.ts_empty_folder)
|
| 57 |
+
ok_path = Path(args.ts_ok_folder)
|
| 58 |
+
|
| 59 |
+
empty_files = list(empty_path.glob('*.jpg'))
|
| 60 |
+
ok_files = list(ok_path.glob('*.jpg'))
|
| 61 |
+
|
| 62 |
+
empty_labels = np.zeros(len(empty_files))
|
| 63 |
+
ok_labels = np.ones(len(ok_files))
|
| 64 |
+
|
| 65 |
+
#ts_data_files = ts_data_files[:20]
|
| 66 |
+
#ts_data_labels = ts_data_labels[:20]
|
| 67 |
+
#ts_ok_files = ts_ok_files[:20]
|
| 68 |
+
#ts_ok_labels = ts_ok_labels[:20]
|
| 69 |
+
|
| 70 |
+
ts_files = empty_files + ok_files
|
| 71 |
+
ts_labels = np.concatenate((empty_labels, ok_labels))
|
| 72 |
+
|
| 73 |
+
print('Test data with empty cells: ', len(empty_files))
|
| 74 |
+
print('Test data without empty cells: ', len(ok_files))
|
| 75 |
+
|
| 76 |
+
return ts_files, ts_labels
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def initialize_model():
|
| 80 |
+
model = onnxruntime.InferenceSession(args.model_path)
|
| 81 |
+
input_size = 224
|
| 82 |
+
return model, input_size
|
| 83 |
+
|
| 84 |
+
# Function for getting precision, recall and F-score metrics
|
| 85 |
+
def get_precision_recall(y_true, y_pred):
|
| 86 |
+
precision_recall_fscore = precision_recall_fscore_support(y_true, y_pred, average=None)
|
| 87 |
+
|
| 88 |
+
prec_0 = precision_recall_fscore[0][0]
|
| 89 |
+
rec_0 = precision_recall_fscore[1][0]
|
| 90 |
+
F_0 = precision_recall_fscore[2][0]
|
| 91 |
+
|
| 92 |
+
prec_1 = precision_recall_fscore[0][1]
|
| 93 |
+
rec_1 = precision_recall_fscore[1][1]
|
| 94 |
+
F_1 = precision_recall_fscore[2][1]
|
| 95 |
+
|
| 96 |
+
print('\nPrecision for ok: %.2f'%prec_1)
|
| 97 |
+
print('Recall for ok: %.2f'%rec_1)
|
| 98 |
+
print('F-score for ok: %.2f'%F_1)
|
| 99 |
+
|
| 100 |
+
print('Precision for empty: %.2f'%prec_0 )
|
| 101 |
+
print('Recall for empty: %.2f'%rec_0)
|
| 102 |
+
print('F-score for empty: %.2f'%F_0)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def createConfusionMatrix(y_true, y_pred):
|
| 106 |
+
classes = np.array(['empty', 'ok'])
|
| 107 |
+
|
| 108 |
+
# Build confusion matrix
|
| 109 |
+
cf_matrix = confusion_matrix(y_true, y_pred)
|
| 110 |
+
print(cf_matrix)
|
| 111 |
+
df_cm = pd.DataFrame(cf_matrix, index=classes,
|
| 112 |
+
columns=classes)
|
| 113 |
+
plt.figure(figsize=(12, 7))
|
| 114 |
+
return sn.heatmap(df_cm, annot=True).get_figure()
|
| 115 |
+
|
| 116 |
+
def save_preds(y_true, y_pred, paths):
|
| 117 |
+
# Identifies images that were not classified correctly
|
| 118 |
+
incorrect_indices = np.where(y_true != y_pred)
|
| 119 |
+
incorrectly_predicted_images = paths[incorrect_indices]
|
| 120 |
+
correct_labels = y_true[incorrect_indices].astype(str)
|
| 121 |
+
incorrect_preds = dict(zip(incorrectly_predicted_images, correct_labels))
|
| 122 |
+
|
| 123 |
+
print(f'{len(incorrect_preds)} incorrect predictions')
|
| 124 |
+
|
| 125 |
+
# Save file names and labels of incorrectly classified images
|
| 126 |
+
with open(args.results_folder + args.name + '_incorrect_preds', "w") as fp:
|
| 127 |
+
json.dump(incorrect_preds, fp)
|
| 128 |
+
|
| 129 |
+
# Initialize the model for this run
|
| 130 |
+
model, input_size = initialize_model()
|
| 131 |
+
|
| 132 |
+
# Print the model we just instantiated
|
| 133 |
+
#print(model_ft)
|
| 134 |
+
|
| 135 |
+
data_transforms = transforms.Compose([
|
| 136 |
+
transforms.Resize((input_size, input_size)),
|
| 137 |
+
transforms.ToTensor()
|
| 138 |
+
])
|
| 139 |
+
|
| 140 |
+
print("Initializing Datasets and Dataloaders...")
|
| 141 |
+
|
| 142 |
+
ts_files, ts_labels = get_data()
|
| 143 |
+
|
| 144 |
+
# Function for getting model predictions on test data
|
| 145 |
+
def test_model(model, ts_files, ts_labels):
|
| 146 |
+
since = time.time()
|
| 147 |
+
label_preds = []
|
| 148 |
+
true_labels = []
|
| 149 |
+
paths = []
|
| 150 |
+
n = len(ts_files)
|
| 151 |
+
# Iterate over data
|
| 152 |
+
for i in range(n):
|
| 153 |
+
print(f'{i}/{n}')
|
| 154 |
+
image = Image.open(ts_files[i])
|
| 155 |
+
label = ts_labels[i]
|
| 156 |
+
image = data_transforms(image.convert("RGB")).unsqueeze(0)
|
| 157 |
+
# Transform tensor to numpy array
|
| 158 |
+
img = image.detach().cpu().numpy()
|
| 159 |
+
input = {model.get_inputs()[0].name: img}
|
| 160 |
+
# Run model prediction
|
| 161 |
+
output = model.run(None, input)
|
| 162 |
+
# Get predicted class
|
| 163 |
+
pred = np.argmax(output[0], 1)
|
| 164 |
+
pred_class = pred.item()
|
| 165 |
+
label_preds.append(pred_class)
|
| 166 |
+
true_labels.append(label)
|
| 167 |
+
paths.append(str(ts_files[i]))
|
| 168 |
+
|
| 169 |
+
time_elapsed = time.time() - since
|
| 170 |
+
print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
| 171 |
+
|
| 172 |
+
return np.array(label_preds), np.array(true_labels), np.array(paths)
|
| 173 |
+
|
| 174 |
+
ts_labels = np.array(ts_labels)
|
| 175 |
+
|
| 176 |
+
# Test model
|
| 177 |
+
y_pred, y_true, paths = test_model(model, ts_files, ts_labels)
|
| 178 |
+
# Saves information of incorrect predictions
|
| 179 |
+
save_preds(y_true, y_pred, paths)
|
| 180 |
+
# Calculates and prints precision, recall and F-score metrics
|
| 181 |
+
get_precision_recall(y_true, y_pred)
|
| 182 |
+
|
| 183 |
+
# Save confusion matrix to Tensorboard
|
| 184 |
+
#cm = createConfusionMatrix(y_true, y_pred)
|
| 185 |
+
#writer.add_figure("Confusion matrix", cm)
|
| 186 |
+
# Create and save confusion matrix of the predictions and true labels
|
| 187 |
+
conf_matrix = ConfusionMatrixDisplay.from_predictions(y_true, y_pred, normalize='true', display_labels=np.array(['empty', 'ok']))
|
| 188 |
+
plt.savefig(args.results_folder + args.name + '_conf_matrix.jpg', bbox_inches='tight')
|
| 189 |
+
|
| 190 |
+
end = time.time()
|
| 191 |
+
time_in_mins = (end - start) / 60
|
| 192 |
+
print('Time: %.2f minutes' % time_in_mins)
|
train.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
from __future__ import division
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from torchvision import models
|
| 8 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 9 |
+
from sklearn.utils import class_weight
|
| 10 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 11 |
+
import numpy as np
|
| 12 |
+
import time
|
| 13 |
+
import argparse
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from PIL import Image, ImageFile
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
from augment import RandAug
|
| 19 |
+
import utils
|
| 20 |
+
|
| 21 |
+
print("PyTorch Version: ",torch.__version__)
|
| 22 |
+
print("Torchvision Version: ",torchvision.__version__)
|
| 23 |
+
|
| 24 |
+
# Much of the code is a modified version of the code available at
|
| 25 |
+
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# nohup python train.py > logs/empty_cell_aug_28032024.txt 2>&1 &
|
| 29 |
+
# echo $! > logs/save_pid.txt
|
| 30 |
+
|
| 31 |
+
parser = argparse.ArgumentParser('arguments for training')
|
| 32 |
+
|
| 33 |
+
parser.add_argument('--tr_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/train/",
|
| 34 |
+
help='path to training data with empty images')
|
| 35 |
+
parser.add_argument('--val_empty_folder', type=str, default="/data/taulukot/solukuvat/empty/val/",
|
| 36 |
+
help='path to validation data with empty images')
|
| 37 |
+
parser.add_argument('--tr_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/train/",
|
| 38 |
+
help='path to training data with ok images')
|
| 39 |
+
parser.add_argument('--val_ok_folder', type=str, default="/data/taulukot/solukuvat/ok/val/",
|
| 40 |
+
help='path to validation data with ok images')
|
| 41 |
+
parser.add_argument('--results_folder', type=str, default="results/28032024_aug/",
|
| 42 |
+
help='Folder for saving training results.')
|
| 43 |
+
parser.add_argument('--save_model_path', type=str, default="./models/",
|
| 44 |
+
help='Path for saving model file.')
|
| 45 |
+
parser.add_argument('--batch_size', type=int, default=32,
|
| 46 |
+
help='Batch size used for model training. ')
|
| 47 |
+
parser.add_argument('--lr', type=float, default=0.0001,
|
| 48 |
+
help='Base learning rate.')
|
| 49 |
+
parser.add_argument('--device', type=str, default='cpu',
|
| 50 |
+
help='Defines whether the model is trained using cpu or gpu.')
|
| 51 |
+
parser.add_argument('--num_classes', type=int, default=2,
|
| 52 |
+
help='Number of classes used in classification.')
|
| 53 |
+
parser.add_argument('--num_epochs', type=int, default=15,
|
| 54 |
+
help='Number of training epochs.')
|
| 55 |
+
parser.add_argument('--random_seed', type=int, default=8765,
|
| 56 |
+
help='Number used for initializing random number generation.')
|
| 57 |
+
parser.add_argument('--early_stop_threshold', type=int, default=3,
|
| 58 |
+
help='Threshold value of epochs after which training stops if validation accuracy does not improve.')
|
| 59 |
+
parser.add_argument('--save_model_format', type=str, default='torch',
|
| 60 |
+
help='Defines the format for saving the model.')
|
| 61 |
+
parser.add_argument('--augment_choice', type=str, default=None,
|
| 62 |
+
help='Defines which image augmentation(s) are used. Defaults to randomly selected augmentations.')
|
| 63 |
+
parser.add_argument('--model_name', type=str, default='aug_b32_lr0001',
|
| 64 |
+
help='Current date.')
|
| 65 |
+
parser.add_argument('--date', type=str, default=time.strftime("%d%m%Y"),
|
| 66 |
+
help='Current date.')
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
# PIL settings to avoid errors caused by truncated and large images
|
| 71 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 72 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 73 |
+
|
| 74 |
+
# List for saving the names of damaged images
|
| 75 |
+
damaged_images = []
|
| 76 |
+
|
| 77 |
+
def get_datapaths():
|
| 78 |
+
"""Function for loading train and validation data."""
|
| 79 |
+
tr_empty_files = list(Path(args.tr_empty_folder).glob('*'))
|
| 80 |
+
tr_ok_files = list(Path(args.tr_ok_folder).glob('*'))
|
| 81 |
+
val_empty_files = list(Path(args.val_empty_folder).glob('*'))
|
| 82 |
+
val_ok_files = list(Path(args.val_ok_folder).glob('*'))
|
| 83 |
+
# Create labels for train and validation data
|
| 84 |
+
tr_labels = np.concatenate((np.zeros(len(tr_empty_files)), np.ones(len(tr_ok_files))))
|
| 85 |
+
val_labels = np.concatenate((np.zeros(len(val_empty_files)), np.ones(len(val_ok_files))))
|
| 86 |
+
# Combine faulty and non-faulty images
|
| 87 |
+
tr_files = tr_empty_files + tr_ok_files
|
| 88 |
+
val_files = val_empty_files + val_ok_files
|
| 89 |
+
|
| 90 |
+
print('\nTraining data with empty cells: ', len(tr_empty_files))
|
| 91 |
+
print('Training data without empty cells: ', len(tr_ok_files))
|
| 92 |
+
|
| 93 |
+
print('Validation data with empty cells: ', len(val_empty_files))
|
| 94 |
+
print('Validation data without empty cells: ', len(val_ok_files))
|
| 95 |
+
|
| 96 |
+
data_dict = {'tr_data': tr_files, 'tr_labels': tr_labels,
|
| 97 |
+
'val_data': val_files, 'val_labels': val_labels}
|
| 98 |
+
|
| 99 |
+
return data_dict
|
| 100 |
+
|
| 101 |
+
class ImageDataset(Dataset):
|
| 102 |
+
"""PyTorch Dataset class is used for generating training and validation datasets."""
|
| 103 |
+
def __init__(self, img_paths, img_labels, transform=None, target_transform=None):
|
| 104 |
+
self.img_paths = img_paths
|
| 105 |
+
self.img_labels = img_labels
|
| 106 |
+
self.transform = transform
|
| 107 |
+
self.target_transform = target_transform
|
| 108 |
+
|
| 109 |
+
def __len__(self):
|
| 110 |
+
return len(self.img_labels)
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, idx):
|
| 113 |
+
img_path = self.img_paths[idx]
|
| 114 |
+
try:
|
| 115 |
+
image = Image.open(img_path).convert('RGB')
|
| 116 |
+
label = self.img_labels[idx]
|
| 117 |
+
except:
|
| 118 |
+
# Image is considered damaged if reading the image fails
|
| 119 |
+
damaged_images.append(img_path)
|
| 120 |
+
return None
|
| 121 |
+
if self.transform:
|
| 122 |
+
image = self.transform(image.convert("RGB"))
|
| 123 |
+
if self.target_transform:
|
| 124 |
+
label = self.target_transform(label)
|
| 125 |
+
|
| 126 |
+
return image, label
|
| 127 |
+
|
| 128 |
+
def initialize_model():
|
| 129 |
+
"""Function for initializing pretrained neural network model (DenseNet121)."""
|
| 130 |
+
model_ft = models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
|
| 131 |
+
num_ftrs = model_ft.classifier.in_features
|
| 132 |
+
model_ft.classifier = nn.Linear(num_ftrs, args.num_classes)
|
| 133 |
+
input_size = 224
|
| 134 |
+
|
| 135 |
+
return model_ft, input_size
|
| 136 |
+
|
| 137 |
+
def collate_fn(batch):
|
| 138 |
+
"""Helper function for creating data batches."""
|
| 139 |
+
batch = list(filter(lambda x: x is not None, batch))
|
| 140 |
+
|
| 141 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
| 142 |
+
|
| 143 |
+
def initialize_dataloaders(data_dict, input_size):
|
| 144 |
+
"""Function for initializing datasets and dataloaders."""
|
| 145 |
+
# Train and validation datasets
|
| 146 |
+
train_dataset = ImageDataset(img_paths=data_dict['tr_data'], img_labels=data_dict['tr_labels'], transform=RandAug(input_size, args.augment_choice))
|
| 147 |
+
validation_dataset = ImageDataset(img_paths=data_dict['val_data'], img_labels=data_dict['val_labels'], transform=RandAug(input_size, 'identity'))
|
| 148 |
+
# Train and validation dataloaders
|
| 149 |
+
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
| 150 |
+
validation_dataloader = DataLoader(validation_dataset, collate_fn=collate_fn, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
| 151 |
+
|
| 152 |
+
return {'train': train_dataloader, 'val': validation_dataloader}
|
| 153 |
+
|
| 154 |
+
def get_criterion(data_dict):
|
| 155 |
+
"""Function for generating class weights and initializing the loss function."""
|
| 156 |
+
y = np.asarray(data_dict['tr_labels'])
|
| 157 |
+
# Class weights are used for compensating the unbalance
|
| 158 |
+
# in the number of training data from the two classes
|
| 159 |
+
class_weights=class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y), y=y)
|
| 160 |
+
class_weights=torch.tensor(class_weights, dtype=torch.float).to(args.device)
|
| 161 |
+
print('\nClass weights: ', class_weights.tolist())
|
| 162 |
+
# Cross Entropy Loss function
|
| 163 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
|
| 164 |
+
|
| 165 |
+
return criterion
|
| 166 |
+
|
| 167 |
+
def get_optimizer(model):
|
| 168 |
+
"""Function for initializing the optimizer."""
|
| 169 |
+
# Model parameters are split into two groups: parameters of the classifier
|
| 170 |
+
# layer and other model parameters
|
| 171 |
+
params_1 = [param for name, param in model.named_parameters()
|
| 172 |
+
if name not in ["classifier.weight", "classifier.bias"]]
|
| 173 |
+
params_2 = model.classifier.parameters()
|
| 174 |
+
# 10 x larger learning rate is used when training the parameters
|
| 175 |
+
# of the classification layers
|
| 176 |
+
params_to_update = [
|
| 177 |
+
{'params': params_1, 'lr': args.lr},
|
| 178 |
+
{'params': params_2, 'lr': args.lr * 10}
|
| 179 |
+
]
|
| 180 |
+
# Adam optimizer
|
| 181 |
+
optimizer = torch.optim.Adam(params_to_update, args.lr)
|
| 182 |
+
# Scheduler reduces learning rate when validation accuracy does not improve for an epoch
|
| 183 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)
|
| 184 |
+
|
| 185 |
+
return optimizer, scheduler
|
| 186 |
+
|
| 187 |
+
def train_model(model, dataloaders, criterion, optimizer, scheduler=None):
|
| 188 |
+
"""Function for model training and validation."""
|
| 189 |
+
since = time.time()
|
| 190 |
+
# Lists for saving train and validation metrics for each epoch
|
| 191 |
+
tr_loss_history = []
|
| 192 |
+
tr_acc_history = []
|
| 193 |
+
tr_f1_history = []
|
| 194 |
+
val_loss_history = []
|
| 195 |
+
val_acc_history = []
|
| 196 |
+
val_f1_history = []
|
| 197 |
+
# Lists for saving learning rates for the 2 parameter groups
|
| 198 |
+
lr1_history = []
|
| 199 |
+
lr2_history = []
|
| 200 |
+
|
| 201 |
+
# Best F1 value and best epoch are saved in variables
|
| 202 |
+
best_f1 = 0
|
| 203 |
+
best_epoch = 0
|
| 204 |
+
early_stop = False
|
| 205 |
+
|
| 206 |
+
# Train / validation loop
|
| 207 |
+
for epoch in tqdm(range(args.num_epochs)):
|
| 208 |
+
# Save learning rates for the epoch
|
| 209 |
+
lr1_history.append(optimizer.param_groups[0]["lr"])
|
| 210 |
+
lr2_history.append(optimizer.param_groups[1]["lr"])
|
| 211 |
+
|
| 212 |
+
print('Epoch {}/{}'.format(epoch+1, args.num_epochs))
|
| 213 |
+
print('-' * 10)
|
| 214 |
+
|
| 215 |
+
# Each epoch has a training and validation phase
|
| 216 |
+
for phase in ['train', 'val']:
|
| 217 |
+
if phase == 'train':
|
| 218 |
+
model.train() # Set model to training mode
|
| 219 |
+
else:
|
| 220 |
+
model.eval() # Set model to evaluate mode
|
| 221 |
+
|
| 222 |
+
running_loss = 0.0
|
| 223 |
+
running_corrects = 0
|
| 224 |
+
running_f1 = 0.0
|
| 225 |
+
|
| 226 |
+
# Iterate over data in batch
|
| 227 |
+
for inputs, labels in dataloaders[phase]:
|
| 228 |
+
if dataloaders[phase] is None:
|
| 229 |
+
continue
|
| 230 |
+
else:
|
| 231 |
+
inputs = inputs.to(args.device)
|
| 232 |
+
labels = labels.long().to(args.device)
|
| 233 |
+
|
| 234 |
+
# Zero the parameter gradients
|
| 235 |
+
optimizer.zero_grad()
|
| 236 |
+
|
| 237 |
+
# Track history only in training phase
|
| 238 |
+
with torch.set_grad_enabled(phase == 'train'):
|
| 239 |
+
# Get model outputs and calculate loss
|
| 240 |
+
outputs = model(inputs)
|
| 241 |
+
loss = criterion(outputs, labels)
|
| 242 |
+
# Model predictions of the image labels for the batch
|
| 243 |
+
_, preds = torch.max(outputs, 1)
|
| 244 |
+
|
| 245 |
+
# Backward + optimize only if in training phase
|
| 246 |
+
if phase == 'train':
|
| 247 |
+
loss.backward()
|
| 248 |
+
optimizer.step()
|
| 249 |
+
|
| 250 |
+
# Get weighted F1 score for the results
|
| 251 |
+
precision_recall_fscore = precision_recall_fscore_support(labels.data.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='weighted', zero_division=0)
|
| 252 |
+
f1_score = precision_recall_fscore[2]
|
| 253 |
+
|
| 254 |
+
# update statistics
|
| 255 |
+
running_loss += loss.item() * inputs.size(0)
|
| 256 |
+
running_corrects += torch.sum(preds == labels.data).cpu()
|
| 257 |
+
running_f1 += f1_score
|
| 258 |
+
|
| 259 |
+
# Calculate loss, accuracy and F1 score for the epoch
|
| 260 |
+
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
| 261 |
+
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
|
| 262 |
+
epoch_f1 = running_f1 / len(dataloaders[phase])
|
| 263 |
+
|
| 264 |
+
print('\nEpoch {} - {} - Loss: {:.4f} Acc: {:.4f} F1: {:.4f}\n'.format(epoch+1, phase, epoch_loss, epoch_acc, epoch_f1))
|
| 265 |
+
|
| 266 |
+
# Validation step
|
| 267 |
+
if phase == 'val':
|
| 268 |
+
val_acc_history.append(epoch_acc)
|
| 269 |
+
val_loss_history.append(epoch_loss)
|
| 270 |
+
val_f1_history.append(epoch_f1)
|
| 271 |
+
if epoch_f1 > best_f1:
|
| 272 |
+
print('\nF1 score {:.4f} improved from {:.4f}. Saving the model.\n'.format(epoch_f1, best_f1))
|
| 273 |
+
# Model with best F1 score is saved
|
| 274 |
+
utils.save_model(model, 224, args.save_model_format, args.save_model_path, args.model_name, args.date)
|
| 275 |
+
model = model.to(args.device)
|
| 276 |
+
best_f1 = epoch_f1
|
| 277 |
+
best_epoch = epoch
|
| 278 |
+
elif epoch - best_epoch > args.early_stop_threshold:
|
| 279 |
+
# terminates the training loop if validation accuracy has not improved
|
| 280 |
+
print("Early stopped training at epoch %d" % epoch)
|
| 281 |
+
# Set early stopping condition
|
| 282 |
+
early_stop = True
|
| 283 |
+
break
|
| 284 |
+
elif phase == 'train':
|
| 285 |
+
tr_acc_history.append(epoch_acc)
|
| 286 |
+
tr_loss_history.append(epoch_loss)
|
| 287 |
+
tr_f1_history.append(epoch_f1)
|
| 288 |
+
|
| 289 |
+
# Break outer loop if early stopping condition is activated
|
| 290 |
+
if early_stop:
|
| 291 |
+
break
|
| 292 |
+
# Take scheduler step
|
| 293 |
+
if scheduler:
|
| 294 |
+
scheduler.step(val_f1_history[-1])
|
| 295 |
+
|
| 296 |
+
time_elapsed = time.time() - since
|
| 297 |
+
print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
| 298 |
+
print('Best validation F1 score: {:.4f}'.format(best_f1))
|
| 299 |
+
# Returns model with the weights from the best epoch (based on validation accuracy)
|
| 300 |
+
hist_dict = {'tr_acc': tr_acc_history,
|
| 301 |
+
'val_acc': val_acc_history,
|
| 302 |
+
'val_loss': val_loss_history,
|
| 303 |
+
'val_f1': val_f1_history,
|
| 304 |
+
'tr_loss': tr_loss_history,
|
| 305 |
+
'tr_f1': tr_f1_history,
|
| 306 |
+
'lr1': lr1_history,
|
| 307 |
+
'lr2': lr2_history}
|
| 308 |
+
|
| 309 |
+
return hist_dict
|
| 310 |
+
|
| 311 |
+
def main():
|
| 312 |
+
# Set random seed(s)
|
| 313 |
+
utils.set_seed(args.random_seed)
|
| 314 |
+
# Load image paths and labels
|
| 315 |
+
data_dict = get_datapaths()
|
| 316 |
+
# Initialize the model
|
| 317 |
+
model, input_size = initialize_model()
|
| 318 |
+
# Print the model architecture
|
| 319 |
+
#print(model_ft)
|
| 320 |
+
# Send the model to GPU (if available)
|
| 321 |
+
model = model.to(args.device)
|
| 322 |
+
print("\nInitializing Datasets and Dataloaders...")
|
| 323 |
+
dataloaders_dict = initialize_dataloaders(data_dict, input_size)
|
| 324 |
+
criterion = get_criterion(data_dict)
|
| 325 |
+
optimizer, scheduler = get_optimizer(model)
|
| 326 |
+
# Train and evaluate model
|
| 327 |
+
hist_dict = train_model(model, dataloaders_dict, criterion, optimizer, scheduler)
|
| 328 |
+
print('Damaged images: ', damaged_images)
|
| 329 |
+
utils.plot_metrics(hist_dict, args.results_folder, args.date)
|
| 330 |
+
|
| 331 |
+
if __name__ == '__main__':
|
| 332 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import onnx
|
| 3 |
+
import onnxruntime
|
| 4 |
+
import os
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
def set_seed(random_seed):
|
| 10 |
+
"""Function for setting random seed for the relevant libraries."""
|
| 11 |
+
np.random.seed(random_seed)
|
| 12 |
+
random.seed(random_seed)
|
| 13 |
+
torch.manual_seed(random_seed)
|
| 14 |
+
torch.cuda.manual_seed(random_seed)
|
| 15 |
+
# When running on the CuDNN backend, two further options must be set
|
| 16 |
+
torch.backends.cudnn.deterministic = True
|
| 17 |
+
torch.backends.cudnn.benchmark = False
|
| 18 |
+
# Set a fixed value for the hash seed
|
| 19 |
+
os.environ["PYTHONHASHSEED"] = str(random_seed)
|
| 20 |
+
print(f"Random seed set as {random_seed}")
|
| 21 |
+
|
| 22 |
+
def save_model(model, input_size, save_model_format, save_model_path, model_name, date):
|
| 23 |
+
"""Function for saving the model in .pth or .onnx format.
|
| 24 |
+
Code modified from
|
| 25 |
+
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"""
|
| 26 |
+
if save_model_format == 'onnx':
|
| 27 |
+
onnx_model_path = os.path.join(save_model_path, model_name + '_' + date + '.onnx')
|
| 28 |
+
# Random batch size
|
| 29 |
+
batch_size = 1
|
| 30 |
+
# Random input to the model (with correct dimensions)
|
| 31 |
+
x = torch.randn(batch_size, 3, input_size, input_size, requires_grad=True)
|
| 32 |
+
model = model.to('cpu')
|
| 33 |
+
torch_out = model(x)
|
| 34 |
+
|
| 35 |
+
# Export the model
|
| 36 |
+
torch.onnx.export(model, # model being run
|
| 37 |
+
x, # model input (or a tuple for multiple inputs)
|
| 38 |
+
onnx_model_path, # where to save the model (can be a file or file-like object)
|
| 39 |
+
export_params=True, # store the trained parameter weights inside the model file
|
| 40 |
+
opset_version=10, # the ONNX version to export the model to
|
| 41 |
+
do_constant_folding=True, # whether to execute constant folding for optimization
|
| 42 |
+
input_names = ['input'], # the model's input names
|
| 43 |
+
output_names = ['output'], # the model's output names
|
| 44 |
+
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
|
| 45 |
+
'output' : {0 : 'batch_size'}})
|
| 46 |
+
|
| 47 |
+
print('ONNX model saved to ', onnx_model_path)
|
| 48 |
+
# Test transformed model
|
| 49 |
+
onnx_model = onnx.load(onnx_model_path)
|
| 50 |
+
onnx.checker.check_model(onnx_model)
|
| 51 |
+
print('ONNX model checked.')
|
| 52 |
+
|
| 53 |
+
def to_numpy(tensor):
|
| 54 |
+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
| 55 |
+
|
| 56 |
+
onnx_session = onnxruntime.InferenceSession(onnx_model_path)
|
| 57 |
+
# compute ONNX Runtime output prediction
|
| 58 |
+
onnx_inputs = {onnx_session.get_inputs()[0].name: to_numpy(x)}
|
| 59 |
+
onnx_out = onnx_session.run(None, onnx_inputs)
|
| 60 |
+
# compare ONNX Runtime and PyTorch results
|
| 61 |
+
np.testing.assert_allclose(to_numpy(torch_out), onnx_out[0], rtol=1e-03, atol=1e-05)
|
| 62 |
+
print("Exported model has been tested with ONNXRuntime, and the result looks good!\n")
|
| 63 |
+
|
| 64 |
+
else:
|
| 65 |
+
pytorch_model_path = os.path.join(save_model_path, 'densenet_' + date + '.pth')
|
| 66 |
+
torch.save(model, pytorch_model_path)
|
| 67 |
+
print('Pytorch model saved to ', pytorch_model_path)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def plot_metrics(hist_dict, results_folder, date):
|
| 71 |
+
"""Function for plotting the training and validation results."""
|
| 72 |
+
epochs = range(1, len(hist_dict['tr_loss'])+1)
|
| 73 |
+
plt.plot(epochs, hist_dict['tr_loss'], 'g', label='Training loss')
|
| 74 |
+
plt.plot(epochs, hist_dict['val_loss'], 'b', label='Validation loss')
|
| 75 |
+
plt.title('Training and Validation loss')
|
| 76 |
+
plt.xlabel('Epochs')
|
| 77 |
+
plt.ylabel('Loss')
|
| 78 |
+
plt.legend()
|
| 79 |
+
plt.savefig(results_folder + date + '_tr_val_loss.jpg', bbox_inches='tight')
|
| 80 |
+
plt.close()
|
| 81 |
+
|
| 82 |
+
plt.plot(epochs, hist_dict['tr_acc'], 'g', label='Training accuracy')
|
| 83 |
+
plt.plot(epochs, hist_dict['val_acc'], 'b', label='Validation accuracy')
|
| 84 |
+
plt.title('Training and Validation accuracy')
|
| 85 |
+
plt.xlabel('Epochs')
|
| 86 |
+
plt.ylabel('Accuracy')
|
| 87 |
+
plt.legend()
|
| 88 |
+
plt.savefig(results_folder + date + '_tr_val_acc.jpg', bbox_inches='tight')
|
| 89 |
+
plt.close()
|
| 90 |
+
|
| 91 |
+
plt.plot(epochs, hist_dict['tr_f1'], 'g', label='Training F1 score')
|
| 92 |
+
plt.plot(epochs, hist_dict['val_f1'], 'b', label='Validation F1 score')
|
| 93 |
+
plt.title('Training and Validation F1 score')
|
| 94 |
+
plt.xlabel('Epochs')
|
| 95 |
+
plt.ylabel('F1 score')
|
| 96 |
+
plt.legend()
|
| 97 |
+
plt.savefig(results_folder + date + '_tr_val_f1.jpg', bbox_inches='tight')
|
| 98 |
+
plt.close()
|
| 99 |
+
|
| 100 |
+
plt.plot(epochs, hist_dict['lr1'], 'g', label='Backbone learning rate')
|
| 101 |
+
plt.plot(epochs, hist_dict['lr2'], 'b', label='Classifier learning rate')
|
| 102 |
+
plt.title('Learning rate')
|
| 103 |
+
plt.xlabel('Epochs')
|
| 104 |
+
plt.ylabel('Learning rate')
|
| 105 |
+
plt.legend()
|
| 106 |
+
plt.savefig(results_folder + date + '_learning_rate.jpg', bbox_inches='tight')
|
| 107 |
+
plt.close()
|