Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +7 -0
- deep_cnn_project.py +439 -0
app.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
def greet(name):
|
| 4 |
+
return "Hello " + name + "!!"
|
| 5 |
+
|
| 6 |
+
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 7 |
+
demo.launch()
|
deep_cnn_project.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""Deep CNN Project
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colab.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/1VvAfokQ6mPAsBqBajbRZU13412Wg0A_m
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from google.colab import drive
|
| 11 |
+
drive.mount('/content/drive')
|
| 12 |
+
|
| 13 |
+
"""<a id="import"></a>
|
| 14 |
+
# <center>We have imported important Modules here </center>
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import os
|
| 20 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 21 |
+
import time
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import cv2
|
| 24 |
+
import seaborn as sns
|
| 25 |
+
sns.set_style('darkgrid')
|
| 26 |
+
import shutil
|
| 27 |
+
from sklearn.metrics import confusion_matrix, classification_report
|
| 28 |
+
from sklearn.model_selection import train_test_split
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
from tensorflow import keras
|
| 31 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
| 32 |
+
from tensorflow.keras.layers import Dense, Activation,Dropout,Conv2D, MaxPooling2D,BatchNormalization
|
| 33 |
+
from tensorflow.keras.optimizers import Adam, Adamax
|
| 34 |
+
from tensorflow.keras.metrics import categorical_crossentropy
|
| 35 |
+
from tensorflow.keras import regularizers
|
| 36 |
+
from tensorflow.keras.models import Model
|
| 37 |
+
import time
|
| 38 |
+
|
| 39 |
+
"""<a id="makedf"></a>
|
| 40 |
+
# <center>Read in images and create a dataframe of image paths and class labels</center>
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
sdir=r'/content/drive/MyDrive/archive (4)/Original Images/Original Images'
|
| 44 |
+
filepaths=[]
|
| 45 |
+
labels=[]
|
| 46 |
+
classlist=os.listdir(sdir)
|
| 47 |
+
for klass in classlist:
|
| 48 |
+
classpath=os.path.join(sdir, klass)
|
| 49 |
+
flist=os.listdir(classpath)
|
| 50 |
+
for f in flist:
|
| 51 |
+
fpath=os.path.join(classpath,f)
|
| 52 |
+
filepaths.append(fpath)
|
| 53 |
+
labels.append(klass)
|
| 54 |
+
Fseries=pd.Series(filepaths, name='filepaths')
|
| 55 |
+
Lseries=pd.Series(labels, name='labels')
|
| 56 |
+
df=pd.concat([Fseries, Lseries], axis=1)
|
| 57 |
+
train_df, dummy_df=train_test_split(df, train_size=.75, shuffle=True, random_state=123, stratify=df['labels'])
|
| 58 |
+
valid_df, test_df=train_test_split(dummy_df, train_size=.5, shuffle=True, random_state=123, stratify=dummy_df['labels'])
|
| 59 |
+
print('train_df lenght: ', len(train_df), ' test_df length: ', len(test_df), ' valid_df length: ', len(valid_df))
|
| 60 |
+
# get the number of classes and the images count for each class in train_df
|
| 61 |
+
classes=sorted(list(train_df['labels'].unique()))
|
| 62 |
+
class_count = len(classes)
|
| 63 |
+
print('The number of classes in the dataset is: ', class_count)
|
| 64 |
+
groups=train_df.groupby('labels')
|
| 65 |
+
print('{0:^30s} {1:^13s}'.format('CLASS', 'IMAGE COUNT'))
|
| 66 |
+
countlist=[]
|
| 67 |
+
classlist=[]
|
| 68 |
+
for label in sorted(list(train_df['labels'].unique())):
|
| 69 |
+
group=groups.get_group(label)
|
| 70 |
+
countlist.append(len(group))
|
| 71 |
+
classlist.append(label)
|
| 72 |
+
print('{0:^30s} {1:^13s}'.format(label, str(len(group))))
|
| 73 |
+
|
| 74 |
+
# get the classes with the minimum and maximum number of train images
|
| 75 |
+
max_value=np.max(countlist)
|
| 76 |
+
max_index=countlist.index(max_value)
|
| 77 |
+
max_class=classlist[max_index]
|
| 78 |
+
min_value=np.min(countlist)
|
| 79 |
+
min_index=countlist.index(min_value)
|
| 80 |
+
min_class=classlist[min_index]
|
| 81 |
+
print(max_class, ' has the most images= ',max_value, ' ', min_class, ' has the least images= ', min_value)
|
| 82 |
+
# lets get the average height and width of a sample of the train images
|
| 83 |
+
ht=0
|
| 84 |
+
wt=0
|
| 85 |
+
# select 100 random samples of train_df
|
| 86 |
+
train_df_sample=train_df.sample(n=100, random_state=123,axis=0)
|
| 87 |
+
for i in range (len(train_df_sample)):
|
| 88 |
+
fpath=train_df_sample['filepaths'].iloc[i]
|
| 89 |
+
img=plt.imread(fpath)
|
| 90 |
+
shape=img.shape
|
| 91 |
+
ht += shape[0]
|
| 92 |
+
wt += shape[1]
|
| 93 |
+
print('average height= ', ht//100, ' average width= ', wt//100, 'aspect ratio= ', ht/wt)
|
| 94 |
+
|
| 95 |
+
"""<a id="balance"></a>
|
| 96 |
+
# <center>Balance train_df by creating augmented images</center>
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def balance(df, n, working_dir, img_size):
|
| 101 |
+
df=df.copy()
|
| 102 |
+
print('Initial length of dataframe is ', len(df))
|
| 103 |
+
aug_dir=os.path.join(working_dir, 'aug')# directory to store augmented images
|
| 104 |
+
if os.path.isdir(aug_dir):# start with an empty directory
|
| 105 |
+
shutil.rmtree(aug_dir)
|
| 106 |
+
os.mkdir(aug_dir)
|
| 107 |
+
for label in df['labels'].unique():
|
| 108 |
+
dir_path=os.path.join(aug_dir,label)
|
| 109 |
+
os.mkdir(dir_path) # make class directories within aug directory
|
| 110 |
+
# create and store the augmented images
|
| 111 |
+
total=0
|
| 112 |
+
gen=ImageDataGenerator(horizontal_flip=True, rotation_range=20, width_shift_range=.2,
|
| 113 |
+
height_shift_range=.2, zoom_range=.2)
|
| 114 |
+
groups=df.groupby('labels') # group by class
|
| 115 |
+
for label in df['labels'].unique(): # for every class
|
| 116 |
+
group=groups.get_group(label) # a dataframe holding only rows with the specified label
|
| 117 |
+
sample_count=len(group) # determine how many samples there are in this class
|
| 118 |
+
if sample_count< n: # if the class has less than target number of images
|
| 119 |
+
aug_img_count=0
|
| 120 |
+
delta=n - sample_count # number of augmented images to create
|
| 121 |
+
target_dir=os.path.join(aug_dir, label) # define where to write the images
|
| 122 |
+
msg='{0:40s} for class {1:^30s} creating {2:^5s} augmented images'.format(' ', label, str(delta))
|
| 123 |
+
print(msg, '\r', end='') # prints over on the same line
|
| 124 |
+
aug_gen=gen.flow_from_dataframe( group, x_col='filepaths', y_col=None, target_size=img_size,
|
| 125 |
+
class_mode=None, batch_size=1, shuffle=False,
|
| 126 |
+
save_to_dir=target_dir, save_prefix='aug-', color_mode='rgb',
|
| 127 |
+
save_format='jpg')
|
| 128 |
+
while aug_img_count<delta:
|
| 129 |
+
images=next(aug_gen)
|
| 130 |
+
aug_img_count += len(images)
|
| 131 |
+
total +=aug_img_count
|
| 132 |
+
print('Total Augmented images created= ', total)
|
| 133 |
+
# create aug_df and merge with train_df to create composite training set ndf
|
| 134 |
+
aug_fpaths=[]
|
| 135 |
+
aug_labels=[]
|
| 136 |
+
classlist=os.listdir(aug_dir)
|
| 137 |
+
for klass in classlist:
|
| 138 |
+
classpath=os.path.join(aug_dir, klass)
|
| 139 |
+
flist=os.listdir(classpath)
|
| 140 |
+
for f in flist:
|
| 141 |
+
fpath=os.path.join(classpath,f)
|
| 142 |
+
aug_fpaths.append(fpath)
|
| 143 |
+
aug_labels.append(klass)
|
| 144 |
+
Fseries=pd.Series(aug_fpaths, name='filepaths')
|
| 145 |
+
Lseries=pd.Series(aug_labels, name='labels')
|
| 146 |
+
aug_df=pd.concat([Fseries, Lseries], axis=1)
|
| 147 |
+
df=pd.concat([df,aug_df], axis=0).reset_index(drop=True)
|
| 148 |
+
print('Length of augmented dataframe is now ', len(df))
|
| 149 |
+
return df
|
| 150 |
+
|
| 151 |
+
n=200 # number of samples in each class
|
| 152 |
+
working_dir=r'./' # directory to store augmented images
|
| 153 |
+
img_size=(224,224) # size of augmented images
|
| 154 |
+
train_df=balance(train_df, n, working_dir, img_size)
|
| 155 |
+
|
| 156 |
+
"""<a id="generators"></a>
|
| 157 |
+
# <center>Create the train_gen, test_gen final_test_gen and valid_gen</center>
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
batch_size=20 # We will use and EfficientetB3 model, with image size of (200, 250) this size should not cause resource error
|
| 161 |
+
trgen=ImageDataGenerator(horizontal_flip=True,rotation_range=20, width_shift_range=.2,
|
| 162 |
+
height_shift_range=.2, zoom_range=.2 )
|
| 163 |
+
t_and_v_gen=ImageDataGenerator()
|
| 164 |
+
msg='{0:70s} for train generator'.format(' ')
|
| 165 |
+
print(msg, '\r', end='') # prints over on the same line
|
| 166 |
+
train_gen=trgen.flow_from_dataframe(train_df, x_col='filepaths', y_col='labels', target_size=img_size,
|
| 167 |
+
class_mode='categorical', color_mode='rgb', shuffle=True, batch_size=batch_size)
|
| 168 |
+
msg='{0:70s} for valid generator'.format(' ')
|
| 169 |
+
print(msg, '\r', end='') # prints over on the same line
|
| 170 |
+
valid_gen=t_and_v_gen.flow_from_dataframe(valid_df, x_col='filepaths', y_col='labels', target_size=img_size,
|
| 171 |
+
class_mode='categorical', color_mode='rgb', shuffle=False, batch_size=batch_size)
|
| 172 |
+
# for the test_gen we want to calculate the batch size and test steps such that batch_size X test_steps= number of samples in test set
|
| 173 |
+
# this insures that we go through all the sample in the test set exactly once.
|
| 174 |
+
length=len(test_df)
|
| 175 |
+
test_batch_size=sorted([int(length/n) for n in range(1,length+1) if length % n ==0 and length/n<=80],reverse=True)[0]
|
| 176 |
+
test_steps=int(length/test_batch_size)
|
| 177 |
+
msg='{0:70s} for test generator'.format(' ')
|
| 178 |
+
print(msg, '\r', end='') # prints over on the same line
|
| 179 |
+
test_gen=t_and_v_gen.flow_from_dataframe(test_df, x_col='filepaths', y_col='labels', target_size=img_size,
|
| 180 |
+
class_mode='categorical', color_mode='rgb', shuffle=False, batch_size=test_batch_size)
|
| 181 |
+
# from the generator we can get information we will need later
|
| 182 |
+
classes=list(train_gen.class_indices.keys())
|
| 183 |
+
class_indices=list(train_gen.class_indices.values())
|
| 184 |
+
class_count=len(classes)
|
| 185 |
+
labels=test_gen.labels
|
| 186 |
+
print ( 'test batch size: ' ,test_batch_size, ' test steps: ', test_steps, ' number of classes : ', class_count)
|
| 187 |
+
|
| 188 |
+
"""<a id="show"></a>
|
| 189 |
+
# <center>Create a function to show example training images</center>
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def show_image_samples(gen ):
|
| 193 |
+
t_dict=gen.class_indices
|
| 194 |
+
classes=list(t_dict.keys())
|
| 195 |
+
images,labels=next(gen) # get a sample batch from the generator
|
| 196 |
+
plt.figure(figsize=(20, 20))
|
| 197 |
+
length=len(labels)
|
| 198 |
+
if length<25: #show maximum of 25 images
|
| 199 |
+
r=length
|
| 200 |
+
else:
|
| 201 |
+
r=25
|
| 202 |
+
for i in range(r):
|
| 203 |
+
plt.subplot(5, 5, i + 1)
|
| 204 |
+
image=images[i] /255
|
| 205 |
+
plt.imshow(image)
|
| 206 |
+
index=np.argmax(labels[i])
|
| 207 |
+
class_name=classes[index]
|
| 208 |
+
plt.title(class_name, color='blue', fontsize=14)
|
| 209 |
+
plt.axis('off')
|
| 210 |
+
plt.show()
|
| 211 |
+
|
| 212 |
+
show_image_samples(train_gen )
|
| 213 |
+
|
| 214 |
+
"""<a id="model"></a>
|
| 215 |
+
# <center>Create a model using transfer learning with EfficientNetB3</center>
|
| 216 |
+
### NOTE experts advise you make the base model initially not trainable. Then train for some number of epochs
|
| 217 |
+
### then fine tune model by making base model trainable and run more epochs
|
| 218 |
+
### I have found this to be WRONG!!!!
|
| 219 |
+
### Making the base model trainable from the outset leads to faster convegence and a lower validation loss
|
| 220 |
+
### for the same number of total epochs!
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
img_shape=(img_size[0], img_size[1], 3)
|
| 224 |
+
model_name='EfficientNetB3'
|
| 225 |
+
base_model=tf.keras.applications.efficientnet.EfficientNetB3(include_top=False, weights="imagenet",input_shape=img_shape, pooling='max')
|
| 226 |
+
# Note you are always told NOT to make the base model trainable initially- that is WRONG you get better results leaving it trainable
|
| 227 |
+
base_model.trainable=True
|
| 228 |
+
x=base_model.output
|
| 229 |
+
x=BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001 )(x)
|
| 230 |
+
x = Dense(256, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006),
|
| 231 |
+
bias_regularizer=regularizers.l1(0.006) ,activation='relu')(x)
|
| 232 |
+
x=Dropout(rate=.4, seed=123)(x)
|
| 233 |
+
output=Dense(class_count, activation='softmax')(x)
|
| 234 |
+
model=Model(inputs=base_model.input, outputs=output)
|
| 235 |
+
lr=.001 # start with this learning rate
|
| 236 |
+
model.compile(Adamax(learning_rate=lr), loss='categorical_crossentropy', metrics=['accuracy'])
|
| 237 |
+
|
| 238 |
+
"""<a id="callback"></a>
|
| 239 |
+
# <center>Create a custom Keras callback to continue and optionally set LR or halt training</center>
|
| 240 |
+
The LR_ASK callback is a convenient callback that allows you to continue training for ask_epoch more epochs or to halt training.
|
| 241 |
+
If you elect to continue training for more epochs you are given the option to retain the current learning rate (LR) or to
|
| 242 |
+
enter a new value for the learning rate. The form of use is:
|
| 243 |
+
ask=LR_ASK(model,epochs, ask_epoch) where:
|
| 244 |
+
* model is a string which is the name of your compiled model
|
| 245 |
+
* epochs is an integer which is the number of epochs to run specified in model.fit
|
| 246 |
+
* ask_epoch is an integer. If ask_epoch is set to a value say 5 then the model will train for 5 epochs.
|
| 247 |
+
then the user is ask to enter H to halt training, or enter an inter value. For example if you enter 4
|
| 248 |
+
training will continue for 4 more epochs to epoch 9 then you will be queried again. Once you enter an
|
| 249 |
+
integer value you are prompted to press ENTER to continue training using the current learning rate
|
| 250 |
+
or to enter a new value for the learning rate.
|
| 251 |
+
|
| 252 |
+
At the end of training the model weights are set to the weights for the epoch that achieved the lowest validation loss
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
class LR_ASK(keras.callbacks.Callback):
|
| 256 |
+
def __init__ (self, model, epochs, ask_epoch): # initialization of the callback
|
| 257 |
+
super(LR_ASK, self).__init__()
|
| 258 |
+
self.model=model
|
| 259 |
+
self.ask_epoch=ask_epoch
|
| 260 |
+
self.epochs=epochs
|
| 261 |
+
self.ask=True # if True query the user on a specified epoch
|
| 262 |
+
self.lowest_vloss=np.inf
|
| 263 |
+
self.best_weights=self.model.get_weights() # set best weights to model's initial weights
|
| 264 |
+
self.best_epoch=1
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def on_train_begin(self, logs=None): # this runs on the beginning of training
|
| 268 |
+
if self.ask_epoch == 0:
|
| 269 |
+
print('you set ask_epoch = 0, ask_epoch will be set to 1', flush=True)
|
| 270 |
+
self.ask_epoch=1
|
| 271 |
+
if self.ask_epoch >= self.epochs: # you are running for epochs but ask_epoch>epochs
|
| 272 |
+
print('ask_epoch >= epochs, will train for ', epochs, ' epochs', flush=True)
|
| 273 |
+
self.ask=False # do not query the user
|
| 274 |
+
if self.epochs == 1:
|
| 275 |
+
self.ask=False # running only for 1 epoch so do not query user
|
| 276 |
+
else:
|
| 277 |
+
print('Training will proceed until epoch', ask_epoch,' then you will be asked to')
|
| 278 |
+
print(' enter H to halt training or enter an integer for how many more epochs to run then be asked again')
|
| 279 |
+
self.start_time= time.time() # set the time at which training started
|
| 280 |
+
|
| 281 |
+
def on_train_end(self, logs=None): # runs at the end of training
|
| 282 |
+
print('loading model with weights from epoch ', self.best_epoch)
|
| 283 |
+
self.model.set_weights(self.best_weights) # set the weights of the model to the best weights
|
| 284 |
+
tr_duration=time.time() - self.start_time # determine how long the training cycle lasted
|
| 285 |
+
hours = tr_duration // 3600
|
| 286 |
+
minutes = (tr_duration - (hours * 3600)) // 60
|
| 287 |
+
seconds = tr_duration - ((hours * 3600) + (minutes * 60))
|
| 288 |
+
msg = f'training elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds)'
|
| 289 |
+
print (msg, flush=True) # print out training duration time
|
| 290 |
+
|
| 291 |
+
def on_epoch_end(self, epoch, logs=None): # method runs on the end of each epoch
|
| 292 |
+
v_loss=logs.get('val_loss') # get the validation loss for this epoch
|
| 293 |
+
if v_loss< self.lowest_vloss:
|
| 294 |
+
self.lowest_vloss=v_loss
|
| 295 |
+
self.best_weights=self.model.get_weights() # set best weights to model's initial weights
|
| 296 |
+
self.best_epoch=epoch + 1
|
| 297 |
+
print (f'\n validation loss of {v_loss:7.4f} is below lowest loss, saving weights from epoch {str(epoch + 1):3s} as best weights')
|
| 298 |
+
else:
|
| 299 |
+
print (f'\n validation loss of {v_loss:7.4f} is above lowest loss of {self.lowest_vloss:7.4f} keeping weights from epoch {str(self.best_epoch)} as best weights')
|
| 300 |
+
|
| 301 |
+
if self.ask: # are the conditions right to query the user?
|
| 302 |
+
if epoch + 1 ==self.ask_epoch: # is this epoch the one for quering the user?
|
| 303 |
+
print('\n Enter H to end training or an integer for the number of additional epochs to run then ask again')
|
| 304 |
+
ans=input()
|
| 305 |
+
|
| 306 |
+
if ans == 'H' or ans =='h' or ans == '0': # quit training for these conditions
|
| 307 |
+
print ('you entered ', ans, ' Training halted on epoch ', epoch+1, ' due to user input\n', flush=True)
|
| 308 |
+
self.model.stop_training = True # halt training
|
| 309 |
+
else: # user wants to continue training
|
| 310 |
+
self.ask_epoch += int(ans)
|
| 311 |
+
if self.ask_epoch > self.epochs:
|
| 312 |
+
print('\nYou specified maximum epochs of as ', self.epochs, ' cannot train for ', self.ask_epoch, flush =True)
|
| 313 |
+
else:
|
| 314 |
+
print ('you entered ', ans, ' Training will continue to epoch ', self.ask_epoch, flush=True)
|
| 315 |
+
lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
|
| 316 |
+
print(f'current LR is {lr:7.5f} hit enter to keep this LR or enter a new LR')
|
| 317 |
+
ans=input(' ')
|
| 318 |
+
if ans =='':
|
| 319 |
+
print (f'keeping current LR of {lr:7.5f}')
|
| 320 |
+
else:
|
| 321 |
+
new_lr=float(ans)
|
| 322 |
+
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr) # set the learning rate in the optimizer
|
| 323 |
+
print(' changing LR to ', ans)
|
| 324 |
+
|
| 325 |
+
"""<a id="callbacks"></a>
|
| 326 |
+
# <center>Instantiate custom callback
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
epochs=40
|
| 330 |
+
ask_epoch=5
|
| 331 |
+
ask=LR_ASK(model, epochs, ask_epoch)
|
| 332 |
+
#rlronp=tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2,verbose=1)
|
| 333 |
+
#callbacks=[rlronp, ask]
|
| 334 |
+
callbacks=[ask]
|
| 335 |
+
|
| 336 |
+
"""<a id="train"></a>
|
| 337 |
+
# <center>Train the model
|
| 338 |
+
### Note unlike how you are told it is BETTER to make the base model trainable from the outset
|
| 339 |
+
### It will converge faster and have a lower validation losss
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
history=model.fit(x=train_gen, epochs=epochs, verbose=1, callbacks=callbacks, validation_data=valid_gen,
|
| 343 |
+
validation_steps=None, shuffle=False, initial_epoch=0)
|
| 344 |
+
|
| 345 |
+
"""<a id="plot"></a>
|
| 346 |
+
# <center>Define a function to plot the training data
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def tr_plot(tr_data, start_epoch):
|
| 350 |
+
#Plot the training and validation data
|
| 351 |
+
tacc=tr_data.history['accuracy']
|
| 352 |
+
tloss=tr_data.history['loss']
|
| 353 |
+
vacc=tr_data.history['val_accuracy']
|
| 354 |
+
vloss=tr_data.history['val_loss']
|
| 355 |
+
Epoch_count=len(tacc)+ start_epoch
|
| 356 |
+
Epochs=[]
|
| 357 |
+
for i in range (start_epoch ,Epoch_count):
|
| 358 |
+
Epochs.append(i+1)
|
| 359 |
+
index_loss=np.argmin(vloss)# this is the epoch with the lowest validation loss
|
| 360 |
+
val_lowest=vloss[index_loss]
|
| 361 |
+
index_acc=np.argmax(vacc)
|
| 362 |
+
acc_highest=vacc[index_acc]
|
| 363 |
+
plt.style.use('fivethirtyeight')
|
| 364 |
+
sc_label='best epoch= '+ str(index_loss+1 +start_epoch)
|
| 365 |
+
vc_label='best epoch= '+ str(index_acc + 1+ start_epoch)
|
| 366 |
+
fig,axes=plt.subplots(nrows=1, ncols=2, figsize=(20,8))
|
| 367 |
+
axes[0].plot(Epochs,tloss, 'r', label='Training loss')
|
| 368 |
+
axes[0].plot(Epochs,vloss,'g',label='Validation loss' )
|
| 369 |
+
axes[0].scatter(index_loss+1 +start_epoch,val_lowest, s=150, c= 'blue', label=sc_label)
|
| 370 |
+
axes[0].set_title('Training and Validation Loss')
|
| 371 |
+
axes[0].set_xlabel('Epochs')
|
| 372 |
+
axes[0].set_ylabel('Loss')
|
| 373 |
+
axes[0].legend()
|
| 374 |
+
axes[1].plot (Epochs,tacc,'r',label= 'Training Accuracy')
|
| 375 |
+
axes[1].plot (Epochs,vacc,'g',label= 'Validation Accuracy')
|
| 376 |
+
axes[1].scatter(index_acc+1 +start_epoch,acc_highest, s=150, c= 'blue', label=vc_label)
|
| 377 |
+
axes[1].set_title('Training and Validation Accuracy')
|
| 378 |
+
axes[1].set_xlabel('Epochs')
|
| 379 |
+
axes[1].set_ylabel('Accuracy')
|
| 380 |
+
axes[1].legend()
|
| 381 |
+
plt.tight_layout
|
| 382 |
+
plt.show()
|
| 383 |
+
|
| 384 |
+
tr_plot(history,0)
|
| 385 |
+
|
| 386 |
+
"""<a id="result"></a>
|
| 387 |
+
# <center>Make Predictions on the test set</a>
|
| 388 |
+
### Define a function which takes in a test generator and an integer test_steps
|
| 389 |
+
### and generates predictions on the test set including a confusion matric
|
| 390 |
+
### and a classification report
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
def predictor(test_gen, test_steps):
|
| 394 |
+
y_pred= []
|
| 395 |
+
y_true=test_gen.labels
|
| 396 |
+
classes=list(test_gen.class_indices.keys())
|
| 397 |
+
class_count=len(classes)
|
| 398 |
+
errors=0
|
| 399 |
+
preds=model.predict(test_gen, verbose=1)
|
| 400 |
+
tests=len(preds)
|
| 401 |
+
for i, p in enumerate(preds):
|
| 402 |
+
pred_index=np.argmax(p)
|
| 403 |
+
true_index=test_gen.labels[i] # labels are integer values
|
| 404 |
+
if pred_index != true_index: # a misclassification has occurred
|
| 405 |
+
errors=errors + 1
|
| 406 |
+
y_pred.append(pred_index)
|
| 407 |
+
|
| 408 |
+
acc=( 1-errors/tests) * 100
|
| 409 |
+
print(f'there were {errors} errors in {tests} tests for an accuracy of {acc:6.2f}')
|
| 410 |
+
ypred=np.array(y_pred)
|
| 411 |
+
ytrue=np.array(y_true)
|
| 412 |
+
if class_count <=30:
|
| 413 |
+
cm = confusion_matrix(ytrue, ypred )
|
| 414 |
+
# plot the confusion matrix
|
| 415 |
+
plt.figure(figsize=(12, 8))
|
| 416 |
+
sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Reds', cbar=False)
|
| 417 |
+
plt.xticks(np.arange(class_count)+.5, classes, rotation=90)
|
| 418 |
+
plt.yticks(np.arange(class_count)+.5, classes, rotation=0)
|
| 419 |
+
plt.xlabel("Predicted")
|
| 420 |
+
plt.ylabel("Actual")
|
| 421 |
+
plt.title("Confusion Matrix")
|
| 422 |
+
plt.show()
|
| 423 |
+
clr = classification_report(y_true, y_pred, target_names=classes, digits= 4) # create classification report
|
| 424 |
+
print("Classification Report:\n----------------------\n", clr)
|
| 425 |
+
return errors, tests
|
| 426 |
+
errors, tests=predictor(test_gen, test_steps)
|
| 427 |
+
|
| 428 |
+
"""<a id="save"></a>
|
| 429 |
+
# <center>Save the model
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
subject='monkey pox'
|
| 433 |
+
acc=str(( 1-errors/tests) * 100)
|
| 434 |
+
index=acc.rfind('.')
|
| 435 |
+
acc=acc[:index + 3]
|
| 436 |
+
save_id= subject + '_' + str(acc) + '.h5'
|
| 437 |
+
model_save_loc=os.path.join(working_dir, save_id)
|
| 438 |
+
model.save(model_save_loc)
|
| 439 |
+
print ('model was saved as ' , model_save_loc )
|