# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse from argparse import ArgumentParser from datasets import NAMES as DATASET_NAMES from models import get_all_models def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def add_gcil_args(parser: ArgumentParser) -> None: """ Adds the arguments required for GCIL-CIFAR100 Dataset. :param parser: the parser instance """ # arguments for GCIL-CIFAR100 parser.add_argument('--gil_seed', type=int, default=1993, help='Seed value for GIL-CIFAR task sampling') parser.add_argument('--pretrain', action='store_true', default=False, help='whether to use pretrain') parser.add_argument('--phase_class_upper', default=50, type=int, help='the maximum number of classes') parser.add_argument('--epoch_size', default=1000, type=int, help='Number of samples in one epoch') parser.add_argument('--pretrain_class_nb', default=0, type=int, help='the number of classes in first group') parser.add_argument('--weight_dist', default='unif', type=str, help='what type of weight distribution assigned to classes to sample (unif or longtail)') def add_experiment_args(parser: ArgumentParser) -> None: """ Adds the arguments used by all the models. :param parser: the parser instance """ parser.add_argument('--dataset', type=str, required=True, choices=DATASET_NAMES, help='Which dataset to perform experiments on.') parser.add_argument('--half_data_in_first_task', action='store_true', help='use half of data for first expirience') parser.add_argument('--model', type=str, required=True, help='Model name.', choices=get_all_models()) parser.add_argument('--resnet_width', type=float, default=1.0) parser.add_argument('--lr', type=float, required=True, help='Learning rate.') parser.add_argument('--optim_wd', type=float, default=0., help='optimizer weight decay.') parser.add_argument('--optim_mom', type=float, default=0., help='optimizer momentum.') parser.add_argument('--optim_nesterov', type=int, default=0, help='optimizer nesterov momentum.') parser.add_argument('--n_epochs', type=int, help='Batch size.') parser.add_argument('--batch_size', type=int, help='Batch size.') parser.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp']) parser.add_argument('--device', type=str, default='cuda:0') def add_management_args(parser: ArgumentParser) -> None: parser.add_argument('--seed', type=int, default=None, help='The random seed.') parser.add_argument('--notes', type=str, default=None, help='Notes for this run.') parser.add_argument('--non_verbose', default=0, choices=[0, 1], type=int, help='Make progress bars non verbose') parser.add_argument('--disable_log', default=0, choices=[0, 1], type=int, help='Enable csv logging') parser.add_argument('--validation', default=0, choices=[0, 1], type=int, help='Test on the validation set') parser.add_argument('--ignore_other_metrics', default=0, choices=[0, 1], type=int, help='disable additional metrics') parser.add_argument('--debug', action='store_true', help='Run only a few forward steps per epoch') parser.add_argument('--experiment_name', type=str, default='Default') parser.add_argument('--parent_run_id', default=None, type=str, help='mlflow parent run id, used for creating nested run in mlflow logger') parser.add_argument('--run_name', type=str, default=None) parser.add_argument('--n_tasks', type=int, default=10) def add_rehearsal_args(parser: ArgumentParser) -> None: """ Adds the arguments used by all the rehearsal-based methods :param parser: the parser instance """ parser.add_argument('--buffer_size', type=int, required=True, help='The size of the memory buffer.') parser.add_argument('--minibatch_size', type=int, help='The batch size of the memory buffer.')