File size: 24,629 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 |
import os
import time
import yaml
import copy
from datetime import datetime
import torch
import torchtask
from torchtask.utils import logger, cmd
from torchtask.nn import data as nndata
from torchtask.nn import lrer as nnlrer
from torchtask.nn import optimizer as nnoptimizer
def add_parser_arguments(parser):
# experimental arguments
parser.add_argument('--exp-id', type=str, default='', metavar='', help='exp - the unique id (or name) of experiment')
parser.add_argument('--resume', type=str, default='', metavar='', help='exp - the checkpoint file that will be resumed')
parser.add_argument('--validation', type=cmd.str2bool, default=False, metavar='', help='exp - validation only if True')
parser.add_argument('--out-path', type=str, default='', metavar='', help='exp - the path where the output of experiment is stored')
parser.add_argument('--visualize', type=cmd.str2bool, default=False, metavar='', help='exp - save the output images for visualization if True')
parser.add_argument('--debug', type=cmd.str2bool, default=False, metavar='', help='exp - experiment under debug mode if True')
parser.add_argument('--val-freq', type=int, default=1, metavar='', help='exp - validation frequency during training [unit: epoch]')
parser.add_argument('--log-freq', type=int, default=100, metavar='', help='exp - logging frequency during training and validation [unit: iteration]')
parser.add_argument('--visual-freq', type=int, default=100, metavar='', help='exp - visulization frequency during training and validation [unit: iteration]')
parser.add_argument('--checkpoint-freq', type=int, default=1, metavar='', help='exp - checkpoint saving frequency during training [unit: epoch]')
# dataset / dataloader arguments
parser.add_argument('--trainset', type=yaml.full_load, default={}, metavar='', help='data - path of the train dataset of format {dataset_type: [path1, path2]}')
parser.add_argument('--valset', type=yaml.full_load, default={}, metavar='', help='data - path of the validate dataset of format {dataset_type: [path1, path2]}')
parser.add_argument('--num-workers', type=int, default=1, metavar='', help='data - number of workers for the data loader on each GPU')
parser.add_argument('--im-size', type=int, default=None, help='data - target size of the input images')
parser.add_argument('--additionalset', type=yaml.full_load, default={}, metavar='', help='data - path of the extra additional dataset of format {dataset_type: [path1, path2]}')
parser.add_argument('--sublabeled-path', type=str, default='', metavar='', help='data - path of the file that stores the prefix of the labeled subset')
parser.add_argument('--ignore-additional', type=cmd.str2bool, default=True, metavar='', help='data - ignore (do not use) the additional samples during training if True')
parser.add_argument('--short-ep', type=cmd.str2bool, default=False, metavar='', help='data - ')
# task algorithm arguments
parser.add_argument('--trainer', type=str, default='', metavar='', help='task - the task algorithm used in experiment')
parser.add_argument('--models', type=yaml.full_load, default={}, metavar='', help='task - dict saves all {component_key: task_model} for the task algorithm')
parser.add_argument('--optimizers', type=yaml.full_load, default={}, metavar='', help='task - dict saves all {component_key: task_optimizer} for the task algorithm')
parser.add_argument('--lrers', type=yaml.full_load, default={}, metavar='', help='task - dict saves all {component_key: task_lrer} for the task algorithm')
parser.add_argument('--criterions', type=yaml.full_load, default={}, metavar='', help='task - dict saves all {componet_key: task_criterion} for the task algorithm')
# training arguemnts
parser.add_argument('--epochs', type=int, default=1, metavar='', help='train/val - total epochs for training')
parser.add_argument('--batch-size', type=int, default=16, metavar='', help='train/val - total batch size for training/validation on each GPU')
parser.add_argument('--additional-batch-size', type=int, default=0, metavar='', help='train/val - number of additional samples in a mini-batch on each GPU')
# arguments set by the code of proxy
parser.add_argument('--gpus', type=int, default=0, metavar='', help='autoset - number of GPUs for running [this argument is automatically set by code!]')
parser.add_argument('--task', type=str, default='', metavar='', help='autoset - name string of current task [this argument is automatically set by code!]')
parser.add_argument('--labeled-batch-size', type=int, default=None, metavar='', help='autoset - number of labeled samples in a mini-batch on each GPU [this argument is automatically set by code!]')
parser.add_argument('--checkpoint-path', type=str, default='', metavar='', help='autoset - the path used to save the checkpoint files [this argument is automatically set by code!]')
parser.add_argument('--visual-debug-path', type=str, default='', metavar='', help='autoset - the path used to save the debuging images for visualization [this argument is automatically set by code!]')
parser.add_argument('--visual-train-path', type=str, default='', metavar='', help='autoset - the path used to save the training images for visualization [this argument is automatically set by code!]')
parser.add_argument('--visual-val-path', type=str, default='', metavar='', help='autoset - the path used to save the validation images for visualization [this argument is automatically set by code!]')
parser.add_argument('--is-epoch-lrer', type=cmd.str2bool, default=None, metavar='', help='autoset - adjust the learning rate after (1) each epoch (if True) or each iter (if False) [this argument is automatically set by code!]')
parser.add_argument('--iters-per-epoch', type=int, default=None, metavar='', help='autoset - number of iterations inside each epoch [this argument is automatically set by code!]')
class TaskProxy:
NAME = 'task' # specific name of task
def __init__(self, args, func, data, model, criterion, trainer):
self.args = args # arguments dict for task-specific proxy
self.func = func # instance of 'TaskFunc'
self.data = data # instance of 'TaskData'
self.model = model # instance of 'TaskModel'
self.criterion = criterion # instance of 'TaskCriterion'
self.trainer = None
self.trainer_class = trainer
self.model_dict = {}
self.criterion_dict = {}
self.optimizer_dict = {}
self.lrer_dict = {}
self.train_loader = None # instance of 'torch.utils.data.DataLoader'
self.val_loader = None # instance of 'torch.utils.data.DataLoader'
self._init()
def run(self):
self._run()
def _run(self):
start_epoch = 0
if self.args.resume is not None and self.args.resume != '':
logger.log_info('Load checkpoint from: {0}'.format(self.args.resume))
start_epoch = self.trainer.load_checkpoint()
if self.args.validation:
if self.val_loader is None:
logger.log_err('No data loader for validation.\n'
'Please set right \'valset\' in the script.\n')
logger.log_info(['=' * 78, '\nStart to validate model\n', '=' * 78])
with torch.no_grad():
self.trainer.validate(self.val_loader, start_epoch - 1)
self.trainer.save_checkpoint(0)
return
# NOTE: the first epoch index for 'train' and 'validatie' is 0
for epoch in range(start_epoch, self.args.epochs):
timer = time.time()
logger.log_info(['=' * 78, '\nStart to train epoch-{0}\n'.format(epoch + 1), '=' * 78])
self.trainer.train(self.train_loader, epoch)
if (epoch + 1) % self.args.val_freq == 0 and self.val_loader is not None:
logger.log_info(['=' * 78, '\nStart to validate epoch-{0}\n'.format(epoch + 1), '=' * 78])
with torch.no_grad():
self.trainer.validate(self.val_loader, epoch)
if (epoch + 1) % self.args.checkpoint_freq == 0:
self.trainer.save_checkpoint(epoch + 1)
logger.log_info("Save checkpoint for epoch {0}".format(epoch + 1))
logger.log_info('Finish epoch in {0} seconds\n'.format(time.time() - timer))
logger.log_info('Finish experiment {0}\n'.format(self.args.exp_id))
def _init(self):
""" Initial function of the task proxy.
"""
self._preprocess_arguments()
self._create_dataloader()
self._build_trainer()
def _preprocess_arguments(self):
""" Preprocess the arguments in the script.
"""
# create the output folder to store the results
self.args.out_path = "{root}/{exp_id}/{date:%Y-%m-%d_%H-%M-%S}/".format(
root=self.args.out_path, exp_id=self.args.exp_id, date=datetime.now())
if not os.path.exists(self.args.out_path):
os.makedirs(self.args.out_path)
# prepare logger
exp_op = 'val' if self.args.validation else 'train'
logger.log_mode(self.args.debug)
logger.log_file(os.path.join(self.args.out_path, '{0}.log'.format(exp_op)), self.args.debug)
logger.log_info('Result folder: \n {0} \n'.format(self.args.out_path))
# print experimental args
cmd.print_args()
# set task name
self.args.task = self.NAME
# check the task-specific components dicts required by the task algorithm
if not len(self.args.models) == len(self.args.optimizers) == len(self.args.lrers) == len(self.args.criterions):
logger.log_err('Condition:\n'
'\tlen(self.args.models) == len(self.args.optimizers) == len(self.args.lrers) == len(self.args.criterions\n'
'is not satisfied in the script\n')
for (model, criterion, optimizer, lrer) in \
zip(self.args.models.values(), self.args.criterions.values(), self.args.optimizers.values(), self.args.lrers.values()):
if model not in self.model.__dict__:
logger.log_err('Unsupport model: {0} for task: {1}\n'
'Please add the export function in task\'s \'model.py\'\n'.format(model, self.args.task))
elif criterion not in self.criterion.__dict__:
logger.log_err('Unsupport criterion: {0} for task: {1}\n'
'Please add the export function in task\'s \'criterion.py\'\n'.format(criterion, self.args.task))
elif optimizer not in nnoptimizer.__dict__:
logger.log_err('Unsupport optimizer: {0}\n'
'Please implement the optimizer wrapper in \'torchtask/nn/optimizer.py\'\n'.format(optimizer))
elif lrer not in nnlrer.__dict__:
logger.log_err('Unsupport learning rate scheduler: {0}\n'
'Please implement lr scheduler wrapper in \'torchtask/nn/lrer.py\'\n'.format(lrer))
# check the types of lrers
for lrer in self.args.lrers.values():
if lrer in nnlrer.EPOCH_LRERS:
is_epoch_lrer = True
elif lrer in nnlrer.ITER_LRERS:
is_epoch_lrer = False
else:
logger.log_err('Unknown learning rate scheduler ({0}) type\n'
'Please add it into either EPOCH_LRERS or ITER_LRERS in \'torchtask/nn/lrer.py\'\n'
'TorchTask supports: \n'
' EPOCH_LRERS\t=>\t{1}\n ITER_LRERS\t=>\t{2}\n'.format(lrer, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS))
if self.args.is_epoch_lrer is None:
self.args.is_epoch_lrer = is_epoch_lrer
elif self.args.is_epoch_lrer != is_epoch_lrer:
logger.log_err('Unmatched lr scheduler types\t=>\t{0}\n'
'All lrers of the task models should have the same types (either EPOCH_LRERS or ITER_LRERS)\n'
'TorchTask supports: \n'
' EPOCH_LRERS\t=>\t{1}\n ITER_LRERS\t=>\t{2}\n'
.format(self.args.lrers, nnlrer.EPOCH_LRERS, nnlrer.ITER_LRERS))
self.args.checkpoint_path = os.path.join(self.args.out_path, 'ckpt')
if not os.path.exists(self.args.checkpoint_path):
os.makedirs(self.args.checkpoint_path)
if self.args.visualize:
self.args.visual_debug_path = os.path.join(self.args.out_path, 'visualization/debug')
self.args.visual_train_path = os.path.join(self.args.out_path, 'visualization/train')
self.args.visual_val_path = os.path.join(self.args.out_path, 'visualization/val')
for vpath in [self.args.visual_debug_path, self.args.visual_train_path, self.args.visual_val_path]:
if not os.path.exists(vpath):
os.makedirs(vpath)
# handle argumens for multiply GPUs training
self.args.gpus = torch.cuda.device_count()
if self.args.gpus < 1:
logger.log_err('No GPU be detected\n'
'TorchTask requires at least one Nvidia GPU\n')
logger.log_info('GPU: \n Total GPU(s): {0}'.format(self.args.gpus))
self.args.lr *= self.args.gpus
self.args.num_workers *= self.args.gpus
self.args.batch_size *= self.args.gpus
self.args.additional_batch_size *= self.args.gpus
# TODO: support unsupervised and self-supervised training
if self.args.additional_batch_size >= self.args.batch_size:
logger.log_err('The argument \'additional_batch_size\' ({0}) should be smaller than \'batch_size\' ({1}) '
'since TorchTask only supports semi-supervised learning now\n')
self.args.labeled_batch_size = self.args.batch_size - self.args.additional_batch_size
logger.log_info(' Total learn rate: {0}\n Total labeled batch size: {1}\n'
' Total additional batch size: {2}\n Total data workers: {3}\n'.format(
self.args.lr, self.args.labeled_batch_size, self.args.additional_batch_size, self.args.num_workers))
def _create_dataloader(self):
""" Create data loaders for experiment.
"""
# ---------------------------------------------------------------------
# create dataloder for training
# ---------------------------------------------------------------------
# ignore_additional == False & additional_batch_size != 0
# means that both labeled and additional data are used
with_additional_data = not self.args.ignore_additional and self.args.additional_batch_size != 0
# ignore_additional == True & additional_batch_size == 0
# means that only the labeled data is used
without_additional_data = self.args.ignore_additional and self.args.additional_batch_size == 0
labeled_train_samples, additional_train_samples = 0, 0
if not self.args.validation:
# ignore_additional == True & additional_batch_size != 0 -> error
if self.args.ignore_additional and self.args.additional_batch_size != 0:
logger.log_err('Arguments conflict => ignore_additional == True requires additional_batch_size == 0\n')
# ignore_additional == False & additional_batch_size == 0 -> error
if not self.args.ignore_additional and self.args.additional_batch_size == 0:
logger.log_err('Arguments conflict => ignore_additional == False requires additional_batch_size != 0\n')
# calculate the number of trainsets
trainset_num = 0
for key, value in self.args.trainset.items():
trainset_num += len(value)
# calculate the number of additionalsets
additionalset_num = 0
for key, value in self.args.additionalset.items():
additionalset_num += len(value)
# if only one labeled training set and without any additional set
if trainset_num == 1 and additionalset_num == 0:
trainset = self._load_dataset(list(self.args.trainset.keys())[0], list(self.args.trainset.values())[0][0])
labeled_train_samples = len(trainset.idxs)
# if the 'sublabeled_path' is given
sublabeled_prefix = None
if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
if not os.path.exists(self.args.sublabeled_path):
logger.log_err('Cannot find labeled file: {0}\n'.format(self.args.sublabeled_path))
else:
with open(self.args.sublabeled_path) as f:
sublabeled_prefix = [line.strip() for line in f.read().splitlines()]
sublabeled_prefix = None if len(sublabeled_prefix) == 0 else sublabeled_prefix
if sublabeled_prefix is not None:
# wrap the trainset by 'SplitUnlabeledWrapper'
trainset = nndata.SplitUnlabeledWrapper(
trainset, sublabeled_prefix, ignore_additional=self.args.ignore_additional)
labeled_train_samples = len(trainset.labeled_idxs)
additional_train_samples = len(trainset.additional_idxs)
# if 'sublabeled_prefix' is None but you want to use the additional data for training
elif with_additional_data:
logger.log_err('Try to use the additional samples without any task dataset wrapper\n')
# if more than one labeled training sets are given or the additional training sets are given
elif trainset_num > 1 or additionalset_num > 0:
# 'arg.sublabeled_path' is disabled
if self.args.sublabeled_path is not None and self.args.sublabeled_path != '':
logger.log_err('Multiple training datasets are given. \n'
'Inter-split additional set is not allowed.\n'
'Please remove the argument \'sublabeled_path\' in the script\n')
# load all training sets
labeled_sets = []
for set_name, set_dirs in self.args.trainset.items():
for set_dir in set_dirs:
labeled_sets.append(self._load_dataset(set_name, set_dir))
# load all extra additional sets
additional_sets = []
# if any extra additional set is given
if additionalset_num > 0:
for set_name, set_dirs in self.args.additionalset.items():
for set_dir in set_dirs:
additional_sets.append(self._load_dataset(set_name, set_dir))
# if unalbeledset_num == 0 but you want to use the additional data for training
elif with_additional_data:
logger.log_err('Try to use the additional samples without any task dataset wrapper\n'
'Please add the argument \'additionalset\' in the script\n')
# wrap both 'labeled_set' and 'additional_set' by 'JointDatasetsWrapper'
trainset = nndata.JointDatasetsWrapper(
labeled_sets, additional_sets, ignore_additional=self.args.ignore_additional)
labeled_train_samples = len(trainset.labeled_idxs)
additional_train_samples = len(trainset.additional_idxs)
# if use labeled data only
if without_additional_data:
self.train_loader = torch.utils.data.DataLoader(trainset, batch_size=self.args.batch_size,
shuffle=True, num_workers=self.args.num_workers, pin_memory=True, drop_last=True)
# if use both labeled and additional data
elif with_additional_data:
train_sampler = nndata.TwoStreamBatchSampler(trainset.labeled_idxs, trainset.additional_idxs,
self.args.labeled_batch_size, self.args.additional_batch_size, short_ep=self.args.short_ep)
self.train_loader = torch.utils.data.DataLoader(trainset, batch_sampler=train_sampler,
num_workers=self.args.num_workers, pin_memory=True)
# ---------------------------------------------------------------------
# create dataloader for validation
# ---------------------------------------------------------------------
# calculate the number of valsets
valset_num = 0
for key, value in self.args.valset.items():
valset_num += len(value)
# if only one validation set is given
if valset_num == 1:
valset = self._load_dataset(
list(self.args.valset.keys())[0], list(self.args.valset.values())[0][0], is_train=False)
val_samples = len(valset.idxs)
# if more than one validation sets are given
elif valset_num > 1:
valsets = []
for set_name, set_dirs in self.args.valset.items():
for set_dir in set_dirs:
valsets.append(self._load_dataset(set_name, set_dir, is_train=False))
valset = nndata.JointDatasetsWrapper(valsets, [], ignore_additional=True)
val_samples = len(valset.labeled_idxs)
# NOTE: batch size is set to 1 during the validation
self.val_loader = torch.utils.data.DataLoader(valset, batch_size=1,
shuffle=False, num_workers=0, pin_memory=True)
# check the data loaders
if self.train_loader is None and not self.args.validation:
logger.log_err('Train data loader is required if validate mode is closed\n')
elif self.val_loader is None and self.args.validation:
logger.log_err('Validate data loader is required if validate mode is opened\n')
elif self.val_loader is None:
logger.log_warn('No validate data loader, there are no validation during the training\n')
# set 'iters_per_epoch', which is required by ITER_LRERS
self.args.iters_per_epoch = len(self.train_loader) if self.train_loader is not None else -1
logger.log_info('Dataset:\n'
' Trainset\t=>\tlabeled samples = {0}\t\tadditional samples = {1}\n'
' Valset\t=>\tsamples = {2}\n'
.format(labeled_train_samples, additional_train_samples, val_samples))
def _build_trainer(self):
""" Build the semi-supervised learning algorithm given in the script.
"""
for cname in self.args.models.keys():
self.model_dict[cname] = self.model.__dict__[self.args.models[cname]]()
self.criterion_dict[cname] = self.criterion.__dict__[self.args.criterions[cname]]()
self.lrer_dict[cname] = nnlrer.__dict__[self.args.lrers[cname]](self.args)
self.optimizer_dict[cname] = nnoptimizer.__dict__[self.args.optimizers[cname]](self.args)
logger.log_info('Trainer: \n {0}\n'.format(self.args.trainer))
logger.log_info('Models: ')
self.trainer = self.trainer_class.__dict__[self.args.trainer](
self.args, self.model_dict, self.optimizer_dict, self.lrer_dict, self.criterion_dict, self.func.task_func()(self.args))
def _load_dataset(self, dataset_name, dataset_dir, is_train=True):
""" Load one dataset.
"""
if not dataset_name in self.data.__dict__.keys():
logger.log_err('Unknown dataset type: {0}\n'.format(dataset_name))
elif not os.path.exists(dataset_dir):
logger.log_err('Cannot find the path of dataset: {0}\n'.format(dataset_dir))
else:
dataset_args = copy.deepcopy(self.args)
if is_train:
dataset_args.trainset = {dataset_name: dataset_dir}
else:
dataset_args.valset = {dataset_name: dataset_dir}
return self.data.__dict__[dataset_name]()(dataset_args, is_train)
|