Delete main.py
Browse files
main.py
DELETED
|
@@ -1,560 +0,0 @@
|
|
| 1 |
-
###### Train CIFAR10 with PyTorch. ######
|
| 2 |
-
|
| 3 |
-
### IMPORT DEPENDENCIES
|
| 4 |
-
|
| 5 |
-
from torch.utils.data import DataLoader
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.optim as optim
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
import torch.backends.cudnn as cudnn
|
| 11 |
-
import gradio as gr
|
| 12 |
-
import wandb
|
| 13 |
-
import math
|
| 14 |
-
import numpy as np
|
| 15 |
-
import matplotlib.pyplot as plt
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
import torchvision
|
| 19 |
-
import torchvision.transforms as transforms
|
| 20 |
-
import torchvision.models as models
|
| 21 |
-
import torch.optim.lr_scheduler as lr_scheduler
|
| 22 |
-
import os
|
| 23 |
-
import argparse
|
| 24 |
-
import torchattacks
|
| 25 |
-
|
| 26 |
-
from models import *
|
| 27 |
-
|
| 28 |
-
from tqdm import tqdm
|
| 29 |
-
from PIL import Image
|
| 30 |
-
import gradio as gr
|
| 31 |
-
|
| 32 |
-
# from utils import progress_bar
|
| 33 |
-
|
| 34 |
-
# CSS theme styling
|
| 35 |
-
theme = gr.themes.Base(
|
| 36 |
-
font=[gr.themes.GoogleFont('Montserrat'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
| 37 |
-
primary_hue="emerald",
|
| 38 |
-
secondary_hue="emerald",
|
| 39 |
-
neutral_hue="zinc"
|
| 40 |
-
).set(
|
| 41 |
-
body_text_color='*neutral_950',
|
| 42 |
-
body_text_color_subdued='*neutral_950',
|
| 43 |
-
block_shadow='*shadow_drop_lg',
|
| 44 |
-
button_shadow='*shadow_drop_lg',
|
| 45 |
-
block_title_text_color='*neutral_950',
|
| 46 |
-
block_title_text_weight='500',
|
| 47 |
-
slider_color='*secondary_600'
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
def normalize(img):
|
| 51 |
-
min_im = np.min(img)
|
| 52 |
-
np_img = img - min_im
|
| 53 |
-
max_im = np.max(np_img)
|
| 54 |
-
np_img /= max_im
|
| 55 |
-
return np_img
|
| 56 |
-
|
| 57 |
-
def imshow(img, fig_name = "test_input.png"):
|
| 58 |
-
try:
|
| 59 |
-
img = img.clone().detach().cpu().numpy()
|
| 60 |
-
except:
|
| 61 |
-
print('img already numpy')
|
| 62 |
-
|
| 63 |
-
plt.imshow(normalize(np.transpose(img, (1, 2, 0))))
|
| 64 |
-
plt.savefig(fig_name)
|
| 65 |
-
print(f'Figure saved as {fig_name}')
|
| 66 |
-
return fig_name
|
| 67 |
-
|
| 68 |
-
def class_names(class_num, class_list): # converts the raw number label to text
|
| 69 |
-
if (class_num < 0) and (class_num >= 10):
|
| 70 |
-
gr.Warning("Class List Error")
|
| 71 |
-
return
|
| 72 |
-
return class_list[class_num]
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
### MAIN FUNCTION
|
| 76 |
-
best_acc = 0
|
| 77 |
-
def main(drop_type, epochs_sldr, train_sldr, test_sldr, learning_rate, optimizer, sigma_sldr, adv_attack, username, scheduler):
|
| 78 |
-
|
| 79 |
-
## Input protection
|
| 80 |
-
if not drop_type:
|
| 81 |
-
gr.Warning("Please select a model from the dropdown.")
|
| 82 |
-
return
|
| 83 |
-
if not username:
|
| 84 |
-
gr.Warning("Please enter a WandB username.")
|
| 85 |
-
return
|
| 86 |
-
if(epochs_sldr % 1 != 0):
|
| 87 |
-
gr.Warning("Number of epochs must be an integer.")
|
| 88 |
-
return
|
| 89 |
-
if(train_sldr % 1 != 0):
|
| 90 |
-
gr.Warning("Training batch size must be an integer.")
|
| 91 |
-
return
|
| 92 |
-
if(test_sldr % 1 != 0):
|
| 93 |
-
gr.Warning("Testing batch size must be an integer.")
|
| 94 |
-
return
|
| 95 |
-
|
| 96 |
-
num_epochs = int(epochs_sldr)
|
| 97 |
-
global learn_batch
|
| 98 |
-
learn_batch = int(train_sldr)
|
| 99 |
-
global test_batch
|
| 100 |
-
test_batch = int(test_sldr)
|
| 101 |
-
learning_rate = float(learning_rate)
|
| 102 |
-
optimizer_choose = str(optimizer)
|
| 103 |
-
sigma = float(sigma_sldr)
|
| 104 |
-
attack = str(adv_attack)
|
| 105 |
-
scheduler_choose = str(scheduler)
|
| 106 |
-
|
| 107 |
-
# REPLACE ENTITY WITH USERNAME BELOW
|
| 108 |
-
wandb.init(entity=username, project="model-training")
|
| 109 |
-
|
| 110 |
-
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
|
| 111 |
-
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
|
| 112 |
-
parser.add_argument('--resume', '-r', action='store_true',
|
| 113 |
-
help='resume from checkpoint')
|
| 114 |
-
args = parser.parse_args()
|
| 115 |
-
|
| 116 |
-
if torch.cuda.is_available():
|
| 117 |
-
device = 'cuda'
|
| 118 |
-
gr.Info("Cuda detected - running on Cuda")
|
| 119 |
-
elif torch.backends.mps.is_available():
|
| 120 |
-
device = 'mps'
|
| 121 |
-
gr.Info("MPS detected - running on Metal")
|
| 122 |
-
else:
|
| 123 |
-
device = 'cpu'
|
| 124 |
-
gr.Info("No GPU Detected - running on CPU")
|
| 125 |
-
|
| 126 |
-
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
| 127 |
-
|
| 128 |
-
## Data
|
| 129 |
-
try:
|
| 130 |
-
print('==> Preparing data..')
|
| 131 |
-
transform_train = transforms.Compose([
|
| 132 |
-
transforms.RandomCrop(32, padding=4),
|
| 133 |
-
transforms.RandomHorizontalFlip(),
|
| 134 |
-
transforms.ToTensor(),
|
| 135 |
-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
| 136 |
-
])
|
| 137 |
-
|
| 138 |
-
transform_test = transforms.Compose([
|
| 139 |
-
transforms.ToTensor(),
|
| 140 |
-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
| 141 |
-
])
|
| 142 |
-
|
| 143 |
-
trainset = torchvision.datasets.CIFAR10(
|
| 144 |
-
root='./data', train=True, download=True, transform=transform_train)
|
| 145 |
-
trainloader = DataLoader(
|
| 146 |
-
trainset, batch_size=learn_batch, shuffle=True, num_workers=2)
|
| 147 |
-
|
| 148 |
-
testset = torchvision.datasets.CIFAR10(
|
| 149 |
-
root='./data', train=False, download=True, transform=transform_test)
|
| 150 |
-
testloader = DataLoader(
|
| 151 |
-
testset, batch_size=test_batch, shuffle=True, num_workers=2)
|
| 152 |
-
|
| 153 |
-
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
| 154 |
-
except Exception as e:
|
| 155 |
-
print(f"Error: {e}")
|
| 156 |
-
gr.Warning(f"Data Loading Error: {e}")
|
| 157 |
-
|
| 158 |
-
## Model
|
| 159 |
-
try:
|
| 160 |
-
print('==> Building model..')
|
| 161 |
-
net = models_dict.get(drop_type, None)
|
| 162 |
-
|
| 163 |
-
# Make list of models containing either classifer or fc functions
|
| 164 |
-
classifier_models = ['ConvNext_Small', 'ConvNext_Base', 'ConvNext_Large', 'DenseNet', 'EfficientNet_B0', 'MobileNetV2',
|
| 165 |
-
'MaxVit', 'MnasNet0_5', 'SqueezeNet', 'VGG19']
|
| 166 |
-
fc_models = ['GoogLeNet', 'InceptionNetV3', 'RegNet_X_400MF', 'ResNet18', 'ShuffleNet_V2_X0_5']
|
| 167 |
-
|
| 168 |
-
# Check dropdown choice for fc or classifier function implementation
|
| 169 |
-
if net in classifier_models:
|
| 170 |
-
num_ftrs = net.classifier[-1].in_features
|
| 171 |
-
net.classifier[-1] = torch.nn.Linear(num_ftrs, len(classes))
|
| 172 |
-
elif net in fc_models:
|
| 173 |
-
num_ftrs = net.fc.in_features
|
| 174 |
-
net.fc = torch.nn.Linear(num_ftrs, len(classes))
|
| 175 |
-
|
| 176 |
-
net = net.to(device)
|
| 177 |
-
|
| 178 |
-
except Exception as e:
|
| 179 |
-
print(f"Error: {e}")
|
| 180 |
-
gr.Warning(f"Model Building Error: {e}")
|
| 181 |
-
|
| 182 |
-
# if args.resume:
|
| 183 |
-
# # Load checkpoint.
|
| 184 |
-
# print('==> Resuming from checkpoint..')
|
| 185 |
-
# assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
|
| 186 |
-
# checkpoint = torch.load('./checkpoint/ckpt.pth')
|
| 187 |
-
# net.load_state_dict(checkpoint['net'])
|
| 188 |
-
# best_acc = checkpoint['acc']
|
| 189 |
-
# start_epoch = checkpoint['epoch']
|
| 190 |
-
|
| 191 |
-
SGDopt = optim.SGD(net.parameters(), lr=learning_rate,momentum=0.9, weight_decay=5e-4)
|
| 192 |
-
Adamopt = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=5e-4)
|
| 193 |
-
|
| 194 |
-
criterion = nn.CrossEntropyLoss()
|
| 195 |
-
|
| 196 |
-
if optimizer_choose == "SGD":
|
| 197 |
-
optimizer = SGDopt
|
| 198 |
-
elif optimizer_choose == "Adam":
|
| 199 |
-
optimizer = Adamopt
|
| 200 |
-
print (f'optimizer: {optimizer}')
|
| 201 |
-
|
| 202 |
-
#scheduler = lr_scheduler.LinearLR(optimizer, start_factor=learning_rate, end_factor=0.0001, total_iters=10)
|
| 203 |
-
if scheduler_choose == "CosineAnnealingLR":
|
| 204 |
-
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
|
| 205 |
-
elif scheduler_choose == "ReduceLROnPlateau":
|
| 206 |
-
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5)
|
| 207 |
-
elif scheduler_choose == "StepLR":
|
| 208 |
-
scheduler = lr_scheduler.StepLR(optimizer, step_size=30)
|
| 209 |
-
print (f'scheduler: {scheduler_choose}')
|
| 210 |
-
|
| 211 |
-
img_labels = [] # initialize list for label generation
|
| 212 |
-
raw_image_list = [] # initialize list for image generation
|
| 213 |
-
img_list1 = [] # initialize list for combined image/labels
|
| 214 |
-
img_list2 = [] # initialize list for gaussian image generation
|
| 215 |
-
img_list3 = [] # initialize list for adversarial attack image generation
|
| 216 |
-
|
| 217 |
-
# The following lists are used when generating all images in an epoch instead of 10:
|
| 218 |
-
full_img_labels = []
|
| 219 |
-
full_raw_image_list = []
|
| 220 |
-
full_img_list1 = []
|
| 221 |
-
|
| 222 |
-
adv_num = 1 # initialize adversarial image number for naming purposes
|
| 223 |
-
global gaussian_num
|
| 224 |
-
gaussian_num = 1 # initialize gaussian noise image number for naming purposes
|
| 225 |
-
|
| 226 |
-
for epoch in range(start_epoch, start_epoch+epochs_sldr):
|
| 227 |
-
if sigma == 0:
|
| 228 |
-
train(epoch, net, trainloader, device, optimizer, criterion, sigma)
|
| 229 |
-
else:
|
| 230 |
-
gaussian_fig = train(epoch, net, trainloader, device, optimizer, criterion, sigma)
|
| 231 |
-
acc, predicted = test(epoch, net, testloader, device, criterion)
|
| 232 |
-
|
| 233 |
-
if scheduler_choose == "ReduceLROnPlateau":
|
| 234 |
-
scheduler.step(metrics=acc)
|
| 235 |
-
elif not scheduler_choose == "None":
|
| 236 |
-
scheduler.step()
|
| 237 |
-
|
| 238 |
-
if (((epoch-1) % 10 == 0) or (epoch == 0)) and (epoch != 1): # generate images every 10 epochs (and the 0th epoch)
|
| 239 |
-
dataiter = iter(testloader)
|
| 240 |
-
imgs, labels = next(dataiter)
|
| 241 |
-
normalized_imgs = (imgs-imgs.min())/(imgs.max()-imgs.min())
|
| 242 |
-
atk = torchattacks.PGD(net, eps=0.00015, alpha=0.0000000000000001, steps=7)
|
| 243 |
-
if attack == "Yes":
|
| 244 |
-
if normalized_imgs is None:
|
| 245 |
-
print("error occured")
|
| 246 |
-
else:
|
| 247 |
-
print(torch.std(normalized_imgs))
|
| 248 |
-
atk.set_normalization_used(mean = torch.mean(normalized_imgs,axis=[0,2,3]), std=torch.std(normalized_imgs,axis=[0,2,3])/1.125)
|
| 249 |
-
adv_images = atk(imgs, labels)
|
| 250 |
-
fig_name = imshow(adv_images[0], fig_name = f'figures/adversarial_attack{adv_num}.png')
|
| 251 |
-
attack_fig = Image.open(fig_name)
|
| 252 |
-
for i in range(1): # generate 1 image per epoch
|
| 253 |
-
img_list3.append(attack_fig)
|
| 254 |
-
adv_num = adv_num + 1
|
| 255 |
-
for i in range(10): # generate 10 images per epoch
|
| 256 |
-
gradio_imgs = transforms.functional.to_pil_image(normalized_imgs[i])
|
| 257 |
-
raw_image_list.append(gradio_imgs)
|
| 258 |
-
predicted_text = class_names(predicted[i].item(), classes)
|
| 259 |
-
actual_text = class_names(labels[i].item(), classes)
|
| 260 |
-
label_text = f'Epoch: {epoch} | Predicted: {predicted_text} | Actual: {actual_text}'
|
| 261 |
-
img_labels.append(label_text)
|
| 262 |
-
for i in range(test_batch): # generate all images per epoch
|
| 263 |
-
full_gradio_imgs = transforms.functional.to_pil_image(normalized_imgs[i])
|
| 264 |
-
full_raw_image_list.append(full_gradio_imgs)
|
| 265 |
-
full_predicted_text = class_names(predicted[i].item(), classes)
|
| 266 |
-
full_actual_text = class_names(labels[i].item(), classes)
|
| 267 |
-
full_label_text = f'Epoch: {epoch} | Predicted: {full_predicted_text} | Actual: {full_actual_text}'
|
| 268 |
-
full_img_labels.append(full_label_text)
|
| 269 |
-
for i in range(len(raw_image_list)):
|
| 270 |
-
img_tuple = (raw_image_list[i], img_labels[i])
|
| 271 |
-
img_list1.append(img_tuple)
|
| 272 |
-
for i in range(len(full_raw_image_list)):
|
| 273 |
-
full_img_tuple = (full_raw_image_list[i], full_img_labels[i])
|
| 274 |
-
full_img_list1.append(full_img_tuple)
|
| 275 |
-
if sigma != 0:
|
| 276 |
-
for i in range(1): # generate 1 image per epoch
|
| 277 |
-
img_list2.append(gaussian_fig)
|
| 278 |
-
gaussian_num = gaussian_num + 1
|
| 279 |
-
if (sigma == 0) and (attack == "No"):
|
| 280 |
-
return str(acc)+"%", img_list1, full_img_list1, None, None
|
| 281 |
-
elif (sigma != 0) and (attack == "No"):
|
| 282 |
-
return str(acc)+"%", img_list1, full_img_list1, img_list2, None
|
| 283 |
-
elif (sigma == 0) and (attack == "Yes"):
|
| 284 |
-
return str(acc)+"%", img_list1, full_img_list1, None, img_list3
|
| 285 |
-
else:
|
| 286 |
-
return str(acc)+"%", img_list1, full_img_list1, img_list2, img_list3
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
### TRAINING
|
| 291 |
-
def train(epoch, net, trainloader, device, optimizer, criterion, sigma, progress=gr.Progress()):
|
| 292 |
-
try:
|
| 293 |
-
print('\nEpoch: %d' % epoch)
|
| 294 |
-
net.train()
|
| 295 |
-
train_loss = 0
|
| 296 |
-
correct = 0
|
| 297 |
-
total = 0
|
| 298 |
-
|
| 299 |
-
iter_float = 50000/learn_batch
|
| 300 |
-
iterations = math.ceil(iter_float)
|
| 301 |
-
iter_prog = 0
|
| 302 |
-
|
| 303 |
-
for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader)):
|
| 304 |
-
if sigma == 0:
|
| 305 |
-
inputs, targets = inputs.to(device), targets.to(device)
|
| 306 |
-
optimizer.zero_grad()
|
| 307 |
-
outputs = net(inputs)
|
| 308 |
-
else:
|
| 309 |
-
noise = np.random.normal(0, sigma, inputs.shape)
|
| 310 |
-
inputs += torch.tensor(noise)
|
| 311 |
-
inputs, targets = inputs.to(device), targets.to(device)
|
| 312 |
-
optimizer.zero_grad()
|
| 313 |
-
outputs = net(inputs)
|
| 314 |
-
n_inputs = inputs.clone().detach().cpu().numpy()
|
| 315 |
-
if(batch_idx%99 == 0):
|
| 316 |
-
fig_name = imshow(n_inputs[0], fig_name= f'figures/gaussian_noise{gaussian_num}.png')
|
| 317 |
-
gaussian_fig = Image.open(fig_name)
|
| 318 |
-
|
| 319 |
-
loss = criterion(outputs, targets)
|
| 320 |
-
loss.backward()
|
| 321 |
-
optimizer.step()
|
| 322 |
-
|
| 323 |
-
train_loss += loss.item()
|
| 324 |
-
_, predicted = outputs.max(1)
|
| 325 |
-
total += targets.size(0)
|
| 326 |
-
correct += predicted.eq(targets).sum().item()
|
| 327 |
-
|
| 328 |
-
iter_prog = iter_prog + 1 # Iterating iteration amount
|
| 329 |
-
progress(iter_prog/iterations, desc=f"Training Epoch {epoch}", total=iterations)
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
# progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
| 333 |
-
# % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
| 334 |
-
|
| 335 |
-
except Exception as e:
|
| 336 |
-
print(f"Error: {e}")
|
| 337 |
-
gr.Warning(f"Training Error: {e}")
|
| 338 |
-
if sigma != 0:
|
| 339 |
-
return gaussian_fig
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
### TESTING
|
| 343 |
-
|
| 344 |
-
def test(epoch, net, testloader, device, criterion, progress = gr.Progress()):
|
| 345 |
-
try:
|
| 346 |
-
net.eval()
|
| 347 |
-
test_loss = 0
|
| 348 |
-
correct = 0
|
| 349 |
-
total = 0
|
| 350 |
-
|
| 351 |
-
iter_float = 10000/test_batch
|
| 352 |
-
iterations = math.ceil(iter_float)
|
| 353 |
-
iter_prog = 0
|
| 354 |
-
|
| 355 |
-
with torch.no_grad():
|
| 356 |
-
for batch_idx, (inputs, targets) in tqdm(enumerate(testloader)):
|
| 357 |
-
inputs, targets = inputs.to(device), targets.to(device)
|
| 358 |
-
outputs = net(inputs)
|
| 359 |
-
loss = criterion(outputs, targets)
|
| 360 |
-
|
| 361 |
-
test_loss += loss.item()
|
| 362 |
-
_, predicted = outputs.max(1)
|
| 363 |
-
total += targets.size(0)
|
| 364 |
-
correct += predicted.eq(targets).sum().item()
|
| 365 |
-
|
| 366 |
-
iter_prog = iter_prog + 1 # Iterating iteration amount
|
| 367 |
-
progress(iter_prog/iterations, desc=f"Testing Epoch {epoch}", total=iterations)
|
| 368 |
-
|
| 369 |
-
wandb.log({'epoch': epoch+1, 'loss': test_loss})
|
| 370 |
-
wandb.log({"acc": correct/total})
|
| 371 |
-
|
| 372 |
-
# progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
| 373 |
-
# % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
|
| 374 |
-
|
| 375 |
-
# Save checkpoint.
|
| 376 |
-
global best_acc
|
| 377 |
-
global acc
|
| 378 |
-
acc = 100.*correct/total
|
| 379 |
-
print(acc)
|
| 380 |
-
if acc > best_acc:
|
| 381 |
-
best_acc = acc
|
| 382 |
-
return best_acc, predicted
|
| 383 |
-
else:
|
| 384 |
-
return acc, predicted
|
| 385 |
-
# if acc > best_acc:
|
| 386 |
-
# print('Saving..')
|
| 387 |
-
# state = {
|
| 388 |
-
# 'net': net.state_dict(),
|
| 389 |
-
# 'acc': acc,
|
| 390 |
-
# 'epoch': epoch,
|
| 391 |
-
# }
|
| 392 |
-
# if not os.path.isdir('checkpoint'):
|
| 393 |
-
# os.mkdir('checkpoint')
|
| 394 |
-
# torch.save(state, './checkpoint/ckpt.pth')
|
| 395 |
-
# best_acc = acc
|
| 396 |
-
|
| 397 |
-
except Exception as e:
|
| 398 |
-
print(f"Error: {e}")
|
| 399 |
-
gr.Warning(f"Testing Error: {e}")
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
models_dict = {
|
| 403 |
-
#"AlexNet": models.AlexNet(weights=models.AlexNet_Weights.DEFAULT),
|
| 404 |
-
#"ConvNext_Small": models.convnext_small(weights=models.ConvNeXt_Small_Weights.DEFAULT),
|
| 405 |
-
#"ConvNext_Base": models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT),
|
| 406 |
-
#"ConvNext_Large": models.convnext_large(weights=models.ConvNeXt_Large_Weights.DEFAULT),
|
| 407 |
-
"DenseNet": models.densenet121(weights=models.DenseNet121_Weights.DEFAULT),
|
| 408 |
-
#"EfficientNet_B0": models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT),
|
| 409 |
-
#"GoogLeNet": models.googlenet(weights=models.GoogLeNet_Weights.DEFAULT),
|
| 410 |
-
# "InceptionNetV3": models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT),
|
| 411 |
-
# "MaxVit": models.maxvit_t(weights=models.MaxVit_T_Weights.DEFAULT),
|
| 412 |
-
#"MnasNet0_5": models.mnasnet0_5(weights=models.MNASNet0_5_Weights.DEFAULT),
|
| 413 |
-
#"MobileNetV2": models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT),
|
| 414 |
-
"ResNet18": models.resnet18(weights=models.ResNet18_Weights.DEFAULT),
|
| 415 |
-
"ResNet50": models.resnet50(weights=models.ResNet50_Weights.DEFAULT),
|
| 416 |
-
#"RegNet_X_400MF": models.regnet_x_400mf(weights=models.RegNet_X_400MF_Weights.DEFAULT),
|
| 417 |
-
#"ShuffleNet_V2_X0_5": models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT),
|
| 418 |
-
#"SqueezeNet": models.squeezenet1_0(weights=models.SqueezeNet1_0_Weights.DEFAULT),
|
| 419 |
-
"VGG19": models.vgg19(weights=models.VGG19_Weights.DEFAULT)
|
| 420 |
-
}
|
| 421 |
-
|
| 422 |
-
# Store dictionary keys into list for dropdown menu choices
|
| 423 |
-
names = list(models_dict.keys())
|
| 424 |
-
|
| 425 |
-
# Optimizer names
|
| 426 |
-
optimizers = ["SGD","Adam"]
|
| 427 |
-
|
| 428 |
-
# Scheduler names
|
| 429 |
-
schedulers = ["None","CosineAnnealingLR","ReduceLROnPlateau","StepLR"]
|
| 430 |
-
|
| 431 |
-
### GRADIO APP INTERFACE
|
| 432 |
-
|
| 433 |
-
def togglepicsettings(choice):
|
| 434 |
-
yes=gr.Gallery(visible=True)
|
| 435 |
-
no=gr.Gallery(visible=False)
|
| 436 |
-
if choice == "Yes":
|
| 437 |
-
return yes,no
|
| 438 |
-
else:
|
| 439 |
-
return no,yes
|
| 440 |
-
|
| 441 |
-
def settings(choice):
|
| 442 |
-
if choice == "Advanced":
|
| 443 |
-
advanced = [
|
| 444 |
-
gr.Slider(visible=True),
|
| 445 |
-
gr.Slider(visible=True),
|
| 446 |
-
gr.Slider(visible=True),
|
| 447 |
-
gr.Dropdown(visible=True),
|
| 448 |
-
gr.Dropdown(visible=True),
|
| 449 |
-
gr.Radio(visible=True)
|
| 450 |
-
]
|
| 451 |
-
return advanced
|
| 452 |
-
else:
|
| 453 |
-
basic = [
|
| 454 |
-
gr.Slider(visible=False),
|
| 455 |
-
gr.Slider(visible=False),
|
| 456 |
-
gr.Slider(visible=False),
|
| 457 |
-
gr.Dropdown(visible=False),
|
| 458 |
-
gr.Dropdown(visible=False),
|
| 459 |
-
gr.Radio(visible=False)
|
| 460 |
-
]
|
| 461 |
-
return basic
|
| 462 |
-
|
| 463 |
-
def attacks(choice):
|
| 464 |
-
if choice == "Yes":
|
| 465 |
-
yes = [
|
| 466 |
-
gr.Markdown(visible=True),
|
| 467 |
-
gr.Radio(visible=True),
|
| 468 |
-
gr.Radio(visible=True)
|
| 469 |
-
]
|
| 470 |
-
return yes
|
| 471 |
-
if choice == "No":
|
| 472 |
-
no = [
|
| 473 |
-
gr.Markdown(visible=False),
|
| 474 |
-
gr.Radio(visible=False),
|
| 475 |
-
gr.Radio(visible=False)
|
| 476 |
-
]
|
| 477 |
-
return no
|
| 478 |
-
|
| 479 |
-
def gaussian(choice):
|
| 480 |
-
if choice == "Yes":
|
| 481 |
-
yes = [
|
| 482 |
-
gr.Slider(visible=True),
|
| 483 |
-
gr.Gallery(visible=True),
|
| 484 |
-
]
|
| 485 |
-
return yes
|
| 486 |
-
else:
|
| 487 |
-
no = [
|
| 488 |
-
gr.Slider(visible=False),
|
| 489 |
-
gr.Gallery(visible=False),
|
| 490 |
-
]
|
| 491 |
-
return no
|
| 492 |
-
def adversarial(choice):
|
| 493 |
-
if choice == "Yes":
|
| 494 |
-
yes = gr.Gallery(visible=True)
|
| 495 |
-
return yes
|
| 496 |
-
else:
|
| 497 |
-
no = gr.Gallery(visible=False)
|
| 498 |
-
|
| 499 |
-
## Main app for functionality
|
| 500 |
-
with gr.Blocks(css=".caption-label {display:none}") as functionApp:
|
| 501 |
-
with gr.Row():
|
| 502 |
-
gr.Markdown("# CIFAR-10 Model Training GUI")
|
| 503 |
-
with gr.Row():
|
| 504 |
-
gr.Markdown("## Parameters")
|
| 505 |
-
with gr.Row():
|
| 506 |
-
inp = gr.Dropdown(choices=names, label="Training Model", value="ResNet18", info="Choose one of 13 common models provided in the dropdown to use for training.")
|
| 507 |
-
username = gr.Textbox(label="Weights and Biases", info="Enter your username or team name from the Weights and Biases API.")
|
| 508 |
-
epochs_sldr = gr.Slider(label="Number of Epochs", minimum=1, maximum=100, step=1, value=1, info="How many times the model will see the entire dataset during trianing.")
|
| 509 |
-
with gr.Column():
|
| 510 |
-
setting_radio = gr.Radio(["Basic", "Advanced"], label="Settings", value="Basic")
|
| 511 |
-
btn = gr.Button("Run")
|
| 512 |
-
with gr.Row():
|
| 513 |
-
train_sldr = gr.Slider(visible=False, label="Training Batch Size", minimum=1, maximum=1000, step=1, value=128, info="The number of training samples processed before the model's internal parameters are updated.")
|
| 514 |
-
test_sldr = gr.Slider(visible=False, label="Testing Batch Size", minimum=1, maximum=1000, step=1, value=100, info="The number of testing samples processed at once during the evaluation phase.")
|
| 515 |
-
learning_rate_sldr = gr.Slider(visible=False, label="Learning Rate", minimum=0.0001, maximum=0.1, step=0.0001, value=0.001, info="The learning rate of the optimization program.")
|
| 516 |
-
optimizer = gr.Dropdown(visible=False, label="Optimizer", choices=optimizers, value="SGD", info="The optimization algorithm used to minimize the loss function during training.")
|
| 517 |
-
scheduler = gr.Dropdown(visible=False, label="Scheduler", choices=schedulers, value="CosineAnnealingLR", info="The scheduler used to iteratively alter learning rate.")
|
| 518 |
-
use_attacks = gr.Radio(["Yes", "No"], visible=False, label="Use Attacking Methods?", value="No")
|
| 519 |
-
setting_radio.change(fn=settings, inputs=setting_radio, outputs=[train_sldr, test_sldr, learning_rate_sldr, optimizer, scheduler, use_attacks])
|
| 520 |
-
with gr.Row():
|
| 521 |
-
attack_method = gr.Markdown("## Attacking Methods", visible=False)
|
| 522 |
-
with gr.Row():
|
| 523 |
-
use_sigma = gr.Radio(["Yes","No"], visible=False, label="Use Gaussian Noise?", value="No")
|
| 524 |
-
sigma_sldr = gr.Slider(visible=False, label="Gaussian Noise", minimum=0, maximum=1, value=0, step=0.1, info="The sigma value of the gaussian noise eqaution. A value of 0 disables gaussian noise.")
|
| 525 |
-
adv_attack = gr.Radio(["Yes","No"], visible=False, label="Use Adversarial Attacks?", value="No")
|
| 526 |
-
with gr.Row():
|
| 527 |
-
gr.Markdown("## Training Results")
|
| 528 |
-
with gr.Row():
|
| 529 |
-
accuracy = gr.Textbox(label = "Accuracy", info="The validation accuracy of the trained model (accuracy evaluated on testing data).")
|
| 530 |
-
with gr.Column():
|
| 531 |
-
showpics = gr.Radio(["Yes","No"], visible = True, label = "Show all pictures?", value = "No")
|
| 532 |
-
pics = gr.Gallery(preview=False, selected_index=0, object_fit='contain', label="Testing Images")
|
| 533 |
-
allpics = gr.Gallery(preview=True, selected_index=0, object_fit='contain', label="Full Testing Images",visible = False)
|
| 534 |
-
showpics.change(fn=togglepicsettings, inputs=[showpics], outputs = [allpics, pics])
|
| 535 |
-
with gr.Row():
|
| 536 |
-
gaussian_pics = gr.Gallery(visible=False, preview=False, selected_index=0, object_fit='contain', label="Gaussian Noise")
|
| 537 |
-
attack_pics = gr.Gallery(visible=False, preview=False, selected_index=0, object_fit='contain', label="Adversarial Attack")
|
| 538 |
-
use_attacks.change(fn=attacks, inputs=use_attacks, outputs=[attack_method, use_sigma, adv_attack])
|
| 539 |
-
use_sigma.change(fn=gaussian, inputs=use_sigma, outputs=[sigma_sldr, gaussian_pics])
|
| 540 |
-
adv_attack.change(fn=adversarial, inputs=adv_attack, outputs=attack_pics)
|
| 541 |
-
btn.click(fn=main, inputs=[inp, epochs_sldr, train_sldr, test_sldr, learning_rate_sldr, optimizer, sigma_sldr, adv_attack, username, scheduler], outputs=[accuracy, pics, allpics, gaussian_pics, attack_pics])
|
| 542 |
-
|
| 543 |
-
## Documentation app (implemented as second tab)
|
| 544 |
-
|
| 545 |
-
markdown_file_path = 'documentation.md'
|
| 546 |
-
with open(markdown_file_path, 'r') as file:
|
| 547 |
-
markdown_content = file.read()
|
| 548 |
-
|
| 549 |
-
with gr.Blocks() as documentationApp:
|
| 550 |
-
with gr.Row():
|
| 551 |
-
gr.Markdown("# CIFAR-10 Training Interface Documentation")
|
| 552 |
-
with gr.Row():
|
| 553 |
-
gr.Markdown(markdown_content) # Can be collapesed in VSCode to hide paragraphs from view. Vscode can also wrap text.
|
| 554 |
-
|
| 555 |
-
### LAUNCH APP
|
| 556 |
-
|
| 557 |
-
if __name__ == '__main__':
|
| 558 |
-
mainApp = gr.TabbedInterface([functionApp, documentationApp], ["Welcome", "Documentation"], theme=theme)
|
| 559 |
-
mainApp.queue()
|
| 560 |
-
mainApp.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|