File size: 5,226 Bytes
c91d7b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
This function help to train model of different archtecture easily. Select model archtecture and training data, then output corresponding model.
"""
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F #233
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from PIL import Image
def train(model, data, device, maxepoch, data_path = './', save_per_epoch = 10, seed = 100):
"""train.
Parameters
----------
model :
model(option:'CNN', 'ResNet18', 'ResNet34', 'ResNet50', 'densenet', 'vgg11', 'vgg13', 'vgg16', 'vgg19')
data :
data(option:'MNIST','CIFAR10')
device :
device(option:'cpu', 'cuda')
maxepoch :
training epoch
data_path :
data path(default = './')
save_per_epoch :
save_per_epoch(default = 10)
seed :
seed
Examples
--------
>>>import deeprobust.image.netmodels.train_model as trainmodel
>>>trainmodel.train('CNN', 'MNIST', 'cuda', 20)
"""
torch.manual_seed(seed)
train_loader, test_loader = feed_dataset(data, data_path)
if (model == 'CNN'):
import deeprobust.image.netmodels.CNN as MODEL
#from deeprobust.image.netmodels.CNN import Net
train_net = MODEL.Net().to(device)
elif (model == 'ResNet18'):
import deeprobust.image.netmodels.resnet as MODEL
train_net = MODEL.ResNet18().to(device)
elif (model == 'ResNet34'):
import deeprobust.image.netmodels.resnet as MODEL
train_net = MODEL.ResNet34().to(device)
elif (model == 'ResNet50'):
import deeprobust.image.netmodels.resnet as MODEL
train_net = MODEL.ResNet50().to(device)
elif (model == 'densenet'):
import deeprobust.image.netmodels.densenet as MODEL
train_net = MODEL.densenet_cifar().to(device)
elif (model == 'vgg11'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG11').to(device)
elif (model == 'vgg13'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG13').to(device)
elif (model == 'vgg16'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG16').to(device)
elif (model == 'vgg19'):
import deeprobust.image.netmodels.vgg as MODEL
train_net = MODEL.VGG('VGG19').to(device)
optimizer = optim.SGD(train_net.parameters(), lr= 0.1, momentum=0.5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
save_model = True
for epoch in range(1, maxepoch + 1): ## 5 batches
print(epoch)
MODEL.train(train_net, device, train_loader, optimizer, epoch)
MODEL.test(train_net, device, test_loader)
if (save_model and (epoch % (save_per_epoch) == 0 or epoch == maxepoch)):
if os.path.isdir('./trained_models/'):
print('Save model.')
torch.save(train_net.state_dict(), os.path.join('trained_models', data + "_" + model + "_epoch_" + str(epoch) + ".pt"))
else:
os.mkdir('./trained_models/')
print('Make directory and save model.')
torch.save(train_net.state_dict(), os.path.join('trained_models', data + "_" + model + "_epoch_" + str(epoch) + ".pt"))
scheduler.step()
def feed_dataset(data, data_dict):
if(data == 'CIFAR10'):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(data_dict, train=True, download = True,
transform=transform_train),
batch_size= 128, shuffle=True) #, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(data_dict, train=False, download = True,
transform=transform_val),
batch_size= 1000, shuffle=True) #, **kwargs)
elif(data == 'MNIST'):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dict, train=True, download = True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=128,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dict, train=False, download = True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=1000,
shuffle=True)
elif(data == 'ImageNet'):
pass
return train_loader, test_loader
|