| |
| |
| |
| |
| import argparse |
| from argparse import ArgumentParser |
| from datasets import NAMES as DATASET_NAMES |
| from models import get_all_models |
|
|
| def str2bool(v): |
| if isinstance(v, bool): |
| return v |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| return True |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| return False |
| else: |
| raise argparse.ArgumentTypeError('Boolean value expected.') |
| def add_gcil_args(parser: ArgumentParser) -> None: |
| """ |
| Adds the arguments required for GCIL-CIFAR100 Dataset. |
| :param parser: the parser instance |
| """ |
| |
| parser.add_argument('--gil_seed', type=int, default=1993, help='Seed value for GIL-CIFAR task sampling') |
| parser.add_argument('--pretrain', action='store_true', default=False, help='whether to use pretrain') |
| parser.add_argument('--phase_class_upper', default=50, type=int, help='the maximum number of classes') |
| parser.add_argument('--epoch_size', default=1000, type=int, help='Number of samples in one epoch') |
| parser.add_argument('--pretrain_class_nb', default=0, type=int, help='the number of classes in first group') |
| parser.add_argument('--weight_dist', default='unif', type=str, help='what type of weight distribution assigned to classes to sample (unif or longtail)') |
| def add_experiment_args(parser: ArgumentParser) -> None: |
| """ |
| Adds the arguments used by all the models. |
| :param parser: the parser instance |
| """ |
| parser.add_argument('--dataset', type=str, required=True, |
| choices=DATASET_NAMES, |
| help='Which dataset to perform experiments on.') |
| parser.add_argument('--half_data_in_first_task', action='store_true', help='use half of data for first expirience') |
| parser.add_argument('--model', type=str, required=True, |
| help='Model name.', choices=get_all_models()) |
| parser.add_argument('--resnet_width', type=float, default=1.0) |
|
|
| parser.add_argument('--lr', type=float, required=True, |
| help='Learning rate.') |
|
|
| parser.add_argument('--optim_wd', type=float, default=0., |
| help='optimizer weight decay.') |
| parser.add_argument('--optim_mom', type=float, default=0., |
| help='optimizer momentum.') |
| parser.add_argument('--optim_nesterov', type=int, default=0, |
| help='optimizer nesterov momentum.') |
|
|
| parser.add_argument('--n_epochs', type=int, |
| help='Batch size.') |
| parser.add_argument('--batch_size', type=int, |
| help='Batch size.') |
| parser.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp']) |
| parser.add_argument('--device', type=str, default='cuda:0') |
|
|
|
|
| def add_management_args(parser: ArgumentParser) -> None: |
| parser.add_argument('--seed', type=int, default=None, |
| help='The random seed.') |
| parser.add_argument('--notes', type=str, default=None, |
| help='Notes for this run.') |
|
|
| parser.add_argument('--non_verbose', default=0, choices=[0, 1], type=int, help='Make progress bars non verbose') |
| parser.add_argument('--disable_log', default=0, choices=[0, 1], type=int, help='Enable csv logging') |
|
|
| parser.add_argument('--validation', default=0, choices=[0, 1], type=int, |
| help='Test on the validation set') |
| parser.add_argument('--ignore_other_metrics', default=0, choices=[0, 1], type=int, |
| help='disable additional metrics') |
| parser.add_argument('--debug', action='store_true', help='Run only a few forward steps per epoch') |
|
|
| parser.add_argument('--experiment_name', type=str, default='Default') |
| parser.add_argument('--parent_run_id', default=None, type=str, help='mlflow parent run id, used for creating nested run in mlflow logger') |
| parser.add_argument('--run_name', type=str, default=None) |
| parser.add_argument('--n_tasks', type=int, default=10) |
|
|
|
|
| def add_rehearsal_args(parser: ArgumentParser) -> None: |
| """ |
| Adds the arguments used by all the rehearsal-based methods |
| :param parser: the parser instance |
| """ |
| parser.add_argument('--buffer_size', type=int, required=True, |
| help='The size of the memory buffer.') |
| parser.add_argument('--minibatch_size', type=int, |
| help='The batch size of the memory buffer.') |
|
|